mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 23:28:07 -05:00
Merge branch 'dev' into codex/add-edit-video-and-transcribe-video-blocks
This commit is contained in:
@@ -9,11 +9,13 @@
|
||||
|
||||
# Platform - Backend
|
||||
!autogpt_platform/backend/backend/
|
||||
!autogpt_platform/backend/test/e2e_test_data.py
|
||||
!autogpt_platform/backend/migrations/
|
||||
!autogpt_platform/backend/schema.prisma
|
||||
!autogpt_platform/backend/pyproject.toml
|
||||
!autogpt_platform/backend/poetry.lock
|
||||
!autogpt_platform/backend/README.md
|
||||
!autogpt_platform/backend/.env
|
||||
|
||||
# Platform - Market
|
||||
!autogpt_platform/market/market/
|
||||
@@ -26,6 +28,7 @@
|
||||
# Platform - Frontend
|
||||
!autogpt_platform/frontend/src/
|
||||
!autogpt_platform/frontend/public/
|
||||
!autogpt_platform/frontend/scripts/
|
||||
!autogpt_platform/frontend/package.json
|
||||
!autogpt_platform/frontend/pnpm-lock.yaml
|
||||
!autogpt_platform/frontend/tsconfig.json
|
||||
@@ -33,6 +36,7 @@
|
||||
## config
|
||||
!autogpt_platform/frontend/*.config.*
|
||||
!autogpt_platform/frontend/.env.*
|
||||
!autogpt_platform/frontend/.env
|
||||
|
||||
# Classic - AutoGPT
|
||||
!classic/original_autogpt/autogpt/
|
||||
|
||||
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -24,7 +24,8 @@
|
||||
</details>
|
||||
|
||||
#### For configuration changes:
|
||||
- [ ] `.env.example` is updated or already compatible with my changes
|
||||
|
||||
- [ ] `.env.default` is updated or already compatible with my changes
|
||||
- [ ] `docker-compose.yml` is updated or already compatible with my changes
|
||||
- [ ] I have included a list of my configuration changes in the PR description (under **Changes**)
|
||||
|
||||
|
||||
322
.github/copilot-instructions.md
vendored
Normal file
322
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,322 @@
|
||||
# GitHub Copilot Instructions for AutoGPT
|
||||
|
||||
This file provides comprehensive onboarding information for GitHub Copilot coding agent to work efficiently with the AutoGPT repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
**AutoGPT** is a powerful platform for creating, deploying, and managing continuous AI agents that automate complex workflows. This is a large monorepo (~150MB) containing multiple components:
|
||||
|
||||
- **AutoGPT Platform** (`autogpt_platform/`) - Main focus: Modern AI agent platform (Polyform Shield License)
|
||||
- **Classic AutoGPT** (`classic/`) - Legacy agent system (MIT License)
|
||||
- **Documentation** (`docs/`) - MkDocs-based documentation site
|
||||
- **Infrastructure** - Docker configurations, CI/CD, and development tools
|
||||
|
||||
**Primary Languages & Frameworks:**
|
||||
|
||||
- **Backend**: Python 3.10-3.13, FastAPI, Prisma ORM, PostgreSQL, RabbitMQ
|
||||
- **Frontend**: TypeScript, Next.js 15, React, Tailwind CSS, Radix UI
|
||||
- **Development**: Docker, Poetry, pnpm, Playwright, Storybook
|
||||
|
||||
## Build and Validation Instructions
|
||||
|
||||
### Essential Setup Commands
|
||||
|
||||
**Always run these commands in the correct directory and in this order:**
|
||||
|
||||
1. **Initial Setup** (required once):
|
||||
|
||||
```bash
|
||||
# Clone and enter repository
|
||||
git clone <repo> && cd AutoGPT
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
```
|
||||
|
||||
2. **Backend Setup** (always run before backend development):
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry install # Install dependencies
|
||||
poetry run prisma migrate dev # Run database migrations
|
||||
poetry run prisma generate # Generate Prisma client
|
||||
```
|
||||
|
||||
3. **Frontend Setup** (always run before frontend development):
|
||||
```bash
|
||||
cd autogpt_platform/frontend
|
||||
pnpm install # Install dependencies
|
||||
```
|
||||
|
||||
### Runtime Requirements
|
||||
|
||||
**Critical:** Always ensure Docker services are running before starting development:
|
||||
|
||||
```bash
|
||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
```
|
||||
|
||||
**Python Version:** Use Python 3.11 (required; managed by Poetry via pyproject.toml)
|
||||
**Node.js Version:** Use Node.js 21+ with pnpm package manager
|
||||
|
||||
### Development Commands
|
||||
|
||||
**Backend Development:**
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry run serve # Start development server (port 8000)
|
||||
poetry run test # Run all tests (requires ~5 minutes)
|
||||
poetry run pytest path/to/test.py # Run specific test
|
||||
poetry run format # Format code (Black + isort) - always run first
|
||||
poetry run lint # Lint code (ruff) - run after format
|
||||
```
|
||||
|
||||
**Frontend Development:**
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/frontend
|
||||
pnpm dev # Start development server (port 3000) - use for active development
|
||||
pnpm build # Build for production (only needed for E2E tests or deployment)
|
||||
pnpm test # Run Playwright E2E tests (requires build first)
|
||||
pnpm test-ui # Run tests with UI
|
||||
pnpm format # Format and lint code
|
||||
pnpm storybook # Start component development server
|
||||
```
|
||||
|
||||
### Testing Strategy
|
||||
|
||||
**Backend Tests:**
|
||||
|
||||
- **Block Tests**: `poetry run pytest backend/blocks/test/test_block.py -xvs` (validates all blocks)
|
||||
- **Specific Block**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[BlockName]' -xvs`
|
||||
- **Snapshot Tests**: Use `--snapshot-update` when output changes, always review with `git diff`
|
||||
|
||||
**Frontend Tests:**
|
||||
|
||||
- **E2E Tests**: Always run `pnpm dev` before `pnpm test` (Playwright requires running instance)
|
||||
- **Component Tests**: Use Storybook for isolated component development
|
||||
|
||||
### Critical Validation Steps
|
||||
|
||||
**Before committing changes:**
|
||||
|
||||
1. Run `poetry run format` (backend) and `pnpm format` (frontend)
|
||||
2. Ensure all tests pass in modified areas
|
||||
3. Verify Docker services are still running
|
||||
4. Check that database migrations apply cleanly
|
||||
|
||||
**Common Issues & Workarounds:**
|
||||
|
||||
- **Prisma issues**: Run `poetry run prisma generate` after schema changes
|
||||
- **Permission errors**: Ensure Docker has proper permissions
|
||||
- **Port conflicts**: Check the `docker-compose.yml` file for the current list of exposed ports. You can list all mapped ports with:
|
||||
- **Test timeouts**: Backend tests can take 5+ minutes, use `-x` flag to stop on first failure
|
||||
|
||||
## Project Layout & Architecture
|
||||
|
||||
### Core Architecture
|
||||
|
||||
**AutoGPT Platform** (`autogpt_platform/`):
|
||||
|
||||
- `backend/` - FastAPI server with async support
|
||||
- `backend/backend/` - Core API logic
|
||||
- `backend/blocks/` - Agent execution blocks
|
||||
- `backend/data/` - Database models and schemas
|
||||
- `schema.prisma` - Database schema definition
|
||||
- `frontend/` - Next.js application
|
||||
- `src/app/` - App Router pages and layouts
|
||||
- `src/components/` - Reusable React components
|
||||
- `src/lib/` - Utilities and configurations
|
||||
- `autogpt_libs/` - Shared Python utilities
|
||||
- `docker-compose.yml` - Development stack orchestration
|
||||
|
||||
**Key Configuration Files:**
|
||||
|
||||
- `pyproject.toml` - Python dependencies and tooling
|
||||
- `package.json` - Node.js dependencies and scripts
|
||||
- `schema.prisma` - Database schema and migrations
|
||||
- `next.config.mjs` - Next.js configuration
|
||||
- `tailwind.config.ts` - Styling configuration
|
||||
|
||||
### Security & Middleware
|
||||
|
||||
**Cache Protection**: Backend includes middleware preventing sensitive data caching in browsers/proxies
|
||||
**Authentication**: JWT-based with Supabase integration
|
||||
**User ID Validation**: All data access requires user ID checks - verify this for any `data/*.py` changes
|
||||
|
||||
### Development Workflow
|
||||
|
||||
**GitHub Actions**: Multiple CI/CD workflows in `.github/workflows/`
|
||||
|
||||
- `platform-backend-ci.yml` - Backend testing and validation
|
||||
- `platform-frontend-ci.yml` - Frontend testing and validation
|
||||
- `platform-fullstack-ci.yml` - End-to-end integration tests
|
||||
|
||||
**Pre-commit Hooks**: Run linting and formatting checks
|
||||
**Conventional Commits**: Use format `type(scope): description` (e.g., `feat(backend): add API`)
|
||||
|
||||
### Key Source Files
|
||||
|
||||
**Backend Entry Points:**
|
||||
|
||||
- `backend/backend/server/server.py` - FastAPI application setup
|
||||
- `backend/backend/data/` - Database models and user management
|
||||
- `backend/blocks/` - Agent execution blocks and logic
|
||||
|
||||
**Frontend Entry Points:**
|
||||
|
||||
- `frontend/src/app/layout.tsx` - Root application layout
|
||||
- `frontend/src/app/page.tsx` - Home page
|
||||
- `frontend/src/lib/supabase/` - Authentication and database client
|
||||
|
||||
**Protected Routes**: Update `frontend/lib/supabase/middleware.ts` when adding protected routes
|
||||
|
||||
### Agent Block System
|
||||
|
||||
Agents are built using a visual block-based system where each block performs a single action. Blocks are defined in `backend/blocks/` and must include:
|
||||
|
||||
- Block definition with input/output schemas
|
||||
- Execution logic with proper error handling
|
||||
- Tests validating functionality
|
||||
|
||||
### Database & ORM
|
||||
|
||||
**Prisma ORM** with PostgreSQL backend including pgvector for embeddings:
|
||||
|
||||
- Schema in `schema.prisma`
|
||||
- Migrations in `backend/migrations/`
|
||||
- Always run `prisma migrate dev` and `prisma generate` after schema changes
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
### Configuration Files Priority Order
|
||||
|
||||
1. **Backend**: `/backend/.env.default` → `/backend/.env` (user overrides)
|
||||
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
|
||||
3. **Platform**: `/.env.default` (Supabase/shared) → `/.env` (user overrides)
|
||||
4. Docker Compose `environment:` sections override file-based config
|
||||
5. Shell environment variables have highest precedence
|
||||
|
||||
### Docker Environment Setup
|
||||
|
||||
- All services use hardcoded defaults (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Copy `.env.default` files to `.env` for local development customization
|
||||
|
||||
## Advanced Development Patterns
|
||||
|
||||
### Adding New Blocks
|
||||
|
||||
1. Create file in `/backend/backend/blocks/`
|
||||
2. Inherit from `Block` base class with input/output schemas
|
||||
3. Implement `run` method with proper error handling
|
||||
4. Generate block UUID using `uuid.uuid4()`
|
||||
5. Register in block registry
|
||||
6. Write tests alongside block implementation
|
||||
7. Consider how inputs/outputs connect with other blocks in graph editor
|
||||
|
||||
### API Development
|
||||
|
||||
1. Update routes in `/backend/backend/server/routers/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside route files
|
||||
4. For `data/*.py` changes, validate user ID checks
|
||||
5. Run `poetry run test` to verify changes
|
||||
|
||||
### Frontend Development
|
||||
|
||||
**📖 Complete Frontend Guide**: See `autogpt_platform/frontend/CONTRIBUTING.md` and `autogpt_platform/frontend/.cursorrules` for comprehensive patterns and conventions.
|
||||
|
||||
**Quick Reference:**
|
||||
|
||||
**Component Structure:**
|
||||
|
||||
- Separate render logic from data/behavior
|
||||
- Structure: `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Exception: Small components (3-4 lines of logic) can be inline
|
||||
- Render-only components can be direct files without folders
|
||||
|
||||
**Data Fetching:**
|
||||
|
||||
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Generated via Orval from backend OpenAPI spec
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
- Example: `useGetV2ListLibraryAgents`
|
||||
- Regenerate with: `pnpm generate:api`
|
||||
- **Never** use deprecated `BackendAPI` or `src/lib/autogpt-server-api/*`
|
||||
|
||||
**Code Conventions:**
|
||||
|
||||
- Use function declarations for components and handlers (not arrow functions)
|
||||
- Only arrow functions for small inline lambdas (map, filter, etc.)
|
||||
- Components: `PascalCase`, Hooks: `camelCase` with `use` prefix
|
||||
- No barrel files or `index.ts` re-exports
|
||||
- Minimal comments (code should be self-documenting)
|
||||
|
||||
**Styling:**
|
||||
|
||||
- Use Tailwind CSS utilities only
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
- Only use Phosphor Icons (`@phosphor-icons/react`)
|
||||
- Prefer design tokens over hardcoded values
|
||||
|
||||
**Error Handling:**
|
||||
|
||||
- Render errors: Use `<ErrorCard />` component
|
||||
- Mutation errors: Display with toast notifications
|
||||
- Manual exceptions: Use `Sentry.captureException()`
|
||||
- Global error boundaries already configured
|
||||
|
||||
**Testing:**
|
||||
|
||||
- Add/update Storybook stories for UI components (`pnpm storybook`)
|
||||
- Run Playwright E2E tests with `pnpm test`
|
||||
- Verify in Chromatic after PR
|
||||
|
||||
**Architecture:**
|
||||
|
||||
- Default to client components ("use client")
|
||||
- Server components only for SEO or extreme TTFB needs
|
||||
- Use React Query for server state (via generated hooks)
|
||||
- Co-locate UI state in components/hooks
|
||||
|
||||
### Security Guidelines
|
||||
|
||||
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
|
||||
|
||||
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
||||
- Prevents sensitive data caching in browsers/proxies
|
||||
- Add new cacheable endpoints to `CACHEABLE_PATHS`
|
||||
|
||||
### CI/CD Alignment
|
||||
|
||||
The repository has comprehensive CI workflows that test:
|
||||
|
||||
- **Backend**: Python 3.11-3.13, services (Redis/RabbitMQ/ClamAV), Prisma migrations, Poetry lock validation
|
||||
- **Frontend**: Node.js 21, pnpm, Playwright with Docker Compose stack, API schema validation
|
||||
- **Integration**: Full-stack type checking and E2E testing
|
||||
|
||||
Match these patterns when developing locally - the copilot setup environment mirrors these CI configurations.
|
||||
|
||||
## Collaboration with Other AI Assistants
|
||||
|
||||
This repository is actively developed with assistance from Claude (via CLAUDE.md files). When working on this codebase:
|
||||
|
||||
- Check for existing CLAUDE.md files that provide additional context
|
||||
- Follow established patterns and conventions already in the codebase
|
||||
- Maintain consistency with existing code style and architecture
|
||||
- Consider that changes may be reviewed and extended by both human developers and AI assistants
|
||||
|
||||
## Trust These Instructions
|
||||
|
||||
These instructions are comprehensive and tested. Only perform additional searches if:
|
||||
|
||||
1. Information here is incomplete for your specific task
|
||||
2. You encounter errors not covered by the workarounds
|
||||
3. You need to understand implementation details not covered above
|
||||
|
||||
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.
|
||||
97
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
Normal file
97
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
Normal file
@@ -0,0 +1,97 @@
|
||||
name: Auto Fix CI Failures
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["CI"]
|
||||
types:
|
||||
- completed
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
actions: read
|
||||
issues: write
|
||||
id-token: write # Required for OIDC token exchange
|
||||
|
||||
jobs:
|
||||
auto-fix:
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'failure' &&
|
||||
github.event.workflow_run.pull_requests[0] &&
|
||||
!startsWith(github.event.workflow_run.head_branch, 'claude-auto-fix-ci-')
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.workflow_run.head_branch }}
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Setup git identity
|
||||
run: |
|
||||
git config --global user.email "claude[bot]@users.noreply.github.com"
|
||||
git config --global user.name "claude[bot]"
|
||||
|
||||
- name: Create fix branch
|
||||
id: branch
|
||||
run: |
|
||||
BRANCH_NAME="claude-auto-fix-ci-${{ github.event.workflow_run.head_branch }}-${{ github.run_id }}"
|
||||
git checkout -b "$BRANCH_NAME"
|
||||
echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Get CI failure details
|
||||
id: failure_details
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const run = await github.rest.actions.getWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: ${{ github.event.workflow_run.id }}
|
||||
});
|
||||
|
||||
const jobs = await github.rest.actions.listJobsForWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: ${{ github.event.workflow_run.id }}
|
||||
});
|
||||
|
||||
const failedJobs = jobs.data.jobs.filter(job => job.conclusion === 'failure');
|
||||
|
||||
let errorLogs = [];
|
||||
for (const job of failedJobs) {
|
||||
const logs = await github.rest.actions.downloadJobLogsForWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
job_id: job.id
|
||||
});
|
||||
errorLogs.push({
|
||||
jobName: job.name,
|
||||
logs: logs.data
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
runUrl: run.data.html_url,
|
||||
failedJobs: failedJobs.map(j => j.name),
|
||||
errorLogs: errorLogs
|
||||
};
|
||||
|
||||
- name: Fix CI failures with Claude
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
prompt: |
|
||||
/fix-ci
|
||||
Failed CI Run: ${{ fromJSON(steps.failure_details.outputs.result).runUrl }}
|
||||
Failed Jobs: ${{ join(fromJSON(steps.failure_details.outputs.result).failedJobs, ', ') }}
|
||||
PR Number: ${{ github.event.workflow_run.pull_requests[0].number }}
|
||||
Branch Name: ${{ steps.branch.outputs.branch_name }}
|
||||
Base Branch: ${{ github.event.workflow_run.head_branch }}
|
||||
Repository: ${{ github.repository }}
|
||||
|
||||
Error logs:
|
||||
${{ toJSON(fromJSON(steps.failure_details.outputs.result).errorLogs) }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: "--allowedTools 'Edit,MultiEdit,Write,Read,Glob,Grep,LS,Bash(git:*),Bash(bun:*),Bash(npm:*),Bash(npx:*),Bash(gh:*)'"
|
||||
379
.github/workflows/claude-dependabot.yml
vendored
Normal file
379
.github/workflows/claude-dependabot.yml
vendored
Normal file
@@ -0,0 +1,379 @@
|
||||
# Claude Dependabot PR Review Workflow
|
||||
#
|
||||
# This workflow automatically runs Claude analysis on Dependabot PRs to:
|
||||
# - Identify dependency changes and their versions
|
||||
# - Look up changelogs for updated packages
|
||||
# - Assess breaking changes and security impacts
|
||||
# - Provide actionable recommendations for the development team
|
||||
#
|
||||
# Triggered on: Dependabot PRs (opened, synchronize)
|
||||
# Requirements: ANTHROPIC_API_KEY secret must be configured
|
||||
|
||||
name: Claude Dependabot PR Review
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
|
||||
jobs:
|
||||
dependabot-review:
|
||||
# Only run on Dependabot PRs
|
||||
if: github.actor == 'dependabot[bot]'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: read
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read # Required for CI access
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock (matches CI)
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
# Add Poetry to PATH
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Check poetry.lock
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry lock
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
echo "Warning: poetry.lock not up to date, but continuing for setup"
|
||||
git checkout poetry.lock # Reset for clean setup
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Install Playwright browsers for frontend testing
|
||||
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
|
||||
# - name: Install Playwright browsers
|
||||
# working-directory: autogpt_platform/frontend
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Copy default environment files
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Copy default environment files for development
|
||||
cp .env.default .env
|
||||
cp backend/.env.default backend/.env
|
||||
cp frontend/.env.default frontend/.env
|
||||
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
|
||||
restore-keys: |
|
||||
docker-images-v2-${{ runner.os }}-
|
||||
docker-images-v1-${{ runner.os }}-
|
||||
|
||||
- name: Load or pull Docker images
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
mkdir -p ~/docker-cache
|
||||
|
||||
# Define image list for easy maintenance
|
||||
IMAGES=(
|
||||
"redis:latest"
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||
echo "Docker cache found, loading images in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||
echo "Loading $image..."
|
||||
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||
fi
|
||||
done
|
||||
wait
|
||||
echo "All cached images loaded"
|
||||
else
|
||||
echo "No Docker cache found, pulling images in parallel..."
|
||||
# Pull all images in parallel
|
||||
for image in "${IMAGES[@]}"; do
|
||||
docker pull "$image" &
|
||||
done
|
||||
wait
|
||||
|
||||
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||
echo "Saving Docker images to cache in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
echo "Saving $image..."
|
||||
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||
done
|
||||
wait
|
||||
echo "Docker image cache saved"
|
||||
else
|
||||
echo "Skipping cache save for PR/feature branch"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Docker images ready for use"
|
||||
|
||||
# Phase 2: Build migrate service with GitHub Actions cache
|
||||
- name: Build migrate Docker image with cache
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Build the migrate image with buildx for GHA caching
|
||||
docker buildx build \
|
||||
--cache-from type=gha \
|
||||
--cache-to type=gha,mode=max \
|
||||
--target migrate \
|
||||
--tag autogpt_platform-migrate:latest \
|
||||
--load \
|
||||
-f backend/Dockerfile \
|
||||
..
|
||||
|
||||
# Start services using pre-built images
|
||||
- name: Start Docker services for development
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Start essential services (migrate image already built with correct tag)
|
||||
docker compose --profile local up deps --no-build --detach
|
||||
echo "Waiting for services to be ready..."
|
||||
|
||||
# Wait for database to be ready
|
||||
echo "Checking database readiness..."
|
||||
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
|
||||
echo " Waiting for database..."
|
||||
sleep 2
|
||||
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
|
||||
|
||||
# Check migrate service status
|
||||
echo "Checking migration status..."
|
||||
docker compose ps migrate || echo " Migrate service not visible in ps output"
|
||||
|
||||
# Wait for migrate service to complete
|
||||
echo "Waiting for migrations to complete..."
|
||||
timeout 30 bash -c '
|
||||
ATTEMPTS=0
|
||||
while [ $ATTEMPTS -lt 15 ]; do
|
||||
ATTEMPTS=$((ATTEMPTS + 1))
|
||||
|
||||
# Check using docker directly (more reliable than docker compose ps)
|
||||
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
|
||||
|
||||
if [ -z "$CONTAINER_STATUS" ]; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
|
||||
echo "✅ Migrations completed successfully"
|
||||
docker compose logs migrate --tail=5 2>/dev/null || true
|
||||
exit 0
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
|
||||
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
|
||||
echo "❌ Migrations failed with exit code: $EXIT_CODE"
|
||||
echo "Migration logs:"
|
||||
docker compose logs migrate --tail=20 2>/dev/null || true
|
||||
exit 1
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
|
||||
else
|
||||
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
|
||||
echo "Final container check:"
|
||||
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
|
||||
echo "Migration logs (if available):"
|
||||
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
|
||||
' || echo "⚠️ Migration check completed with warnings, continuing..."
|
||||
|
||||
# Brief wait for other services to stabilize
|
||||
echo "Waiting 5 seconds for other services to stabilize..."
|
||||
sleep 5
|
||||
|
||||
# Verify installations and provide environment info
|
||||
- name: Verify setup and show environment info
|
||||
run: |
|
||||
echo "=== Python Setup ==="
|
||||
python --version
|
||||
poetry --version
|
||||
|
||||
echo "=== Node.js Setup ==="
|
||||
node --version
|
||||
pnpm --version
|
||||
|
||||
echo "=== Additional Tools ==="
|
||||
docker --version
|
||||
docker compose version
|
||||
gh --version || true
|
||||
|
||||
echo "=== Services Status ==="
|
||||
cd autogpt_platform
|
||||
docker compose ps || true
|
||||
|
||||
echo "=== Backend Dependencies ==="
|
||||
cd backend
|
||||
poetry show | head -10 || true
|
||||
|
||||
echo "=== Frontend Dependencies ==="
|
||||
cd ../frontend
|
||||
pnpm list --depth=0 | head -10 || true
|
||||
|
||||
echo "=== Environment Files ==="
|
||||
ls -la ../.env* || true
|
||||
ls -la .env* || true
|
||||
ls -la ../backend/.env* || true
|
||||
|
||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
||||
echo "🚀 Ready for development with Docker services running"
|
||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
|
||||
|
||||
- name: Run Claude Dependabot Analysis
|
||||
id: claude_review
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||
prompt: |
|
||||
You are Claude, an AI assistant specialized in reviewing Dependabot dependency update PRs.
|
||||
|
||||
Your primary tasks are:
|
||||
1. **Analyze the dependency changes** in this Dependabot PR
|
||||
2. **Look up changelogs** for all updated dependencies to understand what changed
|
||||
3. **Identify breaking changes** and assess potential impact on the AutoGPT codebase
|
||||
4. **Provide actionable recommendations** for the development team
|
||||
|
||||
## Analysis Process:
|
||||
|
||||
1. **Identify Changed Dependencies**:
|
||||
- Use git diff to see what dependencies were updated
|
||||
- Parse package.json, poetry.lock, requirements files, etc.
|
||||
- List all package versions: old → new
|
||||
|
||||
2. **Changelog Research**:
|
||||
- For each updated dependency, look up its changelog/release notes
|
||||
- Use WebFetch to access GitHub releases, NPM package pages, PyPI project pages. The pr should also have some details
|
||||
- Focus on versions between the old and new versions
|
||||
- Identify: breaking changes, deprecations, security fixes, new features
|
||||
|
||||
3. **Breaking Change Assessment**:
|
||||
- Categorize changes: BREAKING, MAJOR, MINOR, PATCH, SECURITY
|
||||
- Assess impact on AutoGPT's usage patterns
|
||||
- Check if AutoGPT uses affected APIs/features
|
||||
- Look for migration guides or upgrade instructions
|
||||
|
||||
4. **Codebase Impact Analysis**:
|
||||
- Search the AutoGPT codebase for usage of changed APIs
|
||||
- Identify files that might be affected by breaking changes
|
||||
- Check test files for deprecated usage patterns
|
||||
- Look for configuration changes needed
|
||||
|
||||
## Output Format:
|
||||
|
||||
Provide a comprehensive review comment with:
|
||||
|
||||
### 🔍 Dependency Analysis Summary
|
||||
- List of updated packages with version changes
|
||||
- Overall risk assessment (LOW/MEDIUM/HIGH)
|
||||
|
||||
### 📋 Detailed Changelog Review
|
||||
For each updated dependency:
|
||||
- **Package**: name (old_version → new_version)
|
||||
- **Changes**: Summary of key changes
|
||||
- **Breaking Changes**: List any breaking changes
|
||||
- **Security Fixes**: Note security improvements
|
||||
- **Migration Notes**: Any upgrade steps needed
|
||||
|
||||
### ⚠️ Impact Assessment
|
||||
- **Breaking Changes Found**: Yes/No with details
|
||||
- **Affected Files**: List AutoGPT files that may need updates
|
||||
- **Test Impact**: Any tests that may need updating
|
||||
- **Configuration Changes**: Required config updates
|
||||
|
||||
### 🛠️ Recommendations
|
||||
- **Action Required**: What the team should do
|
||||
- **Testing Focus**: Areas to test thoroughly
|
||||
- **Follow-up Tasks**: Any additional work needed
|
||||
- **Merge Recommendation**: APPROVE/REVIEW_NEEDED/HOLD
|
||||
|
||||
### 📚 Useful Links
|
||||
- Links to relevant changelogs, migration guides, documentation
|
||||
|
||||
Be thorough but concise. Focus on actionable insights that help the development team make informed decisions about the dependency updates.
|
||||
284
.github/workflows/claude.yml
vendored
284
.github/workflows/claude.yml
vendored
@@ -30,18 +30,296 @@ jobs:
|
||||
github.event.issue.author_association == 'COLLABORATOR'
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
contents: write
|
||||
pull-requests: read
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read # Required for CI access
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock (matches CI)
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
# Add Poetry to PATH
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Check poetry.lock
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry lock
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
echo "Warning: poetry.lock not up to date, but continuing for setup"
|
||||
git checkout poetry.lock # Reset for clean setup
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Install Playwright browsers for frontend testing
|
||||
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
|
||||
# - name: Install Playwright browsers
|
||||
# working-directory: autogpt_platform/frontend
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Copy default environment files
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Copy default environment files for development
|
||||
cp .env.default .env
|
||||
cp backend/.env.default backend/.env
|
||||
cp frontend/.env.default frontend/.env
|
||||
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
|
||||
restore-keys: |
|
||||
docker-images-v2-${{ runner.os }}-
|
||||
docker-images-v1-${{ runner.os }}-
|
||||
|
||||
- name: Load or pull Docker images
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
mkdir -p ~/docker-cache
|
||||
|
||||
# Define image list for easy maintenance
|
||||
IMAGES=(
|
||||
"redis:latest"
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||
echo "Docker cache found, loading images in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||
echo "Loading $image..."
|
||||
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||
fi
|
||||
done
|
||||
wait
|
||||
echo "All cached images loaded"
|
||||
else
|
||||
echo "No Docker cache found, pulling images in parallel..."
|
||||
# Pull all images in parallel
|
||||
for image in "${IMAGES[@]}"; do
|
||||
docker pull "$image" &
|
||||
done
|
||||
wait
|
||||
|
||||
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||
echo "Saving Docker images to cache in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
echo "Saving $image..."
|
||||
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||
done
|
||||
wait
|
||||
echo "Docker image cache saved"
|
||||
else
|
||||
echo "Skipping cache save for PR/feature branch"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Docker images ready for use"
|
||||
|
||||
# Phase 2: Build migrate service with GitHub Actions cache
|
||||
- name: Build migrate Docker image with cache
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Build the migrate image with buildx for GHA caching
|
||||
docker buildx build \
|
||||
--cache-from type=gha \
|
||||
--cache-to type=gha,mode=max \
|
||||
--target migrate \
|
||||
--tag autogpt_platform-migrate:latest \
|
||||
--load \
|
||||
-f backend/Dockerfile \
|
||||
..
|
||||
|
||||
# Start services using pre-built images
|
||||
- name: Start Docker services for development
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Start essential services (migrate image already built with correct tag)
|
||||
docker compose --profile local up deps --no-build --detach
|
||||
echo "Waiting for services to be ready..."
|
||||
|
||||
# Wait for database to be ready
|
||||
echo "Checking database readiness..."
|
||||
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
|
||||
echo " Waiting for database..."
|
||||
sleep 2
|
||||
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
|
||||
|
||||
# Check migrate service status
|
||||
echo "Checking migration status..."
|
||||
docker compose ps migrate || echo " Migrate service not visible in ps output"
|
||||
|
||||
# Wait for migrate service to complete
|
||||
echo "Waiting for migrations to complete..."
|
||||
timeout 30 bash -c '
|
||||
ATTEMPTS=0
|
||||
while [ $ATTEMPTS -lt 15 ]; do
|
||||
ATTEMPTS=$((ATTEMPTS + 1))
|
||||
|
||||
# Check using docker directly (more reliable than docker compose ps)
|
||||
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
|
||||
|
||||
if [ -z "$CONTAINER_STATUS" ]; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
|
||||
echo "✅ Migrations completed successfully"
|
||||
docker compose logs migrate --tail=5 2>/dev/null || true
|
||||
exit 0
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
|
||||
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
|
||||
echo "❌ Migrations failed with exit code: $EXIT_CODE"
|
||||
echo "Migration logs:"
|
||||
docker compose logs migrate --tail=20 2>/dev/null || true
|
||||
exit 1
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
|
||||
else
|
||||
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
|
||||
echo "Final container check:"
|
||||
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
|
||||
echo "Migration logs (if available):"
|
||||
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
|
||||
' || echo "⚠️ Migration check completed with warnings, continuing..."
|
||||
|
||||
# Brief wait for other services to stabilize
|
||||
echo "Waiting 5 seconds for other services to stabilize..."
|
||||
sleep 5
|
||||
|
||||
# Verify installations and provide environment info
|
||||
- name: Verify setup and show environment info
|
||||
run: |
|
||||
echo "=== Python Setup ==="
|
||||
python --version
|
||||
poetry --version
|
||||
|
||||
echo "=== Node.js Setup ==="
|
||||
node --version
|
||||
pnpm --version
|
||||
|
||||
echo "=== Additional Tools ==="
|
||||
docker --version
|
||||
docker compose version
|
||||
gh --version || true
|
||||
|
||||
echo "=== Services Status ==="
|
||||
cd autogpt_platform
|
||||
docker compose ps || true
|
||||
|
||||
echo "=== Backend Dependencies ==="
|
||||
cd backend
|
||||
poetry show | head -10 || true
|
||||
|
||||
echo "=== Frontend Dependencies ==="
|
||||
cd ../frontend
|
||||
pnpm list --depth=0 | head -10 || true
|
||||
|
||||
echo "=== Environment Files ==="
|
||||
ls -la ../.env* || true
|
||||
ls -la .env* || true
|
||||
ls -la ../backend/.env* || true
|
||||
|
||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
||||
echo "🚀 Ready for development with Docker services running"
|
||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
|
||||
- name: Run Claude Code
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@beta
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*), Bash(gh pr edit:*)"
|
||||
--model opus
|
||||
additional_permissions: |
|
||||
actions: read
|
||||
|
||||
302
.github/workflows/copilot-setup-steps.yml
vendored
Normal file
302
.github/workflows/copilot-setup-steps.yml
vendored
Normal file
@@ -0,0 +1,302 @@
|
||||
name: "Copilot Setup Steps"
|
||||
|
||||
# Automatically run the setup steps when they are changed to allow for easy validation, and
|
||||
# allow manual testing through the repository's "Actions" tab
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
paths:
|
||||
- .github/workflows/copilot-setup-steps.yml
|
||||
pull_request:
|
||||
paths:
|
||||
- .github/workflows/copilot-setup-steps.yml
|
||||
|
||||
jobs:
|
||||
# The job MUST be called `copilot-setup-steps` or it will not be picked up by Copilot.
|
||||
copilot-setup-steps:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
|
||||
# Set the permissions to the lowest permissions possible needed for your steps.
|
||||
# Copilot will be given its own token for its operations.
|
||||
permissions:
|
||||
# If you want to clone the repository as part of your setup steps, for example to install dependencies, you'll need the `contents: read` permission. If you don't clone the repository in your setup steps, Copilot will do this for you automatically after the steps complete.
|
||||
contents: read
|
||||
|
||||
# You can define any steps you want, and they will run before the agent starts.
|
||||
# If you do not check out your code, Copilot will do this for you.
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock (matches CI)
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
# Add Poetry to PATH
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Check poetry.lock
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry lock
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
echo "Warning: poetry.lock not up to date, but continuing for setup"
|
||||
git checkout poetry.lock # Reset for clean setup
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Install Playwright browsers for frontend testing
|
||||
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
|
||||
# - name: Install Playwright browsers
|
||||
# working-directory: autogpt_platform/frontend
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Copy default environment files
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Copy default environment files for development
|
||||
cp .env.default .env
|
||||
cp backend/.env.default backend/.env
|
||||
cp frontend/.env.default frontend/.env
|
||||
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
|
||||
restore-keys: |
|
||||
docker-images-v2-${{ runner.os }}-
|
||||
docker-images-v1-${{ runner.os }}-
|
||||
|
||||
- name: Load or pull Docker images
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
mkdir -p ~/docker-cache
|
||||
|
||||
# Define image list for easy maintenance
|
||||
IMAGES=(
|
||||
"redis:latest"
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||
echo "Docker cache found, loading images in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||
echo "Loading $image..."
|
||||
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||
fi
|
||||
done
|
||||
wait
|
||||
echo "All cached images loaded"
|
||||
else
|
||||
echo "No Docker cache found, pulling images in parallel..."
|
||||
# Pull all images in parallel
|
||||
for image in "${IMAGES[@]}"; do
|
||||
docker pull "$image" &
|
||||
done
|
||||
wait
|
||||
|
||||
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||
echo "Saving Docker images to cache in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
echo "Saving $image..."
|
||||
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||
done
|
||||
wait
|
||||
echo "Docker image cache saved"
|
||||
else
|
||||
echo "Skipping cache save for PR/feature branch"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Docker images ready for use"
|
||||
|
||||
# Phase 2: Build migrate service with GitHub Actions cache
|
||||
- name: Build migrate Docker image with cache
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Build the migrate image with buildx for GHA caching
|
||||
docker buildx build \
|
||||
--cache-from type=gha \
|
||||
--cache-to type=gha,mode=max \
|
||||
--target migrate \
|
||||
--tag autogpt_platform-migrate:latest \
|
||||
--load \
|
||||
-f backend/Dockerfile \
|
||||
..
|
||||
|
||||
# Start services using pre-built images
|
||||
- name: Start Docker services for development
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Start essential services (migrate image already built with correct tag)
|
||||
docker compose --profile local up deps --no-build --detach
|
||||
echo "Waiting for services to be ready..."
|
||||
|
||||
# Wait for database to be ready
|
||||
echo "Checking database readiness..."
|
||||
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
|
||||
echo " Waiting for database..."
|
||||
sleep 2
|
||||
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
|
||||
|
||||
# Check migrate service status
|
||||
echo "Checking migration status..."
|
||||
docker compose ps migrate || echo " Migrate service not visible in ps output"
|
||||
|
||||
# Wait for migrate service to complete
|
||||
echo "Waiting for migrations to complete..."
|
||||
timeout 30 bash -c '
|
||||
ATTEMPTS=0
|
||||
while [ $ATTEMPTS -lt 15 ]; do
|
||||
ATTEMPTS=$((ATTEMPTS + 1))
|
||||
|
||||
# Check using docker directly (more reliable than docker compose ps)
|
||||
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
|
||||
|
||||
if [ -z "$CONTAINER_STATUS" ]; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
|
||||
echo "✅ Migrations completed successfully"
|
||||
docker compose logs migrate --tail=5 2>/dev/null || true
|
||||
exit 0
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
|
||||
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
|
||||
echo "❌ Migrations failed with exit code: $EXIT_CODE"
|
||||
echo "Migration logs:"
|
||||
docker compose logs migrate --tail=20 2>/dev/null || true
|
||||
exit 1
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
|
||||
else
|
||||
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
|
||||
echo "Final container check:"
|
||||
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
|
||||
echo "Migration logs (if available):"
|
||||
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
|
||||
' || echo "⚠️ Migration check completed with warnings, continuing..."
|
||||
|
||||
# Brief wait for other services to stabilize
|
||||
echo "Waiting 5 seconds for other services to stabilize..."
|
||||
sleep 5
|
||||
|
||||
# Verify installations and provide environment info
|
||||
- name: Verify setup and show environment info
|
||||
run: |
|
||||
echo "=== Python Setup ==="
|
||||
python --version
|
||||
poetry --version
|
||||
|
||||
echo "=== Node.js Setup ==="
|
||||
node --version
|
||||
pnpm --version
|
||||
|
||||
echo "=== Additional Tools ==="
|
||||
docker --version
|
||||
docker compose version
|
||||
gh --version || true
|
||||
|
||||
echo "=== Services Status ==="
|
||||
cd autogpt_platform
|
||||
docker compose ps || true
|
||||
|
||||
echo "=== Backend Dependencies ==="
|
||||
cd backend
|
||||
poetry show | head -10 || true
|
||||
|
||||
echo "=== Frontend Dependencies ==="
|
||||
cd ../frontend
|
||||
pnpm list --depth=0 | head -10 || true
|
||||
|
||||
echo "=== Environment Files ==="
|
||||
ls -la ../.env* || true
|
||||
ls -la .env* || true
|
||||
ls -la ../backend/.env* || true
|
||||
|
||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
||||
echo "🚀 Ready for development with Docker services running"
|
||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
@@ -5,6 +5,13 @@ on:
|
||||
branches: [ dev ]
|
||||
paths:
|
||||
- 'autogpt_platform/**'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
git_ref:
|
||||
description: 'Git ref (branch/tag) of AutoGPT to deploy'
|
||||
required: true
|
||||
default: 'master'
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: 'read'
|
||||
@@ -19,6 +26,8 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -48,4 +57,4 @@ jobs:
|
||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
event-type: build_deploy_dev
|
||||
client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "repository": "${{ github.repository }}"}'
|
||||
client-payload: '{"ref": "${{ github.event.inputs.git_ref || github.ref }}", "repository": "${{ github.repository }}"}'
|
||||
|
||||
@@ -3,6 +3,7 @@ name: AutoGPT Platform - Deploy Prod Environment
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: 'read'
|
||||
@@ -17,6 +18,8 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name || 'master' }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -36,7 +39,7 @@ jobs:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
|
||||
|
||||
trigger:
|
||||
needs: migrate
|
||||
runs-on: ubuntu-latest
|
||||
@@ -47,4 +50,5 @@ jobs:
|
||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
event-type: build_deploy_prod
|
||||
client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "repository": "${{ github.repository }}"}'
|
||||
client-payload: |
|
||||
{"ref": "${{ github.ref_name || 'master' }}", "repository": "${{ github.repository }}"}
|
||||
14
.github/workflows/platform-backend-ci.yml
vendored
14
.github/workflows/platform-backend-ci.yml
vendored
@@ -32,14 +32,12 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
python-version: ["3.11", "3.12", "3.13"]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: bitnami/redis:6.2
|
||||
env:
|
||||
REDIS_PASSWORD: testpassword
|
||||
image: redis:latest
|
||||
ports:
|
||||
- 6379:6379
|
||||
rabbitmq:
|
||||
@@ -190,9 +188,9 @@ jobs:
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG test
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
|
||||
else
|
||||
poetry run pytest -s -vv test
|
||||
poetry run pytest -s -vv
|
||||
fi
|
||||
if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
env:
|
||||
@@ -201,10 +199,10 @@ jobs:
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
env:
|
||||
CI: true
|
||||
|
||||
199
.github/workflows/platform-frontend-ci.yml
vendored
199
.github/workflows/platform-frontend-ci.yml
vendored
@@ -18,66 +18,99 @@ defaults:
|
||||
working-directory: autogpt_platform/frontend
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Generate cache key
|
||||
id: cache-key
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run lint
|
||||
run: pnpm lint
|
||||
|
||||
type-check:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Generate API client
|
||||
run: pnpm generate:api-client
|
||||
|
||||
- name: Run tsc check
|
||||
run: pnpm type-check
|
||||
|
||||
chromatic:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
# Only run on dev branch pushes or PRs targeting dev
|
||||
if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
@@ -88,13 +121,13 @@ jobs:
|
||||
onlyChanged: true
|
||||
workingDir: autogpt_platform/frontend
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
exitOnceUploaded: true
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: big-boi
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
browser: [chromium, webkit]
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -105,52 +138,98 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
large-packages: false # slow
|
||||
docker-images: false # limited benefit
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.example ../.env
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
cp ../backend/.env.example ../backend/.env
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Cache Docker layers
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /tmp/.buildx-cache
|
||||
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-buildx-frontend-test-
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml up -d
|
||||
NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d
|
||||
env:
|
||||
DOCKER_BUILDKIT: 1
|
||||
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
|
||||
BUILDX_CACHE_TO: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
||||
|
||||
- name: Move cache
|
||||
run: |
|
||||
rm -rf /tmp/.buildx-cache
|
||||
if [ -d "/tmp/.buildx-cache-new" ]; then
|
||||
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
|
||||
fi
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||
|
||||
- name: Create E2E test data
|
||||
run: |
|
||||
echo "Creating E2E test data..."
|
||||
# First try to run the script from inside the container
|
||||
if docker compose -f ../docker-compose.yml exec -T rest_server test -f /app/autogpt_platform/backend/test/e2e_test_data.py; then
|
||||
echo "✅ Found e2e_test_data.py in container, running it..."
|
||||
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python backend/test/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
else
|
||||
echo "⚠️ e2e_test_data.py not found in container, copying and running..."
|
||||
# Copy the script into the container and run it
|
||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.yml ps -q rest_server):/tmp/e2e_test_data.py || {
|
||||
echo "❌ Failed to copy script to container"
|
||||
exit 1
|
||||
}
|
||||
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
fi
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Setup .env
|
||||
run: cp .env.example .env
|
||||
|
||||
- name: Build frontend
|
||||
run: pnpm build --turbo
|
||||
# uses Turbopack, much faster and safe enough for a test pipeline
|
||||
|
||||
- name: Install Browser '${{ matrix.browser }}'
|
||||
run: pnpm playwright install --with-deps ${{ matrix.browser }}
|
||||
- name: Install Browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build --project=${{ matrix.browser }}
|
||||
run: pnpm test:no-build
|
||||
|
||||
- name: Upload Playwright artifacts
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: playwright-report
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
run: docker compose -f ../docker-compose.yml logs
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: ${{ !cancelled() }}
|
||||
with:
|
||||
name: playwright-report-${{ matrix.browser }}
|
||||
path: playwright-report/
|
||||
retention-days: 30
|
||||
|
||||
132
.github/workflows/platform-fullstack-ci.yml
vendored
Normal file
132
.github/workflows/platform-fullstack-ci.yml
vendored
Normal file
@@ -0,0 +1,132 @@
|
||||
name: AutoGPT Platform - Frontend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev]
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- "autogpt_platform/**"
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- "autogpt_platform/**"
|
||||
merge_group:
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: autogpt_platform/frontend
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Generate cache key
|
||||
id: cache-key
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
types:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Setup .env
|
||||
run: cp .env.default .env
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||
|
||||
- name: Generate API queries
|
||||
run: pnpm generate:api:force
|
||||
|
||||
- name: Check for API schema changes
|
||||
run: |
|
||||
if ! git diff --exit-code src/app/api/openapi.json; then
|
||||
echo "❌ API schema changes detected in src/app/api/openapi.json"
|
||||
echo ""
|
||||
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
|
||||
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
|
||||
echo "The API schema is now out of sync with the Front-end queries."
|
||||
echo ""
|
||||
echo "To fix this:"
|
||||
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
|
||||
echo "2. Run 'pnpm generate:api' locally"
|
||||
echo "3. Run 'pnpm types' locally"
|
||||
echo "4. Fix any TypeScript errors that may have been introduced"
|
||||
echo "5. Commit and push your changes"
|
||||
echo ""
|
||||
exit 1
|
||||
else
|
||||
echo "✅ No API schema changes detected"
|
||||
fi
|
||||
|
||||
- name: Run Typescript checks
|
||||
run: pnpm types
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -5,6 +5,8 @@ classic/original_autogpt/*.json
|
||||
auto_gpt_workspace/*
|
||||
*.mpeg
|
||||
.env
|
||||
# Root .env files
|
||||
/.env
|
||||
azure.yaml
|
||||
.vscode
|
||||
.idea/*
|
||||
@@ -121,7 +123,6 @@ celerybeat.pid
|
||||
|
||||
# Environments
|
||||
.direnv/
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv*/
|
||||
@@ -177,6 +178,4 @@ autogpt_platform/backend/settings.py
|
||||
*.ign.*
|
||||
.test-contents
|
||||
.claude/settings.local.json
|
||||
|
||||
# Auto generated client
|
||||
autogpt_platform/frontend/src/api/__generated__
|
||||
/autogpt_platform/backend/logs
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
[pr_reviewer]
|
||||
num_code_suggestions=0
|
||||
|
||||
[pr_code_suggestions]
|
||||
commitable_code_suggestions=false
|
||||
num_code_suggestions=0
|
||||
|
||||
@@ -235,7 +235,7 @@ repos:
|
||||
hooks:
|
||||
- id: tsc
|
||||
name: Typecheck - AutoGPT Platform - Frontend
|
||||
entry: bash -c 'cd autogpt_platform/frontend && pnpm type-check'
|
||||
entry: bash -c 'cd autogpt_platform/frontend && pnpm types'
|
||||
files: ^autogpt_platform/frontend/
|
||||
types: [file]
|
||||
language: system
|
||||
|
||||
6
.vscode/launch.json
vendored
6
.vscode/launch.json
vendored
@@ -6,7 +6,7 @@
|
||||
"type": "node-terminal",
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceFolder}/autogpt_platform/frontend",
|
||||
"command": "yarn dev"
|
||||
"command": "pnpm dev"
|
||||
},
|
||||
{
|
||||
"name": "Frontend: Client Side",
|
||||
@@ -19,12 +19,12 @@
|
||||
"type": "node-terminal",
|
||||
|
||||
"request": "launch",
|
||||
"command": "yarn dev",
|
||||
"command": "pnpm dev",
|
||||
"cwd": "${workspaceFolder}/autogpt_platform/frontend",
|
||||
"serverReadyAction": {
|
||||
"pattern": "- Local:.+(https?://.+)",
|
||||
"uriFormat": "%s",
|
||||
"action": "debugWithEdge"
|
||||
"action": "debugWithChrome"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
195
LICENSE
195
LICENSE
@@ -1,6 +1,197 @@
|
||||
All portions of this repository are under one of two licenses. The majority of the AutoGPT repository is under the MIT License below. The autogpt_platform folder is under the
|
||||
Polyform Shield License.
|
||||
All portions of this repository are under one of two licenses.
|
||||
|
||||
- Everything inside the autogpt_platform folder is under the Polyform Shield License.
|
||||
- Everything outside the autogpt_platform folder is under the MIT License.
|
||||
|
||||
More info:
|
||||
|
||||
**Polyform Shield License:**
|
||||
Code and content within the `autogpt_platform` folder is licensed under the Polyform Shield License. This new project is our in-developlemt platform for building, deploying and managing agents.
|
||||
Read more about this effort here: https://agpt.co/blog/introducing-the-autogpt-platform
|
||||
|
||||
**MIT License:**
|
||||
All other portions of the AutoGPT repository (i.e., everything outside the `autogpt_platform` folder) are licensed under the MIT License. This includes:
|
||||
- The Original, stand-alone AutoGPT Agent
|
||||
- Forge: https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/forge
|
||||
- AG Benchmark: https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/benchmark
|
||||
- AutoGPT Classic GUI: https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/frontend.
|
||||
|
||||
We also publish additional work under the MIT Licence in other repositories, such as GravitasML (https://github.com/Significant-Gravitas/gravitasml) which is developed for and used in the AutoGPT Platform, and our [Code Ability](https://github.com/Significant-Gravitas/AutoGPT-Code-Ability) project.
|
||||
|
||||
Both licences are available to read below:
|
||||
|
||||
=====================================================
|
||||
-----------------------------------------------------
|
||||
=====================================================
|
||||
|
||||
# PolyForm Shield License 1.0.0
|
||||
|
||||
<https://polyformproject.org/licenses/shield/1.0.0>
|
||||
|
||||
## Acceptance
|
||||
|
||||
In order to get any license under these terms, you must agree
|
||||
to them as both strict obligations and conditions to all
|
||||
your licenses.
|
||||
|
||||
## Copyright License
|
||||
|
||||
The licensor grants you a copyright license for the
|
||||
software to do everything you might do with the software
|
||||
that would otherwise infringe the licensor's copyright
|
||||
in it for any permitted purpose. However, you may
|
||||
only distribute the software according to [Distribution
|
||||
License](#distribution-license) and make changes or new works
|
||||
based on the software according to [Changes and New Works
|
||||
License](#changes-and-new-works-license).
|
||||
|
||||
## Distribution License
|
||||
|
||||
The licensor grants you an additional copyright license
|
||||
to distribute copies of the software. Your license
|
||||
to distribute covers distributing the software with
|
||||
changes and new works permitted by [Changes and New Works
|
||||
License](#changes-and-new-works-license).
|
||||
|
||||
## Notices
|
||||
|
||||
You must ensure that anyone who gets a copy of any part of
|
||||
the software from you also gets a copy of these terms or the
|
||||
URL for them above, as well as copies of any plain-text lines
|
||||
beginning with `Required Notice:` that the licensor provided
|
||||
with the software. For example:
|
||||
|
||||
> Required Notice: Copyright Yoyodyne, Inc. (http://example.com)
|
||||
|
||||
## Changes and New Works License
|
||||
|
||||
The licensor grants you an additional copyright license to
|
||||
make changes and new works based on the software for any
|
||||
permitted purpose.
|
||||
|
||||
## Patent License
|
||||
|
||||
The licensor grants you a patent license for the software that
|
||||
covers patent claims the licensor can license, or becomes able
|
||||
to license, that you would infringe by using the software.
|
||||
|
||||
## Noncompete
|
||||
|
||||
Any purpose is a permitted purpose, except for providing any
|
||||
product that competes with the software or any product the
|
||||
licensor or any of its affiliates provides using the software.
|
||||
|
||||
## Competition
|
||||
|
||||
Goods and services compete even when they provide functionality
|
||||
through different kinds of interfaces or for different technical
|
||||
platforms. Applications can compete with services, libraries
|
||||
with plugins, frameworks with development tools, and so on,
|
||||
even if they're written in different programming languages
|
||||
or for different computer architectures. Goods and services
|
||||
compete even when provided free of charge. If you market a
|
||||
product as a practical substitute for the software or another
|
||||
product, it definitely competes.
|
||||
|
||||
## New Products
|
||||
|
||||
If you are using the software to provide a product that does
|
||||
not compete, but the licensor or any of its affiliates brings
|
||||
your product into competition by providing a new version of
|
||||
the software or another product using the software, you may
|
||||
continue using versions of the software available under these
|
||||
terms beforehand to provide your competing product, but not
|
||||
any later versions.
|
||||
|
||||
## Discontinued Products
|
||||
|
||||
You may begin using the software to compete with a product
|
||||
or service that the licensor or any of its affiliates has
|
||||
stopped providing, unless the licensor includes a plain-text
|
||||
line beginning with `Licensor Line of Business:` with the
|
||||
software that mentions that line of business. For example:
|
||||
|
||||
> Licensor Line of Business: YoyodyneCMS Content Management
|
||||
System (http://example.com/cms)
|
||||
|
||||
## Sales of Business
|
||||
|
||||
If the licensor or any of its affiliates sells a line of
|
||||
business developing the software or using the software
|
||||
to provide a product, the buyer can also enforce
|
||||
[Noncompete](#noncompete) for that product.
|
||||
|
||||
## Fair Use
|
||||
|
||||
You may have "fair use" rights for the software under the
|
||||
law. These terms do not limit them.
|
||||
|
||||
## No Other Rights
|
||||
|
||||
These terms do not allow you to sublicense or transfer any of
|
||||
your licenses to anyone else, or prevent the licensor from
|
||||
granting licenses to anyone else. These terms do not imply
|
||||
any other licenses.
|
||||
|
||||
## Patent Defense
|
||||
|
||||
If you make any written claim that the software infringes or
|
||||
contributes to infringement of any patent, your patent license
|
||||
for the software granted under these terms ends immediately. If
|
||||
your company makes such a claim, your patent license ends
|
||||
immediately for work on behalf of your company.
|
||||
|
||||
## Violations
|
||||
|
||||
The first time you are notified in writing that you have
|
||||
violated any of these terms, or done anything with the software
|
||||
not covered by your licenses, your licenses can nonetheless
|
||||
continue if you come into full compliance with these terms,
|
||||
and take practical steps to correct past violations, within
|
||||
32 days of receiving notice. Otherwise, all your licenses
|
||||
end immediately.
|
||||
|
||||
## No Liability
|
||||
|
||||
***As far as the law allows, the software comes as is, without
|
||||
any warranty or condition, and the licensor will not be liable
|
||||
to you for any damages arising out of these terms or the use
|
||||
or nature of the software, under any kind of legal claim.***
|
||||
|
||||
## Definitions
|
||||
|
||||
The **licensor** is the individual or entity offering these
|
||||
terms, and the **software** is the software the licensor makes
|
||||
available under these terms.
|
||||
|
||||
A **product** can be a good or service, or a combination
|
||||
of them.
|
||||
|
||||
**You** refers to the individual or entity agreeing to these
|
||||
terms.
|
||||
|
||||
**Your company** is any legal entity, sole proprietorship,
|
||||
or other kind of organization that you work for, plus all
|
||||
its affiliates.
|
||||
|
||||
**Affiliates** means the other organizations than an
|
||||
organization has control over, is under the control of, or is
|
||||
under common control with.
|
||||
|
||||
**Control** means ownership of substantially all the assets of
|
||||
an entity, or the power to direct its management and policies
|
||||
by vote, contract, or otherwise. Control can be direct or
|
||||
indirect.
|
||||
|
||||
**Your licenses** are all the licenses granted to you for the
|
||||
software under these terms.
|
||||
|
||||
**Use** means anything you do with the software requiring one
|
||||
of your licenses.
|
||||
|
||||
=====================================================
|
||||
-----------------------------------------------------
|
||||
=====================================================
|
||||
|
||||
MIT License
|
||||
|
||||
|
||||
57
README.md
57
README.md
@@ -1,16 +1,25 @@
|
||||
# AutoGPT: Build, Deploy, and Run AI Agents
|
||||
|
||||
[](https://discord.gg/autogpt)  
|
||||
[](https://discord.gg/autogpt)  
|
||||
[](https://twitter.com/Auto_GPT)  
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
|
||||
<!-- Keep these links. Translations will automatically update with the README. -->
|
||||
[Deutsch](https://zdoc.app/de/Significant-Gravitas/AutoGPT) |
|
||||
[Español](https://zdoc.app/es/Significant-Gravitas/AutoGPT) |
|
||||
[français](https://zdoc.app/fr/Significant-Gravitas/AutoGPT) |
|
||||
[日本語](https://zdoc.app/ja/Significant-Gravitas/AutoGPT) |
|
||||
[한국어](https://zdoc.app/ko/Significant-Gravitas/AutoGPT) |
|
||||
[Português](https://zdoc.app/pt/Significant-Gravitas/AutoGPT) |
|
||||
[Русский](https://zdoc.app/ru/Significant-Gravitas/AutoGPT) |
|
||||
[中文](https://zdoc.app/zh/Significant-Gravitas/AutoGPT)
|
||||
|
||||
**AutoGPT** is a powerful platform that allows you to create, deploy, and manage continuous AI agents that automate complex workflows.
|
||||
|
||||
## Hosting Options
|
||||
- Download to self-host
|
||||
- [Join the Waitlist](https://bit.ly/3ZDijAI) for the cloud-hosted beta
|
||||
- Download to self-host (Free!)
|
||||
- [Join the Waitlist](https://bit.ly/3ZDijAI) for the cloud-hosted beta (Closed Beta - Public release Coming Soon!)
|
||||
|
||||
## How to Setup for Self-Hosting
|
||||
## How to Self-Host the AutoGPT Platform
|
||||
> [!NOTE]
|
||||
> Setting up and hosting the AutoGPT Platform yourself is a technical process.
|
||||
> If you'd rather something that just works, we recommend [joining the waitlist](https://bit.ly/3ZDijAI) for the cloud-hosted beta.
|
||||
@@ -50,6 +59,24 @@ We've moved to a fully maintained and regularly updated documentation site.
|
||||
|
||||
This tutorial assumes you have Docker, VSCode, git and npm installed.
|
||||
|
||||
---
|
||||
|
||||
#### ⚡ Quick Setup with One-Line Script (Recommended for Local Hosting)
|
||||
|
||||
Skip the manual steps and get started in minutes using our automatic setup script.
|
||||
|
||||
For macOS/Linux:
|
||||
```
|
||||
curl -fsSL https://setup.agpt.co/install.sh -o install.sh && bash install.sh
|
||||
```
|
||||
|
||||
For Windows (PowerShell):
|
||||
```
|
||||
powershell -c "iwr https://setup.agpt.co/install.bat -o install.bat; ./install.bat"
|
||||
```
|
||||
|
||||
This will install dependencies, configure Docker, and launch your local instance — all in one go.
|
||||
|
||||
### 🧱 AutoGPT Frontend
|
||||
|
||||
The AutoGPT frontend is where users interact with our powerful AI automation platform. It offers multiple ways to engage with and leverage our AI agents. This is the interface where you'll bring your AI automation ideas to life:
|
||||
@@ -96,7 +123,17 @@ Here are two examples of what you can do with AutoGPT:
|
||||
These examples show just a glimpse of what you can achieve with AutoGPT! You can create customized workflows to build agents for any use case.
|
||||
|
||||
---
|
||||
### Mission and Licencing
|
||||
|
||||
### **License Overview:**
|
||||
|
||||
🛡️ **Polyform Shield License:**
|
||||
All code and content within the `autogpt_platform` folder is licensed under the Polyform Shield License. This new project is our in-developlemt platform for building, deploying and managing agents.</br>_[Read more about this effort](https://agpt.co/blog/introducing-the-autogpt-platform)_
|
||||
|
||||
🦉 **MIT License:**
|
||||
All other portions of the AutoGPT repository (i.e., everything outside the `autogpt_platform` folder) are licensed under the MIT License. This includes the original stand-alone AutoGPT Agent, along with projects such as [Forge](https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/forge), [agbenchmark](https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/benchmark) and the [AutoGPT Classic GUI](https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/frontend).</br>We also publish additional work under the MIT Licence in other repositories, such as [GravitasML](https://github.com/Significant-Gravitas/gravitasml) which is developed for and used in the AutoGPT Platform. See also our MIT Licenced [Code Ability](https://github.com/Significant-Gravitas/AutoGPT-Code-Ability) project.
|
||||
|
||||
---
|
||||
### Mission
|
||||
Our mission is to provide the tools, so that you can focus on what matters:
|
||||
|
||||
- 🏗️ **Building** - Lay the foundation for something amazing.
|
||||
@@ -109,14 +146,6 @@ Be part of the revolution! **AutoGPT** is here to stay, at the forefront of AI i
|
||||
 | 
|
||||
**🚀 [Contributing](CONTRIBUTING.md)**
|
||||
|
||||
**Licensing:**
|
||||
|
||||
MIT License: The majority of the AutoGPT repository is under the MIT License.
|
||||
|
||||
Polyform Shield License: This license applies to the autogpt_platform folder.
|
||||
|
||||
For more information, see https://agpt.co/blog/introducing-the-autogpt-platform
|
||||
|
||||
---
|
||||
## 🤖 AutoGPT Classic
|
||||
> Below is information about the classic version of AutoGPT.
|
||||
|
||||
@@ -5,6 +5,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`/backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`/frontend`): Next.js React application
|
||||
- **Shared Libraries** (`/autogpt_libs`): Common Python utilities
|
||||
@@ -12,6 +13,7 @@ AutoGPT Platform is a monorepo containing:
|
||||
## Essential Commands
|
||||
|
||||
### Backend Development
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd backend && poetry install
|
||||
@@ -31,11 +33,18 @@ poetry run test
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in TESTING.md
|
||||
|
||||
#### Creating/Updating Snapshots
|
||||
@@ -48,31 +57,49 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
|
||||
### Frontend Development
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd frontend && npm install
|
||||
cd frontend && pnpm i
|
||||
|
||||
# Generate API client from OpenAPI spec
|
||||
pnpm generate:api
|
||||
|
||||
# Start development server
|
||||
npm run dev
|
||||
pnpm dev
|
||||
|
||||
# Run E2E tests
|
||||
npm run test
|
||||
pnpm test
|
||||
|
||||
# Run Storybook for component development
|
||||
npm run storybook
|
||||
pnpm storybook
|
||||
|
||||
# Build production
|
||||
npm run build
|
||||
pnpm build
|
||||
|
||||
# Format and lint
|
||||
pnpm format
|
||||
|
||||
# Type checking
|
||||
npm run type-check
|
||||
pnpm types
|
||||
```
|
||||
|
||||
**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns.
|
||||
|
||||
**Key Frontend Conventions:**
|
||||
|
||||
- Separate render logic from data/behavior in components
|
||||
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Use function declarations (not arrow functions) for components/handlers
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Only use Phosphor Icons
|
||||
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Backend Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
@@ -81,13 +108,20 @@ npm run type-check
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
### Frontend Architecture
|
||||
- **Framework**: Next.js App Router with React Server Components
|
||||
- **State Management**: React hooks + Supabase client for real-time updates
|
||||
|
||||
- **Framework**: Next.js 15 App Router (client-first approach)
|
||||
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
|
||||
- **State Management**: React Query for server state, co-located UI state in components/hooks
|
||||
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
|
||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
||||
- **UI Components**: Radix UI primitives with Tailwind CSS styling
|
||||
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
|
||||
- **Icons**: Phosphor Icons only
|
||||
- **Feature Flags**: LaunchDarkly integration
|
||||
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
|
||||
- **Testing**: Playwright for E2E, Storybook for component development
|
||||
|
||||
### Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
@@ -95,13 +129,16 @@ npm run type-check
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Testing Approach
|
||||
|
||||
- Backend uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Frontend uses Playwright for E2E tests
|
||||
- Component testing via Storybook
|
||||
|
||||
### Database Schema
|
||||
|
||||
Key models (defined in `/backend/schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
@@ -109,38 +146,130 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
### Environment Configuration
|
||||
- Backend: `.env` file in `/backend`
|
||||
- Frontend: `.env.local` file in `/frontend`
|
||||
- Both require Supabase credentials and API keys for various services
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides)
|
||||
- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides)
|
||||
- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
1. `.env.default` files provide base configuration (tracked in git)
|
||||
2. `.env` files provide user-specific overrides (gitignored)
|
||||
3. Docker Compose `environment:` sections provide service-specific overrides
|
||||
4. Shell environment variables have highest precedence
|
||||
|
||||
#### Key Points
|
||||
|
||||
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Common Development Tasks
|
||||
|
||||
**Adding a new block:**
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `/backend/backend/blocks/`
|
||||
2. Inherit from `Block` base class
|
||||
3. Define input/output schemas
|
||||
4. Implement `run` method
|
||||
5. Register in block registry
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
**Modifying the API:**
|
||||
|
||||
1. Update route in `/backend/backend/server/routers/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
**Frontend feature development:**
|
||||
1. Components go in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for new components
|
||||
4. Test with Playwright if user-facing
|
||||
|
||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
|
||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
||||
- Add `usePageName.ts` hook for logic
|
||||
- Put sub-components in local `components/` folder
|
||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
|
||||
### Security Implementation
|
||||
|
||||
**Cache Protection Middleware:**
|
||||
|
||||
- Located in `/backend/backend/server/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
- Applied to both main API server and external API applications
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR aginst the `dev` branch of the repository.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/
|
||||
- Use conventional commit messages (see below)/
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
- When the user runs /pr-comments or tries to fetch them, also run gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews to get the reviews
|
||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews/[review_id]/comments to get the review contents
|
||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/issues/9924/comments to get the pr specific comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
Use this format for commit messages and Pull Request titles:
|
||||
|
||||
**Conventional Commit Types:**
|
||||
|
||||
- `feat`: Introduces a new feature to the codebase
|
||||
- `fix`: Patches a bug in the codebase
|
||||
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
|
||||
- `ci`: Changes to CI configuration
|
||||
- `docs`: Documentation-only changes
|
||||
- `dx`: Improvements to the developer experience
|
||||
|
||||
**Recommended Base Scopes:**
|
||||
|
||||
- `platform`: Changes affecting both frontend and backend
|
||||
- `frontend`
|
||||
- `backend`
|
||||
- `infra`
|
||||
- `blocks`: Modifications/additions of individual blocks
|
||||
|
||||
**Subscope Examples:**
|
||||
|
||||
- `backend/executor`
|
||||
- `backend/db`
|
||||
- `frontend/builder` (includes changes to the block UI component)
|
||||
- `infra/prod`
|
||||
|
||||
Use these scopes and subscopes for clarity and consistency in commit messages.
|
||||
|
||||
57
autogpt_platform/Makefile
Normal file
57
autogpt_platform/Makefile
Normal file
@@ -0,0 +1,57 @@
|
||||
.PHONY: start-core stop-core logs-core format lint migrate run-backend run-frontend
|
||||
|
||||
# Run just Supabase + Redis + RabbitMQ
|
||||
start-core:
|
||||
docker compose up -d deps
|
||||
|
||||
# Stop core services
|
||||
stop-core:
|
||||
docker compose stop deps
|
||||
|
||||
reset-db:
|
||||
rm -rf db/docker/volumes/db/data
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
|
||||
# View logs for core services
|
||||
logs-core:
|
||||
docker compose logs -f deps
|
||||
|
||||
# Run formatting and linting for backend and frontend
|
||||
format:
|
||||
cd backend && poetry run format
|
||||
cd frontend && pnpm format
|
||||
cd frontend && pnpm lint
|
||||
|
||||
init-env:
|
||||
cp -n .env.default .env || true
|
||||
cd backend && cp -n .env.default .env || true
|
||||
cd frontend && cp -n .env.default .env || true
|
||||
|
||||
|
||||
# Run migrations for backend
|
||||
migrate:
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
|
||||
run-backend:
|
||||
cd backend && poetry run app
|
||||
|
||||
run-frontend:
|
||||
cd frontend && pnpm dev
|
||||
|
||||
test-data:
|
||||
cd backend && poetry run python test/test_data_creator.py
|
||||
|
||||
help:
|
||||
@echo "Usage: make <target>"
|
||||
@echo "Targets:"
|
||||
@echo " start-core - Start just the core services (Supabase, Redis, RabbitMQ) in background"
|
||||
@echo " stop-core - Stop the core services"
|
||||
@echo " reset-db - Reset the database by deleting the volume"
|
||||
@echo " logs-core - Tail the logs for core services"
|
||||
@echo " format - Format & lint backend (Python) and frontend (TypeScript) code"
|
||||
@echo " migrate - Run backend database migrations"
|
||||
@echo " run-backend - Run the backend FastAPI server"
|
||||
@echo " run-frontend - Run the frontend Next.js development server"
|
||||
@echo " test-data - Run the test data creator"
|
||||
@@ -8,7 +8,6 @@ Welcome to the AutoGPT Platform - a powerful system for creating and running AI
|
||||
|
||||
- Docker
|
||||
- Docker Compose V2 (comes with Docker Desktop, or can be installed separately)
|
||||
- Node.js & NPM (for running the frontend application)
|
||||
|
||||
### Running the System
|
||||
|
||||
@@ -24,10 +23,10 @@ To run the AutoGPT Platform, follow these steps:
|
||||
2. Run the following command:
|
||||
|
||||
```
|
||||
cp .env.example .env
|
||||
cp .env.default .env
|
||||
```
|
||||
|
||||
This command will copy the `.env.example` file to `.env`. You can modify the `.env` file to add your own environment variables.
|
||||
This command will copy the `.env.default` file to `.env`. You can modify the `.env` file to add your own environment variables.
|
||||
|
||||
3. Run the following command:
|
||||
|
||||
@@ -37,44 +36,38 @@ To run the AutoGPT Platform, follow these steps:
|
||||
|
||||
This command will start all the necessary backend services defined in the `docker-compose.yml` file in detached mode.
|
||||
|
||||
4. Navigate to `frontend` within the `autogpt_platform` directory:
|
||||
4. After all the services are in ready state, open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
|
||||
```
|
||||
cd frontend
|
||||
```
|
||||
### Running Just Core services
|
||||
|
||||
You will need to run your frontend application separately on your local machine.
|
||||
You can now run the following to enable just the core services.
|
||||
|
||||
5. Run the following command:
|
||||
```
|
||||
# For help
|
||||
make help
|
||||
|
||||
```
|
||||
cp .env.example .env.local
|
||||
```
|
||||
# Run just Supabase + Redis + RabbitMQ
|
||||
make start-core
|
||||
|
||||
This command will copy the `.env.example` file to `.env.local` in the `frontend` directory. You can modify the `.env.local` within this folder to add your own environment variables for the frontend application.
|
||||
# Stop core services
|
||||
make stop-core
|
||||
|
||||
6. Run the following command:
|
||||
# View logs from core services
|
||||
make logs-core
|
||||
|
||||
Enable corepack and install dependencies by running:
|
||||
# Run formatting and linting for backend and frontend
|
||||
make format
|
||||
|
||||
```
|
||||
corepack enable
|
||||
pnpm i
|
||||
```
|
||||
# Run migrations for backend database
|
||||
make migrate
|
||||
|
||||
Generate the API client (this step is required before running the frontend):
|
||||
# Run backend server
|
||||
make run-backend
|
||||
|
||||
```
|
||||
pnpm generate:api-client
|
||||
```
|
||||
# Run frontend development server
|
||||
make run-frontend
|
||||
|
||||
Then start the frontend application in development mode:
|
||||
|
||||
```
|
||||
pnpm dev
|
||||
```
|
||||
|
||||
7. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
```
|
||||
|
||||
### Docker Compose Commands
|
||||
|
||||
@@ -177,20 +170,21 @@ The platform includes scripts for generating and managing the API client:
|
||||
|
||||
- `pnpm fetch:openapi`: Fetches the OpenAPI specification from the backend service (requires backend to be running on port 8006)
|
||||
- `pnpm generate:api-client`: Generates the TypeScript API client from the OpenAPI specification using Orval
|
||||
- `pnpm generate:api-all`: Runs both fetch and generate commands in sequence
|
||||
- `pnpm generate:api`: Runs both fetch and generate commands in sequence
|
||||
|
||||
#### Manual API Client Updates
|
||||
|
||||
If you need to update the API client after making changes to the backend API:
|
||||
|
||||
1. Ensure the backend services are running:
|
||||
|
||||
```
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
2. Generate the updated API client:
|
||||
```
|
||||
pnpm generate:api-all
|
||||
pnpm generate:api
|
||||
```
|
||||
|
||||
This will fetch the latest OpenAPI specification and regenerate the TypeScript client code.
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
class APIKeyContainer(NamedTuple):
|
||||
"""Container for API key parts."""
|
||||
|
||||
raw: str
|
||||
prefix: str
|
||||
postfix: str
|
||||
hash: str
|
||||
|
||||
|
||||
class APIKeyManager:
|
||||
PREFIX: str = "agpt_"
|
||||
PREFIX_LENGTH: int = 8
|
||||
POSTFIX_LENGTH: int = 8
|
||||
|
||||
def generate_api_key(self) -> APIKeyContainer:
|
||||
"""Generate a new API key with all its parts."""
|
||||
raw_key = f"{self.PREFIX}{secrets.token_urlsafe(32)}"
|
||||
return APIKeyContainer(
|
||||
raw=raw_key,
|
||||
prefix=raw_key[: self.PREFIX_LENGTH],
|
||||
postfix=raw_key[-self.POSTFIX_LENGTH :],
|
||||
hash=hashlib.sha256(raw_key.encode()).hexdigest(),
|
||||
)
|
||||
|
||||
def verify_api_key(self, provided_key: str, stored_hash: str) -> bool:
|
||||
"""Verify if a provided API key matches the stored hash."""
|
||||
if not provided_key.startswith(self.PREFIX):
|
||||
return False
|
||||
provided_hash = hashlib.sha256(provided_key.encode()).hexdigest()
|
||||
return secrets.compare_digest(provided_hash, stored_hash)
|
||||
@@ -0,0 +1,78 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import NamedTuple
|
||||
|
||||
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
|
||||
|
||||
|
||||
class APIKeyContainer(NamedTuple):
|
||||
"""Container for API key parts."""
|
||||
|
||||
key: str
|
||||
head: str
|
||||
tail: str
|
||||
hash: str
|
||||
salt: str
|
||||
|
||||
|
||||
class APIKeySmith:
|
||||
PREFIX: str = "agpt_"
|
||||
HEAD_LENGTH: int = 8
|
||||
TAIL_LENGTH: int = 8
|
||||
|
||||
def generate_key(self) -> APIKeyContainer:
|
||||
"""Generate a new API key with secure hashing."""
|
||||
raw_key = f"{self.PREFIX}{secrets.token_urlsafe(32)}"
|
||||
hash, salt = self.hash_key(raw_key)
|
||||
|
||||
return APIKeyContainer(
|
||||
key=raw_key,
|
||||
head=raw_key[: self.HEAD_LENGTH],
|
||||
tail=raw_key[-self.TAIL_LENGTH :],
|
||||
hash=hash,
|
||||
salt=salt,
|
||||
)
|
||||
|
||||
def verify_key(
|
||||
self, provided_key: str, known_hash: str, known_salt: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Verify an API key against a known hash (+ salt).
|
||||
Supports verifying both legacy SHA256 and secure Scrypt hashes.
|
||||
"""
|
||||
if not provided_key.startswith(self.PREFIX):
|
||||
return False
|
||||
|
||||
# Handle legacy SHA256 hashes (migration support)
|
||||
if known_salt is None:
|
||||
legacy_hash = hashlib.sha256(provided_key.encode()).hexdigest()
|
||||
return secrets.compare_digest(legacy_hash, known_hash)
|
||||
|
||||
try:
|
||||
salt_bytes = bytes.fromhex(known_salt)
|
||||
provided_hash = self._hash_key_with_salt(provided_key, salt_bytes)
|
||||
return secrets.compare_digest(provided_hash, known_hash)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
def hash_key(self, raw_key: str) -> tuple[str, str]:
|
||||
"""Migrate a legacy hash to secure hash format."""
|
||||
salt = self._generate_salt()
|
||||
hash = self._hash_key_with_salt(raw_key, salt)
|
||||
return hash, salt.hex()
|
||||
|
||||
def _generate_salt(self) -> bytes:
|
||||
"""Generate a random salt for hashing."""
|
||||
return secrets.token_bytes(32)
|
||||
|
||||
def _hash_key_with_salt(self, raw_key: str, salt: bytes) -> str:
|
||||
"""Hash API key using Scrypt with salt."""
|
||||
kdf = Scrypt(
|
||||
length=32,
|
||||
salt=salt,
|
||||
n=2**14, # CPU/memory cost parameter
|
||||
r=8, # Block size parameter
|
||||
p=1, # Parallelization parameter
|
||||
)
|
||||
key_hash = kdf.derive(raw_key.encode())
|
||||
return key_hash.hex()
|
||||
@@ -0,0 +1,79 @@
|
||||
import hashlib
|
||||
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
|
||||
|
||||
def test_generate_api_key():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
assert key.key.startswith(keysmith.PREFIX)
|
||||
assert key.head == key.key[: keysmith.HEAD_LENGTH]
|
||||
assert key.tail == key.key[-keysmith.TAIL_LENGTH :]
|
||||
assert len(key.hash) == 64 # 32 bytes hex encoded
|
||||
assert len(key.salt) == 64 # 32 bytes hex encoded
|
||||
|
||||
|
||||
def test_verify_new_secure_key():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Test correct key validates
|
||||
assert keysmith.verify_key(key.key, key.hash, key.salt) is True
|
||||
|
||||
# Test wrong key fails
|
||||
wrong_key = f"{keysmith.PREFIX}wrongkey123"
|
||||
assert keysmith.verify_key(wrong_key, key.hash, key.salt) is False
|
||||
|
||||
|
||||
def test_verify_legacy_key():
|
||||
keysmith = APIKeySmith()
|
||||
legacy_key = f"{keysmith.PREFIX}legacykey123"
|
||||
legacy_hash = hashlib.sha256(legacy_key.encode()).hexdigest()
|
||||
|
||||
# Test legacy key validates without salt
|
||||
assert keysmith.verify_key(legacy_key, legacy_hash) is True
|
||||
|
||||
# Test wrong legacy key fails
|
||||
wrong_key = f"{keysmith.PREFIX}wronglegacy"
|
||||
assert keysmith.verify_key(wrong_key, legacy_hash) is False
|
||||
|
||||
|
||||
def test_rehash_existing_key():
|
||||
keysmith = APIKeySmith()
|
||||
legacy_key = f"{keysmith.PREFIX}migratekey123"
|
||||
|
||||
# Migrate the legacy key
|
||||
new_hash, new_salt = keysmith.hash_key(legacy_key)
|
||||
|
||||
# Verify migrated key works
|
||||
assert keysmith.verify_key(legacy_key, new_hash, new_salt) is True
|
||||
|
||||
# Verify different key fails with migrated hash
|
||||
wrong_key = f"{keysmith.PREFIX}wrongkey"
|
||||
assert keysmith.verify_key(wrong_key, new_hash, new_salt) is False
|
||||
|
||||
|
||||
def test_invalid_key_prefix():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Test key without proper prefix fails
|
||||
invalid_key = "invalid_prefix_key"
|
||||
assert keysmith.verify_key(invalid_key, key.hash, key.salt) is False
|
||||
|
||||
|
||||
def test_secure_hash_requires_salt():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Secure hash without salt should fail
|
||||
assert keysmith.verify_key(key.key, key.hash) is False
|
||||
|
||||
|
||||
def test_invalid_salt_format():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Invalid salt format should fail gracefully
|
||||
assert keysmith.verify_key(key.key, key.hash, "invalid_hex") is False
|
||||
@@ -1,13 +1,19 @@
|
||||
from .depends import requires_admin_user, requires_user
|
||||
from .jwt_utils import parse_jwt_token
|
||||
from .middleware import APIKeyValidator, auth_middleware
|
||||
from .config import verify_settings
|
||||
from .dependencies import (
|
||||
get_optional_user_id,
|
||||
get_user_id,
|
||||
requires_admin_user,
|
||||
requires_user,
|
||||
)
|
||||
from .helpers import add_auth_responses_to_openapi
|
||||
from .models import User
|
||||
|
||||
__all__ = [
|
||||
"parse_jwt_token",
|
||||
"requires_user",
|
||||
"verify_settings",
|
||||
"get_user_id",
|
||||
"requires_admin_user",
|
||||
"APIKeyValidator",
|
||||
"auth_middleware",
|
||||
"requires_user",
|
||||
"get_optional_user_id",
|
||||
"add_auth_responses_to_openapi",
|
||||
"User",
|
||||
]
|
||||
|
||||
@@ -1,15 +1,90 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from jwt.algorithms import get_default_algorithms, has_crypto
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthConfigError(ValueError):
|
||||
"""Raised when authentication configuration is invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
ALGO_RECOMMENDATION = (
|
||||
"We highly recommend using an asymmetric algorithm such as ES256, "
|
||||
"because when leaked, a shared secret would allow anyone to "
|
||||
"forge valid tokens and impersonate users. "
|
||||
"More info: https://supabase.com/docs/guides/auth/signing-keys#choosing-the-right-signing-algorithm" # noqa
|
||||
)
|
||||
|
||||
|
||||
class Settings:
|
||||
def __init__(self):
|
||||
self.JWT_SECRET_KEY: str = os.getenv("SUPABASE_JWT_SECRET", "")
|
||||
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
|
||||
self.JWT_ALGORITHM: str = "HS256"
|
||||
self.JWT_VERIFY_KEY: str = os.getenv(
|
||||
"JWT_VERIFY_KEY", os.getenv("SUPABASE_JWT_SECRET", "")
|
||||
).strip()
|
||||
self.JWT_ALGORITHM: str = os.getenv("JWT_SIGN_ALGORITHM", "HS256").strip()
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(self.JWT_SECRET_KEY)
|
||||
self.validate()
|
||||
|
||||
def validate(self):
|
||||
if not self.JWT_VERIFY_KEY:
|
||||
raise AuthConfigError(
|
||||
"JWT_VERIFY_KEY must be set. "
|
||||
"An empty JWT secret would allow anyone to forge valid tokens."
|
||||
)
|
||||
|
||||
if len(self.JWT_VERIFY_KEY) < 32:
|
||||
logger.warning(
|
||||
"⚠️ JWT_VERIFY_KEY appears weak (less than 32 characters). "
|
||||
"Consider using a longer, cryptographically secure secret."
|
||||
)
|
||||
|
||||
supported_algorithms = get_default_algorithms().keys()
|
||||
|
||||
if not has_crypto:
|
||||
logger.warning(
|
||||
"⚠️ Asymmetric JWT verification is not available "
|
||||
"because the 'cryptography' package is not installed. "
|
||||
+ ALGO_RECOMMENDATION
|
||||
)
|
||||
|
||||
if (
|
||||
self.JWT_ALGORITHM not in supported_algorithms
|
||||
or self.JWT_ALGORITHM == "none"
|
||||
):
|
||||
raise AuthConfigError(
|
||||
f"Invalid JWT_SIGN_ALGORITHM: '{self.JWT_ALGORITHM}'. "
|
||||
"Supported algorithms are listed on "
|
||||
"https://pyjwt.readthedocs.io/en/stable/algorithms.html"
|
||||
)
|
||||
|
||||
if self.JWT_ALGORITHM.startswith("HS"):
|
||||
logger.warning(
|
||||
f"⚠️ JWT_SIGN_ALGORITHM is set to '{self.JWT_ALGORITHM}', "
|
||||
"a symmetric shared-key signature algorithm. " + ALGO_RECOMMENDATION
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
_settings: Settings = None # type: ignore
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
global _settings
|
||||
|
||||
if not _settings:
|
||||
_settings = Settings()
|
||||
|
||||
return _settings
|
||||
|
||||
|
||||
def verify_settings() -> None:
|
||||
global _settings
|
||||
|
||||
if not _settings:
|
||||
_settings = Settings() # calls validation indirectly
|
||||
return
|
||||
|
||||
_settings.validate()
|
||||
|
||||
306
autogpt_platform/autogpt_libs/autogpt_libs/auth/config_test.py
Normal file
306
autogpt_platform/autogpt_libs/autogpt_libs/auth/config_test.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""
|
||||
Comprehensive tests for auth configuration to ensure 100% line and branch coverage.
|
||||
These tests verify critical security checks preventing JWT token forgery.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from autogpt_libs.auth.config import AuthConfigError, Settings
|
||||
|
||||
|
||||
def test_environment_variable_precedence(mocker: MockerFixture):
|
||||
"""Test that environment variables take precedence over defaults."""
|
||||
secret = "environment-secret-key-with-proper-length-123456"
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == secret
|
||||
|
||||
|
||||
def test_environment_variable_backwards_compatible(mocker: MockerFixture):
|
||||
"""Test that SUPABASE_JWT_SECRET is read if JWT_VERIFY_KEY is not set."""
|
||||
secret = "environment-secret-key-with-proper-length-123456"
|
||||
mocker.patch.dict(os.environ, {"SUPABASE_JWT_SECRET": secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == secret
|
||||
|
||||
|
||||
def test_auth_config_error_inheritance():
|
||||
"""Test that AuthConfigError is properly defined as an Exception."""
|
||||
assert issubclass(AuthConfigError, Exception)
|
||||
error = AuthConfigError("test message")
|
||||
assert str(error) == "test message"
|
||||
|
||||
|
||||
def test_settings_static_after_creation(mocker: MockerFixture):
|
||||
"""Test that settings maintain their values after creation."""
|
||||
secret = "immutable-secret-key-with-proper-length-12345"
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
original_secret = settings.JWT_VERIFY_KEY
|
||||
|
||||
# Changing environment after creation shouldn't affect settings
|
||||
os.environ["JWT_VERIFY_KEY"] = "different-secret"
|
||||
|
||||
assert settings.JWT_VERIFY_KEY == original_secret
|
||||
|
||||
|
||||
def test_settings_load_with_valid_secret(mocker: MockerFixture):
|
||||
"""Test auth enabled with a valid JWT secret."""
|
||||
valid_secret = "a" * 32 # 32 character secret
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": valid_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == valid_secret
|
||||
|
||||
|
||||
def test_settings_load_with_strong_secret(mocker: MockerFixture):
|
||||
"""Test auth enabled with a cryptographically strong secret."""
|
||||
strong_secret = "super-secret-jwt-token-with-at-least-32-characters-long"
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": strong_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == strong_secret
|
||||
assert len(settings.JWT_VERIFY_KEY) >= 32
|
||||
|
||||
|
||||
def test_secret_empty_raises_error(mocker: MockerFixture):
|
||||
"""Test that auth enabled with empty secret raises AuthConfigError."""
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": ""}, clear=True)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
Settings()
|
||||
assert "JWT_VERIFY_KEY" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_secret_missing_raises_error(mocker: MockerFixture):
|
||||
"""Test that auth enabled without secret env var raises AuthConfigError."""
|
||||
mocker.patch.dict(os.environ, {}, clear=True)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
Settings()
|
||||
assert "JWT_VERIFY_KEY" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("secret", [" ", " ", "\t", "\n", " \t\n "])
|
||||
def test_secret_only_whitespace_raises_error(mocker: MockerFixture, secret: str):
|
||||
"""Test that auth enabled with whitespace-only secret raises error."""
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Settings()
|
||||
|
||||
|
||||
def test_secret_weak_logs_warning(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
"""Test that weak JWT secret triggers warning log."""
|
||||
weak_secret = "short" # Less than 32 characters
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": weak_secret}, clear=True)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == weak_secret
|
||||
assert "key appears weak" in caplog.text.lower()
|
||||
assert "less than 32 characters" in caplog.text
|
||||
|
||||
|
||||
def test_secret_31_char_logs_warning(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
"""Test that 31-character secret triggers warning (boundary test)."""
|
||||
secret_31 = "a" * 31 # Exactly 31 characters
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret_31}, clear=True)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
settings = Settings()
|
||||
assert len(settings.JWT_VERIFY_KEY) == 31
|
||||
assert "key appears weak" in caplog.text.lower()
|
||||
|
||||
|
||||
def test_secret_32_char_no_warning(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
"""Test that 32-character secret does not trigger warning (boundary test)."""
|
||||
secret_32 = "a" * 32 # Exactly 32 characters
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret_32}, clear=True)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
settings = Settings()
|
||||
assert len(settings.JWT_VERIFY_KEY) == 32
|
||||
assert "JWT secret appears weak" not in caplog.text
|
||||
|
||||
|
||||
def test_secret_whitespace_stripped(mocker: MockerFixture):
|
||||
"""Test that JWT secret whitespace is stripped."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": f" {secret} "}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == secret
|
||||
|
||||
|
||||
def test_secret_with_special_characters(mocker: MockerFixture):
|
||||
"""Test JWT secret with special characters."""
|
||||
special_secret = "!@#$%^&*()_+-=[]{}|;:,.<>?`~" + "a" * 10 # 40 chars total
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": special_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == special_secret
|
||||
|
||||
|
||||
def test_secret_with_unicode(mocker: MockerFixture):
|
||||
"""Test JWT secret with unicode characters."""
|
||||
unicode_secret = "秘密🔐キー" + "a" * 25 # Ensure >32 bytes
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": unicode_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == unicode_secret
|
||||
|
||||
|
||||
def test_secret_very_long(mocker: MockerFixture):
|
||||
"""Test JWT secret with excessive length."""
|
||||
long_secret = "a" * 1000 # 1000 character secret
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": long_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == long_secret
|
||||
assert len(settings.JWT_VERIFY_KEY) == 1000
|
||||
|
||||
|
||||
def test_secret_with_newline(mocker: MockerFixture):
|
||||
"""Test JWT secret containing newlines."""
|
||||
multiline_secret = "secret\nwith\nnewlines" + "a" * 20
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": multiline_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == multiline_secret
|
||||
|
||||
|
||||
def test_secret_base64_encoded(mocker: MockerFixture):
|
||||
"""Test JWT secret that looks like base64."""
|
||||
base64_secret = "dGhpc19pc19hX3NlY3JldF9rZXlfd2l0aF9wcm9wZXJfbGVuZ3Ro"
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": base64_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == base64_secret
|
||||
|
||||
|
||||
def test_secret_numeric_only(mocker: MockerFixture):
|
||||
"""Test JWT secret with only numbers."""
|
||||
numeric_secret = "1234567890" * 4 # 40 character numeric secret
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": numeric_secret}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_VERIFY_KEY == numeric_secret
|
||||
|
||||
|
||||
def test_algorithm_default_hs256(mocker: MockerFixture):
|
||||
"""Test that JWT algorithm defaults to HS256."""
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": "a" * 32}, clear=True)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_ALGORITHM == "HS256"
|
||||
|
||||
|
||||
def test_algorithm_whitespace_stripped(mocker: MockerFixture):
|
||||
"""Test that JWT algorithm whitespace is stripped."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": " HS256 "},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
assert settings.JWT_ALGORITHM == "HS256"
|
||||
|
||||
|
||||
def test_no_crypto_warning(mocker: MockerFixture, caplog: pytest.LogCaptureFixture):
|
||||
"""Test warning when crypto package is not available."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": secret}, clear=True)
|
||||
|
||||
# Mock has_crypto to return False
|
||||
mocker.patch("autogpt_libs.auth.config.has_crypto", False)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
Settings()
|
||||
assert "Asymmetric JWT verification is not available" in caplog.text
|
||||
assert "cryptography" in caplog.text
|
||||
|
||||
|
||||
def test_algorithm_invalid_raises_error(mocker: MockerFixture):
|
||||
"""Test that invalid JWT algorithm raises AuthConfigError."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": "INVALID_ALG"},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
with pytest.raises(AuthConfigError) as exc_info:
|
||||
Settings()
|
||||
assert "Invalid JWT_SIGN_ALGORITHM" in str(exc_info.value)
|
||||
assert "INVALID_ALG" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_algorithm_none_raises_error(mocker: MockerFixture):
|
||||
"""Test that 'none' algorithm raises AuthConfigError."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": "none"},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
with pytest.raises(AuthConfigError) as exc_info:
|
||||
Settings()
|
||||
assert "Invalid JWT_SIGN_ALGORITHM" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("algorithm", ["HS256", "HS384", "HS512"])
|
||||
def test_algorithm_symmetric_warning(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture, algorithm: str
|
||||
):
|
||||
"""Test warning for symmetric algorithms (HS256, HS384, HS512)."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": algorithm},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
settings = Settings()
|
||||
assert algorithm in caplog.text
|
||||
assert "symmetric shared-key signature algorithm" in caplog.text
|
||||
assert settings.JWT_ALGORITHM == algorithm
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"algorithm",
|
||||
["ES256", "ES384", "ES512", "RS256", "RS384", "RS512", "PS256", "PS384", "PS512"],
|
||||
)
|
||||
def test_algorithm_asymmetric_no_warning(
|
||||
mocker: MockerFixture, caplog: pytest.LogCaptureFixture, algorithm: str
|
||||
):
|
||||
"""Test that asymmetric algorithms do not trigger warning."""
|
||||
secret = "a" * 32
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": secret, "JWT_SIGN_ALGORITHM": algorithm},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
settings = Settings()
|
||||
# Should not contain the symmetric algorithm warning
|
||||
assert "symmetric shared-key signature algorithm" not in caplog.text
|
||||
assert settings.JWT_ALGORITHM == algorithm
|
||||
117
autogpt_platform/autogpt_libs/autogpt_libs/auth/dependencies.py
Normal file
117
autogpt_platform/autogpt_libs/autogpt_libs/auth/dependencies.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
FastAPI dependency functions for JWT-based authentication and authorization.
|
||||
|
||||
These are the high-level dependency functions used in route definitions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import fastapi
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from .jwt_utils import get_jwt_payload, verify_user
|
||||
from .models import User
|
||||
|
||||
optional_bearer = HTTPBearer(auto_error=False)
|
||||
|
||||
# Header name for admin impersonation
|
||||
IMPERSONATION_HEADER_NAME = "X-Act-As-User-Id"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_optional_user_id(
|
||||
credentials: HTTPAuthorizationCredentials | None = fastapi.Security(
|
||||
optional_bearer
|
||||
),
|
||||
) -> str | None:
|
||||
"""
|
||||
Attempts to extract the user ID ("sub" claim) from a Bearer JWT if provided.
|
||||
|
||||
This dependency allows for both authenticated and anonymous access. If a valid bearer token is
|
||||
supplied, it parses the JWT and extracts the user ID. If the token is missing or invalid, it returns None,
|
||||
treating the request as anonymous.
|
||||
|
||||
Args:
|
||||
credentials: Optional HTTPAuthorizationCredentials object from FastAPI Security dependency.
|
||||
|
||||
Returns:
|
||||
The user ID (str) extracted from the JWT "sub" claim, or None if no valid token is present.
|
||||
"""
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Parse JWT token to get user ID
|
||||
from autogpt_libs.auth.jwt_utils import parse_jwt_token
|
||||
|
||||
payload = parse_jwt_token(credentials.credentials)
|
||||
return payload.get("sub")
|
||||
except Exception as e:
|
||||
logger.debug(f"Auth token validation failed (anonymous access): {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid authenticated user.
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 for authentication failures
|
||||
"""
|
||||
return verify_user(jwt_payload, admin_only=False)
|
||||
|
||||
|
||||
async def requires_admin_user(
|
||||
jwt_payload: dict = fastapi.Security(get_jwt_payload),
|
||||
) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid admin user.
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 for authentication failures, 403 for insufficient permissions
|
||||
"""
|
||||
return verify_user(jwt_payload, admin_only=True)
|
||||
|
||||
|
||||
async def get_user_id(
|
||||
request: fastapi.Request, jwt_payload: dict = fastapi.Security(get_jwt_payload)
|
||||
) -> str:
|
||||
"""
|
||||
FastAPI dependency that returns the ID of the authenticated user.
|
||||
|
||||
Supports admin impersonation via X-Act-As-User-Id header:
|
||||
- If the header is present and user is admin, returns the impersonated user ID
|
||||
- Otherwise returns the authenticated user's own ID
|
||||
- Logs all impersonation actions for audit trail
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 for authentication failures or missing user ID
|
||||
HTTPException: 403 if non-admin tries to use impersonation
|
||||
"""
|
||||
# Get the authenticated user's ID from JWT
|
||||
user_id = jwt_payload.get("sub")
|
||||
if not user_id:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="User ID not found in token"
|
||||
)
|
||||
|
||||
# Check for admin impersonation header
|
||||
impersonate_header = request.headers.get(IMPERSONATION_HEADER_NAME, "").strip()
|
||||
if impersonate_header:
|
||||
# Verify the authenticated user is an admin
|
||||
authenticated_user = verify_user(jwt_payload, admin_only=False)
|
||||
if authenticated_user.role != "admin":
|
||||
raise fastapi.HTTPException(
|
||||
status_code=403, detail="Only admin users can impersonate other users"
|
||||
)
|
||||
|
||||
# Log the impersonation for audit trail
|
||||
logger.info(
|
||||
f"Admin impersonation: {authenticated_user.user_id} ({authenticated_user.email}) "
|
||||
f"acting as user {impersonate_header} for requesting {request.method} {request.url}"
|
||||
)
|
||||
|
||||
return impersonate_header
|
||||
|
||||
return user_id
|
||||
@@ -0,0 +1,554 @@
|
||||
"""
|
||||
Comprehensive integration tests for authentication dependencies.
|
||||
Tests the full authentication flow from HTTP requests to user validation.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException, Request, Security
|
||||
from fastapi.testclient import TestClient
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from autogpt_libs.auth.dependencies import (
|
||||
get_user_id,
|
||||
requires_admin_user,
|
||||
requires_user,
|
||||
)
|
||||
from autogpt_libs.auth.models import User
|
||||
|
||||
|
||||
class TestAuthDependencies:
|
||||
"""Test suite for authentication dependency functions."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create a test FastAPI application."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/user")
|
||||
def get_user_endpoint(user: User = Security(requires_user)):
|
||||
return {"user_id": user.user_id, "role": user.role}
|
||||
|
||||
@app.get("/admin")
|
||||
def get_admin_endpoint(user: User = Security(requires_admin_user)):
|
||||
return {"user_id": user.user_id, "role": user.role}
|
||||
|
||||
@app.get("/user-id")
|
||||
def get_user_id_endpoint(user_id: str = Security(get_user_id)):
|
||||
return {"user_id": user_id}
|
||||
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def client(self, app):
|
||||
"""Create a test client."""
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user with valid JWT payload."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
# Mock get_jwt_payload to return our test payload
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = await requires_user(jwt_payload)
|
||||
assert isinstance(user, User)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.role == "user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user accepts admin users."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = await requires_user(jwt_payload)
|
||||
assert user.user_id == "admin-456"
|
||||
assert user.role == "admin"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_user_missing_sub(self):
|
||||
"""Test requires_user with missing user ID."""
|
||||
jwt_payload = {"role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_user_empty_sub(self):
|
||||
"""Test requires_user with empty user ID."""
|
||||
jwt_payload = {"sub": "", "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
|
||||
"""Test requires_admin_user with admin role."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-789",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = await requires_admin_user(jwt_payload)
|
||||
assert user.user_id == "admin-789"
|
||||
assert user.role == "admin"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_admin_user_with_regular_user(self):
|
||||
"""Test requires_admin_user rejects regular users."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await requires_admin_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Admin access required" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_admin_user_missing_role(self):
|
||||
"""Test requires_admin_user with missing role."""
|
||||
jwt_payload = {"sub": "user-123", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
await requires_admin_user(jwt_payload)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
|
||||
"""Test get_user_id extracts user ID correctly."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {}
|
||||
jwt_payload = {"sub": "user-id-xyz", "role": "user"}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
assert user_id == "user-id-xyz"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_id_missing_sub(self):
|
||||
"""Test get_user_id with missing user ID."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {}
|
||||
jwt_payload = {"role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_user_id(request, jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_id_none_sub(self):
|
||||
"""Test get_user_id with None user ID."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {}
|
||||
jwt_payload = {"sub": None, "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_user_id(request, jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
class TestAuthDependenciesIntegration:
|
||||
"""Integration tests for auth dependencies with FastAPI."""
|
||||
|
||||
acceptable_jwt_secret = "test-secret-with-proper-length-123456"
|
||||
|
||||
@pytest.fixture
|
||||
def create_token(self, mocker: MockerFixture):
|
||||
"""Helper to create JWT tokens."""
|
||||
import jwt
|
||||
|
||||
mocker.patch.dict(
|
||||
os.environ,
|
||||
{"JWT_VERIFY_KEY": self.acceptable_jwt_secret},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
def _create_token(payload, secret=self.acceptable_jwt_secret):
|
||||
return jwt.encode(payload, secret, algorithm="HS256")
|
||||
|
||||
return _create_token
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_endpoint_auth_enabled_no_token(self):
|
||||
"""Test endpoints require token when auth is enabled."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint(user: User = Security(requires_user)):
|
||||
return {"user_id": user.user_id}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Should fail without auth header
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_endpoint_with_valid_token(self, create_token):
|
||||
"""Test endpoint with valid JWT token."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint(user: User = Security(requires_user)):
|
||||
return {"user_id": user.user_id, "role": user.role}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
token = create_token(
|
||||
{"sub": "test-user", "role": "user", "aud": "authenticated"},
|
||||
secret=self.acceptable_jwt_secret,
|
||||
)
|
||||
|
||||
response = client.get("/test", headers={"Authorization": f"Bearer {token}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["user_id"] == "test-user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_endpoint_requires_admin_role(self, create_token):
|
||||
"""Test admin endpoint rejects non-admin users."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/admin")
|
||||
def admin_endpoint(user: User = Security(requires_admin_user)):
|
||||
return {"user_id": user.user_id}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Regular user token
|
||||
user_token = create_token(
|
||||
{"sub": "regular-user", "role": "user", "aud": "authenticated"},
|
||||
secret=self.acceptable_jwt_secret,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin", headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Admin token
|
||||
admin_token = create_token(
|
||||
{"sub": "admin-user", "role": "admin", "aud": "authenticated"},
|
||||
secret=self.acceptable_jwt_secret,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/admin", headers={"Authorization": f"Bearer {admin_token}"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["user_id"] == "admin-user"
|
||||
|
||||
|
||||
class TestAuthDependenciesEdgeCases:
|
||||
"""Edge case tests for authentication dependencies."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_with_complex_payload(self):
|
||||
"""Test dependencies handle complex JWT payloads."""
|
||||
complex_payload = {
|
||||
"sub": "user-123",
|
||||
"role": "admin",
|
||||
"email": "test@example.com",
|
||||
"app_metadata": {"provider": "email", "providers": ["email"]},
|
||||
"user_metadata": {
|
||||
"full_name": "Test User",
|
||||
"avatar_url": "https://example.com/avatar.jpg",
|
||||
},
|
||||
"aud": "authenticated",
|
||||
"iat": 1234567890,
|
||||
"exp": 9999999999,
|
||||
}
|
||||
|
||||
user = await requires_user(complex_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email == "test@example.com"
|
||||
|
||||
admin = await requires_admin_user(complex_payload)
|
||||
assert admin.role == "admin"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_with_unicode_in_payload(self):
|
||||
"""Test dependencies handle unicode in JWT payloads."""
|
||||
unicode_payload = {
|
||||
"sub": "user-😀-123",
|
||||
"role": "user",
|
||||
"email": "测试@example.com",
|
||||
"name": "日本語",
|
||||
}
|
||||
|
||||
user = await requires_user(unicode_payload)
|
||||
assert "😀" in user.user_id
|
||||
assert user.email == "测试@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_with_null_values(self):
|
||||
"""Test dependencies handle null values in payload."""
|
||||
null_payload = {
|
||||
"sub": "user-123",
|
||||
"role": "user",
|
||||
"email": None,
|
||||
"phone": None,
|
||||
"metadata": None,
|
||||
}
|
||||
|
||||
user = await requires_user(null_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_requests_isolation(self):
|
||||
"""Test that concurrent requests don't interfere with each other."""
|
||||
payload1 = {"sub": "user-1", "role": "user"}
|
||||
payload2 = {"sub": "user-2", "role": "admin"}
|
||||
|
||||
# Simulate concurrent processing
|
||||
user1 = await requires_user(payload1)
|
||||
user2 = await requires_admin_user(payload2)
|
||||
|
||||
assert user1.user_id == "user-1"
|
||||
assert user2.user_id == "user-2"
|
||||
assert user1.role == "user"
|
||||
assert user2.role == "admin"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload,expected_error,admin_only",
|
||||
[
|
||||
(None, "Authorization header is missing", False),
|
||||
({}, "User ID not found", False),
|
||||
({"sub": ""}, "User ID not found", False),
|
||||
({"role": "user"}, "User ID not found", False),
|
||||
({"sub": "user", "role": "user"}, "Admin access required", True),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_error_cases(
|
||||
self, payload, expected_error: str, admin_only: bool
|
||||
):
|
||||
"""Test that errors propagate correctly through dependencies."""
|
||||
# Import verify_user to test it directly since dependencies use FastAPI Security
|
||||
from autogpt_libs.auth.jwt_utils import verify_user
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_user(payload, admin_only=admin_only)
|
||||
assert expected_error in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_valid_user(self):
|
||||
"""Test valid user case for dependency."""
|
||||
# Import verify_user to test it directly since dependencies use FastAPI Security
|
||||
from autogpt_libs.auth.jwt_utils import verify_user
|
||||
|
||||
# Valid case
|
||||
user = verify_user({"sub": "user", "role": "user"}, admin_only=False)
|
||||
assert user.user_id == "user"
|
||||
|
||||
|
||||
class TestAdminImpersonation:
|
||||
"""Test suite for admin user impersonation functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_impersonation_success(self, mocker: MockerFixture):
|
||||
"""Test admin successfully impersonating another user."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Act-As-User-Id": "target-user-123"}
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
# Mock verify_user to return admin user data
|
||||
mock_verify_user = mocker.patch("autogpt_libs.auth.dependencies.verify_user")
|
||||
mock_verify_user.return_value = Mock(
|
||||
user_id="admin-456", email="admin@example.com", role="admin"
|
||||
)
|
||||
|
||||
# Mock logger to verify audit logging
|
||||
mock_logger = mocker.patch("autogpt_libs.auth.dependencies.logger")
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Should return the impersonated user ID
|
||||
assert user_id == "target-user-123"
|
||||
|
||||
# Should log the impersonation attempt
|
||||
mock_logger.info.assert_called_once()
|
||||
log_call = mock_logger.info.call_args[0][0]
|
||||
assert "Admin impersonation:" in log_call
|
||||
assert "admin@example.com" in log_call
|
||||
assert "target-user-123" in log_call
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_admin_impersonation_attempt(self, mocker: MockerFixture):
|
||||
"""Test non-admin user attempting impersonation returns 403."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Act-As-User-Id": "target-user-123"}
|
||||
jwt_payload = {
|
||||
"sub": "regular-user",
|
||||
"role": "user",
|
||||
"email": "user@example.com",
|
||||
}
|
||||
|
||||
# Mock verify_user to return regular user data
|
||||
mock_verify_user = mocker.patch("autogpt_libs.auth.dependencies.verify_user")
|
||||
mock_verify_user.return_value = Mock(
|
||||
user_id="regular-user", email="user@example.com", role="user"
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_user_id(request, jwt_payload)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Only admin users can impersonate other users" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impersonation_empty_header(self, mocker: MockerFixture):
|
||||
"""Test impersonation with empty header falls back to regular user ID."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Act-As-User-Id": ""}
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Should fall back to the admin's own user ID
|
||||
assert user_id == "admin-456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impersonation_missing_header(self, mocker: MockerFixture):
|
||||
"""Test normal behavior when impersonation header is missing."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {} # No impersonation header
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Should return the admin's own user ID
|
||||
assert user_id == "admin-456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impersonation_audit_logging_details(self, mocker: MockerFixture):
|
||||
"""Test that impersonation audit logging includes all required details."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Act-As-User-Id": "victim-user-789"}
|
||||
jwt_payload = {
|
||||
"sub": "admin-999",
|
||||
"role": "admin",
|
||||
"email": "superadmin@company.com",
|
||||
}
|
||||
|
||||
# Mock verify_user to return admin user data
|
||||
mock_verify_user = mocker.patch("autogpt_libs.auth.dependencies.verify_user")
|
||||
mock_verify_user.return_value = Mock(
|
||||
user_id="admin-999", email="superadmin@company.com", role="admin"
|
||||
)
|
||||
|
||||
# Mock logger to capture audit trail
|
||||
mock_logger = mocker.patch("autogpt_libs.auth.dependencies.logger")
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Verify all audit details are logged
|
||||
assert user_id == "victim-user-789"
|
||||
mock_logger.info.assert_called_once()
|
||||
|
||||
log_message = mock_logger.info.call_args[0][0]
|
||||
assert "Admin impersonation:" in log_message
|
||||
assert "superadmin@company.com" in log_message
|
||||
assert "victim-user-789" in log_message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impersonation_header_case_sensitivity(self, mocker: MockerFixture):
|
||||
"""Test that impersonation header is case-sensitive."""
|
||||
request = Mock(spec=Request)
|
||||
# Use wrong case - should not trigger impersonation
|
||||
request.headers = {"x-act-as-user-id": "target-user-123"}
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Should fall back to admin's own ID (header case mismatch)
|
||||
assert user_id == "admin-456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impersonation_with_whitespace_header(self, mocker: MockerFixture):
|
||||
"""Test impersonation with whitespace in header value."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Act-As-User-Id": " target-user-123 "}
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
# Mock verify_user to return admin user data
|
||||
mock_verify_user = mocker.patch("autogpt_libs.auth.dependencies.verify_user")
|
||||
mock_verify_user.return_value = Mock(
|
||||
user_id="admin-456", email="admin@example.com", role="admin"
|
||||
)
|
||||
|
||||
# Mock logger
|
||||
mock_logger = mocker.patch("autogpt_libs.auth.dependencies.logger")
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Should strip whitespace and impersonate successfully
|
||||
assert user_id == "target-user-123"
|
||||
mock_logger.info.assert_called_once()
|
||||
@@ -1,46 +0,0 @@
|
||||
import fastapi
|
||||
|
||||
from .config import settings
|
||||
from .middleware import auth_middleware
|
||||
from .models import DEFAULT_USER_ID, User
|
||||
|
||||
|
||||
def requires_user(payload: dict = fastapi.Depends(auth_middleware)) -> User:
|
||||
return verify_user(payload, admin_only=False)
|
||||
|
||||
|
||||
def requires_admin_user(
|
||||
payload: dict = fastapi.Depends(auth_middleware),
|
||||
) -> User:
|
||||
return verify_user(payload, admin_only=True)
|
||||
|
||||
|
||||
def verify_user(payload: dict | None, admin_only: bool) -> User:
|
||||
if not payload:
|
||||
if settings.ENABLE_AUTH:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="Authorization header is missing"
|
||||
)
|
||||
# This handles the case when authentication is disabled
|
||||
payload = {"sub": DEFAULT_USER_ID, "role": "admin"}
|
||||
|
||||
user_id = payload.get("sub")
|
||||
|
||||
if not user_id:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="User ID not found in token"
|
||||
)
|
||||
|
||||
if admin_only and payload["role"] != "admin":
|
||||
raise fastapi.HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
return User.from_payload(payload)
|
||||
|
||||
|
||||
def get_user_id(payload: dict = fastapi.Depends(auth_middleware)) -> str:
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="User ID not found in token"
|
||||
)
|
||||
return user_id
|
||||
@@ -1,68 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from .depends import requires_admin_user, requires_user, verify_user
|
||||
|
||||
|
||||
def test_verify_user_no_payload():
|
||||
user = verify_user(None, admin_only=False)
|
||||
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
assert user.role == "admin"
|
||||
|
||||
|
||||
def test_verify_user_no_user_id():
|
||||
with pytest.raises(Exception):
|
||||
verify_user({"role": "admin"}, admin_only=False)
|
||||
|
||||
|
||||
def test_verify_user_not_admin():
|
||||
with pytest.raises(Exception):
|
||||
verify_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"},
|
||||
admin_only=True,
|
||||
)
|
||||
|
||||
|
||||
def test_verify_user_with_admin_role():
|
||||
user = verify_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "admin"},
|
||||
admin_only=True,
|
||||
)
|
||||
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
assert user.role == "admin"
|
||||
|
||||
|
||||
def test_verify_user_with_user_role():
|
||||
user = verify_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"},
|
||||
admin_only=False,
|
||||
)
|
||||
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
assert user.role == "user"
|
||||
|
||||
|
||||
def test_requires_user():
|
||||
user = requires_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"}
|
||||
)
|
||||
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
assert user.role == "user"
|
||||
|
||||
|
||||
def test_requires_user_no_user_id():
|
||||
with pytest.raises(Exception):
|
||||
requires_user({"role": "user"})
|
||||
|
||||
|
||||
def test_requires_admin_user():
|
||||
user = requires_admin_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "admin"}
|
||||
)
|
||||
assert user.user_id == "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
assert user.role == "admin"
|
||||
|
||||
|
||||
def test_requires_admin_user_not_admin():
|
||||
with pytest.raises(Exception):
|
||||
requires_admin_user(
|
||||
{"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "role": "user"}
|
||||
)
|
||||
68
autogpt_platform/autogpt_libs/autogpt_libs/auth/helpers.py
Normal file
68
autogpt_platform/autogpt_libs/autogpt_libs/auth/helpers.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from .jwt_utils import bearer_jwt_auth
|
||||
|
||||
|
||||
def add_auth_responses_to_openapi(app: FastAPI) -> None:
|
||||
"""
|
||||
Set up custom OpenAPI schema generation that adds 401 responses
|
||||
to all authenticated endpoints.
|
||||
|
||||
This is needed when using HTTPBearer with auto_error=False to get proper
|
||||
401 responses instead of 403, but FastAPI only automatically adds security
|
||||
responses when auto_error=True.
|
||||
"""
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Add 401 response to all endpoints that have security requirements
|
||||
for path, methods in openapi_schema["paths"].items():
|
||||
for method, details in methods.items():
|
||||
security_schemas = [
|
||||
schema
|
||||
for auth_option in details.get("security", [])
|
||||
for schema in auth_option.keys()
|
||||
]
|
||||
if bearer_jwt_auth.scheme_name not in security_schemas:
|
||||
continue
|
||||
|
||||
if "responses" not in details:
|
||||
details["responses"] = {}
|
||||
|
||||
details["responses"]["401"] = {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
}
|
||||
|
||||
# Ensure #/components/responses exists
|
||||
if "components" not in openapi_schema:
|
||||
openapi_schema["components"] = {}
|
||||
if "responses" not in openapi_schema["components"]:
|
||||
openapi_schema["components"]["responses"] = {}
|
||||
|
||||
# Define 401 response
|
||||
openapi_schema["components"]["responses"]["HTTP401NotAuthenticatedError"] = {
|
||||
"description": "Authentication required",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"detail": {"type": "string"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
app.openapi = custom_openapi
|
||||
435
autogpt_platform/autogpt_libs/autogpt_libs/auth/helpers_test.py
Normal file
435
autogpt_platform/autogpt_libs/autogpt_libs/auth/helpers_test.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
Comprehensive tests for auth helpers module to achieve 100% coverage.
|
||||
Tests OpenAPI schema generation and authentication response handling.
|
||||
"""
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from autogpt_libs.auth.helpers import add_auth_responses_to_openapi
|
||||
from autogpt_libs.auth.jwt_utils import bearer_jwt_auth
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_basic():
|
||||
"""Test adding 401 responses to OpenAPI schema."""
|
||||
app = FastAPI(title="Test App", version="1.0.0")
|
||||
|
||||
# Add some test endpoints with authentication
|
||||
from fastapi import Depends
|
||||
|
||||
from autogpt_libs.auth.dependencies import requires_user
|
||||
|
||||
@app.get("/protected", dependencies=[Depends(requires_user)])
|
||||
def protected_endpoint():
|
||||
return {"message": "Protected"}
|
||||
|
||||
@app.get("/public")
|
||||
def public_endpoint():
|
||||
return {"message": "Public"}
|
||||
|
||||
# Apply the OpenAPI customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Get the OpenAPI schema
|
||||
schema = app.openapi()
|
||||
|
||||
# Verify basic schema properties
|
||||
assert schema["info"]["title"] == "Test App"
|
||||
assert schema["info"]["version"] == "1.0.0"
|
||||
|
||||
# Verify 401 response component is added
|
||||
assert "components" in schema
|
||||
assert "responses" in schema["components"]
|
||||
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
|
||||
|
||||
# Verify 401 response structure
|
||||
error_response = schema["components"]["responses"]["HTTP401NotAuthenticatedError"]
|
||||
assert error_response["description"] == "Authentication required"
|
||||
assert "application/json" in error_response["content"]
|
||||
assert "schema" in error_response["content"]["application/json"]
|
||||
|
||||
# Verify schema properties
|
||||
response_schema = error_response["content"]["application/json"]["schema"]
|
||||
assert response_schema["type"] == "object"
|
||||
assert "detail" in response_schema["properties"]
|
||||
assert response_schema["properties"]["detail"]["type"] == "string"
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_with_security():
|
||||
"""Test that 401 responses are added only to secured endpoints."""
|
||||
app = FastAPI()
|
||||
|
||||
# Mock endpoint with security
|
||||
from fastapi import Security
|
||||
|
||||
from autogpt_libs.auth.dependencies import get_user_id
|
||||
|
||||
@app.get("/secured")
|
||||
def secured_endpoint(user_id: str = Security(get_user_id)):
|
||||
return {"user_id": user_id}
|
||||
|
||||
@app.post("/also-secured")
|
||||
def another_secured(user_id: str = Security(get_user_id)):
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/unsecured")
|
||||
def unsecured_endpoint():
|
||||
return {"public": True}
|
||||
|
||||
# Apply OpenAPI customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Get schema
|
||||
schema = app.openapi()
|
||||
|
||||
# Check that secured endpoints have 401 responses
|
||||
if "/secured" in schema["paths"]:
|
||||
if "get" in schema["paths"]["/secured"]:
|
||||
secured_get = schema["paths"]["/secured"]["get"]
|
||||
if "responses" in secured_get:
|
||||
assert "401" in secured_get["responses"]
|
||||
assert (
|
||||
secured_get["responses"]["401"]["$ref"]
|
||||
== "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
)
|
||||
|
||||
if "/also-secured" in schema["paths"]:
|
||||
if "post" in schema["paths"]["/also-secured"]:
|
||||
secured_post = schema["paths"]["/also-secured"]["post"]
|
||||
if "responses" in secured_post:
|
||||
assert "401" in secured_post["responses"]
|
||||
|
||||
# Check that unsecured endpoint does not have 401 response
|
||||
if "/unsecured" in schema["paths"]:
|
||||
if "get" in schema["paths"]["/unsecured"]:
|
||||
unsecured_get = schema["paths"]["/unsecured"]["get"]
|
||||
if "responses" in unsecured_get:
|
||||
assert "401" not in unsecured_get.get("responses", {})
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_cached_schema():
|
||||
"""Test that OpenAPI schema is cached after first generation."""
|
||||
app = FastAPI()
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Get schema twice
|
||||
schema1 = app.openapi()
|
||||
schema2 = app.openapi()
|
||||
|
||||
# Should return the same cached object
|
||||
assert schema1 is schema2
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_existing_responses():
|
||||
"""Test handling endpoints that already have responses defined."""
|
||||
app = FastAPI()
|
||||
|
||||
from fastapi import Security
|
||||
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
@app.get(
|
||||
"/with-responses",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
404: {"description": "Not found"},
|
||||
},
|
||||
)
|
||||
def endpoint_with_responses(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"data": "test"}
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Check that existing responses are preserved and 401 is added
|
||||
if "/with-responses" in schema["paths"]:
|
||||
if "get" in schema["paths"]["/with-responses"]:
|
||||
responses = schema["paths"]["/with-responses"]["get"].get("responses", {})
|
||||
# Original responses should be preserved
|
||||
if "200" in responses:
|
||||
assert responses["200"]["description"] == "Success"
|
||||
if "404" in responses:
|
||||
assert responses["404"]["description"] == "Not found"
|
||||
# 401 should be added
|
||||
if "401" in responses:
|
||||
assert (
|
||||
responses["401"]["$ref"]
|
||||
== "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
)
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_no_security_endpoints():
|
||||
"""Test with app that has no secured endpoints."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/public1")
|
||||
def public1():
|
||||
return {"message": "public1"}
|
||||
|
||||
@app.post("/public2")
|
||||
def public2():
|
||||
return {"message": "public2"}
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Component should still be added for consistency
|
||||
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
|
||||
|
||||
# But no endpoints should have 401 responses
|
||||
for path in schema["paths"].values():
|
||||
for method in path.values():
|
||||
if isinstance(method, dict) and "responses" in method:
|
||||
assert "401" not in method["responses"]
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_multiple_security_schemes():
|
||||
"""Test endpoints with multiple security requirements."""
|
||||
app = FastAPI()
|
||||
|
||||
from fastapi import Security
|
||||
|
||||
from autogpt_libs.auth.dependencies import requires_admin_user, requires_user
|
||||
from autogpt_libs.auth.models import User
|
||||
|
||||
@app.get("/multi-auth")
|
||||
def multi_auth(
|
||||
user: User = Security(requires_user),
|
||||
admin: User = Security(requires_admin_user),
|
||||
):
|
||||
return {"status": "super secure"}
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Should have 401 response
|
||||
if "/multi-auth" in schema["paths"]:
|
||||
if "get" in schema["paths"]["/multi-auth"]:
|
||||
responses = schema["paths"]["/multi-auth"]["get"].get("responses", {})
|
||||
if "401" in responses:
|
||||
assert (
|
||||
responses["401"]["$ref"]
|
||||
== "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
)
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_empty_components():
|
||||
"""Test when OpenAPI schema has no components section initially."""
|
||||
app = FastAPI()
|
||||
|
||||
# Mock get_openapi to return schema without components
|
||||
original_get_openapi = get_openapi
|
||||
|
||||
def mock_get_openapi(*args, **kwargs):
|
||||
schema = original_get_openapi(*args, **kwargs)
|
||||
# Remove components if it exists
|
||||
if "components" in schema:
|
||||
del schema["components"]
|
||||
return schema
|
||||
|
||||
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Components should be created
|
||||
assert "components" in schema
|
||||
assert "responses" in schema["components"]
|
||||
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
|
||||
|
||||
|
||||
def test_add_auth_responses_to_openapi_all_http_methods():
|
||||
"""Test that all HTTP methods are handled correctly."""
|
||||
app = FastAPI()
|
||||
|
||||
from fastapi import Security
|
||||
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
@app.get("/resource")
|
||||
def get_resource(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"method": "GET"}
|
||||
|
||||
@app.post("/resource")
|
||||
def post_resource(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"method": "POST"}
|
||||
|
||||
@app.put("/resource")
|
||||
def put_resource(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"method": "PUT"}
|
||||
|
||||
@app.patch("/resource")
|
||||
def patch_resource(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"method": "PATCH"}
|
||||
|
||||
@app.delete("/resource")
|
||||
def delete_resource(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"method": "DELETE"}
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# All methods should have 401 response
|
||||
if "/resource" in schema["paths"]:
|
||||
for method in ["get", "post", "put", "patch", "delete"]:
|
||||
if method in schema["paths"]["/resource"]:
|
||||
method_spec = schema["paths"]["/resource"][method]
|
||||
if "responses" in method_spec:
|
||||
assert "401" in method_spec["responses"]
|
||||
|
||||
|
||||
def test_bearer_jwt_auth_scheme_config():
|
||||
"""Test that bearer_jwt_auth is configured correctly."""
|
||||
assert bearer_jwt_auth.scheme_name == "HTTPBearerJWT"
|
||||
assert bearer_jwt_auth.auto_error is False
|
||||
|
||||
|
||||
def test_add_auth_responses_with_no_routes():
|
||||
"""Test OpenAPI generation with app that has no routes."""
|
||||
app = FastAPI(title="Empty App")
|
||||
|
||||
# Apply customization to empty app
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Should still have basic structure
|
||||
assert schema["info"]["title"] == "Empty App"
|
||||
assert "components" in schema
|
||||
assert "responses" in schema["components"]
|
||||
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
|
||||
|
||||
|
||||
def test_custom_openapi_function_replacement():
|
||||
"""Test that the custom openapi function properly replaces the default."""
|
||||
app = FastAPI()
|
||||
|
||||
# Store original function
|
||||
original_openapi = app.openapi
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Function should be replaced
|
||||
assert app.openapi != original_openapi
|
||||
assert callable(app.openapi)
|
||||
|
||||
|
||||
def test_endpoint_without_responses_section():
|
||||
"""Test endpoint that has security but no responses section initially."""
|
||||
app = FastAPI()
|
||||
|
||||
from fastapi import Security
|
||||
from fastapi.openapi.utils import get_openapi as original_get_openapi
|
||||
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
# Create endpoint
|
||||
@app.get("/no-responses")
|
||||
def endpoint_without_responses(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"data": "test"}
|
||||
|
||||
# Mock get_openapi to remove responses from the endpoint
|
||||
def mock_get_openapi(*args, **kwargs):
|
||||
schema = original_get_openapi(*args, **kwargs)
|
||||
# Remove responses from our endpoint to trigger line 40
|
||||
if "/no-responses" in schema.get("paths", {}):
|
||||
if "get" in schema["paths"]["/no-responses"]:
|
||||
# Delete responses to force the code to create it
|
||||
if "responses" in schema["paths"]["/no-responses"]["get"]:
|
||||
del schema["paths"]["/no-responses"]["get"]["responses"]
|
||||
return schema
|
||||
|
||||
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Get schema and verify 401 was added
|
||||
schema = app.openapi()
|
||||
|
||||
# The endpoint should now have 401 response
|
||||
if "/no-responses" in schema["paths"]:
|
||||
if "get" in schema["paths"]["/no-responses"]:
|
||||
responses = schema["paths"]["/no-responses"]["get"].get("responses", {})
|
||||
assert "401" in responses
|
||||
assert (
|
||||
responses["401"]["$ref"]
|
||||
== "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
)
|
||||
|
||||
|
||||
def test_components_with_existing_responses():
|
||||
"""Test when components already has a responses section."""
|
||||
app = FastAPI()
|
||||
|
||||
# Mock get_openapi to return schema with existing components/responses
|
||||
from fastapi.openapi.utils import get_openapi as original_get_openapi
|
||||
|
||||
def mock_get_openapi(*args, **kwargs):
|
||||
schema = original_get_openapi(*args, **kwargs)
|
||||
# Add existing components/responses
|
||||
if "components" not in schema:
|
||||
schema["components"] = {}
|
||||
schema["components"]["responses"] = {
|
||||
"ExistingResponse": {"description": "An existing response"}
|
||||
}
|
||||
return schema
|
||||
|
||||
with mock.patch("autogpt_libs.auth.helpers.get_openapi", mock_get_openapi):
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
schema = app.openapi()
|
||||
|
||||
# Both responses should exist
|
||||
assert "ExistingResponse" in schema["components"]["responses"]
|
||||
assert "HTTP401NotAuthenticatedError" in schema["components"]["responses"]
|
||||
|
||||
# Verify our 401 response structure
|
||||
error_response = schema["components"]["responses"][
|
||||
"HTTP401NotAuthenticatedError"
|
||||
]
|
||||
assert error_response["description"] == "Authentication required"
|
||||
|
||||
|
||||
def test_openapi_schema_persistence():
|
||||
"""Test that modifications to OpenAPI schema persist correctly."""
|
||||
app = FastAPI()
|
||||
|
||||
from fastapi import Security
|
||||
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint(jwt: dict = Security(get_jwt_payload)):
|
||||
return {"test": True}
|
||||
|
||||
# Apply customization
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Get schema multiple times
|
||||
schema1 = app.openapi()
|
||||
|
||||
# Modify the cached schema (shouldn't affect future calls)
|
||||
schema1["info"]["title"] = "Modified Title"
|
||||
|
||||
# Clear cache and get again
|
||||
app.openapi_schema = None
|
||||
schema2 = app.openapi()
|
||||
|
||||
# Should regenerate with original title
|
||||
assert schema2["info"]["title"] == app.title
|
||||
assert schema2["info"]["title"] != "Modified Title"
|
||||
@@ -1,11 +1,48 @@
|
||||
from typing import Any, Dict
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from .config import settings
|
||||
from .config import get_settings
|
||||
from .models import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Bearer token authentication scheme
|
||||
bearer_jwt_auth = HTTPBearer(
|
||||
bearerFormat="jwt", scheme_name="HTTPBearerJWT", auto_error=False
|
||||
)
|
||||
|
||||
|
||||
def parse_jwt_token(token: str) -> Dict[str, Any]:
|
||||
async def get_jwt_payload(
|
||||
credentials: HTTPAuthorizationCredentials | None = Security(bearer_jwt_auth),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Extract and validate JWT payload from HTTP Authorization header.
|
||||
|
||||
This is the core authentication function that handles:
|
||||
- Reading the `Authorization` header to obtain the JWT token
|
||||
- Verifying the JWT token's signature
|
||||
- Decoding the JWT token's payload
|
||||
|
||||
:param credentials: HTTP Authorization credentials from bearer token
|
||||
:return: JWT payload dictionary
|
||||
:raises HTTPException: 401 if authentication fails
|
||||
"""
|
||||
if not credentials:
|
||||
raise HTTPException(status_code=401, detail="Authorization header is missing")
|
||||
|
||||
try:
|
||||
payload = parse_jwt_token(credentials.credentials)
|
||||
logger.debug("Token decoded successfully")
|
||||
return payload
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
|
||||
|
||||
def parse_jwt_token(token: str) -> dict[str, Any]:
|
||||
"""
|
||||
Parse and validate a JWT token.
|
||||
|
||||
@@ -13,10 +50,11 @@ def parse_jwt_token(token: str) -> Dict[str, Any]:
|
||||
:return: The decoded payload
|
||||
:raises ValueError: If the token is invalid or expired
|
||||
"""
|
||||
settings = get_settings()
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_SECRET_KEY,
|
||||
settings.JWT_VERIFY_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM],
|
||||
audience="authenticated",
|
||||
)
|
||||
@@ -25,3 +63,18 @@ def parse_jwt_token(token: str) -> Dict[str, Any]:
|
||||
raise ValueError("Token has expired")
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise ValueError(f"Invalid token: {str(e)}")
|
||||
|
||||
|
||||
def verify_user(jwt_payload: dict | None, admin_only: bool) -> User:
|
||||
if jwt_payload is None:
|
||||
raise HTTPException(status_code=401, detail="Authorization header is missing")
|
||||
|
||||
user_id = jwt_payload.get("sub")
|
||||
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found in token")
|
||||
|
||||
if admin_only and jwt_payload["role"] != "admin":
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
return User.from_payload(jwt_payload)
|
||||
|
||||
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
Comprehensive tests for JWT token parsing and validation.
|
||||
Ensures 100% line and branch coverage for JWT security functions.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from autogpt_libs.auth import config, jwt_utils
|
||||
from autogpt_libs.auth.config import Settings
|
||||
from autogpt_libs.auth.models import User
|
||||
|
||||
MOCK_JWT_SECRET = "test-secret-key-with-at-least-32-characters"
|
||||
TEST_USER_PAYLOAD = {
|
||||
"sub": "test-user-id",
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
TEST_ADMIN_PAYLOAD = {
|
||||
"sub": "admin-user-id",
|
||||
"role": "admin",
|
||||
"aud": "authenticated",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_config(mocker: MockerFixture):
|
||||
mocker.patch.dict(os.environ, {"JWT_VERIFY_KEY": MOCK_JWT_SECRET}, clear=True)
|
||||
mocker.patch.object(config, "_settings", Settings())
|
||||
yield
|
||||
|
||||
|
||||
def create_token(payload, secret=None, algorithm="HS256"):
|
||||
"""Helper to create JWT tokens."""
|
||||
if secret is None:
|
||||
secret = MOCK_JWT_SECRET
|
||||
return jwt.encode(payload, secret, algorithm=algorithm)
|
||||
|
||||
|
||||
def test_parse_jwt_token_valid():
|
||||
"""Test parsing a valid JWT token."""
|
||||
token = create_token(TEST_USER_PAYLOAD)
|
||||
result = jwt_utils.parse_jwt_token(token)
|
||||
|
||||
assert result["sub"] == "test-user-id"
|
||||
assert result["role"] == "user"
|
||||
assert result["aud"] == "authenticated"
|
||||
|
||||
|
||||
def test_parse_jwt_token_expired():
|
||||
"""Test parsing an expired JWT token."""
|
||||
expired_payload = {
|
||||
**TEST_USER_PAYLOAD,
|
||||
"exp": datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
}
|
||||
token = create_token(expired_payload)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Token has expired" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_parse_jwt_token_invalid_signature():
|
||||
"""Test parsing a token with invalid signature."""
|
||||
# Create token with different secret
|
||||
token = create_token(TEST_USER_PAYLOAD, secret="wrong-secret")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_parse_jwt_token_malformed():
|
||||
"""Test parsing a malformed token."""
|
||||
malformed_tokens = [
|
||||
"not.a.token",
|
||||
"invalid",
|
||||
"",
|
||||
# Header only
|
||||
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9",
|
||||
# No signature
|
||||
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0",
|
||||
]
|
||||
|
||||
for token in malformed_tokens:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_parse_jwt_token_wrong_audience():
|
||||
"""Test parsing a token with wrong audience."""
|
||||
wrong_aud_payload = {**TEST_USER_PAYLOAD, "aud": "wrong-audience"}
|
||||
token = create_token(wrong_aud_payload)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_parse_jwt_token_missing_audience():
|
||||
"""Test parsing a token without audience claim."""
|
||||
no_aud_payload = {k: v for k, v in TEST_USER_PAYLOAD.items() if k != "aud"}
|
||||
token = create_token(no_aud_payload)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
async def test_get_jwt_payload_with_valid_token():
|
||||
"""Test extracting JWT payload with valid bearer token."""
|
||||
token = create_token(TEST_USER_PAYLOAD)
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
|
||||
result = await jwt_utils.get_jwt_payload(credentials)
|
||||
assert result["sub"] == "test-user-id"
|
||||
assert result["role"] == "user"
|
||||
|
||||
|
||||
async def test_get_jwt_payload_no_credentials():
|
||||
"""Test JWT payload when no credentials provided."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await jwt_utils.get_jwt_payload(None)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Authorization header is missing" in exc_info.value.detail
|
||||
|
||||
|
||||
async def test_get_jwt_payload_invalid_token():
|
||||
"""Test JWT payload extraction with invalid token."""
|
||||
credentials = HTTPAuthorizationCredentials(
|
||||
scheme="Bearer", credentials="invalid.token.here"
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await jwt_utils.get_jwt_payload(credentials)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid token" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_with_valid_user():
|
||||
"""Test verifying a valid user."""
|
||||
user = jwt_utils.verify_user(TEST_USER_PAYLOAD, admin_only=False)
|
||||
assert isinstance(user, User)
|
||||
assert user.user_id == "test-user-id"
|
||||
assert user.role == "user"
|
||||
assert user.email == "test@example.com"
|
||||
|
||||
|
||||
def test_verify_user_with_admin():
|
||||
"""Test verifying an admin user."""
|
||||
user = jwt_utils.verify_user(TEST_ADMIN_PAYLOAD, admin_only=True)
|
||||
assert isinstance(user, User)
|
||||
assert user.user_id == "admin-user-id"
|
||||
assert user.role == "admin"
|
||||
|
||||
|
||||
def test_verify_user_admin_only_with_regular_user():
|
||||
"""Test verifying regular user when admin is required."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.verify_user(TEST_USER_PAYLOAD, admin_only=True)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Admin access required" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_no_payload():
|
||||
"""Test verifying user with no payload."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.verify_user(None, admin_only=False)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Authorization header is missing" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_missing_sub():
|
||||
"""Test verifying user with payload missing 'sub' field."""
|
||||
invalid_payload = {"role": "user", "email": "test@example.com"}
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.verify_user(invalid_payload, admin_only=False)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found in token" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_empty_sub():
|
||||
"""Test verifying user with empty 'sub' field."""
|
||||
invalid_payload = {"sub": "", "role": "user"}
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.verify_user(invalid_payload, admin_only=False)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found in token" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_none_sub():
|
||||
"""Test verifying user with None 'sub' field."""
|
||||
invalid_payload = {"sub": None, "role": "user"}
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
jwt_utils.verify_user(invalid_payload, admin_only=False)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found in token" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_verify_user_missing_role_admin_check():
|
||||
"""Test verifying admin when role field is missing."""
|
||||
no_role_payload = {"sub": "user-id"}
|
||||
with pytest.raises(KeyError):
|
||||
# This will raise KeyError when checking payload["role"]
|
||||
jwt_utils.verify_user(no_role_payload, admin_only=True)
|
||||
|
||||
|
||||
# ======================== EDGE CASES ======================== #
|
||||
|
||||
|
||||
def test_jwt_with_additional_claims():
|
||||
"""Test JWT token with additional custom claims."""
|
||||
extra_claims_payload = {
|
||||
"sub": "user-id",
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
"custom_claim": "custom_value",
|
||||
"permissions": ["read", "write"],
|
||||
"metadata": {"key": "value"},
|
||||
}
|
||||
token = create_token(extra_claims_payload)
|
||||
|
||||
result = jwt_utils.parse_jwt_token(token)
|
||||
assert result["sub"] == "user-id"
|
||||
assert result["custom_claim"] == "custom_value"
|
||||
assert result["permissions"] == ["read", "write"]
|
||||
|
||||
|
||||
def test_jwt_with_numeric_sub():
|
||||
"""Test JWT token with numeric user ID."""
|
||||
payload = {
|
||||
"sub": 12345, # Numeric ID
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
}
|
||||
# Should convert to string internally
|
||||
user = jwt_utils.verify_user(payload, admin_only=False)
|
||||
assert user.user_id == 12345
|
||||
|
||||
|
||||
def test_jwt_with_very_long_sub():
|
||||
"""Test JWT token with very long user ID."""
|
||||
long_id = "a" * 1000
|
||||
payload = {
|
||||
"sub": long_id,
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
}
|
||||
user = jwt_utils.verify_user(payload, admin_only=False)
|
||||
assert user.user_id == long_id
|
||||
|
||||
|
||||
def test_jwt_with_special_characters_in_claims():
|
||||
"""Test JWT token with special characters in claims."""
|
||||
payload = {
|
||||
"sub": "user@example.com/special-chars!@#$%",
|
||||
"role": "admin",
|
||||
"aud": "authenticated",
|
||||
"email": "test+special@example.com",
|
||||
}
|
||||
user = jwt_utils.verify_user(payload, admin_only=True)
|
||||
assert "special-chars!@#$%" in user.user_id
|
||||
|
||||
|
||||
def test_jwt_with_future_iat():
|
||||
"""Test JWT token with issued-at time in future."""
|
||||
future_payload = {
|
||||
"sub": "user-id",
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
"iat": datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
}
|
||||
token = create_token(future_payload)
|
||||
|
||||
# PyJWT validates iat claim and should reject future tokens
|
||||
with pytest.raises(ValueError, match="not yet valid"):
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
|
||||
|
||||
def test_jwt_with_different_algorithms():
|
||||
"""Test that only HS256 algorithm is accepted."""
|
||||
payload = {
|
||||
"sub": "user-id",
|
||||
"role": "user",
|
||||
"aud": "authenticated",
|
||||
}
|
||||
|
||||
# Try different algorithms
|
||||
algorithms = ["HS384", "HS512", "none"]
|
||||
for algo in algorithms:
|
||||
if algo == "none":
|
||||
# Special case for 'none' algorithm (security vulnerability if accepted)
|
||||
token = create_token(payload, "", algorithm="none")
|
||||
else:
|
||||
token = create_token(payload, algorithm=algo)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
jwt_utils.parse_jwt_token(token)
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
@@ -1,140 +0,0 @@
|
||||
import inspect
|
||||
import logging
|
||||
import secrets
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from fastapi import HTTPException, Request, Security
|
||||
from fastapi.security import APIKeyHeader, HTTPBearer
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
|
||||
from .config import settings
|
||||
from .jwt_utils import parse_jwt_token
|
||||
|
||||
security = HTTPBearer()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def auth_middleware(request: Request):
|
||||
if not settings.ENABLE_AUTH:
|
||||
# If authentication is disabled, allow the request to proceed
|
||||
logger.warning("Auth disabled")
|
||||
return {}
|
||||
|
||||
security = HTTPBearer()
|
||||
credentials = await security(request)
|
||||
|
||||
if not credentials:
|
||||
raise HTTPException(status_code=401, detail="Authorization header is missing")
|
||||
|
||||
try:
|
||||
payload = parse_jwt_token(credentials.credentials)
|
||||
request.state.user = payload
|
||||
logger.debug("Token decoded successfully")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
return payload
|
||||
|
||||
|
||||
class APIKeyValidator:
|
||||
"""
|
||||
Configurable API key validator that supports custom validation functions
|
||||
for FastAPI applications.
|
||||
|
||||
This class provides a flexible way to implement API key authentication with optional
|
||||
custom validation logic. It can be used for simple token matching
|
||||
or more complex validation scenarios like database lookups.
|
||||
|
||||
Examples:
|
||||
Simple token validation:
|
||||
```python
|
||||
validator = APIKeyValidator(
|
||||
header_name="X-API-Key",
|
||||
expected_token="your-secret-token"
|
||||
)
|
||||
|
||||
@app.get("/protected", dependencies=[Depends(validator.get_dependency())])
|
||||
def protected_endpoint():
|
||||
return {"message": "Access granted"}
|
||||
```
|
||||
|
||||
Custom validation with database lookup:
|
||||
```python
|
||||
async def validate_with_db(api_key: str):
|
||||
api_key_obj = await db.get_api_key(api_key)
|
||||
return api_key_obj if api_key_obj and api_key_obj.is_active else None
|
||||
|
||||
validator = APIKeyValidator(
|
||||
header_name="X-API-Key",
|
||||
validate_fn=validate_with_db
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
header_name (str): The name of the header containing the API key
|
||||
expected_token (Optional[str]): The expected API key value for simple token matching
|
||||
validate_fn (Optional[Callable]): Custom validation function that takes an API key
|
||||
string and returns a boolean or object. Can be async.
|
||||
error_status (int): HTTP status code to use for validation errors
|
||||
error_message (str): Error message to return when validation fails
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
header_name: str,
|
||||
expected_token: Optional[str] = None,
|
||||
validate_fn: Optional[Callable[[str], bool]] = None,
|
||||
error_status: int = HTTP_401_UNAUTHORIZED,
|
||||
error_message: str = "Invalid API key",
|
||||
):
|
||||
# Create the APIKeyHeader as a class property
|
||||
self.security_scheme = APIKeyHeader(name=header_name)
|
||||
self.expected_token = expected_token
|
||||
self.custom_validate_fn = validate_fn
|
||||
self.error_status = error_status
|
||||
self.error_message = error_message
|
||||
|
||||
async def default_validator(self, api_key: str) -> bool:
|
||||
if not self.expected_token:
|
||||
raise ValueError(
|
||||
"Expected Token Required to be set when uisng API Key Validator default validation"
|
||||
)
|
||||
return secrets.compare_digest(api_key, self.expected_token)
|
||||
|
||||
async def __call__(
|
||||
self, request: Request, api_key: str = Security(APIKeyHeader)
|
||||
) -> Any:
|
||||
if api_key is None:
|
||||
raise HTTPException(status_code=self.error_status, detail="Missing API key")
|
||||
|
||||
# Use custom validation if provided, otherwise use default equality check
|
||||
validator = self.custom_validate_fn or self.default_validator
|
||||
result = (
|
||||
await validator(api_key)
|
||||
if inspect.iscoroutinefunction(validator)
|
||||
else validator(api_key)
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=self.error_status, detail=self.error_message
|
||||
)
|
||||
|
||||
# Store validation result in request state if it's not just a boolean
|
||||
if result is not True:
|
||||
request.state.api_key = result
|
||||
|
||||
return result
|
||||
|
||||
def get_dependency(self):
|
||||
"""
|
||||
Returns a callable dependency that FastAPI will recognize as a security scheme
|
||||
"""
|
||||
|
||||
async def validate_api_key(
|
||||
request: Request, api_key: str = Security(self.security_scheme)
|
||||
) -> Any:
|
||||
return await self(request, api_key)
|
||||
|
||||
# This helps FastAPI recognize it as a security dependency
|
||||
validate_api_key.__name__ = f"validate_{self.security_scheme.model.name}"
|
||||
return validate_api_key
|
||||
@@ -1,166 +0,0 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, cast
|
||||
|
||||
import ldclient
|
||||
from fastapi import HTTPException
|
||||
from ldclient import Context, LDClient
|
||||
from ldclient.config import Config
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from .config import SETTINGS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_client() -> LDClient:
|
||||
"""Get the LaunchDarkly client singleton."""
|
||||
return ldclient.get()
|
||||
|
||||
|
||||
def initialize_launchdarkly() -> None:
|
||||
sdk_key = SETTINGS.launch_darkly_sdk_key
|
||||
logger.debug(
|
||||
f"Initializing LaunchDarkly with SDK key: {'present' if sdk_key else 'missing'}"
|
||||
)
|
||||
|
||||
if not sdk_key:
|
||||
logger.warning("LaunchDarkly SDK key not configured")
|
||||
return
|
||||
|
||||
config = Config(sdk_key)
|
||||
ldclient.set_config(config)
|
||||
|
||||
if ldclient.get().is_initialized():
|
||||
logger.info("LaunchDarkly client initialized successfully")
|
||||
else:
|
||||
logger.error("LaunchDarkly client failed to initialize")
|
||||
|
||||
|
||||
def shutdown_launchdarkly() -> None:
|
||||
"""Shutdown the LaunchDarkly client."""
|
||||
if ldclient.get().is_initialized():
|
||||
ldclient.get().close()
|
||||
logger.info("LaunchDarkly client closed successfully")
|
||||
|
||||
|
||||
def create_context(
|
||||
user_id: str, additional_attributes: Optional[Dict[str, Any]] = None
|
||||
) -> Context:
|
||||
"""Create LaunchDarkly context with optional additional attributes."""
|
||||
builder = Context.builder(str(user_id)).kind("user")
|
||||
if additional_attributes:
|
||||
for key, value in additional_attributes.items():
|
||||
builder.set(key, value)
|
||||
return builder.build()
|
||||
|
||||
|
||||
def feature_flag(
|
||||
flag_key: str,
|
||||
default: bool = False,
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""
|
||||
Decorator for feature flag protected endpoints.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[P, Union[T, Awaitable[T]]],
|
||||
) -> Callable[P, Union[T, Awaitable[T]]]:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
try:
|
||||
user_id = kwargs.get("user_id")
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
if not get_client().is_initialized():
|
||||
logger.warning(
|
||||
f"LaunchDarkly not initialized, using default={default}"
|
||||
)
|
||||
is_enabled = default
|
||||
else:
|
||||
context = create_context(str(user_id))
|
||||
is_enabled = get_client().variation(flag_key, context, default)
|
||||
|
||||
if not is_enabled:
|
||||
raise HTTPException(status_code=404, detail="Feature not available")
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
if asyncio.iscoroutine(result):
|
||||
return await result
|
||||
return cast(T, result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
|
||||
raise
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
try:
|
||||
user_id = kwargs.get("user_id")
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
if not get_client().is_initialized():
|
||||
logger.warning(
|
||||
f"LaunchDarkly not initialized, using default={default}"
|
||||
)
|
||||
is_enabled = default
|
||||
else:
|
||||
context = create_context(str(user_id))
|
||||
is_enabled = get_client().variation(flag_key, context, default)
|
||||
|
||||
if not is_enabled:
|
||||
raise HTTPException(status_code=404, detail="Feature not available")
|
||||
|
||||
return cast(T, func(*args, **kwargs))
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
|
||||
raise
|
||||
|
||||
return cast(
|
||||
Callable[P, Union[T, Awaitable[T]]],
|
||||
async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper,
|
||||
)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def percentage_rollout(
|
||||
flag_key: str,
|
||||
default: bool = False,
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""Decorator for percentage-based rollouts."""
|
||||
return feature_flag(flag_key, default)
|
||||
|
||||
|
||||
def beta_feature(
|
||||
flag_key: Optional[str] = None,
|
||||
unauthorized_response: Any = {"message": "Not available in beta"},
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""Decorator for beta features."""
|
||||
actual_key = f"beta-{flag_key}" if flag_key else "beta"
|
||||
return feature_flag(actual_key, False)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_flag_variation(flag_key: str, return_value: Any):
|
||||
"""Context manager for testing feature flags."""
|
||||
original_variation = get_client().variation
|
||||
get_client().variation = lambda key, context, default: (
|
||||
return_value if key == flag_key else original_variation(key, context, default)
|
||||
)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
get_client().variation = original_variation
|
||||
@@ -1,45 +0,0 @@
|
||||
import pytest
|
||||
from ldclient import LDClient
|
||||
|
||||
from autogpt_libs.feature_flag.client import feature_flag, mock_flag_variation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ld_client(mocker):
|
||||
client = mocker.Mock(spec=LDClient)
|
||||
mocker.patch("ldclient.get", return_value=client)
|
||||
client.is_initialized.return_value = True
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feature_flag_enabled(ld_client):
|
||||
ld_client.variation.return_value = True
|
||||
|
||||
@feature_flag("test-flag")
|
||||
async def test_function(user_id: str):
|
||||
return "success"
|
||||
|
||||
result = test_function(user_id="test-user")
|
||||
assert result == "success"
|
||||
ld_client.variation.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feature_flag_unauthorized_response(ld_client):
|
||||
ld_client.variation.return_value = False
|
||||
|
||||
@feature_flag("test-flag")
|
||||
async def test_function(user_id: str):
|
||||
return "success"
|
||||
|
||||
result = test_function(user_id="test-user")
|
||||
assert result == {"error": "disabled"}
|
||||
|
||||
|
||||
def test_mock_flag_variation(ld_client):
|
||||
with mock_flag_variation("test-flag", True):
|
||||
assert ld_client.variation("test-flag", None, False)
|
||||
|
||||
with mock_flag_variation("test-flag", False):
|
||||
assert ld_client.variation("test-flag", None, False)
|
||||
@@ -1,15 +0,0 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
launch_darkly_sdk_key: str = Field(
|
||||
default="",
|
||||
description="The Launch Darkly SDK key",
|
||||
validation_alias="LAUNCH_DARKLY_SDK_KEY",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(case_sensitive=True, extra="ignore")
|
||||
|
||||
|
||||
SETTINGS = Settings()
|
||||
@@ -1,7 +1,10 @@
|
||||
"""Logging module for Auto-GPT."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
@@ -10,6 +13,15 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from .filters import BelowLevelFilter
|
||||
from .formatters import AGPTFormatter
|
||||
|
||||
# Configure global socket timeout and gRPC keepalive to prevent deadlocks
|
||||
# This must be done at import time before any gRPC connections are established
|
||||
socket.setdefaulttimeout(30) # 30-second socket timeout
|
||||
|
||||
# Enable gRPC keepalive to detect dead connections faster
|
||||
os.environ.setdefault("GRPC_KEEPALIVE_TIME_MS", "30000") # 30 seconds
|
||||
os.environ.setdefault("GRPC_KEEPALIVE_TIMEOUT_MS", "5000") # 5 seconds
|
||||
os.environ.setdefault("GRPC_KEEPALIVE_PERMIT_WITHOUT_CALLS", "true")
|
||||
|
||||
LOG_DIR = Path(__file__).parent.parent.parent.parent / "logs"
|
||||
LOG_FILE = "activity.log"
|
||||
DEBUG_LOG_FILE = "debug.log"
|
||||
@@ -79,42 +91,39 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
Note: This function is typically called at the start of the application
|
||||
to set up the logging infrastructure.
|
||||
"""
|
||||
|
||||
config = LoggingConfig()
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
structured_logging = config.enable_cloud_logging or force_cloud_logging
|
||||
|
||||
# Console output handlers
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
if config.level == logging.DEBUG:
|
||||
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
if not structured_logging:
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
if config.level == logging.DEBUG:
|
||||
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
if config.level == logging.DEBUG:
|
||||
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
if config.level == logging.DEBUG:
|
||||
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
log_handlers += [stdout, stderr]
|
||||
log_handlers += [stdout, stderr]
|
||||
|
||||
# Cloud logging setup
|
||||
if config.enable_cloud_logging or force_cloud_logging:
|
||||
import google.cloud.logging
|
||||
from google.cloud.logging.handlers import CloudLoggingHandler
|
||||
from google.cloud.logging_v2.handlers.transports.sync import SyncTransport
|
||||
else:
|
||||
# Use Google Cloud Structured Log Handler. Log entries are printed to stdout
|
||||
# in a JSON format which is automatically picked up by Google Cloud Logging.
|
||||
from google.cloud.logging.handlers import StructuredLogHandler
|
||||
|
||||
client = google.cloud.logging.Client()
|
||||
cloud_handler = CloudLoggingHandler(
|
||||
client,
|
||||
name="autogpt_logs",
|
||||
transport=SyncTransport,
|
||||
)
|
||||
cloud_handler.setLevel(config.level)
|
||||
log_handlers.append(cloud_handler)
|
||||
structured_log_handler = StructuredLogHandler(stream=sys.stdout)
|
||||
structured_log_handler.setLevel(config.level)
|
||||
log_handlers.append(structured_log_handler)
|
||||
|
||||
# File logging setup
|
||||
if config.enable_file_logging:
|
||||
@@ -125,8 +134,13 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
print(f"Log directory: {config.log_dir}")
|
||||
|
||||
# Activity log handler (INFO and above)
|
||||
activity_log_handler = logging.FileHandler(
|
||||
config.log_dir / LOG_FILE, "a", "utf-8"
|
||||
# Security fix: Use RotatingFileHandler with size limits to prevent disk exhaustion
|
||||
activity_log_handler = RotatingFileHandler(
|
||||
config.log_dir / LOG_FILE,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB per file
|
||||
backupCount=3, # Keep 3 backup files (40MB total)
|
||||
)
|
||||
activity_log_handler.setLevel(config.level)
|
||||
activity_log_handler.setFormatter(
|
||||
@@ -136,8 +150,13 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
|
||||
if config.level == logging.DEBUG:
|
||||
# Debug log handler (all levels)
|
||||
debug_log_handler = logging.FileHandler(
|
||||
config.log_dir / DEBUG_LOG_FILE, "a", "utf-8"
|
||||
# Security fix: Use RotatingFileHandler with size limits
|
||||
debug_log_handler = RotatingFileHandler(
|
||||
config.log_dir / DEBUG_LOG_FILE,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB per file
|
||||
backupCount=3, # Keep 3 backup files (40MB total)
|
||||
)
|
||||
debug_log_handler.setLevel(logging.DEBUG)
|
||||
debug_log_handler.setFormatter(
|
||||
@@ -146,8 +165,13 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
log_handlers.append(debug_log_handler)
|
||||
|
||||
# Error log handler (ERROR and above)
|
||||
error_log_handler = logging.FileHandler(
|
||||
config.log_dir / ERROR_LOG_FILE, "a", "utf-8"
|
||||
# Security fix: Use RotatingFileHandler with size limits
|
||||
error_log_handler = RotatingFileHandler(
|
||||
config.log_dir / ERROR_LOG_FILE,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB per file
|
||||
backupCount=3, # Keep 3 backup files (40MB total)
|
||||
)
|
||||
error_log_handler.setLevel(logging.ERROR)
|
||||
error_log_handler.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT, no_color=True))
|
||||
@@ -155,7 +179,13 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
|
||||
# Configure the root logger
|
||||
logging.basicConfig(
|
||||
format=DEBUG_LOG_FORMAT if config.level == logging.DEBUG else SIMPLE_LOG_FORMAT,
|
||||
format=(
|
||||
"%(levelname)s %(message)s"
|
||||
if structured_logging
|
||||
else (
|
||||
DEBUG_LOG_FORMAT if config.level == logging.DEBUG else SIMPLE_LOG_FORMAT
|
||||
)
|
||||
),
|
||||
level=config.level,
|
||||
handlers=log_handlers,
|
||||
)
|
||||
|
||||
@@ -1,39 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import uvicorn.config
|
||||
from colorama import Fore
|
||||
|
||||
|
||||
def remove_color_codes(s: str) -> str:
|
||||
return re.sub(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])", "", s)
|
||||
|
||||
|
||||
def fmt_kwargs(kwargs: dict) -> str:
|
||||
return ", ".join(f"{n}={repr(v)}" for n, v in kwargs.items())
|
||||
|
||||
|
||||
def print_attribute(
|
||||
title: str, value: Any, title_color: str = Fore.GREEN, value_color: str = ""
|
||||
) -> None:
|
||||
logger = logging.getLogger()
|
||||
logger.info(
|
||||
str(value),
|
||||
extra={
|
||||
"title": f"{title.rstrip(':')}:",
|
||||
"title_color": title_color,
|
||||
"color": value_color,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def generate_uvicorn_config():
|
||||
"""
|
||||
Generates a uvicorn logging config that silences uvicorn's default logging and tells it to use the native logging module.
|
||||
"""
|
||||
log_config = dict(uvicorn.config.LOGGING_CONFIG)
|
||||
log_config["loggers"]["uvicorn"] = {"handlers": []}
|
||||
log_config["loggers"]["uvicorn.error"] = {"handlers": []}
|
||||
log_config["loggers"]["uvicorn.access"] = {"handlers": []}
|
||||
return log_config
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
@@ -13,8 +15,8 @@ class RateLimitSettings(BaseSettings):
|
||||
default="6379", description="Redis port", validation_alias="REDIS_PORT"
|
||||
)
|
||||
|
||||
redis_password: str = Field(
|
||||
default="password",
|
||||
redis_password: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Redis password",
|
||||
validation_alias="REDIS_PASSWORD",
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ class RateLimiter:
|
||||
self,
|
||||
redis_host: str = RATE_LIMIT_SETTINGS.redis_host,
|
||||
redis_port: str = RATE_LIMIT_SETTINGS.redis_port,
|
||||
redis_password: str = RATE_LIMIT_SETTINGS.redis_password,
|
||||
redis_password: str | None = RATE_LIMIT_SETTINGS.redis_password,
|
||||
requests_per_minute: int = RATE_LIMIT_SETTINGS.requests_per_minute,
|
||||
):
|
||||
self.redis = Redis(
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
import inspect
|
||||
import threading
|
||||
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
def thread_cached(
|
||||
func: Callable[P, R] | Callable[P, Awaitable[R]],
|
||||
) -> Callable[P, R] | Callable[P, Awaitable[R]]:
|
||||
thread_local = threading.local()
|
||||
|
||||
def _clear():
|
||||
if hasattr(thread_local, "cache"):
|
||||
del thread_local.cache
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = await cast(Callable[P, Awaitable[R]], func)(
|
||||
*args, **kwargs
|
||||
)
|
||||
return cache[key]
|
||||
|
||||
setattr(async_wrapper, "clear_cache", _clear)
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(sync_wrapper, "clear_cache", _clear)
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
2027
autogpt_platform/autogpt_libs/poetry.lock
generated
2027
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,29 +1,33 @@
|
||||
[tool.poetry]
|
||||
name = "autogpt-libs"
|
||||
version = "0.2.0"
|
||||
description = "Shared libraries across NextGen AutoGPT"
|
||||
authors = ["Aarushi <aarushik93@gmail.com>"]
|
||||
description = "Shared libraries across AutoGPT Platform"
|
||||
authors = ["AutoGPT team <info@agpt.co>"]
|
||||
readme = "README.md"
|
||||
packages = [{ include = "autogpt_libs" }]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<4.0"
|
||||
colorama = "^0.4.6"
|
||||
cryptography = "^45.0"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.116.1"
|
||||
google-cloud-logging = "^3.12.1"
|
||||
pydantic = "^2.11.4"
|
||||
pydantic-settings = "^2.9.1"
|
||||
pyjwt = "^2.10.1"
|
||||
pytest-asyncio = "^0.26.0"
|
||||
pytest-mock = "^3.14.0"
|
||||
supabase = "^2.15.1"
|
||||
launchdarkly-server-sdk = "^9.11.1"
|
||||
fastapi = "^0.115.12"
|
||||
uvicorn = "^0.34.3"
|
||||
launchdarkly-server-sdk = "^9.12.0"
|
||||
pydantic = "^2.11.7"
|
||||
pydantic-settings = "^2.10.1"
|
||||
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
||||
redis = "^6.2.0"
|
||||
supabase = "^2.16.0"
|
||||
uvicorn = "^0.35.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
redis = "^5.2.1"
|
||||
ruff = "^0.11.10"
|
||||
pyright = "^1.1.404"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
pytest-mock = "^3.14.1"
|
||||
pytest-cov = "^6.2.1"
|
||||
ruff = "^0.12.11"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
52
autogpt_platform/backend/.dockerignore
Normal file
52
autogpt_platform/backend/.dockerignore
Normal file
@@ -0,0 +1,52 @@
|
||||
# Development and testing files
|
||||
**/__pycache__
|
||||
**/*.pyc
|
||||
**/*.pyo
|
||||
**/*.pyd
|
||||
**/.Python
|
||||
**/env/
|
||||
**/venv/
|
||||
**/.venv/
|
||||
**/pip-log.txt
|
||||
**/.pytest_cache/
|
||||
**/test-results/
|
||||
**/snapshots/
|
||||
**/test/
|
||||
|
||||
# IDE and editor files
|
||||
**/.vscode/
|
||||
**/.idea/
|
||||
**/*.swp
|
||||
**/*.swo
|
||||
*~
|
||||
|
||||
# OS files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Logs
|
||||
**/*.log
|
||||
**/logs/
|
||||
|
||||
# Git
|
||||
.git/
|
||||
.gitignore
|
||||
|
||||
# Documentation
|
||||
**/*.md
|
||||
!README.md
|
||||
|
||||
# Local development files
|
||||
.env
|
||||
.env.local
|
||||
**/.env.test
|
||||
|
||||
# Build artifacts
|
||||
**/dist/
|
||||
**/build/
|
||||
**/target/
|
||||
|
||||
# Docker files (avoid recursion)
|
||||
Dockerfile*
|
||||
docker-compose*
|
||||
.dockerignore
|
||||
@@ -1,3 +1,9 @@
|
||||
# Backend Configuration
|
||||
# This file contains environment variables that MUST be set for the AutoGPT platform
|
||||
# Variables with working defaults in settings.py are not included here
|
||||
|
||||
## ===== REQUIRED DATABASE CONFIGURATION ===== ##
|
||||
# PostgreSQL Database Connection
|
||||
DB_USER=postgres
|
||||
DB_PASS=your-super-secret-and-long-postgres-password
|
||||
DB_NAME=postgres
|
||||
@@ -11,71 +17,48 @@ DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME
|
||||
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
|
||||
# EXECUTOR
|
||||
NUM_GRAPH_WORKERS=10
|
||||
|
||||
BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]
|
||||
|
||||
# generate using `from cryptography.fernet import Fernet;Fernet.generate_key().decode()`
|
||||
ENCRYPTION_KEY='dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw='
|
||||
UNSUBSCRIBE_SECRET_KEY = 'HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio='
|
||||
|
||||
## ===== REQUIRED SERVICE CREDENTIALS ===== ##
|
||||
# Redis Configuration
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=password
|
||||
# REDIS_PASSWORD=
|
||||
|
||||
ENABLE_CREDIT=false
|
||||
STRIPE_API_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
|
||||
# What environment things should be logged under: local dev or prod
|
||||
APP_ENV=local
|
||||
# What environment to behave as: "local" or "cloud"
|
||||
BEHAVE_AS=local
|
||||
PYRO_HOST=localhost
|
||||
SENTRY_DSN=
|
||||
|
||||
# Email For Postmark so we can send emails
|
||||
POSTMARK_SERVER_API_TOKEN=
|
||||
POSTMARK_SENDER_EMAIL=invalid@invalid.com
|
||||
POSTMARK_WEBHOOK_TOKEN=
|
||||
|
||||
## User auth with Supabase is required for any of the 3rd party integrations with auth to work.
|
||||
ENABLE_AUTH=true
|
||||
SUPABASE_URL=http://localhost:8000
|
||||
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
|
||||
# RabbitMQ credentials -- Used for communication between services
|
||||
RABBITMQ_HOST=localhost
|
||||
RABBITMQ_PORT=5672
|
||||
# RabbitMQ Credentials
|
||||
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
|
||||
## GCS bucket is required for marketplace and library functionality
|
||||
# Supabase Authentication
|
||||
SUPABASE_URL=http://localhost:8000
|
||||
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
JWT_VERIFY_KEY=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
|
||||
## ===== REQUIRED SECURITY KEYS ===== ##
|
||||
# Generate using: from cryptography.fernet import Fernet;Fernet.generate_key().decode()
|
||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||
|
||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||
# Platform URLs (set these for webhooks and OAuth to work)
|
||||
PLATFORM_BASE_URL=http://localhost:8000
|
||||
FRONTEND_BASE_URL=http://localhost:3000
|
||||
|
||||
# Media Storage (required for marketplace and library functionality)
|
||||
MEDIA_GCS_BUCKET_NAME=
|
||||
|
||||
## For local development, you may need to set FRONTEND_BASE_URL for the OAuth flow
|
||||
## for integrations to work. Defaults to the value of PLATFORM_BASE_URL if not set.
|
||||
# FRONTEND_BASE_URL=http://localhost:3000
|
||||
## ===== API KEYS AND OAUTH CREDENTIALS ===== ##
|
||||
# All API keys below are optional - only add what you need
|
||||
|
||||
## PLATFORM_BASE_URL must be set to a *publicly accessible* URL pointing to your backend
|
||||
## to use the platform's webhook-related functionality.
|
||||
## If you are developing locally, you can use something like ngrok to get a publc URL
|
||||
## and tunnel it to your locally running backend.
|
||||
PLATFORM_BASE_URL=http://localhost:3000
|
||||
|
||||
## Cloudflare Turnstile (CAPTCHA) Configuration
|
||||
## Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
|
||||
## This is the backend secret key
|
||||
TURNSTILE_SECRET_KEY=
|
||||
## This is the verify URL
|
||||
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
|
||||
|
||||
## == INTEGRATION CREDENTIALS == ##
|
||||
# Each set of server side credentials is required for the corresponding 3rd party
|
||||
# integration to work.
|
||||
# AI/LLM Services
|
||||
OPENAI_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
GROQ_API_KEY=
|
||||
LLAMA_API_KEY=
|
||||
AIML_API_KEY=
|
||||
V0_API_KEY=
|
||||
OPEN_ROUTER_API_KEY=
|
||||
NVIDIA_API_KEY=
|
||||
|
||||
# OAuth Credentials
|
||||
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
|
||||
# e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||
|
||||
@@ -83,9 +66,13 @@ TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
|
||||
GITHUB_CLIENT_ID=
|
||||
GITHUB_CLIENT_SECRET=
|
||||
|
||||
# Notion OAuth App server credentials - https://developers.notion.com/docs/authorization
|
||||
# Configure a public integration
|
||||
NOTION_CLIENT_ID=
|
||||
NOTION_CLIENT_SECRET=
|
||||
|
||||
# Google OAuth App server credentials - https://console.cloud.google.com/apis/credentials, and enable gmail api and set scopes
|
||||
# https://console.cloud.google.com/apis/credentials/consent ?project=<your_project_id>
|
||||
|
||||
# You'll need to add/enable the following scopes (minimum):
|
||||
# https://console.developers.google.com/apis/api/gmail.googleapis.com/overview ?project=<your_project_id>
|
||||
# https://console.cloud.google.com/apis/library/sheets.googleapis.com/ ?project=<your_project_id>
|
||||
@@ -121,87 +108,68 @@ LINEAR_CLIENT_SECRET=
|
||||
TODOIST_CLIENT_ID=
|
||||
TODOIST_CLIENT_SECRET=
|
||||
|
||||
## ===== OPTIONAL API KEYS ===== ##
|
||||
NOTION_CLIENT_ID=
|
||||
NOTION_CLIENT_SECRET=
|
||||
|
||||
# LLM
|
||||
OPENAI_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
AIML_API_KEY=
|
||||
GROQ_API_KEY=
|
||||
OPEN_ROUTER_API_KEY=
|
||||
LLAMA_API_KEY=
|
||||
# Discord OAuth App credentials
|
||||
# 1. Go to https://discord.com/developers/applications
|
||||
# 2. Create a new application
|
||||
# 3. Go to OAuth2 section and add redirect URI: http://localhost:3000/auth/integrations/oauth_callback
|
||||
# 4. Copy Client ID and Client Secret below
|
||||
DISCORD_CLIENT_ID=
|
||||
DISCORD_CLIENT_SECRET=
|
||||
|
||||
# Reddit
|
||||
# Go to https://www.reddit.com/prefs/apps and create a new app
|
||||
# Choose "script" for the type
|
||||
# Fill in the redirect uri as <your_frontend_url>/auth/integrations/oauth_callback, e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||
REDDIT_CLIENT_ID=
|
||||
REDDIT_CLIENT_SECRET=
|
||||
REDDIT_USER_AGENT="AutoGPT:1.0 (by /u/autogpt)"
|
||||
|
||||
# Discord
|
||||
DISCORD_BOT_TOKEN=
|
||||
# Payment Processing
|
||||
STRIPE_API_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
|
||||
# SMTP/Email
|
||||
SMTP_SERVER=
|
||||
SMTP_PORT=
|
||||
SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
# Email Service (for sending notifications and confirmations)
|
||||
POSTMARK_SERVER_API_TOKEN=
|
||||
POSTMARK_SENDER_EMAIL=invalid@invalid.com
|
||||
POSTMARK_WEBHOOK_TOKEN=
|
||||
|
||||
# D-ID
|
||||
# Error Tracking
|
||||
SENTRY_DSN=
|
||||
|
||||
# Feature Flags
|
||||
LAUNCH_DARKLY_SDK_KEY=
|
||||
|
||||
# Content Generation & Media
|
||||
DID_API_KEY=
|
||||
FAL_API_KEY=
|
||||
IDEOGRAM_API_KEY=
|
||||
REPLICATE_API_KEY=
|
||||
REVID_API_KEY=
|
||||
SCREENSHOTONE_API_KEY=
|
||||
UNREAL_SPEECH_API_KEY=
|
||||
|
||||
# Open Weather Map
|
||||
# Data & Search Services
|
||||
E2B_API_KEY=
|
||||
EXA_API_KEY=
|
||||
JINA_API_KEY=
|
||||
MEM0_API_KEY=
|
||||
OPENWEATHERMAP_API_KEY=
|
||||
|
||||
# SMTP
|
||||
SMTP_SERVER=
|
||||
SMTP_PORT=
|
||||
SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# Medium
|
||||
MEDIUM_API_KEY=
|
||||
MEDIUM_AUTHOR_ID=
|
||||
|
||||
# Google Maps
|
||||
GOOGLE_MAPS_API_KEY=
|
||||
|
||||
# Replicate
|
||||
REPLICATE_API_KEY=
|
||||
# Communication Services
|
||||
DISCORD_BOT_TOKEN=
|
||||
MEDIUM_API_KEY=
|
||||
MEDIUM_AUTHOR_ID=
|
||||
SMTP_SERVER=
|
||||
SMTP_PORT=
|
||||
SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# Ideogram
|
||||
IDEOGRAM_API_KEY=
|
||||
|
||||
# Fal
|
||||
FAL_API_KEY=
|
||||
|
||||
# Exa
|
||||
EXA_API_KEY=
|
||||
|
||||
# E2B
|
||||
E2B_API_KEY=
|
||||
|
||||
# Mem0
|
||||
MEM0_API_KEY=
|
||||
|
||||
# Nvidia
|
||||
NVIDIA_API_KEY=
|
||||
|
||||
# Apollo
|
||||
# Business & Marketing Tools
|
||||
APOLLO_API_KEY=
|
||||
|
||||
# SmartLead
|
||||
ENRICHLAYER_API_KEY=
|
||||
AYRSHARE_API_KEY=
|
||||
AYRSHARE_JWT_KEY=
|
||||
SMARTLEAD_API_KEY=
|
||||
|
||||
# ZeroBounce
|
||||
ZEROBOUNCE_API_KEY=
|
||||
|
||||
## ===== OPTIONAL API KEYS END ===== ##
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=INFO
|
||||
ENABLE_CLOUD_LOGGING=false
|
||||
ENABLE_FILE_LOGGING=false
|
||||
# Use to manually set the log directory
|
||||
# LOG_DIR=./logs
|
||||
# Other Services
|
||||
AUTOMOD_API_KEY=
|
||||
11
autogpt_platform/backend/.gitignore
vendored
11
autogpt_platform/backend/.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
.env
|
||||
database.db
|
||||
database.db-journal
|
||||
dev.db
|
||||
@@ -8,4 +9,12 @@ secrets/*
|
||||
!secrets/.gitkeep
|
||||
|
||||
*.ignore.*
|
||||
*.ign.*
|
||||
*.ign.*
|
||||
|
||||
# Load test results and reports
|
||||
load-tests/*_RESULTS.md
|
||||
load-tests/*_REPORT.md
|
||||
load-tests/results/
|
||||
load-tests/*.json
|
||||
load-tests/*.log
|
||||
load-tests/node_modules/*
|
||||
|
||||
@@ -1,31 +1,43 @@
|
||||
FROM python:3.11.10-slim-bookworm AS builder
|
||||
FROM debian:13-slim AS builder
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONDONTWRITEBYTECODE 1
|
||||
ENV PYTHONUNBUFFERED 1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
|
||||
|
||||
RUN apt-get update --allow-releaseinfo-change --fix-missing
|
||||
# Install Node.js repository key and setup
|
||||
RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
||||
&& apt-get install -y curl ca-certificates gnupg \
|
||||
&& mkdir -p /etc/apt/keyrings \
|
||||
&& curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg \
|
||||
&& echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" | tee /etc/apt/sources.list.d/nodesource.list
|
||||
|
||||
# Install build dependencies
|
||||
RUN apt-get install -y build-essential
|
||||
RUN apt-get install -y libpq5
|
||||
RUN apt-get install -y libz-dev
|
||||
RUN apt-get install -y libssl-dev
|
||||
RUN apt-get install -y postgresql-client
|
||||
# Update package list and install Python, Node.js, and build dependencies
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y \
|
||||
python3.13 \
|
||||
python3.13-dev \
|
||||
python3.13-venv \
|
||||
python3-pip \
|
||||
build-essential \
|
||||
libpq5 \
|
||||
libz-dev \
|
||||
libssl-dev \
|
||||
postgresql-client \
|
||||
nodejs \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry
|
||||
ENV POETRY_NO_INTERACTION=1
|
||||
ENV POETRY_VIRTUALENVS_CREATE=false
|
||||
ENV POETRY_VIRTUALENVS_CREATE=true
|
||||
ENV POETRY_VIRTUALENVS_IN_PROJECT=true
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||
RUN pip3 install --upgrade pip setuptools
|
||||
|
||||
RUN pip3 install poetry
|
||||
RUN pip3 install poetry --break-system-packages
|
||||
|
||||
# Copy and install dependencies
|
||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||
@@ -35,29 +47,38 @@ RUN poetry install --no-ansi --no-root
|
||||
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
RUN poetry run prisma generate
|
||||
|
||||
FROM python:3.11.10-slim-bookworm AS server_dependencies
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry \
|
||||
POETRY_NO_INTERACTION=1 \
|
||||
POETRY_VIRTUALENVS_CREATE=false
|
||||
POETRY_VIRTUALENVS_CREATE=true \
|
||||
POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||
DEBIAN_FRONTEND=noninteractive
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||
RUN pip3 install --upgrade pip setuptools
|
||||
# Install Python without upgrading system-managed packages
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.13 \
|
||||
python3-pip \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy only necessary files from builder
|
||||
COPY --from=builder /app /app
|
||||
COPY --from=builder /usr/local/lib/python3.11 /usr/local/lib/python3.11
|
||||
COPY --from=builder /usr/local/bin /usr/local/bin
|
||||
# Copy Prisma binaries
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
# Copy Node.js installation for Prisma
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||
|
||||
RUN mkdir -p /app/autogpt_platform/autogpt_libs
|
||||
RUN mkdir -p /app/autogpt_platform/backend
|
||||
@@ -68,6 +89,13 @@ COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.tom
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
FROM server_dependencies AS migrate
|
||||
|
||||
# Migration stage only needs schema and migrations - much lighter than full backend
|
||||
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py /app/autogpt_platform/backend/backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
|
||||
|
||||
FROM server_dependencies AS server
|
||||
|
||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||
|
||||
@@ -132,17 +132,58 @@ def test_endpoint_success(snapshot: Snapshot):
|
||||
|
||||
### Testing with Authentication
|
||||
|
||||
For the main API routes that use JWT authentication, auth is provided by the `autogpt_libs.auth` module. If the test actually uses the `user_id`, the recommended approach for testing is to mock the `get_jwt_payload` function, which underpins all higher-level auth functions used in the API (`requires_user`, `requires_admin_user`, `get_user_id`).
|
||||
|
||||
If the test doesn't need the `user_id` specifically, mocking is not necessary as during tests auth is disabled anyway (see `conftest.py`).
|
||||
|
||||
#### Using Global Auth Fixtures
|
||||
|
||||
Two global auth fixtures are provided by `backend/server/conftest.py`:
|
||||
|
||||
- `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id")
|
||||
- `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id")
|
||||
|
||||
These provide the easiest way to set up authentication mocking in test modules:
|
||||
|
||||
```python
|
||||
def override_auth_middleware():
|
||||
return {"sub": "test-user-id"}
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
from backend.server.v2.myroute import router
|
||||
|
||||
def override_get_user_id():
|
||||
return "test-user-id"
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
app.dependency_overrides[auth_middleware] = override_auth_middleware
|
||||
app.dependency_overrides[get_user_id] = override_get_user_id
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_user):
|
||||
"""Setup auth overrides for all tests in this module"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_user['get_jwt_payload']
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
```
|
||||
|
||||
For admin-only endpoints, use `mock_jwt_admin` instead:
|
||||
|
||||
```python
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_auth(mock_jwt_admin):
|
||||
"""Setup auth overrides for admin tests"""
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin['get_jwt_payload']
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
```
|
||||
|
||||
The IDs are also available separately as fixtures:
|
||||
|
||||
- `test_user_id`
|
||||
- `admin_user_id`
|
||||
- `target_user_id` (for admin <-> user operations)
|
||||
|
||||
### Mocking External Services
|
||||
|
||||
```python
|
||||
@@ -153,10 +194,10 @@ def test_external_api_call(mocker, snapshot):
|
||||
"backend.services.external_api.call",
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
|
||||
response = client.post("/api/process")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
snapshot.assert_match(
|
||||
json.dumps(response.json(), indent=2, sort_keys=True),
|
||||
@@ -187,6 +228,17 @@ def test_external_api_call(mocker, snapshot):
|
||||
- Use `async def` with `@pytest.mark.asyncio` for testing async functions directly
|
||||
|
||||
### 5. Fixtures
|
||||
|
||||
#### Global Fixtures (conftest.py)
|
||||
|
||||
Authentication fixtures are available globally from `conftest.py`:
|
||||
|
||||
- `mock_jwt_user` - Standard user authentication
|
||||
- `mock_jwt_admin` - Admin user authentication
|
||||
- `configured_snapshot` - Pre-configured snapshot fixture
|
||||
|
||||
#### Custom Fixtures
|
||||
|
||||
Create reusable fixtures for common test data:
|
||||
|
||||
```python
|
||||
@@ -202,9 +254,18 @@ def test_create_user(sample_user, snapshot):
|
||||
# ... test implementation
|
||||
```
|
||||
|
||||
#### Test Isolation
|
||||
|
||||
All tests must use fixtures that ensure proper isolation:
|
||||
|
||||
- Authentication overrides are automatically cleaned up after each test
|
||||
- Database connections are properly managed with cleanup
|
||||
- Mock objects are reset between tests
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
The GitHub Actions workflow automatically runs tests on:
|
||||
|
||||
- Pull requests
|
||||
- Pushes to main branch
|
||||
|
||||
@@ -216,16 +277,19 @@ Snapshot tests work in CI by:
|
||||
## Troubleshooting
|
||||
|
||||
### Snapshot Mismatches
|
||||
|
||||
- Review the diff carefully
|
||||
- If changes are expected: `poetry run pytest --snapshot-update`
|
||||
- If changes are unexpected: Fix the code causing the difference
|
||||
|
||||
### Async Test Issues
|
||||
|
||||
- Ensure async functions use `@pytest.mark.asyncio`
|
||||
- Use `AsyncMock` for mocking async functions
|
||||
- FastAPI TestClient handles async automatically
|
||||
|
||||
### Import Errors
|
||||
|
||||
- Check that all dependencies are in `pyproject.toml`
|
||||
- Run `poetry install` to ensure dependencies are installed
|
||||
- Verify import paths are correct
|
||||
@@ -234,4 +298,4 @@ Snapshot tests work in CI by:
|
||||
|
||||
Snapshot testing provides a powerful way to ensure API responses remain consistent. Combined with traditional assertions, it creates a robust test suite that catches regressions while remaining maintainable.
|
||||
|
||||
Remember: Good tests are as important as good code!
|
||||
Remember: Good tests are as important as good code!
|
||||
|
||||
150
autogpt_platform/backend/backend/TEST_DATA_README.md
Normal file
150
autogpt_platform/backend/backend/TEST_DATA_README.md
Normal file
@@ -0,0 +1,150 @@
|
||||
# Test Data Scripts
|
||||
|
||||
This directory contains scripts for creating and updating test data in the AutoGPT Platform database, specifically designed to test the materialized views for the store functionality.
|
||||
|
||||
## Scripts
|
||||
|
||||
### test_data_creator.py
|
||||
Creates a comprehensive set of test data including:
|
||||
- Users with profiles
|
||||
- Agent graphs, nodes, and executions
|
||||
- Store listings with multiple versions
|
||||
- Reviews and ratings
|
||||
- Library agents
|
||||
- Integration webhooks
|
||||
- Onboarding data
|
||||
- Credit transactions
|
||||
|
||||
**Image/Video Domains Used:**
|
||||
- Images: `picsum.photos` (for all image URLs)
|
||||
- Videos: `youtube.com` (for store listing videos)
|
||||
|
||||
### test_data_updater.py
|
||||
Updates existing test data to simulate real-world changes:
|
||||
- Adds new agent graph executions
|
||||
- Creates new store listing reviews
|
||||
- Updates store listing versions
|
||||
- Adds credit transactions
|
||||
- Refreshes materialized views
|
||||
|
||||
### check_db.py
|
||||
Tests and verifies materialized views functionality:
|
||||
- Checks pg_cron job status (for automatic refresh)
|
||||
- Displays current materialized view counts
|
||||
- Adds test data (executions and reviews)
|
||||
- Creates store listings if none exist
|
||||
- Manually refreshes materialized views
|
||||
- Compares before/after counts to verify updates
|
||||
- Provides a summary of test results
|
||||
|
||||
## Materialized Views
|
||||
|
||||
The scripts test three key database views:
|
||||
|
||||
1. **mv_agent_run_counts**: Tracks execution counts by agent
|
||||
2. **mv_review_stats**: Tracks review statistics (count, average rating) by store listing
|
||||
3. **StoreAgent**: A view that combines store listing data with execution counts and ratings for display
|
||||
|
||||
The materialized views (mv_agent_run_counts and mv_review_stats) are automatically refreshed every 15 minutes via pg_cron, or can be manually refreshed using the `refresh_store_materialized_views()` function.
|
||||
|
||||
## Usage
|
||||
|
||||
### Prerequisites
|
||||
|
||||
1. Ensure the database is running:
|
||||
```bash
|
||||
docker compose up -d
|
||||
# or for test database:
|
||||
docker compose -f docker-compose.test.yaml --env-file ../.env up -d
|
||||
```
|
||||
|
||||
2. Run database migrations:
|
||||
```bash
|
||||
poetry run prisma migrate deploy
|
||||
```
|
||||
|
||||
### Running the Scripts
|
||||
|
||||
#### Option 1: Use the helper script (from backend directory)
|
||||
```bash
|
||||
poetry run python run_test_data.py
|
||||
```
|
||||
|
||||
#### Option 2: Run individually
|
||||
```bash
|
||||
# From backend/test directory:
|
||||
# Create initial test data
|
||||
poetry run python test_data_creator.py
|
||||
|
||||
# Update data to test materialized view changes
|
||||
poetry run python test_data_updater.py
|
||||
|
||||
# From backend directory:
|
||||
# Test materialized views functionality
|
||||
poetry run python check_db.py
|
||||
|
||||
# Check store data status
|
||||
poetry run python check_store_data.py
|
||||
```
|
||||
|
||||
#### Option 3: Use the shell script (from backend directory)
|
||||
```bash
|
||||
./run_test_data_scripts.sh
|
||||
```
|
||||
|
||||
### Manual Materialized View Refresh
|
||||
|
||||
To manually refresh the materialized views:
|
||||
```sql
|
||||
SELECT refresh_store_materialized_views();
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
The scripts use the database configuration from your `.env` file:
|
||||
- `DATABASE_URL`: PostgreSQL connection string
|
||||
- Database should have the platform schema
|
||||
|
||||
## Data Generation Limits
|
||||
|
||||
Configured in `test_data_creator.py`:
|
||||
- 100 users
|
||||
- 100 agent blocks
|
||||
- 1-5 graphs per user
|
||||
- 2-5 nodes per graph
|
||||
- 1-5 presets per user
|
||||
- 1-10 library agents per user
|
||||
- 1-20 executions per graph
|
||||
- 1-5 reviews per store listing version
|
||||
|
||||
## Notes
|
||||
|
||||
- All image URLs use `picsum.photos` for consistency with Next.js image configuration
|
||||
- The scripts create realistic relationships between entities
|
||||
- Materialized views are refreshed at the end of each script
|
||||
- Data is designed to test both happy paths and edge cases
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Reviews and StoreAgent view showing 0
|
||||
|
||||
If `check_db.py` shows that reviews remain at 0 and StoreAgent view shows 0 store agents:
|
||||
|
||||
1. **No store listings exist**: The script will automatically create test store listings if none exist
|
||||
2. **No approved versions**: Store listings need approved versions to appear in the StoreAgent view
|
||||
3. **Check with `check_store_data.py`**: This script provides detailed information about:
|
||||
- Total store listings
|
||||
- Store listing versions by status
|
||||
- Existing reviews
|
||||
- StoreAgent view contents
|
||||
- Agent graph executions
|
||||
|
||||
### pg_cron not installed
|
||||
|
||||
The warning "pg_cron extension is not installed" is normal in local development environments. The materialized views can still be refreshed manually using the `refresh_store_materialized_views()` function, which all scripts do automatically.
|
||||
|
||||
### Common Issues
|
||||
|
||||
- **Type errors with None values**: Fixed in the latest version of check_db.py by using `or 0` for nullable numeric fields
|
||||
- **Missing relations**: Ensure you're using the correct field names (e.g., `StoreListing` not `storeListing` in includes)
|
||||
- **Column name mismatches**: The database uses camelCase for column names (e.g., `agentGraphId` not `agent_graph_id`)
|
||||
@@ -1,6 +1,10 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.util.process import AppProcess
|
||||
|
||||
@@ -38,12 +42,12 @@ def main(**kwargs):
|
||||
from backend.server.ws_api import WebsocketServer
|
||||
|
||||
run_processes(
|
||||
DatabaseManager(),
|
||||
ExecutionManager(),
|
||||
DatabaseManager().set_log_level("warning"),
|
||||
Scheduler(),
|
||||
NotificationManager(),
|
||||
WebsocketServer(),
|
||||
AgentServer(),
|
||||
ExecutionManager(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,27 +1,45 @@
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from backend.util.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.block import Block
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@functools.cache
|
||||
@cached(ttl_seconds=3600)
|
||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
from backend.data.block import Block
|
||||
from backend.util.settings import Config
|
||||
|
||||
# Check if example blocks should be loaded from settings
|
||||
config = Config()
|
||||
load_examples = config.enable_example_blocks
|
||||
|
||||
# Dynamically load all modules under backend.blocks
|
||||
current_dir = Path(__file__).parent
|
||||
modules = [
|
||||
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
|
||||
for f in current_dir.rglob("*.py")
|
||||
if f.is_file() and f.name != "__init__.py"
|
||||
]
|
||||
modules = []
|
||||
for f in current_dir.rglob("*.py"):
|
||||
if not f.is_file() or f.name == "__init__.py" or f.name.startswith("test_"):
|
||||
continue
|
||||
|
||||
# Skip examples directory if not enabled
|
||||
relative_path = f.relative_to(current_dir)
|
||||
if not load_examples and relative_path.parts[0] == "examples":
|
||||
continue
|
||||
|
||||
module_path = str(relative_path)[:-3].replace(os.path.sep, ".")
|
||||
modules.append(module_path)
|
||||
|
||||
for module in modules:
|
||||
if not re.match("^[a-z0-9_.]+$", module):
|
||||
raise ValueError(
|
||||
@@ -86,7 +104,15 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
|
||||
available_blocks[block.id] = block_cls
|
||||
|
||||
return available_blocks
|
||||
# Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets
|
||||
from backend.data.block import is_block_auth_configured
|
||||
|
||||
filtered_blocks = {}
|
||||
for block_id, block_cls in available_blocks.items():
|
||||
if is_block_auth_configured(block_cls):
|
||||
filtered_blocks[block_id] = block_cls
|
||||
|
||||
return filtered_blocks
|
||||
|
||||
|
||||
__all__ = ["load_all_blocks"]
|
||||
|
||||
@@ -1,36 +1,38 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
from backend.data.execution import ExecutionStatus, NodesInputMasks
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.json import validate_with_jsonschema
|
||||
from backend.util.retry import func_retry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentExecutorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
user_id: str = SchemaField(description="User ID")
|
||||
graph_id: str = SchemaField(description="Graph ID")
|
||||
graph_version: int = SchemaField(description="Graph Version")
|
||||
agent_name: Optional[str] = SchemaField(
|
||||
default=None, description="Name to display in the Builder UI"
|
||||
)
|
||||
|
||||
inputs: BlockInput = SchemaField(description="Input data for the graph")
|
||||
input_schema: dict = SchemaField(description="Input schema for the graph")
|
||||
output_schema: dict = SchemaField(description="Output schema for the graph")
|
||||
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = SchemaField(
|
||||
nodes_input_masks: Optional[NodesInputMasks] = SchemaField(
|
||||
default=None, hidden=True
|
||||
)
|
||||
|
||||
@@ -49,9 +51,10 @@ class AgentExecutorBlock(Block):
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
return json.validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
return validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Use BlockSchema to avoid automatic error field that could clash with graph outputs
|
||||
pass
|
||||
|
||||
def __init__(self):
|
||||
@@ -64,7 +67,13 @@ class AgentExecutorBlock(Block):
|
||||
categories={BlockCategory.AGENT},
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
from backend.executor import utils as execution_utils
|
||||
|
||||
@@ -74,7 +83,18 @@ class AgentExecutorBlock(Block):
|
||||
user_id=input_data.user_id,
|
||||
inputs=input_data.inputs,
|
||||
nodes_input_masks=input_data.nodes_input_masks,
|
||||
use_db_query=False,
|
||||
parent_graph_exec_id=graph_exec_id,
|
||||
is_sub_graph=True, # AgentExecutorBlock executions are always sub-graphs
|
||||
)
|
||||
|
||||
logger = execution_utils.LogMetadata(
|
||||
logger=_logger,
|
||||
user_id=input_data.user_id,
|
||||
graph_eid=graph_exec.id,
|
||||
graph_id=input_data.graph_id,
|
||||
node_eid="*",
|
||||
node_id="*",
|
||||
block_name=self.name,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -83,21 +103,17 @@ class AgentExecutorBlock(Block):
|
||||
graph_version=input_data.graph_version,
|
||||
graph_exec_id=graph_exec.id,
|
||||
user_id=input_data.user_id,
|
||||
logger=logger,
|
||||
):
|
||||
yield name, data
|
||||
except asyncio.CancelledError:
|
||||
except BaseException as e:
|
||||
await self._stop(
|
||||
graph_exec_id=graph_exec.id,
|
||||
user_id=input_data.user_id,
|
||||
logger=logger,
|
||||
)
|
||||
logger.warning(
|
||||
f"Execution of graph {input_data.graph_id} version {input_data.graph_version} was cancelled."
|
||||
)
|
||||
await execution_utils.stop_graph_execution(
|
||||
graph_exec.id, use_db_query=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Execution of graph {input_data.graph_id} version {input_data.graph_version} failed: {e}, stopping execution."
|
||||
)
|
||||
await execution_utils.stop_graph_execution(
|
||||
graph_exec.id, use_db_query=False
|
||||
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} failed: {e.__class__.__name__} {str(e)}; execution is stopped."
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -107,6 +123,7 @@ class AgentExecutorBlock(Block):
|
||||
graph_version: int,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
logger,
|
||||
) -> BlockOutput:
|
||||
|
||||
from backend.data.execution import ExecutionEventType
|
||||
@@ -116,6 +133,7 @@ class AgentExecutorBlock(Block):
|
||||
|
||||
log_id = f"Graph #{graph_id}-V{graph_version}, exec-id: {graph_exec_id}"
|
||||
logger.info(f"Starting execution of {log_id}")
|
||||
yielded_node_exec_ids = set()
|
||||
|
||||
async for event in event_bus.listen(
|
||||
user_id=user_id,
|
||||
@@ -135,12 +153,26 @@ class AgentExecutorBlock(Block):
|
||||
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
|
||||
# If the graph execution is COMPLETED, TERMINATED, or FAILED,
|
||||
# we can stop listening for further events.
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
extra_cost=event.stats.cost if event.stats else 0,
|
||||
extra_steps=event.stats.node_exec_count if event.stats else 0,
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
|
||||
)
|
||||
|
||||
if event.node_exec_id in yielded_node_exec_ids:
|
||||
logger.warning(
|
||||
f"{log_id} received duplicate event for node execution {event.node_exec_id}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
yielded_node_exec_ids.add(event.node_exec_id)
|
||||
|
||||
if not event.block_id:
|
||||
logger.warning(f"{log_id} received event without block_id {event}")
|
||||
continue
|
||||
@@ -159,3 +191,25 @@ class AgentExecutorBlock(Block):
|
||||
f"Execution {log_id} produced {output_name}: {output_data}"
|
||||
)
|
||||
yield output_name, output_data
|
||||
|
||||
@func_retry
|
||||
async def _stop(
|
||||
self,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
logger,
|
||||
) -> None:
|
||||
from backend.executor import utils as execution_utils
|
||||
|
||||
log_id = f"Graph exec-id: {graph_exec_id}"
|
||||
logger.info(f"Stopping execution of {log_id}")
|
||||
|
||||
try:
|
||||
await execution_utils.stop_graph_execution(
|
||||
graph_exec_id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
wait_timeout=3600,
|
||||
)
|
||||
logger.info(f"Execution {log_id} stopped successfully.")
|
||||
except TimeoutError as e:
|
||||
logger.error(f"Execution {log_id} stop timed out: {e}")
|
||||
|
||||
219
autogpt_platform/backend/backend/blocks/ai_condition.py
Normal file
219
autogpt_platform/backend/backend/blocks/ai_condition.py
Normal file
@@ -0,0 +1,219 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks.llm import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AIBlockBase,
|
||||
AICredentials,
|
||||
AICredentialsField,
|
||||
LlmModel,
|
||||
LLMResponse,
|
||||
llm_call,
|
||||
)
|
||||
from backend.data.block import (
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class AIConditionBlock(AIBlockBase):
|
||||
"""
|
||||
An AI-powered condition block that uses natural language to evaluate conditions.
|
||||
|
||||
This block allows users to define conditions in plain English (e.g., "the input is an email address",
|
||||
"the input is a city in the USA") and uses AI to determine if the input satisfies the condition.
|
||||
It provides the same yes/no data pass-through functionality as the standard ConditionBlock.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
input_value: Any = SchemaField(
|
||||
description="The input value to evaluate with the AI condition",
|
||||
placeholder="Enter the value to be evaluated (text, number, or any data)",
|
||||
)
|
||||
condition: str = SchemaField(
|
||||
description="A plaintext English description of the condition to evaluate",
|
||||
placeholder="E.g., 'the input is the body of an email', 'the input is a City in the USA', 'the input is an error or a refusal'",
|
||||
)
|
||||
yes_value: Any = SchemaField(
|
||||
description="(Optional) Value to output if the condition is true. If not provided, input_value will be used.",
|
||||
placeholder="Leave empty to use input_value, or enter a specific value",
|
||||
default=None,
|
||||
)
|
||||
no_value: Any = SchemaField(
|
||||
description="(Optional) Value to output if the condition is false. If not provided, input_value will be used.",
|
||||
placeholder="Leave empty to use input_value, or enter a specific value",
|
||||
default=None,
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for evaluating the condition.",
|
||||
advanced=False,
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
result: bool = SchemaField(
|
||||
description="The result of the AI condition evaluation (True or False)"
|
||||
)
|
||||
yes_output: Any = SchemaField(
|
||||
description="The output value if the condition is true"
|
||||
)
|
||||
no_output: Any = SchemaField(
|
||||
description="The output value if the condition is false"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the AI evaluation is uncertain or fails"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="553ec5b8-6c45-4299-8d75-b394d05f72ff",
|
||||
input_schema=AIConditionBlock.Input,
|
||||
output_schema=AIConditionBlock.Output,
|
||||
description="Uses AI to evaluate natural language conditions and provide conditional outputs",
|
||||
categories={BlockCategory.AI, BlockCategory.LOGIC},
|
||||
test_input={
|
||||
"input_value": "john@example.com",
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "Valid email",
|
||||
"no_value": "Not an email",
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("result", True),
|
||||
("yes_output", "Valid email"),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda *args, **kwargs: LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response="true",
|
||||
tool_calls=None,
|
||||
prompt_tokens=50,
|
||||
completion_tokens=10,
|
||||
reasoning=None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
async def llm_call(
|
||||
self,
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
prompt: list,
|
||||
max_tokens: int,
|
||||
) -> LLMResponse:
|
||||
"""Wrapper method for llm_call to enable mocking in tests."""
|
||||
return await llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
force_json_output=False,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Evaluate the AI condition and return appropriate outputs.
|
||||
"""
|
||||
# Prepare the yes and no values, using input_value as default
|
||||
yes_value = (
|
||||
input_data.yes_value
|
||||
if input_data.yes_value is not None
|
||||
else input_data.input_value
|
||||
)
|
||||
no_value = (
|
||||
input_data.no_value
|
||||
if input_data.no_value is not None
|
||||
else input_data.input_value
|
||||
)
|
||||
|
||||
# Convert input_value to string for AI evaluation
|
||||
input_str = str(input_data.input_value)
|
||||
|
||||
# Create the prompt for AI evaluation
|
||||
prompt = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are an AI assistant that evaluates conditions based on input data. "
|
||||
"You must respond with only 'true' or 'false' (lowercase) to indicate whether "
|
||||
"the given condition is met by the input value. Be accurate and consider the "
|
||||
"context and meaning of both the input and the condition."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Input value: {input_str}\n"
|
||||
f"Condition to evaluate: {input_data.condition}\n\n"
|
||||
f"Does the input value satisfy the condition? Respond with only 'true' or 'false'."
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
# Call the LLM
|
||||
try:
|
||||
response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=10, # We only expect a true/false response
|
||||
)
|
||||
|
||||
# Extract the boolean result from the response
|
||||
response_text = response.response.strip().lower()
|
||||
if response_text == "true":
|
||||
result = True
|
||||
elif response_text == "false":
|
||||
result = False
|
||||
else:
|
||||
# If the response is not clear, try to interpret it using word boundaries
|
||||
import re
|
||||
|
||||
# Use word boundaries to avoid false positives like 'untrue' or '10'
|
||||
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", response_text))
|
||||
|
||||
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
|
||||
result = True
|
||||
elif tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
|
||||
result = False
|
||||
else:
|
||||
# Unclear or conflicting response - default to False and yield error
|
||||
result = False
|
||||
yield "error", f"Unclear AI response: '{response.response}'"
|
||||
|
||||
# Update internal stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
)
|
||||
)
|
||||
self.prompt = response.prompt
|
||||
|
||||
except Exception as e:
|
||||
# In case of any error, default to False to be safe
|
||||
result = False
|
||||
# Log the error but don't fail the block execution
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"AI condition evaluation failed: {str(e)}")
|
||||
yield "error", f"AI evaluation failed: {str(e)}"
|
||||
|
||||
# Yield results
|
||||
yield "result", result
|
||||
|
||||
if result:
|
||||
yield "yes_output", yes_value
|
||||
else:
|
||||
yield "no_output", no_value
|
||||
197
autogpt_platform/backend/backend/blocks/ai_image_customizer.py
Normal file
197
autogpt_platform/backend/backend/blocks/ai_image_customizer.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.file import MediaFileType, store_media_file
|
||||
|
||||
|
||||
class GeminiImageModel(str, Enum):
|
||||
NANO_BANANA = "google/nano-banana"
|
||||
NANO_BANANA_PRO = "google/nano-banana-pro"
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
MATCH_INPUT_IMAGE = "match_input_image"
|
||||
ASPECT_1_1 = "1:1"
|
||||
ASPECT_2_3 = "2:3"
|
||||
ASPECT_3_2 = "3:2"
|
||||
ASPECT_3_4 = "3:4"
|
||||
ASPECT_4_3 = "4:3"
|
||||
ASPECT_4_5 = "4:5"
|
||||
ASPECT_5_4 = "5:4"
|
||||
ASPECT_9_16 = "9:16"
|
||||
ASPECT_16_9 = "16:9"
|
||||
ASPECT_21_9 = "21:9"
|
||||
|
||||
|
||||
class OutputFormat(str, Enum):
|
||||
JPG = "jpg"
|
||||
PNG = "png"
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="replicate",
|
||||
api_key=SecretStr("mock-replicate-api-key"),
|
||||
title="Mock Replicate API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class AIImageCustomizerBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Replicate API key with permissions for Google Gemini image models",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="A text description of the image you want to generate",
|
||||
title="Prompt",
|
||||
)
|
||||
model: GeminiImageModel = SchemaField(
|
||||
description="The AI model to use for image generation and editing",
|
||||
default=GeminiImageModel.NANO_BANANA,
|
||||
title="Model",
|
||||
)
|
||||
images: list[MediaFileType] = SchemaField(
|
||||
description="Optional list of input images to reference or modify",
|
||||
default=[],
|
||||
title="Input Images",
|
||||
)
|
||||
aspect_ratio: AspectRatio = SchemaField(
|
||||
description="Aspect ratio of the generated image",
|
||||
default=AspectRatio.MATCH_INPUT_IMAGE,
|
||||
title="Aspect Ratio",
|
||||
)
|
||||
output_format: OutputFormat = SchemaField(
|
||||
description="Format of the output image",
|
||||
default=OutputFormat.PNG,
|
||||
title="Output Format",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
image_url: MediaFileType = SchemaField(description="URL of the generated image")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d76bbe4c-930e-4894-8469-b66775511f71",
|
||||
description=(
|
||||
"Generate and edit custom images using Google's Nano-Banana model from Gemini 2.5. "
|
||||
"Provide a prompt and optional reference images to create or modify images."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
input_schema=AIImageCustomizerBlock.Input,
|
||||
output_schema=AIImageCustomizerBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Make the scene more vibrant and colorful",
|
||||
"model": GeminiImageModel.NANO_BANANA,
|
||||
"images": [],
|
||||
"aspect_ratio": AspectRatio.MATCH_INPUT_IMAGE,
|
||||
"output_format": OutputFormat.JPG,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
("image_url", "https://replicate.delivery/generated-image.jpg"),
|
||||
],
|
||||
test_mock={
|
||||
"run_model": lambda *args, **kwargs: MediaFileType(
|
||||
"https://replicate.delivery/generated-image.jpg"
|
||||
),
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
# Convert local file paths to Data URIs (base64) so Replicate can access them
|
||||
processed_images = await asyncio.gather(
|
||||
*(
|
||||
store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=img,
|
||||
user_id=user_id,
|
||||
return_content=True,
|
||||
)
|
||||
for img in input_data.images
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.run_model(
|
||||
api_key=credentials.api_key,
|
||||
model_name=input_data.model.value,
|
||||
prompt=input_data.prompt,
|
||||
images=processed_images,
|
||||
aspect_ratio=input_data.aspect_ratio.value,
|
||||
output_format=input_data.output_format.value,
|
||||
)
|
||||
yield "image_url", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
async def run_model(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
images: list[MediaFileType],
|
||||
aspect_ratio: str,
|
||||
output_format: str,
|
||||
) -> MediaFileType:
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
|
||||
input_params: dict = {
|
||||
"prompt": prompt,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"output_format": output_format,
|
||||
}
|
||||
|
||||
# Add images to input if provided (API expects "image_input" parameter)
|
||||
if images:
|
||||
input_params["image_input"] = [str(img) for img in images]
|
||||
|
||||
output: FileOutput | str = await client.async_run( # type: ignore
|
||||
model_name,
|
||||
input=input_params,
|
||||
wait=False,
|
||||
)
|
||||
|
||||
if isinstance(output, FileOutput):
|
||||
return MediaFileType(output.url)
|
||||
if isinstance(output, str):
|
||||
return MediaFileType(output)
|
||||
|
||||
raise ValueError("No output received from the model")
|
||||
@@ -5,7 +5,7 @@ from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockSchema
|
||||
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -60,6 +60,14 @@ SIZE_TO_RECRAFT_DIMENSIONS = {
|
||||
ImageSize.TALL: "1024x1536",
|
||||
}
|
||||
|
||||
SIZE_TO_NANO_BANANA_RATIO = {
|
||||
ImageSize.SQUARE: "1:1",
|
||||
ImageSize.LANDSCAPE: "4:3",
|
||||
ImageSize.PORTRAIT: "3:4",
|
||||
ImageSize.WIDE: "16:9",
|
||||
ImageSize.TALL: "9:16",
|
||||
}
|
||||
|
||||
|
||||
class ImageStyle(str, Enum):
|
||||
"""
|
||||
@@ -98,10 +106,11 @@ class ImageGenModel(str, Enum):
|
||||
FLUX_ULTRA = "Flux 1.1 Pro Ultra"
|
||||
RECRAFT = "Recraft v3"
|
||||
SD3_5 = "Stable Diffusion 3.5 Medium"
|
||||
NANO_BANANA_PRO = "Nano Banana Pro"
|
||||
|
||||
|
||||
class AIImageGeneratorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -135,9 +144,8 @@ class AIImageGeneratorBlock(Block):
|
||||
title="Image Style",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
image_url: str = SchemaField(description="URL of the generated image")
|
||||
error: str = SchemaField(description="Error message if generation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -262,6 +270,20 @@ class AIImageGeneratorBlock(Block):
|
||||
)
|
||||
return output
|
||||
|
||||
elif input_data.model == ImageGenModel.NANO_BANANA_PRO:
|
||||
# Use Nano Banana Pro (Google Gemini 3 Pro Image)
|
||||
input_params = {
|
||||
"prompt": modified_prompt,
|
||||
"aspect_ratio": SIZE_TO_NANO_BANANA_RATIO[input_data.size],
|
||||
"resolution": "2K", # Default to 2K for good quality/cost balance
|
||||
"output_format": "jpg",
|
||||
"safety_filter_level": "block_only_high", # Most permissive
|
||||
}
|
||||
output = await self._run_client(
|
||||
credentials, "google/nano-banana-pro", input_params
|
||||
)
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate image: {str(e)}")
|
||||
|
||||
|
||||
@@ -6,7 +6,13 @@ from typing import Literal
|
||||
from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -54,7 +60,7 @@ class NormalizationStrategy(str, Enum):
|
||||
|
||||
|
||||
class AIMusicGeneratorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -107,9 +113,8 @@ class AIMusicGeneratorBlock(Block):
|
||||
title="Normalization Strategy",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
result: str = SchemaField(description="URL of the generated audio file")
|
||||
error: str = SchemaField(description="Error message if the model run failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -166,7 +171,7 @@ class AIMusicGeneratorBlock(Block):
|
||||
output_format=input_data.output_format,
|
||||
normalization_strategy=input_data.normalization_strategy,
|
||||
)
|
||||
if result and result != "No output received":
|
||||
if result and isinstance(result, str) and result.startswith("http"):
|
||||
yield "result", result
|
||||
return
|
||||
else:
|
||||
|
||||
@@ -6,7 +6,13 @@ from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -53,6 +59,7 @@ class AudioTrack(str, Enum):
|
||||
REFRESHER = ("Refresher",)
|
||||
TOURIST = ("Tourist",)
|
||||
TWIN_TYCHES = ("Twin Tyches",)
|
||||
DONT_STOP_ME_ABSTRACT_FUTURE_BASS = ("Dont Stop Me Abstract Future Bass",)
|
||||
|
||||
@property
|
||||
def audio_url(self):
|
||||
@@ -78,6 +85,7 @@ class AudioTrack(str, Enum):
|
||||
AudioTrack.REFRESHER: "https://cdn.tfrv.xyz/audio/refresher.mp3",
|
||||
AudioTrack.TOURIST: "https://cdn.tfrv.xyz/audio/tourist.mp3",
|
||||
AudioTrack.TWIN_TYCHES: "https://cdn.tfrv.xyz/audio/twin-tynches.mp3",
|
||||
AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS: "https://cdn.revid.ai/audio/_dont-stop-me-abstract-future-bass.mp3",
|
||||
}
|
||||
return audio_urls[self]
|
||||
|
||||
@@ -105,6 +113,7 @@ class GenerationPreset(str, Enum):
|
||||
MOVIE = ("Movie",)
|
||||
STYLIZED_ILLUSTRATION = ("Stylized Illustration",)
|
||||
MANGA = ("Manga",)
|
||||
DEFAULT = ("DEFAULT",)
|
||||
|
||||
|
||||
class Voice(str, Enum):
|
||||
@@ -114,6 +123,7 @@ class Voice(str, Enum):
|
||||
JESSICA = "Jessica"
|
||||
CHARLOTTE = "Charlotte"
|
||||
CALLUM = "Callum"
|
||||
EVA = "Eva"
|
||||
|
||||
@property
|
||||
def voice_id(self):
|
||||
@@ -124,6 +134,7 @@ class Voice(str, Enum):
|
||||
Voice.JESSICA: "cgSgspJ2msm6clMCkdW9",
|
||||
Voice.CHARLOTTE: "XB0fDUnXU5powFXDhCwa",
|
||||
Voice.CALLUM: "N2lVS1w4EtoT3dr4eOWO",
|
||||
Voice.EVA: "FGY2WhTYpPnrIDTdsKH5",
|
||||
}
|
||||
return voice_id_map[self]
|
||||
|
||||
@@ -141,7 +152,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIShortformVideoCreatorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
"""Creates a short‑form text‑to‑video clip using stock or AI imagery."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -180,44 +193,11 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
placeholder=VisualMediaType.STOCK_VIDEOS,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
video_url: str = SchemaField(description="The URL of the created video")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="361697fb-0c4f-4feb-aed3-8320c88c771b",
|
||||
description="Creates a shortform video using revid.ai",
|
||||
categories={BlockCategory.SOCIAL, BlockCategory.AI},
|
||||
input_schema=AIShortformVideoCreatorBlock.Input,
|
||||
output_schema=AIShortformVideoCreatorBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "[close-up of a cat] Meow!",
|
||||
"ratio": "9 / 16",
|
||||
"resolution": "720p",
|
||||
"frame_rate": 60,
|
||||
"generation_preset": GenerationPreset.LEONARDO,
|
||||
"background_music": AudioTrack.HIGHWAY_NOCTURNE,
|
||||
"voice": Voice.LILY,
|
||||
"video_style": VisualMediaType.STOCK_VIDEOS,
|
||||
},
|
||||
test_output=(
|
||||
"video_url",
|
||||
"https://example.com/video.mp4",
|
||||
),
|
||||
test_mock={
|
||||
"create_webhook": lambda: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda api_key, payload: {"pid": "test_pid"},
|
||||
"wait_for_video": lambda api_key, pid, webhook_token, max_wait_time=1000: "https://example.com/video.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def create_webhook(self):
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = await Requests().post(url, headers=headers)
|
||||
@@ -225,6 +205,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
"""Create a video using the Revid API."""
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().post(url, json=payload, headers=headers)
|
||||
@@ -234,6 +215,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
return response.json()
|
||||
|
||||
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
"""Check the status of a video creation job."""
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().get(url, headers=headers)
|
||||
@@ -243,9 +225,9 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
webhook_token: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
"""Wait for video creation to complete and return the video URL."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = await self.check_video_status(api_key, pid)
|
||||
@@ -266,6 +248,40 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="361697fb-0c4f-4feb-aed3-8320c88c771b",
|
||||
description="Creates a shortform video using revid.ai",
|
||||
categories={BlockCategory.SOCIAL, BlockCategory.AI},
|
||||
input_schema=AIShortformVideoCreatorBlock.Input,
|
||||
output_schema=AIShortformVideoCreatorBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "[close-up of a cat] Meow!",
|
||||
"ratio": "9 / 16",
|
||||
"resolution": "720p",
|
||||
"frame_rate": 60,
|
||||
"generation_preset": GenerationPreset.LEONARDO,
|
||||
"background_music": AudioTrack.HIGHWAY_NOCTURNE,
|
||||
"voice": Voice.LILY,
|
||||
"video_style": VisualMediaType.STOCK_VIDEOS,
|
||||
},
|
||||
test_output=("video_url", "https://example.com/video.mp4"),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/video.mp4",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
@@ -273,20 +289,18 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
logger.debug(f"Webhook URL: {webhook_url}")
|
||||
|
||||
audio_url = input_data.background_music.audio_url
|
||||
|
||||
payload = {
|
||||
"frameRate": input_data.frame_rate,
|
||||
"resolution": input_data.resolution,
|
||||
"frameDurationMultiplier": 18,
|
||||
"webhook": webhook_url,
|
||||
"webhook": None,
|
||||
"creationParams": {
|
||||
"mediaType": input_data.video_style,
|
||||
"captionPresetName": "Wrap 1",
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"hasEnhancedGeneration": True,
|
||||
"generationPreset": input_data.generation_preset.name,
|
||||
"selectedAudio": input_data.background_music,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"origin": "/create",
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
@@ -302,7 +316,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
"selectedStoryStyle": {"value": "custom", "label": "Custom"},
|
||||
"hasToGenerateVideos": input_data.video_style
|
||||
!= VisualMediaType.STOCK_VIDEOS,
|
||||
"audioUrl": audio_url,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -319,8 +333,368 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
logger.debug(
|
||||
f"Video created with project ID: {pid}. Waiting for completion..."
|
||||
)
|
||||
video_url = await self.wait_for_video(
|
||||
credentials.api_key, pid, webhook_token
|
||||
)
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
logger.debug(f"Video ready: {video_url}")
|
||||
yield "video_url", video_url
|
||||
|
||||
|
||||
class AIAdMakerVideoCreatorBlock(Block):
|
||||
"""Generates a 30‑second vertical AI advert using optional user‑supplied imagery."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Credentials for Revid.ai API access.",
|
||||
)
|
||||
script: str = SchemaField(
|
||||
description="Short advertising copy. Line breaks create new scenes.",
|
||||
placeholder="Introducing Foobar – [show product photo] the gadget that does it all.",
|
||||
)
|
||||
ratio: str = SchemaField(description="Aspect ratio", default="9 / 16")
|
||||
target_duration: int = SchemaField(
|
||||
description="Desired length of the ad in seconds.", default=30
|
||||
)
|
||||
voice: Voice = SchemaField(
|
||||
description="Narration voice", default=Voice.EVA, placeholder=Voice.EVA
|
||||
)
|
||||
background_music: AudioTrack = SchemaField(
|
||||
description="Background track",
|
||||
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS,
|
||||
)
|
||||
input_media_urls: list[str] = SchemaField(
|
||||
description="List of image URLs to feature in the advert.", default=[]
|
||||
)
|
||||
use_only_provided_media: bool = SchemaField(
|
||||
description="Restrict visuals to supplied images only.", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_url: str = SchemaField(description="URL of the finished advert")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = await Requests().post(url, headers=headers)
|
||||
webhook_data = response.json()
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
"""Create a video using the Revid API."""
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().post(url, json=payload, headers=headers)
|
||||
logger.debug(
|
||||
f"API Response Status Code: {response.status}, Content: {response.text}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
"""Check the status of a video creation job."""
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().get(url, headers=headers)
|
||||
return response.json()
|
||||
|
||||
async def wait_for_video(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
"""Wait for video creation to complete and return the video URL."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = await self.check_video_status(api_key, pid)
|
||||
logger.debug(f"Video status: {status}")
|
||||
|
||||
if status.get("status") == "ready" and "videoUrl" in status:
|
||||
return status["videoUrl"]
|
||||
elif status.get("status") == "error":
|
||||
error_message = status.get("error", "Unknown error occurred")
|
||||
logger.error(f"Video creation failed: {error_message}")
|
||||
raise ValueError(f"Video creation failed: {error_message}")
|
||||
elif status.get("status") in ["FAILED", "CANCELED"]:
|
||||
logger.error(f"Video creation failed: {status.get('message')}")
|
||||
raise ValueError(f"Video creation failed: {status.get('message')}")
|
||||
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="58bd2a19-115d-4fd1-8ca4-13b9e37fa6a0",
|
||||
description="Creates an AI‑generated 30‑second advert (text + images)",
|
||||
categories={BlockCategory.MARKETING, BlockCategory.AI},
|
||||
input_schema=AIAdMakerVideoCreatorBlock.Input,
|
||||
output_schema=AIAdMakerVideoCreatorBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "Test product launch!",
|
||||
"input_media_urls": [
|
||||
"https://cdn.revid.ai/uploads/1747076315114-image.png",
|
||||
],
|
||||
},
|
||||
test_output=("video_url", "https://example.com/ad.mp4"),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/ad.mp4",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
|
||||
payload = {
|
||||
"webhook": webhook_url,
|
||||
"creationParams": {
|
||||
"targetDuration": input_data.target_duration,
|
||||
"ratio": input_data.ratio,
|
||||
"mediaType": "aiVideo",
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
"slug": "ai-ad-generator",
|
||||
"slugNew": "",
|
||||
"isCopiedFrom": False,
|
||||
"hasToGenerateVoice": True,
|
||||
"hasToTranscript": False,
|
||||
"hasToSearchMedia": True,
|
||||
"hasAvatar": False,
|
||||
"hasWebsiteRecorder": False,
|
||||
"hasTextSmallAtBottom": False,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
|
||||
"selectedAvatarType": "video/mp4",
|
||||
"websiteToRecord": "",
|
||||
"hasToGenerateCover": True,
|
||||
"nbGenerations": 1,
|
||||
"disableCaptions": False,
|
||||
"mediaMultiplier": "medium",
|
||||
"characters": [],
|
||||
"captionPresetName": "Revid",
|
||||
"sourceType": "contentScraping",
|
||||
"selectedStoryStyle": {"value": "custom", "label": "General"},
|
||||
"generationPreset": "DEFAULT",
|
||||
"hasToGenerateMusic": False,
|
||||
"isOptimizedForChinese": False,
|
||||
"generationUserPrompt": "",
|
||||
"enableNsfwFilter": False,
|
||||
"addStickers": False,
|
||||
"typeMovingImageAnim": "dynamic",
|
||||
"hasToGenerateSoundEffects": False,
|
||||
"forceModelType": "gpt-image-1",
|
||||
"selectedCharacters": [],
|
||||
"lang": "",
|
||||
"voiceSpeed": 1,
|
||||
"disableAudio": False,
|
||||
"disableVoice": False,
|
||||
"useOnlyProvidedMedia": input_data.use_only_provided_media,
|
||||
"imageGenerationModel": "ultra",
|
||||
"videoGenerationModel": "pro",
|
||||
"hasEnhancedGeneration": True,
|
||||
"hasEnhancedGenerationPro": True,
|
||||
"inputMedias": [
|
||||
{"url": url, "title": "", "type": "image"}
|
||||
for url in input_data.input_media_urls
|
||||
],
|
||||
"hasToGenerateVideos": True,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
"watermark": None,
|
||||
},
|
||||
}
|
||||
|
||||
response = await self.create_video(credentials.api_key, payload)
|
||||
pid = response.get("pid")
|
||||
if not pid:
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
yield "video_url", video_url
|
||||
|
||||
|
||||
class AIScreenshotToVideoAdBlock(Block):
|
||||
"""Creates an advert where the supplied screenshot is narrated by an AI avatar."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(description="Revid.ai API key")
|
||||
script: str = SchemaField(
|
||||
description="Narration that will accompany the screenshot.",
|
||||
placeholder="Check out these amazing stats!",
|
||||
)
|
||||
screenshot_url: str = SchemaField(
|
||||
description="Screenshot or image URL to showcase."
|
||||
)
|
||||
ratio: str = SchemaField(default="9 / 16")
|
||||
target_duration: int = SchemaField(default=30)
|
||||
voice: Voice = SchemaField(default=Voice.EVA)
|
||||
background_music: AudioTrack = SchemaField(
|
||||
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_url: str = SchemaField(description="Rendered video URL")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = await Requests().post(url, headers=headers)
|
||||
webhook_data = response.json()
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
"""Create a video using the Revid API."""
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().post(url, json=payload, headers=headers)
|
||||
logger.debug(
|
||||
f"API Response Status Code: {response.status}, Content: {response.text}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
"""Check the status of a video creation job."""
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().get(url, headers=headers)
|
||||
return response.json()
|
||||
|
||||
async def wait_for_video(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
"""Wait for video creation to complete and return the video URL."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = await self.check_video_status(api_key, pid)
|
||||
logger.debug(f"Video status: {status}")
|
||||
|
||||
if status.get("status") == "ready" and "videoUrl" in status:
|
||||
return status["videoUrl"]
|
||||
elif status.get("status") == "error":
|
||||
error_message = status.get("error", "Unknown error occurred")
|
||||
logger.error(f"Video creation failed: {error_message}")
|
||||
raise ValueError(f"Video creation failed: {error_message}")
|
||||
elif status.get("status") in ["FAILED", "CANCELED"]:
|
||||
logger.error(f"Video creation failed: {status.get('message')}")
|
||||
raise ValueError(f"Video creation failed: {status.get('message')}")
|
||||
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0f3e4635-e810-43d9-9e81-49e6f4e83b7c",
|
||||
description="Turns a screenshot into an engaging, avatar‑narrated video advert.",
|
||||
categories={BlockCategory.AI, BlockCategory.MARKETING},
|
||||
input_schema=AIScreenshotToVideoAdBlock.Input,
|
||||
output_schema=AIScreenshotToVideoAdBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "Amazing numbers!",
|
||||
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
|
||||
},
|
||||
test_output=("video_url", "https://example.com/screenshot.mp4"),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/screenshot.mp4",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
|
||||
payload = {
|
||||
"webhook": webhook_url,
|
||||
"creationParams": {
|
||||
"targetDuration": input_data.target_duration,
|
||||
"ratio": input_data.ratio,
|
||||
"mediaType": "aiVideo",
|
||||
"hasAvatar": True,
|
||||
"removeAvatarBackground": True,
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
"slug": "ai-ad-generator",
|
||||
"slugNew": "screenshot-to-video-ad",
|
||||
"isCopiedFrom": "ai-ad-generator",
|
||||
"hasToGenerateVoice": True,
|
||||
"hasToTranscript": False,
|
||||
"hasToSearchMedia": True,
|
||||
"hasWebsiteRecorder": False,
|
||||
"hasTextSmallAtBottom": False,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
|
||||
"selectedAvatarType": "video/mp4",
|
||||
"websiteToRecord": "",
|
||||
"hasToGenerateCover": True,
|
||||
"nbGenerations": 1,
|
||||
"disableCaptions": False,
|
||||
"mediaMultiplier": "medium",
|
||||
"characters": [],
|
||||
"captionPresetName": "Revid",
|
||||
"sourceType": "contentScraping",
|
||||
"selectedStoryStyle": {"value": "custom", "label": "General"},
|
||||
"generationPreset": "DEFAULT",
|
||||
"hasToGenerateMusic": False,
|
||||
"isOptimizedForChinese": False,
|
||||
"generationUserPrompt": "",
|
||||
"enableNsfwFilter": False,
|
||||
"addStickers": False,
|
||||
"typeMovingImageAnim": "dynamic",
|
||||
"hasToGenerateSoundEffects": False,
|
||||
"forceModelType": "gpt-image-1",
|
||||
"selectedCharacters": [],
|
||||
"lang": "",
|
||||
"voiceSpeed": 1,
|
||||
"disableAudio": False,
|
||||
"disableVoice": False,
|
||||
"useOnlyProvidedMedia": True,
|
||||
"imageGenerationModel": "ultra",
|
||||
"videoGenerationModel": "ultra",
|
||||
"hasEnhancedGeneration": True,
|
||||
"hasEnhancedGenerationPro": True,
|
||||
"inputMedias": [
|
||||
{"url": input_data.screenshot_url, "title": "", "type": "image"}
|
||||
],
|
||||
"hasToGenerateVideos": True,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
"watermark": None,
|
||||
},
|
||||
}
|
||||
|
||||
response = await self.create_video(credentials.api_key, payload)
|
||||
pid = response.get("pid")
|
||||
if not pid:
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
yield "video_url", video_url
|
||||
|
||||
1435
autogpt_platform/backend/backend/blocks/airtable/_api.py
Normal file
1435
autogpt_platform/backend/backend/blocks/airtable/_api.py
Normal file
File diff suppressed because it is too large
Load Diff
323
autogpt_platform/backend/backend/blocks/airtable/_api_test.py
Normal file
323
autogpt_platform/backend/backend/blocks/airtable/_api_test.py
Normal file
@@ -0,0 +1,323 @@
|
||||
from os import getenv
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.sdk import APIKeyCredentials, SecretStr
|
||||
|
||||
from ._api import (
|
||||
TableFieldType,
|
||||
WebhookFilters,
|
||||
WebhookSpecification,
|
||||
create_base,
|
||||
create_field,
|
||||
create_record,
|
||||
create_table,
|
||||
create_webhook,
|
||||
delete_multiple_records,
|
||||
delete_record,
|
||||
delete_webhook,
|
||||
get_record,
|
||||
list_bases,
|
||||
list_records,
|
||||
list_webhook_payloads,
|
||||
update_field,
|
||||
update_multiple_records,
|
||||
update_record,
|
||||
update_table,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_update_table():
|
||||
|
||||
key = getenv("AIRTABLE_API_KEY")
|
||||
if not key:
|
||||
return pytest.skip("AIRTABLE_API_KEY is not set")
|
||||
|
||||
credentials = APIKeyCredentials(
|
||||
provider="airtable",
|
||||
api_key=SecretStr(key),
|
||||
)
|
||||
postfix = uuid4().hex[:4]
|
||||
workspace_id = "wsphuHmfllg7V3Brd"
|
||||
response = await create_base(credentials, workspace_id, "API Testing Base")
|
||||
assert response is not None, f"Checking create base response: {response}"
|
||||
assert (
|
||||
response.get("id") is not None
|
||||
), f"Checking create base response id: {response}"
|
||||
base_id = response.get("id")
|
||||
assert base_id is not None, f"Checking create base response id: {base_id}"
|
||||
|
||||
response = await list_bases(credentials)
|
||||
assert response is not None, f"Checking list bases response: {response}"
|
||||
assert "API Testing Base" in [
|
||||
base.get("name") for base in response.get("bases", [])
|
||||
], f"Checking list bases response bases: {response}"
|
||||
|
||||
table_name = f"test_table_{postfix}"
|
||||
table_fields = [{"name": "test_field", "type": "singleLineText"}]
|
||||
table = await create_table(credentials, base_id, table_name, table_fields)
|
||||
assert table.get("name") == table_name
|
||||
|
||||
table_id = table.get("id")
|
||||
|
||||
assert table_id is not None
|
||||
|
||||
table_name = f"test_table_updated_{postfix}"
|
||||
table_description = "test_description_updated"
|
||||
table = await update_table(
|
||||
credentials,
|
||||
base_id,
|
||||
table_id,
|
||||
table_name=table_name,
|
||||
table_description=table_description,
|
||||
)
|
||||
assert table.get("name") == table_name
|
||||
assert table.get("description") == table_description
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_field_type():
|
||||
|
||||
key = getenv("AIRTABLE_API_KEY")
|
||||
if not key:
|
||||
return pytest.skip("AIRTABLE_API_KEY is not set")
|
||||
|
||||
credentials = APIKeyCredentials(
|
||||
provider="airtable",
|
||||
api_key=SecretStr(key),
|
||||
)
|
||||
postfix = uuid4().hex[:4]
|
||||
base_id = "appZPxegHEU3kDc1S"
|
||||
table_name = f"test_table_{postfix}"
|
||||
table_fields = [{"name": "test_field", "type": "notValid"}]
|
||||
with pytest.raises(AssertionError):
|
||||
await create_table(credentials, base_id, table_name, table_fields)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_update_field():
|
||||
key = getenv("AIRTABLE_API_KEY")
|
||||
if not key:
|
||||
return pytest.skip("AIRTABLE_API_KEY is not set")
|
||||
|
||||
credentials = APIKeyCredentials(
|
||||
provider="airtable",
|
||||
api_key=SecretStr(key),
|
||||
)
|
||||
postfix = uuid4().hex[:4]
|
||||
base_id = "appZPxegHEU3kDc1S"
|
||||
table_name = f"test_table_{postfix}"
|
||||
table_fields = [{"name": "test_field", "type": "singleLineText"}]
|
||||
table = await create_table(credentials, base_id, table_name, table_fields)
|
||||
assert table.get("name") == table_name
|
||||
|
||||
table_id = table.get("id")
|
||||
|
||||
assert table_id is not None
|
||||
|
||||
field_name = f"test_field_{postfix}"
|
||||
field_type = TableFieldType.SINGLE_LINE_TEXT
|
||||
field = await create_field(credentials, base_id, table_id, field_type, field_name)
|
||||
assert field.get("name") == field_name
|
||||
|
||||
field_id = field.get("id")
|
||||
|
||||
assert field_id is not None
|
||||
assert isinstance(field_id, str)
|
||||
|
||||
field_name = f"test_field_updated_{postfix}"
|
||||
field = await update_field(credentials, base_id, table_id, field_id, field_name)
|
||||
assert field.get("name") == field_name
|
||||
|
||||
field_description = "test_description_updated"
|
||||
field = await update_field(
|
||||
credentials, base_id, table_id, field_id, description=field_description
|
||||
)
|
||||
assert field.get("description") == field_description
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_management():
|
||||
key = getenv("AIRTABLE_API_KEY")
|
||||
if not key:
|
||||
return pytest.skip("AIRTABLE_API_KEY is not set")
|
||||
|
||||
credentials = APIKeyCredentials(
|
||||
provider="airtable",
|
||||
api_key=SecretStr(key),
|
||||
)
|
||||
postfix = uuid4().hex[:4]
|
||||
base_id = "appZPxegHEU3kDc1S"
|
||||
table_name = f"test_table_{postfix}"
|
||||
table_fields = [{"name": "test_field", "type": "singleLineText"}]
|
||||
table = await create_table(credentials, base_id, table_name, table_fields)
|
||||
assert table.get("name") == table_name
|
||||
|
||||
table_id = table.get("id")
|
||||
assert table_id is not None
|
||||
|
||||
# Create a record
|
||||
record_fields = {"test_field": "test_value"}
|
||||
record = await create_record(credentials, base_id, table_id, fields=record_fields)
|
||||
fields = record.get("fields")
|
||||
assert fields is not None
|
||||
assert isinstance(fields, dict)
|
||||
assert fields.get("test_field") == "test_value"
|
||||
|
||||
record_id = record.get("id")
|
||||
|
||||
assert record_id is not None
|
||||
assert isinstance(record_id, str)
|
||||
|
||||
# Get a record
|
||||
record = await get_record(credentials, base_id, table_id, record_id)
|
||||
fields = record.get("fields")
|
||||
assert fields is not None
|
||||
assert isinstance(fields, dict)
|
||||
assert fields.get("test_field") == "test_value"
|
||||
|
||||
# Updata a record
|
||||
record_fields = {"test_field": "test_value_updated"}
|
||||
record = await update_record(
|
||||
credentials, base_id, table_id, record_id, fields=record_fields
|
||||
)
|
||||
fields = record.get("fields")
|
||||
assert fields is not None
|
||||
assert isinstance(fields, dict)
|
||||
assert fields.get("test_field") == "test_value_updated"
|
||||
|
||||
# Delete a record
|
||||
record = await delete_record(credentials, base_id, table_id, record_id)
|
||||
assert record is not None
|
||||
assert record.get("id") == record_id
|
||||
assert record.get("deleted")
|
||||
|
||||
# Create 2 records
|
||||
records = [
|
||||
{"fields": {"test_field": "test_value_1"}},
|
||||
{"fields": {"test_field": "test_value_2"}},
|
||||
]
|
||||
response = await create_record(credentials, base_id, table_id, records=records)
|
||||
created_records = response.get("records")
|
||||
assert created_records is not None
|
||||
assert isinstance(created_records, list)
|
||||
assert len(created_records) == 2, f"Created records: {created_records}"
|
||||
first_record = created_records[0] # type: ignore
|
||||
second_record = created_records[1] # type: ignore
|
||||
first_record_id = first_record.get("id")
|
||||
second_record_id = second_record.get("id")
|
||||
assert first_record_id is not None
|
||||
assert second_record_id is not None
|
||||
assert first_record_id != second_record_id
|
||||
first_fields = first_record.get("fields")
|
||||
second_fields = second_record.get("fields")
|
||||
assert first_fields is not None
|
||||
assert second_fields is not None
|
||||
assert first_fields.get("test_field") == "test_value_1" # type: ignore
|
||||
assert second_fields.get("test_field") == "test_value_2" # type: ignore
|
||||
|
||||
# List records
|
||||
response = await list_records(credentials, base_id, table_id)
|
||||
records = response.get("records")
|
||||
assert records is not None
|
||||
assert len(records) == 2, f"Records: {records}"
|
||||
assert isinstance(records, list), f"Type of records: {type(records)}"
|
||||
|
||||
# Update multiple records
|
||||
records = [
|
||||
{"id": first_record_id, "fields": {"test_field": "test_value_1_updated"}},
|
||||
{"id": second_record_id, "fields": {"test_field": "test_value_2_updated"}},
|
||||
]
|
||||
response = await update_multiple_records(
|
||||
credentials, base_id, table_id, records=records
|
||||
)
|
||||
updated_records = response.get("records")
|
||||
assert updated_records is not None
|
||||
assert len(updated_records) == 2, f"Updated records: {updated_records}"
|
||||
assert isinstance(
|
||||
updated_records, list
|
||||
), f"Type of updated records: {type(updated_records)}"
|
||||
first_updated = updated_records[0] # type: ignore
|
||||
second_updated = updated_records[1] # type: ignore
|
||||
first_updated_fields = first_updated.get("fields")
|
||||
second_updated_fields = second_updated.get("fields")
|
||||
assert first_updated_fields is not None
|
||||
assert second_updated_fields is not None
|
||||
assert first_updated_fields.get("test_field") == "test_value_1_updated" # type: ignore
|
||||
assert second_updated_fields.get("test_field") == "test_value_2_updated" # type: ignore
|
||||
|
||||
# Delete multiple records
|
||||
assert isinstance(first_record_id, str)
|
||||
assert isinstance(second_record_id, str)
|
||||
response = await delete_multiple_records(
|
||||
credentials, base_id, table_id, records=[first_record_id, second_record_id]
|
||||
)
|
||||
deleted_records = response.get("records")
|
||||
assert deleted_records is not None
|
||||
assert len(deleted_records) == 2, f"Deleted records: {deleted_records}"
|
||||
assert isinstance(
|
||||
deleted_records, list
|
||||
), f"Type of deleted records: {type(deleted_records)}"
|
||||
first_deleted = deleted_records[0] # type: ignore
|
||||
second_deleted = deleted_records[1] # type: ignore
|
||||
assert first_deleted.get("deleted")
|
||||
assert second_deleted.get("deleted")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_management():
|
||||
key = getenv("AIRTABLE_API_KEY")
|
||||
if not key:
|
||||
return pytest.skip("AIRTABLE_API_KEY is not set")
|
||||
|
||||
credentials = APIKeyCredentials(
|
||||
provider="airtable",
|
||||
api_key=SecretStr(key),
|
||||
)
|
||||
postfix = uuid4().hex[:4]
|
||||
base_id = "appZPxegHEU3kDc1S"
|
||||
table_name = f"test_table_{postfix}"
|
||||
table_fields = [{"name": "test_field", "type": "singleLineText"}]
|
||||
table = await create_table(credentials, base_id, table_name, table_fields)
|
||||
assert table.get("name") == table_name
|
||||
|
||||
table_id = table.get("id")
|
||||
assert table_id is not None
|
||||
webhook_specification = WebhookSpecification(
|
||||
filters=WebhookFilters(
|
||||
dataTypes=["tableData", "tableFields", "tableMetadata"],
|
||||
changeTypes=["add", "update", "remove"],
|
||||
)
|
||||
)
|
||||
response = await create_webhook(credentials, base_id, webhook_specification)
|
||||
assert response is not None, f"Checking create webhook response: {response}"
|
||||
assert (
|
||||
response.get("id") is not None
|
||||
), f"Checking create webhook response id: {response}"
|
||||
assert (
|
||||
response.get("macSecretBase64") is not None
|
||||
), f"Checking create webhook response macSecretBase64: {response}"
|
||||
|
||||
webhook_id = response.get("id")
|
||||
assert webhook_id is not None, f"Webhook ID: {webhook_id}"
|
||||
assert isinstance(webhook_id, str)
|
||||
|
||||
response = await create_record(
|
||||
credentials, base_id, table_id, fields={"test_field": "test_value"}
|
||||
)
|
||||
assert response is not None, f"Checking create record response: {response}"
|
||||
assert (
|
||||
response.get("id") is not None
|
||||
), f"Checking create record response id: {response}"
|
||||
fields = response.get("fields")
|
||||
assert fields is not None, f"Checking create record response fields: {response}"
|
||||
assert (
|
||||
fields.get("test_field") == "test_value"
|
||||
), f"Checking create record response fields test_field: {response}"
|
||||
|
||||
response = await list_webhook_payloads(credentials, base_id, webhook_id)
|
||||
assert response is not None, f"Checking list webhook payloads response: {response}"
|
||||
|
||||
response = await delete_webhook(credentials, base_id, webhook_id)
|
||||
32
autogpt_platform/backend/backend/blocks/airtable/_config.py
Normal file
32
autogpt_platform/backend/backend/blocks/airtable/_config.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Shared configuration for all Airtable blocks using the SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
from ._oauth import AirtableOAuthHandler, AirtableScope
|
||||
from ._webhook import AirtableWebhookManager
|
||||
|
||||
# Configure the Airtable provider with API key authentication
|
||||
airtable = (
|
||||
ProviderBuilder("airtable")
|
||||
.with_api_key("AIRTABLE_API_KEY", "Airtable Personal Access Token")
|
||||
.with_webhook_manager(AirtableWebhookManager)
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.with_oauth(
|
||||
AirtableOAuthHandler,
|
||||
scopes=[
|
||||
v.value
|
||||
for v in [
|
||||
AirtableScope.DATA_RECORDS_READ,
|
||||
AirtableScope.DATA_RECORDS_WRITE,
|
||||
AirtableScope.SCHEMA_BASES_READ,
|
||||
AirtableScope.SCHEMA_BASES_WRITE,
|
||||
AirtableScope.WEBHOOK_MANAGE,
|
||||
]
|
||||
],
|
||||
client_id_env_var="AIRTABLE_CLIENT_ID",
|
||||
client_secret_env_var="AIRTABLE_CLIENT_SECRET",
|
||||
)
|
||||
.build()
|
||||
)
|
||||
185
autogpt_platform/backend/backend/blocks/airtable/_oauth.py
Normal file
185
autogpt_platform/backend/backend/blocks/airtable/_oauth.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
Airtable OAuth handler implementation.
|
||||
"""
|
||||
|
||||
import time
|
||||
from enum import Enum
|
||||
from logging import getLogger
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import BaseOAuthHandler, OAuth2Credentials, ProviderName, SecretStr
|
||||
|
||||
from ._api import (
|
||||
OAuthTokenResponse,
|
||||
make_oauth_authorize_url,
|
||||
oauth_exchange_code_for_tokens,
|
||||
oauth_refresh_tokens,
|
||||
)
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class AirtableScope(str, Enum):
|
||||
# Basic scopes
|
||||
DATA_RECORDS_READ = "data.records:read"
|
||||
DATA_RECORDS_WRITE = "data.records:write"
|
||||
DATA_RECORD_COMMENTS_READ = "data.recordComments:read"
|
||||
DATA_RECORD_COMMENTS_WRITE = "data.recordComments:write"
|
||||
SCHEMA_BASES_READ = "schema.bases:read"
|
||||
SCHEMA_BASES_WRITE = "schema.bases:write"
|
||||
WEBHOOK_MANAGE = "webhook:manage"
|
||||
BLOCK_MANAGE = "block:manage"
|
||||
USER_EMAIL_READ = "user.email:read"
|
||||
|
||||
# Enterprise member scopes
|
||||
ENTERPRISE_GROUPS_READ = "enterprise.groups:read"
|
||||
WORKSPACES_AND_BASES_READ = "workspacesAndBases:read"
|
||||
WORKSPACES_AND_BASES_WRITE = "workspacesAndBases:write"
|
||||
WORKSPACES_AND_BASES_SHARES_MANAGE = "workspacesAndBases.shares:manage"
|
||||
|
||||
# Enterprise admin scopes
|
||||
ENTERPRISE_SCIM_USERS_AND_GROUPS_MANAGE = "enterprise.scim.usersAndGroups:manage"
|
||||
ENTERPRISE_AUDIT_LOGS_READ = "enterprise.auditLogs:read"
|
||||
ENTERPRISE_CHANGE_EVENTS_READ = "enterprise.changeEvents:read"
|
||||
ENTERPRISE_EXPORTS_MANAGE = "enterprise.exports:manage"
|
||||
ENTERPRISE_ACCOUNT_READ = "enterprise.account:read"
|
||||
ENTERPRISE_ACCOUNT_WRITE = "enterprise.account:write"
|
||||
ENTERPRISE_USER_READ = "enterprise.user:read"
|
||||
ENTERPRISE_USER_WRITE = "enterprise.user:write"
|
||||
ENTERPRISE_GROUPS_MANAGE = "enterprise.groups:manage"
|
||||
WORKSPACES_AND_BASES_MANAGE = "workspacesAndBases:manage"
|
||||
HYPERDB_RECORDS_READ = "hyperDB.records:read"
|
||||
HYPERDB_RECORDS_WRITE = "hyperDB.records:write"
|
||||
|
||||
|
||||
class AirtableOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
OAuth2 handler for Airtable with PKCE support.
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = ProviderName("airtable")
|
||||
DEFAULT_SCOPES = [
|
||||
v.value
|
||||
for v in [
|
||||
AirtableScope.DATA_RECORDS_READ,
|
||||
AirtableScope.DATA_RECORDS_WRITE,
|
||||
AirtableScope.SCHEMA_BASES_READ,
|
||||
AirtableScope.SCHEMA_BASES_WRITE,
|
||||
AirtableScope.WEBHOOK_MANAGE,
|
||||
]
|
||||
]
|
||||
|
||||
def __init__(self, client_id: str, client_secret: Optional[str], redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.scopes = self.DEFAULT_SCOPES
|
||||
self.auth_base_url = "https://airtable.com/oauth2/v1/authorize"
|
||||
self.token_url = "https://airtable.com/oauth2/v1/token"
|
||||
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
logger.debug("Generating Airtable OAuth login URL")
|
||||
# Generate code_challenge if not provided (PKCE is required)
|
||||
if not scopes:
|
||||
logger.debug("No scopes provided, using default scopes")
|
||||
scopes = self.scopes
|
||||
|
||||
logger.debug(f"Using scopes: {scopes}")
|
||||
logger.debug(f"State: {state}")
|
||||
logger.debug(f"Code challenge: {code_challenge}")
|
||||
if not code_challenge:
|
||||
logger.error("Code challenge is required but none was provided")
|
||||
raise ValueError("No code challenge provided")
|
||||
|
||||
try:
|
||||
url = make_oauth_authorize_url(
|
||||
self.client_id, self.redirect_uri, scopes, state, code_challenge
|
||||
)
|
||||
logger.debug(f"Generated OAuth URL: {url}")
|
||||
return url
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate OAuth URL: {str(e)}")
|
||||
raise
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
) -> OAuth2Credentials:
|
||||
logger.debug("Exchanging authorization code for tokens")
|
||||
logger.debug(f"Code: {code[:4]}...") # Log first 4 chars only for security
|
||||
logger.debug(f"Scopes: {scopes}")
|
||||
if not code_verifier:
|
||||
logger.error("Code verifier is required but none was provided")
|
||||
raise ValueError("No code verifier provided")
|
||||
|
||||
try:
|
||||
response: OAuthTokenResponse = await oauth_exchange_code_for_tokens(
|
||||
client_id=self.client_id,
|
||||
code=code,
|
||||
code_verifier=code_verifier.encode("utf-8"),
|
||||
redirect_uri=self.redirect_uri,
|
||||
client_secret=self.client_secret,
|
||||
)
|
||||
logger.info("Successfully exchanged code for tokens")
|
||||
|
||||
credentials = OAuth2Credentials(
|
||||
access_token=SecretStr(response.access_token),
|
||||
refresh_token=SecretStr(response.refresh_token),
|
||||
access_token_expires_at=int(time.time()) + response.expires_in,
|
||||
refresh_token_expires_at=int(time.time()) + response.refresh_expires_in,
|
||||
provider=self.PROVIDER_NAME,
|
||||
scopes=scopes,
|
||||
)
|
||||
logger.debug(f"Access token expires in {response.expires_in} seconds")
|
||||
logger.debug(
|
||||
f"Refresh token expires in {response.refresh_expires_in} seconds"
|
||||
)
|
||||
return credentials
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to exchange code for tokens: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _refresh_tokens(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
logger.debug("Attempting to refresh OAuth tokens")
|
||||
|
||||
if credentials.refresh_token is None:
|
||||
logger.error("Cannot refresh tokens - no refresh token available")
|
||||
raise ValueError("No refresh token available")
|
||||
|
||||
try:
|
||||
response: OAuthTokenResponse = await oauth_refresh_tokens(
|
||||
client_id=self.client_id,
|
||||
refresh_token=credentials.refresh_token.get_secret_value(),
|
||||
client_secret=self.client_secret,
|
||||
)
|
||||
logger.info("Successfully refreshed tokens")
|
||||
|
||||
new_credentials = OAuth2Credentials(
|
||||
id=credentials.id,
|
||||
access_token=SecretStr(response.access_token),
|
||||
refresh_token=SecretStr(response.refresh_token),
|
||||
access_token_expires_at=int(time.time()) + response.expires_in,
|
||||
refresh_token_expires_at=int(time.time()) + response.refresh_expires_in,
|
||||
provider=self.PROVIDER_NAME,
|
||||
scopes=self.scopes,
|
||||
)
|
||||
logger.debug(f"New access token expires in {response.expires_in} seconds")
|
||||
logger.debug(
|
||||
f"New refresh token expires in {response.refresh_expires_in} seconds"
|
||||
)
|
||||
return new_credentials
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh tokens: {str(e)}")
|
||||
raise
|
||||
|
||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
logger.debug("Token revocation requested")
|
||||
logger.info(
|
||||
"Airtable doesn't provide a token revocation endpoint - tokens will expire naturally after 60 minutes"
|
||||
)
|
||||
return False
|
||||
154
autogpt_platform/backend/backend/blocks/airtable/_webhook.py
Normal file
154
autogpt_platform/backend/backend/blocks/airtable/_webhook.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Webhook management for Airtable blocks.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
from backend.sdk import (
|
||||
BaseWebhooksManager,
|
||||
Credentials,
|
||||
ProviderName,
|
||||
Webhook,
|
||||
update_webhook,
|
||||
)
|
||||
|
||||
from ._api import (
|
||||
WebhookFilters,
|
||||
WebhookSpecification,
|
||||
create_webhook,
|
||||
delete_webhook,
|
||||
list_webhook_payloads,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AirtableWebhookEvent(str, Enum):
|
||||
TABLE_DATA = "tableData"
|
||||
TABLE_FIELDS = "tableFields"
|
||||
TABLE_METADATA = "tableMetadata"
|
||||
|
||||
|
||||
class AirtableWebhookManager(BaseWebhooksManager):
|
||||
"""Webhook manager for Airtable API."""
|
||||
|
||||
PROVIDER_NAME = ProviderName("airtable")
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(
|
||||
cls, webhook: Webhook, request, credentials: Credentials | None
|
||||
) -> tuple[dict, str]:
|
||||
"""Validate incoming webhook payload and signature."""
|
||||
|
||||
if not credentials:
|
||||
raise ValueError("Missing credentials in webhook metadata")
|
||||
|
||||
payload = await request.json()
|
||||
|
||||
# Verify webhook signature using HMAC-SHA256
|
||||
if webhook.secret:
|
||||
mac_secret = webhook.config.get("mac_secret")
|
||||
if mac_secret:
|
||||
# Get the raw body for signature verification
|
||||
body = await request.body()
|
||||
|
||||
# Calculate expected signature
|
||||
mac_secret_decoded = mac_secret.encode()
|
||||
hmac_obj = hmac.new(mac_secret_decoded, body, hashlib.sha256)
|
||||
expected_mac = f"hmac-sha256={hmac_obj.hexdigest()}"
|
||||
|
||||
# Get signature from headers
|
||||
signature = request.headers.get("X-Airtable-Content-MAC")
|
||||
|
||||
if signature and not hmac.compare_digest(signature, expected_mac):
|
||||
raise ValueError("Invalid webhook signature")
|
||||
|
||||
# Validate payload structure
|
||||
required_fields = ["base", "webhook", "timestamp"]
|
||||
if not all(field in payload for field in required_fields):
|
||||
raise ValueError("Invalid webhook payload structure")
|
||||
|
||||
if "id" not in payload["base"] or "id" not in payload["webhook"]:
|
||||
raise ValueError("Missing required IDs in webhook payload")
|
||||
base_id = payload["base"]["id"]
|
||||
webhook_id = payload["webhook"]["id"]
|
||||
|
||||
# get payload request parameters
|
||||
cursor = webhook.config.get("cursor", 1)
|
||||
|
||||
response = await list_webhook_payloads(credentials, base_id, webhook_id, cursor)
|
||||
|
||||
# update webhook config
|
||||
await update_webhook(
|
||||
webhook.id,
|
||||
config={"base_id": base_id, "cursor": response.cursor},
|
||||
)
|
||||
|
||||
event_type = "notification"
|
||||
return response.model_dump(), event_type
|
||||
|
||||
async def _register_webhook(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
ingress_url: str,
|
||||
secret: str,
|
||||
) -> tuple[str, dict]:
|
||||
"""Register webhook with Airtable API."""
|
||||
|
||||
# Parse resource to get base_id and table_id/name
|
||||
# Resource format: "{base_id}/{table_id_or_name}"
|
||||
parts = resource.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError("Resource must be in format: {base_id}/{table_id_or_name}")
|
||||
|
||||
base_id, table_id_or_name = parts
|
||||
|
||||
# Prepare webhook specification
|
||||
webhook_specification = WebhookSpecification(
|
||||
filters=WebhookFilters(
|
||||
dataTypes=events,
|
||||
)
|
||||
)
|
||||
|
||||
# Create webhook
|
||||
webhook_data = await create_webhook(
|
||||
credentials=credentials,
|
||||
base_id=base_id,
|
||||
webhook_specification=webhook_specification,
|
||||
notification_url=ingress_url,
|
||||
)
|
||||
|
||||
webhook_id = webhook_data["id"]
|
||||
mac_secret = webhook_data.get("macSecretBase64")
|
||||
|
||||
return webhook_id, {
|
||||
"webhook_id": webhook_id,
|
||||
"base_id": base_id,
|
||||
"table_id_or_name": table_id_or_name,
|
||||
"events": events,
|
||||
"mac_secret": mac_secret,
|
||||
"cursor": 1,
|
||||
"expiration_time": webhook_data.get("expirationTime"),
|
||||
}
|
||||
|
||||
async def _deregister_webhook(
|
||||
self, webhook: Webhook, credentials: Credentials
|
||||
) -> None:
|
||||
"""Deregister webhook from Airtable API."""
|
||||
|
||||
base_id = webhook.config.get("base_id")
|
||||
webhook_id = webhook.config.get("webhook_id")
|
||||
|
||||
if not base_id:
|
||||
raise ValueError("Missing base_id in webhook metadata")
|
||||
|
||||
if not webhook_id:
|
||||
raise ValueError("Missing webhook_id in webhook metadata")
|
||||
|
||||
await delete_webhook(credentials, base_id, webhook_id)
|
||||
157
autogpt_platform/backend/backend/blocks/airtable/bases.py
Normal file
157
autogpt_platform/backend/backend/blocks/airtable/bases.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Airtable base operation blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import create_base, get_base_tables, list_bases
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableCreateBaseBlock(Block):
|
||||
"""
|
||||
Creates a new base in an Airtable workspace, or returns existing base if one with the same name exists.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
workspace_id: str = SchemaField(
|
||||
description="The workspace ID where the base will be created"
|
||||
)
|
||||
name: str = SchemaField(description="The name of the new base")
|
||||
find_existing: bool = SchemaField(
|
||||
description="If true, return existing base with same name instead of creating duplicate",
|
||||
default=True,
|
||||
)
|
||||
tables: list[dict] = SchemaField(
|
||||
description="At least one table and field must be specified. Array of table objects to create in the base. Each table should have 'name' and 'fields' properties",
|
||||
default=[
|
||||
{
|
||||
"description": "Default table",
|
||||
"name": "Default table",
|
||||
"fields": [
|
||||
{
|
||||
"name": "ID",
|
||||
"type": "number",
|
||||
"description": "Auto-incrementing ID field",
|
||||
"options": {"precision": 0},
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
base_id: str = SchemaField(description="The ID of the created or found base")
|
||||
tables: list[dict] = SchemaField(description="Array of table objects")
|
||||
table: dict = SchemaField(description="A single table object")
|
||||
was_created: bool = SchemaField(
|
||||
description="True if a new base was created, False if existing was found",
|
||||
default=True,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f59b88a8-54ce-4676-a508-fd614b4e8dce",
|
||||
description="Create or find a base in Airtable",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# If find_existing is true, check if a base with this name already exists
|
||||
if input_data.find_existing:
|
||||
# List all bases to check for existing one with same name
|
||||
# Note: Airtable API doesn't have a direct search, so we need to list and filter
|
||||
existing_bases = await list_bases(credentials)
|
||||
|
||||
for base in existing_bases.get("bases", []):
|
||||
if base.get("name") == input_data.name:
|
||||
# Base already exists, return it
|
||||
base_id = base.get("id")
|
||||
yield "base_id", base_id
|
||||
yield "was_created", False
|
||||
|
||||
# Get the tables for this base
|
||||
try:
|
||||
tables = await get_base_tables(credentials, base_id)
|
||||
yield "tables", tables
|
||||
for table in tables:
|
||||
yield "table", table
|
||||
except Exception:
|
||||
# If we can't get tables, return empty list
|
||||
yield "tables", []
|
||||
return
|
||||
|
||||
# No existing base found or find_existing is false, create new one
|
||||
data = await create_base(
|
||||
credentials,
|
||||
input_data.workspace_id,
|
||||
input_data.name,
|
||||
input_data.tables,
|
||||
)
|
||||
|
||||
yield "base_id", data.get("id", None)
|
||||
yield "was_created", True
|
||||
yield "tables", data.get("tables", [])
|
||||
for table in data.get("tables", []):
|
||||
yield "table", table
|
||||
|
||||
|
||||
class AirtableListBasesBlock(Block):
|
||||
"""
|
||||
Lists all bases in an Airtable workspace that the user has access to.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
trigger: str = SchemaField(
|
||||
description="Trigger the block to run - value is ignored", default="manual"
|
||||
)
|
||||
offset: str = SchemaField(
|
||||
description="Pagination offset from previous request", default=""
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
bases: list[dict] = SchemaField(description="Array of base objects")
|
||||
offset: Optional[str] = SchemaField(
|
||||
description="Offset for next page (null if no more bases)", default=None
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4bd8d466-ed5d-4e44-8083-97f25a8044e7",
|
||||
description="List all bases in Airtable",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
data = await list_bases(
|
||||
credentials,
|
||||
offset=input_data.offset if input_data.offset else None,
|
||||
)
|
||||
|
||||
yield "bases", data.get("bases", [])
|
||||
yield "offset", data.get("offset", None)
|
||||
380
autogpt_platform/backend/backend/blocks/airtable/records.py
Normal file
380
autogpt_platform/backend/backend/blocks/airtable/records.py
Normal file
@@ -0,0 +1,380 @@
|
||||
"""
|
||||
Airtable record operation blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional, cast
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import (
|
||||
create_record,
|
||||
delete_multiple_records,
|
||||
get_record,
|
||||
get_table_schema,
|
||||
list_records,
|
||||
normalize_records,
|
||||
update_multiple_records,
|
||||
)
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableListRecordsBlock(Block):
|
||||
"""
|
||||
Lists records from an Airtable table with optional filtering, sorting, and pagination.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name")
|
||||
filter_formula: str = SchemaField(
|
||||
description="Airtable formula to filter records", default=""
|
||||
)
|
||||
view: str = SchemaField(description="View ID or name to use", default="")
|
||||
sort: list[dict] = SchemaField(
|
||||
description="Sort configuration (array of {field, direction})", default=[]
|
||||
)
|
||||
max_records: int = SchemaField(
|
||||
description="Maximum number of records to return", default=100
|
||||
)
|
||||
page_size: int = SchemaField(
|
||||
description="Number of records per page (max 100)", default=100
|
||||
)
|
||||
offset: str = SchemaField(
|
||||
description="Pagination offset from previous request", default=""
|
||||
)
|
||||
return_fields: list[str] = SchemaField(
|
||||
description="Specific fields to return (comma-separated)", default=[]
|
||||
)
|
||||
normalize_output: bool = SchemaField(
|
||||
description="Normalize output to include all fields with proper empty values (disable to skip schema fetch and get raw Airtable response)",
|
||||
default=True,
|
||||
)
|
||||
include_field_metadata: bool = SchemaField(
|
||||
description="Include field type and configuration metadata (requires normalize_output=true)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
records: list[dict] = SchemaField(description="Array of record objects")
|
||||
offset: Optional[str] = SchemaField(
|
||||
description="Offset for next page (null if no more records)", default=None
|
||||
)
|
||||
field_metadata: Optional[dict] = SchemaField(
|
||||
description="Field type and configuration metadata (only when include_field_metadata=true)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="588a9fde-5733-4da7-b03c-35f5671e960f",
|
||||
description="List records from an Airtable table",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
data = await list_records(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id_or_name,
|
||||
filter_by_formula=(
|
||||
input_data.filter_formula if input_data.filter_formula else None
|
||||
),
|
||||
view=input_data.view if input_data.view else None,
|
||||
sort=input_data.sort if input_data.sort else None,
|
||||
max_records=input_data.max_records if input_data.max_records else None,
|
||||
page_size=min(input_data.page_size, 100) if input_data.page_size else None,
|
||||
offset=input_data.offset if input_data.offset else None,
|
||||
fields=input_data.return_fields if input_data.return_fields else None,
|
||||
)
|
||||
|
||||
records = data.get("records", [])
|
||||
|
||||
# Normalize output if requested
|
||||
if input_data.normalize_output:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the records
|
||||
normalized_data = await normalize_records(
|
||||
records,
|
||||
table_schema,
|
||||
include_field_metadata=input_data.include_field_metadata,
|
||||
)
|
||||
|
||||
yield "records", normalized_data["records"]
|
||||
yield "offset", data.get("offset", None)
|
||||
|
||||
if (
|
||||
input_data.include_field_metadata
|
||||
and "field_metadata" in normalized_data
|
||||
):
|
||||
yield "field_metadata", normalized_data["field_metadata"]
|
||||
else:
|
||||
yield "records", records
|
||||
yield "offset", data.get("offset", None)
|
||||
|
||||
|
||||
class AirtableGetRecordBlock(Block):
|
||||
"""
|
||||
Retrieves a single record from an Airtable table by its ID.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name")
|
||||
record_id: str = SchemaField(description="The record ID to retrieve")
|
||||
normalize_output: bool = SchemaField(
|
||||
description="Normalize output to include all fields with proper empty values (disable to skip schema fetch and get raw Airtable response)",
|
||||
default=True,
|
||||
)
|
||||
include_field_metadata: bool = SchemaField(
|
||||
description="Include field type and configuration metadata (requires normalize_output=true)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
id: str = SchemaField(description="The record ID")
|
||||
fields: dict = SchemaField(description="The record fields")
|
||||
created_time: str = SchemaField(description="The record created time")
|
||||
field_metadata: Optional[dict] = SchemaField(
|
||||
description="Field type and configuration metadata (only when include_field_metadata=true)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c29c5cbf-0aff-40f9-bbb5-f26061792d2b",
|
||||
description="Get a single record from Airtable",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
record = await get_record(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id_or_name,
|
||||
input_data.record_id,
|
||||
)
|
||||
|
||||
# Normalize output if requested
|
||||
if input_data.normalize_output:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the single record (wrap in list and unwrap result)
|
||||
normalized_data = await normalize_records(
|
||||
[record],
|
||||
table_schema,
|
||||
include_field_metadata=input_data.include_field_metadata,
|
||||
)
|
||||
|
||||
normalized_record = normalized_data["records"][0]
|
||||
yield "id", normalized_record.get("id", None)
|
||||
yield "fields", normalized_record.get("fields", None)
|
||||
yield "created_time", normalized_record.get("createdTime", None)
|
||||
|
||||
if (
|
||||
input_data.include_field_metadata
|
||||
and "field_metadata" in normalized_data
|
||||
):
|
||||
yield "field_metadata", normalized_data["field_metadata"]
|
||||
else:
|
||||
yield "id", record.get("id", None)
|
||||
yield "fields", record.get("fields", None)
|
||||
yield "created_time", record.get("createdTime", None)
|
||||
|
||||
|
||||
class AirtableCreateRecordsBlock(Block):
|
||||
"""
|
||||
Creates one or more records in an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name")
|
||||
records: list[dict] = SchemaField(
|
||||
description="Array of records to create (each with 'fields' object)"
|
||||
)
|
||||
skip_normalization: bool = SchemaField(
|
||||
description="Skip output normalization to get raw Airtable response (faster but may have missing fields)",
|
||||
default=False,
|
||||
)
|
||||
typecast: bool = SchemaField(
|
||||
description="Automatically convert string values to appropriate types",
|
||||
default=False,
|
||||
)
|
||||
return_fields_by_field_id: bool | None = SchemaField(
|
||||
description="Return fields by field ID",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
records: list[dict] = SchemaField(description="Array of created record objects")
|
||||
details: dict = SchemaField(description="Details of the created records")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="42527e98-47b6-44ce-ac0e-86b4883721d3",
|
||||
description="Create records in an Airtable table",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
data = await create_record(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id_or_name,
|
||||
records=[{"fields": record} for record in input_data.records],
|
||||
typecast=input_data.typecast if input_data.typecast else None,
|
||||
return_fields_by_field_id=input_data.return_fields_by_field_id,
|
||||
)
|
||||
result_records = cast(list[dict], data.get("records", []))
|
||||
|
||||
# Normalize output unless explicitly disabled
|
||||
if not input_data.skip_normalization and result_records:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the records
|
||||
normalized_data = await normalize_records(
|
||||
result_records, table_schema, include_field_metadata=False
|
||||
)
|
||||
result_records = normalized_data["records"]
|
||||
|
||||
yield "records", result_records
|
||||
details = data.get("details", None)
|
||||
if details:
|
||||
yield "details", details
|
||||
|
||||
|
||||
class AirtableUpdateRecordsBlock(Block):
|
||||
"""
|
||||
Updates one or more existing records in an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(
|
||||
description="Table ID or name - It's better to use the table ID instead of the name"
|
||||
)
|
||||
records: list[dict] = SchemaField(
|
||||
description="Array of records to update (each with 'id' and 'fields')"
|
||||
)
|
||||
typecast: bool | None = SchemaField(
|
||||
description="Automatically convert string values to appropriate types",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
records: list[dict] = SchemaField(description="Array of updated record objects")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="6e7d2590-ac2b-4b5d-b08c-fc039cd77e1f",
|
||||
description="Update records in an Airtable table",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# The update_multiple_records API expects records with id and fields
|
||||
data = await update_multiple_records(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id_or_name,
|
||||
records=input_data.records,
|
||||
typecast=input_data.typecast if input_data.typecast else None,
|
||||
return_fields_by_field_id=False, # Use field names, not IDs
|
||||
)
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
|
||||
|
||||
class AirtableDeleteRecordsBlock(Block):
|
||||
"""
|
||||
Deletes one or more records from an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(
|
||||
description="Table ID or name - It's better to use the table ID instead of the name"
|
||||
)
|
||||
record_ids: list[str] = SchemaField(
|
||||
description="Array of upto 10 record IDs to delete"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
records: list[dict] = SchemaField(description="Array of deletion results")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="93e22b8b-3642-4477-aefb-1c0929a4a3a6",
|
||||
description="Delete records from an Airtable table",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
if len(input_data.record_ids) > 10:
|
||||
yield "error", "Only upto 10 record IDs can be deleted at a time"
|
||||
else:
|
||||
data = await delete_multiple_records(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id_or_name,
|
||||
input_data.record_ids,
|
||||
)
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
253
autogpt_platform/backend/backend/blocks/airtable/schema.py
Normal file
253
autogpt_platform/backend/backend/blocks/airtable/schema.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
Airtable schema and table management blocks.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import TableFieldType, create_field, create_table, update_field, update_table
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableListSchemaBlock(Block):
|
||||
"""
|
||||
Retrieves the complete schema of an Airtable base, including all tables,
|
||||
fields, and views.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
base_schema: dict = SchemaField(
|
||||
description="Complete base schema with tables, fields, and views"
|
||||
)
|
||||
tables: list[dict] = SchemaField(description="Array of table objects")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="64291d3c-99b5-47b7-a976-6d94293cdb2d",
|
||||
description="Get the complete schema of an Airtable base",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Get base schema
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/tables",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "base_schema", data
|
||||
yield "tables", data.get("tables", [])
|
||||
|
||||
|
||||
class AirtableCreateTableBlock(Block):
|
||||
"""
|
||||
Creates a new table in an Airtable base with specified fields and views.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_name: str = SchemaField(description="The name of the table to create")
|
||||
table_fields: list[dict] = SchemaField(
|
||||
description="Table fields with name, type, and options",
|
||||
default=[{"name": "Name", "type": "singleLineText"}],
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
table: dict = SchemaField(description="Created table object")
|
||||
table_id: str = SchemaField(description="ID of the created table")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fcc20ced-d817-42ea-9b40-c35e7bf34b4f",
|
||||
description="Create a new table in an Airtable base",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
table_data = await create_table(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_name,
|
||||
input_data.table_fields,
|
||||
)
|
||||
|
||||
yield "table", table_data
|
||||
yield "table_id", table_data.get("id", "")
|
||||
|
||||
|
||||
class AirtableUpdateTableBlock(Block):
|
||||
"""
|
||||
Updates an existing table's properties such as name or description.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id: str = SchemaField(description="The table ID to update")
|
||||
table_name: str | None = SchemaField(
|
||||
description="The name of the table to update", default=None
|
||||
)
|
||||
table_description: str | None = SchemaField(
|
||||
description="The description of the table to update", default=None
|
||||
)
|
||||
date_dependency: dict | None = SchemaField(
|
||||
description="The date dependency of the table to update", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
table: dict = SchemaField(description="Updated table object")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="34077c5f-f962-49f2-9ec6-97c67077013a",
|
||||
description="Update table properties",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
table_data = await update_table(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id,
|
||||
input_data.table_name,
|
||||
input_data.table_description,
|
||||
input_data.date_dependency,
|
||||
)
|
||||
|
||||
yield "table", table_data
|
||||
|
||||
|
||||
class AirtableCreateFieldBlock(Block):
|
||||
"""
|
||||
Adds a new field (column) to an existing Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id: str = SchemaField(description="The table ID to add field to")
|
||||
field_type: TableFieldType = SchemaField(
|
||||
description="The type of the field to create",
|
||||
default=TableFieldType.SINGLE_LINE_TEXT,
|
||||
advanced=False,
|
||||
)
|
||||
name: str = SchemaField(description="The name of the field to create")
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the field to create", default=None
|
||||
)
|
||||
options: dict[str, str] | None = SchemaField(
|
||||
description="The options of the field to create", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
field: dict = SchemaField(description="Created field object")
|
||||
field_id: str = SchemaField(description="ID of the created field")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="6c98a32f-dbf9-45d8-a2a8-5e97e8326351",
|
||||
description="Add a new field to an Airtable table",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
field_data = await create_field(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id,
|
||||
input_data.field_type,
|
||||
input_data.name,
|
||||
)
|
||||
|
||||
yield "field", field_data
|
||||
yield "field_id", field_data.get("id", "")
|
||||
|
||||
|
||||
class AirtableUpdateFieldBlock(Block):
|
||||
"""
|
||||
Updates an existing field's properties in an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id: str = SchemaField(description="The table ID containing the field")
|
||||
field_id: str = SchemaField(description="The field ID to update")
|
||||
name: str | None = SchemaField(
|
||||
description="The name of the field to update", default=None, advanced=False
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the field to update",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
field: dict = SchemaField(description="Updated field object")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f46ac716-3b18-4da1-92e4-34ca9a464d48",
|
||||
description="Update field properties in an Airtable table",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
field_data = await update_field(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id,
|
||||
input_data.field_id,
|
||||
input_data.name,
|
||||
input_data.description,
|
||||
)
|
||||
|
||||
yield "field", field_data
|
||||
114
autogpt_platform/backend/backend/blocks/airtable/triggers.py
Normal file
114
autogpt_platform/backend/backend/blocks/airtable/triggers.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from backend.sdk import (
|
||||
BaseModel,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
BlockWebhookConfig,
|
||||
CredentialsMetaInput,
|
||||
ProviderName,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import WebhookPayload
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableEventSelector(BaseModel):
|
||||
"""
|
||||
Selects the Airtable webhook event to trigger on.
|
||||
"""
|
||||
|
||||
tableData: bool = True
|
||||
tableFields: bool = True
|
||||
tableMetadata: bool = True
|
||||
|
||||
|
||||
class AirtableWebhookTriggerBlock(Block):
|
||||
"""
|
||||
Starts a flow whenever Airtable emits a webhook event.
|
||||
|
||||
Thin wrapper just forwards the payloads one at a time to the next block.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(description="Airtable table ID or name")
|
||||
payload: dict = SchemaField(hidden=True, default_factory=dict)
|
||||
events: AirtableEventSelector = SchemaField(
|
||||
description="Airtable webhook event filter"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
payload: WebhookPayload = SchemaField(description="Airtable webhook payload")
|
||||
|
||||
def __init__(self):
|
||||
example_payload = {
|
||||
"payloads": [
|
||||
{
|
||||
"timestamp": "2022-02-01T21:25:05.663Z",
|
||||
"baseTransactionNumber": 4,
|
||||
"actionMetadata": {
|
||||
"source": "client",
|
||||
"sourceMetadata": {
|
||||
"user": {
|
||||
"id": "usr00000000000000",
|
||||
"email": "foo@bar.com",
|
||||
"permissionLevel": "create",
|
||||
}
|
||||
},
|
||||
},
|
||||
"payloadFormat": "v0",
|
||||
}
|
||||
],
|
||||
"cursor": 5,
|
||||
"mightHaveMore": False,
|
||||
}
|
||||
|
||||
super().__init__(
|
||||
# NOTE: This is disabled whilst the webhook system is finalised.
|
||||
disabled=False,
|
||||
id="d0180ce6-ccb9-48c7-8256-b39e93e62801",
|
||||
description="Starts a flow whenever Airtable emits a webhook event",
|
||||
categories={BlockCategory.INPUT, BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
block_type=BlockType.WEBHOOK,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName("airtable"),
|
||||
webhook_type="not-used",
|
||||
event_filter_input="events",
|
||||
event_format="{event}",
|
||||
resource_format="{base_id}/{table_id_or_name}",
|
||||
),
|
||||
test_input={
|
||||
"credentials": airtable.get_test_credentials().model_dump(),
|
||||
"base_id": "app1234567890",
|
||||
"table_id_or_name": "table1234567890",
|
||||
"events": AirtableEventSelector(
|
||||
tableData=True,
|
||||
tableFields=True,
|
||||
tableMetadata=False,
|
||||
).model_dump(),
|
||||
"payload": example_payload,
|
||||
},
|
||||
test_credentials=airtable.get_test_credentials(),
|
||||
test_output=[
|
||||
(
|
||||
"payload",
|
||||
WebhookPayload.model_validate(example_payload["payloads"][0]),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
if len(input_data.payload["payloads"]) > 0:
|
||||
for item in input_data.payload["payloads"]:
|
||||
yield "payload", WebhookPayload.model_validate(item)
|
||||
else:
|
||||
yield "error", "No valid payloads found in webhook payload"
|
||||
@@ -4,6 +4,7 @@ from typing import List
|
||||
from backend.blocks.apollo._auth import ApolloCredentials
|
||||
from backend.blocks.apollo.models import (
|
||||
Contact,
|
||||
EnrichPersonRequest,
|
||||
Organization,
|
||||
SearchOrganizationsRequest,
|
||||
SearchOrganizationsResponse,
|
||||
@@ -110,3 +111,21 @@ class ApolloClient:
|
||||
return (
|
||||
organizations[: query.max_results] if query.max_results else organizations
|
||||
)
|
||||
|
||||
async def enrich_person(self, query: EnrichPersonRequest) -> Contact:
|
||||
"""Enrich a person's data including email & phone reveal"""
|
||||
response = await self.requests.post(
|
||||
f"{self.API_URL}/people/match",
|
||||
headers=self._get_headers(),
|
||||
json=query.model_dump(),
|
||||
params={
|
||||
"reveal_personal_emails": "true",
|
||||
},
|
||||
)
|
||||
data = response.json()
|
||||
if "person" not in data:
|
||||
raise ValueError(f"Person not found or enrichment failed: {data}")
|
||||
|
||||
contact = Contact(**data["person"])
|
||||
contact.email = contact.email or "-"
|
||||
return contact
|
||||
|
||||
@@ -23,9 +23,9 @@ class BaseModel(OriginalBaseModel):
|
||||
class PrimaryPhone(BaseModel):
|
||||
"""A primary phone in Apollo"""
|
||||
|
||||
number: str = ""
|
||||
source: str = ""
|
||||
sanitized_number: str = ""
|
||||
number: Optional[str] = ""
|
||||
source: Optional[str] = ""
|
||||
sanitized_number: Optional[str] = ""
|
||||
|
||||
|
||||
class SenorityLevels(str, Enum):
|
||||
@@ -56,102 +56,102 @@ class ContactEmailStatuses(str, Enum):
|
||||
class RuleConfigStatus(BaseModel):
|
||||
"""A rule config status in Apollo"""
|
||||
|
||||
_id: str = ""
|
||||
created_at: str = ""
|
||||
rule_action_config_id: str = ""
|
||||
rule_config_id: str = ""
|
||||
status_cd: str = ""
|
||||
updated_at: str = ""
|
||||
id: str = ""
|
||||
key: str = ""
|
||||
_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
rule_action_config_id: Optional[str] = ""
|
||||
rule_config_id: Optional[str] = ""
|
||||
status_cd: Optional[str] = ""
|
||||
updated_at: Optional[str] = ""
|
||||
id: Optional[str] = ""
|
||||
key: Optional[str] = ""
|
||||
|
||||
|
||||
class ContactCampaignStatus(BaseModel):
|
||||
"""A contact campaign status in Apollo"""
|
||||
|
||||
id: str = ""
|
||||
emailer_campaign_id: str = ""
|
||||
send_email_from_user_id: str = ""
|
||||
inactive_reason: str = ""
|
||||
status: str = ""
|
||||
added_at: str = ""
|
||||
added_by_user_id: str = ""
|
||||
finished_at: str = ""
|
||||
paused_at: str = ""
|
||||
auto_unpause_at: str = ""
|
||||
send_email_from_email_address: str = ""
|
||||
send_email_from_email_account_id: str = ""
|
||||
manually_set_unpause: str = ""
|
||||
failure_reason: str = ""
|
||||
current_step_id: str = ""
|
||||
in_response_to_emailer_message_id: str = ""
|
||||
cc_emails: str = ""
|
||||
bcc_emails: str = ""
|
||||
to_emails: str = ""
|
||||
id: Optional[str] = ""
|
||||
emailer_campaign_id: Optional[str] = ""
|
||||
send_email_from_user_id: Optional[str] = ""
|
||||
inactive_reason: Optional[str] = ""
|
||||
status: Optional[str] = ""
|
||||
added_at: Optional[str] = ""
|
||||
added_by_user_id: Optional[str] = ""
|
||||
finished_at: Optional[str] = ""
|
||||
paused_at: Optional[str] = ""
|
||||
auto_unpause_at: Optional[str] = ""
|
||||
send_email_from_email_address: Optional[str] = ""
|
||||
send_email_from_email_account_id: Optional[str] = ""
|
||||
manually_set_unpause: Optional[str] = ""
|
||||
failure_reason: Optional[str] = ""
|
||||
current_step_id: Optional[str] = ""
|
||||
in_response_to_emailer_message_id: Optional[str] = ""
|
||||
cc_emails: Optional[str] = ""
|
||||
bcc_emails: Optional[str] = ""
|
||||
to_emails: Optional[str] = ""
|
||||
|
||||
|
||||
class Account(BaseModel):
|
||||
"""An account in Apollo"""
|
||||
|
||||
id: str = ""
|
||||
name: str = ""
|
||||
website_url: str = ""
|
||||
blog_url: str = ""
|
||||
angellist_url: str = ""
|
||||
linkedin_url: str = ""
|
||||
twitter_url: str = ""
|
||||
facebook_url: str = ""
|
||||
primary_phone: PrimaryPhone = PrimaryPhone()
|
||||
languages: list[str]
|
||||
alexa_ranking: int = 0
|
||||
phone: str = ""
|
||||
linkedin_uid: str = ""
|
||||
founded_year: int = 0
|
||||
publicly_traded_symbol: str = ""
|
||||
publicly_traded_exchange: str = ""
|
||||
logo_url: str = ""
|
||||
chrunchbase_url: str = ""
|
||||
primary_domain: str = ""
|
||||
domain: str = ""
|
||||
team_id: str = ""
|
||||
organization_id: str = ""
|
||||
account_stage_id: str = ""
|
||||
source: str = ""
|
||||
original_source: str = ""
|
||||
creator_id: str = ""
|
||||
owner_id: str = ""
|
||||
created_at: str = ""
|
||||
phone_status: str = ""
|
||||
hubspot_id: str = ""
|
||||
salesforce_id: str = ""
|
||||
crm_owner_id: str = ""
|
||||
parent_account_id: str = ""
|
||||
sanitized_phone: str = ""
|
||||
id: Optional[str] = ""
|
||||
name: Optional[str] = ""
|
||||
website_url: Optional[str] = ""
|
||||
blog_url: Optional[str] = ""
|
||||
angellist_url: Optional[str] = ""
|
||||
linkedin_url: Optional[str] = ""
|
||||
twitter_url: Optional[str] = ""
|
||||
facebook_url: Optional[str] = ""
|
||||
primary_phone: Optional[PrimaryPhone] = PrimaryPhone()
|
||||
languages: Optional[list[str]] = []
|
||||
alexa_ranking: Optional[int] = 0
|
||||
phone: Optional[str] = ""
|
||||
linkedin_uid: Optional[str] = ""
|
||||
founded_year: Optional[int] = 0
|
||||
publicly_traded_symbol: Optional[str] = ""
|
||||
publicly_traded_exchange: Optional[str] = ""
|
||||
logo_url: Optional[str] = ""
|
||||
chrunchbase_url: Optional[str] = ""
|
||||
primary_domain: Optional[str] = ""
|
||||
domain: Optional[str] = ""
|
||||
team_id: Optional[str] = ""
|
||||
organization_id: Optional[str] = ""
|
||||
account_stage_id: Optional[str] = ""
|
||||
source: Optional[str] = ""
|
||||
original_source: Optional[str] = ""
|
||||
creator_id: Optional[str] = ""
|
||||
owner_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
phone_status: Optional[str] = ""
|
||||
hubspot_id: Optional[str] = ""
|
||||
salesforce_id: Optional[str] = ""
|
||||
crm_owner_id: Optional[str] = ""
|
||||
parent_account_id: Optional[str] = ""
|
||||
sanitized_phone: Optional[str] = ""
|
||||
# no listed type on the API docs
|
||||
account_playbook_statues: list[Any] = []
|
||||
account_rule_config_statuses: list[RuleConfigStatus] = []
|
||||
existence_level: str = ""
|
||||
label_ids: list[str] = []
|
||||
typed_custom_fields: Any
|
||||
custom_field_errors: Any
|
||||
modality: str = ""
|
||||
source_display_name: str = ""
|
||||
salesforce_record_id: str = ""
|
||||
crm_record_url: str = ""
|
||||
account_playbook_statues: Optional[list[Any]] = []
|
||||
account_rule_config_statuses: Optional[list[RuleConfigStatus]] = []
|
||||
existence_level: Optional[str] = ""
|
||||
label_ids: Optional[list[str]] = []
|
||||
typed_custom_fields: Optional[Any] = {}
|
||||
custom_field_errors: Optional[Any] = {}
|
||||
modality: Optional[str] = ""
|
||||
source_display_name: Optional[str] = ""
|
||||
salesforce_record_id: Optional[str] = ""
|
||||
crm_record_url: Optional[str] = ""
|
||||
|
||||
|
||||
class ContactEmail(BaseModel):
|
||||
"""A contact email in Apollo"""
|
||||
|
||||
email: str = ""
|
||||
email_md5: str = ""
|
||||
email_sha256: str = ""
|
||||
email_status: str = ""
|
||||
email_source: str = ""
|
||||
extrapolated_email_confidence: str = ""
|
||||
position: int = 0
|
||||
email_from_customer: str = ""
|
||||
free_domain: bool = True
|
||||
email: Optional[str] = ""
|
||||
email_md5: Optional[str] = ""
|
||||
email_sha256: Optional[str] = ""
|
||||
email_status: Optional[str] = ""
|
||||
email_source: Optional[str] = ""
|
||||
extrapolated_email_confidence: Optional[str] = ""
|
||||
position: Optional[int] = 0
|
||||
email_from_customer: Optional[str] = ""
|
||||
free_domain: Optional[bool] = True
|
||||
|
||||
|
||||
class EmploymentHistory(BaseModel):
|
||||
@@ -164,40 +164,40 @@ class EmploymentHistory(BaseModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
_id: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
current: Optional[bool] = None
|
||||
degree: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
emails: Optional[str] = None
|
||||
end_date: Optional[str] = None
|
||||
grade_level: Optional[str] = None
|
||||
kind: Optional[str] = None
|
||||
major: Optional[str] = None
|
||||
organization_id: Optional[str] = None
|
||||
organization_name: Optional[str] = None
|
||||
raw_address: Optional[str] = None
|
||||
start_date: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
id: Optional[str] = None
|
||||
key: Optional[str] = None
|
||||
_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
current: Optional[bool] = False
|
||||
degree: Optional[str] = ""
|
||||
description: Optional[str] = ""
|
||||
emails: Optional[str] = ""
|
||||
end_date: Optional[str] = ""
|
||||
grade_level: Optional[str] = ""
|
||||
kind: Optional[str] = ""
|
||||
major: Optional[str] = ""
|
||||
organization_id: Optional[str] = ""
|
||||
organization_name: Optional[str] = ""
|
||||
raw_address: Optional[str] = ""
|
||||
start_date: Optional[str] = ""
|
||||
title: Optional[str] = ""
|
||||
updated_at: Optional[str] = ""
|
||||
id: Optional[str] = ""
|
||||
key: Optional[str] = ""
|
||||
|
||||
|
||||
class Breadcrumb(BaseModel):
|
||||
"""A breadcrumb in Apollo"""
|
||||
|
||||
label: Optional[str] = "N/A"
|
||||
signal_field_name: Optional[str] = "N/A"
|
||||
value: str | list | None = "N/A"
|
||||
display_name: Optional[str] = "N/A"
|
||||
label: Optional[str] = ""
|
||||
signal_field_name: Optional[str] = ""
|
||||
value: str | list | None = ""
|
||||
display_name: Optional[str] = ""
|
||||
|
||||
|
||||
class TypedCustomField(BaseModel):
|
||||
"""A typed custom field in Apollo"""
|
||||
|
||||
id: Optional[str] = "N/A"
|
||||
value: Optional[str] = "N/A"
|
||||
id: Optional[str] = ""
|
||||
value: Optional[str] = ""
|
||||
|
||||
|
||||
class Pagination(BaseModel):
|
||||
@@ -219,23 +219,23 @@ class Pagination(BaseModel):
|
||||
class DialerFlags(BaseModel):
|
||||
"""A dialer flags in Apollo"""
|
||||
|
||||
country_name: str = ""
|
||||
country_enabled: bool
|
||||
high_risk_calling_enabled: bool
|
||||
potential_high_risk_number: bool
|
||||
country_name: Optional[str] = ""
|
||||
country_enabled: Optional[bool] = True
|
||||
high_risk_calling_enabled: Optional[bool] = True
|
||||
potential_high_risk_number: Optional[bool] = True
|
||||
|
||||
|
||||
class PhoneNumber(BaseModel):
|
||||
"""A phone number in Apollo"""
|
||||
|
||||
raw_number: str = ""
|
||||
sanitized_number: str = ""
|
||||
type: str = ""
|
||||
position: int = 0
|
||||
status: str = ""
|
||||
dnc_status: str = ""
|
||||
dnc_other_info: str = ""
|
||||
dailer_flags: DialerFlags = DialerFlags(
|
||||
raw_number: Optional[str] = ""
|
||||
sanitized_number: Optional[str] = ""
|
||||
type: Optional[str] = ""
|
||||
position: Optional[int] = 0
|
||||
status: Optional[str] = ""
|
||||
dnc_status: Optional[str] = ""
|
||||
dnc_other_info: Optional[str] = ""
|
||||
dailer_flags: Optional[DialerFlags] = DialerFlags(
|
||||
country_name="",
|
||||
country_enabled=True,
|
||||
high_risk_calling_enabled=True,
|
||||
@@ -253,33 +253,31 @@ class Organization(BaseModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
id: Optional[str] = "N/A"
|
||||
name: Optional[str] = "N/A"
|
||||
website_url: Optional[str] = "N/A"
|
||||
blog_url: Optional[str] = "N/A"
|
||||
angellist_url: Optional[str] = "N/A"
|
||||
linkedin_url: Optional[str] = "N/A"
|
||||
twitter_url: Optional[str] = "N/A"
|
||||
facebook_url: Optional[str] = "N/A"
|
||||
primary_phone: Optional[PrimaryPhone] = PrimaryPhone(
|
||||
number="N/A", source="N/A", sanitized_number="N/A"
|
||||
)
|
||||
languages: list[str] = []
|
||||
id: Optional[str] = ""
|
||||
name: Optional[str] = ""
|
||||
website_url: Optional[str] = ""
|
||||
blog_url: Optional[str] = ""
|
||||
angellist_url: Optional[str] = ""
|
||||
linkedin_url: Optional[str] = ""
|
||||
twitter_url: Optional[str] = ""
|
||||
facebook_url: Optional[str] = ""
|
||||
primary_phone: Optional[PrimaryPhone] = PrimaryPhone()
|
||||
languages: Optional[list[str]] = []
|
||||
alexa_ranking: Optional[int] = 0
|
||||
phone: Optional[str] = "N/A"
|
||||
linkedin_uid: Optional[str] = "N/A"
|
||||
phone: Optional[str] = ""
|
||||
linkedin_uid: Optional[str] = ""
|
||||
founded_year: Optional[int] = 0
|
||||
publicly_traded_symbol: Optional[str] = "N/A"
|
||||
publicly_traded_exchange: Optional[str] = "N/A"
|
||||
logo_url: Optional[str] = "N/A"
|
||||
chrunchbase_url: Optional[str] = "N/A"
|
||||
primary_domain: Optional[str] = "N/A"
|
||||
sanitized_phone: Optional[str] = "N/A"
|
||||
owned_by_organization_id: Optional[str] = "N/A"
|
||||
intent_strength: Optional[str] = "N/A"
|
||||
show_intent: bool = True
|
||||
publicly_traded_symbol: Optional[str] = ""
|
||||
publicly_traded_exchange: Optional[str] = ""
|
||||
logo_url: Optional[str] = ""
|
||||
chrunchbase_url: Optional[str] = ""
|
||||
primary_domain: Optional[str] = ""
|
||||
sanitized_phone: Optional[str] = ""
|
||||
owned_by_organization_id: Optional[str] = ""
|
||||
intent_strength: Optional[str] = ""
|
||||
show_intent: Optional[bool] = True
|
||||
has_intent_signal_account: Optional[bool] = True
|
||||
intent_signal_account: Optional[str] = "N/A"
|
||||
intent_signal_account: Optional[str] = ""
|
||||
|
||||
|
||||
class Contact(BaseModel):
|
||||
@@ -292,95 +290,95 @@ class Contact(BaseModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
contact_roles: list[Any] = []
|
||||
id: Optional[str] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
linkedin_url: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
contact_stage_id: Optional[str] = None
|
||||
owner_id: Optional[str] = None
|
||||
creator_id: Optional[str] = None
|
||||
person_id: Optional[str] = None
|
||||
email_needs_tickling: bool = True
|
||||
organization_name: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
original_source: Optional[str] = None
|
||||
organization_id: Optional[str] = None
|
||||
headline: Optional[str] = None
|
||||
photo_url: Optional[str] = None
|
||||
present_raw_address: Optional[str] = None
|
||||
linkededin_uid: Optional[str] = None
|
||||
extrapolated_email_confidence: Optional[float] = None
|
||||
salesforce_id: Optional[str] = None
|
||||
salesforce_lead_id: Optional[str] = None
|
||||
salesforce_contact_id: Optional[str] = None
|
||||
saleforce_account_id: Optional[str] = None
|
||||
crm_owner_id: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
emailer_campaign_ids: list[str] = []
|
||||
direct_dial_status: Optional[str] = None
|
||||
direct_dial_enrichment_failed_at: Optional[str] = None
|
||||
email_status: Optional[str] = None
|
||||
email_source: Optional[str] = None
|
||||
account_id: Optional[str] = None
|
||||
last_activity_date: Optional[str] = None
|
||||
hubspot_vid: Optional[str] = None
|
||||
hubspot_company_id: Optional[str] = None
|
||||
crm_id: Optional[str] = None
|
||||
sanitized_phone: Optional[str] = None
|
||||
merged_crm_ids: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
queued_for_crm_push: bool = True
|
||||
suggested_from_rule_engine_config_id: Optional[str] = None
|
||||
email_unsubscribed: Optional[str] = None
|
||||
label_ids: list[Any] = []
|
||||
has_pending_email_arcgate_request: bool = True
|
||||
has_email_arcgate_request: bool = True
|
||||
existence_level: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
email_from_customer: Optional[str] = None
|
||||
typed_custom_fields: list[TypedCustomField] = []
|
||||
custom_field_errors: Any = None
|
||||
salesforce_record_id: Optional[str] = None
|
||||
crm_record_url: Optional[str] = None
|
||||
email_status_unavailable_reason: Optional[str] = None
|
||||
email_true_status: Optional[str] = None
|
||||
updated_email_true_status: bool = True
|
||||
contact_rule_config_statuses: list[RuleConfigStatus] = []
|
||||
source_display_name: Optional[str] = None
|
||||
twitter_url: Optional[str] = None
|
||||
contact_campaign_statuses: list[ContactCampaignStatus] = []
|
||||
state: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
account: Optional[Account] = None
|
||||
contact_emails: list[ContactEmail] = []
|
||||
organization: Optional[Organization] = None
|
||||
employment_history: list[EmploymentHistory] = []
|
||||
time_zone: Optional[str] = None
|
||||
intent_strength: Optional[str] = None
|
||||
show_intent: bool = True
|
||||
phone_numbers: list[PhoneNumber] = []
|
||||
account_phone_note: Optional[str] = None
|
||||
free_domain: bool = True
|
||||
is_likely_to_engage: bool = True
|
||||
email_domain_catchall: bool = True
|
||||
contact_job_change_event: Optional[str] = None
|
||||
contact_roles: Optional[list[Any]] = []
|
||||
id: Optional[str] = ""
|
||||
first_name: Optional[str] = ""
|
||||
last_name: Optional[str] = ""
|
||||
name: Optional[str] = ""
|
||||
linkedin_url: Optional[str] = ""
|
||||
title: Optional[str] = ""
|
||||
contact_stage_id: Optional[str] = ""
|
||||
owner_id: Optional[str] = ""
|
||||
creator_id: Optional[str] = ""
|
||||
person_id: Optional[str] = ""
|
||||
email_needs_tickling: Optional[bool] = True
|
||||
organization_name: Optional[str] = ""
|
||||
source: Optional[str] = ""
|
||||
original_source: Optional[str] = ""
|
||||
organization_id: Optional[str] = ""
|
||||
headline: Optional[str] = ""
|
||||
photo_url: Optional[str] = ""
|
||||
present_raw_address: Optional[str] = ""
|
||||
linkededin_uid: Optional[str] = ""
|
||||
extrapolated_email_confidence: Optional[float] = 0.0
|
||||
salesforce_id: Optional[str] = ""
|
||||
salesforce_lead_id: Optional[str] = ""
|
||||
salesforce_contact_id: Optional[str] = ""
|
||||
saleforce_account_id: Optional[str] = ""
|
||||
crm_owner_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
emailer_campaign_ids: Optional[list[str]] = []
|
||||
direct_dial_status: Optional[str] = ""
|
||||
direct_dial_enrichment_failed_at: Optional[str] = ""
|
||||
email_status: Optional[str] = ""
|
||||
email_source: Optional[str] = ""
|
||||
account_id: Optional[str] = ""
|
||||
last_activity_date: Optional[str] = ""
|
||||
hubspot_vid: Optional[str] = ""
|
||||
hubspot_company_id: Optional[str] = ""
|
||||
crm_id: Optional[str] = ""
|
||||
sanitized_phone: Optional[str] = ""
|
||||
merged_crm_ids: Optional[str] = ""
|
||||
updated_at: Optional[str] = ""
|
||||
queued_for_crm_push: Optional[bool] = True
|
||||
suggested_from_rule_engine_config_id: Optional[str] = ""
|
||||
email_unsubscribed: Optional[str] = ""
|
||||
label_ids: Optional[list[Any]] = []
|
||||
has_pending_email_arcgate_request: Optional[bool] = True
|
||||
has_email_arcgate_request: Optional[bool] = True
|
||||
existence_level: Optional[str] = ""
|
||||
email: Optional[str] = ""
|
||||
email_from_customer: Optional[str] = ""
|
||||
typed_custom_fields: Optional[list[TypedCustomField]] = []
|
||||
custom_field_errors: Optional[Any] = {}
|
||||
salesforce_record_id: Optional[str] = ""
|
||||
crm_record_url: Optional[str] = ""
|
||||
email_status_unavailable_reason: Optional[str] = ""
|
||||
email_true_status: Optional[str] = ""
|
||||
updated_email_true_status: Optional[bool] = True
|
||||
contact_rule_config_statuses: Optional[list[RuleConfigStatus]] = []
|
||||
source_display_name: Optional[str] = ""
|
||||
twitter_url: Optional[str] = ""
|
||||
contact_campaign_statuses: Optional[list[ContactCampaignStatus]] = []
|
||||
state: Optional[str] = ""
|
||||
city: Optional[str] = ""
|
||||
country: Optional[str] = ""
|
||||
account: Optional[Account] = Account()
|
||||
contact_emails: Optional[list[ContactEmail]] = []
|
||||
organization: Optional[Organization] = Organization()
|
||||
employment_history: Optional[list[EmploymentHistory]] = []
|
||||
time_zone: Optional[str] = ""
|
||||
intent_strength: Optional[str] = ""
|
||||
show_intent: Optional[bool] = True
|
||||
phone_numbers: Optional[list[PhoneNumber]] = []
|
||||
account_phone_note: Optional[str] = ""
|
||||
free_domain: Optional[bool] = True
|
||||
is_likely_to_engage: Optional[bool] = True
|
||||
email_domain_catchall: Optional[bool] = True
|
||||
contact_job_change_event: Optional[str] = ""
|
||||
|
||||
|
||||
class SearchOrganizationsRequest(BaseModel):
|
||||
"""Request for Apollo's search organizations API"""
|
||||
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
organization_num_employees_range: Optional[list[int]] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default=[0, 1000000],
|
||||
)
|
||||
|
||||
organization_locations: list[str] = SchemaField(
|
||||
organization_locations: Optional[list[str]] = SchemaField(
|
||||
description="""The location of the company headquarters. You can search across cities, US states, and countries.
|
||||
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, any Boston-based companies will not appearch in your search results, even if they match other parameters.
|
||||
@@ -389,28 +387,30 @@ To exclude companies based on location, use the organization_not_locations param
|
||||
""",
|
||||
default_factory=list,
|
||||
)
|
||||
organizations_not_locations: list[str] = SchemaField(
|
||||
organizations_not_locations: Optional[list[str]] = SchemaField(
|
||||
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
|
||||
|
||||
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
|
||||
""",
|
||||
default_factory=list,
|
||||
)
|
||||
q_organization_keyword_tags: list[str] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry."""
|
||||
q_organization_keyword_tags: Optional[list[str]] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry.""",
|
||||
default_factory=list,
|
||||
)
|
||||
q_organization_name: str = SchemaField(
|
||||
q_organization_name: Optional[str] = SchemaField(
|
||||
description="""Filter search results to include a specific company name.
|
||||
|
||||
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible."""
|
||||
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible.""",
|
||||
default="",
|
||||
)
|
||||
organization_ids: list[str] = SchemaField(
|
||||
organization_ids: Optional[list[str]] = SchemaField(
|
||||
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, identify the values for organization_id when you call this endpoint.""",
|
||||
default_factory=list,
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
max_results: Optional[int] = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
ge=1,
|
||||
@@ -435,11 +435,11 @@ Use the page parameter to search the different pages of data.""",
|
||||
class SearchOrganizationsResponse(BaseModel):
|
||||
"""Response from Apollo's search organizations API"""
|
||||
|
||||
breadcrumbs: list[Breadcrumb] = []
|
||||
partial_results_only: bool = True
|
||||
has_join: bool = True
|
||||
disable_eu_prospecting: bool = True
|
||||
partial_results_limit: int = 0
|
||||
breadcrumbs: Optional[list[Breadcrumb]] = []
|
||||
partial_results_only: Optional[bool] = True
|
||||
has_join: Optional[bool] = True
|
||||
disable_eu_prospecting: Optional[bool] = True
|
||||
partial_results_limit: Optional[int] = 0
|
||||
pagination: Pagination = Pagination(
|
||||
page=0, per_page=0, total_entries=0, total_pages=0
|
||||
)
|
||||
@@ -447,14 +447,14 @@ class SearchOrganizationsResponse(BaseModel):
|
||||
accounts: list[Any] = []
|
||||
organizations: list[Organization] = []
|
||||
models_ids: list[str] = []
|
||||
num_fetch_result: Optional[str] = "N/A"
|
||||
derived_params: Optional[str] = "N/A"
|
||||
num_fetch_result: Optional[str] = ""
|
||||
derived_params: Optional[str] = ""
|
||||
|
||||
|
||||
class SearchPeopleRequest(BaseModel):
|
||||
"""Request for Apollo's search people API"""
|
||||
|
||||
person_titles: list[str] = SchemaField(
|
||||
person_titles: Optional[list[str]] = SchemaField(
|
||||
description="""Job titles held by the people you want to find. For a person to be included in search results, they only need to match 1 of the job titles you add. Adding more job titles expands your search results.
|
||||
|
||||
Results also include job titles with the same terms, even if they are not exact matches. For example, searching for marketing manager might return people with the job title content marketing manager.
|
||||
@@ -464,13 +464,13 @@ Use this parameter in combination with the person_seniorities[] parameter to fin
|
||||
default_factory=list,
|
||||
placeholder="marketing manager",
|
||||
)
|
||||
person_locations: list[str] = SchemaField(
|
||||
person_locations: Optional[list[str]] = SchemaField(
|
||||
description="""The location where people live. You can search across cities, US states, and countries.
|
||||
|
||||
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
|
||||
default_factory=list,
|
||||
)
|
||||
person_seniorities: list[SenorityLevels] = SchemaField(
|
||||
person_seniorities: Optional[list[SenorityLevels]] = SchemaField(
|
||||
description="""The job seniority that people hold within their current employer. This enables you to find people that currently hold positions at certain reporting levels, such as Director level or senior IC level.
|
||||
|
||||
For a person to be included in search results, they only need to match 1 of the seniorities you add. Adding more seniorities expands your search results.
|
||||
@@ -480,7 +480,7 @@ Searches only return results based on their current job title, so searching for
|
||||
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
|
||||
default_factory=list,
|
||||
)
|
||||
organization_locations: list[str] = SchemaField(
|
||||
organization_locations: Optional[list[str]] = SchemaField(
|
||||
description="""The location of the company headquarters for a person's current employer. You can search across cities, US states, and countries.
|
||||
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
|
||||
@@ -488,7 +488,7 @@ If a company has several office locations, results are still based on the headqu
|
||||
To find people based on their personal location, use the person_locations parameter.""",
|
||||
default_factory=list,
|
||||
)
|
||||
q_organization_domains: list[str] = SchemaField(
|
||||
q_organization_domains: Optional[list[str]] = SchemaField(
|
||||
description="""The domain name for the person's employer. This can be the current employer or a previous employer. Do not include www., the @ symbol, or similar.
|
||||
|
||||
You can add multiple domains to search across companies.
|
||||
@@ -496,23 +496,23 @@ You can add multiple domains to search across companies.
|
||||
Examples: apollo.io and microsoft.com""",
|
||||
default_factory=list,
|
||||
)
|
||||
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
|
||||
contact_email_statuses: Optional[list[ContactEmailStatuses]] = SchemaField(
|
||||
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
|
||||
default_factory=list,
|
||||
)
|
||||
organization_ids: list[str] = SchemaField(
|
||||
organization_ids: Optional[list[str]] = SchemaField(
|
||||
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
|
||||
default_factory=list,
|
||||
)
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
organization_num_employees_range: Optional[list[int]] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default_factory=list,
|
||||
)
|
||||
q_keywords: str = SchemaField(
|
||||
q_keywords: Optional[str] = SchemaField(
|
||||
description="""A string of words over which we want to filter the results""",
|
||||
default="",
|
||||
)
|
||||
@@ -528,7 +528,7 @@ Use this parameter in combination with the per_page parameter to make search res
|
||||
Use the page parameter to search the different pages of data.""",
|
||||
default=100,
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
max_results: Optional[int] = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
ge=1,
|
||||
@@ -547,16 +547,61 @@ class SearchPeopleResponse(BaseModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
breadcrumbs: list[Breadcrumb] = []
|
||||
partial_results_only: bool = True
|
||||
has_join: bool = True
|
||||
disable_eu_prospecting: bool = True
|
||||
partial_results_limit: int = 0
|
||||
breadcrumbs: Optional[list[Breadcrumb]] = []
|
||||
partial_results_only: Optional[bool] = True
|
||||
has_join: Optional[bool] = True
|
||||
disable_eu_prospecting: Optional[bool] = True
|
||||
partial_results_limit: Optional[int] = 0
|
||||
pagination: Pagination = Pagination(
|
||||
page=0, per_page=0, total_entries=0, total_pages=0
|
||||
)
|
||||
contacts: list[Contact] = []
|
||||
people: list[Contact] = []
|
||||
model_ids: list[str] = []
|
||||
num_fetch_result: Optional[str] = "N/A"
|
||||
derived_params: Optional[str] = "N/A"
|
||||
num_fetch_result: Optional[str] = ""
|
||||
derived_params: Optional[str] = ""
|
||||
|
||||
|
||||
class EnrichPersonRequest(BaseModel):
|
||||
"""Request for Apollo's person enrichment API"""
|
||||
|
||||
person_id: Optional[str] = SchemaField(
|
||||
description="Apollo person ID to enrich (most accurate method)",
|
||||
default="",
|
||||
)
|
||||
first_name: Optional[str] = SchemaField(
|
||||
description="First name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
last_name: Optional[str] = SchemaField(
|
||||
description="Last name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
name: Optional[str] = SchemaField(
|
||||
description="Full name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
email: Optional[str] = SchemaField(
|
||||
description="Email address of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
domain: Optional[str] = SchemaField(
|
||||
description="Company domain of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
company: Optional[str] = SchemaField(
|
||||
description="Company name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
linkedin_url: Optional[str] = SchemaField(
|
||||
description="LinkedIn URL of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
organization_id: Optional[str] = SchemaField(
|
||||
description="Apollo organization ID of the person's company",
|
||||
default="",
|
||||
)
|
||||
title: Optional[str] = SchemaField(
|
||||
description="Job title of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
|
||||
@@ -10,15 +10,21 @@ from backend.blocks.apollo.models import (
|
||||
PrimaryPhone,
|
||||
SearchOrganizationsRequest,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class SearchOrganizationsBlock(Block):
|
||||
"""Search for organizations in Apollo"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
class Input(BlockSchemaInput):
|
||||
organization_num_employees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
@@ -65,11 +71,11 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
le=50000,
|
||||
advanced=True,
|
||||
)
|
||||
credentials: ApolloCredentialsInput = SchemaField(
|
||||
credentials: ApolloCredentialsInput = CredentialsField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
organizations: list[Organization] = SchemaField(
|
||||
description="List of organizations found",
|
||||
default_factory=list,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
|
||||
from backend.blocks.apollo._api import ApolloClient
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
@@ -8,17 +10,24 @@ from backend.blocks.apollo._auth import (
|
||||
from backend.blocks.apollo.models import (
|
||||
Contact,
|
||||
ContactEmailStatuses,
|
||||
EnrichPersonRequest,
|
||||
SearchPeopleRequest,
|
||||
SenorityLevels,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class SearchPeopleBlock(Block):
|
||||
"""Search for people in Apollo"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
person_titles: list[str] = SchemaField(
|
||||
description="""Job titles held by the people you want to find. For a person to be included in search results, they only need to match 1 of the job titles you add. Adding more job titles expands your search results.
|
||||
|
||||
@@ -77,7 +86,7 @@ class SearchPeopleBlock(Block):
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
organization_num_employees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
@@ -90,26 +99,27 @@ class SearchPeopleBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 25. Limited to 500 to prevent overspending.""",
|
||||
default=25,
|
||||
ge=1,
|
||||
le=50000,
|
||||
le=500,
|
||||
advanced=True,
|
||||
)
|
||||
enrich_info: bool = SchemaField(
|
||||
description="""Whether to enrich contacts with detailed information including real email addresses. This will double the search cost.""",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
credentials: ApolloCredentialsInput = SchemaField(
|
||||
credentials: ApolloCredentialsInput = CredentialsField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
people: list[Contact] = SchemaField(
|
||||
description="List of people found",
|
||||
default_factory=list,
|
||||
)
|
||||
person: Contact = SchemaField(
|
||||
title="Person",
|
||||
description="Each found person, one at a time",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the search failed",
|
||||
default="",
|
||||
@@ -125,87 +135,6 @@ class SearchPeopleBlock(Block):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
(
|
||||
"person",
|
||||
Contact(
|
||||
contact_roles=[],
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
organization_id="123456",
|
||||
contact_stage_id="1",
|
||||
owner_id="1",
|
||||
creator_id="1",
|
||||
person_id="1",
|
||||
email_needs_tickling=True,
|
||||
source="apollo",
|
||||
original_source="apollo",
|
||||
headline="Software Engineer",
|
||||
photo_url="https://www.linkedin.com/in/johndoe",
|
||||
present_raw_address="123 Main St, Anytown, USA",
|
||||
linkededin_uid="123456",
|
||||
extrapolated_email_confidence=0.8,
|
||||
salesforce_id="123456",
|
||||
salesforce_lead_id="123456",
|
||||
salesforce_contact_id="123456",
|
||||
saleforce_account_id="123456",
|
||||
crm_owner_id="123456",
|
||||
created_at="2021-01-01",
|
||||
emailer_campaign_ids=[],
|
||||
direct_dial_status="active",
|
||||
direct_dial_enrichment_failed_at="2021-01-01",
|
||||
email_status="active",
|
||||
email_source="apollo",
|
||||
account_id="123456",
|
||||
last_activity_date="2021-01-01",
|
||||
hubspot_vid="123456",
|
||||
hubspot_company_id="123456",
|
||||
crm_id="123456",
|
||||
sanitized_phone="123456",
|
||||
merged_crm_ids="123456",
|
||||
updated_at="2021-01-01",
|
||||
queued_for_crm_push=True,
|
||||
suggested_from_rule_engine_config_id="123456",
|
||||
email_unsubscribed=None,
|
||||
label_ids=[],
|
||||
has_pending_email_arcgate_request=True,
|
||||
has_email_arcgate_request=True,
|
||||
existence_level=None,
|
||||
email=None,
|
||||
email_from_customer=None,
|
||||
typed_custom_fields=[],
|
||||
custom_field_errors=None,
|
||||
salesforce_record_id=None,
|
||||
crm_record_url=None,
|
||||
email_status_unavailable_reason=None,
|
||||
email_true_status=None,
|
||||
updated_email_true_status=True,
|
||||
contact_rule_config_statuses=[],
|
||||
source_display_name=None,
|
||||
twitter_url=None,
|
||||
contact_campaign_statuses=[],
|
||||
state=None,
|
||||
city=None,
|
||||
country=None,
|
||||
account=None,
|
||||
contact_emails=[],
|
||||
organization=None,
|
||||
employment_history=[],
|
||||
time_zone=None,
|
||||
intent_strength=None,
|
||||
show_intent=True,
|
||||
phone_numbers=[],
|
||||
account_phone_note=None,
|
||||
free_domain=True,
|
||||
is_likely_to_engage=True,
|
||||
email_domain_catchall=True,
|
||||
contact_job_change_event=None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"people",
|
||||
[
|
||||
@@ -380,6 +309,34 @@ class SearchPeopleBlock(Block):
|
||||
client = ApolloClient(credentials)
|
||||
return await client.search_people(query)
|
||||
|
||||
@staticmethod
|
||||
async def enrich_person(
|
||||
query: EnrichPersonRequest, credentials: ApolloCredentials
|
||||
) -> Contact:
|
||||
client = ApolloClient(credentials)
|
||||
return await client.enrich_person(query)
|
||||
|
||||
@staticmethod
|
||||
def merge_contact_data(original: Contact, enriched: Contact) -> Contact:
|
||||
"""
|
||||
Merge contact data from original search with enriched data.
|
||||
Enriched data complements original data, only filling in missing values.
|
||||
"""
|
||||
merged_data = original.model_dump()
|
||||
enriched_data = enriched.model_dump()
|
||||
|
||||
# Only update fields that are None, empty string, empty list, or default values in original
|
||||
for key, enriched_value in enriched_data.items():
|
||||
# Skip if enriched value is None, empty string, or empty list
|
||||
if enriched_value is None or enriched_value == "" or enriched_value == []:
|
||||
continue
|
||||
|
||||
# Update if original value is None, empty string, empty list, or zero
|
||||
if enriched_value:
|
||||
merged_data[key] = enriched_value
|
||||
|
||||
return Contact(**merged_data)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
@@ -390,6 +347,23 @@ class SearchPeopleBlock(Block):
|
||||
|
||||
query = SearchPeopleRequest(**input_data.model_dump())
|
||||
people = await self.search_people(query, credentials)
|
||||
for person in people:
|
||||
yield "person", person
|
||||
|
||||
# Enrich with detailed info if requested
|
||||
if input_data.enrich_info:
|
||||
|
||||
async def enrich_or_fallback(person: Contact):
|
||||
try:
|
||||
enrich_query = EnrichPersonRequest(person_id=person.id)
|
||||
enriched_person = await self.enrich_person(
|
||||
enrich_query, credentials
|
||||
)
|
||||
# Merge enriched data with original data, complementing instead of replacing
|
||||
return self.merge_contact_data(person, enriched_person)
|
||||
except Exception:
|
||||
return person # If enrichment fails, use original person data
|
||||
|
||||
people = await asyncio.gather(
|
||||
*(enrich_or_fallback(person) for person in people)
|
||||
)
|
||||
|
||||
yield "people", people
|
||||
|
||||
144
autogpt_platform/backend/backend/blocks/apollo/person.py
Normal file
144
autogpt_platform/backend/backend/blocks/apollo/person.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from backend.blocks.apollo._api import ApolloClient
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
ApolloCredentials,
|
||||
ApolloCredentialsInput,
|
||||
)
|
||||
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class GetPersonDetailBlock(Block):
|
||||
"""Get detailed person data with Apollo API, including email reveal"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
person_id: str = SchemaField(
|
||||
description="Apollo person ID to enrich (most accurate method)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
first_name: str = SchemaField(
|
||||
description="First name of the person to enrich",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
last_name: str = SchemaField(
|
||||
description="Last name of the person to enrich",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
name: str = SchemaField(
|
||||
description="Full name of the person to enrich (alternative to first_name + last_name)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
email: str = SchemaField(
|
||||
description="Known email address of the person (helps with matching)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
domain: str = SchemaField(
|
||||
description="Company domain of the person (e.g., 'google.com')",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
company: str = SchemaField(
|
||||
description="Company name of the person",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
linkedin_url: str = SchemaField(
|
||||
description="LinkedIn URL of the person",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
organization_id: str = SchemaField(
|
||||
description="Apollo organization ID of the person's company",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
title: str = SchemaField(
|
||||
description="Job title of the person to enrich",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
credentials: ApolloCredentialsInput = CredentialsField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
contact: Contact = SchemaField(
|
||||
description="Enriched contact information",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if enrichment failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3b18d46c-3db6-42ae-a228-0ba441bdd176",
|
||||
description="Get detailed person data with Apollo API, including email reveal",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=GetPersonDetailBlock.Input,
|
||||
output_schema=GetPersonDetailBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"first_name": "John",
|
||||
"last_name": "Doe",
|
||||
"company": "Google",
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"contact",
|
||||
Contact(
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
email="john.doe@gmail.com",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
),
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"enrich_person": lambda query, credentials: Contact(
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
email="john.doe@gmail.com",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def enrich_person(
|
||||
query: EnrichPersonRequest, credentials: ApolloCredentials
|
||||
) -> Contact:
|
||||
client = ApolloClient(credentials)
|
||||
return await client.enrich_person(query)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: ApolloCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
query = EnrichPersonRequest(**input_data.model_dump())
|
||||
yield "contact", await self.enrich_person(query, credentials)
|
||||
15
autogpt_platform/backend/backend/blocks/ayrshare/__init__.py
Normal file
15
autogpt_platform/backend/backend/blocks/ayrshare/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
AYRSHARE_BLOCK_IDS = [
|
||||
"cbd52c2a-06d2-43ed-9560-6576cc163283", # PostToBlueskyBlock
|
||||
"3352f512-3524-49ed-a08f-003042da2fc1", # PostToFacebookBlock
|
||||
"9e8f844e-b4a5-4b25-80f2-9e1dd7d67625", # PostToXBlock
|
||||
"589af4e4-507f-42fd-b9ac-a67ecef25811", # PostToLinkedInBlock
|
||||
"89b02b96-a7cb-46f4-9900-c48b32fe1552", # PostToInstagramBlock
|
||||
"0082d712-ff1b-4c3d-8a8d-6c7721883b83", # PostToYouTubeBlock
|
||||
"c7733580-3c72-483e-8e47-a8d58754d853", # PostToRedditBlock
|
||||
"47bc74eb-4af2-452c-b933-af377c7287df", # PostToTelegramBlock
|
||||
"2c38c783-c484-4503-9280-ef5d1d345a7e", # PostToGMBBlock
|
||||
"3ca46e05-dbaa-4afb-9e95-5a429c4177e6", # PostToPinterestBlock
|
||||
"7faf4b27-96b0-4f05-bf64-e0de54ae74e1", # PostToTikTokBlock
|
||||
"f8c3b2e1-9d4a-4e5f-8c7b-6a9e8d2f1c3b", # PostToThreadsBlock
|
||||
"a9d7f854-2c83-4e96-b3a1-7f2e9c5d4b8e", # PostToSnapchatBlock
|
||||
]
|
||||
152
autogpt_platform/backend/backend/blocks/ayrshare/_util.py
Normal file
152
autogpt_platform/backend/backend/blocks/ayrshare/_util.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.block import BlockSchemaInput
|
||||
from backend.data.model import SchemaField, UserIntegrations
|
||||
from backend.integrations.ayrshare import AyrshareClient
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
|
||||
|
||||
async def get_profile_key(user_id: str):
|
||||
user_integrations: UserIntegrations = (
|
||||
await get_database_manager_async_client().get_user_integrations(user_id)
|
||||
)
|
||||
return user_integrations.managed_credentials.ayrshare_profile_key
|
||||
|
||||
|
||||
class BaseAyrshareInput(BlockSchemaInput):
|
||||
"""Base input model for Ayrshare social media posts with common fields."""
|
||||
|
||||
post: str = SchemaField(
|
||||
description="The post text to be published", default="", advanced=False
|
||||
)
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs to include. Set is_video in advanced settings to true if you want to upload videos.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
is_video: bool = SchemaField(
|
||||
description="Whether the media is a video", default=False, advanced=True
|
||||
)
|
||||
schedule_date: Optional[datetime] = SchemaField(
|
||||
description="UTC datetime for scheduling (YYYY-MM-DDThh:mm:ssZ)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
disable_comments: bool = SchemaField(
|
||||
description="Whether to disable comments", default=False, advanced=True
|
||||
)
|
||||
shorten_links: bool = SchemaField(
|
||||
description="Whether to shorten links", default=False, advanced=True
|
||||
)
|
||||
unsplash: Optional[str] = SchemaField(
|
||||
description="Unsplash image configuration", default=None, advanced=True
|
||||
)
|
||||
requires_approval: bool = SchemaField(
|
||||
description="Whether to enable approval workflow",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
random_post: bool = SchemaField(
|
||||
description="Whether to generate random post text",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
random_media_url: bool = SchemaField(
|
||||
description="Whether to generate random media", default=False, advanced=True
|
||||
)
|
||||
notes: Optional[str] = SchemaField(
|
||||
description="Additional notes for the post", default=None, advanced=True
|
||||
)
|
||||
|
||||
|
||||
class CarouselItem(BaseModel):
|
||||
"""Model for Facebook carousel items."""
|
||||
|
||||
name: str = Field(..., description="The name of the item")
|
||||
link: str = Field(..., description="The link of the item")
|
||||
picture: str = Field(..., description="The picture URL of the item")
|
||||
|
||||
|
||||
class CallToAction(BaseModel):
|
||||
"""Model for Google My Business Call to Action."""
|
||||
|
||||
action_type: str = Field(
|
||||
..., description="Type of action (book, order, shop, learn_more, sign_up, call)"
|
||||
)
|
||||
url: Optional[str] = Field(
|
||||
description="URL for the action (not required for 'call' action)"
|
||||
)
|
||||
|
||||
|
||||
class EventDetails(BaseModel):
|
||||
"""Model for Google My Business Event details."""
|
||||
|
||||
title: str = Field(..., description="Event title")
|
||||
start_date: str = Field(..., description="Event start date (ISO format)")
|
||||
end_date: str = Field(..., description="Event end date (ISO format)")
|
||||
|
||||
|
||||
class OfferDetails(BaseModel):
|
||||
"""Model for Google My Business Offer details."""
|
||||
|
||||
title: str = Field(..., description="Offer title")
|
||||
start_date: str = Field(..., description="Offer start date (ISO format)")
|
||||
end_date: str = Field(..., description="Offer end date (ISO format)")
|
||||
coupon_code: str = Field(..., description="Coupon code (max 58 characters)")
|
||||
redeem_online_url: str = Field(..., description="URL to redeem the offer")
|
||||
terms_conditions: str = Field(..., description="Terms and conditions")
|
||||
|
||||
|
||||
class InstagramUserTag(BaseModel):
|
||||
"""Model for Instagram user tags."""
|
||||
|
||||
username: str = Field(..., description="Instagram username (without @)")
|
||||
x: Optional[float] = Field(description="X coordinate (0.0-1.0) for image posts")
|
||||
y: Optional[float] = Field(description="Y coordinate (0.0-1.0) for image posts")
|
||||
|
||||
|
||||
class LinkedInTargeting(BaseModel):
|
||||
"""Model for LinkedIn audience targeting."""
|
||||
|
||||
countries: Optional[list[str]] = Field(
|
||||
description="Country codes (e.g., ['US', 'IN', 'DE', 'GB'])"
|
||||
)
|
||||
seniorities: Optional[list[str]] = Field(
|
||||
description="Seniority levels (e.g., ['Senior', 'VP'])"
|
||||
)
|
||||
degrees: Optional[list[str]] = Field(description="Education degrees")
|
||||
fields_of_study: Optional[list[str]] = Field(description="Fields of study")
|
||||
industries: Optional[list[str]] = Field(description="Industry categories")
|
||||
job_functions: Optional[list[str]] = Field(description="Job function categories")
|
||||
staff_count_ranges: Optional[list[str]] = Field(description="Company size ranges")
|
||||
|
||||
|
||||
class PinterestCarouselOption(BaseModel):
|
||||
"""Model for Pinterest carousel image options."""
|
||||
|
||||
title: Optional[str] = Field(description="Image title")
|
||||
link: Optional[str] = Field(description="External destination link for the image")
|
||||
description: Optional[str] = Field(description="Image description")
|
||||
|
||||
|
||||
class YouTubeTargeting(BaseModel):
|
||||
"""Model for YouTube country targeting."""
|
||||
|
||||
block: Optional[list[str]] = Field(
|
||||
description="Country codes to block (e.g., ['US', 'CA'])"
|
||||
)
|
||||
allow: Optional[list[str]] = Field(
|
||||
description="Country codes to allow (e.g., ['GB', 'AU'])"
|
||||
)
|
||||
|
||||
|
||||
def create_ayrshare_client():
|
||||
"""Create an Ayrshare client instance."""
|
||||
try:
|
||||
return AyrshareClient()
|
||||
except MissingConfigError:
|
||||
return None
|
||||
@@ -0,0 +1,114 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToBlueskyBlock(Block):
|
||||
"""Block for posting to Bluesky with Bluesky-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Bluesky posts."""
|
||||
|
||||
# Override post field to include character limit information
|
||||
post: str = SchemaField(
|
||||
description="The post text to be published (max 300 characters for Bluesky)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include Bluesky-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs to include. Bluesky supports up to 4 images or 1 video.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Bluesky-specific options
|
||||
alt_text: list[str] = SchemaField(
|
||||
description="Alt text for each media item (accessibility)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="cbd52c2a-06d2-43ed-9560-6576cc163283",
|
||||
description="Post to Bluesky using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToBlueskyBlock.Input,
|
||||
output_schema=PostToBlueskyBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToBlueskyBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Bluesky with Bluesky-specific options."""
|
||||
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate character limit for Bluesky
|
||||
if len(input_data.post) > 300:
|
||||
yield "error", f"Post text exceeds Bluesky's 300 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
# Validate media constraints for Bluesky
|
||||
if len(input_data.media_urls) > 4:
|
||||
yield "error", "Bluesky supports a maximum of 4 images or 1 video"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build Bluesky-specific options
|
||||
bluesky_options = {}
|
||||
if input_data.alt_text:
|
||||
bluesky_options["altText"] = input_data.alt_text
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.BLUESKY],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
bluesky_options=bluesky_options if bluesky_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -0,0 +1,212 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
CarouselItem,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
|
||||
|
||||
class PostToFacebookBlock(Block):
|
||||
"""Block for posting to Facebook with Facebook-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Facebook posts."""
|
||||
|
||||
# Facebook-specific options
|
||||
is_carousel: bool = SchemaField(
|
||||
description="Whether to post a carousel", default=False, advanced=True
|
||||
)
|
||||
carousel_link: str = SchemaField(
|
||||
description="The URL for the 'See More At' button in the carousel",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
carousel_items: list[CarouselItem] = SchemaField(
|
||||
description="List of carousel items with name, link and picture URLs. Min 2, max 10 items.",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
is_reels: bool = SchemaField(
|
||||
description="Whether to post to Facebook Reels",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
reels_title: str = SchemaField(
|
||||
description="Title for the Reels video (max 255 chars)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
reels_thumbnail: str = SchemaField(
|
||||
description="Thumbnail URL for Reels video (JPEG/PNG, <10MB)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
is_story: bool = SchemaField(
|
||||
description="Whether to post as a Facebook Story",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
media_captions: list[str] = SchemaField(
|
||||
description="Captions for each media item",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
location_id: str = SchemaField(
|
||||
description="Facebook Page ID or name for location tagging",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
age_min: int = SchemaField(
|
||||
description="Minimum age for audience targeting (13,15,18,21,25)",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
target_countries: list[str] = SchemaField(
|
||||
description="List of country codes to target (max 25)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
alt_text: list[str] = SchemaField(
|
||||
description="Alt text for each media item",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
video_title: str = SchemaField(
|
||||
description="Title for video post", default="", advanced=True
|
||||
)
|
||||
video_thumbnail: str = SchemaField(
|
||||
description="Thumbnail URL for video post", default="", advanced=True
|
||||
)
|
||||
is_draft: bool = SchemaField(
|
||||
description="Save as draft in Meta Business Suite",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
scheduled_publish_date: str = SchemaField(
|
||||
description="Schedule publish time in Meta Business Suite (UTC)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
preview_link: str = SchemaField(
|
||||
description="URL for custom link preview", default="", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="3352f512-3524-49ed-a08f-003042da2fc1",
|
||||
description="Post to Facebook using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToFacebookBlock.Input,
|
||||
output_schema=PostToFacebookBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToFacebookBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Facebook with Facebook-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build Facebook-specific options
|
||||
facebook_options = {}
|
||||
if input_data.is_carousel:
|
||||
facebook_options["isCarousel"] = True
|
||||
if input_data.carousel_link:
|
||||
facebook_options["carouselLink"] = input_data.carousel_link
|
||||
if input_data.carousel_items:
|
||||
facebook_options["carouselItems"] = [
|
||||
item.dict() for item in input_data.carousel_items
|
||||
]
|
||||
|
||||
if input_data.is_reels:
|
||||
facebook_options["isReels"] = True
|
||||
if input_data.reels_title:
|
||||
facebook_options["reelsTitle"] = input_data.reels_title
|
||||
if input_data.reels_thumbnail:
|
||||
facebook_options["reelsThumbnail"] = input_data.reels_thumbnail
|
||||
|
||||
if input_data.is_story:
|
||||
facebook_options["isStory"] = True
|
||||
|
||||
if input_data.media_captions:
|
||||
facebook_options["mediaCaptions"] = input_data.media_captions
|
||||
|
||||
if input_data.location_id:
|
||||
facebook_options["locationId"] = input_data.location_id
|
||||
|
||||
if input_data.age_min > 0:
|
||||
facebook_options["ageMin"] = input_data.age_min
|
||||
|
||||
if input_data.target_countries:
|
||||
facebook_options["targetCountries"] = input_data.target_countries
|
||||
|
||||
if input_data.alt_text:
|
||||
facebook_options["altText"] = input_data.alt_text
|
||||
|
||||
if input_data.video_title:
|
||||
facebook_options["videoTitle"] = input_data.video_title
|
||||
|
||||
if input_data.video_thumbnail:
|
||||
facebook_options["videoThumbnail"] = input_data.video_thumbnail
|
||||
|
||||
if input_data.is_draft:
|
||||
facebook_options["isDraft"] = True
|
||||
|
||||
if input_data.scheduled_publish_date:
|
||||
facebook_options["scheduledPublishDate"] = input_data.scheduled_publish_date
|
||||
|
||||
if input_data.preview_link:
|
||||
facebook_options["previewLink"] = input_data.preview_link
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.FACEBOOK],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
facebook_options=facebook_options if facebook_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
210
autogpt_platform/backend/backend/blocks/ayrshare/post_to_gmb.py
Normal file
210
autogpt_platform/backend/backend/blocks/ayrshare/post_to_gmb.py
Normal file
@@ -0,0 +1,210 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToGMBBlock(Block):
|
||||
"""Block for posting to Google My Business with GMB-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Google My Business posts."""
|
||||
|
||||
# Override media_urls to include GMB-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs. GMB supports only one image or video per post.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# GMB-specific options
|
||||
is_photo_video: bool = SchemaField(
|
||||
description="Whether this is a photo/video post (appears in Photos section)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
photo_category: str = SchemaField(
|
||||
description="Category for photo/video: cover, profile, logo, exterior, interior, product, at_work, food_and_drink, menu, common_area, rooms, teams",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
# Call to action options (flattened from CallToAction object)
|
||||
call_to_action_type: str = SchemaField(
|
||||
description="Type of action button: 'book', 'order', 'shop', 'learn_more', 'sign_up', or 'call'",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
call_to_action_url: str = SchemaField(
|
||||
description="URL for the action button (not required for 'call' action)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
# Event details options (flattened from EventDetails object)
|
||||
event_title: str = SchemaField(
|
||||
description="Event title for event posts",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
event_start_date: str = SchemaField(
|
||||
description="Event start date in ISO format (e.g., '2024-03-15T09:00:00Z')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
event_end_date: str = SchemaField(
|
||||
description="Event end date in ISO format (e.g., '2024-03-15T17:00:00Z')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
# Offer details options (flattened from OfferDetails object)
|
||||
offer_title: str = SchemaField(
|
||||
description="Offer title for promotional posts",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
offer_start_date: str = SchemaField(
|
||||
description="Offer start date in ISO format (e.g., '2024-03-15T00:00:00Z')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
offer_end_date: str = SchemaField(
|
||||
description="Offer end date in ISO format (e.g., '2024-04-15T23:59:59Z')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
offer_coupon_code: str = SchemaField(
|
||||
description="Coupon code for the offer (max 58 characters)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
offer_redeem_online_url: str = SchemaField(
|
||||
description="URL where customers can redeem the offer online",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
offer_terms_conditions: str = SchemaField(
|
||||
description="Terms and conditions for the offer",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="2c38c783-c484-4503-9280-ef5d1d345a7e",
|
||||
description="Post to Google My Business using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToGMBBlock.Input,
|
||||
output_schema=PostToGMBBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: "PostToGMBBlock.Input", *, user_id: str, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""Post to Google My Business with GMB-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate GMB constraints
|
||||
if len(input_data.media_urls) > 1:
|
||||
yield "error", "Google My Business supports only one image or video per post"
|
||||
return
|
||||
|
||||
# Validate offer coupon code length
|
||||
if input_data.offer_coupon_code and len(input_data.offer_coupon_code) > 58:
|
||||
yield "error", "GMB offer coupon code cannot exceed 58 characters"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build GMB-specific options
|
||||
gmb_options = {}
|
||||
|
||||
# Photo/Video post options
|
||||
if input_data.is_photo_video:
|
||||
gmb_options["isPhotoVideo"] = True
|
||||
if input_data.photo_category:
|
||||
gmb_options["category"] = input_data.photo_category
|
||||
|
||||
# Call to Action (from flattened fields)
|
||||
if input_data.call_to_action_type:
|
||||
cta_dict = {"actionType": input_data.call_to_action_type}
|
||||
# URL not required for 'call' action type
|
||||
if (
|
||||
input_data.call_to_action_type != "call"
|
||||
and input_data.call_to_action_url
|
||||
):
|
||||
cta_dict["url"] = input_data.call_to_action_url
|
||||
gmb_options["callToAction"] = cta_dict
|
||||
|
||||
# Event details (from flattened fields)
|
||||
if (
|
||||
input_data.event_title
|
||||
and input_data.event_start_date
|
||||
and input_data.event_end_date
|
||||
):
|
||||
gmb_options["event"] = {
|
||||
"title": input_data.event_title,
|
||||
"startDate": input_data.event_start_date,
|
||||
"endDate": input_data.event_end_date,
|
||||
}
|
||||
|
||||
# Offer details (from flattened fields)
|
||||
if (
|
||||
input_data.offer_title
|
||||
and input_data.offer_start_date
|
||||
and input_data.offer_end_date
|
||||
and input_data.offer_coupon_code
|
||||
and input_data.offer_redeem_online_url
|
||||
and input_data.offer_terms_conditions
|
||||
):
|
||||
gmb_options["offer"] = {
|
||||
"title": input_data.offer_title,
|
||||
"startDate": input_data.offer_start_date,
|
||||
"endDate": input_data.offer_end_date,
|
||||
"couponCode": input_data.offer_coupon_code,
|
||||
"redeemOnlineUrl": input_data.offer_redeem_online_url,
|
||||
"termsConditions": input_data.offer_terms_conditions,
|
||||
}
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.GOOGLE_MY_BUSINESS],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
gmb_options=gmb_options if gmb_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -0,0 +1,249 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
InstagramUserTag,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
|
||||
|
||||
class PostToInstagramBlock(Block):
|
||||
"""Block for posting to Instagram with Instagram-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Instagram posts."""
|
||||
|
||||
# Override post field to include Instagram-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (max 2,200 chars, up to 30 hashtags, 3 @mentions)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include Instagram-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs. Instagram supports up to 10 images/videos in a carousel.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Instagram-specific options
|
||||
is_story: bool | None = SchemaField(
|
||||
description="Whether to post as Instagram Story (24-hour expiration)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# ------- REELS OPTIONS -------
|
||||
share_reels_feed: bool | None = SchemaField(
|
||||
description="Whether Reel should appear in both Feed and Reels tabs",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
audio_name: str | None = SchemaField(
|
||||
description="Audio name for Reels (e.g., 'The Weeknd - Blinding Lights')",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
thumbnail: str | None = SchemaField(
|
||||
description="Thumbnail URL for Reel video", default=None, advanced=True
|
||||
)
|
||||
thumbnail_offset: int | None = SchemaField(
|
||||
description="Thumbnail frame offset in milliseconds (default: 0)",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# ------- POST OPTIONS -------
|
||||
|
||||
alt_text: list[str] = SchemaField(
|
||||
description="Alt text for each media item (up to 1,000 chars each, accessibility feature), each item in the list corresponds to a media item in the media_urls list",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
location_id: str | None = SchemaField(
|
||||
description="Facebook Page ID or name for location tagging (e.g., '7640348500' or '@guggenheimmuseum')",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
user_tags: list[dict[str, Any]] = SchemaField(
|
||||
description="List of users to tag with coordinates for images",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
collaborators: list[str] = SchemaField(
|
||||
description="Instagram usernames to invite as collaborators (max 3, public accounts only)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
auto_resize: bool | None = SchemaField(
|
||||
description="Auto-resize images to 1080x1080px for Instagram",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="89b02b96-a7cb-46f4-9900-c48b32fe1552",
|
||||
description="Post to Instagram using Ayrshare. Requires a Business or Creator Instagram Account connected with a Facebook Page",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToInstagramBlock.Input,
|
||||
output_schema=PostToInstagramBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToInstagramBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Instagram with Instagram-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate Instagram constraints
|
||||
if len(input_data.post) > 2200:
|
||||
yield "error", f"Instagram post text exceeds 2,200 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 10:
|
||||
yield "error", "Instagram supports a maximum of 10 images/videos in a carousel"
|
||||
return
|
||||
|
||||
if len(input_data.collaborators) > 3:
|
||||
yield "error", "Instagram supports a maximum of 3 collaborators"
|
||||
return
|
||||
|
||||
# Validate that if any reel option is set, all required reel options are set
|
||||
reel_options = [
|
||||
input_data.share_reels_feed,
|
||||
input_data.audio_name,
|
||||
input_data.thumbnail,
|
||||
]
|
||||
|
||||
if any(reel_options) and not all(reel_options):
|
||||
yield "error", "When posting a reel, all reel options must be set: share_reels_feed, audio_name, and either thumbnail or thumbnail_offset"
|
||||
return
|
||||
|
||||
# Count hashtags and mentions
|
||||
hashtag_count = input_data.post.count("#")
|
||||
mention_count = input_data.post.count("@")
|
||||
|
||||
if hashtag_count > 30:
|
||||
yield "error", f"Instagram allows maximum 30 hashtags ({hashtag_count} found)"
|
||||
return
|
||||
|
||||
if mention_count > 3:
|
||||
yield "error", f"Instagram allows maximum 3 @mentions ({mention_count} found)"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build Instagram-specific options
|
||||
instagram_options = {}
|
||||
|
||||
# Stories
|
||||
if input_data.is_story:
|
||||
instagram_options["stories"] = True
|
||||
|
||||
# Reels options
|
||||
if input_data.share_reels_feed is not None:
|
||||
instagram_options["shareReelsFeed"] = input_data.share_reels_feed
|
||||
|
||||
if input_data.audio_name:
|
||||
instagram_options["audioName"] = input_data.audio_name
|
||||
|
||||
if input_data.thumbnail:
|
||||
instagram_options["thumbNail"] = input_data.thumbnail
|
||||
elif input_data.thumbnail_offset and input_data.thumbnail_offset > 0:
|
||||
instagram_options["thumbNailOffset"] = input_data.thumbnail_offset
|
||||
|
||||
# Alt text
|
||||
if input_data.alt_text:
|
||||
# Validate alt text length
|
||||
for i, alt in enumerate(input_data.alt_text):
|
||||
if len(alt) > 1000:
|
||||
yield "error", f"Alt text {i+1} exceeds 1,000 character limit ({len(alt)} characters)"
|
||||
return
|
||||
instagram_options["altText"] = input_data.alt_text
|
||||
|
||||
# Location
|
||||
if input_data.location_id:
|
||||
instagram_options["locationId"] = input_data.location_id
|
||||
|
||||
# User tags
|
||||
if input_data.user_tags:
|
||||
user_tags_list = []
|
||||
for tag in input_data.user_tags:
|
||||
try:
|
||||
tag_obj = InstagramUserTag(**tag)
|
||||
except Exception as e:
|
||||
yield "error", f"Invalid user tag: {e}, tages need to be a dictionary with a 3 items: username (str), x (float) and y (float)"
|
||||
return
|
||||
tag_dict: dict[str, float | str] = {"username": tag_obj.username}
|
||||
if tag_obj.x is not None and tag_obj.y is not None:
|
||||
# Validate coordinates
|
||||
if not (0.0 <= tag_obj.x <= 1.0) or not (0.0 <= tag_obj.y <= 1.0):
|
||||
yield "error", f"User tag coordinates must be between 0.0 and 1.0 (user: {tag_obj.username})"
|
||||
return
|
||||
tag_dict["x"] = tag_obj.x
|
||||
tag_dict["y"] = tag_obj.y
|
||||
user_tags_list.append(tag_dict)
|
||||
instagram_options["userTags"] = user_tags_list
|
||||
|
||||
# Collaborators
|
||||
if input_data.collaborators:
|
||||
instagram_options["collaborators"] = input_data.collaborators
|
||||
|
||||
# Auto resize
|
||||
if input_data.auto_resize:
|
||||
instagram_options["autoResize"] = True
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.INSTAGRAM],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
instagram_options=instagram_options if instagram_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -0,0 +1,222 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToLinkedInBlock(Block):
|
||||
"""Block for posting to LinkedIn with LinkedIn-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for LinkedIn posts."""
|
||||
|
||||
# Override post field to include LinkedIn-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (max 3,000 chars, hashtags supported with #)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include LinkedIn-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs. LinkedIn supports up to 9 images, videos, or documents (PPT, PPTX, DOC, DOCX, PDF <100MB, <300 pages).",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# LinkedIn-specific options
|
||||
visibility: str = SchemaField(
|
||||
description="Post visibility: 'public' (default), 'connections' (personal only), 'loggedin'",
|
||||
default="public",
|
||||
advanced=True,
|
||||
)
|
||||
alt_text: list[str] = SchemaField(
|
||||
description="Alt text for each image (accessibility feature, not supported for videos/documents)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
titles: list[str] = SchemaField(
|
||||
description="Title/caption for each image or video",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
document_title: str = SchemaField(
|
||||
description="Title for document posts (max 400 chars, uses filename if not specified)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
thumbnail: str = SchemaField(
|
||||
description="Thumbnail URL for video (PNG/JPG, same dimensions as video, <10MB)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
# LinkedIn targeting options (flattened from LinkedInTargeting object)
|
||||
targeting_countries: list[str] | None = SchemaField(
|
||||
description="Country codes for targeting (e.g., ['US', 'IN', 'DE', 'GB']). Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_seniorities: list[str] | None = SchemaField(
|
||||
description="Seniority levels for targeting (e.g., ['Senior', 'VP']). Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_degrees: list[str] | None = SchemaField(
|
||||
description="Education degrees for targeting. Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_fields_of_study: list[str] | None = SchemaField(
|
||||
description="Fields of study for targeting. Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_industries: list[str] | None = SchemaField(
|
||||
description="Industry categories for targeting. Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_job_functions: list[str] | None = SchemaField(
|
||||
description="Job function categories for targeting. Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_staff_count_ranges: list[str] | None = SchemaField(
|
||||
description="Company size ranges for targeting. Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="589af4e4-507f-42fd-b9ac-a67ecef25811",
|
||||
description="Post to LinkedIn using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToLinkedInBlock.Input,
|
||||
output_schema=PostToLinkedInBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToLinkedInBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to LinkedIn with LinkedIn-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate LinkedIn constraints
|
||||
if len(input_data.post) > 3000:
|
||||
yield "error", f"LinkedIn post text exceeds 3,000 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 9:
|
||||
yield "error", "LinkedIn supports a maximum of 9 images/videos/documents"
|
||||
return
|
||||
|
||||
if input_data.document_title and len(input_data.document_title) > 400:
|
||||
yield "error", f"LinkedIn document title exceeds 400 character limit ({len(input_data.document_title)} characters)"
|
||||
return
|
||||
|
||||
# Validate visibility option
|
||||
valid_visibility = ["public", "connections", "loggedin"]
|
||||
if input_data.visibility not in valid_visibility:
|
||||
yield "error", f"LinkedIn visibility must be one of: {', '.join(valid_visibility)}"
|
||||
return
|
||||
|
||||
# Check for document extensions
|
||||
document_extensions = [".ppt", ".pptx", ".doc", ".docx", ".pdf"]
|
||||
has_documents = any(
|
||||
any(url.lower().endswith(ext) for ext in document_extensions)
|
||||
for url in input_data.media_urls
|
||||
)
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build LinkedIn-specific options
|
||||
linkedin_options = {}
|
||||
|
||||
# Visibility
|
||||
if input_data.visibility != "public":
|
||||
linkedin_options["visibility"] = input_data.visibility
|
||||
|
||||
# Alt text (not supported for videos or documents)
|
||||
if input_data.alt_text and not has_documents:
|
||||
linkedin_options["altText"] = input_data.alt_text
|
||||
|
||||
# Titles/captions
|
||||
if input_data.titles:
|
||||
linkedin_options["titles"] = input_data.titles
|
||||
|
||||
# Document title
|
||||
if input_data.document_title and has_documents:
|
||||
linkedin_options["title"] = input_data.document_title
|
||||
|
||||
# Video thumbnail
|
||||
if input_data.thumbnail:
|
||||
linkedin_options["thumbNail"] = input_data.thumbnail
|
||||
|
||||
# Audience targeting (from flattened fields)
|
||||
targeting_dict = {}
|
||||
if input_data.targeting_countries:
|
||||
targeting_dict["countries"] = input_data.targeting_countries
|
||||
if input_data.targeting_seniorities:
|
||||
targeting_dict["seniorities"] = input_data.targeting_seniorities
|
||||
if input_data.targeting_degrees:
|
||||
targeting_dict["degrees"] = input_data.targeting_degrees
|
||||
if input_data.targeting_fields_of_study:
|
||||
targeting_dict["fieldsOfStudy"] = input_data.targeting_fields_of_study
|
||||
if input_data.targeting_industries:
|
||||
targeting_dict["industries"] = input_data.targeting_industries
|
||||
if input_data.targeting_job_functions:
|
||||
targeting_dict["jobFunctions"] = input_data.targeting_job_functions
|
||||
if input_data.targeting_staff_count_ranges:
|
||||
targeting_dict["staffCountRanges"] = input_data.targeting_staff_count_ranges
|
||||
|
||||
if targeting_dict:
|
||||
linkedin_options["targeting"] = targeting_dict
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.LINKEDIN],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
linkedin_options=linkedin_options if linkedin_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -0,0 +1,214 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
PinterestCarouselOption,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
|
||||
|
||||
class PostToPinterestBlock(Block):
|
||||
"""Block for posting to Pinterest with Pinterest-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Pinterest posts."""
|
||||
|
||||
# Override post field to include Pinterest-specific information
|
||||
post: str = SchemaField(
|
||||
description="Pin description (max 500 chars, links not clickable - use link field instead)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include Pinterest-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Required image/video URLs. Pinterest requires at least one image. Videos need thumbnail. Up to 5 images for carousel.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Pinterest-specific options
|
||||
pin_title: str = SchemaField(
|
||||
description="Pin title displayed in 'Add your title' section (max 100 chars)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
link: str = SchemaField(
|
||||
description="Clickable destination URL when users click the pin (max 2048 chars)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
board_id: str = SchemaField(
|
||||
description="Pinterest Board ID to post to (from /user/details endpoint, uses default board if not specified)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
note: str = SchemaField(
|
||||
description="Private note for the pin (only visible to you and board collaborators)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
thumbnail: str = SchemaField(
|
||||
description="Required thumbnail URL for video pins (must have valid image Content-Type)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
carousel_options: list[PinterestCarouselOption] = SchemaField(
|
||||
description="Options for each image in carousel (title, link, description per image)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
alt_text: list[str] = SchemaField(
|
||||
description="Alt text for each image/video (max 500 chars each, accessibility feature)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="3ca46e05-dbaa-4afb-9e95-5a429c4177e6",
|
||||
description="Post to Pinterest using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToPinterestBlock.Input,
|
||||
output_schema=PostToPinterestBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToPinterestBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Pinterest with Pinterest-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate Pinterest constraints
|
||||
if len(input_data.post) > 500:
|
||||
yield "error", f"Pinterest pin description exceeds 500 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.pin_title) > 100:
|
||||
yield "error", f"Pinterest pin title exceeds 100 character limit ({len(input_data.pin_title)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.link) > 2048:
|
||||
yield "error", f"Pinterest link URL exceeds 2048 character limit ({len(input_data.link)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) == 0:
|
||||
yield "error", "Pinterest requires at least one image or video"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 5:
|
||||
yield "error", "Pinterest supports a maximum of 5 images in a carousel"
|
||||
return
|
||||
|
||||
# Check if video is included and thumbnail is provided
|
||||
video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".webm"]
|
||||
has_video = any(
|
||||
any(url.lower().endswith(ext) for ext in video_extensions)
|
||||
for url in input_data.media_urls
|
||||
)
|
||||
|
||||
if (has_video or input_data.is_video) and not input_data.thumbnail:
|
||||
yield "error", "Pinterest video pins require a thumbnail URL"
|
||||
return
|
||||
|
||||
# Validate alt text length
|
||||
for i, alt in enumerate(input_data.alt_text):
|
||||
if len(alt) > 500:
|
||||
yield "error", f"Pinterest alt text {i+1} exceeds 500 character limit ({len(alt)} characters)"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build Pinterest-specific options
|
||||
pinterest_options = {}
|
||||
|
||||
# Pin title
|
||||
if input_data.pin_title:
|
||||
pinterest_options["title"] = input_data.pin_title
|
||||
|
||||
# Clickable link
|
||||
if input_data.link:
|
||||
pinterest_options["link"] = input_data.link
|
||||
|
||||
# Board ID
|
||||
if input_data.board_id:
|
||||
pinterest_options["boardId"] = input_data.board_id
|
||||
|
||||
# Private note
|
||||
if input_data.note:
|
||||
pinterest_options["note"] = input_data.note
|
||||
|
||||
# Video thumbnail
|
||||
if input_data.thumbnail:
|
||||
pinterest_options["thumbNail"] = input_data.thumbnail
|
||||
|
||||
# Carousel options
|
||||
if input_data.carousel_options:
|
||||
carousel_list = []
|
||||
for option in input_data.carousel_options:
|
||||
carousel_dict = {}
|
||||
if option.title:
|
||||
carousel_dict["title"] = option.title
|
||||
if option.link:
|
||||
carousel_dict["link"] = option.link
|
||||
if option.description:
|
||||
carousel_dict["description"] = option.description
|
||||
if carousel_dict: # Only add if not empty
|
||||
carousel_list.append(carousel_dict)
|
||||
if carousel_list:
|
||||
pinterest_options["carouselOptions"] = carousel_list
|
||||
|
||||
# Alt text
|
||||
if input_data.alt_text:
|
||||
pinterest_options["altText"] = input_data.alt_text
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.PINTEREST],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
pinterest_options=pinterest_options if pinterest_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -0,0 +1,69 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToRedditBlock(Block):
|
||||
"""Block for posting to Reddit."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Reddit posts."""
|
||||
|
||||
pass # Uses all base fields
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="c7733580-3c72-483e-8e47-a8d58754d853",
|
||||
description="Post to Reddit using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToRedditBlock.Input,
|
||||
output_schema=PostToRedditBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: "PostToRedditBlock.Input", *, user_id: str, **kwargs
|
||||
) -> BlockOutput:
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured."
|
||||
return
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.REDDIT],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -0,0 +1,129 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToSnapchatBlock(Block):
|
||||
"""Block for posting to Snapchat with Snapchat-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Snapchat posts."""
|
||||
|
||||
# Override post field to include Snapchat-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (optional for video-only content)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include Snapchat-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Required video URL for Snapchat posts. Snapchat only supports video content.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Snapchat-specific options
|
||||
story_type: str = SchemaField(
|
||||
description="Type of Snapchat content: 'story' (24-hour Stories), 'saved_story' (Saved Stories), or 'spotlight' (Spotlight posts)",
|
||||
default="story",
|
||||
advanced=True,
|
||||
)
|
||||
video_thumbnail: str = SchemaField(
|
||||
description="Thumbnail URL for video content (optional, auto-generated if not provided)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="a9d7f854-2c83-4e96-b3a1-7f2e9c5d4b8e",
|
||||
description="Post to Snapchat using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToSnapchatBlock.Input,
|
||||
output_schema=PostToSnapchatBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToSnapchatBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Snapchat with Snapchat-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate Snapchat constraints
|
||||
if not input_data.media_urls:
|
||||
yield "error", "Snapchat requires at least one video URL"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 1:
|
||||
yield "error", "Snapchat supports only one video per post"
|
||||
return
|
||||
|
||||
# Validate story type
|
||||
valid_story_types = ["story", "saved_story", "spotlight"]
|
||||
if input_data.story_type not in valid_story_types:
|
||||
yield "error", f"Snapchat story type must be one of: {', '.join(valid_story_types)}"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build Snapchat-specific options
|
||||
snapchat_options = {}
|
||||
|
||||
# Story type
|
||||
if input_data.story_type != "story":
|
||||
snapchat_options["storyType"] = input_data.story_type
|
||||
|
||||
# Video thumbnail
|
||||
if input_data.video_thumbnail:
|
||||
snapchat_options["videoThumbnail"] = input_data.video_thumbnail
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.SNAPCHAT],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=True, # Snapchat only supports video
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
snapchat_options=snapchat_options if snapchat_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -0,0 +1,116 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToTelegramBlock(Block):
|
||||
"""Block for posting to Telegram with Telegram-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Telegram posts."""
|
||||
|
||||
# Override post field to include Telegram-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (empty string allowed). Use @handle to mention other Telegram users.",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include Telegram-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs. For animated GIFs, only one URL is allowed. Telegram will auto-preview links unless image/video is included.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override is_video to include GIF-specific information
|
||||
is_video: bool = SchemaField(
|
||||
description="Whether the media is a video. Set to true for animated GIFs that don't end in .gif/.GIF extension.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="47bc74eb-4af2-452c-b933-af377c7287df",
|
||||
description="Post to Telegram using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToTelegramBlock.Input,
|
||||
output_schema=PostToTelegramBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToTelegramBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Telegram with Telegram-specific validation."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate Telegram constraints
|
||||
# Check for animated GIFs - only one URL allowed
|
||||
gif_extensions = [".gif", ".GIF"]
|
||||
has_gif = any(
|
||||
any(url.endswith(ext) for ext in gif_extensions)
|
||||
for url in input_data.media_urls
|
||||
)
|
||||
|
||||
if has_gif and len(input_data.media_urls) > 1:
|
||||
yield "error", "Telegram animated GIFs support only one URL per post"
|
||||
return
|
||||
|
||||
# Auto-detect if we need to set is_video for GIFs without proper extension
|
||||
detected_is_video = input_data.is_video
|
||||
if input_data.media_urls and not has_gif and not input_data.is_video:
|
||||
# Check if this might be a GIF without proper extension
|
||||
# This is just informational - user needs to set is_video manually
|
||||
pass
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.TELEGRAM],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=detected_is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -0,0 +1,111 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToThreadsBlock(Block):
|
||||
"""Block for posting to Threads with Threads-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Threads posts."""
|
||||
|
||||
# Override post field to include Threads-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (max 500 chars, empty string allowed). Only 1 hashtag allowed. Use @handle to mention users.",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include Threads-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs. Supports up to 20 images/videos in a carousel. Auto-preview links unless media is included.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="f8c3b2e1-9d4a-4e5f-8c7b-6a9e8d2f1c3b",
|
||||
description="Post to Threads using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToThreadsBlock.Input,
|
||||
output_schema=PostToThreadsBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToThreadsBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Threads with Threads-specific validation."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate Threads constraints
|
||||
if len(input_data.post) > 500:
|
||||
yield "error", f"Threads post text exceeds 500 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 20:
|
||||
yield "error", "Threads supports a maximum of 20 images/videos in a carousel"
|
||||
return
|
||||
|
||||
# Count hashtags (only 1 allowed)
|
||||
hashtag_count = input_data.post.count("#")
|
||||
if hashtag_count > 1:
|
||||
yield "error", f"Threads allows only 1 hashtag per post ({hashtag_count} found)"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build Threads-specific options
|
||||
threads_options = {}
|
||||
# Note: Based on the documentation, Threads doesn't seem to have specific options
|
||||
# beyond the standard ones. The main constraints are validation-based.
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.THREADS],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
threads_options=threads_options if threads_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -0,0 +1,243 @@
|
||||
from enum import Enum
|
||||
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class TikTokVisibility(str, Enum):
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
||||
FOLLOWERS = "followers"
|
||||
|
||||
|
||||
class PostToTikTokBlock(Block):
|
||||
"""Block for posting to TikTok with TikTok-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for TikTok posts."""
|
||||
|
||||
# Override post field to include TikTok-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (max 2,200 chars, empty string allowed). Use @handle to mention users. Line breaks will be ignored.",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include TikTok-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Required media URLs. Either 1 video OR up to 35 images (JPG/JPEG/WEBP only). Cannot mix video and images.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# TikTok-specific options
|
||||
auto_add_music: bool = SchemaField(
|
||||
description="Whether to automatically add recommended music to the post. If you set this field to true, you can change the music later in the TikTok app.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
disable_comments: bool = SchemaField(
|
||||
description="Disable comments on the published post",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
disable_duet: bool = SchemaField(
|
||||
description="Disable duets on published video (video only)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
disable_stitch: bool = SchemaField(
|
||||
description="Disable stitch on published video (video only)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
is_ai_generated: bool = SchemaField(
|
||||
description="If you enable the toggle, your video will be labeled as “Creator labeled as AI-generated” once posted and can’t be changed. The “Creator labeled as AI-generated” label indicates that the content was completely AI-generated or significantly edited with AI.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
is_branded_content: bool = SchemaField(
|
||||
description="Whether to enable the Branded Content toggle. If this field is set to true, the video will be labeled as Branded Content, indicating you are in a paid partnership with a brand. A “Paid partnership” label will be attached to the video.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
is_brand_organic: bool = SchemaField(
|
||||
description="Whether to enable the Brand Organic Content toggle. If this field is set to true, the video will be labeled as Brand Organic Content, indicating you are promoting yourself or your own business. A “Promotional content” label will be attached to the video.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
image_cover_index: int = SchemaField(
|
||||
description="Index of image to use as cover (0-based, image posts only)",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
title: str = SchemaField(
|
||||
description="Title for image posts", default="", advanced=True
|
||||
)
|
||||
thumbnail_offset: int = SchemaField(
|
||||
description="Video thumbnail frame offset in milliseconds (video only)",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
visibility: TikTokVisibility = SchemaField(
|
||||
description="Post visibility: 'public', 'private', 'followers', or 'friends'",
|
||||
default=TikTokVisibility.PUBLIC,
|
||||
advanced=True,
|
||||
)
|
||||
draft: bool = SchemaField(
|
||||
description="Create as draft post (video only)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7faf4b27-96b0-4f05-bf64-e0de54ae74e1",
|
||||
description="Post to TikTok using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToTikTokBlock.Input,
|
||||
output_schema=PostToTikTokBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: "PostToTikTokBlock.Input", *, user_id: str, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""Post to TikTok with TikTok-specific validation and options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate TikTok constraints
|
||||
if len(input_data.post) > 2200:
|
||||
yield "error", f"TikTok post text exceeds 2,200 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
if not input_data.media_urls:
|
||||
yield "error", "TikTok requires at least one media URL (either 1 video or up to 35 images)"
|
||||
return
|
||||
|
||||
# Check for video vs image constraints
|
||||
video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".webm"]
|
||||
image_extensions = [".jpg", ".jpeg", ".webp"]
|
||||
|
||||
has_video = input_data.is_video or any(
|
||||
any(url.lower().endswith(ext) for ext in video_extensions)
|
||||
for url in input_data.media_urls
|
||||
)
|
||||
|
||||
has_images = any(
|
||||
any(url.lower().endswith(ext) for ext in image_extensions)
|
||||
for url in input_data.media_urls
|
||||
)
|
||||
|
||||
if has_video and has_images:
|
||||
yield "error", "TikTok does not support mixing video and images in the same post"
|
||||
return
|
||||
|
||||
if has_video and len(input_data.media_urls) > 1:
|
||||
yield "error", "TikTok supports only 1 video per post"
|
||||
return
|
||||
|
||||
if has_images and len(input_data.media_urls) > 35:
|
||||
yield "error", "TikTok supports a maximum of 35 images per post"
|
||||
return
|
||||
|
||||
# Validate image cover index
|
||||
if has_images and input_data.image_cover_index >= len(input_data.media_urls):
|
||||
yield "error", f"Image cover index {input_data.image_cover_index} is out of range (max: {len(input_data.media_urls) - 1})"
|
||||
return
|
||||
|
||||
# Check for PNG files (not supported)
|
||||
has_png = any(url.lower().endswith(".png") for url in input_data.media_urls)
|
||||
if has_png:
|
||||
yield "error", "TikTok does not support PNG files. Please use JPG, JPEG, or WEBP for images."
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build TikTok-specific options
|
||||
tiktok_options = {}
|
||||
|
||||
# Common options
|
||||
if input_data.auto_add_music and has_images:
|
||||
tiktok_options["autoAddMusic"] = True
|
||||
|
||||
if input_data.disable_comments:
|
||||
tiktok_options["disableComments"] = True
|
||||
|
||||
if input_data.is_branded_content:
|
||||
tiktok_options["isBrandedContent"] = True
|
||||
|
||||
if input_data.is_brand_organic:
|
||||
tiktok_options["isBrandOrganic"] = True
|
||||
|
||||
# Video-specific options
|
||||
if has_video:
|
||||
if input_data.disable_duet:
|
||||
tiktok_options["disableDuet"] = True
|
||||
|
||||
if input_data.disable_stitch:
|
||||
tiktok_options["disableStitch"] = True
|
||||
|
||||
if input_data.is_ai_generated:
|
||||
tiktok_options["isAIGenerated"] = True
|
||||
|
||||
if input_data.thumbnail_offset > 0:
|
||||
tiktok_options["thumbNailOffset"] = input_data.thumbnail_offset
|
||||
|
||||
if input_data.draft:
|
||||
tiktok_options["draft"] = True
|
||||
|
||||
# Image-specific options
|
||||
if has_images:
|
||||
if input_data.image_cover_index > 0:
|
||||
tiktok_options["imageCoverIndex"] = input_data.image_cover_index
|
||||
|
||||
if input_data.title:
|
||||
tiktok_options["title"] = input_data.title
|
||||
|
||||
if input_data.visibility != TikTokVisibility.PUBLIC:
|
||||
tiktok_options["visibility"] = input_data.visibility.value
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.TIKTOK],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=has_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
tiktok_options=tiktok_options if tiktok_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
241
autogpt_platform/backend/backend/blocks/ayrshare/post_to_x.py
Normal file
241
autogpt_platform/backend/backend/blocks/ayrshare/post_to_x.py
Normal file
@@ -0,0 +1,241 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToXBlock(Block):
|
||||
"""Block for posting to X / Twitter with Twitter-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for X / Twitter posts."""
|
||||
|
||||
# Override post field to include X-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (max 280 chars, up to 25,000 for Premium users). Use @handle to mention users. Use \\n\\n for thread breaks.",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include X-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs. X supports up to 4 images or videos per tweet. Auto-preview links unless media is included.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# X-specific options
|
||||
reply_to_id: str | None = SchemaField(
|
||||
description="ID of the tweet to reply to",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
quote_tweet_id: str | None = SchemaField(
|
||||
description="ID of the tweet to quote (low-level Tweet ID)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
poll_options: list[str] = SchemaField(
|
||||
description="Poll options (2-4 choices)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
poll_duration: int = SchemaField(
|
||||
description="Poll duration in minutes (1-10080)",
|
||||
default=1440,
|
||||
advanced=True,
|
||||
)
|
||||
alt_text: list[str] = SchemaField(
|
||||
description="Alt text for each image (max 1,000 chars each, not supported for videos)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
is_thread: bool = SchemaField(
|
||||
description="Whether to automatically break post into thread based on line breaks",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
thread_number: bool = SchemaField(
|
||||
description="Add thread numbers (1/n format) to each thread post",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
thread_media_urls: list[str] = SchemaField(
|
||||
description="Media URLs for thread posts (one per thread, use 'null' to skip)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
long_post: bool = SchemaField(
|
||||
description="Force long form post (requires Premium X account)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
long_video: bool = SchemaField(
|
||||
description="Enable long video upload (requires approval and Business/Enterprise plan)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
subtitle_url: str = SchemaField(
|
||||
description="URL to SRT subtitle file for videos (must be HTTPS and end in .srt)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
subtitle_language: str = SchemaField(
|
||||
description="Language code for subtitles (default: 'en')",
|
||||
default="en",
|
||||
advanced=True,
|
||||
)
|
||||
subtitle_name: str = SchemaField(
|
||||
description="Name of caption track (max 150 chars, default: 'English')",
|
||||
default="English",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9e8f844e-b4a5-4b25-80f2-9e1dd7d67625",
|
||||
description="Post to X / Twitter using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToXBlock.Input,
|
||||
output_schema=PostToXBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToXBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to X / Twitter with enhanced X-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate X constraints
|
||||
if not input_data.long_post and len(input_data.post) > 280:
|
||||
yield "error", f"X post text exceeds 280 character limit ({len(input_data.post)} characters). Enable 'long_post' for Premium accounts."
|
||||
return
|
||||
|
||||
if input_data.long_post and len(input_data.post) > 25000:
|
||||
yield "error", f"X long post text exceeds 25,000 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 4:
|
||||
yield "error", "X supports a maximum of 4 images or videos per tweet"
|
||||
return
|
||||
|
||||
# Validate poll options
|
||||
if input_data.poll_options:
|
||||
if len(input_data.poll_options) < 2 or len(input_data.poll_options) > 4:
|
||||
yield "error", "X polls require 2-4 options"
|
||||
return
|
||||
|
||||
if input_data.poll_duration < 1 or input_data.poll_duration > 10080:
|
||||
yield "error", "X poll duration must be between 1 and 10,080 minutes (7 days)"
|
||||
return
|
||||
|
||||
# Validate alt text
|
||||
if input_data.alt_text:
|
||||
for i, alt in enumerate(input_data.alt_text):
|
||||
if len(alt) > 1000:
|
||||
yield "error", f"X alt text {i+1} exceeds 1,000 character limit ({len(alt)} characters)"
|
||||
return
|
||||
|
||||
# Validate subtitle settings
|
||||
if input_data.subtitle_url:
|
||||
if not input_data.subtitle_url.startswith(
|
||||
"https://"
|
||||
) or not input_data.subtitle_url.endswith(".srt"):
|
||||
yield "error", "Subtitle URL must start with https:// and end with .srt"
|
||||
return
|
||||
|
||||
if len(input_data.subtitle_name) > 150:
|
||||
yield "error", f"Subtitle name exceeds 150 character limit ({len(input_data.subtitle_name)} characters)"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build X-specific options
|
||||
twitter_options = {}
|
||||
|
||||
# Basic options
|
||||
if input_data.reply_to_id:
|
||||
twitter_options["replyToId"] = input_data.reply_to_id
|
||||
|
||||
if input_data.quote_tweet_id:
|
||||
twitter_options["quoteTweetId"] = input_data.quote_tweet_id
|
||||
|
||||
if input_data.long_post:
|
||||
twitter_options["longPost"] = True
|
||||
|
||||
if input_data.long_video:
|
||||
twitter_options["longVideo"] = True
|
||||
|
||||
# Poll options
|
||||
if input_data.poll_options:
|
||||
twitter_options["poll"] = {
|
||||
"duration": input_data.poll_duration,
|
||||
"options": input_data.poll_options,
|
||||
}
|
||||
|
||||
# Alt text for images
|
||||
if input_data.alt_text:
|
||||
twitter_options["altText"] = input_data.alt_text
|
||||
|
||||
# Thread options
|
||||
if input_data.is_thread:
|
||||
twitter_options["thread"] = True
|
||||
|
||||
if input_data.thread_number:
|
||||
twitter_options["threadNumber"] = True
|
||||
|
||||
if input_data.thread_media_urls:
|
||||
twitter_options["mediaUrls"] = input_data.thread_media_urls
|
||||
|
||||
# Video subtitle options
|
||||
if input_data.subtitle_url:
|
||||
twitter_options["subTitleUrl"] = input_data.subtitle_url
|
||||
twitter_options["subTitleLanguage"] = input_data.subtitle_language
|
||||
twitter_options["subTitleName"] = input_data.subtitle_name
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.TWITTER],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
twitter_options=twitter_options if twitter_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -0,0 +1,310 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class YouTubeVisibility(str, Enum):
|
||||
PRIVATE = "private"
|
||||
PUBLIC = "public"
|
||||
UNLISTED = "unlisted"
|
||||
|
||||
|
||||
class PostToYouTubeBlock(Block):
|
||||
"""Block for posting to YouTube with YouTube-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for YouTube posts."""
|
||||
|
||||
# Override post field to include YouTube-specific information
|
||||
post: str = SchemaField(
|
||||
description="Video description (max 5,000 chars, empty string allowed). Cannot contain < or > characters.",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include YouTube-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Required video URL. YouTube only supports 1 video per post.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# YouTube-specific required options
|
||||
title: str = SchemaField(
|
||||
description="Video title (max 100 chars, required). Cannot contain < or > characters.",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# YouTube-specific optional options
|
||||
visibility: YouTubeVisibility = SchemaField(
|
||||
description="Video visibility: 'private' (default), 'public' , or 'unlisted'",
|
||||
default=YouTubeVisibility.PRIVATE,
|
||||
advanced=False,
|
||||
)
|
||||
thumbnail: str | None = SchemaField(
|
||||
description="Thumbnail URL (JPEG/PNG under 2MB, must end in .png/.jpg/.jpeg). Requires phone verification.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
playlist_id: str | None = SchemaField(
|
||||
description="Playlist ID to add video (user must own playlist)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
tags: list[str] | None = SchemaField(
|
||||
description="Video tags (min 2 chars each, max 500 chars total)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
made_for_kids: bool | None = SchemaField(
|
||||
description="Self-declared kids content", default=None, advanced=True
|
||||
)
|
||||
is_shorts: bool | None = SchemaField(
|
||||
description="Post as YouTube Short (max 3 minutes, adds #shorts)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
notify_subscribers: bool | None = SchemaField(
|
||||
description="Send notification to subscribers", default=None, advanced=True
|
||||
)
|
||||
category_id: int | None = SchemaField(
|
||||
description="Video category ID (e.g., 24 = Entertainment)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
contains_synthetic_media: bool | None = SchemaField(
|
||||
description="Disclose realistic AI/synthetic content",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
publish_at: str | None = SchemaField(
|
||||
description="UTC publish time (YouTube controlled, format: 2022-10-08T21:18:36Z)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
# YouTube targeting options (flattened from YouTubeTargeting object)
|
||||
targeting_block_countries: list[str] | None = SchemaField(
|
||||
description="Country codes to block from viewing (e.g., ['US', 'CA'])",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_allow_countries: list[str] | None = SchemaField(
|
||||
description="Country codes to allow viewing (e.g., ['GB', 'AU'])",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
subtitle_url: str | None = SchemaField(
|
||||
description="URL to SRT or SBV subtitle file (must be HTTPS and end in .srt/.sbv, under 100MB)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
subtitle_language: str | None = SchemaField(
|
||||
description="Language code for subtitles (default: 'en')",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
subtitle_name: str | None = SchemaField(
|
||||
description="Name of caption track (max 150 chars, default: 'English')",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0082d712-ff1b-4c3d-8a8d-6c7721883b83",
|
||||
description="Post to YouTube using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToYouTubeBlock.Input,
|
||||
output_schema=PostToYouTubeBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToYouTubeBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to YouTube with YouTube-specific validation and options."""
|
||||
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate YouTube constraints
|
||||
if not input_data.title:
|
||||
yield "error", "YouTube requires a video title"
|
||||
return
|
||||
|
||||
if len(input_data.title) > 100:
|
||||
yield "error", f"YouTube title exceeds 100 character limit ({len(input_data.title)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.post) > 5000:
|
||||
yield "error", f"YouTube description exceeds 5,000 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
# Check for forbidden characters
|
||||
forbidden_chars = ["<", ">"]
|
||||
for char in forbidden_chars:
|
||||
if char in input_data.title:
|
||||
yield "error", f"YouTube title cannot contain '{char}' character"
|
||||
return
|
||||
if char in input_data.post:
|
||||
yield "error", f"YouTube description cannot contain '{char}' character"
|
||||
return
|
||||
|
||||
if not input_data.media_urls:
|
||||
yield "error", "YouTube requires exactly one video URL"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 1:
|
||||
yield "error", "YouTube supports only 1 video per post"
|
||||
return
|
||||
|
||||
# Validate visibility option
|
||||
valid_visibility = ["private", "public", "unlisted"]
|
||||
if input_data.visibility not in valid_visibility:
|
||||
yield "error", f"YouTube visibility must be one of: {', '.join(valid_visibility)}"
|
||||
return
|
||||
|
||||
# Validate thumbnail URL format
|
||||
if input_data.thumbnail:
|
||||
valid_extensions = [".png", ".jpg", ".jpeg"]
|
||||
if not any(
|
||||
input_data.thumbnail.lower().endswith(ext) for ext in valid_extensions
|
||||
):
|
||||
yield "error", "YouTube thumbnail must end in .png, .jpg, or .jpeg"
|
||||
return
|
||||
|
||||
# Validate tags
|
||||
if input_data.tags:
|
||||
total_tag_length = sum(len(tag) for tag in input_data.tags)
|
||||
if total_tag_length > 500:
|
||||
yield "error", f"YouTube tags total length exceeds 500 characters ({total_tag_length} characters)"
|
||||
return
|
||||
|
||||
for tag in input_data.tags:
|
||||
if len(tag) < 2:
|
||||
yield "error", f"YouTube tag '{tag}' is too short (minimum 2 characters)"
|
||||
return
|
||||
|
||||
# Validate subtitle URL
|
||||
if input_data.subtitle_url:
|
||||
if not input_data.subtitle_url.startswith("https://"):
|
||||
yield "error", "YouTube subtitle URL must start with https://"
|
||||
return
|
||||
|
||||
valid_subtitle_extensions = [".srt", ".sbv"]
|
||||
if not any(
|
||||
input_data.subtitle_url.lower().endswith(ext)
|
||||
for ext in valid_subtitle_extensions
|
||||
):
|
||||
yield "error", "YouTube subtitle URL must end in .srt or .sbv"
|
||||
return
|
||||
|
||||
if input_data.subtitle_name and len(input_data.subtitle_name) > 150:
|
||||
yield "error", f"YouTube subtitle name exceeds 150 character limit ({len(input_data.subtitle_name)} characters)"
|
||||
return
|
||||
|
||||
# Validate publish_at format if provided
|
||||
if input_data.publish_at and input_data.schedule_date:
|
||||
yield "error", "Cannot use both 'publish_at' and 'schedule_date'. Use 'publish_at' for YouTube-controlled publishing."
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided (only if not using publish_at)
|
||||
iso_date = None
|
||||
if not input_data.publish_at and input_data.schedule_date:
|
||||
iso_date = input_data.schedule_date.isoformat()
|
||||
|
||||
# Build YouTube-specific options
|
||||
youtube_options: dict[str, Any] = {"title": input_data.title}
|
||||
|
||||
# Basic options
|
||||
if input_data.visibility != "private":
|
||||
youtube_options["visibility"] = input_data.visibility
|
||||
|
||||
if input_data.thumbnail:
|
||||
youtube_options["thumbNail"] = input_data.thumbnail
|
||||
|
||||
if input_data.playlist_id:
|
||||
youtube_options["playListId"] = input_data.playlist_id
|
||||
|
||||
if input_data.tags:
|
||||
youtube_options["tags"] = input_data.tags
|
||||
|
||||
if input_data.made_for_kids:
|
||||
youtube_options["madeForKids"] = True
|
||||
|
||||
if input_data.is_shorts:
|
||||
youtube_options["shorts"] = True
|
||||
|
||||
if not input_data.notify_subscribers:
|
||||
youtube_options["notifySubscribers"] = False
|
||||
|
||||
if input_data.category_id and input_data.category_id > 0:
|
||||
youtube_options["categoryId"] = input_data.category_id
|
||||
|
||||
if input_data.contains_synthetic_media:
|
||||
youtube_options["containsSyntheticMedia"] = True
|
||||
|
||||
if input_data.publish_at:
|
||||
youtube_options["publishAt"] = input_data.publish_at
|
||||
|
||||
# Country targeting (from flattened fields)
|
||||
targeting_dict = {}
|
||||
if input_data.targeting_block_countries:
|
||||
targeting_dict["block"] = input_data.targeting_block_countries
|
||||
if input_data.targeting_allow_countries:
|
||||
targeting_dict["allow"] = input_data.targeting_allow_countries
|
||||
|
||||
if targeting_dict:
|
||||
youtube_options["targeting"] = targeting_dict
|
||||
|
||||
# Subtitle options
|
||||
if input_data.subtitle_url:
|
||||
youtube_options["subTitleUrl"] = input_data.subtitle_url
|
||||
youtube_options["subTitleLanguage"] = input_data.subtitle_language
|
||||
youtube_options["subTitleName"] = input_data.subtitle_name
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.YOUTUBE],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=True, # YouTube only supports videos
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
youtube_options=youtube_options,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
205
autogpt_platform/backend/backend/blocks/baas/_api.py
Normal file
205
autogpt_platform/backend/backend/blocks/baas/_api.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
Meeting BaaS API client module.
|
||||
All API calls centralized for consistency and maintainability.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.sdk import Requests
|
||||
|
||||
|
||||
class MeetingBaasAPI:
|
||||
"""Client for Meeting BaaS API endpoints."""
|
||||
|
||||
BASE_URL = "https://api.meetingbaas.com"
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
"""Initialize API client with authentication key."""
|
||||
self.api_key = api_key
|
||||
self.headers = {"x-meeting-baas-api-key": api_key}
|
||||
self.requests = Requests()
|
||||
|
||||
# Bot Management Endpoints
|
||||
|
||||
async def join_meeting(
|
||||
self,
|
||||
bot_name: str,
|
||||
meeting_url: str,
|
||||
reserved: bool = False,
|
||||
bot_image: Optional[str] = None,
|
||||
entry_message: Optional[str] = None,
|
||||
start_time: Optional[int] = None,
|
||||
speech_to_text: Optional[Dict[str, Any]] = None,
|
||||
webhook_url: Optional[str] = None,
|
||||
automatic_leave: Optional[Dict[str, Any]] = None,
|
||||
extra: Optional[Dict[str, Any]] = None,
|
||||
recording_mode: str = "speaker_view",
|
||||
streaming: Optional[Dict[str, Any]] = None,
|
||||
deduplication_key: Optional[str] = None,
|
||||
zoom_sdk_id: Optional[str] = None,
|
||||
zoom_sdk_pwd: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Deploy a bot to join and record a meeting.
|
||||
|
||||
POST /bots
|
||||
"""
|
||||
body = {
|
||||
"bot_name": bot_name,
|
||||
"meeting_url": meeting_url,
|
||||
"reserved": reserved,
|
||||
"recording_mode": recording_mode,
|
||||
}
|
||||
|
||||
# Add optional fields if provided
|
||||
if bot_image is not None:
|
||||
body["bot_image"] = bot_image
|
||||
if entry_message is not None:
|
||||
body["entry_message"] = entry_message
|
||||
if start_time is not None:
|
||||
body["start_time"] = start_time
|
||||
if speech_to_text is not None:
|
||||
body["speech_to_text"] = speech_to_text
|
||||
if webhook_url is not None:
|
||||
body["webhook_url"] = webhook_url
|
||||
if automatic_leave is not None:
|
||||
body["automatic_leave"] = automatic_leave
|
||||
if extra is not None:
|
||||
body["extra"] = extra
|
||||
if streaming is not None:
|
||||
body["streaming"] = streaming
|
||||
if deduplication_key is not None:
|
||||
body["deduplication_key"] = deduplication_key
|
||||
if zoom_sdk_id is not None:
|
||||
body["zoom_sdk_id"] = zoom_sdk_id
|
||||
if zoom_sdk_pwd is not None:
|
||||
body["zoom_sdk_pwd"] = zoom_sdk_pwd
|
||||
|
||||
response = await self.requests.post(
|
||||
f"{self.BASE_URL}/bots",
|
||||
headers=self.headers,
|
||||
json=body,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def leave_meeting(self, bot_id: str) -> bool:
|
||||
"""
|
||||
Remove a bot from an ongoing meeting.
|
||||
|
||||
DELETE /bots/{uuid}
|
||||
"""
|
||||
response = await self.requests.delete(
|
||||
f"{self.BASE_URL}/bots/{bot_id}",
|
||||
headers=self.headers,
|
||||
)
|
||||
return response.status in [200, 204]
|
||||
|
||||
async def retranscribe(
|
||||
self,
|
||||
bot_uuid: str,
|
||||
speech_to_text: Optional[Dict[str, Any]] = None,
|
||||
webhook_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Re-run transcription on a bot's audio.
|
||||
|
||||
POST /bots/retranscribe
|
||||
"""
|
||||
body: Dict[str, Any] = {"bot_uuid": bot_uuid}
|
||||
|
||||
if speech_to_text is not None:
|
||||
body["speech_to_text"] = speech_to_text
|
||||
if webhook_url is not None:
|
||||
body["webhook_url"] = webhook_url
|
||||
|
||||
response = await self.requests.post(
|
||||
f"{self.BASE_URL}/bots/retranscribe",
|
||||
headers=self.headers,
|
||||
json=body,
|
||||
)
|
||||
|
||||
if response.status == 202:
|
||||
return {"accepted": True}
|
||||
return response.json()
|
||||
|
||||
# Data Retrieval Endpoints
|
||||
|
||||
async def get_meeting_data(
|
||||
self, bot_id: str, include_transcripts: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieve meeting data including recording and transcripts.
|
||||
|
||||
GET /bots/meeting_data
|
||||
"""
|
||||
params = {
|
||||
"bot_id": bot_id,
|
||||
"include_transcripts": str(include_transcripts).lower(),
|
||||
}
|
||||
|
||||
response = await self.requests.get(
|
||||
f"{self.BASE_URL}/bots/meeting_data",
|
||||
headers=self.headers,
|
||||
params=params,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def get_screenshots(self, bot_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve screenshots captured during a meeting.
|
||||
|
||||
GET /bots/{uuid}/screenshots
|
||||
"""
|
||||
response = await self.requests.get(
|
||||
f"{self.BASE_URL}/bots/{bot_id}/screenshots",
|
||||
headers=self.headers,
|
||||
)
|
||||
result = response.json()
|
||||
# Ensure we return a list
|
||||
if isinstance(result, list):
|
||||
return result
|
||||
return []
|
||||
|
||||
async def delete_data(self, bot_id: str) -> bool:
|
||||
"""
|
||||
Delete a bot's recorded data.
|
||||
|
||||
POST /bots/{uuid}/delete_data
|
||||
"""
|
||||
response = await self.requests.post(
|
||||
f"{self.BASE_URL}/bots/{bot_id}/delete_data",
|
||||
headers=self.headers,
|
||||
)
|
||||
return response.status == 200
|
||||
|
||||
async def list_bots_with_metadata(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
sort_by: Optional[str] = None,
|
||||
sort_order: Optional[str] = None,
|
||||
filter_by: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
List bots with metadata including IDs, names, and meeting details.
|
||||
|
||||
GET /bots/bots_with_metadata
|
||||
"""
|
||||
params = {}
|
||||
if limit is not None:
|
||||
params["limit"] = limit
|
||||
if offset is not None:
|
||||
params["offset"] = offset
|
||||
if sort_by is not None:
|
||||
params["sort_by"] = sort_by
|
||||
if sort_order is not None:
|
||||
params["sort_order"] = sort_order
|
||||
if filter_by is not None:
|
||||
params.update(filter_by)
|
||||
|
||||
response = await self.requests.get(
|
||||
f"{self.BASE_URL}/bots/bots_with_metadata",
|
||||
headers=self.headers,
|
||||
params=params,
|
||||
)
|
||||
return response.json()
|
||||
13
autogpt_platform/backend/backend/blocks/baas/_config.py
Normal file
13
autogpt_platform/backend/backend/blocks/baas/_config.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Shared configuration for all Meeting BaaS blocks using the SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
# Configure the Meeting BaaS provider with API key authentication
|
||||
baas = (
|
||||
ProviderBuilder("baas")
|
||||
.with_api_key("MEETING_BAAS_API_KEY", "Meeting BaaS API Key")
|
||||
.with_base_cost(5, BlockCostType.RUN) # Higher cost for meeting recording service
|
||||
.build()
|
||||
)
|
||||
218
autogpt_platform/backend/backend/blocks/baas/bots.py
Normal file
218
autogpt_platform/backend/backend/blocks/baas/bots.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""
|
||||
Meeting BaaS bot (recording) blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import MeetingBaasAPI
|
||||
from ._config import baas
|
||||
|
||||
|
||||
class BaasBotJoinMeetingBlock(Block):
|
||||
"""
|
||||
Deploy a bot immediately or at a scheduled start_time to join and record a meeting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
meeting_url: str = SchemaField(
|
||||
description="The URL of the meeting the bot should join"
|
||||
)
|
||||
bot_name: str = SchemaField(
|
||||
description="Display name for the bot in the meeting"
|
||||
)
|
||||
bot_image: str = SchemaField(
|
||||
description="URL to an image for the bot's avatar (16:9 ratio recommended)",
|
||||
default="",
|
||||
)
|
||||
entry_message: str = SchemaField(
|
||||
description="Chat message the bot will post upon entry", default=""
|
||||
)
|
||||
reserved: bool = SchemaField(
|
||||
description="Use a reserved bot slot (joins 4 min before meeting)",
|
||||
default=False,
|
||||
)
|
||||
start_time: Optional[int] = SchemaField(
|
||||
description="Unix timestamp (ms) when bot should join", default=None
|
||||
)
|
||||
webhook_url: str | None = SchemaField(
|
||||
description="URL to receive webhook events for this bot", default=None
|
||||
)
|
||||
timeouts: dict = SchemaField(
|
||||
description="Automatic leave timeouts configuration", default={}
|
||||
)
|
||||
extra: dict = SchemaField(
|
||||
description="Custom metadata to attach to the bot", default={}
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
bot_id: str = SchemaField(description="UUID of the deployed bot")
|
||||
join_response: dict = SchemaField(
|
||||
description="Full response from join operation"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="377d1a6a-a99b-46cf-9af3-1d1b12758e04",
|
||||
description="Deploy a bot to join and record a meeting",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
api = MeetingBaasAPI(api_key)
|
||||
|
||||
# Call API with all parameters
|
||||
data = await api.join_meeting(
|
||||
bot_name=input_data.bot_name,
|
||||
meeting_url=input_data.meeting_url,
|
||||
reserved=input_data.reserved,
|
||||
bot_image=input_data.bot_image if input_data.bot_image else None,
|
||||
entry_message=(
|
||||
input_data.entry_message if input_data.entry_message else None
|
||||
),
|
||||
start_time=input_data.start_time,
|
||||
speech_to_text={"provider": "Default"},
|
||||
webhook_url=input_data.webhook_url if input_data.webhook_url else None,
|
||||
automatic_leave=input_data.timeouts if input_data.timeouts else None,
|
||||
extra=input_data.extra if input_data.extra else None,
|
||||
)
|
||||
|
||||
yield "bot_id", data.get("bot_id", "")
|
||||
yield "join_response", data
|
||||
|
||||
|
||||
class BaasBotLeaveMeetingBlock(Block):
|
||||
"""
|
||||
Force the bot to exit the call.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot to remove from meeting")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
left: bool = SchemaField(description="Whether the bot successfully left")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bf77d128-8b25-4280-b5c7-2d553ba7e482",
|
||||
description="Remove a bot from an ongoing meeting",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
api = MeetingBaasAPI(api_key)
|
||||
|
||||
# Leave meeting
|
||||
left = await api.leave_meeting(input_data.bot_id)
|
||||
|
||||
yield "left", left
|
||||
|
||||
|
||||
class BaasBotFetchMeetingDataBlock(Block):
|
||||
"""
|
||||
Pull MP4 URL, transcript & metadata for a completed meeting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot whose data to fetch")
|
||||
include_transcripts: bool = SchemaField(
|
||||
description="Include transcript data in response", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
mp4_url: str = SchemaField(
|
||||
description="URL to download the meeting recording (time-limited)"
|
||||
)
|
||||
transcript: list = SchemaField(description="Meeting transcript data")
|
||||
metadata: dict = SchemaField(description="Meeting metadata and bot information")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ea7c1309-303c-4da1-893f-89c0e9d64e78",
|
||||
description="Retrieve recorded meeting data",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
api = MeetingBaasAPI(api_key)
|
||||
|
||||
# Fetch meeting data
|
||||
data = await api.get_meeting_data(
|
||||
bot_id=input_data.bot_id,
|
||||
include_transcripts=input_data.include_transcripts,
|
||||
)
|
||||
|
||||
yield "mp4_url", data.get("mp4", "")
|
||||
yield "transcript", data.get("bot_data", {}).get("transcripts", [])
|
||||
yield "metadata", data.get("bot_data", {}).get("bot", {})
|
||||
|
||||
|
||||
class BaasBotDeleteRecordingBlock(Block):
|
||||
"""
|
||||
Purge MP4 + transcript data for privacy or storage management.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot whose data to delete")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
deleted: bool = SchemaField(
|
||||
description="Whether the data was successfully deleted"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bf8d1aa6-42d8-4944-b6bd-6bac554c0d3b",
|
||||
description="Permanently delete a meeting's recorded data",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
api = MeetingBaasAPI(api_key)
|
||||
|
||||
# Delete recording data
|
||||
deleted = await api.delete_data(input_data.bot_id)
|
||||
|
||||
yield "deleted", deleted
|
||||
@@ -0,0 +1,3 @@
|
||||
from .text_overlay import BannerbearTextOverlayBlock
|
||||
|
||||
__all__ = ["BannerbearTextOverlayBlock"]
|
||||
@@ -0,0 +1,8 @@
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
bannerbear = (
|
||||
ProviderBuilder("bannerbear")
|
||||
.with_api_key("BANNERBEAR_API_KEY", "Bannerbear API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
@@ -0,0 +1,239 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import bannerbear
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="bannerbear",
|
||||
api_key=SecretStr("mock-bannerbear-api-key"),
|
||||
title="Mock Bannerbear API Key",
|
||||
)
|
||||
|
||||
|
||||
class TextModification(BlockSchemaInput):
|
||||
name: str = SchemaField(
|
||||
description="The name of the layer to modify in the template"
|
||||
)
|
||||
text: str = SchemaField(description="The text content to add to this layer")
|
||||
color: str = SchemaField(
|
||||
description="Hex color code for the text (e.g., '#FF0000')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
font_family: str = SchemaField(
|
||||
description="Font family to use for the text",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
font_size: int = SchemaField(
|
||||
description="Font size in pixels",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
font_weight: str = SchemaField(
|
||||
description="Font weight (e.g., 'bold', 'normal')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
text_align: str = SchemaField(
|
||||
description="Text alignment (left, center, right)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class BannerbearTextOverlayBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = bannerbear.credentials_field(
|
||||
description="API credentials for Bannerbear"
|
||||
)
|
||||
template_id: str = SchemaField(
|
||||
description="The unique ID of your Bannerbear template"
|
||||
)
|
||||
project_id: str = SchemaField(
|
||||
description="Optional: Project ID (required when using Master API Key)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
text_modifications: List[TextModification] = SchemaField(
|
||||
description="List of text layers to modify in the template"
|
||||
)
|
||||
image_url: str = SchemaField(
|
||||
description="Optional: URL of an image to use in the template",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
image_layer_name: str = SchemaField(
|
||||
description="Optional: Name of the image layer in the template",
|
||||
default="photo",
|
||||
advanced=True,
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="Optional: URL to receive webhook notification when image is ready",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
metadata: str = SchemaField(
|
||||
description="Optional: Custom metadata to attach to the image",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the image generation was successfully initiated"
|
||||
)
|
||||
image_url: str = SchemaField(
|
||||
description="URL of the generated image (if synchronous) or placeholder"
|
||||
)
|
||||
uid: str = SchemaField(description="Unique identifier for the generated image")
|
||||
status: str = SchemaField(description="Status of the image generation")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c7d3a5c2-05fc-450e-8dce-3b0e04626009",
|
||||
description="Add text overlay to images using Bannerbear templates. Perfect for creating social media graphics, marketing materials, and dynamic image content.",
|
||||
categories={BlockCategory.PRODUCTIVITY, BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"template_id": "jJWBKNELpQPvbX5R93Gk",
|
||||
"text_modifications": [
|
||||
{
|
||||
"name": "headline",
|
||||
"text": "Amazing Product Launch!",
|
||||
"color": "#FF0000",
|
||||
},
|
||||
{
|
||||
"name": "subtitle",
|
||||
"text": "50% OFF Today Only",
|
||||
},
|
||||
],
|
||||
"credentials": {
|
||||
"provider": "bannerbear",
|
||||
"id": str(uuid.uuid4()),
|
||||
"type": "api_key",
|
||||
},
|
||||
},
|
||||
test_output=[
|
||||
("success", True),
|
||||
("image_url", "https://cdn.bannerbear.com/test-image.jpg"),
|
||||
("uid", "test-uid-123"),
|
||||
("status", "completed"),
|
||||
],
|
||||
test_mock={
|
||||
"_make_api_request": lambda *args, **kwargs: {
|
||||
"uid": "test-uid-123",
|
||||
"status": "completed",
|
||||
"image_url": "https://cdn.bannerbear.com/test-image.jpg",
|
||||
}
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def _make_api_request(self, payload: dict, api_key: str) -> dict:
|
||||
"""Make the actual API request to Bannerbear. This is separated for easy mocking in tests."""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = await Requests().post(
|
||||
"https://sync.api.bannerbear.com/v2/images",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status in [200, 201, 202]:
|
||||
return response.json()
|
||||
else:
|
||||
error_msg = f"API request failed with status {response.status}"
|
||||
if response.text:
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_msg = (
|
||||
f"{error_msg}: {error_data.get('message', response.text)}"
|
||||
)
|
||||
except Exception:
|
||||
error_msg = f"{error_msg}: {response.text}"
|
||||
raise Exception(error_msg)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Build the modifications array
|
||||
modifications = []
|
||||
|
||||
# Add text modifications
|
||||
for text_mod in input_data.text_modifications:
|
||||
mod_data: Dict[str, Any] = {
|
||||
"name": text_mod.name,
|
||||
"text": text_mod.text,
|
||||
}
|
||||
|
||||
# Add optional text styling parameters only if they have values
|
||||
if text_mod.color and text_mod.color.strip():
|
||||
mod_data["color"] = text_mod.color
|
||||
if text_mod.font_family and text_mod.font_family.strip():
|
||||
mod_data["font_family"] = text_mod.font_family
|
||||
if text_mod.font_size and text_mod.font_size > 0:
|
||||
mod_data["font_size"] = text_mod.font_size
|
||||
if text_mod.font_weight and text_mod.font_weight.strip():
|
||||
mod_data["font_weight"] = text_mod.font_weight
|
||||
if text_mod.text_align and text_mod.text_align.strip():
|
||||
mod_data["text_align"] = text_mod.text_align
|
||||
|
||||
modifications.append(mod_data)
|
||||
|
||||
# Add image modification if provided and not empty
|
||||
if input_data.image_url and input_data.image_url.strip():
|
||||
modifications.append(
|
||||
{
|
||||
"name": input_data.image_layer_name,
|
||||
"image_url": input_data.image_url,
|
||||
}
|
||||
)
|
||||
|
||||
# Build the request payload - only include non-empty optional fields
|
||||
payload = {
|
||||
"template": input_data.template_id,
|
||||
"modifications": modifications,
|
||||
}
|
||||
|
||||
# Add project_id if provided (required for Master API keys)
|
||||
if input_data.project_id and input_data.project_id.strip():
|
||||
payload["project_id"] = input_data.project_id
|
||||
|
||||
if input_data.webhook_url and input_data.webhook_url.strip():
|
||||
payload["webhook_url"] = input_data.webhook_url
|
||||
if input_data.metadata and input_data.metadata.strip():
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
# Make the API request using the private method
|
||||
data = await self._make_api_request(
|
||||
payload, credentials.api_key.get_secret_value()
|
||||
)
|
||||
|
||||
# Synchronous request - image should be ready
|
||||
yield "success", True
|
||||
yield "image_url", data.get("image_url", "")
|
||||
yield "uid", data.get("uid", "")
|
||||
yield "status", data.get("status", "completed")
|
||||
@@ -1,16 +1,21 @@
|
||||
import enum
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.type import MediaFileType, convert
|
||||
|
||||
|
||||
class FileStoreBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
file_in: MediaFileType = SchemaField(
|
||||
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
|
||||
)
|
||||
@@ -21,7 +26,7 @@ class FileStoreBlock(Block):
|
||||
title="Produce Base64 Output",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
file_out: MediaFileType = SchemaField(
|
||||
description="The relative path to the stored file in the temporary directory."
|
||||
)
|
||||
@@ -41,11 +46,13 @@ class FileStoreBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
yield "file_out", await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.file_in,
|
||||
user_id=user_id,
|
||||
return_content=input_data.base_64,
|
||||
)
|
||||
|
||||
@@ -57,7 +64,7 @@ class StoreValueBlock(Block):
|
||||
The block output will be static, the output can be consumed multiple times.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
input: Any = SchemaField(
|
||||
description="Trigger the block to produce the output. "
|
||||
"The value is only used when `data` is None."
|
||||
@@ -68,7 +75,7 @@ class StoreValueBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
output: Any = SchemaField(description="The stored data retained in the block.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -94,10 +101,10 @@ class StoreValueBlock(Block):
|
||||
|
||||
|
||||
class PrintToConsoleBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
text: Any = SchemaField(description="The data to print to the console.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
output: Any = SchemaField(description="The data printed to the console.")
|
||||
status: str = SchemaField(description="The status of the print operation.")
|
||||
|
||||
@@ -120,271 +127,11 @@ class PrintToConsoleBlock(Block):
|
||||
yield "status", "printed"
|
||||
|
||||
|
||||
class FindInDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: Any = SchemaField(description="Dictionary to lookup from")
|
||||
key: str | int = SchemaField(description="Key to lookup in the dictionary")
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="Value found for the given key")
|
||||
missing: Any = SchemaField(
|
||||
description="Value of the input that missing the key"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0e50422c-6dee-4145-83d6-3a5a392f65de",
|
||||
description="Lookup the given key in the input dictionary/object/list and return the value.",
|
||||
input_schema=FindInDictionaryBlock.Input,
|
||||
output_schema=FindInDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{"input": {"apple": 1, "banana": 2, "cherry": 3}, "key": "banana"},
|
||||
{"input": {"x": 10, "y": 20, "z": 30}, "key": "w"},
|
||||
{"input": [1, 2, 3], "key": 1},
|
||||
{"input": [1, 2, 3], "key": 3},
|
||||
{"input": MockObject(value="!!", key="key"), "key": "key"},
|
||||
{"input": [{"k1": "v1"}, {"k2": "v2"}, {"k1": "v3"}], "key": "k1"},
|
||||
],
|
||||
test_output=[
|
||||
("output", 2),
|
||||
("missing", {"x": 10, "y": 20, "z": 30}),
|
||||
("output", 2),
|
||||
("missing", [1, 2, 3]),
|
||||
("output", "key"),
|
||||
("output", ["v1", "v3"]),
|
||||
],
|
||||
categories={BlockCategory.BASIC},
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
obj = input_data.input
|
||||
key = input_data.key
|
||||
|
||||
if isinstance(obj, str):
|
||||
obj = json.loads(obj)
|
||||
|
||||
if isinstance(obj, dict) and key in obj:
|
||||
yield "output", obj[key]
|
||||
elif isinstance(obj, list) and isinstance(key, int) and 0 <= key < len(obj):
|
||||
yield "output", obj[key]
|
||||
elif isinstance(obj, list) and isinstance(key, str):
|
||||
if len(obj) == 0:
|
||||
yield "output", []
|
||||
elif isinstance(obj[0], dict) and key in obj[0]:
|
||||
yield "output", [item[key] for item in obj if key in item]
|
||||
else:
|
||||
yield "output", [getattr(val, key) for val in obj if hasattr(val, key)]
|
||||
elif isinstance(obj, object) and isinstance(key, str) and hasattr(obj, key):
|
||||
yield "output", getattr(obj, key)
|
||||
else:
|
||||
yield "missing", input_data.input
|
||||
|
||||
|
||||
class AddToDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
|
||||
)
|
||||
key: str = SchemaField(
|
||||
default="",
|
||||
description="The key for the new entry.",
|
||||
placeholder="new_key",
|
||||
advanced=False,
|
||||
)
|
||||
value: Any = SchemaField(
|
||||
default=None,
|
||||
description="The value for the new entry.",
|
||||
placeholder="new_value",
|
||||
advanced=False,
|
||||
)
|
||||
entries: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
description="The entries to add to the dictionary. This is the batch version of the `key` and `value` fields.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_dictionary: dict = SchemaField(
|
||||
description="The dictionary with the new entry added."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="31d1064e-7446-4693-a7d4-65e5ca1180d1",
|
||||
description="Adds a new key-value pair to a dictionary. If no dictionary is provided, a new one is created.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=AddToDictionaryBlock.Input,
|
||||
output_schema=AddToDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"dictionary": {"existing_key": "existing_value"},
|
||||
"key": "new_key",
|
||||
"value": "new_value",
|
||||
},
|
||||
{"key": "first_key", "value": "first_value"},
|
||||
{
|
||||
"dictionary": {"existing_key": "existing_value"},
|
||||
"entries": {"new_key": "new_value", "first_key": "first_value"},
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"updated_dictionary",
|
||||
{"existing_key": "existing_value", "new_key": "new_value"},
|
||||
),
|
||||
("updated_dictionary", {"first_key": "first_value"}),
|
||||
(
|
||||
"updated_dictionary",
|
||||
{
|
||||
"existing_key": "existing_value",
|
||||
"new_key": "new_value",
|
||||
"first_key": "first_value",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
updated_dict = input_data.dictionary.copy()
|
||||
|
||||
if input_data.value is not None and input_data.key:
|
||||
updated_dict[input_data.key] = input_data.value
|
||||
|
||||
for key, value in input_data.entries.items():
|
||||
updated_dict[key] = value
|
||||
|
||||
yield "updated_dictionary", updated_dict
|
||||
|
||||
|
||||
class AddToListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
description="The list to add the entry to. If not provided, a new list will be created.",
|
||||
)
|
||||
entry: Any = SchemaField(
|
||||
description="The entry to add to the list. Can be of any type (string, int, dict, etc.).",
|
||||
advanced=False,
|
||||
default=None,
|
||||
)
|
||||
entries: List[Any] = SchemaField(
|
||||
default_factory=lambda: list(),
|
||||
description="The entries to add to the list. This is the batch version of the `entry` field.",
|
||||
advanced=True,
|
||||
)
|
||||
position: int | None = SchemaField(
|
||||
default=None,
|
||||
description="The position to insert the new entry. If not provided, the entry will be appended to the end of the list.",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_list: List[Any] = SchemaField(
|
||||
description="The list with the new entry added."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="aeb08fc1-2fc1-4141-bc8e-f758f183a822",
|
||||
description="Adds a new entry to a list. The entry can be of any type. If no list is provided, a new one is created.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=AddToListBlock.Input,
|
||||
output_schema=AddToListBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"list": [1, "string", {"existing_key": "existing_value"}],
|
||||
"entry": {"new_key": "new_value"},
|
||||
"position": 1,
|
||||
},
|
||||
{"entry": "first_entry"},
|
||||
{"list": ["a", "b", "c"], "entry": "d"},
|
||||
{
|
||||
"entry": "e",
|
||||
"entries": ["f", "g"],
|
||||
"list": ["a", "b"],
|
||||
"position": 1,
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"updated_list",
|
||||
[
|
||||
1,
|
||||
{"new_key": "new_value"},
|
||||
"string",
|
||||
{"existing_key": "existing_value"},
|
||||
],
|
||||
),
|
||||
("updated_list", ["first_entry"]),
|
||||
("updated_list", ["a", "b", "c", "d"]),
|
||||
("updated_list", ["a", "f", "g", "e", "b"]),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
entries_added = input_data.entries.copy()
|
||||
if input_data.entry:
|
||||
entries_added.append(input_data.entry)
|
||||
|
||||
updated_list = input_data.list.copy()
|
||||
if (pos := input_data.position) is not None:
|
||||
updated_list = updated_list[:pos] + entries_added + updated_list[pos:]
|
||||
else:
|
||||
updated_list += entries_added
|
||||
|
||||
yield "updated_list", updated_list
|
||||
|
||||
|
||||
class FindInListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to search in.")
|
||||
value: Any = SchemaField(description="The value to search for.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
index: int = SchemaField(description="The index of the value in the list.")
|
||||
found: bool = SchemaField(
|
||||
description="Whether the value was found in the list."
|
||||
)
|
||||
not_found_value: Any = SchemaField(
|
||||
description="The value that was not found in the list."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5e2c6d0a-1e37-489f-b1d0-8e1812b23333",
|
||||
description="Finds the index of the value in the list.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=FindInListBlock.Input,
|
||||
output_schema=FindInListBlock.Output,
|
||||
test_input=[
|
||||
{"list": [1, 2, 3, 4, 5], "value": 3},
|
||||
{"list": [1, 2, 3, 4, 5], "value": 6},
|
||||
],
|
||||
test_output=[
|
||||
("index", 2),
|
||||
("found", True),
|
||||
("found", False),
|
||||
("not_found_value", 6),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
yield "index", input_data.list.index(input_data.value)
|
||||
yield "found", True
|
||||
except ValueError:
|
||||
yield "found", False
|
||||
yield "not_found_value", input_data.value
|
||||
|
||||
|
||||
class NoteBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
text: str = SchemaField(description="The text to display in the sticky note.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
output: str = SchemaField(description="The text to display in the sticky note.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -405,110 +152,6 @@ class NoteBlock(Block):
|
||||
yield "output", input_data.text
|
||||
|
||||
|
||||
class CreateDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: dict[str, Any] = SchemaField(
|
||||
description="Key-value pairs to create the dictionary with",
|
||||
placeholder="e.g., {'name': 'Alice', 'age': 25}",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
dictionary: dict[str, Any] = SchemaField(
|
||||
description="The created dictionary containing the specified key-value pairs"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if dictionary creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b924ddf4-de4f-4b56-9a85-358930dcbc91",
|
||||
description="Creates a dictionary with the specified key-value pairs. Use this when you know all the values you want to add upfront.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=CreateDictionaryBlock.Input,
|
||||
output_schema=CreateDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"values": {"name": "Alice", "age": 25, "city": "New York"},
|
||||
},
|
||||
{
|
||||
"values": {"numbers": [1, 2, 3], "active": True, "score": 95.5},
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"dictionary",
|
||||
{"name": "Alice", "age": 25, "city": "New York"},
|
||||
),
|
||||
(
|
||||
"dictionary",
|
||||
{"numbers": [1, 2, 3], "active": True, "score": 95.5},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# The values are already validated by Pydantic schema
|
||||
yield "dictionary", input_data.values
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create dictionary: {str(e)}"
|
||||
|
||||
|
||||
class CreateListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: List[Any] = SchemaField(
|
||||
description="A list of values to be combined into a new list.",
|
||||
placeholder="e.g., ['Alice', 25, True]",
|
||||
)
|
||||
max_size: int | None = SchemaField(
|
||||
default=None,
|
||||
description="Maximum size of the list. If provided, the list will be yielded in chunks of this size.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
description="The created list containing the specified values."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if list creation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a912d5c7-6e00-4542-b2a9-8034136930e4",
|
||||
description="Creates a list with the specified values. Use this when you know all the values you want to add upfront.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=CreateListBlock.Input,
|
||||
output_schema=CreateListBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"values": ["Alice", 25, True],
|
||||
},
|
||||
{
|
||||
"values": [1, 2, 3, "four", {"key": "value"}],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"list",
|
||||
["Alice", 25, True],
|
||||
),
|
||||
(
|
||||
"list",
|
||||
[1, 2, 3, "four", {"key": "value"}],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
max_size = input_data.max_size or len(input_data.values)
|
||||
for i in range(0, len(input_data.values), max_size):
|
||||
yield "list", input_data.values[i : i + max_size]
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create list: {str(e)}"
|
||||
|
||||
|
||||
class TypeOptions(enum.Enum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
@@ -518,13 +161,13 @@ class TypeOptions(enum.Enum):
|
||||
|
||||
|
||||
class UniversalTypeConverterBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
value: Any = SchemaField(
|
||||
description="The value to convert to a universal type."
|
||||
)
|
||||
type: TypeOptions = SchemaField(description="The type to convert the value to.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
value: Any = SchemaField(description="The converted value.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -551,3 +194,31 @@ class UniversalTypeConverterBlock(Block):
|
||||
yield "value", converted_value
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to convert value: {str(e)}"
|
||||
|
||||
|
||||
class ReverseListOrderBlock(Block):
|
||||
"""
|
||||
A block which takes in a list and returns it in the opposite order.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
input_list: list[Any] = SchemaField(description="The list to reverse")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
reversed_list: list[Any] = SchemaField(description="The list in reversed order")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="422cb708-3109-4277-bfe3-bc2ae5812777",
|
||||
description="Reverses the order of elements in a list",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=ReverseListOrderBlock.Input,
|
||||
output_schema=ReverseListOrderBlock.Output,
|
||||
test_input={"input_list": [1, 2, 3, 4, 5]},
|
||||
test_output=[("reversed_list", [5, 4, 3, 2, 1])],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
reversed_list = list(input_data.input_list)
|
||||
reversed_list.reverse()
|
||||
yield "reversed_list", reversed_list
|
||||
|
||||
@@ -2,7 +2,13 @@ import os
|
||||
import re
|
||||
from typing import Type
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
@@ -15,12 +21,12 @@ class BlockInstallationBlock(Block):
|
||||
for development purposes only.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Input(BlockSchemaInput):
|
||||
code: str = SchemaField(
|
||||
description="Python code of the block to be installed",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class Output(BlockSchemaOutput):
|
||||
success: str = SchemaField(
|
||||
description="Success message if the block is installed successfully",
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user