mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-13 00:58:16 -05:00
Compare commits
4 Commits
gmail-repl
...
autogpt-rs
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
894e3600fb | ||
|
|
9de4b09f20 | ||
|
|
62e41d409a | ||
|
|
9f03e3af47 |
@@ -15,7 +15,6 @@
|
||||
!autogpt_platform/backend/pyproject.toml
|
||||
!autogpt_platform/backend/poetry.lock
|
||||
!autogpt_platform/backend/README.md
|
||||
!autogpt_platform/backend/.env
|
||||
|
||||
# Platform - Market
|
||||
!autogpt_platform/market/market/
|
||||
@@ -28,7 +27,6 @@
|
||||
# Platform - Frontend
|
||||
!autogpt_platform/frontend/src/
|
||||
!autogpt_platform/frontend/public/
|
||||
!autogpt_platform/frontend/scripts/
|
||||
!autogpt_platform/frontend/package.json
|
||||
!autogpt_platform/frontend/pnpm-lock.yaml
|
||||
!autogpt_platform/frontend/tsconfig.json
|
||||
@@ -36,7 +34,6 @@
|
||||
## config
|
||||
!autogpt_platform/frontend/*.config.*
|
||||
!autogpt_platform/frontend/.env.*
|
||||
!autogpt_platform/frontend/.env
|
||||
|
||||
# Classic - AutoGPT
|
||||
!classic/original_autogpt/autogpt/
|
||||
|
||||
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -24,8 +24,7 @@
|
||||
</details>
|
||||
|
||||
#### For configuration changes:
|
||||
|
||||
- [ ] `.env.default` is updated or already compatible with my changes
|
||||
- [ ] `.env.example` is updated or already compatible with my changes
|
||||
- [ ] `docker-compose.yml` is updated or already compatible with my changes
|
||||
- [ ] I have included a list of my configuration changes in the PR description (under **Changes**)
|
||||
|
||||
|
||||
244
.github/copilot-instructions.md
vendored
244
.github/copilot-instructions.md
vendored
@@ -1,244 +0,0 @@
|
||||
# GitHub Copilot Instructions for AutoGPT
|
||||
|
||||
This file provides comprehensive onboarding information for GitHub Copilot coding agent to work efficiently with the AutoGPT repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
**AutoGPT** is a powerful platform for creating, deploying, and managing continuous AI agents that automate complex workflows. This is a large monorepo (~150MB) containing multiple components:
|
||||
|
||||
- **AutoGPT Platform** (`autogpt_platform/`) - Main focus: Modern AI agent platform (Polyform Shield License)
|
||||
- **Classic AutoGPT** (`classic/`) - Legacy agent system (MIT License)
|
||||
- **Documentation** (`docs/`) - MkDocs-based documentation site
|
||||
- **Infrastructure** - Docker configurations, CI/CD, and development tools
|
||||
|
||||
**Primary Languages & Frameworks:**
|
||||
- **Backend**: Python 3.10-3.13, FastAPI, Prisma ORM, PostgreSQL, RabbitMQ
|
||||
- **Frontend**: TypeScript, Next.js 15, React, Tailwind CSS, Radix UI
|
||||
- **Development**: Docker, Poetry, pnpm, Playwright, Storybook
|
||||
|
||||
## Build and Validation Instructions
|
||||
|
||||
### Essential Setup Commands
|
||||
|
||||
**Always run these commands in the correct directory and in this order:**
|
||||
|
||||
1. **Initial Setup** (required once):
|
||||
```bash
|
||||
# Clone and enter repository
|
||||
git clone <repo> && cd AutoGPT
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
```
|
||||
|
||||
2. **Backend Setup** (always run before backend development):
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry install # Install dependencies
|
||||
poetry run prisma migrate dev # Run database migrations
|
||||
poetry run prisma generate # Generate Prisma client
|
||||
```
|
||||
|
||||
3. **Frontend Setup** (always run before frontend development):
|
||||
```bash
|
||||
cd autogpt_platform/frontend
|
||||
pnpm install # Install dependencies
|
||||
```
|
||||
|
||||
### Runtime Requirements
|
||||
|
||||
**Critical:** Always ensure Docker services are running before starting development:
|
||||
```bash
|
||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
```
|
||||
|
||||
**Python Version:** Use Python 3.11 (required; managed by Poetry via pyproject.toml)
|
||||
**Node.js Version:** Use Node.js 21+ with pnpm package manager
|
||||
|
||||
### Development Commands
|
||||
|
||||
**Backend Development:**
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry run serve # Start development server (port 8000)
|
||||
poetry run test # Run all tests (requires ~5 minutes)
|
||||
poetry run pytest path/to/test.py # Run specific test
|
||||
poetry run format # Format code (Black + isort) - always run first
|
||||
poetry run lint # Lint code (ruff) - run after format
|
||||
```
|
||||
|
||||
**Frontend Development:**
|
||||
```bash
|
||||
cd autogpt_platform/frontend
|
||||
pnpm dev # Start development server (port 3000) - use for active development
|
||||
pnpm build # Build for production (only needed for E2E tests or deployment)
|
||||
pnpm test # Run Playwright E2E tests (requires build first)
|
||||
pnpm test-ui # Run tests with UI
|
||||
pnpm format # Format and lint code
|
||||
pnpm storybook # Start component development server
|
||||
```
|
||||
|
||||
### Testing Strategy
|
||||
|
||||
**Backend Tests:**
|
||||
- **Block Tests**: `poetry run pytest backend/blocks/test/test_block.py -xvs` (validates all blocks)
|
||||
- **Specific Block**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[BlockName]' -xvs`
|
||||
- **Snapshot Tests**: Use `--snapshot-update` when output changes, always review with `git diff`
|
||||
|
||||
**Frontend Tests:**
|
||||
- **E2E Tests**: Always run `pnpm dev` before `pnpm test` (Playwright requires running instance)
|
||||
- **Component Tests**: Use Storybook for isolated component development
|
||||
|
||||
### Critical Validation Steps
|
||||
|
||||
**Before committing changes:**
|
||||
1. Run `poetry run format` (backend) and `pnpm format` (frontend)
|
||||
2. Ensure all tests pass in modified areas
|
||||
3. Verify Docker services are still running
|
||||
4. Check that database migrations apply cleanly
|
||||
|
||||
**Common Issues & Workarounds:**
|
||||
- **Prisma issues**: Run `poetry run prisma generate` after schema changes
|
||||
- **Permission errors**: Ensure Docker has proper permissions
|
||||
- **Port conflicts**: Check the `docker-compose.yml` file for the current list of exposed ports. You can list all mapped ports with:
|
||||
- **Test timeouts**: Backend tests can take 5+ minutes, use `-x` flag to stop on first failure
|
||||
|
||||
## Project Layout & Architecture
|
||||
|
||||
### Core Architecture
|
||||
|
||||
**AutoGPT Platform** (`autogpt_platform/`):
|
||||
- `backend/` - FastAPI server with async support
|
||||
- `backend/backend/` - Core API logic
|
||||
- `backend/blocks/` - Agent execution blocks
|
||||
- `backend/data/` - Database models and schemas
|
||||
- `schema.prisma` - Database schema definition
|
||||
- `frontend/` - Next.js application
|
||||
- `src/app/` - App Router pages and layouts
|
||||
- `src/components/` - Reusable React components
|
||||
- `src/lib/` - Utilities and configurations
|
||||
- `autogpt_libs/` - Shared Python utilities
|
||||
- `docker-compose.yml` - Development stack orchestration
|
||||
|
||||
**Key Configuration Files:**
|
||||
- `pyproject.toml` - Python dependencies and tooling
|
||||
- `package.json` - Node.js dependencies and scripts
|
||||
- `schema.prisma` - Database schema and migrations
|
||||
- `next.config.mjs` - Next.js configuration
|
||||
- `tailwind.config.ts` - Styling configuration
|
||||
|
||||
### Security & Middleware
|
||||
|
||||
**Cache Protection**: Backend includes middleware preventing sensitive data caching in browsers/proxies
|
||||
**Authentication**: JWT-based with Supabase integration
|
||||
**User ID Validation**: All data access requires user ID checks - verify this for any `data/*.py` changes
|
||||
|
||||
### Development Workflow
|
||||
|
||||
**GitHub Actions**: Multiple CI/CD workflows in `.github/workflows/`
|
||||
- `platform-backend-ci.yml` - Backend testing and validation
|
||||
- `platform-frontend-ci.yml` - Frontend testing and validation
|
||||
- `platform-fullstack-ci.yml` - End-to-end integration tests
|
||||
|
||||
**Pre-commit Hooks**: Run linting and formatting checks
|
||||
**Conventional Commits**: Use format `type(scope): description` (e.g., `feat(backend): add API`)
|
||||
|
||||
### Key Source Files
|
||||
|
||||
**Backend Entry Points:**
|
||||
- `backend/backend/server/server.py` - FastAPI application setup
|
||||
- `backend/backend/data/` - Database models and user management
|
||||
- `backend/blocks/` - Agent execution blocks and logic
|
||||
|
||||
**Frontend Entry Points:**
|
||||
- `frontend/src/app/layout.tsx` - Root application layout
|
||||
- `frontend/src/app/page.tsx` - Home page
|
||||
- `frontend/src/lib/supabase/` - Authentication and database client
|
||||
|
||||
**Protected Routes**: Update `frontend/lib/supabase/middleware.ts` when adding protected routes
|
||||
|
||||
### Agent Block System
|
||||
|
||||
Agents are built using a visual block-based system where each block performs a single action. Blocks are defined in `backend/blocks/` and must include:
|
||||
- Block definition with input/output schemas
|
||||
- Execution logic with proper error handling
|
||||
- Tests validating functionality
|
||||
|
||||
### Database & ORM
|
||||
|
||||
**Prisma ORM** with PostgreSQL backend including pgvector for embeddings:
|
||||
- Schema in `schema.prisma`
|
||||
- Migrations in `backend/migrations/`
|
||||
- Always run `prisma migrate dev` and `prisma generate` after schema changes
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
### Configuration Files Priority Order
|
||||
1. **Backend**: `/backend/.env.default` → `/backend/.env` (user overrides)
|
||||
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
|
||||
3. **Platform**: `/.env.default` (Supabase/shared) → `/.env` (user overrides)
|
||||
4. Docker Compose `environment:` sections override file-based config
|
||||
5. Shell environment variables have highest precedence
|
||||
|
||||
### Docker Environment Setup
|
||||
- All services use hardcoded defaults (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Copy `.env.default` files to `.env` for local development customization
|
||||
|
||||
## Advanced Development Patterns
|
||||
|
||||
### Adding New Blocks
|
||||
1. Create file in `/backend/backend/blocks/`
|
||||
2. Inherit from `Block` base class with input/output schemas
|
||||
3. Implement `run` method with proper error handling
|
||||
4. Generate block UUID using `uuid.uuid4()`
|
||||
5. Register in block registry
|
||||
6. Write tests alongside block implementation
|
||||
7. Consider how inputs/outputs connect with other blocks in graph editor
|
||||
|
||||
### API Development
|
||||
1. Update routes in `/backend/backend/server/routers/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside route files
|
||||
4. For `data/*.py` changes, validate user ID checks
|
||||
5. Run `poetry run test` to verify changes
|
||||
|
||||
### Frontend Development
|
||||
1. Components in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for component development
|
||||
4. Test user-facing features with Playwright E2E tests
|
||||
5. Update protected routes in middleware when needed
|
||||
|
||||
### Security Guidelines
|
||||
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
|
||||
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
||||
- Prevents sensitive data caching in browsers/proxies
|
||||
- Add new cacheable endpoints to `CACHEABLE_PATHS`
|
||||
|
||||
### CI/CD Alignment
|
||||
The repository has comprehensive CI workflows that test:
|
||||
- **Backend**: Python 3.11-3.13, services (Redis/RabbitMQ/ClamAV), Prisma migrations, Poetry lock validation
|
||||
- **Frontend**: Node.js 21, pnpm, Playwright with Docker Compose stack, API schema validation
|
||||
- **Integration**: Full-stack type checking and E2E testing
|
||||
|
||||
Match these patterns when developing locally - the copilot setup environment mirrors these CI configurations.
|
||||
|
||||
## Collaboration with Other AI Assistants
|
||||
|
||||
This repository is actively developed with assistance from Claude (via CLAUDE.md files). When working on this codebase:
|
||||
- Check for existing CLAUDE.md files that provide additional context
|
||||
- Follow established patterns and conventions already in the codebase
|
||||
- Maintain consistency with existing code style and architecture
|
||||
- Consider that changes may be reviewed and extended by both human developers and AI assistants
|
||||
|
||||
## Trust These Instructions
|
||||
|
||||
These instructions are comprehensive and tested. Only perform additional searches if:
|
||||
1. Information here is incomplete for your specific task
|
||||
2. You encounter errors not covered by the workarounds
|
||||
3. You need to understand implementation details not covered above
|
||||
|
||||
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.
|
||||
302
.github/workflows/copilot-setup-steps.yml
vendored
302
.github/workflows/copilot-setup-steps.yml
vendored
@@ -1,302 +0,0 @@
|
||||
name: "Copilot Setup Steps"
|
||||
|
||||
# Automatically run the setup steps when they are changed to allow for easy validation, and
|
||||
# allow manual testing through the repository's "Actions" tab
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
paths:
|
||||
- .github/workflows/copilot-setup-steps.yml
|
||||
pull_request:
|
||||
paths:
|
||||
- .github/workflows/copilot-setup-steps.yml
|
||||
|
||||
jobs:
|
||||
# The job MUST be called `copilot-setup-steps` or it will not be picked up by Copilot.
|
||||
copilot-setup-steps:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
|
||||
# Set the permissions to the lowest permissions possible needed for your steps.
|
||||
# Copilot will be given its own token for its operations.
|
||||
permissions:
|
||||
# If you want to clone the repository as part of your setup steps, for example to install dependencies, you'll need the `contents: read` permission. If you don't clone the repository in your setup steps, Copilot will do this for you automatically after the steps complete.
|
||||
contents: read
|
||||
|
||||
# You can define any steps you want, and they will run before the agent starts.
|
||||
# If you do not check out your code, Copilot will do this for you.
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock (matches CI)
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
# Add Poetry to PATH
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Check poetry.lock
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry lock
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
echo "Warning: poetry.lock not up to date, but continuing for setup"
|
||||
git checkout poetry.lock # Reset for clean setup
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Install Playwright browsers for frontend testing
|
||||
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
|
||||
# - name: Install Playwright browsers
|
||||
# working-directory: autogpt_platform/frontend
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Copy default environment files
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Copy default environment files for development
|
||||
cp .env.default .env
|
||||
cp backend/.env.default backend/.env
|
||||
cp frontend/.env.default frontend/.env
|
||||
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
|
||||
restore-keys: |
|
||||
docker-images-v2-${{ runner.os }}-
|
||||
docker-images-v1-${{ runner.os }}-
|
||||
|
||||
- name: Load or pull Docker images
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
mkdir -p ~/docker-cache
|
||||
|
||||
# Define image list for easy maintenance
|
||||
IMAGES=(
|
||||
"redis:latest"
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||
echo "Docker cache found, loading images in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||
echo "Loading $image..."
|
||||
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||
fi
|
||||
done
|
||||
wait
|
||||
echo "All cached images loaded"
|
||||
else
|
||||
echo "No Docker cache found, pulling images in parallel..."
|
||||
# Pull all images in parallel
|
||||
for image in "${IMAGES[@]}"; do
|
||||
docker pull "$image" &
|
||||
done
|
||||
wait
|
||||
|
||||
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||
echo "Saving Docker images to cache in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
echo "Saving $image..."
|
||||
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||
done
|
||||
wait
|
||||
echo "Docker image cache saved"
|
||||
else
|
||||
echo "Skipping cache save for PR/feature branch"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Docker images ready for use"
|
||||
|
||||
# Phase 2: Build migrate service with GitHub Actions cache
|
||||
- name: Build migrate Docker image with cache
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Build the migrate image with buildx for GHA caching
|
||||
docker buildx build \
|
||||
--cache-from type=gha \
|
||||
--cache-to type=gha,mode=max \
|
||||
--target migrate \
|
||||
--tag autogpt_platform-migrate:latest \
|
||||
--load \
|
||||
-f backend/Dockerfile \
|
||||
..
|
||||
|
||||
# Start services using pre-built images
|
||||
- name: Start Docker services for development
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Start essential services (migrate image already built with correct tag)
|
||||
docker compose --profile local up deps --no-build --detach
|
||||
echo "Waiting for services to be ready..."
|
||||
|
||||
# Wait for database to be ready
|
||||
echo "Checking database readiness..."
|
||||
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
|
||||
echo " Waiting for database..."
|
||||
sleep 2
|
||||
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
|
||||
|
||||
# Check migrate service status
|
||||
echo "Checking migration status..."
|
||||
docker compose ps migrate || echo " Migrate service not visible in ps output"
|
||||
|
||||
# Wait for migrate service to complete
|
||||
echo "Waiting for migrations to complete..."
|
||||
timeout 30 bash -c '
|
||||
ATTEMPTS=0
|
||||
while [ $ATTEMPTS -lt 15 ]; do
|
||||
ATTEMPTS=$((ATTEMPTS + 1))
|
||||
|
||||
# Check using docker directly (more reliable than docker compose ps)
|
||||
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
|
||||
|
||||
if [ -z "$CONTAINER_STATUS" ]; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
|
||||
echo "✅ Migrations completed successfully"
|
||||
docker compose logs migrate --tail=5 2>/dev/null || true
|
||||
exit 0
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
|
||||
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
|
||||
echo "❌ Migrations failed with exit code: $EXIT_CODE"
|
||||
echo "Migration logs:"
|
||||
docker compose logs migrate --tail=20 2>/dev/null || true
|
||||
exit 1
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
|
||||
else
|
||||
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
|
||||
echo "Final container check:"
|
||||
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
|
||||
echo "Migration logs (if available):"
|
||||
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
|
||||
' || echo "⚠️ Migration check completed with warnings, continuing..."
|
||||
|
||||
# Brief wait for other services to stabilize
|
||||
echo "Waiting 5 seconds for other services to stabilize..."
|
||||
sleep 5
|
||||
|
||||
# Verify installations and provide environment info
|
||||
- name: Verify setup and show environment info
|
||||
run: |
|
||||
echo "=== Python Setup ==="
|
||||
python --version
|
||||
poetry --version
|
||||
|
||||
echo "=== Node.js Setup ==="
|
||||
node --version
|
||||
pnpm --version
|
||||
|
||||
echo "=== Additional Tools ==="
|
||||
docker --version
|
||||
docker compose version
|
||||
gh --version || true
|
||||
|
||||
echo "=== Services Status ==="
|
||||
cd autogpt_platform
|
||||
docker compose ps || true
|
||||
|
||||
echo "=== Backend Dependencies ==="
|
||||
cd backend
|
||||
poetry show | head -10 || true
|
||||
|
||||
echo "=== Frontend Dependencies ==="
|
||||
cd ../frontend
|
||||
pnpm list --depth=0 | head -10 || true
|
||||
|
||||
echo "=== Environment Files ==="
|
||||
ls -la ../.env* || true
|
||||
ls -la .env* || true
|
||||
ls -la ../backend/.env* || true
|
||||
|
||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
||||
echo "🚀 Ready for development with Docker services running"
|
||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11", "3.12", "3.13"]
|
||||
python-version: ["3.11"]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
|
||||
46
.github/workflows/platform-frontend-ci.yml
vendored
46
.github/workflows/platform-frontend-ci.yml
vendored
@@ -82,6 +82,37 @@ jobs:
|
||||
- name: Run lint
|
||||
run: pnpm lint
|
||||
|
||||
type-check:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run tsc check
|
||||
run: pnpm type-check
|
||||
|
||||
chromatic:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
@@ -145,7 +176,11 @@ jobs:
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
cp ../.env.example ../.env
|
||||
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
cp ../backend/.env.example ../backend/.env
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -217,6 +252,15 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Setup .env
|
||||
run: cp .env.example .env
|
||||
|
||||
- name: Build frontend
|
||||
run: pnpm build --turbo
|
||||
# uses Turbopack, much faster and safe enough for a test pipeline
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Install Browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
|
||||
132
.github/workflows/platform-fullstack-ci.yml
vendored
132
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -1,132 +0,0 @@
|
||||
name: AutoGPT Platform - Frontend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev]
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- "autogpt_platform/**"
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- "autogpt_platform/**"
|
||||
merge_group:
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: autogpt_platform/frontend
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Generate cache key
|
||||
id: cache-key
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
types:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Setup .env
|
||||
run: cp .env.default .env
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||
|
||||
- name: Generate API queries
|
||||
run: pnpm generate:api:force
|
||||
|
||||
- name: Check for API schema changes
|
||||
run: |
|
||||
if ! git diff --exit-code src/app/api/openapi.json; then
|
||||
echo "❌ API schema changes detected in src/app/api/openapi.json"
|
||||
echo ""
|
||||
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
|
||||
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
|
||||
echo "The API schema is now out of sync with the Front-end queries."
|
||||
echo ""
|
||||
echo "To fix this:"
|
||||
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
|
||||
echo "2. Run 'pnpm generate:api' locally"
|
||||
echo "3. Run 'pnpm types' locally"
|
||||
echo "4. Fix any TypeScript errors that may have been introduced"
|
||||
echo "5. Commit and push your changes"
|
||||
echo ""
|
||||
exit 1
|
||||
else
|
||||
echo "✅ No API schema changes detected"
|
||||
fi
|
||||
|
||||
- name: Run Typescript checks
|
||||
run: pnpm types
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -5,8 +5,6 @@ classic/original_autogpt/*.json
|
||||
auto_gpt_workspace/*
|
||||
*.mpeg
|
||||
.env
|
||||
# Root .env files
|
||||
/.env
|
||||
azure.yaml
|
||||
.vscode
|
||||
.idea/*
|
||||
@@ -123,6 +121,7 @@ celerybeat.pid
|
||||
|
||||
# Environments
|
||||
.direnv/
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv*/
|
||||
|
||||
@@ -235,7 +235,7 @@ repos:
|
||||
hooks:
|
||||
- id: tsc
|
||||
name: Typecheck - AutoGPT Platform - Frontend
|
||||
entry: bash -c 'cd autogpt_platform/frontend && pnpm types'
|
||||
entry: bash -c 'cd autogpt_platform/frontend && pnpm type-check'
|
||||
files: ^autogpt_platform/frontend/
|
||||
types: [file]
|
||||
language: system
|
||||
|
||||
10
README.md
10
README.md
@@ -3,16 +3,6 @@
|
||||
[](https://discord.gg/autogpt)  
|
||||
[](https://twitter.com/Auto_GPT)  
|
||||
|
||||
<!-- Keep these links. Translations will automatically update with the README. -->
|
||||
[Deutsch](https://zdoc.app/de/Significant-Gravitas/AutoGPT) |
|
||||
[Español](https://zdoc.app/es/Significant-Gravitas/AutoGPT) |
|
||||
[français](https://zdoc.app/fr/Significant-Gravitas/AutoGPT) |
|
||||
[日本語](https://zdoc.app/ja/Significant-Gravitas/AutoGPT) |
|
||||
[한국어](https://zdoc.app/ko/Significant-Gravitas/AutoGPT) |
|
||||
[Português](https://zdoc.app/pt/Significant-Gravitas/AutoGPT) |
|
||||
[Русский](https://zdoc.app/ru/Significant-Gravitas/AutoGPT) |
|
||||
[中文](https://zdoc.app/zh/Significant-Gravitas/AutoGPT)
|
||||
|
||||
**AutoGPT** is a powerful platform that allows you to create, deploy, and manage continuous AI agents that automate complex workflows.
|
||||
|
||||
## Hosting Options
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`/backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`/frontend`): Next.js React application
|
||||
- **Shared Libraries** (`/autogpt_libs`): Common Python utilities
|
||||
@@ -13,7 +11,6 @@ AutoGPT Platform is a monorepo containing:
|
||||
## Essential Commands
|
||||
|
||||
### Backend Development
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd backend && poetry install
|
||||
@@ -33,18 +30,11 @@ poetry run test
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in TESTING.md
|
||||
|
||||
#### Creating/Updating Snapshots
|
||||
@@ -57,8 +47,8 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
### Frontend Development
|
||||
|
||||
### Frontend Development
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd frontend && npm install
|
||||
@@ -76,13 +66,12 @@ npm run storybook
|
||||
npm run build
|
||||
|
||||
# Type checking
|
||||
npm run types
|
||||
npm run type-check
|
||||
```
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Backend Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
@@ -91,7 +80,6 @@ npm run types
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
### Frontend Architecture
|
||||
|
||||
- **Framework**: Next.js App Router with React Server Components
|
||||
- **State Management**: React hooks + Supabase client for real-time updates
|
||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
||||
@@ -99,7 +87,6 @@ npm run types
|
||||
- **Feature Flags**: LaunchDarkly integration
|
||||
|
||||
### Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
@@ -107,16 +94,13 @@ npm run types
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Testing Approach
|
||||
|
||||
- Backend uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Frontend uses Playwright for E2E tests
|
||||
- Component testing via Storybook
|
||||
|
||||
### Database Schema
|
||||
|
||||
Key models (defined in `/backend/schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
@@ -124,31 +108,13 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides)
|
||||
- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides)
|
||||
- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
1. `.env.default` files provide base configuration (tracked in git)
|
||||
2. `.env` files provide user-specific overrides (gitignored)
|
||||
3. Docker Compose `environment:` sections provide service-specific overrides
|
||||
4. Shell environment variables have highest precedence
|
||||
|
||||
#### Key Points
|
||||
|
||||
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
- Backend: `.env` file in `/backend`
|
||||
- Frontend: `.env.local` file in `/frontend`
|
||||
- Both require Supabase credentials and API keys for various services
|
||||
|
||||
### Common Development Tasks
|
||||
|
||||
**Adding a new block:**
|
||||
|
||||
1. Create new file in `/backend/backend/blocks/`
|
||||
2. Inherit from `Block` base class
|
||||
3. Define input/output schemas
|
||||
@@ -156,18 +122,13 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
5. Register in block registry
|
||||
6. Generate the block uuid using `uuid.uuid4()`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
**Modifying the API:**
|
||||
|
||||
1. Update route in `/backend/backend/server/routers/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
**Frontend feature development:**
|
||||
|
||||
1. Components go in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for new components
|
||||
@@ -176,7 +137,6 @@ ex: do the inputs and outputs tie well together?
|
||||
### Security Implementation
|
||||
|
||||
**Cache Protection Middleware:**
|
||||
|
||||
- Located in `/backend/backend/server/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
@@ -184,47 +144,3 @@ ex: do the inputs and outputs tie well together?
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR aginst the `dev` branch of the repository.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/
|
||||
- Use conventional commit messages (see below)/
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
- When the user runs /pr-comments or tries to fetch them, also run gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews to get the reviews
|
||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews/[review_id]/comments to get the review contents
|
||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/issues/9924/comments to get the pr specific comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
Use this format for commit messages and Pull Request titles:
|
||||
|
||||
**Conventional Commit Types:**
|
||||
|
||||
- `feat`: Introduces a new feature to the codebase
|
||||
- `fix`: Patches a bug in the codebase
|
||||
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
|
||||
- `ci`: Changes to CI configuration
|
||||
- `docs`: Documentation-only changes
|
||||
- `dx`: Improvements to the developer experience
|
||||
|
||||
**Recommended Base Scopes:**
|
||||
|
||||
- `platform`: Changes affecting both frontend and backend
|
||||
- `frontend`
|
||||
- `backend`
|
||||
- `infra`
|
||||
- `blocks`: Modifications/additions of individual blocks
|
||||
|
||||
**Subscope Examples:**
|
||||
|
||||
- `backend/executor`
|
||||
- `backend/db`
|
||||
- `frontend/builder` (includes changes to the block UI component)
|
||||
- `infra/prod`
|
||||
|
||||
Use these scopes and subscopes for clarity and consistency in commit messages.
|
||||
|
||||
@@ -8,6 +8,7 @@ Welcome to the AutoGPT Platform - a powerful system for creating and running AI
|
||||
|
||||
- Docker
|
||||
- Docker Compose V2 (comes with Docker Desktop, or can be installed separately)
|
||||
- Node.js & NPM (for running the frontend application)
|
||||
|
||||
### Running the System
|
||||
|
||||
@@ -23,10 +24,10 @@ To run the AutoGPT Platform, follow these steps:
|
||||
2. Run the following command:
|
||||
|
||||
```
|
||||
cp .env.default .env
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
This command will copy the `.env.default` file to `.env`. You can modify the `.env` file to add your own environment variables.
|
||||
This command will copy the `.env.example` file to `.env`. You can modify the `.env` file to add your own environment variables.
|
||||
|
||||
3. Run the following command:
|
||||
|
||||
@@ -36,7 +37,44 @@ To run the AutoGPT Platform, follow these steps:
|
||||
|
||||
This command will start all the necessary backend services defined in the `docker-compose.yml` file in detached mode.
|
||||
|
||||
4. After all the services are in ready state, open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
4. Navigate to `frontend` within the `autogpt_platform` directory:
|
||||
|
||||
```
|
||||
cd frontend
|
||||
```
|
||||
|
||||
You will need to run your frontend application separately on your local machine.
|
||||
|
||||
5. Run the following command:
|
||||
|
||||
```
|
||||
cp .env.example .env.local
|
||||
```
|
||||
|
||||
This command will copy the `.env.example` file to `.env.local` in the `frontend` directory. You can modify the `.env.local` within this folder to add your own environment variables for the frontend application.
|
||||
|
||||
6. Run the following command:
|
||||
|
||||
Enable corepack and install dependencies by running:
|
||||
|
||||
```
|
||||
corepack enable
|
||||
pnpm i
|
||||
```
|
||||
|
||||
Generate the API client (this step is required before running the frontend):
|
||||
|
||||
```
|
||||
pnpm generate:api-client
|
||||
```
|
||||
|
||||
Then start the frontend application in development mode:
|
||||
|
||||
```
|
||||
pnpm dev
|
||||
```
|
||||
|
||||
7. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
|
||||
### Docker Compose Commands
|
||||
|
||||
@@ -139,21 +177,20 @@ The platform includes scripts for generating and managing the API client:
|
||||
|
||||
- `pnpm fetch:openapi`: Fetches the OpenAPI specification from the backend service (requires backend to be running on port 8006)
|
||||
- `pnpm generate:api-client`: Generates the TypeScript API client from the OpenAPI specification using Orval
|
||||
- `pnpm generate:api`: Runs both fetch and generate commands in sequence
|
||||
- `pnpm generate:api-all`: Runs both fetch and generate commands in sequence
|
||||
|
||||
#### Manual API Client Updates
|
||||
|
||||
If you need to update the API client after making changes to the backend API:
|
||||
|
||||
1. Ensure the backend services are running:
|
||||
|
||||
```
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
2. Generate the updated API client:
|
||||
```
|
||||
pnpm generate:api
|
||||
pnpm generate:api-all
|
||||
```
|
||||
|
||||
This will fetch the latest OpenAPI specification and regenerate the TypeScript client code.
|
||||
|
||||
802
autogpt_platform/autogpt-rs/DATABASE_MANAGER.md
Normal file
802
autogpt_platform/autogpt-rs/DATABASE_MANAGER.md
Normal file
@@ -0,0 +1,802 @@
|
||||
# DatabaseManager Technical Specification
|
||||
|
||||
## Executive Summary
|
||||
|
||||
This document provides a complete technical specification for implementing a drop-in replacement for the AutoGPT Platform's DatabaseManager service. The replacement must maintain 100% API compatibility while preserving all functional behaviors, security requirements, and performance characteristics.
|
||||
|
||||
## 1. System Overview
|
||||
|
||||
### 1.1 Purpose
|
||||
The DatabaseManager is a centralized service that provides database access for the AutoGPT Platform's executor system. It encapsulates all database operations behind a service interface, enabling distributed execution while maintaining data consistency and security.
|
||||
|
||||
### 1.2 Architecture Pattern
|
||||
- **Service Type**: HTTP-based microservice using FastAPI
|
||||
- **Communication**: RPC-style over HTTP with JSON serialization
|
||||
- **Base Class**: Inherits from `AppService` (backend.util.service)
|
||||
- **Client Classes**: `DatabaseManagerClient` (sync) and `DatabaseManagerAsyncClient` (async)
|
||||
- **Port**: Configurable via `config.database_api_port`
|
||||
|
||||
### 1.3 Critical Requirements
|
||||
1. **API Compatibility**: All 40+ exposed methods must maintain exact signatures
|
||||
2. **Type Safety**: Full type preservation across service boundaries
|
||||
3. **User Isolation**: All operations must respect user_id boundaries
|
||||
4. **Transaction Support**: Maintain ACID properties for critical operations
|
||||
5. **Event Publishing**: Maintain Redis event bus integration for real-time updates
|
||||
|
||||
## 2. Service Implementation Requirements
|
||||
|
||||
### 2.1 Base Service Class
|
||||
|
||||
```python
|
||||
from backend.util.service import AppService, expose
|
||||
from backend.util.settings import Config
|
||||
from backend.data import db
|
||||
import logging
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
"""
|
||||
REQUIRED: Inherit from AppService to get:
|
||||
- Automatic endpoint generation via @expose decorator
|
||||
- Built-in health checks at /health
|
||||
- Request/response serialization
|
||||
- Error handling and logging
|
||||
"""
|
||||
|
||||
def run_service(self) -> None:
|
||||
"""REQUIRED: Initialize database connection before starting service"""
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
|
||||
self.run_and_wait(db.connect()) # CRITICAL: Must connect to database
|
||||
super().run_service() # Start HTTP server
|
||||
|
||||
def cleanup(self):
|
||||
"""REQUIRED: Clean disconnect on shutdown"""
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
|
||||
self.run_and_wait(db.disconnect()) # CRITICAL: Must disconnect cleanly
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
"""REQUIRED: Return configured port"""
|
||||
return config.database_api_port
|
||||
```
|
||||
|
||||
### 2.2 Method Exposure Pattern
|
||||
|
||||
```python
|
||||
@staticmethod
|
||||
def _(f: Callable[P, R], name: str | None = None) -> Callable[Concatenate[object, P], R]:
|
||||
"""
|
||||
REQUIRED: Helper to expose methods with proper signatures
|
||||
- Preserves function name for endpoint generation
|
||||
- Maintains type information
|
||||
- Adds 'self' parameter for instance binding
|
||||
"""
|
||||
if name is not None:
|
||||
f.__name__ = name
|
||||
return cast(Callable[Concatenate[object, P], R], expose(f))
|
||||
```
|
||||
|
||||
### 2.3 Database Connection Management
|
||||
|
||||
**REQUIRED: Use Prisma ORM with these exact configurations:**
|
||||
|
||||
```python
|
||||
from prisma import Prisma
|
||||
|
||||
prisma = Prisma(
|
||||
auto_register=True,
|
||||
http={"timeout": HTTP_TIMEOUT}, # Default: 120 seconds
|
||||
datasource={"url": DATABASE_URL}
|
||||
)
|
||||
|
||||
# Connection lifecycle
|
||||
async def connect():
|
||||
await prisma.connect()
|
||||
|
||||
async def disconnect():
|
||||
await prisma.disconnect()
|
||||
```
|
||||
|
||||
### 2.4 Transaction Support
|
||||
|
||||
**REQUIRED: Implement both regular and locked transactions:**
|
||||
|
||||
```python
|
||||
async def transaction(timeout: float | None = None):
|
||||
"""Regular database transaction"""
|
||||
async with prisma.tx(timeout=timeout) as tx:
|
||||
yield tx
|
||||
|
||||
async def locked_transaction(key: str, timeout: float | None = None):
|
||||
"""Transaction with PostgreSQL advisory lock"""
|
||||
lock_key = zlib.crc32(key.encode("utf-8"))
|
||||
async with transaction(timeout=timeout) as tx:
|
||||
await tx.execute_raw("SELECT pg_advisory_xact_lock($1)", lock_key)
|
||||
yield tx
|
||||
```
|
||||
|
||||
## 3. Complete API Specification
|
||||
|
||||
### 3.1 Execution Management APIs
|
||||
|
||||
#### get_graph_execution
|
||||
```python
|
||||
async def get_graph_execution(
|
||||
user_id: str,
|
||||
execution_id: str,
|
||||
*,
|
||||
include_node_executions: bool = False
|
||||
) -> GraphExecution | GraphExecutionWithNodes | None
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns execution only if user_id matches
|
||||
- Optionally includes all node executions
|
||||
- Returns None if not found or unauthorized
|
||||
|
||||
#### get_graph_executions
|
||||
```python
|
||||
async def get_graph_executions(
|
||||
user_id: str,
|
||||
graph_id: str | None = None,
|
||||
*,
|
||||
limit: int = 50,
|
||||
graph_version: int | None = None,
|
||||
cursor: str | None = None,
|
||||
preset_id: str | None = None
|
||||
) -> tuple[list[GraphExecution], str | None]
|
||||
```
|
||||
**Behavior**:
|
||||
- Paginated results with cursor
|
||||
- Filter by graph_id, version, or preset_id
|
||||
- Returns (executions, next_cursor)
|
||||
|
||||
#### create_graph_execution
|
||||
```python
|
||||
async def create_graph_execution(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
starting_nodes_input: dict[str, dict[str, Any]],
|
||||
user_id: str,
|
||||
preset_id: str | None = None
|
||||
) -> GraphExecutionWithNodes
|
||||
```
|
||||
**Behavior**:
|
||||
- Creates execution with status "QUEUED"
|
||||
- Initializes all nodes with "PENDING" status
|
||||
- Publishes creation event to Redis
|
||||
- Uses locked transaction on graph_id
|
||||
|
||||
#### update_graph_execution_start_time
|
||||
```python
|
||||
async def update_graph_execution_start_time(
|
||||
graph_exec_id: str
|
||||
) -> None
|
||||
```
|
||||
**Behavior**:
|
||||
- Sets start_time to current timestamp
|
||||
- Only updates if currently NULL
|
||||
|
||||
#### update_graph_execution_stats
|
||||
```python
|
||||
async def update_graph_execution_stats(
|
||||
graph_exec_id: str,
|
||||
status: AgentExecutionStatus | None = None,
|
||||
stats: dict[str, Any] | None = None
|
||||
) -> GraphExecution | None
|
||||
```
|
||||
**Behavior**:
|
||||
- Updates status and/or stats atomically
|
||||
- Sets end_time if status is terminal (COMPLETED/FAILED)
|
||||
- Publishes update event to Redis
|
||||
- Returns updated execution
|
||||
|
||||
#### get_node_execution
|
||||
```python
|
||||
async def get_node_execution(
|
||||
node_exec_id: str
|
||||
) -> NodeExecutionResult | None
|
||||
```
|
||||
**Behavior**:
|
||||
- No user_id check (relies on graph execution security)
|
||||
- Includes all input/output data
|
||||
|
||||
#### get_node_executions
|
||||
```python
|
||||
async def get_node_executions(
|
||||
graph_exec_id: str
|
||||
) -> list[NodeExecutionResult]
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns all node executions for graph
|
||||
- Ordered by creation time
|
||||
|
||||
#### get_latest_node_execution
|
||||
```python
|
||||
async def get_latest_node_execution(
|
||||
graph_exec_id: str,
|
||||
node_id: str
|
||||
) -> NodeExecutionResult | None
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns most recent execution of specific node
|
||||
- Used for retry/rerun scenarios
|
||||
|
||||
#### update_node_execution_status
|
||||
```python
|
||||
async def update_node_execution_status(
|
||||
node_exec_id: str,
|
||||
status: AgentExecutionStatus,
|
||||
execution_data: dict[str, Any] | None = None,
|
||||
stats: dict[str, Any] | None = None
|
||||
) -> NodeExecutionResult
|
||||
```
|
||||
**Behavior**:
|
||||
- Updates status atomically
|
||||
- Sets end_time for terminal states
|
||||
- Optionally updates stats/data
|
||||
- Publishes event to Redis
|
||||
- Returns updated execution
|
||||
|
||||
#### update_node_execution_status_batch
|
||||
```python
|
||||
async def update_node_execution_status_batch(
|
||||
execution_updates: list[NodeExecutionUpdate]
|
||||
) -> list[NodeExecutionResult]
|
||||
```
|
||||
**Behavior**:
|
||||
- Batch update multiple nodes in single transaction
|
||||
- Each update can have different status/stats
|
||||
- Publishes events for all updates
|
||||
- Returns all updated executions
|
||||
|
||||
#### update_node_execution_stats
|
||||
```python
|
||||
async def update_node_execution_stats(
|
||||
node_exec_id: str,
|
||||
stats: dict[str, Any]
|
||||
) -> NodeExecutionResult
|
||||
```
|
||||
**Behavior**:
|
||||
- Updates only stats field
|
||||
- Merges with existing stats
|
||||
- Does not affect status
|
||||
|
||||
#### upsert_execution_input
|
||||
```python
|
||||
async def upsert_execution_input(
|
||||
node_id: str,
|
||||
graph_exec_id: str,
|
||||
input_name: str,
|
||||
input_data: Any,
|
||||
node_exec_id: str | None = None
|
||||
) -> tuple[str, BlockInput]
|
||||
```
|
||||
**Behavior**:
|
||||
- Creates or updates input data
|
||||
- If node_exec_id not provided, creates node execution
|
||||
- Serializes input_data to JSON
|
||||
- Returns (node_exec_id, input_object)
|
||||
|
||||
#### upsert_execution_output
|
||||
```python
|
||||
async def upsert_execution_output(
|
||||
node_exec_id: str,
|
||||
output_name: str,
|
||||
output_data: Any
|
||||
) -> None
|
||||
```
|
||||
**Behavior**:
|
||||
- Creates or updates output data
|
||||
- Serializes output_data to JSON
|
||||
- No return value
|
||||
|
||||
#### get_execution_kv_data
|
||||
```python
|
||||
async def get_execution_kv_data(
|
||||
user_id: str,
|
||||
key: str
|
||||
) -> Any | None
|
||||
```
|
||||
**Behavior**:
|
||||
- User-scoped key-value storage
|
||||
- Returns deserialized JSON data
|
||||
- Returns None if key not found
|
||||
|
||||
#### set_execution_kv_data
|
||||
```python
|
||||
async def set_execution_kv_data(
|
||||
user_id: str,
|
||||
node_exec_id: str,
|
||||
key: str,
|
||||
data: Any
|
||||
) -> Any | None
|
||||
```
|
||||
**Behavior**:
|
||||
- Sets user-scoped key-value data
|
||||
- Associates with node execution
|
||||
- Serializes data to JSON
|
||||
- Returns previous value or None
|
||||
|
||||
#### get_block_error_stats
|
||||
```python
|
||||
async def get_block_error_stats() -> list[BlockErrorStats]
|
||||
```
|
||||
**Behavior**:
|
||||
- Aggregates error counts by block_id
|
||||
- Last 7 days of data
|
||||
- Groups by error type
|
||||
|
||||
### 3.2 Graph Management APIs
|
||||
|
||||
#### get_node
|
||||
```python
|
||||
async def get_node(
|
||||
node_id: str
|
||||
) -> AgentNode | None
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns node with block data
|
||||
- No user_id check (public blocks)
|
||||
|
||||
#### get_graph
|
||||
```python
|
||||
async def get_graph(
|
||||
graph_id: str,
|
||||
version: int | None = None,
|
||||
user_id: str | None = None,
|
||||
for_export: bool = False,
|
||||
include_subgraphs: bool = False
|
||||
) -> GraphModel | None
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns latest version if version=None
|
||||
- Checks user_id for private graphs
|
||||
- for_export=True excludes internal fields
|
||||
- include_subgraphs=True loads nested graphs
|
||||
|
||||
#### get_connected_output_nodes
|
||||
```python
|
||||
async def get_connected_output_nodes(
|
||||
node_id: str,
|
||||
output_name: str
|
||||
) -> list[tuple[AgentNode, AgentNodeLink]]
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns downstream nodes connected to output
|
||||
- Includes link metadata
|
||||
- Used for execution flow
|
||||
|
||||
#### get_graph_metadata
|
||||
```python
|
||||
async def get_graph_metadata(
|
||||
graph_id: str,
|
||||
user_id: str
|
||||
) -> GraphMetadata | None
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns graph metadata without full definition
|
||||
- User must own or have access to graph
|
||||
|
||||
### 3.3 Credit System APIs
|
||||
|
||||
#### get_credits
|
||||
```python
|
||||
async def get_credits(
|
||||
user_id: str
|
||||
) -> int
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns current credit balance
|
||||
- Always non-negative
|
||||
|
||||
#### spend_credits
|
||||
```python
|
||||
async def spend_credits(
|
||||
user_id: str,
|
||||
cost: int,
|
||||
metadata: UsageTransactionMetadata
|
||||
) -> int
|
||||
```
|
||||
**Behavior**:
|
||||
- Deducts credits atomically
|
||||
- Creates transaction record
|
||||
- Throws InsufficientCredits if balance too low
|
||||
- Returns new balance
|
||||
- metadata includes: block_id, node_exec_id, context
|
||||
|
||||
### 3.4 User Management APIs
|
||||
|
||||
#### get_user_metadata
|
||||
```python
|
||||
async def get_user_metadata(
|
||||
user_id: str
|
||||
) -> UserMetadata
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns user preferences and settings
|
||||
- Creates default if not exists
|
||||
|
||||
#### update_user_metadata
|
||||
```python
|
||||
async def update_user_metadata(
|
||||
user_id: str,
|
||||
data: UserMetadataDTO
|
||||
) -> UserMetadata
|
||||
```
|
||||
**Behavior**:
|
||||
- Partial update of metadata
|
||||
- Validates against schema
|
||||
- Returns updated metadata
|
||||
|
||||
#### get_user_integrations
|
||||
```python
|
||||
async def get_user_integrations(
|
||||
user_id: str
|
||||
) -> UserIntegrations
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns OAuth credentials
|
||||
- Decrypts sensitive data
|
||||
- Creates empty if not exists
|
||||
|
||||
#### update_user_integrations
|
||||
```python
|
||||
async def update_user_integrations(
|
||||
user_id: str,
|
||||
data: UserIntegrations
|
||||
) -> None
|
||||
```
|
||||
**Behavior**:
|
||||
- Updates integration credentials
|
||||
- Encrypts sensitive data
|
||||
- No return value
|
||||
|
||||
### 3.5 User Communication APIs
|
||||
|
||||
#### get_active_user_ids_in_timerange
|
||||
```python
|
||||
async def get_active_user_ids_in_timerange(
|
||||
start_time: datetime,
|
||||
end_time: datetime
|
||||
) -> list[str]
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns users with graph executions in range
|
||||
- Used for analytics/notifications
|
||||
|
||||
#### get_user_email_by_id
|
||||
```python
|
||||
async def get_user_email_by_id(
|
||||
user_id: str
|
||||
) -> str | None
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns user's email address
|
||||
- None if user not found
|
||||
|
||||
#### get_user_email_verification
|
||||
```python
|
||||
async def get_user_email_verification(
|
||||
user_id: str
|
||||
) -> UserEmailVerification
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns email and verification status
|
||||
- Used for notification filtering
|
||||
|
||||
#### get_user_notification_preference
|
||||
```python
|
||||
async def get_user_notification_preference(
|
||||
user_id: str
|
||||
) -> NotificationPreference
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns notification settings
|
||||
- Creates default if not exists
|
||||
|
||||
### 3.6 Notification APIs
|
||||
|
||||
#### create_or_add_to_user_notification_batch
|
||||
```python
|
||||
async def create_or_add_to_user_notification_batch(
|
||||
user_id: str,
|
||||
notification_type: NotificationType,
|
||||
notification_data: NotificationEvent
|
||||
) -> UserNotificationBatchDTO
|
||||
```
|
||||
**Behavior**:
|
||||
- Adds to existing batch or creates new
|
||||
- Batches by type for efficiency
|
||||
- Returns updated batch
|
||||
|
||||
#### empty_user_notification_batch
|
||||
```python
|
||||
async def empty_user_notification_batch(
|
||||
user_id: str,
|
||||
notification_type: NotificationType
|
||||
) -> None
|
||||
```
|
||||
**Behavior**:
|
||||
- Clears all notifications of type
|
||||
- Used after sending batch
|
||||
|
||||
#### get_all_batches_by_type
|
||||
```python
|
||||
async def get_all_batches_by_type(
|
||||
notification_type: NotificationType
|
||||
) -> list[UserNotificationBatchDTO]
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns all user batches of type
|
||||
- Used by notification service
|
||||
|
||||
#### get_user_notification_batch
|
||||
```python
|
||||
async def get_user_notification_batch(
|
||||
user_id: str,
|
||||
notification_type: NotificationType
|
||||
) -> UserNotificationBatchDTO | None
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns user's batch for type
|
||||
- None if no batch exists
|
||||
|
||||
#### get_user_notification_oldest_message_in_batch
|
||||
```python
|
||||
async def get_user_notification_oldest_message_in_batch(
|
||||
user_id: str,
|
||||
notification_type: NotificationType
|
||||
) -> NotificationEvent | None
|
||||
```
|
||||
**Behavior**:
|
||||
- Returns oldest notification in batch
|
||||
- Used for batch timing decisions
|
||||
|
||||
## 4. Client Implementation Requirements
|
||||
|
||||
### 4.1 Synchronous Client
|
||||
|
||||
```python
|
||||
class DatabaseManagerClient(AppServiceClient):
|
||||
"""
|
||||
REQUIRED: Synchronous client that:
|
||||
- Converts async methods to sync using endpoint_to_sync
|
||||
- Maintains exact method signatures
|
||||
- Handles connection pooling
|
||||
- Implements retry logic
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return DatabaseManager
|
||||
|
||||
# Example method mapping
|
||||
get_graph_execution = endpoint_to_sync(DatabaseManager.get_graph_execution)
|
||||
```
|
||||
|
||||
### 4.2 Asynchronous Client
|
||||
|
||||
```python
|
||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
"""
|
||||
REQUIRED: Async client that:
|
||||
- Directly references async methods
|
||||
- No conversion needed
|
||||
- Shares connection pool
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return DatabaseManager
|
||||
|
||||
# Direct method reference
|
||||
get_graph_execution = DatabaseManager.get_graph_execution
|
||||
```
|
||||
|
||||
## 5. Data Models
|
||||
|
||||
### 5.1 Core Enums
|
||||
|
||||
```python
|
||||
class AgentExecutionStatus(str, Enum):
|
||||
PENDING = "PENDING"
|
||||
QUEUED = "QUEUED"
|
||||
RUNNING = "RUNNING"
|
||||
COMPLETED = "COMPLETED"
|
||||
FAILED = "FAILED"
|
||||
CANCELED = "CANCELED"
|
||||
|
||||
class NotificationType(str, Enum):
|
||||
SYSTEM = "SYSTEM"
|
||||
REVIEW = "REVIEW"
|
||||
EXECUTION = "EXECUTION"
|
||||
MARKETING = "MARKETING"
|
||||
```
|
||||
|
||||
### 5.2 Key Data Models
|
||||
|
||||
All models must exactly match the Prisma schema definitions. Key models include:
|
||||
|
||||
- `GraphExecution`: Execution metadata with stats
|
||||
- `GraphExecutionWithNodes`: Includes all node executions
|
||||
- `NodeExecutionResult`: Node execution with I/O data
|
||||
- `GraphModel`: Complete graph definition
|
||||
- `UserIntegrations`: OAuth credentials
|
||||
- `UsageTransactionMetadata`: Credit usage context
|
||||
- `NotificationEvent`: Individual notification data
|
||||
|
||||
## 6. Security Requirements
|
||||
|
||||
### 6.1 User Isolation
|
||||
- **CRITICAL**: All user-scoped operations MUST filter by user_id
|
||||
- Never expose data across user boundaries
|
||||
- Use database-level row security where possible
|
||||
|
||||
### 6.2 Authentication
|
||||
- Service assumes authentication handled by API gateway
|
||||
- user_id parameter is trusted after authentication
|
||||
- No additional auth checks within service
|
||||
|
||||
### 6.3 Data Protection
|
||||
- Encrypt sensitive integration credentials
|
||||
- Use HMAC for unsubscribe tokens
|
||||
- Never log sensitive data
|
||||
|
||||
## 7. Performance Requirements
|
||||
|
||||
### 7.1 Connection Management
|
||||
- Maintain persistent database connection
|
||||
- Use connection pooling (default: 10 connections)
|
||||
- Implement exponential backoff for retries
|
||||
|
||||
### 7.2 Query Optimization
|
||||
- Use indexes for all WHERE clauses
|
||||
- Batch operations where possible
|
||||
- Limit default result sets (50 items)
|
||||
|
||||
### 7.3 Event Publishing
|
||||
- Publish events asynchronously
|
||||
- Don't block on event delivery
|
||||
- Use fire-and-forget pattern
|
||||
|
||||
## 8. Error Handling
|
||||
|
||||
### 8.1 Standard Exceptions
|
||||
```python
|
||||
class InsufficientCredits(Exception):
|
||||
"""Raised when user lacks credits"""
|
||||
|
||||
class NotFoundError(Exception):
|
||||
"""Raised when entity not found"""
|
||||
|
||||
class AuthorizationError(Exception):
|
||||
"""Raised when user lacks access"""
|
||||
```
|
||||
|
||||
### 8.2 Error Response Format
|
||||
```json
|
||||
{
|
||||
"error": "error_type",
|
||||
"message": "Human readable message",
|
||||
"details": {} // Optional additional context
|
||||
}
|
||||
```
|
||||
|
||||
## 9. Testing Requirements
|
||||
|
||||
### 9.1 Unit Tests
|
||||
- Test each method in isolation
|
||||
- Mock database calls
|
||||
- Verify user_id filtering
|
||||
|
||||
### 9.2 Integration Tests
|
||||
- Test with real database
|
||||
- Verify transaction boundaries
|
||||
- Test concurrent operations
|
||||
|
||||
### 9.3 Service Tests
|
||||
- Test HTTP endpoint generation
|
||||
- Verify serialization/deserialization
|
||||
- Test error handling
|
||||
|
||||
## 10. Implementation Checklist
|
||||
|
||||
### Phase 1: Core Service Setup
|
||||
- [ ] Create DatabaseManager class inheriting from AppService
|
||||
- [ ] Implement run_service() with database connection
|
||||
- [ ] Implement cleanup() with proper disconnect
|
||||
- [ ] Configure port from settings
|
||||
- [ ] Set up method exposure helper
|
||||
|
||||
### Phase 2: Execution APIs (15 methods)
|
||||
- [ ] get_graph_execution
|
||||
- [ ] get_graph_executions
|
||||
- [ ] get_graph_execution_meta
|
||||
- [ ] create_graph_execution
|
||||
- [ ] update_graph_execution_start_time
|
||||
- [ ] update_graph_execution_stats
|
||||
- [ ] get_node_execution
|
||||
- [ ] get_node_executions
|
||||
- [ ] get_latest_node_execution
|
||||
- [ ] update_node_execution_status
|
||||
- [ ] update_node_execution_status_batch
|
||||
- [ ] update_node_execution_stats
|
||||
- [ ] upsert_execution_input
|
||||
- [ ] upsert_execution_output
|
||||
- [ ] get_execution_kv_data
|
||||
- [ ] set_execution_kv_data
|
||||
- [ ] get_block_error_stats
|
||||
|
||||
### Phase 3: Graph APIs (4 methods)
|
||||
- [ ] get_node
|
||||
- [ ] get_graph
|
||||
- [ ] get_connected_output_nodes
|
||||
- [ ] get_graph_metadata
|
||||
|
||||
### Phase 4: Credit APIs (2 methods)
|
||||
- [ ] get_credits
|
||||
- [ ] spend_credits
|
||||
|
||||
### Phase 5: User APIs (4 methods)
|
||||
- [ ] get_user_metadata
|
||||
- [ ] update_user_metadata
|
||||
- [ ] get_user_integrations
|
||||
- [ ] update_user_integrations
|
||||
|
||||
### Phase 6: Communication APIs (4 methods)
|
||||
- [ ] get_active_user_ids_in_timerange
|
||||
- [ ] get_user_email_by_id
|
||||
- [ ] get_user_email_verification
|
||||
- [ ] get_user_notification_preference
|
||||
|
||||
### Phase 7: Notification APIs (5 methods)
|
||||
- [ ] create_or_add_to_user_notification_batch
|
||||
- [ ] empty_user_notification_batch
|
||||
- [ ] get_all_batches_by_type
|
||||
- [ ] get_user_notification_batch
|
||||
- [ ] get_user_notification_oldest_message_in_batch
|
||||
|
||||
### Phase 8: Client Implementation
|
||||
- [ ] Create DatabaseManagerClient with sync methods
|
||||
- [ ] Create DatabaseManagerAsyncClient with async methods
|
||||
- [ ] Test client method generation
|
||||
- [ ] Verify type preservation
|
||||
|
||||
### Phase 9: Integration Testing
|
||||
- [ ] Test all methods with real database
|
||||
- [ ] Verify user isolation
|
||||
- [ ] Test error scenarios
|
||||
- [ ] Performance testing
|
||||
- [ ] Event publishing verification
|
||||
|
||||
### Phase 10: Deployment Validation
|
||||
- [ ] Deploy to test environment
|
||||
- [ ] Run integration test suite
|
||||
- [ ] Verify backward compatibility
|
||||
- [ ] Performance benchmarking
|
||||
- [ ] Production deployment
|
||||
|
||||
## 11. Success Criteria
|
||||
|
||||
The implementation is successful when:
|
||||
|
||||
1. **All 40+ methods** produce identical outputs to the original
|
||||
2. **Performance** is within 10% of original implementation
|
||||
3. **All tests** pass without modification
|
||||
4. **No breaking changes** to any client code
|
||||
5. **Security boundaries** are maintained
|
||||
6. **Event publishing** works identically
|
||||
7. **Error handling** matches original behavior
|
||||
|
||||
## 12. Critical Implementation Notes
|
||||
|
||||
1. **DO NOT** modify any function signatures
|
||||
2. **DO NOT** change any return types
|
||||
3. **DO NOT** add new required parameters
|
||||
4. **DO NOT** remove any functionality
|
||||
5. **ALWAYS** maintain user_id isolation
|
||||
6. **ALWAYS** publish events for state changes
|
||||
7. **ALWAYS** use transactions for multi-step operations
|
||||
8. **ALWAYS** handle errors exactly as original
|
||||
|
||||
This specification, when implemented correctly, will produce a drop-in replacement for the DatabaseManager that maintains 100% compatibility with the existing system.
|
||||
765
autogpt_platform/autogpt-rs/NOTIFICATION_SERVICE.md
Normal file
765
autogpt_platform/autogpt-rs/NOTIFICATION_SERVICE.md
Normal file
@@ -0,0 +1,765 @@
|
||||
# Notification Service Technical Specification
|
||||
|
||||
## Overview
|
||||
|
||||
The AutoGPT Platform Notification Service is a RabbitMQ-based asynchronous notification system that handles various types of user notifications including real-time alerts, batched notifications, and scheduled summaries. The service supports email delivery via Postmark and system alerts via Discord.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Core Components
|
||||
|
||||
1. **NotificationManager Service** (`notifications.py`)
|
||||
- AppService implementation with RabbitMQ integration
|
||||
- Processes notification queues asynchronously
|
||||
- Manages batching strategies and delivery timing
|
||||
- Handles email templating and sending
|
||||
|
||||
2. **RabbitMQ Message Broker**
|
||||
- Multiple queues for different notification strategies
|
||||
- Dead letter exchange for failed messages
|
||||
- Topic-based routing for message distribution
|
||||
|
||||
3. **Email Sender** (`email.py`)
|
||||
- Postmark integration for email delivery
|
||||
- Jinja2 template rendering
|
||||
- HTML email composition with unsubscribe headers
|
||||
|
||||
4. **Database Storage**
|
||||
- Notification batching tables
|
||||
- User preference storage
|
||||
- Email verification tracking
|
||||
|
||||
## Service Exposure Mechanism
|
||||
|
||||
### AppService Framework
|
||||
|
||||
The NotificationManager extends `AppService` which automatically exposes methods decorated with `@expose` as HTTP endpoints:
|
||||
|
||||
```python
|
||||
class NotificationManager(AppService):
|
||||
@expose
|
||||
def queue_weekly_summary(self):
|
||||
# Implementation
|
||||
|
||||
@expose
|
||||
def process_existing_batches(self, notification_types: list[NotificationType]):
|
||||
# Implementation
|
||||
|
||||
@expose
|
||||
async def discord_system_alert(self, content: str):
|
||||
# Implementation
|
||||
```
|
||||
|
||||
### Automatic HTTP Endpoint Creation
|
||||
|
||||
When the service starts, the AppService base class:
|
||||
1. Scans for methods with `@expose` decorator
|
||||
2. Creates FastAPI routes for each exposed method:
|
||||
- Route path: `/{method_name}`
|
||||
- HTTP method: POST
|
||||
- Endpoint handler: Generated via `_create_fastapi_endpoint()`
|
||||
|
||||
### Service Client Access
|
||||
|
||||
#### NotificationManagerClient
|
||||
```python
|
||||
class NotificationManagerClient(AppServiceClient):
|
||||
@classmethod
|
||||
def get_service_type(cls):
|
||||
return NotificationManager
|
||||
|
||||
# Direct method references (sync)
|
||||
process_existing_batches = NotificationManager.process_existing_batches
|
||||
queue_weekly_summary = NotificationManager.queue_weekly_summary
|
||||
|
||||
# Async-to-sync conversion
|
||||
discord_system_alert = endpoint_to_sync(NotificationManager.discord_system_alert)
|
||||
```
|
||||
|
||||
#### Client Usage Pattern
|
||||
```python
|
||||
# Get client instance
|
||||
client = get_service_client(NotificationManagerClient)
|
||||
|
||||
# Call exposed methods via HTTP
|
||||
client.process_existing_batches([NotificationType.AGENT_RUN])
|
||||
client.queue_weekly_summary()
|
||||
client.discord_system_alert("System alert message")
|
||||
```
|
||||
|
||||
### HTTP Communication Details
|
||||
|
||||
1. **Service URL**: `http://{host}:{notification_service_port}`
|
||||
- Default port: 8007
|
||||
- Host: Configurable via settings
|
||||
|
||||
2. **Request Format**:
|
||||
- Method: POST
|
||||
- Path: `/{method_name}`
|
||||
- Body: JSON with method parameters
|
||||
|
||||
3. **Client Implementation**:
|
||||
- Uses `httpx` for HTTP requests
|
||||
- Automatic retry on connection failures
|
||||
- Configurable timeout (default from api_call_timeout)
|
||||
|
||||
### Direct Function Calls
|
||||
|
||||
The service also exposes two functions that can be called directly without going through the service client:
|
||||
|
||||
```python
|
||||
# Sync version - used by ExecutionManager
|
||||
def queue_notification(event: NotificationEventModel) -> NotificationResult
|
||||
|
||||
# Async version - used by credit system
|
||||
async def queue_notification_async(event: NotificationEventModel) -> NotificationResult
|
||||
```
|
||||
|
||||
These functions:
|
||||
- Connect directly to RabbitMQ
|
||||
- Publish messages to appropriate queues
|
||||
- Return success/failure status
|
||||
- Are NOT exposed via HTTP
|
||||
|
||||
## Message Queuing Architecture
|
||||
|
||||
### RabbitMQ Configuration
|
||||
|
||||
#### Exchanges
|
||||
```python
|
||||
NOTIFICATION_EXCHANGE = Exchange(name="notifications", type=ExchangeType.TOPIC)
|
||||
DEAD_LETTER_EXCHANGE = Exchange(name="dead_letter", type=ExchangeType.TOPIC)
|
||||
```
|
||||
|
||||
#### Queues
|
||||
1. **immediate_notifications**
|
||||
- Routing Key: `notification.immediate.#`
|
||||
- Dead Letter: `failed.immediate`
|
||||
- For: Critical alerts, errors
|
||||
|
||||
2. **admin_notifications**
|
||||
- Routing Key: `notification.admin.#`
|
||||
- Dead Letter: `failed.admin`
|
||||
- For: Refund requests, system alerts
|
||||
|
||||
3. **summary_notifications**
|
||||
- Routing Key: `notification.summary.#`
|
||||
- Dead Letter: `failed.summary`
|
||||
- For: Daily/weekly summaries
|
||||
|
||||
4. **batch_notifications**
|
||||
- Routing Key: `notification.batch.#`
|
||||
- Dead Letter: `failed.batch`
|
||||
- For: Agent runs, batched events
|
||||
|
||||
5. **failed_notifications**
|
||||
- Routing Key: `failed.#`
|
||||
- For: All failed messages
|
||||
|
||||
### Queue Strategies (QueueType enum)
|
||||
|
||||
1. **IMMEDIATE**: Send right away (errors, critical notifications)
|
||||
2. **BATCH**: Batch for configured delay (agent runs)
|
||||
3. **SUMMARY**: Scheduled digest (daily/weekly summaries)
|
||||
4. **BACKOFF**: Exponential backoff strategy (defined but not fully implemented)
|
||||
5. **ADMIN**: Admin-only notifications
|
||||
|
||||
## Notification Types
|
||||
|
||||
### Enum Values (NotificationType)
|
||||
```python
|
||||
AGENT_RUN # Batch strategy, 1 day delay
|
||||
ZERO_BALANCE # Backoff strategy, 60 min delay
|
||||
LOW_BALANCE # Immediate strategy
|
||||
BLOCK_EXECUTION_FAILED # Backoff strategy, 60 min delay
|
||||
CONTINUOUS_AGENT_ERROR # Backoff strategy, 60 min delay
|
||||
DAILY_SUMMARY # Summary strategy
|
||||
WEEKLY_SUMMARY # Summary strategy
|
||||
MONTHLY_SUMMARY # Summary strategy
|
||||
REFUND_REQUEST # Admin strategy
|
||||
REFUND_PROCESSED # Admin strategy
|
||||
```
|
||||
|
||||
## Integration Points
|
||||
|
||||
### 1. Scheduler Integration
|
||||
The scheduler service (`backend.executor.scheduler`) imports monitoring functions that call the NotificationManagerClient:
|
||||
|
||||
```python
|
||||
from backend.monitoring import (
|
||||
process_existing_batches,
|
||||
process_weekly_summary,
|
||||
)
|
||||
|
||||
# These are scheduled as cron jobs
|
||||
```
|
||||
|
||||
### 2. Execution Manager Integration
|
||||
The ExecutionManager directly calls `queue_notification()` for:
|
||||
- Agent run completions
|
||||
- Low balance alerts
|
||||
|
||||
```python
|
||||
from backend.notifications.notifications import queue_notification
|
||||
|
||||
# Called after graph execution completes
|
||||
queue_notification(NotificationEventModel(
|
||||
user_id=graph_exec.user_id,
|
||||
type=NotificationType.AGENT_RUN,
|
||||
data=AgentRunData(...)
|
||||
))
|
||||
```
|
||||
|
||||
### 3. Credit System Integration
|
||||
The credit system uses `queue_notification_async()` for:
|
||||
- Refund requests
|
||||
- Refund processed notifications
|
||||
|
||||
```python
|
||||
from backend.notifications.notifications import queue_notification_async
|
||||
|
||||
await queue_notification_async(NotificationEventModel(
|
||||
user_id=user_id,
|
||||
type=NotificationType.REFUND_REQUEST,
|
||||
data=RefundRequestData(...)
|
||||
))
|
||||
```
|
||||
|
||||
### 4. Monitoring Module Wrappers
|
||||
The monitoring module provides wrapper functions that are used by the scheduler:
|
||||
|
||||
```python
|
||||
# backend/monitoring/notification_monitor.py
|
||||
def process_existing_batches(**kwargs):
|
||||
args = NotificationJobArgs(**kwargs)
|
||||
get_notification_manager_client().process_existing_batches(
|
||||
args.notification_types
|
||||
)
|
||||
|
||||
def process_weekly_summary(**kwargs):
|
||||
get_notification_manager_client().queue_weekly_summary()
|
||||
```
|
||||
|
||||
## Data Models
|
||||
|
||||
### Base Event Model
|
||||
```typescript
|
||||
interface BaseEventModel {
|
||||
type: NotificationType;
|
||||
user_id: string;
|
||||
created_at: string; // ISO datetime with timezone
|
||||
}
|
||||
```
|
||||
|
||||
### Notification Event Model
|
||||
```typescript
|
||||
interface NotificationEventModel<T> extends BaseEventModel {
|
||||
data: T;
|
||||
}
|
||||
```
|
||||
|
||||
### Notification Data Types
|
||||
|
||||
#### AgentRunData
|
||||
```typescript
|
||||
interface AgentRunData {
|
||||
agent_name: string;
|
||||
credits_used: number;
|
||||
execution_time: number;
|
||||
node_count: number;
|
||||
graph_id: string;
|
||||
outputs: Array<Record<string, any>>;
|
||||
}
|
||||
```
|
||||
|
||||
#### ZeroBalanceData
|
||||
```typescript
|
||||
interface ZeroBalanceData {
|
||||
last_transaction: number;
|
||||
last_transaction_time: string; // ISO datetime with timezone
|
||||
top_up_link: string;
|
||||
}
|
||||
```
|
||||
|
||||
#### LowBalanceData
|
||||
```typescript
|
||||
interface LowBalanceData {
|
||||
agent_name: string;
|
||||
current_balance: number; // credits (100 = $1)
|
||||
billing_page_link: string;
|
||||
shortfall: number;
|
||||
}
|
||||
```
|
||||
|
||||
#### BlockExecutionFailedData
|
||||
```typescript
|
||||
interface BlockExecutionFailedData {
|
||||
block_name: string;
|
||||
block_id: string;
|
||||
error_message: string;
|
||||
graph_id: string;
|
||||
node_id: string;
|
||||
execution_id: string;
|
||||
}
|
||||
```
|
||||
|
||||
#### ContinuousAgentErrorData
|
||||
```typescript
|
||||
interface ContinuousAgentErrorData {
|
||||
agent_name: string;
|
||||
error_message: string;
|
||||
graph_id: string;
|
||||
execution_id: string;
|
||||
start_time: string; // ISO datetime with timezone
|
||||
error_time: string; // ISO datetime with timezone
|
||||
attempts: number;
|
||||
}
|
||||
```
|
||||
|
||||
#### Summary Data Types
|
||||
```typescript
|
||||
interface BaseSummaryData {
|
||||
total_credits_used: number;
|
||||
total_executions: number;
|
||||
most_used_agent: string;
|
||||
total_execution_time: number;
|
||||
successful_runs: number;
|
||||
failed_runs: number;
|
||||
average_execution_time: number;
|
||||
cost_breakdown: Record<string, number>;
|
||||
}
|
||||
|
||||
interface DailySummaryData extends BaseSummaryData {
|
||||
date: string; // ISO datetime with timezone
|
||||
}
|
||||
|
||||
interface WeeklySummaryData extends BaseSummaryData {
|
||||
start_date: string; // ISO datetime with timezone
|
||||
end_date: string; // ISO datetime with timezone
|
||||
}
|
||||
```
|
||||
|
||||
#### RefundRequestData
|
||||
```typescript
|
||||
interface RefundRequestData {
|
||||
user_id: string;
|
||||
user_name: string;
|
||||
user_email: string;
|
||||
transaction_id: string;
|
||||
refund_request_id: string;
|
||||
reason: string;
|
||||
amount: number;
|
||||
balance: number;
|
||||
}
|
||||
```
|
||||
|
||||
### Summary Parameters
|
||||
```typescript
|
||||
interface BaseSummaryParams {
|
||||
start_date: string; // ISO datetime with timezone
|
||||
end_date: string; // ISO datetime with timezone
|
||||
}
|
||||
|
||||
interface DailySummaryParams extends BaseSummaryParams {
|
||||
date: string; // ISO datetime with timezone
|
||||
}
|
||||
|
||||
interface WeeklySummaryParams extends BaseSummaryParams {
|
||||
start_date: string; // ISO datetime with timezone
|
||||
end_date: string; // ISO datetime with timezone
|
||||
}
|
||||
```
|
||||
|
||||
## Database Schema
|
||||
|
||||
### NotificationEvent Table
|
||||
```sql
|
||||
model NotificationEvent {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
UserNotificationBatch UserNotificationBatch? @relation
|
||||
userNotificationBatchId String?
|
||||
type NotificationType
|
||||
data Json
|
||||
@@index([userNotificationBatchId])
|
||||
}
|
||||
```
|
||||
|
||||
### UserNotificationBatch Table
|
||||
```sql
|
||||
model UserNotificationBatch {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
userId String
|
||||
User User @relation
|
||||
type NotificationType
|
||||
Notifications NotificationEvent[]
|
||||
@@unique([userId, type])
|
||||
}
|
||||
```
|
||||
|
||||
## API Methods
|
||||
|
||||
### Exposed Service Methods (via HTTP)
|
||||
|
||||
#### queue_weekly_summary()
|
||||
- **HTTP Endpoint**: `POST /queue_weekly_summary`
|
||||
- **Purpose**: Triggers weekly summary generation for all active users
|
||||
- **Process**:
|
||||
1. Runs in background executor
|
||||
2. Queries users active in last 7 days
|
||||
3. Queues summary notification for each user
|
||||
- **Used by**: Scheduler service (via cron)
|
||||
|
||||
#### process_existing_batches(notification_types: list[NotificationType])
|
||||
- **HTTP Endpoint**: `POST /process_existing_batches`
|
||||
- **Purpose**: Processes aged-out batches for specified notification types
|
||||
- **Process**:
|
||||
1. Runs in background executor
|
||||
2. Retrieves all batches for given types
|
||||
3. Checks if oldest message exceeds max delay
|
||||
4. Sends batched email if aged out
|
||||
5. Clears processed batches
|
||||
- **Used by**: Scheduler service (via cron)
|
||||
|
||||
#### discord_system_alert(content: str)
|
||||
- **HTTP Endpoint**: `POST /discord_system_alert`
|
||||
- **Purpose**: Sends system alerts to Discord channel
|
||||
- **Async**: Yes (converted to sync by client)
|
||||
- **Used by**: Monitoring services
|
||||
|
||||
### Direct Queue Functions (not via HTTP)
|
||||
|
||||
#### queue_notification(event: NotificationEventModel) -> NotificationResult
|
||||
- **Purpose**: Queue a notification (sync version)
|
||||
- **Used by**: ExecutionManager (same process)
|
||||
- **Direct RabbitMQ**: Yes
|
||||
|
||||
#### queue_notification_async(event: NotificationEventModel) -> NotificationResult
|
||||
- **Purpose**: Queue a notification (async version)
|
||||
- **Used by**: Credit system (async context)
|
||||
- **Direct RabbitMQ**: Yes
|
||||
|
||||
## Message Processing Flow
|
||||
|
||||
### 1. Message Routing
|
||||
```python
|
||||
def get_routing_key(event_type: NotificationType) -> str:
|
||||
strategy = NotificationTypeOverride(event_type).strategy
|
||||
if strategy == QueueType.IMMEDIATE:
|
||||
return f"notification.immediate.{event_type.value}"
|
||||
elif strategy == QueueType.BATCH:
|
||||
return f"notification.batch.{event_type.value}"
|
||||
# ... etc
|
||||
```
|
||||
|
||||
### 2. Queue Processing Methods
|
||||
|
||||
#### _process_immediate(message: str) -> bool
|
||||
1. Parse message to NotificationEventModel
|
||||
2. Retrieve user email
|
||||
3. Check user preferences and email verification
|
||||
4. Send email immediately via EmailSender
|
||||
5. Return True if successful
|
||||
|
||||
#### _process_batch(message: str) -> bool
|
||||
1. Parse message to NotificationEventModel
|
||||
2. Add to user's notification batch
|
||||
3. Check if batch is old enough (based on delay)
|
||||
4. If aged out:
|
||||
- Retrieve all batch messages
|
||||
- Send combined email
|
||||
- Clear batch
|
||||
5. Return True if processed or batched
|
||||
|
||||
#### _process_summary(message: str) -> bool
|
||||
1. Parse message to SummaryParamsEventModel
|
||||
2. Gather summary data (credits, executions, etc.)
|
||||
- **Note**: Currently returns hardcoded placeholder data
|
||||
3. Format and send summary email
|
||||
4. Return True if successful
|
||||
|
||||
#### _process_admin_message(message: str) -> bool
|
||||
1. Parse message
|
||||
2. Send to configured admin email
|
||||
3. No user preference checks
|
||||
4. Return True if successful
|
||||
|
||||
## Email Delivery
|
||||
|
||||
### EmailSender Class
|
||||
|
||||
#### Template Loading
|
||||
- Base template: `templates/base.html.jinja2`
|
||||
- Notification templates: `templates/{notification_type}.html.jinja2`
|
||||
- Subject templates from NotificationTypeOverride
|
||||
- **Note**: Templates use `.html.jinja2` extension, not just `.html`
|
||||
|
||||
#### Email Composition
|
||||
```python
|
||||
def send_templated(
|
||||
notification: NotificationType,
|
||||
user_email: str,
|
||||
data: NotificationEventModel | list[NotificationEventModel],
|
||||
user_unsub_link: str | None = None
|
||||
)
|
||||
```
|
||||
|
||||
#### Postmark Integration
|
||||
- API Token: `settings.secrets.postmark_server_api_token`
|
||||
- Sender Email: `settings.config.postmark_sender_email`
|
||||
- Headers:
|
||||
- `List-Unsubscribe-Post: List-Unsubscribe=One-Click`
|
||||
- `List-Unsubscribe: <{unsubscribe_link}>`
|
||||
|
||||
## User Preferences and Permissions
|
||||
|
||||
### Email Verification Check
|
||||
```python
|
||||
validated_email = get_db().get_user_email_verification(user_id)
|
||||
```
|
||||
|
||||
### Notification Preferences
|
||||
```python
|
||||
preferences = get_db().get_user_notification_preference(user_id).preferences
|
||||
# Returns dict[NotificationType, bool]
|
||||
```
|
||||
|
||||
### Preference Fields in User Model
|
||||
- `notifyOnAgentRun`
|
||||
- `notifyOnZeroBalance`
|
||||
- `notifyOnLowBalance`
|
||||
- `notifyOnBlockExecutionFailed`
|
||||
- `notifyOnContinuousAgentError`
|
||||
- `notifyOnDailySummary`
|
||||
- `notifyOnWeeklySummary`
|
||||
- `notifyOnMonthlySummary`
|
||||
|
||||
### Unsubscribe Link Generation
|
||||
```python
|
||||
def generate_unsubscribe_link(user_id: str) -> str:
|
||||
# HMAC-SHA256 signed token
|
||||
# Format: base64(user_id:signature_hex)
|
||||
# URL: {platform_base_url}/api/email/unsubscribe?token={token}
|
||||
```
|
||||
|
||||
## Batching Logic
|
||||
|
||||
### Batch Delays (get_batch_delay)
|
||||
|
||||
**Note**: The delay configuration exists for multiple notification types, but only notifications with `QueueType.BATCH` strategy actually use batching. Others use different strategies:
|
||||
|
||||
- `AGENT_RUN`: 1 day (Strategy: BATCH - actually uses batching)
|
||||
- `ZERO_BALANCE`: 60 minutes configured (Strategy: BACKOFF - not batched)
|
||||
- `LOW_BALANCE`: 60 minutes configured (Strategy: IMMEDIATE - sent immediately)
|
||||
- `BLOCK_EXECUTION_FAILED`: 60 minutes configured (Strategy: BACKOFF - not batched)
|
||||
- `CONTINUOUS_AGENT_ERROR`: 60 minutes configured (Strategy: BACKOFF - not batched)
|
||||
|
||||
### Batch Processing
|
||||
1. Messages added to UserNotificationBatch
|
||||
2. Oldest message timestamp tracked
|
||||
3. When `oldest_timestamp + delay < now()`:
|
||||
- Batch is processed
|
||||
- All messages sent in single email
|
||||
- Batch cleared
|
||||
|
||||
## Service Lifecycle
|
||||
|
||||
### Startup
|
||||
1. Initialize FastAPI app with exposed endpoints
|
||||
2. Start HTTP server on port 8007
|
||||
3. Initialize RabbitMQ connection
|
||||
4. Create/verify exchanges and queues
|
||||
5. Set up queue consumers
|
||||
6. Start processing loop
|
||||
|
||||
### Main Loop
|
||||
```python
|
||||
while self.running:
|
||||
await self._run_queue(immediate_queue, self._process_immediate, ...)
|
||||
await self._run_queue(admin_queue, self._process_admin_message, ...)
|
||||
await self._run_queue(batch_queue, self._process_batch, ...)
|
||||
await self._run_queue(summary_queue, self._process_summary, ...)
|
||||
await asyncio.sleep(0.1)
|
||||
```
|
||||
|
||||
### Shutdown
|
||||
1. Set `running = False`
|
||||
2. Disconnect RabbitMQ
|
||||
3. Cleanup resources
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
```python
|
||||
# Service Configuration
|
||||
notification_service_port: int = 8007
|
||||
|
||||
# Email Configuration
|
||||
postmark_sender_email: str = "invalid@invalid.com"
|
||||
refund_notification_email: str = "refund@agpt.co"
|
||||
|
||||
# Security
|
||||
unsubscribe_secret_key: str = ""
|
||||
|
||||
# Secrets
|
||||
postmark_server_api_token: str = ""
|
||||
postmark_webhook_token: str = ""
|
||||
discord_bot_token: str = ""
|
||||
|
||||
# Platform URLs
|
||||
platform_base_url: str
|
||||
frontend_base_url: str
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Message Processing Errors
|
||||
- Failed messages sent to dead letter queue
|
||||
- Validation errors logged but don't crash service
|
||||
- Connection errors trigger retry with `@continuous_retry()`
|
||||
|
||||
### RabbitMQ ACK/NACK Protocol
|
||||
- Success: `message.ack()`
|
||||
- Failure: `message.reject(requeue=False)`
|
||||
- Timeout/Queue empty: Continue loop
|
||||
|
||||
### HTTP Endpoint Errors
|
||||
- Wrapped in RemoteCallError for client
|
||||
- Automatic retry available via client configuration
|
||||
- Connection failures tracked and logged
|
||||
|
||||
## System Integrations
|
||||
|
||||
### DatabaseManagerClient
|
||||
- User email retrieval
|
||||
- Email verification status
|
||||
- Notification preferences
|
||||
- Batch management
|
||||
- Active user queries
|
||||
|
||||
### Discord Integration
|
||||
- Uses SendDiscordMessageBlock
|
||||
- Configured via discord_bot_token
|
||||
- For system alerts only
|
||||
|
||||
## Implementation Checklist
|
||||
|
||||
1. **Core Service**
|
||||
- [ ] AppService implementation with @expose decorators
|
||||
- [ ] FastAPI endpoint generation
|
||||
- [ ] RabbitMQ connection management
|
||||
- [ ] Queue consumer setup
|
||||
- [ ] Message routing logic
|
||||
|
||||
2. **Service Client**
|
||||
- [ ] NotificationManagerClient implementation
|
||||
- [ ] HTTP client configuration
|
||||
- [ ] Method mapping to service endpoints
|
||||
- [ ] Async-to-sync conversions
|
||||
|
||||
3. **Message Processing**
|
||||
- [ ] Parse and validate all notification types
|
||||
- [ ] Implement all queue strategies
|
||||
- [ ] Batch management with delays
|
||||
- [ ] Summary data gathering
|
||||
|
||||
4. **Email Delivery**
|
||||
- [ ] Postmark integration
|
||||
- [ ] Template loading and rendering
|
||||
- [ ] Unsubscribe header support
|
||||
- [ ] HTML email composition
|
||||
|
||||
5. **User Management**
|
||||
- [ ] Preference checking
|
||||
- [ ] Email verification
|
||||
- [ ] Unsubscribe link generation
|
||||
- [ ] Daily limit tracking
|
||||
|
||||
6. **Batching System**
|
||||
- [ ] Database batch operations
|
||||
- [ ] Age-out checking
|
||||
- [ ] Batch clearing after send
|
||||
- [ ] Oldest message tracking
|
||||
|
||||
7. **Error Handling**
|
||||
- [ ] Dead letter queue routing
|
||||
- [ ] Message rejection on failure
|
||||
- [ ] Continuous retry wrapper
|
||||
- [ ] Validation error logging
|
||||
|
||||
8. **Scheduled Operations**
|
||||
- [ ] Weekly summary generation
|
||||
- [ ] Batch processing triggers
|
||||
- [ ] Background executor usage
|
||||
|
||||
## Security Considerations
|
||||
|
||||
1. **Service-to-Service Communication**:
|
||||
- HTTP endpoints only accessible internally
|
||||
- No authentication on service endpoints (internal network only)
|
||||
- Service discovery via host/port configuration
|
||||
|
||||
2. **User Security**:
|
||||
- Email verification required for all user notifications
|
||||
- Unsubscribe tokens HMAC-signed
|
||||
- User preferences enforced
|
||||
|
||||
3. **Admin Notifications**:
|
||||
- Separate queue, no user preference checks
|
||||
- Fixed admin email configuration
|
||||
|
||||
## Testing Considerations
|
||||
|
||||
1. **Unit Tests**
|
||||
- Message parsing and validation
|
||||
- Routing key generation
|
||||
- Batch delay calculations
|
||||
- Template rendering
|
||||
|
||||
2. **Integration Tests**
|
||||
- HTTP endpoint accessibility
|
||||
- Service client method calls
|
||||
- RabbitMQ message flow
|
||||
- Database batch operations
|
||||
- Email sending (mock Postmark)
|
||||
|
||||
3. **Load Tests**
|
||||
- High volume message processing
|
||||
- Concurrent HTTP requests
|
||||
- Batch accumulation limits
|
||||
- Memory usage under load
|
||||
|
||||
## Implementation Status Notes
|
||||
|
||||
1. **Backoff Strategy**: While `QueueType.BACKOFF` is defined and used by several notification types (ZERO_BALANCE, BLOCK_EXECUTION_FAILED, CONTINUOUS_AGENT_ERROR), the actual exponential backoff processing logic is not implemented. These messages are routed to immediate queue.
|
||||
|
||||
2. **Summary Data**: The `_gather_summary_data()` method currently returns hardcoded placeholder values rather than querying actual execution data from the database.
|
||||
|
||||
3. **Batch Processing**: Only `AGENT_RUN` notifications actually use batch processing. Other notification types with configured delays use different strategies (IMMEDIATE or BACKOFF).
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
1. **Additional Channels**
|
||||
- SMS notifications (not implemented)
|
||||
- Webhook notifications (not implemented)
|
||||
- In-app notifications
|
||||
|
||||
2. **Advanced Batching**
|
||||
- Dynamic batch sizes
|
||||
- Priority-based processing
|
||||
- Custom delay configurations
|
||||
|
||||
3. **Analytics**
|
||||
- Delivery tracking
|
||||
- Open/click rates
|
||||
- Notification effectiveness metrics
|
||||
|
||||
4. **Service Improvements**
|
||||
- Authentication for HTTP endpoints
|
||||
- Rate limiting per user
|
||||
- Circuit breaker patterns
|
||||
- Implement actual backoff processing for BACKOFF strategy
|
||||
- Implement real summary data gathering
|
||||
474
autogpt_platform/autogpt-rs/SCHEDULER.md
Normal file
474
autogpt_platform/autogpt-rs/SCHEDULER.md
Normal file
@@ -0,0 +1,474 @@
|
||||
# AutoGPT Platform Scheduler Technical Specification
|
||||
|
||||
## Executive Summary
|
||||
|
||||
This document provides a comprehensive technical specification for the AutoGPT Platform Scheduler service. The scheduler is responsible for managing scheduled graph executions, system monitoring tasks, and periodic maintenance operations. This specification is designed to enable a complete reimplementation that maintains 100% compatibility with the existing system.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [System Architecture](#system-architecture)
|
||||
2. [Service Implementation](#service-implementation)
|
||||
3. [Data Models](#data-models)
|
||||
4. [API Endpoints](#api-endpoints)
|
||||
5. [Database Schema](#database-schema)
|
||||
6. [External Dependencies](#external-dependencies)
|
||||
7. [Authentication & Authorization](#authentication--authorization)
|
||||
8. [Process Management](#process-management)
|
||||
9. [Error Handling](#error-handling)
|
||||
10. [Configuration](#configuration)
|
||||
11. [Testing Strategy](#testing-strategy)
|
||||
|
||||
## System Architecture
|
||||
|
||||
### Overview
|
||||
|
||||
The scheduler operates as an independent microservice within the AutoGPT platform, implementing the `AppService` base class pattern. It runs on a dedicated port (default: 8003) and exposes HTTP/JSON-RPC endpoints for communication with other services.
|
||||
|
||||
### Core Components
|
||||
|
||||
1. **Scheduler Service** (`backend/executor/scheduler.py:156`)
|
||||
- Extends `AppService` base class
|
||||
- Manages APScheduler instance with multiple jobstores
|
||||
- Handles lifecycle management and graceful shutdown
|
||||
|
||||
2. **Scheduler Client** (`backend/executor/scheduler.py:354`)
|
||||
- Extends `AppServiceClient` base class
|
||||
- Provides async/sync method wrappers for RPC calls
|
||||
- Implements automatic retry and connection pooling
|
||||
|
||||
3. **Entry Points**
|
||||
- Main executable: `backend/scheduler.py`
|
||||
- Service launcher: `backend/app.py`
|
||||
|
||||
## Service Implementation
|
||||
|
||||
### Base Service Pattern
|
||||
|
||||
```python
|
||||
class Scheduler(AppService):
|
||||
scheduler: BlockingScheduler
|
||||
|
||||
def __init__(self, register_system_tasks: bool = True):
|
||||
self.register_system_tasks = register_system_tasks
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return config.execution_scheduler_port # Default: 8003
|
||||
|
||||
@classmethod
|
||||
def db_pool_size(cls) -> int:
|
||||
return config.scheduler_db_pool_size # Default: 3
|
||||
|
||||
def run_service(self):
|
||||
# Initialize scheduler with jobstores
|
||||
# Register system tasks if enabled
|
||||
# Start scheduler blocking loop
|
||||
|
||||
def cleanup(self):
|
||||
# Graceful shutdown of scheduler
|
||||
# Wait=False for immediate termination
|
||||
```
|
||||
|
||||
### Jobstore Configuration
|
||||
|
||||
The scheduler uses three distinct jobstores:
|
||||
|
||||
1. **EXECUTION** (`Jobstores.EXECUTION.value`)
|
||||
- Type: SQLAlchemyJobStore
|
||||
- Table: `apscheduler_jobs`
|
||||
- Purpose: Graph execution schedules
|
||||
- Persistence: Required
|
||||
|
||||
2. **BATCHED_NOTIFICATIONS** (`Jobstores.BATCHED_NOTIFICATIONS.value`)
|
||||
- Type: SQLAlchemyJobStore
|
||||
- Table: `apscheduler_jobs_batched_notifications`
|
||||
- Purpose: Batched notification processing
|
||||
- Persistence: Required
|
||||
|
||||
3. **WEEKLY_NOTIFICATIONS** (`Jobstores.WEEKLY_NOTIFICATIONS.value`)
|
||||
- Type: MemoryJobStore
|
||||
- Purpose: Weekly summary notifications
|
||||
- Persistence: Not required
|
||||
|
||||
### System Tasks
|
||||
|
||||
When `register_system_tasks=True`, the following monitoring tasks are registered:
|
||||
|
||||
1. **Weekly Summary Processing**
|
||||
- Job ID: `process_weekly_summary`
|
||||
- Schedule: `0 * * * *` (hourly)
|
||||
- Function: `monitoring.process_weekly_summary`
|
||||
- Jobstore: WEEKLY_NOTIFICATIONS
|
||||
|
||||
2. **Late Execution Monitoring**
|
||||
- Job ID: `report_late_executions`
|
||||
- Schedule: Interval (config.execution_late_notification_threshold_secs)
|
||||
- Function: `monitoring.report_late_executions`
|
||||
- Jobstore: EXECUTION
|
||||
|
||||
3. **Block Error Rate Monitoring**
|
||||
- Job ID: `report_block_error_rates`
|
||||
- Schedule: Interval (config.block_error_rate_check_interval_secs)
|
||||
- Function: `monitoring.report_block_error_rates`
|
||||
- Jobstore: EXECUTION
|
||||
|
||||
4. **Cloud Storage Cleanup**
|
||||
- Job ID: `cleanup_expired_files`
|
||||
- Schedule: Interval (config.cloud_storage_cleanup_interval_hours * 3600)
|
||||
- Function: `cleanup_expired_files`
|
||||
- Jobstore: EXECUTION
|
||||
|
||||
## Data Models
|
||||
|
||||
### GraphExecutionJobArgs
|
||||
|
||||
```python
|
||||
class GraphExecutionJobArgs(BaseModel):
|
||||
user_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
cron: str
|
||||
input_data: BlockInput
|
||||
input_credentials: dict[str, CredentialsMetaInput] = Field(default_factory=dict)
|
||||
```
|
||||
|
||||
### GraphExecutionJobInfo
|
||||
|
||||
```python
|
||||
class GraphExecutionJobInfo(GraphExecutionJobArgs):
|
||||
id: str
|
||||
name: str
|
||||
next_run_time: str
|
||||
|
||||
@staticmethod
|
||||
def from_db(job_args: GraphExecutionJobArgs, job_obj: JobObj) -> "GraphExecutionJobInfo":
|
||||
return GraphExecutionJobInfo(
|
||||
id=job_obj.id,
|
||||
name=job_obj.name,
|
||||
next_run_time=job_obj.next_run_time.isoformat(),
|
||||
**job_args.model_dump(),
|
||||
)
|
||||
```
|
||||
|
||||
### NotificationJobArgs
|
||||
|
||||
```python
|
||||
class NotificationJobArgs(BaseModel):
|
||||
notification_types: list[NotificationType]
|
||||
cron: str
|
||||
```
|
||||
|
||||
### CredentialsMetaInput
|
||||
|
||||
```python
|
||||
class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
id: str
|
||||
title: Optional[str] = None
|
||||
provider: CP
|
||||
type: CT
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
All endpoints are exposed via the `@expose` decorator and follow HTTP POST JSON-RPC pattern.
|
||||
|
||||
### 1. Add Graph Execution Schedule
|
||||
|
||||
**Endpoint**: `/add_graph_execution_schedule`
|
||||
|
||||
**Request Body**:
|
||||
```json
|
||||
{
|
||||
"user_id": "string",
|
||||
"graph_id": "string",
|
||||
"graph_version": "integer",
|
||||
"cron": "string (crontab format)",
|
||||
"input_data": {},
|
||||
"input_credentials": {},
|
||||
"name": "string (optional)"
|
||||
}
|
||||
```
|
||||
|
||||
**Response**: `GraphExecutionJobInfo`
|
||||
|
||||
**Behavior**:
|
||||
- Creates APScheduler job with CronTrigger
|
||||
- Uses job kwargs to store GraphExecutionJobArgs
|
||||
- Sets `replace_existing=True` to allow updates
|
||||
- Returns job info with generated ID and next run time
|
||||
|
||||
### 2. Delete Graph Execution Schedule
|
||||
|
||||
**Endpoint**: `/delete_graph_execution_schedule`
|
||||
|
||||
**Request Body**:
|
||||
```json
|
||||
{
|
||||
"schedule_id": "string",
|
||||
"user_id": "string"
|
||||
}
|
||||
```
|
||||
|
||||
**Response**: `GraphExecutionJobInfo`
|
||||
|
||||
**Behavior**:
|
||||
- Validates schedule exists in EXECUTION jobstore
|
||||
- Verifies user_id matches job's user_id
|
||||
- Removes job from scheduler
|
||||
- Returns deleted job info
|
||||
|
||||
**Errors**:
|
||||
- `NotFoundError`: If job doesn't exist
|
||||
- `NotAuthorizedError`: If user_id doesn't match
|
||||
|
||||
### 3. Get Graph Execution Schedules
|
||||
|
||||
**Endpoint**: `/get_graph_execution_schedules`
|
||||
|
||||
**Request Body**:
|
||||
```json
|
||||
{
|
||||
"graph_id": "string (optional)",
|
||||
"user_id": "string (optional)"
|
||||
}
|
||||
```
|
||||
|
||||
**Response**: `list[GraphExecutionJobInfo]`
|
||||
|
||||
**Behavior**:
|
||||
- Retrieves all jobs from EXECUTION jobstore
|
||||
- Filters by graph_id and/or user_id if provided
|
||||
- Validates job kwargs as GraphExecutionJobArgs
|
||||
- Skips invalid jobs (ValidationError)
|
||||
- Only returns jobs with next_run_time set
|
||||
|
||||
### 4. System Task Endpoints
|
||||
|
||||
- `/execute_process_existing_batches` - Trigger batch processing
|
||||
- `/execute_process_weekly_summary` - Trigger weekly summary
|
||||
- `/execute_report_late_executions` - Trigger late execution report
|
||||
- `/execute_report_block_error_rates` - Trigger error rate report
|
||||
- `/execute_cleanup_expired_files` - Trigger file cleanup
|
||||
|
||||
### 5. Health Check
|
||||
|
||||
**Endpoints**: `/health_check`, `/health_check_async`
|
||||
**Methods**: POST, GET
|
||||
**Response**: "OK"
|
||||
|
||||
## Database Schema
|
||||
|
||||
### APScheduler Tables
|
||||
|
||||
The scheduler relies on APScheduler's SQLAlchemy jobstore schema:
|
||||
|
||||
1. **apscheduler_jobs**
|
||||
- id: VARCHAR (PRIMARY KEY)
|
||||
- next_run_time: FLOAT
|
||||
- job_state: BLOB/BYTEA (pickled job data)
|
||||
|
||||
2. **apscheduler_jobs_batched_notifications**
|
||||
- Same schema as above
|
||||
- Separate table for notification jobs
|
||||
|
||||
### Database Configuration
|
||||
|
||||
- URL extraction from `DIRECT_URL` environment variable
|
||||
- Schema extraction from URL query parameter
|
||||
- Connection pooling: `pool_size=db_pool_size()`, `max_overflow=0`
|
||||
- Metadata schema binding for multi-schema support
|
||||
|
||||
## External Dependencies
|
||||
|
||||
### Required Services
|
||||
|
||||
1. **PostgreSQL Database**
|
||||
- Connection via `DIRECT_URL` environment variable
|
||||
- Schema support via URL parameter
|
||||
- APScheduler job persistence
|
||||
|
||||
2. **ExecutionManager** (via execution_utils)
|
||||
- Function: `add_graph_execution`
|
||||
- Called by: `execute_graph` job function
|
||||
- Purpose: Create graph execution entries
|
||||
|
||||
3. **NotificationManager** (via monitoring module)
|
||||
- Functions: `process_existing_batches`, `queue_weekly_summary`
|
||||
- Purpose: Notification processing
|
||||
|
||||
4. **Cloud Storage** (via util.cloud_storage)
|
||||
- Function: `cleanup_expired_files_async`
|
||||
- Purpose: File expiration management
|
||||
|
||||
### Python Dependencies
|
||||
|
||||
```
|
||||
apscheduler>=3.10.0
|
||||
sqlalchemy
|
||||
pydantic>=2.0
|
||||
httpx
|
||||
uvicorn
|
||||
fastapi
|
||||
python-dotenv
|
||||
tenacity
|
||||
```
|
||||
|
||||
## Authentication & Authorization
|
||||
|
||||
### Service-Level Authentication
|
||||
|
||||
- No authentication required between internal services
|
||||
- Services communicate via trusted internal network
|
||||
- Host/port configuration via environment variables
|
||||
|
||||
### User-Level Authorization
|
||||
|
||||
- Authorization check in `delete_graph_execution_schedule`:
|
||||
- Validates `user_id` matches job's `user_id`
|
||||
- Raises `NotAuthorizedError` on mismatch
|
||||
- No authorization for read operations (security consideration)
|
||||
|
||||
## Process Management
|
||||
|
||||
### Startup Sequence
|
||||
|
||||
1. Load environment variables via `dotenv.load_dotenv()`
|
||||
2. Extract database URL and schema
|
||||
3. Initialize BlockingScheduler with configured jobstores
|
||||
4. Register system tasks (if enabled)
|
||||
5. Add job execution listener
|
||||
6. Start scheduler (blocking)
|
||||
|
||||
### Shutdown Sequence
|
||||
|
||||
1. Receive SIGTERM/SIGINT signal
|
||||
2. Call `cleanup()` method
|
||||
3. Shutdown scheduler with `wait=False`
|
||||
4. Terminate process
|
||||
|
||||
### Multi-Process Architecture
|
||||
|
||||
- Runs as independent process via `AppProcess`
|
||||
- Started by `run_processes()` in app.py
|
||||
- Can run in foreground or background mode
|
||||
- Automatic signal handling for graceful shutdown
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Job Execution Errors
|
||||
|
||||
- Listener on `EVENT_JOB_ERROR` logs failures
|
||||
- Errors in job functions are caught and logged
|
||||
- Jobs continue to run on schedule despite failures
|
||||
|
||||
### RPC Communication Errors
|
||||
|
||||
- Automatic retry via `@conn_retry` decorator
|
||||
- Configurable retry count and timeout
|
||||
- Connection pooling with self-healing
|
||||
|
||||
### Database Connection Errors
|
||||
|
||||
- APScheduler handles reconnection automatically
|
||||
- Pool exhaustion prevented by `max_overflow=0`
|
||||
- Connection errors logged but don't crash service
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
- `DIRECT_URL`: PostgreSQL connection string (required)
|
||||
- `{SERVICE_NAME}_HOST`: Override service host
|
||||
- Standard logging configuration
|
||||
|
||||
### Config Settings (via Config class)
|
||||
|
||||
```python
|
||||
execution_scheduler_port: int = 8003
|
||||
scheduler_db_pool_size: int = 3
|
||||
execution_late_notification_threshold_secs: int
|
||||
block_error_rate_check_interval_secs: int
|
||||
cloud_storage_cleanup_interval_hours: int
|
||||
pyro_host: str = "localhost"
|
||||
pyro_client_comm_timeout: float = 15
|
||||
pyro_client_comm_retry: int = 3
|
||||
rpc_client_call_timeout: int = 300
|
||||
```
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### Unit Tests
|
||||
|
||||
1. Mock APScheduler for job management tests
|
||||
2. Mock database connections
|
||||
3. Test each RPC endpoint independently
|
||||
4. Verify job serialization/deserialization
|
||||
|
||||
### Integration Tests
|
||||
|
||||
1. Test with real PostgreSQL instance
|
||||
2. Verify job persistence across restarts
|
||||
3. Test concurrent job execution
|
||||
4. Validate cron expression parsing
|
||||
|
||||
### Critical Test Cases
|
||||
|
||||
1. **Job Persistence**: Jobs survive scheduler restart
|
||||
2. **User Isolation**: Users can only delete their own jobs
|
||||
3. **Concurrent Access**: Multiple clients can add/remove jobs
|
||||
4. **Error Recovery**: Service recovers from database outages
|
||||
5. **Resource Cleanup**: No memory/connection leaks
|
||||
|
||||
## Implementation Notes
|
||||
|
||||
### Key Design Decisions
|
||||
|
||||
1. **BlockingScheduler vs AsyncIOScheduler**: Uses BlockingScheduler for simplicity and compatibility with multiprocessing architecture
|
||||
|
||||
2. **Job Storage**: All job arguments stored in kwargs, not in job name/id
|
||||
|
||||
3. **Separate Jobstores**: Isolation between execution and notification jobs
|
||||
|
||||
4. **No Authentication**: Relies on network isolation for security
|
||||
|
||||
### Migration Considerations
|
||||
|
||||
1. APScheduler job format must be preserved exactly
|
||||
2. Database schema cannot change without migration
|
||||
3. RPC protocol must maintain compatibility
|
||||
4. Environment variables must match existing deployment
|
||||
|
||||
### Performance Considerations
|
||||
|
||||
1. Database pool size limited to prevent exhaustion
|
||||
2. No job result storage (fire-and-forget pattern)
|
||||
3. Minimal logging in hot paths
|
||||
4. Connection reuse via pooling
|
||||
|
||||
## Appendix: Critical Implementation Details
|
||||
|
||||
### Event Loop Management
|
||||
|
||||
```python
|
||||
@thread_cached
|
||||
def get_event_loop():
|
||||
return asyncio.new_event_loop()
|
||||
|
||||
def execute_graph(**kwargs):
|
||||
get_event_loop().run_until_complete(_execute_graph(**kwargs))
|
||||
```
|
||||
|
||||
### Job Function Execution Context
|
||||
|
||||
- Jobs run in scheduler's process space
|
||||
- Each job gets fresh event loop
|
||||
- No shared state between job executions
|
||||
- Exceptions logged but don't affect scheduler
|
||||
|
||||
### Cron Expression Format
|
||||
|
||||
- Uses standard crontab format via `CronTrigger.from_crontab()`
|
||||
- Supports: minute hour day month day_of_week
|
||||
- Special strings: @yearly, @monthly, @weekly, @daily, @hourly
|
||||
|
||||
This specification provides all necessary details to reimplement the scheduler service while maintaining 100% compatibility with the existing system. Any deviation from these specifications may result in system incompatibility.
|
||||
85
autogpt_platform/autogpt-rs/websocket/.github/workflows/ci.yml
vendored
Normal file
85
autogpt_platform/autogpt-rs/websocket/.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,85 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
RUSTFLAGS: "-D warnings"
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Test
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
redis:
|
||||
image: redis:7
|
||||
options: >-
|
||||
--health-cmd "redis-cli ping"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
ports:
|
||||
- 6379:6379
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- name: Run tests
|
||||
run: cargo test
|
||||
env:
|
||||
REDIS_URL: redis://localhost:6379
|
||||
|
||||
clippy:
|
||||
name: Clippy
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
components: clippy
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- name: Run clippy
|
||||
run: |
|
||||
cargo clippy -- \
|
||||
-D warnings \
|
||||
-D clippy::unwrap_used \
|
||||
-D clippy::panic \
|
||||
-D clippy::unimplemented \
|
||||
-D clippy::todo
|
||||
|
||||
fmt:
|
||||
name: Format
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
components: rustfmt
|
||||
- name: Check formatting
|
||||
run: cargo fmt -- --check
|
||||
|
||||
bench:
|
||||
name: Benchmarks
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
redis:
|
||||
image: redis:7
|
||||
options: >-
|
||||
--health-cmd "redis-cli ping"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
ports:
|
||||
- 6379:6379
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- name: Build benchmarks
|
||||
run: cargo bench --no-run
|
||||
env:
|
||||
REDIS_URL: redis://localhost:6379
|
||||
3382
autogpt_platform/autogpt-rs/websocket/Cargo.lock
generated
Normal file
3382
autogpt_platform/autogpt-rs/websocket/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
60
autogpt_platform/autogpt-rs/websocket/Cargo.toml
Normal file
60
autogpt_platform/autogpt-rs/websocket/Cargo.toml
Normal file
@@ -0,0 +1,60 @@
|
||||
[package]
|
||||
name = "websocket"
|
||||
authors = ["AutoGPT Team"]
|
||||
description = "WebSocket server for AutoGPT Platform"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
name = "websocket"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "websocket"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
axum = { version = "0.7.5", features = ["ws"] }
|
||||
jsonwebtoken = "9.3.0"
|
||||
redis = { version = "0.25.4", features = ["aio", "tokio-comp"] }
|
||||
serde = { version = "1.0.204", features = ["derive"] }
|
||||
serde_json = "1.0.120"
|
||||
tokio = { version = "1.38.1", features = ["rt-multi-thread", "macros", "net", "sync", "time", "io-util"] }
|
||||
tower-http = { version = "0.5.2", features = ["cors"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
futures = "0.3"
|
||||
dotenvy = "0.15"
|
||||
clap = { version = "4.5.4", features = ["derive"] }
|
||||
toml = "0.8"
|
||||
|
||||
[dev-dependencies]
|
||||
# Load testing and profiling
|
||||
tokio-console = "0.1"
|
||||
criterion = { version = "0.5", features = ["async_tokio"] }
|
||||
pprof = { version = "0.13", features = ["flamegraph", "criterion"] }
|
||||
# Dependencies for benchmarks
|
||||
tokio-tungstenite = "0.24"
|
||||
futures-util = "0.3"
|
||||
chrono = "0.4"
|
||||
|
||||
[[bench]]
|
||||
name = "websocket_bench"
|
||||
harness = false
|
||||
|
||||
[[example]]
|
||||
name = "ws_client_example"
|
||||
required-features = []
|
||||
|
||||
[profile.release]
|
||||
opt-level = 3 # Maximum optimization
|
||||
lto = true # Enable link-time optimization
|
||||
codegen-units = 1 # Reduce parallel code generation units to increase optimization
|
||||
panic = "abort" # Remove panic unwinding to reduce binary size
|
||||
strip = true # Strip symbols from binary
|
||||
|
||||
[profile.bench]
|
||||
opt-level = 3 # Maximum optimization
|
||||
lto = true # Enable link-time optimization
|
||||
codegen-units = 1 # Reduce parallel code generation units to increase optimization
|
||||
debug = true # Keep debug symbols for profiling
|
||||
412
autogpt_platform/autogpt-rs/websocket/README.md
Normal file
412
autogpt_platform/autogpt-rs/websocket/README.md
Normal file
@@ -0,0 +1,412 @@
|
||||
# WebSocket API Technical Specification
|
||||
|
||||
## Overview
|
||||
|
||||
This document provides a complete technical specification for the AutoGPT Platform WebSocket API (`ws_api.py`). The WebSocket API provides real-time updates for graph and node execution events, enabling clients to monitor workflow execution progress.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Core Components
|
||||
|
||||
1. **WebSocket Server** (`ws_api.py`)
|
||||
- FastAPI application with WebSocket endpoint
|
||||
- Handles client connections and message routing
|
||||
- Authenticates clients via JWT tokens
|
||||
- Manages subscriptions to execution events
|
||||
|
||||
2. **Connection Manager** (`conn_manager.py`)
|
||||
- Maintains active WebSocket connections
|
||||
- Manages channel subscriptions
|
||||
- Routes execution events to subscribed clients
|
||||
- Handles connection lifecycle
|
||||
|
||||
3. **Event Broadcasting System**
|
||||
- Redis Pub/Sub based event bus
|
||||
- Asynchronous event broadcaster
|
||||
- Execution event propagation from backend services
|
||||
|
||||
## API Endpoint
|
||||
|
||||
### WebSocket Endpoint
|
||||
- **URL**: `/ws`
|
||||
- **Protocol**: WebSocket (ws:// or wss://)
|
||||
- **Query Parameters**:
|
||||
- `token` (required when auth enabled): JWT authentication token
|
||||
|
||||
## Authentication
|
||||
|
||||
### JWT Token Authentication
|
||||
- **When Required**: When `settings.config.enable_auth` is `True`
|
||||
- **Token Location**: Query parameter `?token=<JWT_TOKEN>`
|
||||
- **Token Validation**:
|
||||
```python
|
||||
payload = parse_jwt_token(token)
|
||||
user_id = payload.get("sub")
|
||||
```
|
||||
- **JWT Requirements**:
|
||||
- Algorithm: Configured via `settings.JWT_ALGORITHM`
|
||||
- Secret Key: Configured via `settings.JWT_SECRET_KEY`
|
||||
- Audience: Must be "authenticated"
|
||||
- Claims: Must contain `sub` (user ID)
|
||||
|
||||
### Authentication Failures
|
||||
- **4001**: Missing authentication token
|
||||
- **4002**: Invalid token (missing user ID)
|
||||
- **4003**: Invalid token (parsing error or expired)
|
||||
|
||||
### No-Auth Mode
|
||||
- When `settings.config.enable_auth` is `False`
|
||||
- Uses `DEFAULT_USER_ID` from `backend.data.user`
|
||||
|
||||
## Message Protocol
|
||||
|
||||
### Message Format
|
||||
All messages use JSON format with the following structure:
|
||||
|
||||
```typescript
|
||||
interface WSMessage {
|
||||
method: WSMethod;
|
||||
data?: Record<string, any> | any[] | string;
|
||||
success?: boolean;
|
||||
channel?: string;
|
||||
error?: string;
|
||||
}
|
||||
```
|
||||
|
||||
### Message Methods (WSMethod enum)
|
||||
|
||||
1. **Client-to-Server Methods**:
|
||||
- `SUBSCRIBE_GRAPH_EXEC`: Subscribe to specific graph execution
|
||||
- `SUBSCRIBE_GRAPH_EXECS`: Subscribe to all executions of a graph
|
||||
- `UNSUBSCRIBE`: Unsubscribe from a channel
|
||||
- `HEARTBEAT`: Keep-alive ping
|
||||
|
||||
2. **Server-to-Client Methods**:
|
||||
- `GRAPH_EXECUTION_EVENT`: Graph execution status update
|
||||
- `NODE_EXECUTION_EVENT`: Node execution status update
|
||||
- `ERROR`: Error message
|
||||
- `HEARTBEAT`: Keep-alive pong
|
||||
|
||||
## Subscription Models
|
||||
|
||||
### Subscribe to Specific Graph Execution
|
||||
```typescript
|
||||
interface WSSubscribeGraphExecutionRequest {
|
||||
graph_exec_id: string;
|
||||
}
|
||||
```
|
||||
**Channel Key Format**: `{user_id}|graph_exec#{graph_exec_id}`
|
||||
|
||||
### Subscribe to All Graph Executions
|
||||
```typescript
|
||||
interface WSSubscribeGraphExecutionsRequest {
|
||||
graph_id: string;
|
||||
}
|
||||
```
|
||||
**Channel Key Format**: `{user_id}|graph#{graph_id}|executions`
|
||||
|
||||
## Event Models
|
||||
|
||||
### Graph Execution Event
|
||||
```typescript
|
||||
interface GraphExecutionEvent {
|
||||
event_type: "graph_execution_update";
|
||||
id: string; // graph_exec_id
|
||||
user_id: string;
|
||||
graph_id: string;
|
||||
graph_version: number;
|
||||
preset_id?: string;
|
||||
status: ExecutionStatus;
|
||||
started_at: string; // ISO datetime
|
||||
ended_at: string; // ISO datetime
|
||||
inputs: Record<string, any>;
|
||||
outputs: Record<string, any>;
|
||||
stats?: {
|
||||
cost: number; // cents
|
||||
duration: number; // seconds
|
||||
duration_cpu_only: number;
|
||||
node_exec_time: number;
|
||||
node_exec_time_cpu_only: number;
|
||||
node_exec_count: number;
|
||||
node_error_count: number;
|
||||
error?: string;
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
### Node Execution Event
|
||||
```typescript
|
||||
interface NodeExecutionEvent {
|
||||
event_type: "node_execution_update";
|
||||
user_id: string;
|
||||
graph_id: string;
|
||||
graph_version: number;
|
||||
graph_exec_id: string;
|
||||
node_exec_id: string;
|
||||
node_id: string;
|
||||
block_id: string;
|
||||
status: ExecutionStatus;
|
||||
input_data: Record<string, any>;
|
||||
output_data: Record<string, any>;
|
||||
add_time: string; // ISO datetime
|
||||
queue_time?: string; // ISO datetime
|
||||
start_time?: string; // ISO datetime
|
||||
end_time?: string; // ISO datetime
|
||||
}
|
||||
```
|
||||
|
||||
### Execution Status Enum
|
||||
```typescript
|
||||
enum ExecutionStatus {
|
||||
INCOMPLETE = "INCOMPLETE",
|
||||
QUEUED = "QUEUED",
|
||||
RUNNING = "RUNNING",
|
||||
COMPLETED = "COMPLETED",
|
||||
FAILED = "FAILED"
|
||||
}
|
||||
```
|
||||
|
||||
## Message Flow Examples
|
||||
|
||||
### 1. Subscribe to Graph Execution
|
||||
```json
|
||||
// Client → Server
|
||||
{
|
||||
"method": "subscribe_graph_execution",
|
||||
"data": {
|
||||
"graph_exec_id": "exec-123"
|
||||
}
|
||||
}
|
||||
|
||||
// Server → Client (Success)
|
||||
{
|
||||
"method": "subscribe_graph_execution",
|
||||
"success": true,
|
||||
"channel": "user-456|graph_exec#exec-123"
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Receive Execution Updates
|
||||
```json
|
||||
// Server → Client (Graph Update)
|
||||
{
|
||||
"method": "graph_execution_event",
|
||||
"channel": "user-456|graph_exec#exec-123",
|
||||
"data": {
|
||||
"event_type": "graph_execution_update",
|
||||
"id": "exec-123",
|
||||
"user_id": "user-456",
|
||||
"graph_id": "graph-789",
|
||||
"status": "RUNNING",
|
||||
// ... other fields
|
||||
}
|
||||
}
|
||||
|
||||
// Server → Client (Node Update)
|
||||
{
|
||||
"method": "node_execution_event",
|
||||
"channel": "user-456|graph_exec#exec-123",
|
||||
"data": {
|
||||
"event_type": "node_execution_update",
|
||||
"node_exec_id": "node-exec-111",
|
||||
"status": "COMPLETED",
|
||||
// ... other fields
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Heartbeat
|
||||
```json
|
||||
// Client → Server
|
||||
{
|
||||
"method": "heartbeat",
|
||||
"data": "ping"
|
||||
}
|
||||
|
||||
// Server → Client
|
||||
{
|
||||
"method": "heartbeat",
|
||||
"data": "pong",
|
||||
"success": true
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Error Handling
|
||||
```json
|
||||
// Server → Client (Invalid Message)
|
||||
{
|
||||
"method": "error",
|
||||
"success": false,
|
||||
"error": "Invalid message format. Review the schema and retry"
|
||||
}
|
||||
```
|
||||
|
||||
## Event Broadcasting Architecture
|
||||
|
||||
### Redis Pub/Sub Integration
|
||||
1. **Event Bus Name**: Configured via `config.execution_event_bus_name`
|
||||
2. **Channel Pattern**: `{event_bus_name}/{channel_key}`
|
||||
3. **Event Flow**:
|
||||
- Execution services publish events to Redis
|
||||
- Event broadcaster listens to Redis pattern `*`
|
||||
- Events are routed to WebSocket connections based on subscriptions
|
||||
|
||||
### Event Broadcaster
|
||||
- Runs as continuous async task using `@continuous_retry()` decorator
|
||||
- Listens to all execution events via `AsyncRedisExecutionEventBus`
|
||||
- Calls `ConnectionManager.send_execution_update()` for each event
|
||||
|
||||
## Connection Lifecycle
|
||||
|
||||
### Connection Establishment
|
||||
1. Client connects to `/ws` endpoint
|
||||
2. Authentication performed (JWT validation)
|
||||
3. WebSocket accepted via `manager.connect_socket()`
|
||||
4. Connection added to active connections set
|
||||
|
||||
### Message Processing Loop
|
||||
1. Receive text message from client
|
||||
2. Parse and validate as `WSMessage`
|
||||
3. Route to appropriate handler based on `method`
|
||||
4. Send response or error back to client
|
||||
|
||||
### Connection Termination
|
||||
1. `WebSocketDisconnect` exception caught
|
||||
2. `manager.disconnect_socket()` called
|
||||
3. Connection removed from active connections
|
||||
4. All subscriptions for that connection removed
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Validation Errors
|
||||
- **Invalid Message Format**: Returns error with method "error"
|
||||
- **Invalid Message Data**: Returns error with specific validation message
|
||||
- **Unknown Message Type**: Returns error indicating unsupported method
|
||||
|
||||
### Connection Errors
|
||||
- WebSocket disconnections handled gracefully
|
||||
- Failed event parsing logged but doesn't crash connection
|
||||
- Handler exceptions logged and connection continues
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
```python
|
||||
# WebSocket Server Configuration
|
||||
websocket_server_host: str = "0.0.0.0"
|
||||
websocket_server_port: int = 8001
|
||||
|
||||
# Authentication
|
||||
enable_auth: bool = True
|
||||
|
||||
# CORS
|
||||
backend_cors_allow_origins: List[str] = []
|
||||
|
||||
# Redis Event Bus
|
||||
execution_event_bus_name: str = "autogpt:execution_event_bus"
|
||||
|
||||
# Message Size Limits
|
||||
max_message_size_limit: int = 512000 # 512KB
|
||||
```
|
||||
|
||||
### Security Headers
|
||||
- CORS middleware applied with configured origins
|
||||
- Credentials allowed for authenticated requests
|
||||
- All methods and headers allowed (configurable)
|
||||
|
||||
## Deployment Requirements
|
||||
|
||||
### Dependencies
|
||||
1. **FastAPI**: Web framework with WebSocket support
|
||||
2. **Redis**: For pub/sub event broadcasting
|
||||
3. **JWT Libraries**: For token validation
|
||||
4. **Prisma**: Database ORM (for future graph access validation)
|
||||
|
||||
### Process Management
|
||||
- Implements `AppProcess` interface for service lifecycle
|
||||
- Runs via `uvicorn` ASGI server
|
||||
- Graceful shutdown handling in `cleanup()` method
|
||||
|
||||
### Concurrent Connections
|
||||
- No hard limit on WebSocket connections
|
||||
- Memory usage scales with active connections
|
||||
- Each connection maintains subscription set
|
||||
|
||||
## Implementation Checklist
|
||||
|
||||
To implement a compatible WebSocket API:
|
||||
|
||||
1. **Authentication**
|
||||
- [ ] JWT token validation from query parameters
|
||||
- [ ] Support for no-auth mode with default user ID
|
||||
- [ ] Proper error codes for auth failures
|
||||
|
||||
2. **Message Handling**
|
||||
- [ ] Parse and validate WSMessage format
|
||||
- [ ] Implement all client-to-server methods
|
||||
- [ ] Support all server-to-client event types
|
||||
- [ ] Proper error responses for invalid messages
|
||||
|
||||
3. **Subscription Management**
|
||||
- [ ] Channel key generation matching exact format
|
||||
- [ ] Support for both execution and graph-level subscriptions
|
||||
- [ ] Unsubscribe functionality
|
||||
- [ ] Clean up subscriptions on disconnect
|
||||
|
||||
4. **Event Broadcasting**
|
||||
- [ ] Listen to Redis pub/sub for execution events
|
||||
- [ ] Route events to correct subscribed connections
|
||||
- [ ] Handle both graph and node execution events
|
||||
- [ ] Maintain event order and completeness
|
||||
|
||||
5. **Connection Management**
|
||||
- [ ] Track active WebSocket connections
|
||||
- [ ] Handle graceful disconnections
|
||||
- [ ] Implement heartbeat/keepalive
|
||||
- [ ] Memory-efficient subscription storage
|
||||
|
||||
6. **Configuration**
|
||||
- [ ] Support all environment variables
|
||||
- [ ] CORS configuration for allowed origins
|
||||
- [ ] Configurable host/port binding
|
||||
- [ ] Redis connection configuration
|
||||
|
||||
7. **Error Handling**
|
||||
- [ ] Graceful handling of malformed messages
|
||||
- [ ] Logging of errors without dropping connections
|
||||
- [ ] Specific error messages for debugging
|
||||
- [ ] Recovery from Redis connection issues
|
||||
|
||||
## Testing Considerations
|
||||
|
||||
1. **Unit Tests**
|
||||
- Message parsing and validation
|
||||
- Channel key generation
|
||||
- Subscription management logic
|
||||
|
||||
2. **Integration Tests**
|
||||
- Full WebSocket connection flow
|
||||
- Event broadcasting from Redis
|
||||
- Multi-client subscription scenarios
|
||||
- Authentication success/failure cases
|
||||
|
||||
3. **Load Tests**
|
||||
- Many concurrent connections
|
||||
- High-frequency event broadcasting
|
||||
- Memory usage under load
|
||||
- Connection/disconnection cycles
|
||||
|
||||
## Security Considerations
|
||||
|
||||
1. **Authentication**: JWT tokens transmitted via query parameters (consider upgrading to headers)
|
||||
2. **Authorization**: Currently no graph-level access validation (commented out in code)
|
||||
3. **Rate Limiting**: No rate limiting implemented
|
||||
4. **Message Size**: Limited by `max_message_size_limit` configuration
|
||||
5. **Input Validation**: All inputs validated via Pydantic models
|
||||
|
||||
## Future Enhancements (Currently Commented Out)
|
||||
|
||||
1. **Graph Access Validation**: Verify user has read access to subscribed graphs
|
||||
2. **Message Compression**: For large execution payloads
|
||||
3. **Batch Updates**: Aggregate multiple events in single message
|
||||
4. **Selective Field Subscription**: Subscribe to specific fields only
|
||||
93
autogpt_platform/autogpt-rs/websocket/benches/README.md
Normal file
93
autogpt_platform/autogpt-rs/websocket/benches/README.md
Normal file
@@ -0,0 +1,93 @@
|
||||
# WebSocket Server Benchmarks
|
||||
|
||||
This directory contains performance benchmarks for the AutoGPT WebSocket server.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. Redis must be running locally or set `REDIS_URL` environment variable:
|
||||
```bash
|
||||
docker run -d -p 6379:6379 redis:latest
|
||||
```
|
||||
|
||||
2. Build the project in release mode:
|
||||
```bash
|
||||
cargo build --release
|
||||
```
|
||||
|
||||
## Running Benchmarks
|
||||
|
||||
Run all benchmarks:
|
||||
```bash
|
||||
cargo bench
|
||||
```
|
||||
|
||||
Run specific benchmark group:
|
||||
```bash
|
||||
cargo bench connection_establishment
|
||||
cargo bench subscriptions
|
||||
cargo bench message_throughput
|
||||
cargo bench concurrent_connections
|
||||
cargo bench message_parsing
|
||||
cargo bench redis_event_processing
|
||||
```
|
||||
|
||||
## Benchmark Categories
|
||||
|
||||
### Connection Establishment
|
||||
Tests the performance of establishing WebSocket connections with different authentication scenarios:
|
||||
- No authentication
|
||||
- Valid JWT authentication
|
||||
- Invalid JWT authentication (connection rejection)
|
||||
|
||||
### Subscriptions
|
||||
Measures the performance of subscription operations:
|
||||
- Subscribing to graph execution events
|
||||
- Unsubscribing from channels
|
||||
|
||||
### Message Throughput
|
||||
Tests how many messages the server can process per second with varying message counts (10, 100, 1000).
|
||||
|
||||
### Concurrent Connections
|
||||
Benchmarks the server's ability to handle multiple simultaneous connections (10, 50, 100, 500 clients).
|
||||
|
||||
### Message Parsing
|
||||
Tests JSON parsing performance with different message sizes (100B to 100KB).
|
||||
|
||||
### Redis Event Processing
|
||||
Benchmarks the parsing of execution events received from Redis.
|
||||
|
||||
## Profiling
|
||||
|
||||
To generate flamegraphs for CPU profiling:
|
||||
|
||||
1. Install flamegraph tools:
|
||||
```bash
|
||||
cargo install flamegraph
|
||||
```
|
||||
|
||||
2. Run benchmarks with profiling:
|
||||
```bash
|
||||
cargo bench --bench websocket_bench -- --profile-time=10
|
||||
```
|
||||
|
||||
## Interpreting Results
|
||||
|
||||
- **Throughput**: Higher is better (operations/second or elements/second)
|
||||
- **Time**: Lower is better (nanoseconds per operation)
|
||||
- **Error margins**: Look for stable results with low standard deviation
|
||||
|
||||
## Optimizing Performance
|
||||
|
||||
Based on benchmark results, consider:
|
||||
|
||||
1. **Connection pooling** for Redis connections
|
||||
2. **Message batching** for high-throughput scenarios
|
||||
3. **Async task tuning** for concurrent connection handling
|
||||
4. **JSON parsing optimization** using simd-json or other fast parsers
|
||||
5. **Memory allocation** optimization using arena allocators
|
||||
|
||||
## Notes
|
||||
|
||||
- Benchmarks create actual WebSocket servers on random ports
|
||||
- Each benchmark iteration properly cleans up resources
|
||||
- Results may vary based on system resources and Redis performance
|
||||
406
autogpt_platform/autogpt-rs/websocket/benches/websocket_bench.rs
Normal file
406
autogpt_platform/autogpt-rs/websocket/benches/websocket_bench.rs
Normal file
@@ -0,0 +1,406 @@
|
||||
#![allow(clippy::unwrap_used)] // Benchmarks can panic on setup errors
|
||||
|
||||
use axum::{routing::get, Router};
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio_tungstenite::{connect_async, tungstenite::Message};
|
||||
|
||||
// Import the actual websocket server components
|
||||
use websocket::{models, ws_handler, AppState, Config, ConnectionManager, Stats};
|
||||
|
||||
// Helper to create a test server
|
||||
async fn create_test_server(enable_auth: bool) -> (String, tokio::task::JoinHandle<()>) {
|
||||
// Set environment variables for test config
|
||||
std::env::set_var("WEBSOCKET_SERVER_HOST", "127.0.0.1");
|
||||
std::env::set_var("WEBSOCKET_SERVER_PORT", "0");
|
||||
std::env::set_var("ENABLE_AUTH", enable_auth.to_string());
|
||||
std::env::set_var("SUPABASE_JWT_SECRET", "test_secret");
|
||||
std::env::set_var("DEFAULT_USER_ID", "test_user");
|
||||
if std::env::var("REDIS_URL").is_err() {
|
||||
std::env::set_var("REDIS_URL", "redis://localhost:6379");
|
||||
}
|
||||
|
||||
let mut config = Config::load(None);
|
||||
config.port = 0; // Force OS to assign port
|
||||
|
||||
let redis_client =
|
||||
redis::Client::open(config.redis_url.clone()).expect("Failed to connect to Redis");
|
||||
let stats = Arc::new(Stats::default());
|
||||
let mgr = Arc::new(ConnectionManager::new(
|
||||
redis_client,
|
||||
config.execution_event_bus_name.clone(),
|
||||
stats.clone(),
|
||||
));
|
||||
|
||||
// Start broadcaster
|
||||
let mgr_clone = mgr.clone();
|
||||
tokio::spawn(async move {
|
||||
mgr_clone.run_broadcaster().await;
|
||||
});
|
||||
|
||||
let state = AppState {
|
||||
mgr,
|
||||
config: Arc::new(config),
|
||||
stats,
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.route("/ws", get(ws_handler))
|
||||
.layer(axum::Extension(state));
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let server_url = format!("ws://{addr}");
|
||||
|
||||
let server_handle = tokio::spawn(async move {
|
||||
axum::serve(listener, app.into_make_service())
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
// Give server time to start
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
|
||||
(server_url, server_handle)
|
||||
}
|
||||
|
||||
// Helper to create a valid JWT token
|
||||
fn create_jwt_token(user_id: &str) -> String {
|
||||
use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Claims {
|
||||
sub: String,
|
||||
aud: Vec<String>,
|
||||
exp: usize,
|
||||
}
|
||||
|
||||
let claims = Claims {
|
||||
sub: user_id.to_string(),
|
||||
aud: vec!["authenticated".to_string()],
|
||||
exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp() as usize,
|
||||
};
|
||||
|
||||
encode(
|
||||
&Header::new(Algorithm::HS256),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(b"test_secret"),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
// Benchmark connection establishment
|
||||
fn benchmark_connection_establishment(c: &mut Criterion) {
|
||||
let rt = Runtime::new().unwrap();
|
||||
|
||||
let mut group = c.benchmark_group("connection_establishment");
|
||||
group.measurement_time(Duration::from_secs(30));
|
||||
|
||||
// Test without auth
|
||||
group.bench_function("no_auth", |b| {
|
||||
b.to_async(&rt).iter_with_large_drop(|| async {
|
||||
let (server_url, server_handle) = create_test_server(false).await;
|
||||
let url = format!("{server_url}/ws");
|
||||
let (ws_stream, _) = connect_async(&url).await.unwrap();
|
||||
drop(ws_stream);
|
||||
server_handle.abort();
|
||||
});
|
||||
});
|
||||
|
||||
// Test with valid auth
|
||||
group.bench_function("valid_auth", |b| {
|
||||
b.to_async(&rt).iter_with_large_drop(|| async {
|
||||
let (server_url, server_handle) = create_test_server(true).await;
|
||||
let token = create_jwt_token("test_user");
|
||||
let url = format!("{server_url}/ws?token={token}");
|
||||
let (ws_stream, _) = connect_async(&url).await.unwrap();
|
||||
drop(ws_stream);
|
||||
server_handle.abort();
|
||||
});
|
||||
});
|
||||
|
||||
// Test with invalid auth
|
||||
group.bench_function("invalid_auth", |b| {
|
||||
b.to_async(&rt).iter_with_large_drop(|| async {
|
||||
let (server_url, server_handle) = create_test_server(true).await;
|
||||
let url = format!("{server_url}/ws?token=invalid");
|
||||
let result = connect_async(&url).await;
|
||||
assert!(
|
||||
result.is_err() || {
|
||||
if let Ok((mut ws_stream, _)) = result {
|
||||
// Should receive close frame
|
||||
matches!(ws_stream.next().await, Some(Ok(Message::Close(_))))
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
);
|
||||
server_handle.abort();
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// Benchmark subscription operations
|
||||
fn benchmark_subscriptions(c: &mut Criterion) {
|
||||
let rt = Runtime::new().unwrap();
|
||||
|
||||
let mut group = c.benchmark_group("subscriptions");
|
||||
group.measurement_time(Duration::from_secs(20));
|
||||
|
||||
group.bench_function("subscribe_graph_execution", |b| {
|
||||
b.to_async(&rt).iter_with_large_drop(|| async {
|
||||
let (server_url, server_handle) = create_test_server(false).await;
|
||||
let url = format!("{server_url}/ws");
|
||||
let (mut ws_stream, _) = connect_async(&url).await.unwrap();
|
||||
let msg = json!({
|
||||
"method": "subscribe_graph_execution",
|
||||
"data": {
|
||||
"graph_exec_id": "test_exec_123"
|
||||
}
|
||||
});
|
||||
|
||||
ws_stream
|
||||
.send(Message::Text(msg.to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait for response
|
||||
if let Some(Ok(Message::Text(response))) = ws_stream.next().await {
|
||||
let resp: serde_json::Value = serde_json::from_str(&response).unwrap();
|
||||
assert_eq!(resp["success"], true);
|
||||
}
|
||||
|
||||
server_handle.abort();
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("unsubscribe", |b| {
|
||||
b.to_async(&rt).iter_with_large_drop(|| async {
|
||||
let (server_url, server_handle) = create_test_server(false).await;
|
||||
let url = format!("{server_url}/ws");
|
||||
let (mut ws_stream, _) = connect_async(&url).await.unwrap();
|
||||
|
||||
// First subscribe
|
||||
let msg = json!({
|
||||
"method": "subscribe_graph_execution",
|
||||
"data": {
|
||||
"graph_exec_id": "test_exec_123"
|
||||
}
|
||||
});
|
||||
ws_stream
|
||||
.send(Message::Text(msg.to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
ws_stream.next().await; // Consume response
|
||||
let msg = json!({
|
||||
"method": "unsubscribe",
|
||||
"data": {
|
||||
"channel": "test_user|graph_exec#test_exec_123"
|
||||
}
|
||||
});
|
||||
|
||||
ws_stream
|
||||
.send(Message::Text(msg.to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait for response
|
||||
if let Some(Ok(Message::Text(response))) = ws_stream.next().await {
|
||||
let resp: serde_json::Value = serde_json::from_str(&response).unwrap();
|
||||
assert_eq!(resp["success"], true);
|
||||
}
|
||||
|
||||
server_handle.abort();
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// Benchmark message throughput
|
||||
fn benchmark_message_throughput(c: &mut Criterion) {
|
||||
let rt = Runtime::new().unwrap();
|
||||
|
||||
let mut group = c.benchmark_group("message_throughput");
|
||||
group.measurement_time(Duration::from_secs(30));
|
||||
|
||||
for msg_count in [10, 100, 1000].iter() {
|
||||
group.throughput(Throughput::Elements(*msg_count as u64));
|
||||
group.bench_with_input(
|
||||
BenchmarkId::from_parameter(msg_count),
|
||||
msg_count,
|
||||
|b, &msg_count| {
|
||||
b.to_async(&rt).iter_with_large_drop(|| async {
|
||||
let (server_url, server_handle) = create_test_server(false).await;
|
||||
let url = format!("{server_url}/ws");
|
||||
let (mut ws_stream, _) = connect_async(&url).await.unwrap();
|
||||
// Send multiple heartbeat messages
|
||||
for _ in 0..msg_count {
|
||||
let msg = json!({
|
||||
"method": "heartbeat",
|
||||
"data": "ping"
|
||||
});
|
||||
ws_stream
|
||||
.send(Message::Text(msg.to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Receive all responses
|
||||
for _ in 0..msg_count {
|
||||
ws_stream.next().await;
|
||||
}
|
||||
|
||||
server_handle.abort();
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// Benchmark concurrent connections
|
||||
fn benchmark_concurrent_connections(c: &mut Criterion) {
|
||||
let rt = Runtime::new().unwrap();
|
||||
|
||||
let mut group = c.benchmark_group("concurrent_connections");
|
||||
group.measurement_time(Duration::from_secs(60));
|
||||
group.sample_size(10);
|
||||
|
||||
for num_clients in [100, 500, 1000].iter() {
|
||||
group.throughput(Throughput::Elements(*num_clients as u64));
|
||||
group.bench_with_input(
|
||||
BenchmarkId::from_parameter(num_clients),
|
||||
num_clients,
|
||||
|b, &num_clients| {
|
||||
b.to_async(&rt).iter_with_large_drop(|| async {
|
||||
let (server_url, server_handle) = create_test_server(false).await;
|
||||
let url = format!("{server_url}/ws");
|
||||
|
||||
// Create multiple concurrent connections
|
||||
let mut handles = vec![];
|
||||
for i in 0..num_clients {
|
||||
let url = url.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
let (mut ws_stream, _) = connect_async(&url).await.unwrap();
|
||||
|
||||
// Subscribe to a unique channel
|
||||
let msg = json!({
|
||||
"method": "subscribe_graph_execution",
|
||||
"data": {
|
||||
"graph_exec_id": format!("exec_{}", i)
|
||||
}
|
||||
});
|
||||
ws_stream
|
||||
.send(Message::Text(msg.to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
ws_stream.next().await; // Wait for response
|
||||
|
||||
// Send a heartbeat
|
||||
let msg = json!({
|
||||
"method": "heartbeat",
|
||||
"data": "ping"
|
||||
});
|
||||
ws_stream
|
||||
.send(Message::Text(msg.to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
ws_stream.next().await; // Wait for response
|
||||
|
||||
ws_stream
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all connections to complete
|
||||
for handle in handles {
|
||||
let _ = handle.await;
|
||||
}
|
||||
|
||||
server_handle.abort();
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// Benchmark message parsing
|
||||
fn benchmark_message_parsing(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("message_parsing");
|
||||
|
||||
// Test different message sizes
|
||||
for msg_size in [100, 1000, 10000].iter() {
|
||||
group.throughput(Throughput::Bytes(*msg_size as u64));
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("parse_json", msg_size),
|
||||
msg_size,
|
||||
|b, &msg_size| {
|
||||
let data_str = "x".repeat(msg_size);
|
||||
let json_msg = json!({
|
||||
"method": "subscribe_graph_execution",
|
||||
"data": {
|
||||
"graph_exec_id": data_str
|
||||
}
|
||||
});
|
||||
let json_str = json_msg.to_string();
|
||||
|
||||
b.iter(|| {
|
||||
let _: models::WSMessage = serde_json::from_str(&json_str).unwrap();
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// Benchmark Redis event processing
|
||||
fn benchmark_redis_event_processing(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("redis_event_processing");
|
||||
|
||||
group.bench_function("parse_execution_event", |b| {
|
||||
let event = json!({
|
||||
"payload": {
|
||||
"event_type": "graph_execution_update",
|
||||
"id": "exec_123",
|
||||
"graph_id": "graph_456",
|
||||
"graph_version": 1,
|
||||
"user_id": "user_789",
|
||||
"status": "RUNNING",
|
||||
"started_at": "2024-01-01T00:00:00Z",
|
||||
"inputs": {"test": "data"},
|
||||
"outputs": {}
|
||||
}
|
||||
});
|
||||
let event_str = event.to_string();
|
||||
|
||||
b.iter(|| {
|
||||
let _: models::RedisEventWrapper = serde_json::from_str(&event_str).unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
benchmark_connection_establishment,
|
||||
benchmark_subscriptions,
|
||||
benchmark_message_throughput,
|
||||
benchmark_concurrent_connections,
|
||||
benchmark_message_parsing,
|
||||
benchmark_redis_event_processing
|
||||
);
|
||||
criterion_main!(benches);
|
||||
10
autogpt_platform/autogpt-rs/websocket/clippy.toml
Normal file
10
autogpt_platform/autogpt-rs/websocket/clippy.toml
Normal file
@@ -0,0 +1,10 @@
|
||||
# Clippy configuration for robust error handling
|
||||
|
||||
# Set the maximum cognitive complexity allowed
|
||||
cognitive-complexity-threshold = 30
|
||||
|
||||
# Warn on TODO/FIXME comments
|
||||
allow-dbg-in-tests = false
|
||||
|
||||
# Enforce documentation
|
||||
missing-docs-in-crate-items = true
|
||||
23
autogpt_platform/autogpt-rs/websocket/config.toml
Normal file
23
autogpt_platform/autogpt-rs/websocket/config.toml
Normal file
@@ -0,0 +1,23 @@
|
||||
# WebSocket API Configuration
|
||||
|
||||
# Server settings
|
||||
host = "0.0.0.0"
|
||||
port = 8001
|
||||
|
||||
# Authentication
|
||||
enable_auth = true
|
||||
jwt_secret = "your-super-secret-jwt-token-with-at-least-32-characters-long"
|
||||
jwt_algorithm = "HS256"
|
||||
default_user_id = "default"
|
||||
|
||||
# Redis configuration
|
||||
redis_url = "redis://:password@localhost:6379/"
|
||||
|
||||
# Event bus
|
||||
execution_event_bus_name = "execution_event"
|
||||
|
||||
# Message size limit (in bytes)
|
||||
max_message_size_limit = 512000
|
||||
|
||||
# CORS allowed origins
|
||||
backend_cors_allow_origins = ["http://localhost:3000", "https://559f69c159ef.ngrok.app"]
|
||||
@@ -0,0 +1,75 @@
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde_json::json;
|
||||
use tokio_tungstenite::{connect_async, tungstenite::Message};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = "ws://localhost:8001/ws";
|
||||
|
||||
println!("Connecting to {url}");
|
||||
let (mut ws_stream, _) = connect_async(url).await?;
|
||||
println!("Connected!");
|
||||
|
||||
// Subscribe to a graph execution
|
||||
let subscribe_msg = json!({
|
||||
"method": "subscribe_graph_execution",
|
||||
"data": {
|
||||
"graph_exec_id": "test_exec_123"
|
||||
}
|
||||
});
|
||||
|
||||
println!("Sending subscription request...");
|
||||
ws_stream
|
||||
.send(Message::Text(subscribe_msg.to_string()))
|
||||
.await?;
|
||||
|
||||
// Wait for response
|
||||
if let Some(msg) = ws_stream.next().await {
|
||||
if let Message::Text(text) = msg? {
|
||||
println!("Received: {text}");
|
||||
}
|
||||
}
|
||||
|
||||
// Send heartbeat
|
||||
let heartbeat_msg = json!({
|
||||
"method": "heartbeat",
|
||||
"data": "ping"
|
||||
});
|
||||
|
||||
println!("Sending heartbeat...");
|
||||
ws_stream
|
||||
.send(Message::Text(heartbeat_msg.to_string()))
|
||||
.await?;
|
||||
|
||||
// Wait for pong
|
||||
if let Some(msg) = ws_stream.next().await {
|
||||
if let Message::Text(text) = msg? {
|
||||
println!("Received: {text}");
|
||||
}
|
||||
}
|
||||
|
||||
// Unsubscribe
|
||||
let unsubscribe_msg = json!({
|
||||
"method": "unsubscribe",
|
||||
"data": {
|
||||
"channel": "default|graph_exec#test_exec_123"
|
||||
}
|
||||
});
|
||||
|
||||
println!("Sending unsubscribe request...");
|
||||
ws_stream
|
||||
.send(Message::Text(unsubscribe_msg.to_string()))
|
||||
.await?;
|
||||
|
||||
// Wait for response
|
||||
if let Some(msg) = ws_stream.next().await {
|
||||
if let Message::Text(text) = msg? {
|
||||
println!("Received: {text}");
|
||||
}
|
||||
}
|
||||
|
||||
println!("Closing connection...");
|
||||
ws_stream.close(None).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
99
autogpt_platform/autogpt-rs/websocket/src/config.rs
Normal file
99
autogpt_platform/autogpt-rs/websocket/src/config.rs
Normal file
@@ -0,0 +1,99 @@
|
||||
use jsonwebtoken::Algorithm;
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct Config {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub enable_auth: bool,
|
||||
pub jwt_secret: String,
|
||||
pub jwt_algorithm: Algorithm,
|
||||
pub execution_event_bus_name: String,
|
||||
pub redis_url: String,
|
||||
pub default_user_id: String,
|
||||
pub max_message_size_limit: usize,
|
||||
pub backend_cors_allow_origins: Vec<String>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn load(config_path: Option<&Path>) -> Self {
|
||||
let path = config_path.unwrap_or(Path::new("config.toml"));
|
||||
let toml_result = fs::read_to_string(path)
|
||||
.ok()
|
||||
.and_then(|s| toml::from_str::<Config>(&s).ok());
|
||||
|
||||
let mut config = match toml_result {
|
||||
Some(config) => config,
|
||||
None => Config {
|
||||
host: env::var("WEBSOCKET_SERVER_HOST").unwrap_or_else(|_| "0.0.0.0".to_string()),
|
||||
port: env::var("WEBSOCKET_SERVER_PORT")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(8001),
|
||||
enable_auth: env::var("ENABLE_AUTH")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(true),
|
||||
jwt_secret: env::var("SUPABASE_JWT_SECRET")
|
||||
.unwrap_or_else(|_| "dummy_secret_for_no_auth".to_string()),
|
||||
jwt_algorithm: Algorithm::HS256,
|
||||
execution_event_bus_name: env::var("EXECUTION_EVENT_BUS_NAME")
|
||||
.unwrap_or_else(|_| "execution_event".to_string()),
|
||||
redis_url: env::var("REDIS_URL")
|
||||
.unwrap_or_else(|_| "redis://localhost/".to_string()),
|
||||
default_user_id: "default".to_string(),
|
||||
max_message_size_limit: env::var("MAX_MESSAGE_SIZE_LIMIT")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(512000),
|
||||
backend_cors_allow_origins: env::var("BACKEND_CORS_ALLOW_ORIGINS")
|
||||
.unwrap_or_default()
|
||||
.split(',')
|
||||
.map(|s| s.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect(),
|
||||
},
|
||||
};
|
||||
|
||||
if let Ok(v) = env::var("WEBSOCKET_SERVER_HOST") {
|
||||
config.host = v;
|
||||
}
|
||||
if let Ok(v) = env::var("WEBSOCKET_SERVER_PORT") {
|
||||
config.port = v.parse().unwrap_or(8001);
|
||||
}
|
||||
if let Ok(v) = env::var("ENABLE_AUTH") {
|
||||
config.enable_auth = v.parse().unwrap_or(true);
|
||||
}
|
||||
if let Ok(v) = env::var("SUPABASE_JWT_SECRET") {
|
||||
config.jwt_secret = v;
|
||||
}
|
||||
if let Ok(v) = env::var("JWT_ALGORITHM") {
|
||||
config.jwt_algorithm = Algorithm::from_str(&v).unwrap_or(Algorithm::HS256);
|
||||
}
|
||||
if let Ok(v) = env::var("EXECUTION_EVENT_BUS_NAME") {
|
||||
config.execution_event_bus_name = v;
|
||||
}
|
||||
if let Ok(v) = env::var("REDIS_URL") {
|
||||
config.redis_url = v;
|
||||
}
|
||||
if let Ok(v) = env::var("DEFAULT_USER_ID") {
|
||||
config.default_user_id = v;
|
||||
}
|
||||
if let Ok(v) = env::var("MAX_MESSAGE_SIZE_LIMIT") {
|
||||
config.max_message_size_limit = v.parse().unwrap_or(512000);
|
||||
}
|
||||
if let Ok(v) = env::var("BACKEND_CORS_ALLOW_ORIGINS") {
|
||||
config.backend_cors_allow_origins = v
|
||||
.split(',')
|
||||
.map(|s| s.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect();
|
||||
}
|
||||
|
||||
config
|
||||
}
|
||||
}
|
||||
277
autogpt_platform/autogpt-rs/websocket/src/connection_manager.rs
Normal file
277
autogpt_platform/autogpt-rs/websocket/src/connection_manager.rs
Normal file
@@ -0,0 +1,277 @@
|
||||
use futures::StreamExt;
|
||||
use redis::Client as RedisClient;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::models::{ExecutionEvent, RedisEventWrapper, WSMessage};
|
||||
use crate::stats::Stats;
|
||||
|
||||
pub struct ConnectionManager {
|
||||
pub subscribers: RwLock<HashMap<String, HashSet<u64>>>,
|
||||
pub clients: RwLock<HashMap<u64, (String, mpsc::Sender<String>)>>,
|
||||
pub client_channels: RwLock<HashMap<u64, HashSet<String>>>,
|
||||
pub next_id: AtomicU64,
|
||||
pub redis_client: RedisClient,
|
||||
pub bus_name: String,
|
||||
pub stats: Arc<Stats>,
|
||||
}
|
||||
|
||||
impl ConnectionManager {
|
||||
pub fn new(redis_client: RedisClient, bus_name: String, stats: Arc<Stats>) -> Self {
|
||||
Self {
|
||||
subscribers: RwLock::new(HashMap::new()),
|
||||
clients: RwLock::new(HashMap::new()),
|
||||
client_channels: RwLock::new(HashMap::new()),
|
||||
next_id: AtomicU64::new(0),
|
||||
redis_client,
|
||||
bus_name,
|
||||
stats,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_broadcaster(self: Arc<Self>) {
|
||||
info!("🚀 Starting Redis event broadcaster");
|
||||
|
||||
loop {
|
||||
match self.run_broadcaster_inner().await {
|
||||
Ok(_) => {
|
||||
warn!("⚠️ Event broadcaster stopped unexpectedly, restarting in 5 seconds");
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("❌ Event broadcaster error: {}, restarting in 5 seconds", e);
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_broadcaster_inner(
|
||||
self: &Arc<Self>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut pubsub = self.redis_client.get_async_pubsub().await?;
|
||||
pubsub.psubscribe("*").await?;
|
||||
info!(
|
||||
"📡 Listening to all Redis events, filtering for bus: {}",
|
||||
self.bus_name
|
||||
);
|
||||
|
||||
let mut pubsub_stream = pubsub.on_message();
|
||||
|
||||
loop {
|
||||
let msg = pubsub_stream.next().await;
|
||||
match msg {
|
||||
Some(msg) => {
|
||||
let channel: String = msg.get_channel_name().to_string();
|
||||
debug!("📨 Received message on Redis channel: {}", channel);
|
||||
self.stats
|
||||
.redis_messages_received
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
let payload: String = match msg.get_payload() {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
warn!("⚠️ Failed to get payload from Redis message: {}", e);
|
||||
self.stats
|
||||
.errors_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Parse the channel format: execution_event/{user_id}/{graph_id}/{graph_exec_id}
|
||||
let parts: Vec<&str> = channel.split('/').collect();
|
||||
|
||||
// Check if this is an execution event channel
|
||||
if parts.len() != 4 || parts[0] != self.bus_name {
|
||||
debug!(
|
||||
"🚫 Ignoring non-execution event channel: {} (parts: {:?}, bus_name: {})",
|
||||
channel, parts, self.bus_name
|
||||
);
|
||||
self.stats
|
||||
.redis_messages_ignored
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
continue;
|
||||
}
|
||||
|
||||
let user_id = parts[1];
|
||||
let graph_id = parts[2];
|
||||
let graph_exec_id = parts[3];
|
||||
|
||||
debug!(
|
||||
"📥 Received event - user: {}, graph: {}, exec: {}",
|
||||
user_id, graph_id, graph_exec_id
|
||||
);
|
||||
|
||||
// Parse the wrapped event
|
||||
let wrapped_event = match RedisEventWrapper::parse(&payload) {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
warn!("⚠️ Failed to parse event JSON: {}, payload: {}", e, payload);
|
||||
self.stats
|
||||
.errors_json_parse
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
self.stats
|
||||
.errors_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let event = wrapped_event.payload;
|
||||
debug!("📦 Event received: {:?}", event);
|
||||
|
||||
let (method, event_json) = match &event {
|
||||
ExecutionEvent::GraphExecutionUpdate(graph_event) => {
|
||||
self.stats
|
||||
.graph_execution_events
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
self.stats
|
||||
.events_received_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
(
|
||||
"graph_execution_event",
|
||||
match serde_json::to_value(graph_event) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
error!("❌ Failed to serialize graph event: {}", e);
|
||||
continue;
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
ExecutionEvent::NodeExecutionUpdate(node_event) => {
|
||||
self.stats
|
||||
.node_execution_events
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
self.stats
|
||||
.events_received_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
(
|
||||
"node_execution_event",
|
||||
match serde_json::to_value(node_event) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
error!("❌ Failed to serialize node event: {}", e);
|
||||
continue;
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
// Create the channel keys in the format expected by WebSocket clients
|
||||
let mut channels_to_notify = Vec::new();
|
||||
|
||||
// For both event types, notify the specific execution channel
|
||||
let exec_channel = format!("{user_id}|graph_exec#{graph_exec_id}");
|
||||
channels_to_notify.push(exec_channel.clone());
|
||||
|
||||
// For graph execution events, also notify the graph executions channel
|
||||
if matches!(&event, ExecutionEvent::GraphExecutionUpdate(_)) {
|
||||
let graph_channel = format!("{user_id}|graph#{graph_id}|executions");
|
||||
channels_to_notify.push(graph_channel);
|
||||
}
|
||||
|
||||
debug!(
|
||||
"📢 Broadcasting {} event to channels: {:?}",
|
||||
method, channels_to_notify
|
||||
);
|
||||
|
||||
let subs = self.subscribers.read().await;
|
||||
|
||||
// Log current subscriber state
|
||||
debug!("📊 Current subscribers count: {}", subs.len());
|
||||
|
||||
for channel_key in channels_to_notify {
|
||||
let ws_msg = WSMessage {
|
||||
method: method.to_string(),
|
||||
channel: Some(channel_key.clone()),
|
||||
data: Some(event_json.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
let json_msg = match serde_json::to_string(&ws_msg) {
|
||||
Ok(j) => {
|
||||
debug!("📤 Sending WebSocket message: {}", j);
|
||||
j
|
||||
}
|
||||
Err(e) => {
|
||||
error!("❌ Failed to serialize WebSocket message: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(client_ids) = subs.get(&channel_key) {
|
||||
let clients = self.clients.read().await;
|
||||
let client_count = client_ids.len();
|
||||
debug!(
|
||||
"📣 Broadcasting to {} clients on channel: {}",
|
||||
client_count, channel_key
|
||||
);
|
||||
|
||||
for &cid in client_ids {
|
||||
if let Some((user_id, tx)) = clients.get(&cid) {
|
||||
match tx.try_send(json_msg.clone()) {
|
||||
Ok(_) => {
|
||||
debug!(
|
||||
"✅ Message sent immediately to client {} (user: {})",
|
||||
cid, user_id
|
||||
);
|
||||
self.stats
|
||||
.messages_sent_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Full(_)) => {
|
||||
// Channel is full, try with a small timeout
|
||||
let tx_clone = tx.clone();
|
||||
let msg_clone = json_msg.clone();
|
||||
let stats_clone = self.stats.clone();
|
||||
tokio::spawn(async move {
|
||||
match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(100),
|
||||
tx_clone.send(msg_clone),
|
||||
)
|
||||
.await {
|
||||
Ok(Ok(_)) => {
|
||||
stats_clone
|
||||
.messages_sent_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
_ => {
|
||||
stats_clone
|
||||
.messages_failed_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
});
|
||||
warn!("⚠️ Channel full for client {} (user: {}), sending async", cid, user_id);
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => {
|
||||
warn!(
|
||||
"⚠️ Channel closed for client {} (user: {})",
|
||||
cid, user_id
|
||||
);
|
||||
self.stats
|
||||
.messages_failed_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!("⚠️ Client {} not found in clients map", cid);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
info!("📭 No subscribers for channel: {}", channel_key);
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Err("❌ Redis pubsub stream ended".into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
442
autogpt_platform/autogpt-rs/websocket/src/handlers.rs
Normal file
442
autogpt_platform/autogpt-rs/websocket/src/handlers.rs
Normal file
@@ -0,0 +1,442 @@
|
||||
use axum::extract::ws::{CloseFrame, Message, WebSocket};
|
||||
use axum::{
|
||||
extract::{Query, WebSocketUpgrade},
|
||||
http::HeaderMap,
|
||||
response::IntoResponse,
|
||||
Extension,
|
||||
};
|
||||
use jsonwebtoken::{decode, DecodingKey, Validation};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::connection_manager::ConnectionManager;
|
||||
use crate::models::{Claims, WSMessage};
|
||||
use crate::AppState;
|
||||
|
||||
// Helper function to safely serialize messages
|
||||
fn serialize_message(msg: &WSMessage) -> String {
|
||||
serde_json::to_string(msg).unwrap_or_else(|e| {
|
||||
error!("❌ Failed to serialize WebSocket message: {}", e);
|
||||
json!({"method": "error", "success": false, "error": "Internal serialization error"})
|
||||
.to_string()
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
query: Query<HashMap<String, String>>,
|
||||
_headers: HeaderMap,
|
||||
Extension(state): Extension<AppState>,
|
||||
) -> impl IntoResponse {
|
||||
let token = query.0.get("token").cloned();
|
||||
let mut user_id = state.config.default_user_id.clone();
|
||||
let mut auth_error_code: Option<u16> = None;
|
||||
|
||||
if state.config.enable_auth {
|
||||
match token {
|
||||
Some(token_str) => {
|
||||
debug!("🔐 Authenticating WebSocket connection");
|
||||
let mut validation = Validation::new(state.config.jwt_algorithm);
|
||||
validation.set_audience(&["authenticated"]);
|
||||
|
||||
let key = DecodingKey::from_secret(state.config.jwt_secret.as_bytes());
|
||||
|
||||
match decode::<Claims>(&token_str, &key, &validation) {
|
||||
Ok(token_data) => {
|
||||
user_id = token_data.claims.sub.clone();
|
||||
debug!("✅ WebSocket authenticated for user: {}", user_id);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("⚠️ JWT validation failed: {}", e);
|
||||
auth_error_code = Some(4003);
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
warn!("⚠️ Missing authentication token in WebSocket connection");
|
||||
auth_error_code = Some(4001);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("🔓 WebSocket connection without auth (auth disabled)");
|
||||
}
|
||||
|
||||
if let Some(code) = auth_error_code {
|
||||
error!("❌ WebSocket authentication failed with code: {}", code);
|
||||
state
|
||||
.mgr
|
||||
.stats
|
||||
.connections_failed_auth
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
state
|
||||
.mgr
|
||||
.stats
|
||||
.connections_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
return ws
|
||||
.on_upgrade(move |mut socket: WebSocket| async move {
|
||||
let close_frame = Some(CloseFrame {
|
||||
code,
|
||||
reason: "Authentication failed".into(),
|
||||
});
|
||||
let _ = socket.send(Message::Close(close_frame)).await;
|
||||
let _ = socket.close().await;
|
||||
})
|
||||
.into_response();
|
||||
}
|
||||
|
||||
debug!("✅ WebSocket connection established for user: {}", user_id);
|
||||
ws.on_upgrade(move |socket| {
|
||||
handle_socket(
|
||||
socket,
|
||||
user_id,
|
||||
state.mgr.clone(),
|
||||
state.config.max_message_size_limit,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
async fn update_subscription_stats(mgr: &ConnectionManager, channel: &str, add: bool) {
|
||||
if add {
|
||||
mgr.stats
|
||||
.subscriptions_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
mgr.stats
|
||||
.subscriptions_active
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
let mut channel_stats = mgr.stats.channels_active.write().await;
|
||||
let count = channel_stats.entry(channel.to_string()).or_insert(0);
|
||||
*count += 1;
|
||||
} else {
|
||||
mgr.stats
|
||||
.unsubscriptions_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
mgr.stats
|
||||
.subscriptions_active
|
||||
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
let mut channel_stats = mgr.stats.channels_active.write().await;
|
||||
if let Some(count) = channel_stats.get_mut(channel) {
|
||||
*count = count.saturating_sub(1);
|
||||
if *count == 0 {
|
||||
channel_stats.remove(channel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_socket(
|
||||
mut socket: WebSocket,
|
||||
user_id: String,
|
||||
mgr: std::sync::Arc<ConnectionManager>,
|
||||
max_size: usize,
|
||||
) {
|
||||
let client_id = mgr
|
||||
.next_id
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
let (tx, mut rx) = mpsc::channel::<String>(10);
|
||||
info!("👋 New WebSocket client {} for user: {}", client_id, user_id);
|
||||
|
||||
// Update connection stats
|
||||
mgr.stats
|
||||
.connections_total
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
mgr.stats
|
||||
.connections_active
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
// Update active users
|
||||
{
|
||||
let mut active_users = mgr.stats.active_users.write().await;
|
||||
let count = active_users.entry(user_id.clone()).or_insert(0);
|
||||
*count += 1;
|
||||
}
|
||||
|
||||
{
|
||||
let mut clients = mgr.clients.write().await;
|
||||
clients.insert(client_id, (user_id.clone(), tx));
|
||||
}
|
||||
|
||||
{
|
||||
let mut client_channels = mgr.client_channels.write().await;
|
||||
client_channels.insert(client_id, std::collections::HashSet::new());
|
||||
}
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = rx.recv() => {
|
||||
if let Some(msg) = msg {
|
||||
if socket.send(Message::Text(msg)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
incoming = socket.recv() => {
|
||||
let msg = match incoming {
|
||||
Some(Ok(msg)) => msg,
|
||||
_ => break,
|
||||
};
|
||||
match msg {
|
||||
Message::Text(text) => {
|
||||
if text.len() > max_size {
|
||||
warn!("⚠️ Message from client {} exceeds size limit: {} > {}", client_id, text.len(), max_size);
|
||||
let err_resp = serialize_message(&WSMessage {
|
||||
method: "error".to_string(),
|
||||
success: Some(false),
|
||||
error: Some("Message exceeds size limit".to_string()),
|
||||
..Default::default()
|
||||
});
|
||||
if socket.send(Message::Text(err_resp)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
mgr.stats.messages_received_total.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
let ws_msg: WSMessage = match serde_json::from_str(&text) {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
warn!("⚠️ Invalid message format from client {}: {}", client_id, e);
|
||||
mgr.stats.errors_json_parse.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
mgr.stats.errors_total.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
let err_resp = serialize_message(&WSMessage {
|
||||
method: "error".to_string(),
|
||||
success: Some(false),
|
||||
error: Some("Invalid message format. Review the schema and retry".to_string()),
|
||||
..Default::default()
|
||||
});
|
||||
if socket.send(Message::Text(err_resp)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
debug!("📥 Received {} message from client {}", ws_msg.method, client_id);
|
||||
|
||||
match ws_msg.method.as_str() {
|
||||
"subscribe_graph_execution" => {
|
||||
let graph_exec_id = match &ws_msg.data {
|
||||
Some(Value::Object(map)) => map.get("graph_exec_id").and_then(|v| v.as_str()),
|
||||
_ => None,
|
||||
};
|
||||
let Some(graph_exec_id) = graph_exec_id else {
|
||||
warn!("⚠️ Missing graph_exec_id in subscribe_graph_execution from client {}", client_id);
|
||||
let err_resp = json!({"method": "error", "success": false, "error": "Missing graph_exec_id"});
|
||||
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
};
|
||||
let channel = format!("{user_id}|graph_exec#{graph_exec_id}");
|
||||
debug!("📌 Client {} subscribing to channel: {}", client_id, channel);
|
||||
|
||||
{
|
||||
let mut subs = mgr.subscribers.write().await;
|
||||
subs.entry(channel.clone()).or_insert(std::collections::HashSet::new()).insert(client_id);
|
||||
}
|
||||
{
|
||||
let mut chs = mgr.client_channels.write().await;
|
||||
if let Some(set) = chs.get_mut(&client_id) {
|
||||
set.insert(channel.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Update subscription stats
|
||||
update_subscription_stats(&mgr, &channel, true).await;
|
||||
|
||||
let resp = WSMessage {
|
||||
method: "subscribe_graph_execution".to_string(),
|
||||
success: Some(true),
|
||||
channel: Some(channel),
|
||||
..Default::default()
|
||||
};
|
||||
if socket.send(Message::Text(serialize_message(&resp))).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
"subscribe_graph_executions" => {
|
||||
let graph_id = match &ws_msg.data {
|
||||
Some(Value::Object(map)) => map.get("graph_id").and_then(|v| v.as_str()),
|
||||
_ => None,
|
||||
};
|
||||
let Some(graph_id) = graph_id else {
|
||||
let err_resp = json!({"method": "error", "success": false, "error": "Missing graph_id"});
|
||||
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
};
|
||||
let channel = format!("{user_id}|graph#{graph_id}|executions");
|
||||
|
||||
{
|
||||
let mut subs = mgr.subscribers.write().await;
|
||||
subs.entry(channel.clone()).or_insert(std::collections::HashSet::new()).insert(client_id);
|
||||
}
|
||||
{
|
||||
let mut chs = mgr.client_channels.write().await;
|
||||
if let Some(set) = chs.get_mut(&client_id) {
|
||||
set.insert(channel.clone());
|
||||
}
|
||||
}
|
||||
debug!("📌 Client {} subscribing to channel: {}", client_id, channel);
|
||||
// Update subscription stats
|
||||
update_subscription_stats(&mgr, &channel, true).await;
|
||||
|
||||
let resp = WSMessage {
|
||||
method: "subscribe_graph_executions".to_string(),
|
||||
success: Some(true),
|
||||
channel: Some(channel),
|
||||
..Default::default()
|
||||
};
|
||||
if socket.send(Message::Text(serialize_message(&resp))).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
"unsubscribe" => {
|
||||
let channel = match &ws_msg.data {
|
||||
Some(Value::String(s)) => Some(s.as_str()),
|
||||
Some(Value::Object(map)) => map.get("channel").and_then(|v| v.as_str()),
|
||||
_ => None,
|
||||
};
|
||||
let Some(channel) = channel else {
|
||||
let err_resp = json!({"method": "error", "success": false, "error": "Missing channel"});
|
||||
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
};
|
||||
let channel = channel.to_string();
|
||||
|
||||
if !channel.starts_with(&format!("{user_id}|")) {
|
||||
let err_resp = json!({"method": "error", "success": false, "error": "Unauthorized channel"});
|
||||
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
{
|
||||
let mut subs = mgr.subscribers.write().await;
|
||||
if let Some(set) = subs.get_mut(&channel) {
|
||||
set.remove(&client_id);
|
||||
if set.is_empty() {
|
||||
subs.remove(&channel);
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
let mut chs = mgr.client_channels.write().await;
|
||||
if let Some(set) = chs.get_mut(&client_id) {
|
||||
set.remove(&channel);
|
||||
}
|
||||
}
|
||||
|
||||
// Update subscription stats
|
||||
update_subscription_stats(&mgr, &channel, false).await;
|
||||
|
||||
let resp = WSMessage {
|
||||
method: "unsubscribe".to_string(),
|
||||
success: Some(true),
|
||||
channel: Some(channel),
|
||||
..Default::default()
|
||||
};
|
||||
if socket.send(Message::Text(serialize_message(&resp))).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
"heartbeat" => {
|
||||
if ws_msg.data == Some(Value::String("ping".to_string())) {
|
||||
let resp = WSMessage {
|
||||
method: "heartbeat".to_string(),
|
||||
data: Some(Value::String("pong".to_string())),
|
||||
success: Some(true),
|
||||
..Default::default()
|
||||
};
|
||||
if socket.send(Message::Text(serialize_message(&resp))).await.is_err() {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
let err_resp = json!({"method": "error", "success": false, "error": "Invalid heartbeat"});
|
||||
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
warn!("❓ Unknown method '{}' from client {}", ws_msg.method, client_id);
|
||||
let err_resp = json!({"method": "error", "success": false, "error": "Unknown method"});
|
||||
if socket.send(Message::Text(err_resp.to_string())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Message::Close(_) => break,
|
||||
Message::Ping(_) => {
|
||||
if socket.send(Message::Pong(vec![])).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Message::Pong(_) => {}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
else => break,
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
debug!("👋 WebSocket client {} disconnected, cleaning up", client_id);
|
||||
|
||||
// Update connection stats
|
||||
mgr.stats
|
||||
.connections_active
|
||||
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
// Update active users
|
||||
{
|
||||
let mut active_users = mgr.stats.active_users.write().await;
|
||||
if let Some(count) = active_users.get_mut(&user_id) {
|
||||
*count = count.saturating_sub(1);
|
||||
if *count == 0 {
|
||||
active_users.remove(&user_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let channels = {
|
||||
let mut client_channels = mgr.client_channels.write().await;
|
||||
client_channels.remove(&client_id).unwrap_or_default()
|
||||
};
|
||||
|
||||
{
|
||||
let mut subs = mgr.subscribers.write().await;
|
||||
for channel in &channels {
|
||||
if let Some(set) = subs.get_mut(channel) {
|
||||
set.remove(&client_id);
|
||||
if set.is_empty() {
|
||||
subs.remove(channel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update subscription stats for all channels the client was subscribed to
|
||||
for channel in &channels {
|
||||
update_subscription_stats(&mgr, channel, false).await;
|
||||
}
|
||||
|
||||
{
|
||||
let mut clients = mgr.clients.write().await;
|
||||
clients.remove(&client_id);
|
||||
}
|
||||
|
||||
debug!("✨ Cleanup completed for client {}", client_id);
|
||||
}
|
||||
26
autogpt_platform/autogpt-rs/websocket/src/lib.rs
Normal file
26
autogpt_platform/autogpt-rs/websocket/src/lib.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
#![deny(warnings)]
|
||||
#![deny(clippy::unwrap_used)]
|
||||
#![deny(clippy::panic)]
|
||||
#![deny(clippy::unimplemented)]
|
||||
#![deny(clippy::todo)]
|
||||
|
||||
|
||||
pub mod config;
|
||||
pub mod connection_manager;
|
||||
pub mod handlers;
|
||||
pub mod models;
|
||||
pub mod stats;
|
||||
|
||||
pub use config::Config;
|
||||
pub use connection_manager::ConnectionManager;
|
||||
pub use handlers::ws_handler;
|
||||
pub use stats::Stats;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub mgr: Arc<ConnectionManager>,
|
||||
pub config: Arc<Config>,
|
||||
pub stats: Arc<Stats>,
|
||||
}
|
||||
172
autogpt_platform/autogpt-rs/websocket/src/main.rs
Normal file
172
autogpt_platform/autogpt-rs/websocket/src/main.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{header, StatusCode},
|
||||
response::Response,
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use clap::Parser;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use tracing::{debug, error, info};
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::connection_manager::ConnectionManager;
|
||||
use crate::handlers::ws_handler;
|
||||
|
||||
async fn stats_handler(
|
||||
axum::Extension(state): axum::Extension<AppState>,
|
||||
) -> Result<axum::response::Json<stats::StatsSnapshot>, StatusCode> {
|
||||
let snapshot = state.stats.snapshot().await;
|
||||
Ok(axum::response::Json(snapshot))
|
||||
}
|
||||
|
||||
async fn prometheus_handler(
|
||||
axum::Extension(state): axum::Extension<AppState>,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let snapshot = state.stats.snapshot().await;
|
||||
let prometheus_text = state.stats.to_prometheus_format(&snapshot);
|
||||
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "text/plain; version=0.0.4")
|
||||
.body(Body::from(prometheus_text))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
||||
mod config;
|
||||
mod connection_manager;
|
||||
mod handlers;
|
||||
mod models;
|
||||
mod stats;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about)]
|
||||
struct Cli {
|
||||
/// Path to a TOML configuration file
|
||||
#[arg(short = 'c', long = "config", value_name = "FILE")]
|
||||
config: Option<std::path::PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
mgr: Arc<ConnectionManager>,
|
||||
config: Arc<Config>,
|
||||
stats: Arc<stats::Stats>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// Initialize tracing
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| "websocket=info,tower_http=debug".into()),
|
||||
)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
|
||||
info!("🚀 Starting WebSocket API server");
|
||||
|
||||
let cli = Cli::parse();
|
||||
let config = Arc::new(Config::load(cli.config.as_deref()));
|
||||
info!(
|
||||
"⚙️ Configuration loaded - host: {}, port: {}, auth: {}",
|
||||
config.host, config.port, config.enable_auth
|
||||
);
|
||||
|
||||
let redis_client = match redis::Client::open(config.redis_url.clone()) {
|
||||
Ok(client) => {
|
||||
debug!("✅ Redis client created successfully");
|
||||
client
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"❌ Failed to create Redis client: {}. Please check REDIS_URL environment variable",
|
||||
e
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
||||
let stats = Arc::new(stats::Stats::default());
|
||||
let mgr = Arc::new(ConnectionManager::new(
|
||||
redis_client,
|
||||
config.execution_event_bus_name.clone(),
|
||||
stats.clone(),
|
||||
));
|
||||
|
||||
let mgr_clone = mgr.clone();
|
||||
tokio::spawn(async move {
|
||||
debug!("📡 Starting event broadcaster task");
|
||||
mgr_clone.run_broadcaster().await;
|
||||
});
|
||||
|
||||
let state = AppState {
|
||||
mgr,
|
||||
config: config.clone(),
|
||||
stats,
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.route("/ws", get(ws_handler))
|
||||
.route("/stats", get(stats_handler))
|
||||
.route("/metrics", get(prometheus_handler))
|
||||
.layer(axum::Extension(state));
|
||||
|
||||
let cors = if config.backend_cors_allow_origins.is_empty() {
|
||||
// If no specific origins configured, allow any origin but without credentials
|
||||
CorsLayer::new()
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any)
|
||||
.allow_origin(Any)
|
||||
} else {
|
||||
// If specific origins configured, allow credentials
|
||||
CorsLayer::new()
|
||||
.allow_methods([
|
||||
axum::http::Method::GET,
|
||||
axum::http::Method::POST,
|
||||
axum::http::Method::PUT,
|
||||
axum::http::Method::DELETE,
|
||||
axum::http::Method::OPTIONS,
|
||||
])
|
||||
.allow_headers(vec![
|
||||
axum::http::header::CONTENT_TYPE,
|
||||
axum::http::header::AUTHORIZATION,
|
||||
])
|
||||
.allow_credentials(true)
|
||||
.allow_origin(
|
||||
config
|
||||
.backend_cors_allow_origins
|
||||
.iter()
|
||||
.filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
};
|
||||
|
||||
let app = app.layer(cors);
|
||||
|
||||
let addr = format!("{}:{}", config.host, config.port);
|
||||
let listener = match TcpListener::bind(&addr).await {
|
||||
Ok(listener) => {
|
||||
info!("🎧 WebSocket server listening on: {}", addr);
|
||||
listener
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"❌ Failed to bind to {}: {}. Please check if the port is already in use",
|
||||
addr, e
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
||||
info!("✨ WebSocket API server ready to accept connections");
|
||||
|
||||
if let Err(e) = axum::serve(listener, app.into_make_service()).await {
|
||||
error!("💥 Server error: {}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
103
autogpt_platform/autogpt-rs/websocket/src/models.rs
Normal file
103
autogpt_platform/autogpt-rs/websocket/src/models.rs
Normal file
@@ -0,0 +1,103 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct WSMessage {
|
||||
pub method: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<Value>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub success: Option<bool>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub channel: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Claims {
|
||||
pub sub: String,
|
||||
}
|
||||
|
||||
// Event models moved from events.rs
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "event_type")]
|
||||
pub enum ExecutionEvent {
|
||||
#[serde(rename = "graph_execution_update")]
|
||||
GraphExecutionUpdate(GraphExecutionEvent),
|
||||
#[serde(rename = "node_execution_update")]
|
||||
NodeExecutionUpdate(NodeExecutionEvent),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GraphExecutionEvent {
|
||||
pub id: String,
|
||||
pub graph_id: String,
|
||||
pub graph_version: u32,
|
||||
pub user_id: String,
|
||||
pub status: ExecutionStatus,
|
||||
pub started_at: Option<String>,
|
||||
pub ended_at: Option<String>,
|
||||
pub preset_id: Option<String>,
|
||||
pub stats: Option<ExecutionStats>,
|
||||
|
||||
// Keep these as JSON since they vary by graph
|
||||
pub inputs: Value,
|
||||
pub outputs: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NodeExecutionEvent {
|
||||
pub node_exec_id: String,
|
||||
pub node_id: String,
|
||||
pub graph_exec_id: String,
|
||||
pub graph_id: String,
|
||||
pub graph_version: u32,
|
||||
pub user_id: String,
|
||||
pub block_id: String,
|
||||
pub status: ExecutionStatus,
|
||||
pub add_time: String,
|
||||
pub queue_time: Option<String>,
|
||||
pub start_time: Option<String>,
|
||||
pub end_time: Option<String>,
|
||||
|
||||
// Keep these as JSON since they vary by node type
|
||||
pub input_data: Value,
|
||||
pub output_data: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExecutionStats {
|
||||
pub cost: f64,
|
||||
pub duration: f64,
|
||||
pub duration_cpu_only: f64,
|
||||
pub error: Option<String>,
|
||||
pub node_error_count: u32,
|
||||
pub node_exec_count: u32,
|
||||
pub node_exec_time: f64,
|
||||
pub node_exec_time_cpu_only: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||
pub enum ExecutionStatus {
|
||||
Queued,
|
||||
Running,
|
||||
Completed,
|
||||
Failed,
|
||||
Incomplete,
|
||||
Terminated,
|
||||
}
|
||||
|
||||
// Wrapper for the Redis event that includes the payload
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RedisEventWrapper {
|
||||
pub payload: ExecutionEvent,
|
||||
}
|
||||
|
||||
impl RedisEventWrapper {
|
||||
pub fn parse(json_str: &str) -> Result<Self, serde_json::Error> {
|
||||
serde_json::from_str(json_str)
|
||||
}
|
||||
}
|
||||
238
autogpt_platform/autogpt-rs/websocket/src/stats.rs
Normal file
238
autogpt_platform/autogpt-rs/websocket/src/stats.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Stats {
|
||||
// Connection metrics
|
||||
pub connections_total: AtomicU64,
|
||||
pub connections_active: AtomicU64,
|
||||
pub connections_failed_auth: AtomicU64,
|
||||
|
||||
// Message metrics
|
||||
pub messages_received_total: AtomicU64,
|
||||
pub messages_sent_total: AtomicU64,
|
||||
pub messages_failed_total: AtomicU64,
|
||||
|
||||
// Subscription metrics
|
||||
pub subscriptions_total: AtomicU64,
|
||||
pub subscriptions_active: AtomicU64,
|
||||
pub unsubscriptions_total: AtomicU64,
|
||||
|
||||
// Event metrics by type
|
||||
pub events_received_total: AtomicU64,
|
||||
pub graph_execution_events: AtomicU64,
|
||||
pub node_execution_events: AtomicU64,
|
||||
|
||||
// Redis metrics
|
||||
pub redis_messages_received: AtomicU64,
|
||||
pub redis_messages_ignored: AtomicU64,
|
||||
|
||||
// Channel metrics
|
||||
pub channels_active: RwLock<HashMap<String, usize>>, // channel -> subscriber count
|
||||
|
||||
// User metrics
|
||||
pub active_users: RwLock<HashMap<String, usize>>, // user_id -> connection count
|
||||
|
||||
// Error metrics
|
||||
pub errors_total: AtomicU64,
|
||||
pub errors_json_parse: AtomicU64,
|
||||
pub errors_message_size: AtomicU64,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct StatsSnapshot {
|
||||
// Connection metrics
|
||||
pub connections_total: u64,
|
||||
pub connections_active: u64,
|
||||
pub connections_failed_auth: u64,
|
||||
|
||||
// Message metrics
|
||||
pub messages_received_total: u64,
|
||||
pub messages_sent_total: u64,
|
||||
pub messages_failed_total: u64,
|
||||
|
||||
// Subscription metrics
|
||||
pub subscriptions_total: u64,
|
||||
pub subscriptions_active: u64,
|
||||
pub unsubscriptions_total: u64,
|
||||
|
||||
// Event metrics
|
||||
pub events_received_total: u64,
|
||||
pub graph_execution_events: u64,
|
||||
pub node_execution_events: u64,
|
||||
|
||||
// Redis metrics
|
||||
pub redis_messages_received: u64,
|
||||
pub redis_messages_ignored: u64,
|
||||
|
||||
// Channel metrics
|
||||
pub channels_active_count: usize,
|
||||
pub total_subscribers: usize,
|
||||
|
||||
// User metrics
|
||||
pub active_users_count: usize,
|
||||
|
||||
// Error metrics
|
||||
pub errors_total: u64,
|
||||
pub errors_json_parse: u64,
|
||||
pub errors_message_size: u64,
|
||||
}
|
||||
|
||||
impl Stats {
|
||||
pub async fn snapshot(&self) -> StatsSnapshot {
|
||||
// Take read locks for HashMap data - it's ok if this is slightly stale
|
||||
let channels = self.channels_active.read().await;
|
||||
let total_subscribers: usize = channels.values().sum();
|
||||
let channels_active_count = channels.len();
|
||||
drop(channels); // Release lock early
|
||||
|
||||
let users = self.active_users.read().await;
|
||||
let active_users_count = users.len();
|
||||
drop(users); // Release lock early
|
||||
|
||||
StatsSnapshot {
|
||||
connections_total: self.connections_total.load(Ordering::Relaxed),
|
||||
connections_active: self.connections_active.load(Ordering::Relaxed),
|
||||
connections_failed_auth: self.connections_failed_auth.load(Ordering::Relaxed),
|
||||
|
||||
messages_received_total: self.messages_received_total.load(Ordering::Relaxed),
|
||||
messages_sent_total: self.messages_sent_total.load(Ordering::Relaxed),
|
||||
messages_failed_total: self.messages_failed_total.load(Ordering::Relaxed),
|
||||
|
||||
subscriptions_total: self.subscriptions_total.load(Ordering::Relaxed),
|
||||
subscriptions_active: self.subscriptions_active.load(Ordering::Relaxed),
|
||||
unsubscriptions_total: self.unsubscriptions_total.load(Ordering::Relaxed),
|
||||
|
||||
events_received_total: self.events_received_total.load(Ordering::Relaxed),
|
||||
graph_execution_events: self.graph_execution_events.load(Ordering::Relaxed),
|
||||
node_execution_events: self.node_execution_events.load(Ordering::Relaxed),
|
||||
|
||||
redis_messages_received: self.redis_messages_received.load(Ordering::Relaxed),
|
||||
redis_messages_ignored: self.redis_messages_ignored.load(Ordering::Relaxed),
|
||||
|
||||
channels_active_count,
|
||||
total_subscribers,
|
||||
active_users_count,
|
||||
|
||||
errors_total: self.errors_total.load(Ordering::Relaxed),
|
||||
errors_json_parse: self.errors_json_parse.load(Ordering::Relaxed),
|
||||
errors_message_size: self.errors_message_size.load(Ordering::Relaxed),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_prometheus_format(&self, snapshot: &StatsSnapshot) -> String {
|
||||
let mut output = String::new();
|
||||
|
||||
// Connection metrics
|
||||
output.push_str("# HELP ws_connections_total Total number of WebSocket connections\n");
|
||||
output.push_str("# TYPE ws_connections_total counter\n");
|
||||
output.push_str(&format!(
|
||||
"ws_connections_total {}\n\n",
|
||||
snapshot.connections_total
|
||||
));
|
||||
|
||||
output.push_str(
|
||||
"# HELP ws_connections_active Current number of active WebSocket connections\n",
|
||||
);
|
||||
output.push_str("# TYPE ws_connections_active gauge\n");
|
||||
output.push_str(&format!(
|
||||
"ws_connections_active {}\n\n",
|
||||
snapshot.connections_active
|
||||
));
|
||||
|
||||
output
|
||||
.push_str("# HELP ws_connections_failed_auth Total number of failed authentications\n");
|
||||
output.push_str("# TYPE ws_connections_failed_auth counter\n");
|
||||
output.push_str(&format!(
|
||||
"ws_connections_failed_auth {}\n\n",
|
||||
snapshot.connections_failed_auth
|
||||
));
|
||||
|
||||
// Message metrics
|
||||
output.push_str(
|
||||
"# HELP ws_messages_received_total Total number of messages received from clients\n",
|
||||
);
|
||||
output.push_str("# TYPE ws_messages_received_total counter\n");
|
||||
output.push_str(&format!(
|
||||
"ws_messages_received_total {}\n\n",
|
||||
snapshot.messages_received_total
|
||||
));
|
||||
|
||||
output.push_str("# HELP ws_messages_sent_total Total number of messages sent to clients\n");
|
||||
output.push_str("# TYPE ws_messages_sent_total counter\n");
|
||||
output.push_str(&format!(
|
||||
"ws_messages_sent_total {}\n\n",
|
||||
snapshot.messages_sent_total
|
||||
));
|
||||
|
||||
// Subscription metrics
|
||||
output.push_str("# HELP ws_subscriptions_active Current number of active subscriptions\n");
|
||||
output.push_str("# TYPE ws_subscriptions_active gauge\n");
|
||||
output.push_str(&format!(
|
||||
"ws_subscriptions_active {}\n\n",
|
||||
snapshot.subscriptions_active
|
||||
));
|
||||
|
||||
// Event metrics
|
||||
output.push_str(
|
||||
"# HELP ws_events_received_total Total number of events received from Redis\n",
|
||||
);
|
||||
output.push_str("# TYPE ws_events_received_total counter\n");
|
||||
output.push_str(&format!(
|
||||
"ws_events_received_total {}\n\n",
|
||||
snapshot.events_received_total
|
||||
));
|
||||
|
||||
output.push_str(
|
||||
"# HELP ws_graph_execution_events_total Total number of graph execution events\n",
|
||||
);
|
||||
output.push_str("# TYPE ws_graph_execution_events_total counter\n");
|
||||
output.push_str(&format!(
|
||||
"ws_graph_execution_events_total {}\n\n",
|
||||
snapshot.graph_execution_events
|
||||
));
|
||||
|
||||
output.push_str(
|
||||
"# HELP ws_node_execution_events_total Total number of node execution events\n",
|
||||
);
|
||||
output.push_str("# TYPE ws_node_execution_events_total counter\n");
|
||||
output.push_str(&format!(
|
||||
"ws_node_execution_events_total {}\n\n",
|
||||
snapshot.node_execution_events
|
||||
));
|
||||
|
||||
// Channel metrics
|
||||
output.push_str("# HELP ws_channels_active Number of active channels\n");
|
||||
output.push_str("# TYPE ws_channels_active gauge\n");
|
||||
output.push_str(&format!(
|
||||
"ws_channels_active {}\n\n",
|
||||
snapshot.channels_active_count
|
||||
));
|
||||
|
||||
output.push_str(
|
||||
"# HELP ws_total_subscribers Total number of subscribers across all channels\n",
|
||||
);
|
||||
output.push_str("# TYPE ws_total_subscribers gauge\n");
|
||||
output.push_str(&format!(
|
||||
"ws_total_subscribers {}\n\n",
|
||||
snapshot.total_subscribers
|
||||
));
|
||||
|
||||
// User metrics
|
||||
output.push_str("# HELP ws_active_users Number of unique users with active connections\n");
|
||||
output.push_str("# TYPE ws_active_users gauge\n");
|
||||
output.push_str(&format!(
|
||||
"ws_active_users {}\n\n",
|
||||
snapshot.active_users_count
|
||||
));
|
||||
|
||||
// Error metrics
|
||||
output.push_str("# HELP ws_errors_total Total number of errors\n");
|
||||
output.push_str("# TYPE ws_errors_total counter\n");
|
||||
output.push_str(&format!("ws_errors_total {}\n", snapshot.errors_total));
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
@@ -7,5 +7,9 @@ class Settings:
|
||||
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
|
||||
self.JWT_ALGORITHM: str = "HS256"
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(self.JWT_SECRET_KEY)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -10,8 +10,8 @@ from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
from .config import settings
|
||||
from .jwt_utils import parse_jwt_token
|
||||
|
||||
security = HTTPBearer()
|
||||
logger = logging.getLogger(__name__)
|
||||
bearer_auth = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def auth_middleware(request: Request):
|
||||
@@ -20,10 +20,11 @@ async def auth_middleware(request: Request):
|
||||
logger.warning("Auth disabled")
|
||||
return {}
|
||||
|
||||
credentials = await bearer_auth(request)
|
||||
security = HTTPBearer()
|
||||
credentials = await security(request)
|
||||
|
||||
if not credentials:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
raise HTTPException(status_code=401, detail="Authorization header is missing")
|
||||
|
||||
try:
|
||||
payload = parse_jwt_token(credentials.credentials)
|
||||
|
||||
@@ -0,0 +1,166 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, cast
|
||||
|
||||
import ldclient
|
||||
from fastapi import HTTPException
|
||||
from ldclient import Context, LDClient
|
||||
from ldclient.config import Config
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from .config import SETTINGS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_client() -> LDClient:
|
||||
"""Get the LaunchDarkly client singleton."""
|
||||
return ldclient.get()
|
||||
|
||||
|
||||
def initialize_launchdarkly() -> None:
|
||||
sdk_key = SETTINGS.launch_darkly_sdk_key
|
||||
logger.debug(
|
||||
f"Initializing LaunchDarkly with SDK key: {'present' if sdk_key else 'missing'}"
|
||||
)
|
||||
|
||||
if not sdk_key:
|
||||
logger.warning("LaunchDarkly SDK key not configured")
|
||||
return
|
||||
|
||||
config = Config(sdk_key)
|
||||
ldclient.set_config(config)
|
||||
|
||||
if ldclient.get().is_initialized():
|
||||
logger.info("LaunchDarkly client initialized successfully")
|
||||
else:
|
||||
logger.error("LaunchDarkly client failed to initialize")
|
||||
|
||||
|
||||
def shutdown_launchdarkly() -> None:
|
||||
"""Shutdown the LaunchDarkly client."""
|
||||
if ldclient.get().is_initialized():
|
||||
ldclient.get().close()
|
||||
logger.info("LaunchDarkly client closed successfully")
|
||||
|
||||
|
||||
def create_context(
|
||||
user_id: str, additional_attributes: Optional[Dict[str, Any]] = None
|
||||
) -> Context:
|
||||
"""Create LaunchDarkly context with optional additional attributes."""
|
||||
builder = Context.builder(str(user_id)).kind("user")
|
||||
if additional_attributes:
|
||||
for key, value in additional_attributes.items():
|
||||
builder.set(key, value)
|
||||
return builder.build()
|
||||
|
||||
|
||||
def feature_flag(
|
||||
flag_key: str,
|
||||
default: bool = False,
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""
|
||||
Decorator for feature flag protected endpoints.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[P, Union[T, Awaitable[T]]],
|
||||
) -> Callable[P, Union[T, Awaitable[T]]]:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
try:
|
||||
user_id = kwargs.get("user_id")
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
if not get_client().is_initialized():
|
||||
logger.warning(
|
||||
f"LaunchDarkly not initialized, using default={default}"
|
||||
)
|
||||
is_enabled = default
|
||||
else:
|
||||
context = create_context(str(user_id))
|
||||
is_enabled = get_client().variation(flag_key, context, default)
|
||||
|
||||
if not is_enabled:
|
||||
raise HTTPException(status_code=404, detail="Feature not available")
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
if asyncio.iscoroutine(result):
|
||||
return await result
|
||||
return cast(T, result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
|
||||
raise
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
try:
|
||||
user_id = kwargs.get("user_id")
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
if not get_client().is_initialized():
|
||||
logger.warning(
|
||||
f"LaunchDarkly not initialized, using default={default}"
|
||||
)
|
||||
is_enabled = default
|
||||
else:
|
||||
context = create_context(str(user_id))
|
||||
is_enabled = get_client().variation(flag_key, context, default)
|
||||
|
||||
if not is_enabled:
|
||||
raise HTTPException(status_code=404, detail="Feature not available")
|
||||
|
||||
return cast(T, func(*args, **kwargs))
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
|
||||
raise
|
||||
|
||||
return cast(
|
||||
Callable[P, Union[T, Awaitable[T]]],
|
||||
async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper,
|
||||
)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def percentage_rollout(
|
||||
flag_key: str,
|
||||
default: bool = False,
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""Decorator for percentage-based rollouts."""
|
||||
return feature_flag(flag_key, default)
|
||||
|
||||
|
||||
def beta_feature(
|
||||
flag_key: Optional[str] = None,
|
||||
unauthorized_response: Any = {"message": "Not available in beta"},
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""Decorator for beta features."""
|
||||
actual_key = f"beta-{flag_key}" if flag_key else "beta"
|
||||
return feature_flag(actual_key, False)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_flag_variation(flag_key: str, return_value: Any):
|
||||
"""Context manager for testing feature flags."""
|
||||
original_variation = get_client().variation
|
||||
get_client().variation = lambda key, context, default: (
|
||||
return_value if key == flag_key else original_variation(key, context, default)
|
||||
)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
get_client().variation = original_variation
|
||||
@@ -0,0 +1,45 @@
|
||||
import pytest
|
||||
from ldclient import LDClient
|
||||
|
||||
from autogpt_libs.feature_flag.client import feature_flag, mock_flag_variation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ld_client(mocker):
|
||||
client = mocker.Mock(spec=LDClient)
|
||||
mocker.patch("ldclient.get", return_value=client)
|
||||
client.is_initialized.return_value = True
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feature_flag_enabled(ld_client):
|
||||
ld_client.variation.return_value = True
|
||||
|
||||
@feature_flag("test-flag")
|
||||
async def test_function(user_id: str):
|
||||
return "success"
|
||||
|
||||
result = test_function(user_id="test-user")
|
||||
assert result == "success"
|
||||
ld_client.variation.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feature_flag_unauthorized_response(ld_client):
|
||||
ld_client.variation.return_value = False
|
||||
|
||||
@feature_flag("test-flag")
|
||||
async def test_function(user_id: str):
|
||||
return "success"
|
||||
|
||||
result = test_function(user_id="test-user")
|
||||
assert result == {"error": "disabled"}
|
||||
|
||||
|
||||
def test_mock_flag_variation(ld_client):
|
||||
with mock_flag_variation("test-flag", True):
|
||||
assert ld_client.variation("test-flag", None, False)
|
||||
|
||||
with mock_flag_variation("test-flag", False):
|
||||
assert ld_client.variation("test-flag", None, False)
|
||||
@@ -0,0 +1,15 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
launch_darkly_sdk_key: str = Field(
|
||||
default="",
|
||||
description="The Launch Darkly SDK key",
|
||||
validation_alias="LAUNCH_DARKLY_SDK_KEY",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(case_sensitive=True, extra="ignore")
|
||||
|
||||
|
||||
SETTINGS = Settings()
|
||||
@@ -1,8 +1,6 @@
|
||||
"""Logging module for Auto-GPT."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@@ -12,15 +10,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from .filters import BelowLevelFilter
|
||||
from .formatters import AGPTFormatter
|
||||
|
||||
# Configure global socket timeout and gRPC keepalive to prevent deadlocks
|
||||
# This must be done at import time before any gRPC connections are established
|
||||
socket.setdefaulttimeout(30) # 30-second socket timeout
|
||||
|
||||
# Enable gRPC keepalive to detect dead connections faster
|
||||
os.environ.setdefault("GRPC_KEEPALIVE_TIME_MS", "30000") # 30 seconds
|
||||
os.environ.setdefault("GRPC_KEEPALIVE_TIMEOUT_MS", "5000") # 5 seconds
|
||||
os.environ.setdefault("GRPC_KEEPALIVE_PERMIT_WITHOUT_CALLS", "true")
|
||||
|
||||
LOG_DIR = Path(__file__).parent.parent.parent.parent / "logs"
|
||||
LOG_FILE = "activity.log"
|
||||
DEBUG_LOG_FILE = "debug.log"
|
||||
@@ -90,6 +79,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
Note: This function is typically called at the start of the application
|
||||
to set up the logging infrastructure.
|
||||
"""
|
||||
|
||||
config = LoggingConfig()
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
@@ -115,17 +105,13 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
if config.enable_cloud_logging or force_cloud_logging:
|
||||
import google.cloud.logging
|
||||
from google.cloud.logging.handlers import CloudLoggingHandler
|
||||
from google.cloud.logging_v2.handlers.transports import (
|
||||
BackgroundThreadTransport,
|
||||
)
|
||||
from google.cloud.logging_v2.handlers.transports.sync import SyncTransport
|
||||
|
||||
client = google.cloud.logging.Client()
|
||||
# Use BackgroundThreadTransport to prevent blocking the main thread
|
||||
# and deadlocks when gRPC calls to Google Cloud Logging hang
|
||||
cloud_handler = CloudLoggingHandler(
|
||||
client,
|
||||
name="autogpt_logs",
|
||||
transport=BackgroundThreadTransport,
|
||||
transport=SyncTransport,
|
||||
)
|
||||
cloud_handler.setLevel(config.level)
|
||||
log_handlers.append(cloud_handler)
|
||||
|
||||
@@ -1,5 +1,39 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import uvicorn.config
|
||||
from colorama import Fore
|
||||
|
||||
|
||||
def remove_color_codes(s: str) -> str:
|
||||
return re.sub(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])", "", s)
|
||||
|
||||
|
||||
def fmt_kwargs(kwargs: dict) -> str:
|
||||
return ", ".join(f"{n}={repr(v)}" for n, v in kwargs.items())
|
||||
|
||||
|
||||
def print_attribute(
|
||||
title: str, value: Any, title_color: str = Fore.GREEN, value_color: str = ""
|
||||
) -> None:
|
||||
logger = logging.getLogger()
|
||||
logger.info(
|
||||
str(value),
|
||||
extra={
|
||||
"title": f"{title.rstrip(':')}:",
|
||||
"title_color": title_color,
|
||||
"color": value_color,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def generate_uvicorn_config():
|
||||
"""
|
||||
Generates a uvicorn logging config that silences uvicorn's default logging and tells it to use the native logging module.
|
||||
"""
|
||||
log_config = dict(uvicorn.config.LOGGING_CONFIG)
|
||||
log_config["loggers"]["uvicorn"] = {"handlers": []}
|
||||
log_config["loggers"]["uvicorn.error"] = {"handlers": []}
|
||||
log_config["loggers"]["uvicorn.access"] = {"handlers": []}
|
||||
return log_config
|
||||
|
||||
@@ -1,34 +1,17 @@
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
|
||||
pass
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
def thread_cached(
|
||||
@@ -74,193 +57,3 @@ def thread_cached(
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
|
||||
|
||||
FuncT = TypeVar("FuncT")
|
||||
|
||||
|
||||
R_co = TypeVar("R_co", covariant=True)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AsyncCachedFunction(Protocol[P, R_co]):
|
||||
"""Protocol for async functions with cache management methods."""
|
||||
|
||||
def cache_clear(self) -> None:
|
||||
"""Clear all cached entries."""
|
||||
return None
|
||||
|
||||
def cache_info(self) -> dict[str, int | None]:
|
||||
"""Get cache statistics."""
|
||||
return {}
|
||||
|
||||
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
|
||||
"""Call the cached function."""
|
||||
return None # type: ignore
|
||||
|
||||
|
||||
def async_ttl_cache(
|
||||
maxsize: int = 128, ttl_seconds: int | None = None
|
||||
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
|
||||
"""
|
||||
TTL (Time To Live) cache decorator for async functions.
|
||||
|
||||
Similar to functools.lru_cache but works with async functions and includes optional TTL.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum number of cached entries
|
||||
ttl_seconds: Time to live in seconds. If None, entries never expire (like lru_cache)
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
|
||||
Example:
|
||||
# With TTL
|
||||
@async_ttl_cache(maxsize=1000, ttl_seconds=300)
|
||||
async def api_call(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
# Without TTL (permanent cache like lru_cache)
|
||||
@async_ttl_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
async_func: Callable[P, Awaitable[R]],
|
||||
) -> AsyncCachedFunction[P, R]:
|
||||
# Cache storage - use union type to handle both cases
|
||||
cache_storage: dict[tuple, R | Tuple[R, float]] = {}
|
||||
|
||||
@wraps(async_func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# Create cache key from arguments
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we have a valid cached entry
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
# No TTL - return cached result directly
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cast(R, cache_storage[key])
|
||||
else:
|
||||
# With TTL - check expiration
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cast(R, result)
|
||||
else:
|
||||
# Expired entry
|
||||
del cache_storage[key]
|
||||
logger.debug(
|
||||
f"Cache entry expired for {async_func.__name__}"
|
||||
)
|
||||
|
||||
# Cache miss or expired - fetch fresh data
|
||||
logger.debug(
|
||||
f"Cache miss for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
result = await async_func(*args, **kwargs)
|
||||
|
||||
# Store in cache
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Simple cleanup when cache gets too large
|
||||
if len(cache_storage) > maxsize:
|
||||
# Remove oldest entries (simple FIFO cleanup)
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
logger.debug(
|
||||
f"Cache cleanup: removed {len(oldest_keys)} entries for {async_func.__name__}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
# Add cache management methods (similar to functools.lru_cache)
|
||||
def cache_clear() -> None:
|
||||
cache_storage.clear()
|
||||
|
||||
def cache_info() -> dict[str, int | None]:
|
||||
return {
|
||||
"size": len(cache_storage),
|
||||
"maxsize": maxsize,
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
|
||||
# Attach methods to wrapper
|
||||
setattr(wrapper, "cache_clear", cache_clear)
|
||||
setattr(wrapper, "cache_info", cache_info)
|
||||
|
||||
return cast(AsyncCachedFunction[P, R], wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def async_cache(
|
||||
func: Callable[P, Awaitable[R]],
|
||||
) -> AsyncCachedFunction[P, R]:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def async_cache(
|
||||
func: None = None,
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
|
||||
pass
|
||||
|
||||
|
||||
def async_cache(
|
||||
func: Callable[P, Awaitable[R]] | None = None,
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
) -> (
|
||||
AsyncCachedFunction[P, R]
|
||||
| Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]
|
||||
):
|
||||
"""
|
||||
Process-level cache decorator for async functions (no TTL).
|
||||
|
||||
Similar to functools.lru_cache but works with async functions.
|
||||
This is a convenience wrapper around async_ttl_cache with ttl_seconds=None.
|
||||
|
||||
Args:
|
||||
func: The async function to cache (when used without parentheses)
|
||||
maxsize: Maximum number of cached entries
|
||||
|
||||
Returns:
|
||||
Decorated function or decorator
|
||||
|
||||
Example:
|
||||
# Without parentheses (uses default maxsize=128)
|
||||
@async_cache
|
||||
async def get_data(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
# With parentheses and custom maxsize
|
||||
@async_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
# Expensive computation here
|
||||
return {"result": param}
|
||||
"""
|
||||
if func is None:
|
||||
# Called with parentheses @async_cache() or @async_cache(maxsize=...)
|
||||
return async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
else:
|
||||
# Called without parentheses @async_cache
|
||||
decorator = async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
return decorator(func)
|
||||
|
||||
@@ -16,12 +16,7 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt_libs.utils.cache import (
|
||||
async_cache,
|
||||
async_ttl_cache,
|
||||
clear_thread_cache,
|
||||
thread_cached,
|
||||
)
|
||||
from autogpt_libs.utils.cache import clear_thread_cache, thread_cached
|
||||
|
||||
|
||||
class TestThreadCached:
|
||||
@@ -328,378 +323,3 @@ class TestThreadCached:
|
||||
|
||||
assert function_using_mock(2) == 42
|
||||
assert mock.call_count == 2
|
||||
|
||||
|
||||
class TestAsyncTTLCache:
|
||||
"""Tests for the @async_ttl_cache decorator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching(self):
|
||||
"""Test basic caching functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result3 = await cached_function(2, 3)
|
||||
assert result3 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_expiration(self):
|
||||
"""Test that cache entries expire after TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def short_lived_cache(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First call
|
||||
result1 = await short_lived_cache(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Second call immediately - should use cache
|
||||
result2 = await short_lived_cache(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Third call after expiration - should call function again
|
||||
result3 = await short_lived_cache(5)
|
||||
assert result3 == 10
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_info(self):
|
||||
"""Test cache info functionality."""
|
||||
|
||||
@async_ttl_cache(maxsize=5, ttl_seconds=300)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] == 300
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_clear(self):
|
||||
"""Test cache clearing functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 4
|
||||
|
||||
# First call
|
||||
result1 = await clearable_function(2)
|
||||
assert result1 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
result2 = await clearable_function(2)
|
||||
assert result2 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Clear cache
|
||||
clearable_function.cache_clear()
|
||||
|
||||
# Third call after clear - should call function again
|
||||
result3 = await clearable_function(2)
|
||||
assert result3 == 8
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maxsize_cleanup(self):
|
||||
"""Test that cache cleans up when maxsize is exceeded."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=3, ttl_seconds=60)
|
||||
async def size_limited_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# Fill cache to maxsize
|
||||
await size_limited_function(1) # call_count: 1
|
||||
await size_limited_function(2) # call_count: 2
|
||||
await size_limited_function(3) # call_count: 3
|
||||
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] == 3
|
||||
|
||||
# Add one more entry - should trigger cleanup
|
||||
await size_limited_function(4) # call_count: 4
|
||||
|
||||
# Cache size should be reduced (cleanup removes oldest entries)
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] is not None and info["size"] <= 3 # Should be cleaned up
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_argument_variations(self):
|
||||
"""Test caching with different argument patterns."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def arg_test_function(a: int, b: str = "default", *, c: int = 100) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{a}-{b}-{c}"
|
||||
|
||||
# Different ways to call with same logical arguments
|
||||
result1 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
|
||||
# Same arguments, same order - should use cache
|
||||
result2 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
assert result1 == result2
|
||||
|
||||
# Different arguments - should call function
|
||||
result3 = await arg_test_function(2, "test", c=200)
|
||||
assert call_count == 2
|
||||
assert result1 != result3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_handling(self):
|
||||
"""Test that exceptions are not cached."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def exception_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if x < 0:
|
||||
raise ValueError("Negative value not allowed")
|
||||
return x * 2
|
||||
|
||||
# Successful call - should be cached
|
||||
result1 = await exception_function(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Same successful call - should use cache
|
||||
result2 = await exception_function(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Exception call - should not be cached
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
# Same exception call - should call again (not cached)
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_calls(self):
|
||||
"""Test caching behavior with concurrent calls."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def concurrent_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.05) # Simulate work
|
||||
return x * x
|
||||
|
||||
# Launch concurrent calls with same arguments
|
||||
tasks = [concurrent_function(3) for _ in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All results should be the same
|
||||
assert all(result == 9 for result in results)
|
||||
|
||||
# Note: Due to race conditions, call_count might be up to 5 for concurrent calls
|
||||
# This tests that the cache doesn't break under concurrent access
|
||||
assert 1 <= call_count <= 5
|
||||
|
||||
|
||||
class TestAsyncCache:
|
||||
"""Tests for the @async_cache decorator (no TTL)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching_no_ttl(self):
|
||||
"""Test basic caching functionality without TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_cache(maxsize=10)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Third call after some time - should still use cache (no TTL)
|
||||
await asyncio.sleep(0.05)
|
||||
result3 = await cached_function(1, 2)
|
||||
assert result3 == 3
|
||||
assert call_count == 1 # Still no additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result4 = await cached_function(2, 3)
|
||||
assert result4 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_ttl_vs_ttl_behavior(self):
|
||||
"""Test the difference between TTL and no-TTL caching."""
|
||||
ttl_call_count = 0
|
||||
no_ttl_call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_call_count
|
||||
ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
@async_cache(maxsize=10) # No TTL
|
||||
async def no_ttl_function(x: int) -> int:
|
||||
nonlocal no_ttl_call_count
|
||||
no_ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First calls
|
||||
await ttl_function(5)
|
||||
await no_ttl_function(5)
|
||||
assert ttl_call_count == 1
|
||||
assert no_ttl_call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Second calls after TTL expiry
|
||||
await ttl_function(5) # Should call function again (TTL expired)
|
||||
await no_ttl_function(5) # Should use cache (no TTL)
|
||||
assert ttl_call_count == 2 # TTL function called again
|
||||
assert no_ttl_call_count == 1 # No-TTL function still cached
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_cache_info(self):
|
||||
"""Test cache info for no-TTL cache."""
|
||||
|
||||
@async_cache(maxsize=5)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] is None # No TTL
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
|
||||
class TestTTLOptional:
|
||||
"""Tests for optional TTL functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_none_behavior(self):
|
||||
"""Test that ttl_seconds=None works like no TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=None)
|
||||
async def no_ttl_via_none(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# First call
|
||||
result1 = await no_ttl_via_none(3)
|
||||
assert result1 == 9
|
||||
assert call_count == 1
|
||||
|
||||
# Wait (would expire if there was TTL)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Second call - should still use cache
|
||||
result2 = await no_ttl_via_none(3)
|
||||
assert result2 == 9
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Check cache info
|
||||
info = no_ttl_via_none.cache_info()
|
||||
assert info["ttl_seconds"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_options_comparison(self):
|
||||
"""Test different cache options work as expected."""
|
||||
ttl_calls = 0
|
||||
no_ttl_calls = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # With TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_calls
|
||||
ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
@async_cache(maxsize=10) # Process-level cache (no TTL)
|
||||
async def process_function(x: int) -> int:
|
||||
nonlocal no_ttl_calls
|
||||
no_ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
# Both should cache initially
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Immediate second calls - both should use cache
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# After TTL expiry
|
||||
await ttl_function(3) # Should call function again
|
||||
await process_function(3) # Should still use cache
|
||||
assert ttl_calls == 2 # TTL cache expired, called again
|
||||
assert no_ttl_calls == 1 # Process cache never expires
|
||||
|
||||
41
autogpt_platform/autogpt_libs/poetry.lock
generated
41
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1253,31 +1253,30 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.12.9"
|
||||
version = "0.12.3"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "ruff-0.12.9-py3-none-linux_armv6l.whl", hash = "sha256:fcebc6c79fcae3f220d05585229463621f5dbf24d79fdc4936d9302e177cfa3e"},
|
||||
{file = "ruff-0.12.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:aed9d15f8c5755c0e74467731a007fcad41f19bcce41cd75f768bbd687f8535f"},
|
||||
{file = "ruff-0.12.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5b15ea354c6ff0d7423814ba6d44be2807644d0c05e9ed60caca87e963e93f70"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d596c2d0393c2502eaabfef723bd74ca35348a8dac4267d18a94910087807c53"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1b15599931a1a7a03c388b9c5df1bfa62be7ede6eb7ef753b272381f39c3d0ff"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3d02faa2977fb6f3f32ddb7828e212b7dd499c59eb896ae6c03ea5c303575756"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:17d5b6b0b3a25259b69ebcba87908496e6830e03acfb929ef9fd4c58675fa2ea"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:72db7521860e246adbb43f6ef464dd2a532ef2ef1f5dd0d470455b8d9f1773e0"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a03242c1522b4e0885af63320ad754d53983c9599157ee33e77d748363c561ce"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fc83e4e9751e6c13b5046d7162f205d0a7bac5840183c5beebf824b08a27340"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:881465ed56ba4dd26a691954650de6ad389a2d1fdb130fe51ff18a25639fe4bb"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:43f07a3ccfc62cdb4d3a3348bf0588358a66da756aa113e071b8ca8c3b9826af"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:07adb221c54b6bba24387911e5734357f042e5669fa5718920ee728aba3cbadc"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f5cd34fabfdea3933ab85d72359f118035882a01bff15bd1d2b15261d85d5f66"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:f6be1d2ca0686c54564da8e7ee9e25f93bdd6868263805f8c0b8fc6a449db6d7"},
|
||||
{file = "ruff-0.12.9-py3-none-win32.whl", hash = "sha256:cc7a37bd2509974379d0115cc5608a1a4a6c4bff1b452ea69db83c8855d53f93"},
|
||||
{file = "ruff-0.12.9-py3-none-win_amd64.whl", hash = "sha256:6fb15b1977309741d7d098c8a3cb7a30bc112760a00fb6efb7abc85f00ba5908"},
|
||||
{file = "ruff-0.12.9-py3-none-win_arm64.whl", hash = "sha256:63c8c819739d86b96d500cce885956a1a48ab056bbcbc61b747ad494b2485089"},
|
||||
{file = "ruff-0.12.9.tar.gz", hash = "sha256:fbd94b2e3c623f659962934e52c2bea6fc6da11f667a427a368adaf3af2c866a"},
|
||||
{file = "ruff-0.12.3-py3-none-linux_armv6l.whl", hash = "sha256:47552138f7206454eaf0c4fe827e546e9ddac62c2a3d2585ca54d29a890137a2"},
|
||||
{file = "ruff-0.12.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:0a9153b000c6fe169bb307f5bd1b691221c4286c133407b8827c406a55282041"},
|
||||
{file = "ruff-0.12.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa6b24600cf3b750e48ddb6057e901dd5b9aa426e316addb2a1af185a7509882"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2506961bf6ead54887ba3562604d69cb430f59b42133d36976421bc8bd45901"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c4faaff1f90cea9d3033cbbcdf1acf5d7fb11d8180758feb31337391691f3df0"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40dced4a79d7c264389de1c59467d5d5cefd79e7e06d1dfa2c75497b5269a5a6"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:0262d50ba2767ed0fe212aa7e62112a1dcbfd46b858c5bf7bbd11f326998bafc"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12371aec33e1a3758597c5c631bae9a5286f3c963bdfb4d17acdd2d395406687"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:560f13b6baa49785665276c963edc363f8ad4b4fc910a883e2625bdb14a83a9e"},
|
||||
{file = "ruff-0.12.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:023040a3499f6f974ae9091bcdd0385dd9e9eb4942f231c23c57708147b06311"},
|
||||
{file = "ruff-0.12.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:883d844967bffff5ab28bba1a4d246c1a1b2933f48cb9840f3fdc5111c603b07"},
|
||||
{file = "ruff-0.12.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2120d3aa855ff385e0e562fdee14d564c9675edbe41625c87eeab744a7830d12"},
|
||||
{file = "ruff-0.12.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6b16647cbb470eaf4750d27dddc6ebf7758b918887b56d39e9c22cce2049082b"},
|
||||
{file = "ruff-0.12.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e1417051edb436230023575b149e8ff843a324557fe0a265863b7602df86722f"},
|
||||
{file = "ruff-0.12.3-py3-none-win32.whl", hash = "sha256:dfd45e6e926deb6409d0616078a666ebce93e55e07f0fb0228d4b2608b2c248d"},
|
||||
{file = "ruff-0.12.3-py3-none-win_amd64.whl", hash = "sha256:a946cf1e7ba3209bdef039eb97647f1c77f6f540e5845ec9c114d3af8df873e7"},
|
||||
{file = "ruff-0.12.3-py3-none-win_arm64.whl", hash = "sha256:5f9c7c9c8f84c2d7f27e93674d27136fbf489720251544c4da7fb3d742e011b1"},
|
||||
{file = "ruff-0.12.3.tar.gz", hash = "sha256:f1b5a4b6668fd7b7ea3697d8d98857390b40c1320a63a178eee6be0899ea2d77"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1615,4 +1614,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "4cc687aabe5865665fb8c4ccc0ea7e0af80b41e401ca37919f57efa6e0b5be00"
|
||||
content-hash = "f67db13e6f68b1d67a55eee908c1c560bfa44da8509f98f842889a7570a9830f"
|
||||
|
||||
@@ -23,7 +23,7 @@ supabase = "^2.16.0"
|
||||
uvicorn = "^0.35.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "^0.12.9"
|
||||
ruff = "^0.12.3"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
# Development and testing files
|
||||
**/__pycache__
|
||||
**/*.pyc
|
||||
**/*.pyo
|
||||
**/*.pyd
|
||||
**/.Python
|
||||
**/env/
|
||||
**/venv/
|
||||
**/.venv/
|
||||
**/pip-log.txt
|
||||
**/.pytest_cache/
|
||||
**/test-results/
|
||||
**/snapshots/
|
||||
**/test/
|
||||
|
||||
# IDE and editor files
|
||||
**/.vscode/
|
||||
**/.idea/
|
||||
**/*.swp
|
||||
**/*.swo
|
||||
*~
|
||||
|
||||
# OS files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Logs
|
||||
**/*.log
|
||||
**/logs/
|
||||
|
||||
# Git
|
||||
.git/
|
||||
.gitignore
|
||||
|
||||
# Documentation
|
||||
**/*.md
|
||||
!README.md
|
||||
|
||||
# Local development files
|
||||
.env
|
||||
.env.local
|
||||
**/.env.test
|
||||
|
||||
# Build artifacts
|
||||
**/dist/
|
||||
**/build/
|
||||
**/target/
|
||||
|
||||
# Docker files (avoid recursion)
|
||||
Dockerfile*
|
||||
docker-compose*
|
||||
.dockerignore
|
||||
@@ -1,9 +1,3 @@
|
||||
# Backend Configuration
|
||||
# This file contains environment variables that MUST be set for the AutoGPT platform
|
||||
# Variables with working defaults in settings.py are not included here
|
||||
|
||||
## ===== REQUIRED DATABASE CONFIGURATION ===== ##
|
||||
# PostgreSQL Database Connection
|
||||
DB_USER=postgres
|
||||
DB_PASS=your-super-secret-and-long-postgres-password
|
||||
DB_NAME=postgres
|
||||
@@ -16,50 +10,72 @@ DB_SCHEMA=platform
|
||||
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
ENABLE_AUTH=true
|
||||
|
||||
## ===== REQUIRED SERVICE CREDENTIALS ===== ##
|
||||
# Redis Configuration
|
||||
# EXECUTOR
|
||||
NUM_GRAPH_WORKERS=10
|
||||
|
||||
BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]
|
||||
|
||||
# generate using `from cryptography.fernet import Fernet;Fernet.generate_key().decode()`
|
||||
ENCRYPTION_KEY='dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw='
|
||||
UNSUBSCRIBE_SECRET_KEY = 'HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio='
|
||||
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=password
|
||||
|
||||
# RabbitMQ Credentials
|
||||
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
ENABLE_CREDIT=false
|
||||
STRIPE_API_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
|
||||
# Supabase Authentication
|
||||
# What environment things should be logged under: local dev or prod
|
||||
APP_ENV=local
|
||||
# What environment to behave as: "local" or "cloud"
|
||||
BEHAVE_AS=local
|
||||
PYRO_HOST=localhost
|
||||
SENTRY_DSN=
|
||||
|
||||
# Email For Postmark so we can send emails
|
||||
POSTMARK_SERVER_API_TOKEN=
|
||||
POSTMARK_SENDER_EMAIL=invalid@invalid.com
|
||||
POSTMARK_WEBHOOK_TOKEN=
|
||||
|
||||
## User auth with Supabase is required for any of the 3rd party integrations with auth to work.
|
||||
ENABLE_AUTH=true
|
||||
SUPABASE_URL=http://localhost:8000
|
||||
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
|
||||
## ===== REQUIRED SECURITY KEYS ===== ##
|
||||
# Generate using: from cryptography.fernet import Fernet;Fernet.generate_key().decode()
|
||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||
# RabbitMQ credentials -- Used for communication between services
|
||||
RABBITMQ_HOST=localhost
|
||||
RABBITMQ_PORT=5672
|
||||
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
|
||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||
# Platform URLs (set these for webhooks and OAuth to work)
|
||||
PLATFORM_BASE_URL=http://localhost:8000
|
||||
FRONTEND_BASE_URL=http://localhost:3000
|
||||
|
||||
# Media Storage (required for marketplace and library functionality)
|
||||
## GCS bucket is required for marketplace and library functionality
|
||||
MEDIA_GCS_BUCKET_NAME=
|
||||
|
||||
## ===== API KEYS AND OAUTH CREDENTIALS ===== ##
|
||||
# All API keys below are optional - only add what you need
|
||||
## For local development, you may need to set FRONTEND_BASE_URL for the OAuth flow
|
||||
## for integrations to work. Defaults to the value of PLATFORM_BASE_URL if not set.
|
||||
# FRONTEND_BASE_URL=http://localhost:3000
|
||||
|
||||
# AI/LLM Services
|
||||
OPENAI_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
GROQ_API_KEY=
|
||||
LLAMA_API_KEY=
|
||||
AIML_API_KEY=
|
||||
V0_API_KEY=
|
||||
OPEN_ROUTER_API_KEY=
|
||||
NVIDIA_API_KEY=
|
||||
## PLATFORM_BASE_URL must be set to a *publicly accessible* URL pointing to your backend
|
||||
## to use the platform's webhook-related functionality.
|
||||
## If you are developing locally, you can use something like ngrok to get a publc URL
|
||||
## and tunnel it to your locally running backend.
|
||||
PLATFORM_BASE_URL=http://localhost:3000
|
||||
|
||||
## Cloudflare Turnstile (CAPTCHA) Configuration
|
||||
## Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
|
||||
## This is the backend secret key
|
||||
TURNSTILE_SECRET_KEY=
|
||||
## This is the verify URL
|
||||
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
|
||||
|
||||
## == INTEGRATION CREDENTIALS == ##
|
||||
# Each set of server side credentials is required for the corresponding 3rd party
|
||||
# integration to work.
|
||||
|
||||
# OAuth Credentials
|
||||
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
|
||||
# e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||
|
||||
@@ -69,6 +85,7 @@ GITHUB_CLIENT_SECRET=
|
||||
|
||||
# Google OAuth App server credentials - https://console.cloud.google.com/apis/credentials, and enable gmail api and set scopes
|
||||
# https://console.cloud.google.com/apis/credentials/consent ?project=<your_project_id>
|
||||
|
||||
# You'll need to add/enable the following scopes (minimum):
|
||||
# https://console.developers.google.com/apis/api/gmail.googleapis.com/overview ?project=<your_project_id>
|
||||
# https://console.cloud.google.com/apis/library/sheets.googleapis.com/ ?project=<your_project_id>
|
||||
@@ -104,75 +121,104 @@ LINEAR_CLIENT_SECRET=
|
||||
TODOIST_CLIENT_ID=
|
||||
TODOIST_CLIENT_SECRET=
|
||||
|
||||
NOTION_CLIENT_ID=
|
||||
NOTION_CLIENT_SECRET=
|
||||
## ===== OPTIONAL API KEYS ===== ##
|
||||
|
||||
# Discord OAuth App credentials
|
||||
# 1. Go to https://discord.com/developers/applications
|
||||
# 2. Create a new application
|
||||
# 3. Go to OAuth2 section and add redirect URI: http://localhost:3000/auth/integrations/oauth_callback
|
||||
# 4. Copy Client ID and Client Secret below
|
||||
DISCORD_CLIENT_ID=
|
||||
DISCORD_CLIENT_SECRET=
|
||||
# LLM
|
||||
OPENAI_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
AIML_API_KEY=
|
||||
GROQ_API_KEY=
|
||||
OPEN_ROUTER_API_KEY=
|
||||
LLAMA_API_KEY=
|
||||
|
||||
# Reddit
|
||||
# Go to https://www.reddit.com/prefs/apps and create a new app
|
||||
# Choose "script" for the type
|
||||
# Fill in the redirect uri as <your_frontend_url>/auth/integrations/oauth_callback, e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||
REDDIT_CLIENT_ID=
|
||||
REDDIT_CLIENT_SECRET=
|
||||
REDDIT_USER_AGENT="AutoGPT:1.0 (by /u/autogpt)"
|
||||
|
||||
# Payment Processing
|
||||
STRIPE_API_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
|
||||
# Email Service (for sending notifications and confirmations)
|
||||
POSTMARK_SERVER_API_TOKEN=
|
||||
POSTMARK_SENDER_EMAIL=invalid@invalid.com
|
||||
POSTMARK_WEBHOOK_TOKEN=
|
||||
|
||||
# Error Tracking
|
||||
SENTRY_DSN=
|
||||
|
||||
# Cloudflare Turnstile (CAPTCHA) Configuration
|
||||
# Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
|
||||
# This is the backend secret key
|
||||
TURNSTILE_SECRET_KEY=
|
||||
# This is the verify URL
|
||||
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
|
||||
|
||||
# Feature Flags
|
||||
LAUNCH_DARKLY_SDK_KEY=
|
||||
|
||||
# Content Generation & Media
|
||||
DID_API_KEY=
|
||||
FAL_API_KEY=
|
||||
IDEOGRAM_API_KEY=
|
||||
REPLICATE_API_KEY=
|
||||
REVID_API_KEY=
|
||||
SCREENSHOTONE_API_KEY=
|
||||
UNREAL_SPEECH_API_KEY=
|
||||
|
||||
# Data & Search Services
|
||||
E2B_API_KEY=
|
||||
EXA_API_KEY=
|
||||
JINA_API_KEY=
|
||||
MEM0_API_KEY=
|
||||
OPENWEATHERMAP_API_KEY=
|
||||
GOOGLE_MAPS_API_KEY=
|
||||
|
||||
# Communication Services
|
||||
# Discord
|
||||
DISCORD_BOT_TOKEN=
|
||||
MEDIUM_API_KEY=
|
||||
MEDIUM_AUTHOR_ID=
|
||||
|
||||
# SMTP/Email
|
||||
SMTP_SERVER=
|
||||
SMTP_PORT=
|
||||
SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# Business & Marketing Tools
|
||||
# D-ID
|
||||
DID_API_KEY=
|
||||
|
||||
# Open Weather Map
|
||||
OPENWEATHERMAP_API_KEY=
|
||||
|
||||
# SMTP
|
||||
SMTP_SERVER=
|
||||
SMTP_PORT=
|
||||
SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# Medium
|
||||
MEDIUM_API_KEY=
|
||||
MEDIUM_AUTHOR_ID=
|
||||
|
||||
# Google Maps
|
||||
GOOGLE_MAPS_API_KEY=
|
||||
|
||||
# Replicate
|
||||
REPLICATE_API_KEY=
|
||||
|
||||
# Ideogram
|
||||
IDEOGRAM_API_KEY=
|
||||
|
||||
# Fal
|
||||
FAL_API_KEY=
|
||||
|
||||
# Exa
|
||||
EXA_API_KEY=
|
||||
|
||||
# E2B
|
||||
E2B_API_KEY=
|
||||
|
||||
# Mem0
|
||||
MEM0_API_KEY=
|
||||
|
||||
# Nvidia
|
||||
NVIDIA_API_KEY=
|
||||
|
||||
# Apollo
|
||||
APOLLO_API_KEY=
|
||||
ENRICHLAYER_API_KEY=
|
||||
AYRSHARE_API_KEY=
|
||||
AYRSHARE_JWT_KEY=
|
||||
|
||||
# SmartLead
|
||||
SMARTLEAD_API_KEY=
|
||||
|
||||
# ZeroBounce
|
||||
ZEROBOUNCE_API_KEY=
|
||||
|
||||
# Other Services
|
||||
AUTOMOD_API_KEY=
|
||||
# Ayrshare
|
||||
AYRSHARE_API_KEY=
|
||||
AYRSHARE_JWT_KEY=
|
||||
|
||||
## ===== OPTIONAL API KEYS END ===== ##
|
||||
|
||||
# Block Error Rate Monitoring
|
||||
BLOCK_ERROR_RATE_THRESHOLD=0.5
|
||||
BLOCK_ERROR_RATE_CHECK_INTERVAL_SECS=86400
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=INFO
|
||||
ENABLE_CLOUD_LOGGING=false
|
||||
ENABLE_FILE_LOGGING=false
|
||||
# Use to manually set the log directory
|
||||
# LOG_DIR=./logs
|
||||
|
||||
# Example Blocks Configuration
|
||||
# Set to true to enable example blocks in development
|
||||
# These blocks are disabled by default in production
|
||||
ENABLE_EXAMPLE_BLOCKS=false
|
||||
|
||||
# Cloud Storage Configuration
|
||||
# Cleanup interval for expired files (hours between cleanup runs, 1-24 hours)
|
||||
CLOUD_STORAGE_CLEANUP_INTERVAL_HOURS=6
|
||||
1
autogpt_platform/backend/.gitignore
vendored
1
autogpt_platform/backend/.gitignore
vendored
@@ -1,4 +1,3 @@
|
||||
.env
|
||||
database.db
|
||||
database.db-journal
|
||||
dev.db
|
||||
|
||||
@@ -1,34 +1,31 @@
|
||||
FROM debian:13-slim AS builder
|
||||
FROM python:3.11.10-slim-bookworm AS builder
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONDONTWRITEBYTECODE 1
|
||||
ENV PYTHONUNBUFFERED 1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
|
||||
|
||||
# Update package list and install Python and build dependencies
|
||||
RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
||||
&& apt-get install -y \
|
||||
python3.13 \
|
||||
python3.13-dev \
|
||||
python3.13-venv \
|
||||
python3-pip \
|
||||
build-essential \
|
||||
libpq5 \
|
||||
libz-dev \
|
||||
libssl-dev \
|
||||
postgresql-client
|
||||
RUN apt-get update --allow-releaseinfo-change --fix-missing
|
||||
|
||||
# Install build dependencies
|
||||
RUN apt-get install -y build-essential
|
||||
RUN apt-get install -y libpq5
|
||||
RUN apt-get install -y libz-dev
|
||||
RUN apt-get install -y libssl-dev
|
||||
RUN apt-get install -y postgresql-client
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry
|
||||
ENV POETRY_NO_INTERACTION=1
|
||||
ENV POETRY_VIRTUALENVS_CREATE=true
|
||||
ENV POETRY_VIRTUALENVS_IN_PROJECT=true
|
||||
ENV POETRY_VIRTUALENVS_CREATE=false
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
RUN pip3 install poetry --break-system-packages
|
||||
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||
RUN pip3 install --upgrade pip setuptools
|
||||
|
||||
RUN pip3 install poetry
|
||||
|
||||
# Copy and install dependencies
|
||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||
@@ -40,30 +37,27 @@ RUN poetry install --no-ansi --no-root
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
RUN poetry run prisma generate
|
||||
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
FROM python:3.11.10-slim-bookworm AS server_dependencies
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry \
|
||||
POETRY_NO_INTERACTION=1 \
|
||||
POETRY_VIRTUALENVS_CREATE=true \
|
||||
POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||
DEBIAN_FRONTEND=noninteractive
|
||||
POETRY_VIRTUALENVS_CREATE=false
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
# Install Python without upgrading system-managed packages
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.13 \
|
||||
python3-pip
|
||||
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||
RUN pip3 install --upgrade pip setuptools
|
||||
|
||||
# Copy only necessary files from builder
|
||||
COPY --from=builder /app /app
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
COPY --from=builder /usr/local/lib/python3.11 /usr/local/lib/python3.11
|
||||
COPY --from=builder /usr/local/bin /usr/local/bin
|
||||
# Copy Prisma binaries
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
|
||||
RUN mkdir -p /app/autogpt_platform/autogpt_libs
|
||||
RUN mkdir -p /app/autogpt_platform/backend
|
||||
@@ -74,12 +68,6 @@ COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.tom
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
FROM server_dependencies AS migrate
|
||||
|
||||
# Migration stage only needs schema and migrations - much lighter than full backend
|
||||
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
|
||||
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
|
||||
|
||||
FROM server_dependencies AS server
|
||||
|
||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||
|
||||
@@ -43,11 +43,11 @@ def main(**kwargs):
|
||||
|
||||
run_processes(
|
||||
DatabaseManager().set_log_level("warning"),
|
||||
ExecutionManager(),
|
||||
Scheduler(),
|
||||
NotificationManager(),
|
||||
WebsocketServer(),
|
||||
AgentServer(),
|
||||
ExecutionManager(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -14,8 +15,7 @@ from backend.data.block import (
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.json import validate_with_jsonschema
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util import json, retry
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,9 +25,6 @@ class AgentExecutorBlock(Block):
|
||||
user_id: str = SchemaField(description="User ID")
|
||||
graph_id: str = SchemaField(description="Graph ID")
|
||||
graph_version: int = SchemaField(description="Graph Version")
|
||||
agent_name: Optional[str] = SchemaField(
|
||||
default=None, description="Name to display in the Builder UI"
|
||||
)
|
||||
|
||||
inputs: BlockInput = SchemaField(description="Input data for the graph")
|
||||
input_schema: dict = SchemaField(description="Input schema for the graph")
|
||||
@@ -52,7 +49,7 @@ class AgentExecutorBlock(Block):
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
return validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
return json.validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
|
||||
class Output(BlockSchema):
|
||||
pass
|
||||
@@ -98,14 +95,23 @@ class AgentExecutorBlock(Block):
|
||||
logger=logger,
|
||||
):
|
||||
yield name, data
|
||||
except BaseException as e:
|
||||
except asyncio.CancelledError:
|
||||
await self._stop(
|
||||
graph_exec_id=graph_exec.id,
|
||||
user_id=input_data.user_id,
|
||||
logger=logger,
|
||||
)
|
||||
logger.warning(
|
||||
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} failed: {e.__class__.__name__} {str(e)}; execution is stopped."
|
||||
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} was cancelled."
|
||||
)
|
||||
except Exception as e:
|
||||
await self._stop(
|
||||
graph_exec_id=graph_exec.id,
|
||||
user_id=input_data.user_id,
|
||||
logger=logger,
|
||||
)
|
||||
logger.error(
|
||||
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} failed: {e}, execution is stopped."
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -125,7 +131,6 @@ class AgentExecutorBlock(Block):
|
||||
|
||||
log_id = f"Graph #{graph_id}-V{graph_version}, exec-id: {graph_exec_id}"
|
||||
logger.info(f"Starting execution of {log_id}")
|
||||
yielded_node_exec_ids = set()
|
||||
|
||||
async for event in event_bus.listen(
|
||||
user_id=user_id,
|
||||
@@ -157,14 +162,6 @@ class AgentExecutorBlock(Block):
|
||||
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
|
||||
)
|
||||
|
||||
if event.node_exec_id in yielded_node_exec_ids:
|
||||
logger.warning(
|
||||
f"{log_id} received duplicate event for node execution {event.node_exec_id}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
yielded_node_exec_ids.add(event.node_exec_id)
|
||||
|
||||
if not event.block_id:
|
||||
logger.warning(f"{log_id} received event without block_id {event}")
|
||||
continue
|
||||
@@ -184,7 +181,7 @@ class AgentExecutorBlock(Block):
|
||||
)
|
||||
yield output_name, output_data
|
||||
|
||||
@func_retry
|
||||
@retry.func_retry
|
||||
async def _stop(
|
||||
self,
|
||||
graph_exec_id: str,
|
||||
@@ -200,8 +197,7 @@ class AgentExecutorBlock(Block):
|
||||
await execution_utils.stop_graph_execution(
|
||||
graph_exec_id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
wait_timeout=3600,
|
||||
)
|
||||
logger.info(f"Execution {log_id} stopped successfully.")
|
||||
except TimeoutError as e:
|
||||
logger.error(f"Execution {log_id} stop timed out: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop execution {log_id}: {e}")
|
||||
|
||||
@@ -166,7 +166,7 @@ class AIMusicGeneratorBlock(Block):
|
||||
output_format=input_data.output_format,
|
||||
normalization_strategy=input_data.normalization_strategy,
|
||||
)
|
||||
if result and isinstance(result, str) and result.startswith("http"):
|
||||
if result and result != "No output received":
|
||||
yield "result", result
|
||||
return
|
||||
else:
|
||||
|
||||
@@ -9,24 +9,6 @@ from backend.sdk import BaseModel, Credentials, Requests
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_bools(
|
||||
obj: Any,
|
||||
) -> Any: # noqa: ANN401 – allow Any for deep conversion utility
|
||||
"""Recursively walk *obj* and coerce string booleans to real booleans."""
|
||||
if isinstance(obj, str):
|
||||
lowered = obj.lower()
|
||||
if lowered == "true":
|
||||
return True
|
||||
if lowered == "false":
|
||||
return False
|
||||
return obj
|
||||
if isinstance(obj, list):
|
||||
return [_convert_bools(item) for item in obj]
|
||||
if isinstance(obj, dict):
|
||||
return {k: _convert_bools(v) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
|
||||
class WebhookFilters(BaseModel):
|
||||
dataTypes: list[str]
|
||||
changeTypes: list[str] | None = None
|
||||
@@ -597,7 +579,7 @@ async def update_table(
|
||||
response = await Requests().patch(
|
||||
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables/{table_id}",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
json=_convert_bools(params),
|
||||
json=params,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
@@ -627,7 +609,7 @@ async def create_field(
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables/{table_id}/fields",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
json=_convert_bools(params),
|
||||
json=params,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@@ -651,7 +633,7 @@ async def update_field(
|
||||
response = await Requests().patch(
|
||||
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables/{table_id}/fields/{field_id}",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
json=_convert_bools(params),
|
||||
json=params,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@@ -709,7 +691,7 @@ async def list_records(
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/{base_id}/{table_id_or_name}",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
json=_convert_bools(params),
|
||||
params=params,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@@ -738,22 +720,20 @@ async def update_multiple_records(
|
||||
typecast: bool | None = None,
|
||||
) -> dict[str, dict[str, dict[str, str]]]:
|
||||
|
||||
params: dict[
|
||||
str, str | bool | dict[str, list[str]] | list[dict[str, dict[str, str]]]
|
||||
] = {}
|
||||
params: dict[str, str | dict[str, list[str]] | list[dict[str, dict[str, str]]]] = {}
|
||||
if perform_upsert:
|
||||
params["performUpsert"] = perform_upsert
|
||||
if return_fields_by_field_id:
|
||||
params["returnFieldsByFieldId"] = str(return_fields_by_field_id)
|
||||
if typecast:
|
||||
params["typecast"] = typecast
|
||||
params["typecast"] = str(typecast)
|
||||
|
||||
params["records"] = [_convert_bools(record) for record in records]
|
||||
params["records"] = records
|
||||
|
||||
response = await Requests().patch(
|
||||
f"https://api.airtable.com/v0/{base_id}/{table_id_or_name}",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
json=_convert_bools(params),
|
||||
json=params,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@@ -767,20 +747,18 @@ async def update_record(
|
||||
typecast: bool | None = None,
|
||||
fields: dict[str, Any] | None = None,
|
||||
) -> dict[str, dict[str, dict[str, str]]]:
|
||||
params: dict[str, str | bool | dict[str, Any] | list[dict[str, dict[str, str]]]] = (
|
||||
{}
|
||||
)
|
||||
params: dict[str, str | dict[str, Any] | list[dict[str, dict[str, str]]]] = {}
|
||||
if return_fields_by_field_id:
|
||||
params["returnFieldsByFieldId"] = return_fields_by_field_id
|
||||
params["returnFieldsByFieldId"] = str(return_fields_by_field_id)
|
||||
if typecast:
|
||||
params["typecast"] = typecast
|
||||
params["typecast"] = str(typecast)
|
||||
if fields:
|
||||
params["fields"] = fields
|
||||
|
||||
response = await Requests().patch(
|
||||
f"https://api.airtable.com/v0/{base_id}/{table_id_or_name}/{record_id}",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
json=_convert_bools(params),
|
||||
json=params,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@@ -801,22 +779,21 @@ async def create_record(
|
||||
len(records) <= 10
|
||||
), "Only up to 10 records can be provided when using records"
|
||||
|
||||
params: dict[str, str | bool | dict[str, Any] | list[dict[str, Any]]] = {}
|
||||
params: dict[str, str | dict[str, Any] | list[dict[str, Any]]] = {}
|
||||
if fields:
|
||||
params["fields"] = fields
|
||||
if records:
|
||||
params["records"] = records
|
||||
if return_fields_by_field_id:
|
||||
params["returnFieldsByFieldId"] = return_fields_by_field_id
|
||||
params["returnFieldsByFieldId"] = str(return_fields_by_field_id)
|
||||
if typecast:
|
||||
params["typecast"] = typecast
|
||||
params["typecast"] = str(typecast)
|
||||
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/{base_id}/{table_id_or_name}",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
json=_convert_bools(params),
|
||||
json=params,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
|
||||
@@ -873,7 +850,7 @@ async def create_webhook(
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/bases/{base_id}/webhooks",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
json=_convert_bools(params),
|
||||
json=params,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@@ -1218,7 +1195,7 @@ async def create_base(
|
||||
"Authorization": credentials.auth_header(),
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=_convert_bools(params),
|
||||
json=params,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
@@ -159,7 +159,6 @@ class AirtableOAuthHandler(BaseOAuthHandler):
|
||||
logger.info("Successfully refreshed tokens")
|
||||
|
||||
new_credentials = OAuth2Credentials(
|
||||
id=credentials.id,
|
||||
access_token=SecretStr(response.access_token),
|
||||
refresh_token=SecretStr(response.refresh_token),
|
||||
access_token_expires_at=int(time.time()) + response.expires_in,
|
||||
|
||||
@@ -4,19 +4,11 @@ from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.block import BlockSchema
|
||||
from backend.data.model import SchemaField, UserIntegrations
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.ayrshare import AyrshareClient
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
|
||||
|
||||
async def get_profile_key(user_id: str):
|
||||
user_integrations: UserIntegrations = (
|
||||
await get_database_manager_async_client().get_user_integrations(user_id)
|
||||
)
|
||||
return user_integrations.managed_credentials.ayrshare_profile_key
|
||||
|
||||
|
||||
class BaseAyrshareInput(BlockSchema):
|
||||
"""Base input model for Ayrshare social media posts with common fields."""
|
||||
|
||||
|
||||
@@ -6,9 +6,10 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
|
||||
|
||||
class PostToBlueskyBlock(Block):
|
||||
@@ -57,12 +58,10 @@ class PostToBlueskyBlock(Block):
|
||||
self,
|
||||
input_data: "PostToBlueskyBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Bluesky with Bluesky-specific options."""
|
||||
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -6,14 +6,10 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
CarouselItem,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
from ._util import BaseAyrshareInput, CarouselItem, create_ayrshare_client
|
||||
|
||||
|
||||
class PostToFacebookBlock(Block):
|
||||
@@ -120,11 +116,10 @@ class PostToFacebookBlock(Block):
|
||||
self,
|
||||
input_data: "PostToFacebookBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Facebook with Facebook-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -6,9 +6,10 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
|
||||
|
||||
class PostToGMBBlock(Block):
|
||||
@@ -110,10 +111,9 @@ class PostToGMBBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: "PostToGMBBlock.Input", *, user_id: str, **kwargs
|
||||
self, input_data: "PostToGMBBlock.Input", *, profile_key: SecretStr, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""Post to Google My Business with GMB-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -8,14 +8,10 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
InstagramUserTag,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
from ._util import BaseAyrshareInput, InstagramUserTag, create_ayrshare_client
|
||||
|
||||
|
||||
class PostToInstagramBlock(Block):
|
||||
@@ -112,11 +108,10 @@ class PostToInstagramBlock(Block):
|
||||
self,
|
||||
input_data: "PostToInstagramBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Instagram with Instagram-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -6,9 +6,10 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
|
||||
|
||||
class PostToLinkedInBlock(Block):
|
||||
@@ -112,11 +113,10 @@ class PostToLinkedInBlock(Block):
|
||||
self,
|
||||
input_data: "PostToLinkedInBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to LinkedIn with LinkedIn-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -6,14 +6,10 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
PinterestCarouselOption,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
from ._util import BaseAyrshareInput, PinterestCarouselOption, create_ayrshare_client
|
||||
|
||||
|
||||
class PostToPinterestBlock(Block):
|
||||
@@ -92,11 +88,10 @@ class PostToPinterestBlock(Block):
|
||||
self,
|
||||
input_data: "PostToPinterestBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Pinterest with Pinterest-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -6,9 +6,10 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
|
||||
|
||||
class PostToRedditBlock(Block):
|
||||
@@ -35,9 +36,8 @@ class PostToRedditBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: "PostToRedditBlock.Input", *, user_id: str, **kwargs
|
||||
self, input_data: "PostToRedditBlock.Input", *, profile_key: SecretStr, **kwargs
|
||||
) -> BlockOutput:
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -6,9 +6,10 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
|
||||
|
||||
class PostToSnapchatBlock(Block):
|
||||
@@ -62,11 +63,10 @@ class PostToSnapchatBlock(Block):
|
||||
self,
|
||||
input_data: "PostToSnapchatBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Snapchat with Snapchat-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -6,9 +6,10 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
|
||||
|
||||
class PostToTelegramBlock(Block):
|
||||
@@ -57,11 +58,10 @@ class PostToTelegramBlock(Block):
|
||||
self,
|
||||
input_data: "PostToTelegramBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Telegram with Telegram-specific validation."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -6,9 +6,10 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
|
||||
|
||||
class PostToThreadsBlock(Block):
|
||||
@@ -50,11 +51,10 @@ class PostToThreadsBlock(Block):
|
||||
self,
|
||||
input_data: "PostToThreadsBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Threads with Threads-specific validation."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from enum import Enum
|
||||
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
@@ -8,15 +6,10 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class TikTokVisibility(str, Enum):
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
||||
FOLLOWERS = "followers"
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
|
||||
|
||||
class PostToTikTokBlock(Block):
|
||||
@@ -28,6 +21,7 @@ class PostToTikTokBlock(Block):
|
||||
# Override post field to include TikTok-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (max 2,200 chars, empty string allowed). Use @handle to mention users. Line breaks will be ignored.",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
@@ -40,7 +34,7 @@ class PostToTikTokBlock(Block):
|
||||
|
||||
# TikTok-specific options
|
||||
auto_add_music: bool = SchemaField(
|
||||
description="Whether to automatically add recommended music to the post. If you set this field to true, you can change the music later in the TikTok app.",
|
||||
description="Automatically add recommended music to image posts",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
@@ -60,17 +54,17 @@ class PostToTikTokBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
is_ai_generated: bool = SchemaField(
|
||||
description="If you enable the toggle, your video will be labeled as “Creator labeled as AI-generated” once posted and can’t be changed. The “Creator labeled as AI-generated” label indicates that the content was completely AI-generated or significantly edited with AI.",
|
||||
description="Label content as AI-generated (video only)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
is_branded_content: bool = SchemaField(
|
||||
description="Whether to enable the Branded Content toggle. If this field is set to true, the video will be labeled as Branded Content, indicating you are in a paid partnership with a brand. A “Paid partnership” label will be attached to the video.",
|
||||
description="Label as branded content (paid partnership)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
is_brand_organic: bool = SchemaField(
|
||||
description="Whether to enable the Brand Organic Content toggle. If this field is set to true, the video will be labeled as Brand Organic Content, indicating you are promoting yourself or your own business. A “Promotional content” label will be attached to the video.",
|
||||
description="Label as brand organic content (promotional)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
@@ -87,9 +81,9 @@ class PostToTikTokBlock(Block):
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
visibility: TikTokVisibility = SchemaField(
|
||||
visibility: str = SchemaField(
|
||||
description="Post visibility: 'public', 'private', 'followers', or 'friends'",
|
||||
default=TikTokVisibility.PUBLIC,
|
||||
default="public",
|
||||
advanced=True,
|
||||
)
|
||||
draft: bool = SchemaField(
|
||||
@@ -104,6 +98,7 @@ class PostToTikTokBlock(Block):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="7faf4b27-96b0-4f05-bf64-e0de54ae74e1",
|
||||
description="Post to TikTok using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
@@ -113,10 +108,9 @@ class PostToTikTokBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: "PostToTikTokBlock.Input", *, user_id: str, **kwargs
|
||||
self, input_data: "PostToTikTokBlock.Input", *, profile_key: SecretStr, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""Post to TikTok with TikTok-specific validation and options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
@@ -166,6 +160,12 @@ class PostToTikTokBlock(Block):
|
||||
yield "error", f"Image cover index {input_data.image_cover_index} is out of range (max: {len(input_data.media_urls) - 1})"
|
||||
return
|
||||
|
||||
# Validate visibility option
|
||||
valid_visibility = ["public", "private", "followers", "friends"]
|
||||
if input_data.visibility not in valid_visibility:
|
||||
yield "error", f"TikTok visibility must be one of: {', '.join(valid_visibility)}"
|
||||
return
|
||||
|
||||
# Check for PNG files (not supported)
|
||||
has_png = any(url.lower().endswith(".png") for url in input_data.media_urls)
|
||||
if has_png:
|
||||
@@ -218,8 +218,8 @@ class PostToTikTokBlock(Block):
|
||||
if input_data.title:
|
||||
tiktok_options["title"] = input_data.title
|
||||
|
||||
if input_data.visibility != TikTokVisibility.PUBLIC:
|
||||
tiktok_options["visibility"] = input_data.visibility.value
|
||||
if input_data.visibility != "public":
|
||||
tiktok_options["visibility"] = input_data.visibility
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
|
||||
@@ -6,9 +6,10 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
|
||||
|
||||
class PostToXBlock(Block):
|
||||
@@ -115,11 +116,10 @@ class PostToXBlock(Block):
|
||||
self,
|
||||
input_data: "PostToXBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to X / Twitter with enhanced X-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -9,9 +9,10 @@ from backend.sdk import (
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client
|
||||
|
||||
|
||||
class YouTubeVisibility(str, Enum):
|
||||
@@ -137,12 +138,10 @@ class PostToYouTubeBlock(Block):
|
||||
self,
|
||||
input_data: "PostToYouTubeBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
profile_key: SecretStr,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to YouTube with YouTube-specific validation and options."""
|
||||
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
@@ -1,205 +0,0 @@
|
||||
"""
|
||||
Meeting BaaS API client module.
|
||||
All API calls centralized for consistency and maintainability.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.sdk import Requests
|
||||
|
||||
|
||||
class MeetingBaasAPI:
|
||||
"""Client for Meeting BaaS API endpoints."""
|
||||
|
||||
BASE_URL = "https://api.meetingbaas.com"
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
"""Initialize API client with authentication key."""
|
||||
self.api_key = api_key
|
||||
self.headers = {"x-meeting-baas-api-key": api_key}
|
||||
self.requests = Requests()
|
||||
|
||||
# Bot Management Endpoints
|
||||
|
||||
async def join_meeting(
|
||||
self,
|
||||
bot_name: str,
|
||||
meeting_url: str,
|
||||
reserved: bool = False,
|
||||
bot_image: Optional[str] = None,
|
||||
entry_message: Optional[str] = None,
|
||||
start_time: Optional[int] = None,
|
||||
speech_to_text: Optional[Dict[str, Any]] = None,
|
||||
webhook_url: Optional[str] = None,
|
||||
automatic_leave: Optional[Dict[str, Any]] = None,
|
||||
extra: Optional[Dict[str, Any]] = None,
|
||||
recording_mode: str = "speaker_view",
|
||||
streaming: Optional[Dict[str, Any]] = None,
|
||||
deduplication_key: Optional[str] = None,
|
||||
zoom_sdk_id: Optional[str] = None,
|
||||
zoom_sdk_pwd: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Deploy a bot to join and record a meeting.
|
||||
|
||||
POST /bots
|
||||
"""
|
||||
body = {
|
||||
"bot_name": bot_name,
|
||||
"meeting_url": meeting_url,
|
||||
"reserved": reserved,
|
||||
"recording_mode": recording_mode,
|
||||
}
|
||||
|
||||
# Add optional fields if provided
|
||||
if bot_image is not None:
|
||||
body["bot_image"] = bot_image
|
||||
if entry_message is not None:
|
||||
body["entry_message"] = entry_message
|
||||
if start_time is not None:
|
||||
body["start_time"] = start_time
|
||||
if speech_to_text is not None:
|
||||
body["speech_to_text"] = speech_to_text
|
||||
if webhook_url is not None:
|
||||
body["webhook_url"] = webhook_url
|
||||
if automatic_leave is not None:
|
||||
body["automatic_leave"] = automatic_leave
|
||||
if extra is not None:
|
||||
body["extra"] = extra
|
||||
if streaming is not None:
|
||||
body["streaming"] = streaming
|
||||
if deduplication_key is not None:
|
||||
body["deduplication_key"] = deduplication_key
|
||||
if zoom_sdk_id is not None:
|
||||
body["zoom_sdk_id"] = zoom_sdk_id
|
||||
if zoom_sdk_pwd is not None:
|
||||
body["zoom_sdk_pwd"] = zoom_sdk_pwd
|
||||
|
||||
response = await self.requests.post(
|
||||
f"{self.BASE_URL}/bots",
|
||||
headers=self.headers,
|
||||
json=body,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def leave_meeting(self, bot_id: str) -> bool:
|
||||
"""
|
||||
Remove a bot from an ongoing meeting.
|
||||
|
||||
DELETE /bots/{uuid}
|
||||
"""
|
||||
response = await self.requests.delete(
|
||||
f"{self.BASE_URL}/bots/{bot_id}",
|
||||
headers=self.headers,
|
||||
)
|
||||
return response.status in [200, 204]
|
||||
|
||||
async def retranscribe(
|
||||
self,
|
||||
bot_uuid: str,
|
||||
speech_to_text: Optional[Dict[str, Any]] = None,
|
||||
webhook_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Re-run transcription on a bot's audio.
|
||||
|
||||
POST /bots/retranscribe
|
||||
"""
|
||||
body: Dict[str, Any] = {"bot_uuid": bot_uuid}
|
||||
|
||||
if speech_to_text is not None:
|
||||
body["speech_to_text"] = speech_to_text
|
||||
if webhook_url is not None:
|
||||
body["webhook_url"] = webhook_url
|
||||
|
||||
response = await self.requests.post(
|
||||
f"{self.BASE_URL}/bots/retranscribe",
|
||||
headers=self.headers,
|
||||
json=body,
|
||||
)
|
||||
|
||||
if response.status == 202:
|
||||
return {"accepted": True}
|
||||
return response.json()
|
||||
|
||||
# Data Retrieval Endpoints
|
||||
|
||||
async def get_meeting_data(
|
||||
self, bot_id: str, include_transcripts: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieve meeting data including recording and transcripts.
|
||||
|
||||
GET /bots/meeting_data
|
||||
"""
|
||||
params = {
|
||||
"bot_id": bot_id,
|
||||
"include_transcripts": str(include_transcripts).lower(),
|
||||
}
|
||||
|
||||
response = await self.requests.get(
|
||||
f"{self.BASE_URL}/bots/meeting_data",
|
||||
headers=self.headers,
|
||||
params=params,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def get_screenshots(self, bot_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve screenshots captured during a meeting.
|
||||
|
||||
GET /bots/{uuid}/screenshots
|
||||
"""
|
||||
response = await self.requests.get(
|
||||
f"{self.BASE_URL}/bots/{bot_id}/screenshots",
|
||||
headers=self.headers,
|
||||
)
|
||||
result = response.json()
|
||||
# Ensure we return a list
|
||||
if isinstance(result, list):
|
||||
return result
|
||||
return []
|
||||
|
||||
async def delete_data(self, bot_id: str) -> bool:
|
||||
"""
|
||||
Delete a bot's recorded data.
|
||||
|
||||
POST /bots/{uuid}/delete_data
|
||||
"""
|
||||
response = await self.requests.post(
|
||||
f"{self.BASE_URL}/bots/{bot_id}/delete_data",
|
||||
headers=self.headers,
|
||||
)
|
||||
return response.status == 200
|
||||
|
||||
async def list_bots_with_metadata(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
sort_by: Optional[str] = None,
|
||||
sort_order: Optional[str] = None,
|
||||
filter_by: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
List bots with metadata including IDs, names, and meeting details.
|
||||
|
||||
GET /bots/bots_with_metadata
|
||||
"""
|
||||
params = {}
|
||||
if limit is not None:
|
||||
params["limit"] = limit
|
||||
if offset is not None:
|
||||
params["offset"] = offset
|
||||
if sort_by is not None:
|
||||
params["sort_by"] = sort_by
|
||||
if sort_order is not None:
|
||||
params["sort_order"] = sort_order
|
||||
if filter_by is not None:
|
||||
params.update(filter_by)
|
||||
|
||||
response = await self.requests.get(
|
||||
f"{self.BASE_URL}/bots/bots_with_metadata",
|
||||
headers=self.headers,
|
||||
params=params,
|
||||
)
|
||||
return response.json()
|
||||
@@ -1,13 +0,0 @@
|
||||
"""
|
||||
Shared configuration for all Meeting BaaS blocks using the SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
# Configure the Meeting BaaS provider with API key authentication
|
||||
baas = (
|
||||
ProviderBuilder("baas")
|
||||
.with_api_key("MEETING_BAAS_API_KEY", "Meeting BaaS API Key")
|
||||
.with_base_cost(5, BlockCostType.RUN) # Higher cost for meeting recording service
|
||||
.build()
|
||||
)
|
||||
@@ -1,217 +0,0 @@
|
||||
"""
|
||||
Meeting BaaS bot (recording) blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import MeetingBaasAPI
|
||||
from ._config import baas
|
||||
|
||||
|
||||
class BaasBotJoinMeetingBlock(Block):
|
||||
"""
|
||||
Deploy a bot immediately or at a scheduled start_time to join and record a meeting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
meeting_url: str = SchemaField(
|
||||
description="The URL of the meeting the bot should join"
|
||||
)
|
||||
bot_name: str = SchemaField(
|
||||
description="Display name for the bot in the meeting"
|
||||
)
|
||||
bot_image: str = SchemaField(
|
||||
description="URL to an image for the bot's avatar (16:9 ratio recommended)",
|
||||
default="",
|
||||
)
|
||||
entry_message: str = SchemaField(
|
||||
description="Chat message the bot will post upon entry", default=""
|
||||
)
|
||||
reserved: bool = SchemaField(
|
||||
description="Use a reserved bot slot (joins 4 min before meeting)",
|
||||
default=False,
|
||||
)
|
||||
start_time: Optional[int] = SchemaField(
|
||||
description="Unix timestamp (ms) when bot should join", default=None
|
||||
)
|
||||
webhook_url: str | None = SchemaField(
|
||||
description="URL to receive webhook events for this bot", default=None
|
||||
)
|
||||
timeouts: dict = SchemaField(
|
||||
description="Automatic leave timeouts configuration", default={}
|
||||
)
|
||||
extra: dict = SchemaField(
|
||||
description="Custom metadata to attach to the bot", default={}
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
bot_id: str = SchemaField(description="UUID of the deployed bot")
|
||||
join_response: dict = SchemaField(
|
||||
description="Full response from join operation"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="377d1a6a-a99b-46cf-9af3-1d1b12758e04",
|
||||
description="Deploy a bot to join and record a meeting",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
api = MeetingBaasAPI(api_key)
|
||||
|
||||
# Call API with all parameters
|
||||
data = await api.join_meeting(
|
||||
bot_name=input_data.bot_name,
|
||||
meeting_url=input_data.meeting_url,
|
||||
reserved=input_data.reserved,
|
||||
bot_image=input_data.bot_image if input_data.bot_image else None,
|
||||
entry_message=(
|
||||
input_data.entry_message if input_data.entry_message else None
|
||||
),
|
||||
start_time=input_data.start_time,
|
||||
speech_to_text={"provider": "Default"},
|
||||
webhook_url=input_data.webhook_url if input_data.webhook_url else None,
|
||||
automatic_leave=input_data.timeouts if input_data.timeouts else None,
|
||||
extra=input_data.extra if input_data.extra else None,
|
||||
)
|
||||
|
||||
yield "bot_id", data.get("bot_id", "")
|
||||
yield "join_response", data
|
||||
|
||||
|
||||
class BaasBotLeaveMeetingBlock(Block):
|
||||
"""
|
||||
Force the bot to exit the call.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot to remove from meeting")
|
||||
|
||||
class Output(BlockSchema):
|
||||
left: bool = SchemaField(description="Whether the bot successfully left")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bf77d128-8b25-4280-b5c7-2d553ba7e482",
|
||||
description="Remove a bot from an ongoing meeting",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
api = MeetingBaasAPI(api_key)
|
||||
|
||||
# Leave meeting
|
||||
left = await api.leave_meeting(input_data.bot_id)
|
||||
|
||||
yield "left", left
|
||||
|
||||
|
||||
class BaasBotFetchMeetingDataBlock(Block):
|
||||
"""
|
||||
Pull MP4 URL, transcript & metadata for a completed meeting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot whose data to fetch")
|
||||
include_transcripts: bool = SchemaField(
|
||||
description="Include transcript data in response", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
mp4_url: str = SchemaField(
|
||||
description="URL to download the meeting recording (time-limited)"
|
||||
)
|
||||
transcript: list = SchemaField(description="Meeting transcript data")
|
||||
metadata: dict = SchemaField(description="Meeting metadata and bot information")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ea7c1309-303c-4da1-893f-89c0e9d64e78",
|
||||
description="Retrieve recorded meeting data",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
api = MeetingBaasAPI(api_key)
|
||||
|
||||
# Fetch meeting data
|
||||
data = await api.get_meeting_data(
|
||||
bot_id=input_data.bot_id,
|
||||
include_transcripts=input_data.include_transcripts,
|
||||
)
|
||||
|
||||
yield "mp4_url", data.get("mp4", "")
|
||||
yield "transcript", data.get("bot_data", {}).get("transcripts", [])
|
||||
yield "metadata", data.get("bot_data", {}).get("bot", {})
|
||||
|
||||
|
||||
class BaasBotDeleteRecordingBlock(Block):
|
||||
"""
|
||||
Purge MP4 + transcript data for privacy or storage management.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot whose data to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
deleted: bool = SchemaField(
|
||||
description="Whether the data was successfully deleted"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bf8d1aa6-42d8-4944-b6bd-6bac554c0d3b",
|
||||
description="Permanently delete a meeting's recorded data",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
api = MeetingBaasAPI(api_key)
|
||||
|
||||
# Delete recording data
|
||||
deleted = await api.delete_data(input_data.bot_id)
|
||||
|
||||
yield "deleted", deleted
|
||||
@@ -1,178 +0,0 @@
|
||||
"""
|
||||
DataForSEO API client with async support using the SDK patterns.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.sdk import Requests, UserPasswordCredentials
|
||||
|
||||
|
||||
class DataForSeoClient:
|
||||
"""Client for the DataForSEO API using async requests."""
|
||||
|
||||
API_URL = "https://api.dataforseo.com"
|
||||
|
||||
def __init__(self, credentials: UserPasswordCredentials):
|
||||
self.credentials = credentials
|
||||
self.requests = Requests(
|
||||
trusted_origins=["https://api.dataforseo.com"],
|
||||
raise_for_status=False,
|
||||
)
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""Generate the authorization header using Basic Auth."""
|
||||
username = self.credentials.username.get_secret_value()
|
||||
password = self.credentials.password.get_secret_value()
|
||||
credentials_str = f"{username}:{password}"
|
||||
encoded = base64.b64encode(credentials_str.encode("ascii")).decode("ascii")
|
||||
return {
|
||||
"Authorization": f"Basic {encoded}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def keyword_suggestions(
|
||||
self,
|
||||
keyword: str,
|
||||
location_code: Optional[int] = None,
|
||||
language_code: Optional[str] = None,
|
||||
include_seed_keyword: bool = True,
|
||||
include_serp_info: bool = False,
|
||||
include_clickstream_data: bool = False,
|
||||
limit: int = 100,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get keyword suggestions from DataForSEO Labs.
|
||||
|
||||
Args:
|
||||
keyword: Seed keyword
|
||||
location_code: Location code for targeting
|
||||
language_code: Language code (e.g., "en")
|
||||
include_seed_keyword: Include seed keyword in results
|
||||
include_serp_info: Include SERP data
|
||||
include_clickstream_data: Include clickstream metrics
|
||||
limit: Maximum number of results (up to 3000)
|
||||
|
||||
Returns:
|
||||
API response with keyword suggestions
|
||||
"""
|
||||
endpoint = f"{self.API_URL}/v3/dataforseo_labs/google/keyword_suggestions/live"
|
||||
|
||||
# Build payload only with non-None values to avoid sending null fields
|
||||
task_data: dict[str, Any] = {
|
||||
"keyword": keyword,
|
||||
}
|
||||
|
||||
if location_code is not None:
|
||||
task_data["location_code"] = location_code
|
||||
if language_code is not None:
|
||||
task_data["language_code"] = language_code
|
||||
if include_seed_keyword is not None:
|
||||
task_data["include_seed_keyword"] = include_seed_keyword
|
||||
if include_serp_info is not None:
|
||||
task_data["include_serp_info"] = include_serp_info
|
||||
if include_clickstream_data is not None:
|
||||
task_data["include_clickstream_data"] = include_clickstream_data
|
||||
if limit is not None:
|
||||
task_data["limit"] = limit
|
||||
|
||||
payload = [task_data]
|
||||
|
||||
response = await self.requests.post(
|
||||
endpoint,
|
||||
headers=self._get_headers(),
|
||||
json=payload,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Check for API errors
|
||||
if response.status != 200:
|
||||
error_message = data.get("status_message", "Unknown error")
|
||||
raise Exception(
|
||||
f"DataForSEO API error ({response.status}): {error_message}"
|
||||
)
|
||||
|
||||
# Extract the results from the response
|
||||
if data.get("tasks") and len(data["tasks"]) > 0:
|
||||
task = data["tasks"][0]
|
||||
if task.get("status_code") == 20000: # Success code
|
||||
return task.get("result", [])
|
||||
else:
|
||||
error_msg = task.get("status_message", "Task failed")
|
||||
raise Exception(f"DataForSEO task error: {error_msg}")
|
||||
|
||||
return []
|
||||
|
||||
async def related_keywords(
|
||||
self,
|
||||
keyword: str,
|
||||
location_code: Optional[int] = None,
|
||||
language_code: Optional[str] = None,
|
||||
include_seed_keyword: bool = True,
|
||||
include_serp_info: bool = False,
|
||||
include_clickstream_data: bool = False,
|
||||
limit: int = 100,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get related keywords from DataForSEO Labs.
|
||||
|
||||
Args:
|
||||
keyword: Seed keyword
|
||||
location_code: Location code for targeting
|
||||
language_code: Language code (e.g., "en")
|
||||
include_seed_keyword: Include seed keyword in results
|
||||
include_serp_info: Include SERP data
|
||||
include_clickstream_data: Include clickstream metrics
|
||||
limit: Maximum number of results (up to 3000)
|
||||
|
||||
Returns:
|
||||
API response with related keywords
|
||||
"""
|
||||
endpoint = f"{self.API_URL}/v3/dataforseo_labs/google/related_keywords/live"
|
||||
|
||||
# Build payload only with non-None values to avoid sending null fields
|
||||
task_data: dict[str, Any] = {
|
||||
"keyword": keyword,
|
||||
}
|
||||
|
||||
if location_code is not None:
|
||||
task_data["location_code"] = location_code
|
||||
if language_code is not None:
|
||||
task_data["language_code"] = language_code
|
||||
if include_seed_keyword is not None:
|
||||
task_data["include_seed_keyword"] = include_seed_keyword
|
||||
if include_serp_info is not None:
|
||||
task_data["include_serp_info"] = include_serp_info
|
||||
if include_clickstream_data is not None:
|
||||
task_data["include_clickstream_data"] = include_clickstream_data
|
||||
if limit is not None:
|
||||
task_data["limit"] = limit
|
||||
|
||||
payload = [task_data]
|
||||
|
||||
response = await self.requests.post(
|
||||
endpoint,
|
||||
headers=self._get_headers(),
|
||||
json=payload,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Check for API errors
|
||||
if response.status != 200:
|
||||
error_message = data.get("status_message", "Unknown error")
|
||||
raise Exception(
|
||||
f"DataForSEO API error ({response.status}): {error_message}"
|
||||
)
|
||||
|
||||
# Extract the results from the response
|
||||
if data.get("tasks") and len(data["tasks"]) > 0:
|
||||
task = data["tasks"][0]
|
||||
if task.get("status_code") == 20000: # Success code
|
||||
return task.get("result", [])
|
||||
else:
|
||||
error_msg = task.get("status_message", "Task failed")
|
||||
raise Exception(f"DataForSEO task error: {error_msg}")
|
||||
|
||||
return []
|
||||
@@ -1,17 +0,0 @@
|
||||
"""
|
||||
Configuration for all DataForSEO blocks using the new SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
# Build the DataForSEO provider with username/password authentication
|
||||
dataforseo = (
|
||||
ProviderBuilder("dataforseo")
|
||||
.with_user_password(
|
||||
username_env_var="DATAFORSEO_USERNAME",
|
||||
password_env_var="DATAFORSEO_PASSWORD",
|
||||
title="DataForSEO Credentials",
|
||||
)
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
@@ -1,273 +0,0 @@
|
||||
"""
|
||||
DataForSEO Google Keyword Suggestions block.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
|
||||
from ._api import DataForSeoClient
|
||||
from ._config import dataforseo
|
||||
|
||||
|
||||
class KeywordSuggestion(BlockSchema):
|
||||
"""Schema for a keyword suggestion result."""
|
||||
|
||||
keyword: str = SchemaField(description="The keyword suggestion")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
)
|
||||
competition: Optional[float] = SchemaField(
|
||||
description="Competition level (0-1)", default=None
|
||||
)
|
||||
cpc: Optional[float] = SchemaField(
|
||||
description="Cost per click in USD", default=None
|
||||
)
|
||||
keyword_difficulty: Optional[int] = SchemaField(
|
||||
description="Keyword difficulty score", default=None
|
||||
)
|
||||
serp_info: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="data from SERP for each keyword", default=None
|
||||
)
|
||||
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Clickstream data metrics", default=None
|
||||
)
|
||||
|
||||
|
||||
class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
"""Block for getting keyword suggestions from DataForSEO Labs."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = dataforseo.credentials_field(
|
||||
description="DataForSEO credentials (username and password)"
|
||||
)
|
||||
keyword: str = SchemaField(description="Seed keyword to get suggestions for")
|
||||
location_code: Optional[int] = SchemaField(
|
||||
description="Location code for targeting (e.g., 2840 for USA)",
|
||||
default=2840, # USA
|
||||
)
|
||||
language_code: Optional[str] = SchemaField(
|
||||
description="Language code (e.g., 'en' for English)",
|
||||
default="en",
|
||||
)
|
||||
include_seed_keyword: bool = SchemaField(
|
||||
description="Include the seed keyword in results",
|
||||
default=True,
|
||||
)
|
||||
include_serp_info: bool = SchemaField(
|
||||
description="Include SERP information",
|
||||
default=False,
|
||||
)
|
||||
include_clickstream_data: bool = SchemaField(
|
||||
description="Include clickstream metrics",
|
||||
default=False,
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of results (up to 3000)",
|
||||
default=100,
|
||||
ge=1,
|
||||
le=3000,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
suggestions: List[KeywordSuggestion] = SchemaField(
|
||||
description="List of keyword suggestions with metrics"
|
||||
)
|
||||
suggestion: KeywordSuggestion = SchemaField(
|
||||
description="A single keyword suggestion with metrics"
|
||||
)
|
||||
total_count: int = SchemaField(
|
||||
description="Total number of suggestions returned"
|
||||
)
|
||||
seed_keyword: str = SchemaField(
|
||||
description="The seed keyword used for the query"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="73c3e7c4-2b3f-4e9f-9e3e-8f7a5c3e2d45",
|
||||
description="Get keyword suggestions from DataForSEO Labs Google API",
|
||||
categories={BlockCategory.SEARCH, BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"credentials": dataforseo.get_test_credentials().model_dump(),
|
||||
"keyword": "digital marketing",
|
||||
"location_code": 2840,
|
||||
"language_code": "en",
|
||||
"limit": 1,
|
||||
},
|
||||
test_credentials=dataforseo.get_test_credentials(),
|
||||
test_output=[
|
||||
(
|
||||
"suggestion",
|
||||
lambda x: hasattr(x, "keyword")
|
||||
and x.keyword == "digital marketing strategy",
|
||||
),
|
||||
("suggestions", lambda x: isinstance(x, list) and len(x) == 1),
|
||||
("total_count", 1),
|
||||
("seed_keyword", "digital marketing"),
|
||||
],
|
||||
test_mock={
|
||||
"_fetch_keyword_suggestions": lambda *args, **kwargs: [
|
||||
{
|
||||
"items": [
|
||||
{
|
||||
"keyword": "digital marketing strategy",
|
||||
"keyword_info": {
|
||||
"search_volume": 10000,
|
||||
"competition": 0.5,
|
||||
"cpc": 2.5,
|
||||
},
|
||||
"keyword_properties": {
|
||||
"keyword_difficulty": 50,
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
async def _fetch_keyword_suggestions(
|
||||
self,
|
||||
client: DataForSeoClient,
|
||||
input_data: Input,
|
||||
) -> Any:
|
||||
"""Private method to fetch keyword suggestions - can be mocked for testing."""
|
||||
return await client.keyword_suggestions(
|
||||
keyword=input_data.keyword,
|
||||
location_code=input_data.location_code,
|
||||
language_code=input_data.language_code,
|
||||
include_seed_keyword=input_data.include_seed_keyword,
|
||||
include_serp_info=input_data.include_serp_info,
|
||||
include_clickstream_data=input_data.include_clickstream_data,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: UserPasswordCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the keyword suggestions query."""
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
suggestions = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Create the KeywordSuggestion object
|
||||
suggestion = KeywordSuggestion(
|
||||
keyword=item.get("keyword", ""),
|
||||
search_volume=item.get("keyword_info", {}).get("search_volume"),
|
||||
competition=item.get("keyword_info", {}).get("competition"),
|
||||
cpc=item.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=item.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
item.get("serp_info") if input_data.include_serp_info else None
|
||||
),
|
||||
clickstream_data=(
|
||||
item.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "suggestion", suggestion
|
||||
suggestions.append(suggestion)
|
||||
|
||||
yield "suggestions", suggestions
|
||||
yield "total_count", len(suggestions)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
|
||||
|
||||
class KeywordSuggestionExtractorBlock(Block):
|
||||
"""Extracts individual fields from a KeywordSuggestion object."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
suggestion: KeywordSuggestion = SchemaField(
|
||||
description="The keyword suggestion object to extract fields from"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
keyword: str = SchemaField(description="The keyword suggestion")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
)
|
||||
competition: Optional[float] = SchemaField(
|
||||
description="Competition level (0-1)", default=None
|
||||
)
|
||||
cpc: Optional[float] = SchemaField(
|
||||
description="Cost per click in USD", default=None
|
||||
)
|
||||
keyword_difficulty: Optional[int] = SchemaField(
|
||||
description="Keyword difficulty score", default=None
|
||||
)
|
||||
serp_info: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="data from SERP for each keyword", default=None
|
||||
)
|
||||
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Clickstream data metrics", default=None
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4193cb94-677c-48b0-9eec-6ac72fffd0f2",
|
||||
description="Extract individual fields from a KeywordSuggestion object",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"suggestion": KeywordSuggestion(
|
||||
keyword="test keyword",
|
||||
search_volume=1000,
|
||||
competition=0.5,
|
||||
cpc=2.5,
|
||||
keyword_difficulty=60,
|
||||
).model_dump()
|
||||
},
|
||||
test_output=[
|
||||
("keyword", "test keyword"),
|
||||
("search_volume", 1000),
|
||||
("competition", 0.5),
|
||||
("cpc", 2.5),
|
||||
("keyword_difficulty", 60),
|
||||
("serp_info", None),
|
||||
("clickstream_data", None),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Extract fields from the KeywordSuggestion object."""
|
||||
suggestion = input_data.suggestion
|
||||
|
||||
yield "keyword", suggestion.keyword
|
||||
yield "search_volume", suggestion.search_volume
|
||||
yield "competition", suggestion.competition
|
||||
yield "cpc", suggestion.cpc
|
||||
yield "keyword_difficulty", suggestion.keyword_difficulty
|
||||
yield "serp_info", suggestion.serp_info
|
||||
yield "clickstream_data", suggestion.clickstream_data
|
||||
@@ -1,283 +0,0 @@
|
||||
"""
|
||||
DataForSEO Google Related Keywords block.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
|
||||
from ._api import DataForSeoClient
|
||||
from ._config import dataforseo
|
||||
|
||||
|
||||
class RelatedKeyword(BlockSchema):
|
||||
"""Schema for a related keyword result."""
|
||||
|
||||
keyword: str = SchemaField(description="The related keyword")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
)
|
||||
competition: Optional[float] = SchemaField(
|
||||
description="Competition level (0-1)", default=None
|
||||
)
|
||||
cpc: Optional[float] = SchemaField(
|
||||
description="Cost per click in USD", default=None
|
||||
)
|
||||
keyword_difficulty: Optional[int] = SchemaField(
|
||||
description="Keyword difficulty score", default=None
|
||||
)
|
||||
serp_info: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="SERP data for the keyword", default=None
|
||||
)
|
||||
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Clickstream data metrics", default=None
|
||||
)
|
||||
|
||||
|
||||
class DataForSeoRelatedKeywordsBlock(Block):
|
||||
"""Block for getting related keywords from DataForSEO Labs."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = dataforseo.credentials_field(
|
||||
description="DataForSEO credentials (username and password)"
|
||||
)
|
||||
keyword: str = SchemaField(
|
||||
description="Seed keyword to find related keywords for"
|
||||
)
|
||||
location_code: Optional[int] = SchemaField(
|
||||
description="Location code for targeting (e.g., 2840 for USA)",
|
||||
default=2840, # USA
|
||||
)
|
||||
language_code: Optional[str] = SchemaField(
|
||||
description="Language code (e.g., 'en' for English)",
|
||||
default="en",
|
||||
)
|
||||
include_seed_keyword: bool = SchemaField(
|
||||
description="Include the seed keyword in results",
|
||||
default=True,
|
||||
)
|
||||
include_serp_info: bool = SchemaField(
|
||||
description="Include SERP information",
|
||||
default=False,
|
||||
)
|
||||
include_clickstream_data: bool = SchemaField(
|
||||
description="Include clickstream metrics",
|
||||
default=False,
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of results (up to 3000)",
|
||||
default=100,
|
||||
ge=1,
|
||||
le=3000,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
related_keywords: List[RelatedKeyword] = SchemaField(
|
||||
description="List of related keywords with metrics"
|
||||
)
|
||||
related_keyword: RelatedKeyword = SchemaField(
|
||||
description="A related keyword with metrics"
|
||||
)
|
||||
total_count: int = SchemaField(
|
||||
description="Total number of related keywords returned"
|
||||
)
|
||||
seed_keyword: str = SchemaField(
|
||||
description="The seed keyword used for the query"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8f2e4d6a-1b3c-4a5e-9d7f-2c8e6a4b3f1d",
|
||||
description="Get related keywords from DataForSEO Labs Google API",
|
||||
categories={BlockCategory.SEARCH, BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"credentials": dataforseo.get_test_credentials().model_dump(),
|
||||
"keyword": "content marketing",
|
||||
"location_code": 2840,
|
||||
"language_code": "en",
|
||||
"limit": 1,
|
||||
},
|
||||
test_credentials=dataforseo.get_test_credentials(),
|
||||
test_output=[
|
||||
(
|
||||
"related_keyword",
|
||||
lambda x: hasattr(x, "keyword") and x.keyword == "content strategy",
|
||||
),
|
||||
("related_keywords", lambda x: isinstance(x, list) and len(x) == 1),
|
||||
("total_count", 1),
|
||||
("seed_keyword", "content marketing"),
|
||||
],
|
||||
test_mock={
|
||||
"_fetch_related_keywords": lambda *args, **kwargs: [
|
||||
{
|
||||
"items": [
|
||||
{
|
||||
"keyword_data": {
|
||||
"keyword": "content strategy",
|
||||
"keyword_info": {
|
||||
"search_volume": 8000,
|
||||
"competition": 0.4,
|
||||
"cpc": 3.0,
|
||||
},
|
||||
"keyword_properties": {
|
||||
"keyword_difficulty": 45,
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
async def _fetch_related_keywords(
|
||||
self,
|
||||
client: DataForSeoClient,
|
||||
input_data: Input,
|
||||
) -> Any:
|
||||
"""Private method to fetch related keywords - can be mocked for testing."""
|
||||
return await client.related_keywords(
|
||||
keyword=input_data.keyword,
|
||||
location_code=input_data.location_code,
|
||||
language_code=input_data.language_code,
|
||||
include_seed_keyword=input_data.include_seed_keyword,
|
||||
include_serp_info=input_data.include_serp_info,
|
||||
include_clickstream_data=input_data.include_clickstream_data,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: UserPasswordCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the related keywords query."""
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_related_keywords(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
related_keywords = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
keyword_data = item.get("keyword_data", {})
|
||||
|
||||
# Create the RelatedKeyword object
|
||||
keyword = RelatedKeyword(
|
||||
keyword=keyword_data.get("keyword", ""),
|
||||
search_volume=keyword_data.get("keyword_info", {}).get(
|
||||
"search_volume"
|
||||
),
|
||||
competition=keyword_data.get("keyword_info", {}).get("competition"),
|
||||
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=keyword_data.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
keyword_data.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
keyword_data.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "related_keyword", keyword
|
||||
related_keywords.append(keyword)
|
||||
|
||||
yield "related_keywords", related_keywords
|
||||
yield "total_count", len(related_keywords)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
|
||||
|
||||
class RelatedKeywordExtractorBlock(Block):
|
||||
"""Extracts individual fields from a RelatedKeyword object."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
related_keyword: RelatedKeyword = SchemaField(
|
||||
description="The related keyword object to extract fields from"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
keyword: str = SchemaField(description="The related keyword")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
)
|
||||
competition: Optional[float] = SchemaField(
|
||||
description="Competition level (0-1)", default=None
|
||||
)
|
||||
cpc: Optional[float] = SchemaField(
|
||||
description="Cost per click in USD", default=None
|
||||
)
|
||||
keyword_difficulty: Optional[int] = SchemaField(
|
||||
description="Keyword difficulty score", default=None
|
||||
)
|
||||
serp_info: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="SERP data for the keyword", default=None
|
||||
)
|
||||
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
|
||||
description="Clickstream data metrics", default=None
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="98342061-09d2-4952-bf77-0761fc8cc9a8",
|
||||
description="Extract individual fields from a RelatedKeyword object",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"related_keyword": RelatedKeyword(
|
||||
keyword="test related keyword",
|
||||
search_volume=800,
|
||||
competition=0.4,
|
||||
cpc=3.0,
|
||||
keyword_difficulty=55,
|
||||
).model_dump()
|
||||
},
|
||||
test_output=[
|
||||
("keyword", "test related keyword"),
|
||||
("search_volume", 800),
|
||||
("competition", 0.4),
|
||||
("cpc", 3.0),
|
||||
("keyword_difficulty", 55),
|
||||
("serp_info", None),
|
||||
("clickstream_data", None),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Extract fields from the RelatedKeyword object."""
|
||||
related_keyword = input_data.related_keyword
|
||||
|
||||
yield "keyword", related_keyword.keyword
|
||||
yield "search_volume", related_keyword.search_volume
|
||||
yield "competition", related_keyword.competition
|
||||
yield "cpc", related_keyword.cpc
|
||||
yield "keyword_difficulty", related_keyword.keyword_difficulty
|
||||
yield "serp_info", related_keyword.serp_info
|
||||
yield "clickstream_data", related_keyword.clickstream_data
|
||||
237
autogpt_platform/backend/backend/blocks/discord.py
Normal file
237
autogpt_platform/backend/backend/blocks/discord.py
Normal file
@@ -0,0 +1,237 @@
|
||||
from typing import Literal
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
DiscordCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.DISCORD], Literal["api_key"]
|
||||
]
|
||||
|
||||
|
||||
def DiscordCredentialsField() -> DiscordCredentials:
|
||||
return CredentialsField(description="Discord bot token")
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="discord",
|
||||
api_key=SecretStr("test_api_key"),
|
||||
title="Mock Discord API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.type,
|
||||
}
|
||||
|
||||
|
||||
class ReadDiscordMessagesBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
|
||||
class Output(BlockSchema):
|
||||
message_content: str = SchemaField(
|
||||
description="The content of the message received"
|
||||
)
|
||||
channel_name: str = SchemaField(
|
||||
description="The name of the channel the message was received from"
|
||||
)
|
||||
username: str = SchemaField(
|
||||
description="The username of the user who sent the message"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="df06086a-d5ac-4abb-9996-2ad0acb2eff7",
|
||||
input_schema=ReadDiscordMessagesBlock.Input, # Assign input schema
|
||||
output_schema=ReadDiscordMessagesBlock.Output, # Assign output schema
|
||||
description="Reads messages from a Discord channel using a bot token.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
test_input={
|
||||
"continuous_read": False,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"message_content",
|
||||
"Hello!\n\nFile from user: example.txt\nContent: This is the content of the file.",
|
||||
),
|
||||
("channel_name", "general"),
|
||||
("username", "test_user"),
|
||||
],
|
||||
test_mock={
|
||||
"run_bot": lambda token: {
|
||||
"output_data": "Hello!\n\nFile from user: example.txt\nContent: This is the content of the file.",
|
||||
"channel_name": "general",
|
||||
"username": "test_user",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def run_bot(self, token: SecretStr):
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
|
||||
client = discord.Client(intents=intents)
|
||||
|
||||
self.output_data = None
|
||||
self.channel_name = None
|
||||
self.username = None
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
print(f"Logged in as {client.user}")
|
||||
|
||||
@client.event
|
||||
async def on_message(message):
|
||||
if message.author == client.user:
|
||||
return
|
||||
|
||||
self.output_data = message.content
|
||||
self.channel_name = message.channel.name
|
||||
self.username = message.author.name
|
||||
|
||||
if message.attachments:
|
||||
attachment = message.attachments[0] # Process the first attachment
|
||||
if attachment.filename.endswith((".txt", ".py")):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(attachment.url) as response:
|
||||
file_content = response.text()
|
||||
self.output_data += f"\n\nFile from user: {attachment.filename}\nContent: {file_content}"
|
||||
|
||||
await client.close()
|
||||
|
||||
await client.start(token.get_secret_value())
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
async for output_name, output_value in self.__run(input_data, credentials):
|
||||
yield output_name, output_value
|
||||
|
||||
async def __run(
|
||||
self, input_data: Input, credentials: APIKeyCredentials
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = await self.run_bot(credentials.api_key)
|
||||
|
||||
# For testing purposes, use the mocked result
|
||||
if isinstance(result, dict):
|
||||
self.output_data = result.get("output_data")
|
||||
self.channel_name = result.get("channel_name")
|
||||
self.username = result.get("username")
|
||||
|
||||
if (
|
||||
self.output_data is None
|
||||
or self.channel_name is None
|
||||
or self.username is None
|
||||
):
|
||||
raise ValueError("No message, channel name, or username received.")
|
||||
|
||||
yield "message_content", self.output_data
|
||||
yield "channel_name", self.channel_name
|
||||
yield "username", self.username
|
||||
|
||||
except discord.errors.LoginFailure as login_err:
|
||||
raise ValueError(f"Login error occurred: {login_err}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"An error occurred: {e}")
|
||||
|
||||
|
||||
class SendDiscordMessageBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
message_content: str = SchemaField(
|
||||
description="The content of the message received"
|
||||
)
|
||||
channel_name: str = SchemaField(
|
||||
description="The name of the channel the message was received from"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(
|
||||
description="The status of the operation (e.g., 'Message sent', 'Error')"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d0822ab5-9f8a-44a3-8971-531dd0178b6b",
|
||||
input_schema=SendDiscordMessageBlock.Input, # Assign input schema
|
||||
output_schema=SendDiscordMessageBlock.Output, # Assign output schema
|
||||
description="Sends a message to a Discord channel using a bot token.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
test_input={
|
||||
"channel_name": "general",
|
||||
"message_content": "Hello, Discord!",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[("status", "Message sent")],
|
||||
test_mock={
|
||||
"send_message": lambda token, channel_name, message_content: "Message sent"
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def send_message(self, token: str, channel_name: str, message_content: str):
|
||||
intents = discord.Intents.default()
|
||||
intents.guilds = True # Required for fetching guild/channel information
|
||||
client = discord.Client(intents=intents)
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
print(f"Logged in as {client.user}")
|
||||
for guild in client.guilds:
|
||||
for channel in guild.text_channels:
|
||||
if channel.name == channel_name:
|
||||
# Split message into chunks if it exceeds 2000 characters
|
||||
for chunk in self.chunk_message(message_content):
|
||||
await channel.send(chunk)
|
||||
self.output_data = "Message sent"
|
||||
await client.close()
|
||||
return
|
||||
|
||||
self.output_data = "Channel not found"
|
||||
await client.close()
|
||||
|
||||
await client.start(token)
|
||||
|
||||
def chunk_message(self, message: str, limit: int = 2000) -> list:
|
||||
"""Splits a message into chunks not exceeding the Discord limit."""
|
||||
return [message[i : i + limit] for i in range(0, len(message), limit)]
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = await self.send_message(
|
||||
credentials.api_key.get_secret_value(),
|
||||
input_data.channel_name,
|
||||
input_data.message_content,
|
||||
)
|
||||
|
||||
# For testing purposes, use the mocked result
|
||||
if isinstance(result, str):
|
||||
self.output_data = result
|
||||
|
||||
if self.output_data is None:
|
||||
raise ValueError("No status message received.")
|
||||
|
||||
yield "status", self.output_data
|
||||
|
||||
except discord.errors.LoginFailure as login_err:
|
||||
raise ValueError(f"Login error occurred: {login_err}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"An error occurred: {e}")
|
||||
@@ -1,117 +0,0 @@
|
||||
"""
|
||||
Discord API helper functions for making authenticated requests.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.util.request import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DiscordAPIException(Exception):
|
||||
"""Exception raised for Discord API errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class DiscordOAuthUser(BaseModel):
|
||||
"""Model for Discord OAuth user response."""
|
||||
|
||||
user_id: str
|
||||
username: str
|
||||
avatar_url: str
|
||||
banner: Optional[str] = None
|
||||
accent_color: Optional[int] = None
|
||||
|
||||
|
||||
def get_api(credentials: OAuth2Credentials) -> Requests:
|
||||
"""
|
||||
Create a Requests instance configured for Discord API calls with OAuth2 credentials.
|
||||
|
||||
Args:
|
||||
credentials: The OAuth2 credentials containing the access token.
|
||||
|
||||
Returns:
|
||||
A configured Requests instance for Discord API calls.
|
||||
"""
|
||||
return Requests(
|
||||
trusted_origins=[],
|
||||
extra_headers={
|
||||
"Authorization": f"Bearer {credentials.access_token.get_secret_value()}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
raise_for_status=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_current_user(credentials: OAuth2Credentials) -> DiscordOAuthUser:
|
||||
"""
|
||||
Fetch the current user's information using Discord OAuth2 API.
|
||||
|
||||
Reference: https://discord.com/developers/docs/resources/user#get-current-user
|
||||
|
||||
Args:
|
||||
credentials: The OAuth2 credentials.
|
||||
|
||||
Returns:
|
||||
A model containing user data with avatar URL.
|
||||
|
||||
Raises:
|
||||
DiscordAPIException: If the API request fails.
|
||||
"""
|
||||
api = get_api(credentials)
|
||||
response = await api.get("https://discord.com/api/oauth2/@me")
|
||||
|
||||
if not response.ok:
|
||||
error_text = response.text()
|
||||
raise DiscordAPIException(
|
||||
f"Failed to fetch user info: {response.status} - {error_text}",
|
||||
response.status,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
logger.info(f"Discord OAuth2 API Response: {data}")
|
||||
|
||||
# The /api/oauth2/@me endpoint returns a user object nested in the response
|
||||
user_info = data.get("user", {})
|
||||
logger.info(f"User info extracted: {user_info}")
|
||||
|
||||
# Build avatar URL
|
||||
user_id = user_info.get("id")
|
||||
avatar_hash = user_info.get("avatar")
|
||||
if avatar_hash:
|
||||
# Custom avatar
|
||||
avatar_ext = "gif" if avatar_hash.startswith("a_") else "png"
|
||||
avatar_url = (
|
||||
f"https://cdn.discordapp.com/avatars/{user_id}/{avatar_hash}.{avatar_ext}"
|
||||
)
|
||||
else:
|
||||
# Default avatar based on discriminator or user ID
|
||||
discriminator = user_info.get("discriminator", "0")
|
||||
if discriminator == "0":
|
||||
# New username system - use user ID for default avatar
|
||||
default_avatar_index = (int(user_id) >> 22) % 6
|
||||
else:
|
||||
# Legacy discriminator system
|
||||
default_avatar_index = int(discriminator) % 5
|
||||
avatar_url = (
|
||||
f"https://cdn.discordapp.com/embed/avatars/{default_avatar_index}.png"
|
||||
)
|
||||
|
||||
result = DiscordOAuthUser(
|
||||
user_id=user_id,
|
||||
username=user_info.get("username", ""),
|
||||
avatar_url=avatar_url,
|
||||
banner=user_info.get("banner"),
|
||||
accent_color=user_info.get("accent_color"),
|
||||
)
|
||||
|
||||
logger.info(f"Returning user data: {result.model_dump()}")
|
||||
return result
|
||||
@@ -1,74 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
secrets = Secrets()
|
||||
DISCORD_OAUTH_IS_CONFIGURED = bool(
|
||||
secrets.discord_client_id and secrets.discord_client_secret
|
||||
)
|
||||
|
||||
# Bot token credentials (existing)
|
||||
DiscordBotCredentials = APIKeyCredentials
|
||||
DiscordBotCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.DISCORD], Literal["api_key"]
|
||||
]
|
||||
|
||||
# OAuth2 credentials (new)
|
||||
DiscordOAuthCredentials = OAuth2Credentials
|
||||
DiscordOAuthCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.DISCORD], Literal["oauth2"]
|
||||
]
|
||||
|
||||
|
||||
def DiscordBotCredentialsField() -> DiscordBotCredentialsInput:
|
||||
"""Creates a Discord bot token credentials field."""
|
||||
return CredentialsField(description="Discord bot token")
|
||||
|
||||
|
||||
def DiscordOAuthCredentialsField(scopes: list[str]) -> DiscordOAuthCredentialsInput:
|
||||
"""Creates a Discord OAuth2 credentials field."""
|
||||
return CredentialsField(
|
||||
description="Discord OAuth2 credentials",
|
||||
required_scopes=set(scopes) | {"identify"}, # Basic user info scope
|
||||
)
|
||||
|
||||
|
||||
# Test credentials for bot tokens
|
||||
TEST_BOT_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="discord",
|
||||
api_key=SecretStr("test_api_key"),
|
||||
title="Mock Discord API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_BOT_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_BOT_CREDENTIALS.provider,
|
||||
"id": TEST_BOT_CREDENTIALS.id,
|
||||
"type": TEST_BOT_CREDENTIALS.type,
|
||||
"title": TEST_BOT_CREDENTIALS.type,
|
||||
}
|
||||
|
||||
# Test credentials for OAuth2
|
||||
TEST_OAUTH_CREDENTIALS = OAuth2Credentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="discord",
|
||||
access_token=SecretStr("test_access_token"),
|
||||
title="Mock Discord OAuth",
|
||||
scopes=["identify"],
|
||||
username="testuser",
|
||||
)
|
||||
TEST_OAUTH_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_OAUTH_CREDENTIALS.provider,
|
||||
"id": TEST_OAUTH_CREDENTIALS.id,
|
||||
"type": TEST_OAUTH_CREDENTIALS.type,
|
||||
"title": TEST_OAUTH_CREDENTIALS.type,
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,99 +0,0 @@
|
||||
"""
|
||||
Discord OAuth-based blocks.
|
||||
"""
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import DiscordOAuthUser, get_current_user
|
||||
from ._auth import (
|
||||
DISCORD_OAUTH_IS_CONFIGURED,
|
||||
TEST_OAUTH_CREDENTIALS,
|
||||
TEST_OAUTH_CREDENTIALS_INPUT,
|
||||
DiscordOAuthCredentialsField,
|
||||
DiscordOAuthCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class DiscordGetCurrentUserBlock(Block):
|
||||
"""
|
||||
Gets information about the currently authenticated Discord user using OAuth2.
|
||||
This block requires Discord OAuth2 credentials (not bot tokens).
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordOAuthCredentialsInput = DiscordOAuthCredentialsField(
|
||||
["identify"]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
user_id: str = SchemaField(description="The authenticated user's Discord ID")
|
||||
username: str = SchemaField(description="The user's username")
|
||||
avatar_url: str = SchemaField(description="URL to the user's avatar image")
|
||||
banner_url: str = SchemaField(
|
||||
description="URL to the user's banner image (if set)", default=""
|
||||
)
|
||||
accent_color: int = SchemaField(
|
||||
description="The user's accent color as an integer", default=0
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8c7e39b8-4e9d-4f3a-b4e1-2a8c9d5f6e3b",
|
||||
input_schema=DiscordGetCurrentUserBlock.Input,
|
||||
output_schema=DiscordGetCurrentUserBlock.Output,
|
||||
description="Gets information about the currently authenticated Discord user using OAuth2 credentials.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
disabled=not DISCORD_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"credentials": TEST_OAUTH_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_OAUTH_CREDENTIALS,
|
||||
test_output=[
|
||||
("user_id", "123456789012345678"),
|
||||
("username", "testuser"),
|
||||
(
|
||||
"avatar_url",
|
||||
"https://cdn.discordapp.com/avatars/123456789012345678/avatar.png",
|
||||
),
|
||||
("banner_url", ""),
|
||||
("accent_color", 0),
|
||||
],
|
||||
test_mock={
|
||||
"get_user": lambda _: DiscordOAuthUser(
|
||||
user_id="123456789012345678",
|
||||
username="testuser",
|
||||
avatar_url="https://cdn.discordapp.com/avatars/123456789012345678/avatar.png",
|
||||
banner=None,
|
||||
accent_color=0,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_user(credentials: OAuth2Credentials) -> DiscordOAuthUser:
|
||||
user_info = await get_current_user(credentials)
|
||||
return user_info
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: OAuth2Credentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = await self.get_user(credentials)
|
||||
|
||||
# Yield each output field
|
||||
yield "user_id", result.user_id
|
||||
yield "username", result.username
|
||||
yield "avatar_url", result.avatar_url
|
||||
|
||||
# Handle banner URL if banner hash exists
|
||||
if result.banner:
|
||||
banner_url = f"https://cdn.discordapp.com/banners/{result.user_id}/{result.banner}.png"
|
||||
yield "banner_url", banner_url
|
||||
else:
|
||||
yield "banner_url", ""
|
||||
|
||||
yield "accent_color", result.accent_color or 0
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to get Discord user info: {e}")
|
||||
@@ -1,408 +0,0 @@
|
||||
"""
|
||||
API module for Enrichlayer integration.
|
||||
|
||||
This module provides a client for interacting with the Enrichlayer API,
|
||||
which allows fetching LinkedIn profile data and related information.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import enum
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.model import APIKeyCredentials
|
||||
from backend.util.request import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class EnrichlayerAPIException(Exception):
|
||||
"""Exception raised for Enrichlayer API errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class FallbackToCache(enum.Enum):
|
||||
ON_ERROR = "on-error"
|
||||
NEVER = "never"
|
||||
|
||||
|
||||
class UseCache(enum.Enum):
|
||||
IF_PRESENT = "if-present"
|
||||
NEVER = "never"
|
||||
|
||||
|
||||
class SocialMediaProfiles(BaseModel):
|
||||
"""Social media profiles model."""
|
||||
|
||||
twitter: Optional[str] = None
|
||||
facebook: Optional[str] = None
|
||||
github: Optional[str] = None
|
||||
|
||||
|
||||
class Experience(BaseModel):
|
||||
"""Experience model for LinkedIn profiles."""
|
||||
|
||||
company: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
location: Optional[str] = None
|
||||
starts_at: Optional[dict[str, int]] = None
|
||||
ends_at: Optional[dict[str, int]] = None
|
||||
company_linkedin_profile_url: Optional[str] = None
|
||||
|
||||
|
||||
class Education(BaseModel):
|
||||
"""Education model for LinkedIn profiles."""
|
||||
|
||||
school: Optional[str] = None
|
||||
degree_name: Optional[str] = None
|
||||
field_of_study: Optional[str] = None
|
||||
starts_at: Optional[dict[str, int]] = None
|
||||
ends_at: Optional[dict[str, int]] = None
|
||||
school_linkedin_profile_url: Optional[str] = None
|
||||
|
||||
|
||||
class PersonProfileResponse(BaseModel):
|
||||
"""Response model for LinkedIn person profile.
|
||||
|
||||
This model represents the response from Enrichlayer's LinkedIn profile API.
|
||||
The API returns comprehensive profile data including work experience,
|
||||
education, skills, and contact information (when available).
|
||||
|
||||
Example API Response:
|
||||
{
|
||||
"public_identifier": "johnsmith",
|
||||
"full_name": "John Smith",
|
||||
"occupation": "Software Engineer at Tech Corp",
|
||||
"experiences": [
|
||||
{
|
||||
"company": "Tech Corp",
|
||||
"title": "Software Engineer",
|
||||
"starts_at": {"year": 2020, "month": 1}
|
||||
}
|
||||
],
|
||||
"education": [...],
|
||||
"skills": ["Python", "JavaScript", ...]
|
||||
}
|
||||
"""
|
||||
|
||||
public_identifier: Optional[str] = None
|
||||
profile_pic_url: Optional[str] = None
|
||||
full_name: Optional[str] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
occupation: Optional[str] = None
|
||||
headline: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
country_full_name: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
state: Optional[str] = None
|
||||
experiences: Optional[list[Experience]] = None
|
||||
education: Optional[list[Education]] = None
|
||||
languages: Optional[list[str]] = None
|
||||
skills: Optional[list[str]] = None
|
||||
inferred_salary: Optional[dict[str, Any]] = None
|
||||
personal_email: Optional[str] = None
|
||||
personal_contact_number: Optional[str] = None
|
||||
social_media_profiles: Optional[SocialMediaProfiles] = None
|
||||
extra: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class SimilarProfile(BaseModel):
|
||||
"""Similar profile model for LinkedIn person lookup."""
|
||||
|
||||
similarity: float
|
||||
linkedin_profile_url: str
|
||||
|
||||
|
||||
class PersonLookupResponse(BaseModel):
|
||||
"""Response model for LinkedIn person lookup.
|
||||
|
||||
This model represents the response from Enrichlayer's person lookup API.
|
||||
The API returns a LinkedIn profile URL and similarity scores when
|
||||
searching for a person by name and company.
|
||||
|
||||
Example API Response:
|
||||
{
|
||||
"url": "https://www.linkedin.com/in/johnsmith/",
|
||||
"name_similarity_score": 0.95,
|
||||
"company_similarity_score": 0.88,
|
||||
"title_similarity_score": 0.75,
|
||||
"location_similarity_score": 0.60
|
||||
}
|
||||
"""
|
||||
|
||||
url: str | None = None
|
||||
name_similarity_score: float | None
|
||||
company_similarity_score: float | None
|
||||
title_similarity_score: float | None
|
||||
location_similarity_score: float | None
|
||||
last_updated: datetime.datetime | None = None
|
||||
profile: PersonProfileResponse | None = None
|
||||
|
||||
|
||||
class RoleLookupResponse(BaseModel):
|
||||
"""Response model for LinkedIn role lookup.
|
||||
|
||||
This model represents the response from Enrichlayer's role lookup API.
|
||||
The API returns LinkedIn profile data for a specific role at a company.
|
||||
|
||||
Example API Response:
|
||||
{
|
||||
"linkedin_profile_url": "https://www.linkedin.com/in/johnsmith/",
|
||||
"profile_data": {...} // Full PersonProfileResponse data when enrich_profile=True
|
||||
}
|
||||
"""
|
||||
|
||||
linkedin_profile_url: Optional[str] = None
|
||||
profile_data: Optional[PersonProfileResponse] = None
|
||||
|
||||
|
||||
class ProfilePictureResponse(BaseModel):
|
||||
"""Response model for LinkedIn profile picture.
|
||||
|
||||
This model represents the response from Enrichlayer's profile picture API.
|
||||
The API returns a URL to the person's LinkedIn profile picture.
|
||||
|
||||
Example API Response:
|
||||
{
|
||||
"tmp_profile_pic_url": "https://media.licdn.com/dms/image/..."
|
||||
}
|
||||
"""
|
||||
|
||||
tmp_profile_pic_url: str = Field(
|
||||
..., description="URL of the profile picture", alias="tmp_profile_pic_url"
|
||||
)
|
||||
|
||||
@property
|
||||
def profile_picture_url(self) -> str:
|
||||
"""Backward compatibility property for profile_picture_url."""
|
||||
return self.tmp_profile_pic_url
|
||||
|
||||
|
||||
class EnrichlayerClient:
|
||||
"""Client for interacting with the Enrichlayer API."""
|
||||
|
||||
API_BASE_URL = "https://enrichlayer.com/api/v2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
credentials: Optional[APIKeyCredentials] = None,
|
||||
custom_requests: Optional[Requests] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Enrichlayer client.
|
||||
|
||||
Args:
|
||||
credentials: The credentials to use for authentication.
|
||||
custom_requests: Custom Requests instance for testing.
|
||||
"""
|
||||
if custom_requests:
|
||||
self._requests = custom_requests
|
||||
else:
|
||||
headers: dict[str, str] = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if credentials:
|
||||
headers["Authorization"] = (
|
||||
f"Bearer {credentials.api_key.get_secret_value()}"
|
||||
)
|
||||
|
||||
self._requests = Requests(
|
||||
extra_headers=headers,
|
||||
raise_for_status=False,
|
||||
)
|
||||
|
||||
async def _handle_response(self, response) -> Any:
|
||||
"""
|
||||
Handle API response and check for errors.
|
||||
|
||||
Args:
|
||||
response: The response object from the request.
|
||||
|
||||
Returns:
|
||||
The response data.
|
||||
|
||||
Raises:
|
||||
EnrichlayerAPIException: If the API request fails.
|
||||
"""
|
||||
if not response.ok:
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_message = error_data.get("message", "")
|
||||
except JSONDecodeError:
|
||||
error_message = response.text
|
||||
|
||||
raise EnrichlayerAPIException(
|
||||
f"Enrichlayer API request failed ({response.status_code}): {error_message}",
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def fetch_profile(
|
||||
self,
|
||||
linkedin_url: str,
|
||||
fallback_to_cache: FallbackToCache = FallbackToCache.ON_ERROR,
|
||||
use_cache: UseCache = UseCache.IF_PRESENT,
|
||||
include_skills: bool = False,
|
||||
include_inferred_salary: bool = False,
|
||||
include_personal_email: bool = False,
|
||||
include_personal_contact_number: bool = False,
|
||||
include_social_media: bool = False,
|
||||
include_extra: bool = False,
|
||||
) -> PersonProfileResponse:
|
||||
"""
|
||||
Fetch a LinkedIn profile with optional parameters.
|
||||
|
||||
Args:
|
||||
linkedin_url: The LinkedIn profile URL to fetch.
|
||||
fallback_to_cache: Cache usage if live fetch fails ('on-error' or 'never').
|
||||
use_cache: Cache utilization ('if-present' or 'never').
|
||||
include_skills: Whether to include skills data.
|
||||
include_inferred_salary: Whether to include inferred salary data.
|
||||
include_personal_email: Whether to include personal email.
|
||||
include_personal_contact_number: Whether to include personal contact number.
|
||||
include_social_media: Whether to include social media profiles.
|
||||
include_extra: Whether to include additional data.
|
||||
|
||||
Returns:
|
||||
The LinkedIn profile data.
|
||||
|
||||
Raises:
|
||||
EnrichlayerAPIException: If the API request fails.
|
||||
"""
|
||||
params = {
|
||||
"url": linkedin_url,
|
||||
"fallback_to_cache": fallback_to_cache.value.lower(),
|
||||
"use_cache": use_cache.value.lower(),
|
||||
}
|
||||
|
||||
if include_skills:
|
||||
params["skills"] = "include"
|
||||
if include_inferred_salary:
|
||||
params["inferred_salary"] = "include"
|
||||
if include_personal_email:
|
||||
params["personal_email"] = "include"
|
||||
if include_personal_contact_number:
|
||||
params["personal_contact_number"] = "include"
|
||||
if include_social_media:
|
||||
params["twitter_profile_id"] = "include"
|
||||
params["facebook_profile_id"] = "include"
|
||||
params["github_profile_id"] = "include"
|
||||
if include_extra:
|
||||
params["extra"] = "include"
|
||||
|
||||
response = await self._requests.get(
|
||||
f"{self.API_BASE_URL}/profile", params=params
|
||||
)
|
||||
return PersonProfileResponse(**await self._handle_response(response))
|
||||
|
||||
async def lookup_person(
|
||||
self,
|
||||
first_name: str,
|
||||
company_domain: str,
|
||||
last_name: str | None = None,
|
||||
location: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
include_similarity_checks: bool = False,
|
||||
enrich_profile: bool = False,
|
||||
) -> PersonLookupResponse:
|
||||
"""
|
||||
Look up a LinkedIn profile by person's information.
|
||||
|
||||
Args:
|
||||
first_name: The person's first name.
|
||||
last_name: The person's last name.
|
||||
company_domain: The domain of the company they work for.
|
||||
location: The person's location.
|
||||
title: The person's job title.
|
||||
include_similarity_checks: Whether to include similarity checks.
|
||||
enrich_profile: Whether to enrich the profile.
|
||||
|
||||
Returns:
|
||||
The LinkedIn profile lookup result.
|
||||
|
||||
Raises:
|
||||
EnrichlayerAPIException: If the API request fails.
|
||||
"""
|
||||
params = {"first_name": first_name, "company_domain": company_domain}
|
||||
|
||||
if last_name:
|
||||
params["last_name"] = last_name
|
||||
if location:
|
||||
params["location"] = location
|
||||
if title:
|
||||
params["title"] = title
|
||||
if include_similarity_checks:
|
||||
params["similarity_checks"] = "include"
|
||||
if enrich_profile:
|
||||
params["enrich_profile"] = "enrich"
|
||||
|
||||
response = await self._requests.get(
|
||||
f"{self.API_BASE_URL}/profile/resolve", params=params
|
||||
)
|
||||
return PersonLookupResponse(**await self._handle_response(response))
|
||||
|
||||
async def lookup_role(
|
||||
self, role: str, company_name: str, enrich_profile: bool = False
|
||||
) -> RoleLookupResponse:
|
||||
"""
|
||||
Look up a LinkedIn profile by role in a company.
|
||||
|
||||
Args:
|
||||
role: The role title (e.g., CEO, CTO).
|
||||
company_name: The name of the company.
|
||||
enrich_profile: Whether to enrich the profile.
|
||||
|
||||
Returns:
|
||||
The LinkedIn profile lookup result.
|
||||
|
||||
Raises:
|
||||
EnrichlayerAPIException: If the API request fails.
|
||||
"""
|
||||
params = {
|
||||
"role": role,
|
||||
"company_name": company_name,
|
||||
}
|
||||
|
||||
if enrich_profile:
|
||||
params["enrich_profile"] = "enrich"
|
||||
|
||||
response = await self._requests.get(
|
||||
f"{self.API_BASE_URL}/find/company/role", params=params
|
||||
)
|
||||
return RoleLookupResponse(**await self._handle_response(response))
|
||||
|
||||
async def get_profile_picture(
|
||||
self, linkedin_profile_url: str
|
||||
) -> ProfilePictureResponse:
|
||||
"""
|
||||
Get a LinkedIn profile picture URL.
|
||||
|
||||
Args:
|
||||
linkedin_profile_url: The LinkedIn profile URL.
|
||||
|
||||
Returns:
|
||||
The profile picture URL.
|
||||
|
||||
Raises:
|
||||
EnrichlayerAPIException: If the API request fails.
|
||||
"""
|
||||
params = {
|
||||
"linkedin_person_profile_url": linkedin_profile_url,
|
||||
}
|
||||
|
||||
response = await self._requests.get(
|
||||
f"{self.API_BASE_URL}/person/profile-picture", params=params
|
||||
)
|
||||
return ProfilePictureResponse(**await self._handle_response(response))
|
||||
@@ -1,34 +0,0 @@
|
||||
"""
|
||||
Authentication module for Enrichlayer API integration.
|
||||
|
||||
This module provides credential types and test credentials for the Enrichlayer API.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
# Define the type of credentials input expected for Enrichlayer API
|
||||
EnrichlayerCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.ENRICHLAYER], Literal["api_key"]
|
||||
]
|
||||
|
||||
# Mock credentials for testing Enrichlayer API integration
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="1234a567-89bc-4def-ab12-3456cdef7890",
|
||||
provider="enrichlayer",
|
||||
api_key=SecretStr("mock-enrichlayer-api-key"),
|
||||
title="Mock Enrichlayer API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
# Dictionary representation of test credentials for input fields
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
@@ -1,527 +0,0 @@
|
||||
"""
|
||||
Block definitions for Enrichlayer API integration.
|
||||
|
||||
This module implements blocks for interacting with the Enrichlayer API,
|
||||
which provides access to LinkedIn profile data and related information.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
from ._api import (
|
||||
EnrichlayerClient,
|
||||
Experience,
|
||||
FallbackToCache,
|
||||
PersonLookupResponse,
|
||||
PersonProfileResponse,
|
||||
RoleLookupResponse,
|
||||
UseCache,
|
||||
)
|
||||
from ._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, EnrichlayerCredentialsInput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GetLinkedinProfileBlock(Block):
|
||||
"""Block to fetch LinkedIn profile data using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
"""Input schema for GetLinkedinProfileBlock."""
|
||||
|
||||
linkedin_url: str = SchemaField(
|
||||
description="LinkedIn profile URL to fetch data from",
|
||||
placeholder="https://www.linkedin.com/in/username/",
|
||||
)
|
||||
fallback_to_cache: FallbackToCache = SchemaField(
|
||||
description="Cache usage if live fetch fails",
|
||||
default=FallbackToCache.ON_ERROR,
|
||||
advanced=True,
|
||||
)
|
||||
use_cache: UseCache = SchemaField(
|
||||
description="Cache utilization strategy",
|
||||
default=UseCache.IF_PRESENT,
|
||||
advanced=True,
|
||||
)
|
||||
include_skills: bool = SchemaField(
|
||||
description="Include skills data",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
include_inferred_salary: bool = SchemaField(
|
||||
description="Include inferred salary data",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
include_personal_email: bool = SchemaField(
|
||||
description="Include personal email",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
include_personal_contact_number: bool = SchemaField(
|
||||
description="Include personal contact number",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
include_social_media: bool = SchemaField(
|
||||
description="Include social media profiles",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
include_extra: bool = SchemaField(
|
||||
description="Include additional data",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
credentials: EnrichlayerCredentialsInput = CredentialsField(
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
"""Output schema for GetLinkedinProfileBlock."""
|
||||
|
||||
profile: PersonProfileResponse = SchemaField(
|
||||
description="LinkedIn profile data"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GetLinkedinProfileBlock."""
|
||||
super().__init__(
|
||||
id="f6e0ac73-4f1d-4acb-b4b7-b67066c5984e",
|
||||
description="Fetch LinkedIn profile data using Enrichlayer",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=GetLinkedinProfileBlock.Input,
|
||||
output_schema=GetLinkedinProfileBlock.Output,
|
||||
test_input={
|
||||
"linkedin_url": "https://www.linkedin.com/in/williamhgates/",
|
||||
"include_skills": True,
|
||||
"include_social_media": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"profile",
|
||||
PersonProfileResponse(
|
||||
public_identifier="williamhgates",
|
||||
full_name="Bill Gates",
|
||||
occupation="Co-chair at Bill & Melinda Gates Foundation",
|
||||
experiences=[
|
||||
Experience(
|
||||
company="Bill & Melinda Gates Foundation",
|
||||
title="Co-chair",
|
||||
starts_at={"year": 2000},
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"_fetch_profile": lambda *args, **kwargs: PersonProfileResponse(
|
||||
public_identifier="williamhgates",
|
||||
full_name="Bill Gates",
|
||||
occupation="Co-chair at Bill & Melinda Gates Foundation",
|
||||
experiences=[
|
||||
Experience(
|
||||
company="Bill & Melinda Gates Foundation",
|
||||
title="Co-chair",
|
||||
starts_at={"year": 2000},
|
||||
)
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _fetch_profile(
|
||||
credentials: APIKeyCredentials,
|
||||
linkedin_url: str,
|
||||
fallback_to_cache: FallbackToCache = FallbackToCache.ON_ERROR,
|
||||
use_cache: UseCache = UseCache.IF_PRESENT,
|
||||
include_skills: bool = False,
|
||||
include_inferred_salary: bool = False,
|
||||
include_personal_email: bool = False,
|
||||
include_personal_contact_number: bool = False,
|
||||
include_social_media: bool = False,
|
||||
include_extra: bool = False,
|
||||
):
|
||||
client = EnrichlayerClient(credentials)
|
||||
profile = await client.fetch_profile(
|
||||
linkedin_url=linkedin_url,
|
||||
fallback_to_cache=fallback_to_cache,
|
||||
use_cache=use_cache,
|
||||
include_skills=include_skills,
|
||||
include_inferred_salary=include_inferred_salary,
|
||||
include_personal_email=include_personal_email,
|
||||
include_personal_contact_number=include_personal_contact_number,
|
||||
include_social_media=include_social_media,
|
||||
include_extra=include_extra,
|
||||
)
|
||||
return profile
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Run the block to fetch LinkedIn profile data.
|
||||
|
||||
Args:
|
||||
input_data: Input parameters for the block
|
||||
credentials: API key credentials for Enrichlayer
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Yields:
|
||||
Tuples of (output_name, output_value)
|
||||
"""
|
||||
try:
|
||||
profile = await self._fetch_profile(
|
||||
credentials=credentials,
|
||||
linkedin_url=input_data.linkedin_url,
|
||||
fallback_to_cache=input_data.fallback_to_cache,
|
||||
use_cache=input_data.use_cache,
|
||||
include_skills=input_data.include_skills,
|
||||
include_inferred_salary=input_data.include_inferred_salary,
|
||||
include_personal_email=input_data.include_personal_email,
|
||||
include_personal_contact_number=input_data.include_personal_contact_number,
|
||||
include_social_media=input_data.include_social_media,
|
||||
include_extra=input_data.include_extra,
|
||||
)
|
||||
yield "profile", profile
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching LinkedIn profile: {str(e)}")
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class LinkedinPersonLookupBlock(Block):
|
||||
"""Block to look up LinkedIn profiles by person's information using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
"""Input schema for LinkedinPersonLookupBlock."""
|
||||
|
||||
first_name: str = SchemaField(
|
||||
description="Person's first name",
|
||||
placeholder="John",
|
||||
advanced=False,
|
||||
)
|
||||
last_name: str | None = SchemaField(
|
||||
description="Person's last name",
|
||||
placeholder="Doe",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
company_domain: str = SchemaField(
|
||||
description="Domain of the company they work for (optional)",
|
||||
placeholder="example.com",
|
||||
advanced=False,
|
||||
)
|
||||
location: Optional[str] = SchemaField(
|
||||
description="Person's location (optional)",
|
||||
placeholder="San Francisco",
|
||||
default=None,
|
||||
)
|
||||
title: Optional[str] = SchemaField(
|
||||
description="Person's job title (optional)",
|
||||
placeholder="CEO",
|
||||
default=None,
|
||||
)
|
||||
include_similarity_checks: bool = SchemaField(
|
||||
description="Include similarity checks",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
enrich_profile: bool = SchemaField(
|
||||
description="Enrich the profile with additional data",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
credentials: EnrichlayerCredentialsInput = CredentialsField(
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
"""Output schema for LinkedinPersonLookupBlock."""
|
||||
|
||||
lookup_result: PersonLookupResponse = SchemaField(
|
||||
description="LinkedIn profile lookup result"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize LinkedinPersonLookupBlock."""
|
||||
super().__init__(
|
||||
id="d237a98a-5c4b-4a1c-b9e3-e6f9a6c81df7",
|
||||
description="Look up LinkedIn profiles by person information using Enrichlayer",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=LinkedinPersonLookupBlock.Input,
|
||||
output_schema=LinkedinPersonLookupBlock.Output,
|
||||
test_input={
|
||||
"first_name": "Bill",
|
||||
"last_name": "Gates",
|
||||
"company_domain": "gatesfoundation.org",
|
||||
"include_similarity_checks": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"lookup_result",
|
||||
PersonLookupResponse(
|
||||
url="https://www.linkedin.com/in/williamhgates/",
|
||||
name_similarity_score=0.93,
|
||||
company_similarity_score=0.83,
|
||||
title_similarity_score=0.3,
|
||||
location_similarity_score=0.20,
|
||||
),
|
||||
)
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"_lookup_person": lambda *args, **kwargs: PersonLookupResponse(
|
||||
url="https://www.linkedin.com/in/williamhgates/",
|
||||
name_similarity_score=0.93,
|
||||
company_similarity_score=0.83,
|
||||
title_similarity_score=0.3,
|
||||
location_similarity_score=0.20,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _lookup_person(
|
||||
credentials: APIKeyCredentials,
|
||||
first_name: str,
|
||||
company_domain: str,
|
||||
last_name: str | None = None,
|
||||
location: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
include_similarity_checks: bool = False,
|
||||
enrich_profile: bool = False,
|
||||
):
|
||||
client = EnrichlayerClient(credentials=credentials)
|
||||
lookup_result = await client.lookup_person(
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
company_domain=company_domain,
|
||||
location=location,
|
||||
title=title,
|
||||
include_similarity_checks=include_similarity_checks,
|
||||
enrich_profile=enrich_profile,
|
||||
)
|
||||
return lookup_result
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Run the block to look up LinkedIn profiles.
|
||||
|
||||
Args:
|
||||
input_data: Input parameters for the block
|
||||
credentials: API key credentials for Enrichlayer
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Yields:
|
||||
Tuples of (output_name, output_value)
|
||||
"""
|
||||
try:
|
||||
lookup_result = await self._lookup_person(
|
||||
credentials=credentials,
|
||||
first_name=input_data.first_name,
|
||||
last_name=input_data.last_name,
|
||||
company_domain=input_data.company_domain,
|
||||
location=input_data.location,
|
||||
title=input_data.title,
|
||||
include_similarity_checks=input_data.include_similarity_checks,
|
||||
enrich_profile=input_data.enrich_profile,
|
||||
)
|
||||
yield "lookup_result", lookup_result
|
||||
except Exception as e:
|
||||
logger.error(f"Error looking up LinkedIn profile: {str(e)}")
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class LinkedinRoleLookupBlock(Block):
|
||||
"""Block to look up LinkedIn profiles by role in a company using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
"""Input schema for LinkedinRoleLookupBlock."""
|
||||
|
||||
role: str = SchemaField(
|
||||
description="Role title (e.g., CEO, CTO)",
|
||||
placeholder="CEO",
|
||||
)
|
||||
company_name: str = SchemaField(
|
||||
description="Name of the company",
|
||||
placeholder="Microsoft",
|
||||
)
|
||||
enrich_profile: bool = SchemaField(
|
||||
description="Enrich the profile with additional data",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
credentials: EnrichlayerCredentialsInput = CredentialsField(
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
"""Output schema for LinkedinRoleLookupBlock."""
|
||||
|
||||
role_lookup_result: RoleLookupResponse = SchemaField(
|
||||
description="LinkedIn role lookup result"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize LinkedinRoleLookupBlock."""
|
||||
super().__init__(
|
||||
id="3b9fc742-06d4-49c7-b5ce-7e302dd7c8a7",
|
||||
description="Look up LinkedIn profiles by role in a company using Enrichlayer",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=LinkedinRoleLookupBlock.Input,
|
||||
output_schema=LinkedinRoleLookupBlock.Output,
|
||||
test_input={
|
||||
"role": "Co-chair",
|
||||
"company_name": "Gates Foundation",
|
||||
"enrich_profile": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"role_lookup_result",
|
||||
RoleLookupResponse(
|
||||
linkedin_profile_url="https://www.linkedin.com/in/williamhgates/",
|
||||
),
|
||||
)
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"_lookup_role": lambda *args, **kwargs: RoleLookupResponse(
|
||||
linkedin_profile_url="https://www.linkedin.com/in/williamhgates/",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _lookup_role(
|
||||
credentials: APIKeyCredentials,
|
||||
role: str,
|
||||
company_name: str,
|
||||
enrich_profile: bool = False,
|
||||
):
|
||||
client = EnrichlayerClient(credentials=credentials)
|
||||
role_lookup_result = await client.lookup_role(
|
||||
role=role,
|
||||
company_name=company_name,
|
||||
enrich_profile=enrich_profile,
|
||||
)
|
||||
return role_lookup_result
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Run the block to look up LinkedIn profiles by role.
|
||||
|
||||
Args:
|
||||
input_data: Input parameters for the block
|
||||
credentials: API key credentials for Enrichlayer
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Yields:
|
||||
Tuples of (output_name, output_value)
|
||||
"""
|
||||
try:
|
||||
role_lookup_result = await self._lookup_role(
|
||||
credentials=credentials,
|
||||
role=input_data.role,
|
||||
company_name=input_data.company_name,
|
||||
enrich_profile=input_data.enrich_profile,
|
||||
)
|
||||
yield "role_lookup_result", role_lookup_result
|
||||
except Exception as e:
|
||||
logger.error(f"Error looking up role in company: {str(e)}")
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GetLinkedinProfilePictureBlock(Block):
|
||||
"""Block to get LinkedIn profile pictures using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
"""Input schema for GetLinkedinProfilePictureBlock."""
|
||||
|
||||
linkedin_profile_url: str = SchemaField(
|
||||
description="LinkedIn profile URL",
|
||||
placeholder="https://www.linkedin.com/in/username/",
|
||||
)
|
||||
credentials: EnrichlayerCredentialsInput = CredentialsField(
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
"""Output schema for GetLinkedinProfilePictureBlock."""
|
||||
|
||||
profile_picture_url: MediaFileType = SchemaField(
|
||||
description="LinkedIn profile picture URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GetLinkedinProfilePictureBlock."""
|
||||
super().__init__(
|
||||
id="68d5a942-9b3f-4e9a-b7c1-d96ea4321f0d",
|
||||
description="Get LinkedIn profile pictures using Enrichlayer",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=GetLinkedinProfilePictureBlock.Input,
|
||||
output_schema=GetLinkedinProfilePictureBlock.Output,
|
||||
test_input={
|
||||
"linkedin_profile_url": "https://www.linkedin.com/in/williamhgates/",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"profile_picture_url",
|
||||
"https://media.licdn.com/dms/image/C4D03AQFj-xjuXrLFSQ/profile-displayphoto-shrink_800_800/0/1576881858598?e=1686787200&v=beta&t=zrQC76QwsfQQIWthfOnrKRBMZ5D-qIAvzLXLmWgYvTk",
|
||||
)
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"_get_profile_picture": lambda *args, **kwargs: "https://media.licdn.com/dms/image/C4D03AQFj-xjuXrLFSQ/profile-displayphoto-shrink_800_800/0/1576881858598?e=1686787200&v=beta&t=zrQC76QwsfQQIWthfOnrKRBMZ5D-qIAvzLXLmWgYvTk",
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _get_profile_picture(
|
||||
credentials: APIKeyCredentials, linkedin_profile_url: str
|
||||
):
|
||||
client = EnrichlayerClient(credentials=credentials)
|
||||
profile_picture_response = await client.get_profile_picture(
|
||||
linkedin_profile_url=linkedin_profile_url,
|
||||
)
|
||||
return profile_picture_response.profile_picture_url
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Run the block to get LinkedIn profile pictures.
|
||||
|
||||
Args:
|
||||
input_data: Input parameters for the block
|
||||
credentials: API key credentials for Enrichlayer
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Yields:
|
||||
Tuples of (output_name, output_value)
|
||||
"""
|
||||
try:
|
||||
profile_picture = await self._get_profile_picture(
|
||||
credentials=credentials,
|
||||
linkedin_profile_url=input_data.linkedin_profile_url,
|
||||
)
|
||||
yield "profile_picture_url", profile_picture
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting profile picture: {str(e)}")
|
||||
yield "error", str(e)
|
||||
@@ -1,247 +0,0 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# Enum definitions based on available options
|
||||
class WebsetStatus(str, Enum):
|
||||
IDLE = "idle"
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
|
||||
|
||||
class WebsetSearchStatus(str, Enum):
|
||||
CREATED = "created"
|
||||
# Add more if known, based on example it's "created"
|
||||
|
||||
|
||||
class ImportStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class ImportFormat(str, Enum):
|
||||
CSV = "csv"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class EnrichmentStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class EnrichmentFormat(str, Enum):
|
||||
TEXT = "text"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorStatus(str, Enum):
|
||||
ENABLED = "enabled"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorBehaviorType(str, Enum):
|
||||
SEARCH = "search"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorRunStatus(str, Enum):
|
||||
CREATED = "created"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class CanceledReason(str, Enum):
|
||||
WEBSET_DELETED = "webset_deleted"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class FailedReason(str, Enum):
|
||||
INVALID_FORMAT = "invalid_format"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class Confidence(str, Enum):
|
||||
HIGH = "high"
|
||||
# Add more if known
|
||||
|
||||
|
||||
# Nested models
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
class Criterion(BaseModel):
|
||||
description: str
|
||||
successRate: Optional[int] = None
|
||||
|
||||
|
||||
class ExcludeItem(BaseModel):
|
||||
source: str = Field(default="import")
|
||||
id: str
|
||||
|
||||
|
||||
class Relationship(BaseModel):
|
||||
definition: str
|
||||
limit: Optional[float] = None
|
||||
|
||||
|
||||
class ScopeItem(BaseModel):
|
||||
source: str = Field(default="import")
|
||||
id: str
|
||||
relationship: Optional[Relationship] = None
|
||||
|
||||
|
||||
class Progress(BaseModel):
|
||||
found: int
|
||||
analyzed: int
|
||||
completion: int
|
||||
timeLeft: int
|
||||
|
||||
|
||||
class Bounds(BaseModel):
|
||||
min: int
|
||||
max: int
|
||||
|
||||
|
||||
class Expected(BaseModel):
|
||||
total: int
|
||||
confidence: str = Field(default="high") # Use str or Confidence enum
|
||||
bounds: Bounds
|
||||
|
||||
|
||||
class Recall(BaseModel):
|
||||
expected: Expected
|
||||
reasoning: str
|
||||
|
||||
|
||||
class WebsetSearch(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset_search")
|
||||
status: str = Field(default="created") # Or use WebsetSearchStatus
|
||||
websetId: str
|
||||
query: str
|
||||
entity: Entity
|
||||
criteria: List[Criterion]
|
||||
count: int
|
||||
behavior: str = Field(default="override")
|
||||
exclude: List[ExcludeItem]
|
||||
scope: List[ScopeItem]
|
||||
progress: Progress
|
||||
recall: Recall
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
canceledAt: Optional[datetime] = None
|
||||
canceledReason: Optional[str] = Field(default=None) # Or use CanceledReason
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class ImportEntity(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
class Import(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="import")
|
||||
status: str = Field(default="pending") # Or use ImportStatus
|
||||
format: str = Field(default="csv") # Or use ImportFormat
|
||||
entity: ImportEntity
|
||||
title: str
|
||||
count: int
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
failedReason: Optional[str] = Field(default=None) # Or use FailedReason
|
||||
failedAt: Optional[datetime] = None
|
||||
failedMessage: Optional[str] = None
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Option(BaseModel):
|
||||
label: str
|
||||
|
||||
|
||||
class WebsetEnrichment(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset_enrichment")
|
||||
status: str = Field(default="pending") # Or use EnrichmentStatus
|
||||
websetId: str
|
||||
title: str
|
||||
description: str
|
||||
format: str = Field(default="text") # Or use EnrichmentFormat
|
||||
options: List[Option]
|
||||
instructions: str
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Cadence(BaseModel):
|
||||
cron: str
|
||||
timezone: str = Field(default="Etc/UTC")
|
||||
|
||||
|
||||
class BehaviorConfig(BaseModel):
|
||||
query: Optional[str] = None
|
||||
criteria: Optional[List[Criterion]] = None
|
||||
entity: Optional[Entity] = None
|
||||
count: Optional[int] = None
|
||||
behavior: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class Behavior(BaseModel):
|
||||
type: str = Field(default="search") # Or use MonitorBehaviorType
|
||||
config: BehaviorConfig
|
||||
|
||||
|
||||
class MonitorRun(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="monitor_run")
|
||||
status: str = Field(default="created") # Or use MonitorRunStatus
|
||||
monitorId: str
|
||||
type: str = Field(default="search")
|
||||
completedAt: Optional[datetime] = None
|
||||
failedAt: Optional[datetime] = None
|
||||
failedReason: Optional[str] = None
|
||||
canceledAt: Optional[datetime] = None
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Monitor(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="monitor")
|
||||
status: str = Field(default="enabled") # Or use MonitorStatus
|
||||
websetId: str
|
||||
cadence: Cadence
|
||||
behavior: Behavior
|
||||
lastRun: Optional[MonitorRun] = None
|
||||
nextRunAt: Optional[datetime] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Webset(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset")
|
||||
status: WebsetStatus
|
||||
externalId: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
searches: List[WebsetSearch]
|
||||
imports: List[Import]
|
||||
enrichments: List[WebsetEnrichment]
|
||||
monitors: List[Monitor]
|
||||
streams: List[Any]
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ListWebsets(BaseModel):
|
||||
data: List[Webset]
|
||||
hasMore: bool
|
||||
nextCursor: Optional[str] = None
|
||||
@@ -114,7 +114,6 @@ class ExaWebsetWebhookBlock(Block):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="d0204ed8-8b81-408d-8b8d-ed087a546228",
|
||||
description="Receive webhook notifications for Exa webset events",
|
||||
categories={BlockCategory.INPUT},
|
||||
|
||||
@@ -1,33 +1,7 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Optional
|
||||
|
||||
from exa_py import Exa
|
||||
from exa_py.websets.types import (
|
||||
CreateCriterionParameters,
|
||||
CreateEnrichmentParameters,
|
||||
CreateWebsetParameters,
|
||||
CreateWebsetParametersSearch,
|
||||
ExcludeItem,
|
||||
Format,
|
||||
ImportItem,
|
||||
ImportSource,
|
||||
Option,
|
||||
ScopeItem,
|
||||
ScopeRelationship,
|
||||
ScopeSourceType,
|
||||
WebsetArticleEntity,
|
||||
WebsetCompanyEntity,
|
||||
WebsetCustomEntity,
|
||||
WebsetPersonEntity,
|
||||
WebsetResearchPaperEntity,
|
||||
WebsetStatus,
|
||||
)
|
||||
from pydantic import Field
|
||||
from typing import Any, Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseModel,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
@@ -38,69 +12,7 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
class SearchEntityType(str, Enum):
|
||||
COMPANY = "company"
|
||||
PERSON = "person"
|
||||
ARTICLE = "article"
|
||||
RESEARCH_PAPER = "research_paper"
|
||||
CUSTOM = "custom"
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
IMPORT = "import"
|
||||
WEBSET = "webset"
|
||||
|
||||
|
||||
class EnrichmentFormat(str, Enum):
|
||||
TEXT = "text"
|
||||
DATE = "date"
|
||||
NUMBER = "number"
|
||||
OPTIONS = "options"
|
||||
EMAIL = "email"
|
||||
PHONE = "phone"
|
||||
|
||||
|
||||
class Webset(BaseModel):
|
||||
id: str
|
||||
status: WebsetStatus | None = Field(..., title="WebsetStatus")
|
||||
"""
|
||||
The status of the webset
|
||||
"""
|
||||
external_id: Annotated[Optional[str], Field(alias="externalId")] = None
|
||||
"""
|
||||
The external identifier for the webset
|
||||
NOTE: Returning dict to avoid ui crashing due to nested objects
|
||||
"""
|
||||
searches: List[dict[str, Any]] | None = None
|
||||
"""
|
||||
The searches that have been performed on the webset.
|
||||
NOTE: Returning dict to avoid ui crashing due to nested objects
|
||||
"""
|
||||
enrichments: List[dict[str, Any]] | None = None
|
||||
"""
|
||||
The Enrichments to apply to the Webset Items.
|
||||
NOTE: Returning dict to avoid ui crashing due to nested objects
|
||||
"""
|
||||
monitors: List[dict[str, Any]] | None = None
|
||||
"""
|
||||
The Monitors for the Webset.
|
||||
NOTE: Returning dict to avoid ui crashing due to nested objects
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = {}
|
||||
"""
|
||||
Set of key-value pairs you want to associate with this object.
|
||||
"""
|
||||
created_at: Annotated[datetime, Field(alias="createdAt")] | None = None
|
||||
"""
|
||||
The date and time the webset was created
|
||||
"""
|
||||
updated_at: Annotated[datetime, Field(alias="updatedAt")] | None = None
|
||||
"""
|
||||
The date and time the webset was last updated
|
||||
"""
|
||||
from .helpers import WebsetEnrichmentConfig, WebsetSearchConfig
|
||||
|
||||
|
||||
class ExaCreateWebsetBlock(Block):
|
||||
@@ -108,121 +20,40 @@ class ExaCreateWebsetBlock(Block):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
|
||||
# Search parameters (flattened)
|
||||
search_query: str = SchemaField(
|
||||
description="Your search query. Use this to describe what you are looking for. Any URL provided will be crawled and used as context for the search.",
|
||||
placeholder="Marketing agencies based in the US, that focus on consumer products",
|
||||
search: WebsetSearchConfig = SchemaField(
|
||||
description="Initial search configuration for the Webset"
|
||||
)
|
||||
search_count: Optional[int] = SchemaField(
|
||||
default=10,
|
||||
description="Number of items the search will attempt to find. The actual number of items found may be less than this number depending on the search complexity.",
|
||||
ge=1,
|
||||
le=1000,
|
||||
)
|
||||
search_entity_type: SearchEntityType = SchemaField(
|
||||
default=SearchEntityType.AUTO,
|
||||
description="Entity type: 'company', 'person', 'article', 'research_paper', or 'custom'. If not provided, we automatically detect the entity from the query.",
|
||||
advanced=True,
|
||||
)
|
||||
search_entity_description: Optional[str] = SchemaField(
|
||||
enrichments: Optional[list[WebsetEnrichmentConfig]] = SchemaField(
|
||||
default=None,
|
||||
description="Description for custom entity type (required when search_entity_type is 'custom')",
|
||||
description="Enrichments to apply to Webset items",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Search criteria (flattened)
|
||||
search_criteria: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of criteria descriptions that every item will be evaluated against. If not provided, we automatically detect the criteria from the query.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Search exclude sources (flattened)
|
||||
search_exclude_sources: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source IDs (imports or websets) to exclude from search results",
|
||||
advanced=True,
|
||||
)
|
||||
search_exclude_types: list[SearchType] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source types corresponding to exclude sources ('import' or 'webset')",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Search scope sources (flattened)
|
||||
search_scope_sources: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source IDs (imports or websets) to limit search scope to",
|
||||
advanced=True,
|
||||
)
|
||||
search_scope_types: list[SearchType] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source types corresponding to scope sources ('import' or 'webset')",
|
||||
advanced=True,
|
||||
)
|
||||
search_scope_relationships: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of relationship definitions for hop searches (optional, one per scope source)",
|
||||
advanced=True,
|
||||
)
|
||||
search_scope_relationship_limits: list[int] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of limits on the number of related entities to find (optional, one per scope relationship)",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Import parameters (flattened)
|
||||
import_sources: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source IDs to import from",
|
||||
advanced=True,
|
||||
)
|
||||
import_types: list[SearchType] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source types corresponding to import sources ('import' or 'webset')",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Enrichment parameters (flattened)
|
||||
enrichment_descriptions: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of enrichment task descriptions to perform on each webset item",
|
||||
advanced=True,
|
||||
)
|
||||
enrichment_formats: list[EnrichmentFormat] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of formats for enrichment responses ('text', 'date', 'number', 'options', 'email', 'phone'). If not specified, we automatically select the best format.",
|
||||
advanced=True,
|
||||
)
|
||||
enrichment_options: list[list[str]] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of option lists for enrichments with 'options' format. Each inner list contains the option labels.",
|
||||
advanced=True,
|
||||
)
|
||||
enrichment_metadata: list[dict] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of metadata dictionaries for enrichments",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Webset metadata
|
||||
external_id: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="External identifier for the webset. You can use this to reference the webset by your own internal identifiers.",
|
||||
description="External identifier for the webset",
|
||||
placeholder="my-webset-123",
|
||||
advanced=True,
|
||||
)
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default_factory=dict,
|
||||
default=None,
|
||||
description="Key-value pairs to associate with this webset",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset: Webset = SchemaField(
|
||||
webset_id: str = SchemaField(
|
||||
description="The unique identifier for the created webset"
|
||||
)
|
||||
status: str = SchemaField(description="The status of the webset")
|
||||
external_id: Optional[str] = SchemaField(
|
||||
description="The external identifier for the webset", default=None
|
||||
)
|
||||
created_at: str = SchemaField(
|
||||
description="The date and time the webset was created"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -236,171 +67,44 @@ class ExaCreateWebsetBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/websets/v0/websets"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
exa = Exa(credentials.api_key.get_secret_value())
|
||||
# Build the payload
|
||||
payload: dict[str, Any] = {
|
||||
"search": input_data.search.model_dump(exclude_none=True),
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Build entity (if explicitly provided)
|
||||
# ------------------------------------------------------------
|
||||
entity = None
|
||||
if input_data.search_entity_type == SearchEntityType.COMPANY:
|
||||
entity = WebsetCompanyEntity(type="company")
|
||||
elif input_data.search_entity_type == SearchEntityType.PERSON:
|
||||
entity = WebsetPersonEntity(type="person")
|
||||
elif input_data.search_entity_type == SearchEntityType.ARTICLE:
|
||||
entity = WebsetArticleEntity(type="article")
|
||||
elif input_data.search_entity_type == SearchEntityType.RESEARCH_PAPER:
|
||||
entity = WebsetResearchPaperEntity(type="research_paper")
|
||||
elif (
|
||||
input_data.search_entity_type == SearchEntityType.CUSTOM
|
||||
and input_data.search_entity_description
|
||||
):
|
||||
entity = WebsetCustomEntity(
|
||||
type="custom", description=input_data.search_entity_description
|
||||
)
|
||||
# Convert enrichments to API format
|
||||
if input_data.enrichments:
|
||||
enrichments_data = []
|
||||
for enrichment in input_data.enrichments:
|
||||
enrichments_data.append(enrichment.model_dump(exclude_none=True))
|
||||
payload["enrichments"] = enrichments_data
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Build criteria list
|
||||
# ------------------------------------------------------------
|
||||
criteria = None
|
||||
if input_data.search_criteria:
|
||||
criteria = [
|
||||
CreateCriterionParameters(description=item)
|
||||
for item in input_data.search_criteria
|
||||
]
|
||||
if input_data.external_id:
|
||||
payload["externalId"] = input_data.external_id
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Build exclude sources list
|
||||
# ------------------------------------------------------------
|
||||
exclude_items = None
|
||||
if input_data.search_exclude_sources:
|
||||
exclude_items = []
|
||||
for idx, src_id in enumerate(input_data.search_exclude_sources):
|
||||
src_type = None
|
||||
if input_data.search_exclude_types and idx < len(
|
||||
input_data.search_exclude_types
|
||||
):
|
||||
src_type = input_data.search_exclude_types[idx]
|
||||
# Default to IMPORT if type missing
|
||||
if src_type == SearchType.WEBSET:
|
||||
source_enum = ImportSource.webset
|
||||
else:
|
||||
source_enum = ImportSource.import_
|
||||
exclude_items.append(ExcludeItem(source=source_enum, id=src_id))
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Build scope list
|
||||
# ------------------------------------------------------------
|
||||
scope_items = None
|
||||
if input_data.search_scope_sources:
|
||||
scope_items = []
|
||||
for idx, src_id in enumerate(input_data.search_scope_sources):
|
||||
src_type = None
|
||||
if input_data.search_scope_types and idx < len(
|
||||
input_data.search_scope_types
|
||||
):
|
||||
src_type = input_data.search_scope_types[idx]
|
||||
relationship = None
|
||||
if input_data.search_scope_relationships and idx < len(
|
||||
input_data.search_scope_relationships
|
||||
):
|
||||
rel_def = input_data.search_scope_relationships[idx]
|
||||
lim = None
|
||||
if input_data.search_scope_relationship_limits and idx < len(
|
||||
input_data.search_scope_relationship_limits
|
||||
):
|
||||
lim = input_data.search_scope_relationship_limits[idx]
|
||||
relationship = ScopeRelationship(definition=rel_def, limit=lim)
|
||||
if src_type == SearchType.WEBSET:
|
||||
src_enum = ScopeSourceType.webset
|
||||
else:
|
||||
src_enum = ScopeSourceType.import_
|
||||
scope_items.append(
|
||||
ScopeItem(source=src_enum, id=src_id, relationship=relationship)
|
||||
)
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Assemble search parameters (only if a query is provided)
|
||||
# ------------------------------------------------------------
|
||||
search_params = None
|
||||
if input_data.search_query:
|
||||
search_params = CreateWebsetParametersSearch(
|
||||
query=input_data.search_query,
|
||||
count=input_data.search_count,
|
||||
entity=entity,
|
||||
criteria=criteria,
|
||||
exclude=exclude_items,
|
||||
scope=scope_items,
|
||||
)
|
||||
yield "webset_id", data.get("id", "")
|
||||
yield "status", data.get("status", "")
|
||||
yield "external_id", data.get("externalId")
|
||||
yield "created_at", data.get("createdAt", "")
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Build imports list
|
||||
# ------------------------------------------------------------
|
||||
imports_params = None
|
||||
if input_data.import_sources:
|
||||
imports_params = []
|
||||
for idx, src_id in enumerate(input_data.import_sources):
|
||||
src_type = None
|
||||
if input_data.import_types and idx < len(input_data.import_types):
|
||||
src_type = input_data.import_types[idx]
|
||||
if src_type == SearchType.WEBSET:
|
||||
source_enum = ImportSource.webset
|
||||
else:
|
||||
source_enum = ImportSource.import_
|
||||
imports_params.append(ImportItem(source=source_enum, id=src_id))
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Build enrichment list
|
||||
# ------------------------------------------------------------
|
||||
enrichments_params = None
|
||||
if input_data.enrichment_descriptions:
|
||||
enrichments_params = []
|
||||
for idx, desc in enumerate(input_data.enrichment_descriptions):
|
||||
fmt = None
|
||||
if input_data.enrichment_formats and idx < len(
|
||||
input_data.enrichment_formats
|
||||
):
|
||||
fmt_enum = input_data.enrichment_formats[idx]
|
||||
if fmt_enum is not None:
|
||||
fmt = Format(
|
||||
fmt_enum.value if isinstance(fmt_enum, Enum) else fmt_enum
|
||||
)
|
||||
options_list = None
|
||||
if input_data.enrichment_options and idx < len(
|
||||
input_data.enrichment_options
|
||||
):
|
||||
raw_opts = input_data.enrichment_options[idx]
|
||||
if raw_opts:
|
||||
options_list = [Option(label=o) for o in raw_opts]
|
||||
metadata_obj = None
|
||||
if input_data.enrichment_metadata and idx < len(
|
||||
input_data.enrichment_metadata
|
||||
):
|
||||
metadata_obj = input_data.enrichment_metadata[idx]
|
||||
enrichments_params.append(
|
||||
CreateEnrichmentParameters(
|
||||
description=desc,
|
||||
format=fmt,
|
||||
options=options_list,
|
||||
metadata=metadata_obj,
|
||||
)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Create the webset
|
||||
# ------------------------------------------------------------
|
||||
webset = exa.websets.create(
|
||||
params=CreateWebsetParameters(
|
||||
search=search_params,
|
||||
imports=imports_params,
|
||||
enrichments=enrichments_params,
|
||||
external_id=input_data.external_id,
|
||||
metadata=input_data.metadata,
|
||||
)
|
||||
)
|
||||
|
||||
# Use alias field names returned from Exa SDK so that nested models validate correctly
|
||||
yield "webset", Webset.model_validate(webset.model_dump(by_alias=True))
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "webset_id", ""
|
||||
yield "status", ""
|
||||
yield "created_at", ""
|
||||
|
||||
|
||||
class ExaUpdateWebsetBlock(Block):
|
||||
@@ -479,11 +183,6 @@ class ExaListWebsetsBlock(Block):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
trigger: Any | None = SchemaField(
|
||||
default=None,
|
||||
description="Trigger for the webset, value is ignored!",
|
||||
advanced=False,
|
||||
)
|
||||
cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor for pagination through results",
|
||||
@@ -498,9 +197,7 @@ class ExaListWebsetsBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
websets: list[Webset] = SchemaField(
|
||||
description="List of websets", default_factory=list
|
||||
)
|
||||
websets: list = SchemaField(description="List of websets", default_factory=list)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more results to paginate through",
|
||||
default=False,
|
||||
@@ -558,6 +255,9 @@ class ExaGetWebsetBlock(Block):
|
||||
description="The ID or external ID of the Webset to retrieve",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
expand_items: bool = SchemaField(
|
||||
default=False, description="Include items in the response", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: str = SchemaField(description="The unique identifier for the webset")
|
||||
@@ -609,8 +309,12 @@ class ExaGetWebsetBlock(Block):
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
params = {}
|
||||
if input_data.expand_items:
|
||||
params["expand[]"] = "items"
|
||||
|
||||
try:
|
||||
response = await Requests().get(url, headers=headers)
|
||||
response = await Requests().get(url, headers=headers, params=params)
|
||||
data = response.json()
|
||||
|
||||
yield "webset_id", data.get("id", "")
|
||||
|
||||
5
autogpt_platform/backend/backend/blocks/firecrawl/extract.py
Executable file → Normal file
5
autogpt_platform/backend/backend/blocks/firecrawl/extract.py
Executable file → Normal file
@@ -29,8 +29,8 @@ class FirecrawlExtractBlock(Block):
|
||||
prompt: str | None = SchemaField(
|
||||
description="The prompt to use for the crawl", default=None, advanced=False
|
||||
)
|
||||
output_schema: dict | None = SchemaField(
|
||||
description="A Json Schema describing the output structure if more rigid structure is desired.",
|
||||
output_schema: str | None = SchemaField(
|
||||
description="A more rigid structure if you already know the JSON layout.",
|
||||
default=None,
|
||||
)
|
||||
enable_web_search: bool = SchemaField(
|
||||
@@ -56,6 +56,7 @@ class FirecrawlExtractBlock(Block):
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Sync call
|
||||
extract_result = app.extract(
|
||||
urls=input_data.urls,
|
||||
prompt=input_data.prompt,
|
||||
|
||||
@@ -1,388 +0,0 @@
|
||||
import logging
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CheckRunStatus(Enum):
|
||||
QUEUED = "queued"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
|
||||
|
||||
class CheckRunConclusion(Enum):
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
NEUTRAL = "neutral"
|
||||
CANCELLED = "cancelled"
|
||||
SKIPPED = "skipped"
|
||||
TIMED_OUT = "timed_out"
|
||||
ACTION_REQUIRED = "action_required"
|
||||
|
||||
|
||||
class GithubGetCIResultsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
target: str | int = SchemaField(
|
||||
description="Commit SHA or PR number to get CI results for",
|
||||
placeholder="abc123def or 123",
|
||||
)
|
||||
search_pattern: Optional[str] = SchemaField(
|
||||
description="Optional regex pattern to search for in CI logs (e.g., error messages, file names)",
|
||||
placeholder=".*error.*|.*warning.*",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
check_name_filter: Optional[str] = SchemaField(
|
||||
description="Optional filter for specific check names (supports wildcards)",
|
||||
placeholder="*lint* or build-*",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class CheckRunItem(TypedDict, total=False):
|
||||
id: int
|
||||
name: str
|
||||
status: str
|
||||
conclusion: Optional[str]
|
||||
started_at: Optional[str]
|
||||
completed_at: Optional[str]
|
||||
html_url: str
|
||||
details_url: Optional[str]
|
||||
output_title: Optional[str]
|
||||
output_summary: Optional[str]
|
||||
output_text: Optional[str]
|
||||
annotations: list[dict]
|
||||
|
||||
class MatchedLine(TypedDict):
|
||||
check_name: str
|
||||
line_number: int
|
||||
line: str
|
||||
context: list[str]
|
||||
|
||||
check_run: CheckRunItem = SchemaField(
|
||||
title="Check Run",
|
||||
description="Individual CI check run with details",
|
||||
)
|
||||
check_runs: list[CheckRunItem] = SchemaField(
|
||||
description="List of all CI check runs"
|
||||
)
|
||||
matched_line: MatchedLine = SchemaField(
|
||||
title="Matched Line",
|
||||
description="Line matching the search pattern with context",
|
||||
)
|
||||
matched_lines: list[MatchedLine] = SchemaField(
|
||||
description="All lines matching the search pattern across all checks"
|
||||
)
|
||||
overall_status: str = SchemaField(
|
||||
description="Overall CI status (pending, success, failure)"
|
||||
)
|
||||
overall_conclusion: str = SchemaField(
|
||||
description="Overall CI conclusion if completed"
|
||||
)
|
||||
total_checks: int = SchemaField(description="Total number of CI checks")
|
||||
passed_checks: int = SchemaField(description="Number of passed checks")
|
||||
failed_checks: int = SchemaField(description="Number of failed checks")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8ad9e103-78f2-4fdb-ba12-3571f2c95e98",
|
||||
description="This block gets CI results for a commit or PR, with optional search for specific errors/warnings in logs.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubGetCIResultsBlock.Input,
|
||||
output_schema=GithubGetCIResultsBlock.Output,
|
||||
test_input={
|
||||
"repo": "owner/repo",
|
||||
"target": "abc123def456",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("overall_status", "completed"),
|
||||
("overall_conclusion", "success"),
|
||||
("total_checks", 1),
|
||||
("passed_checks", 1),
|
||||
("failed_checks", 0),
|
||||
(
|
||||
"check_runs",
|
||||
[
|
||||
{
|
||||
"id": 123456,
|
||||
"name": "build",
|
||||
"status": "completed",
|
||||
"conclusion": "success",
|
||||
"started_at": "2024-01-01T00:00:00Z",
|
||||
"completed_at": "2024-01-01T00:05:00Z",
|
||||
"html_url": "https://github.com/owner/repo/runs/123456",
|
||||
"details_url": None,
|
||||
"output_title": "Build passed",
|
||||
"output_summary": "All tests passed",
|
||||
"output_text": "Build log output...",
|
||||
"annotations": [],
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_ci_results": lambda *args, **kwargs: {
|
||||
"check_runs": [
|
||||
{
|
||||
"id": 123456,
|
||||
"name": "build",
|
||||
"status": "completed",
|
||||
"conclusion": "success",
|
||||
"started_at": "2024-01-01T00:00:00Z",
|
||||
"completed_at": "2024-01-01T00:05:00Z",
|
||||
"html_url": "https://github.com/owner/repo/runs/123456",
|
||||
"details_url": None,
|
||||
"output_title": "Build passed",
|
||||
"output_summary": "All tests passed",
|
||||
"output_text": "Build log output...",
|
||||
"annotations": [],
|
||||
}
|
||||
],
|
||||
"total_count": 1,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_commit_sha(api, repo: str, target: str | int) -> str:
|
||||
"""Get commit SHA from either a commit SHA or PR URL."""
|
||||
# If it's already a SHA, return it
|
||||
|
||||
if isinstance(target, str):
|
||||
if re.match(r"^[0-9a-f]{6,40}$", target, re.IGNORECASE):
|
||||
return target
|
||||
|
||||
# If it's a PR URL, get the head SHA
|
||||
if isinstance(target, int):
|
||||
pr_url = f"https://api.github.com/repos/{repo}/pulls/{target}"
|
||||
response = await api.get(pr_url)
|
||||
pr_data = response.json()
|
||||
return pr_data["head"]["sha"]
|
||||
|
||||
raise ValueError("Target must be a commit SHA or PR URL")
|
||||
|
||||
@staticmethod
|
||||
async def search_in_logs(
|
||||
check_runs: list,
|
||||
pattern: str,
|
||||
) -> list[Output.MatchedLine]:
|
||||
"""Search for pattern in check run logs."""
|
||||
if not pattern:
|
||||
return []
|
||||
|
||||
matched_lines = []
|
||||
regex = re.compile(pattern, re.IGNORECASE | re.MULTILINE)
|
||||
|
||||
for check in check_runs:
|
||||
output_text = check.get("output_text", "") or ""
|
||||
if not output_text:
|
||||
continue
|
||||
|
||||
lines = output_text.split("\n")
|
||||
for i, line in enumerate(lines):
|
||||
if regex.search(line):
|
||||
# Get context (2 lines before and after)
|
||||
start = max(0, i - 2)
|
||||
end = min(len(lines), i + 3)
|
||||
context = lines[start:end]
|
||||
|
||||
matched_lines.append(
|
||||
{
|
||||
"check_name": check["name"],
|
||||
"line_number": i + 1,
|
||||
"line": line,
|
||||
"context": context,
|
||||
}
|
||||
)
|
||||
|
||||
return matched_lines
|
||||
|
||||
@staticmethod
|
||||
async def get_ci_results(
|
||||
credentials: GithubCredentials,
|
||||
repo: str,
|
||||
target: str | int,
|
||||
search_pattern: Optional[str] = None,
|
||||
check_name_filter: Optional[str] = None,
|
||||
) -> dict:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
|
||||
# Get the commit SHA
|
||||
commit_sha = await GithubGetCIResultsBlock.get_commit_sha(api, repo, target)
|
||||
|
||||
# Get check runs for the commit
|
||||
check_runs_url = (
|
||||
f"https://api.github.com/repos/{repo}/commits/{commit_sha}/check-runs"
|
||||
)
|
||||
|
||||
# Get all pages of check runs
|
||||
all_check_runs = []
|
||||
page = 1
|
||||
per_page = 100
|
||||
|
||||
while True:
|
||||
response = await api.get(
|
||||
check_runs_url, params={"per_page": per_page, "page": page}
|
||||
)
|
||||
data = response.json()
|
||||
|
||||
check_runs = data.get("check_runs", [])
|
||||
all_check_runs.extend(check_runs)
|
||||
|
||||
if len(check_runs) < per_page:
|
||||
break
|
||||
page += 1
|
||||
|
||||
# Filter by check name if specified
|
||||
if check_name_filter:
|
||||
import fnmatch
|
||||
|
||||
filtered_runs = []
|
||||
for run in all_check_runs:
|
||||
if fnmatch.fnmatch(run["name"].lower(), check_name_filter.lower()):
|
||||
filtered_runs.append(run)
|
||||
all_check_runs = filtered_runs
|
||||
|
||||
# Get check run details with logs
|
||||
detailed_runs = []
|
||||
for run in all_check_runs:
|
||||
# Get detailed output including logs
|
||||
if run.get("output", {}).get("text"):
|
||||
# Already has output
|
||||
detailed_run = {
|
||||
"id": run["id"],
|
||||
"name": run["name"],
|
||||
"status": run["status"],
|
||||
"conclusion": run.get("conclusion"),
|
||||
"started_at": run.get("started_at"),
|
||||
"completed_at": run.get("completed_at"),
|
||||
"html_url": run["html_url"],
|
||||
"details_url": run.get("details_url"),
|
||||
"output_title": run.get("output", {}).get("title"),
|
||||
"output_summary": run.get("output", {}).get("summary"),
|
||||
"output_text": run.get("output", {}).get("text"),
|
||||
"annotations": [],
|
||||
}
|
||||
else:
|
||||
# Try to get logs from the check run
|
||||
detailed_run = {
|
||||
"id": run["id"],
|
||||
"name": run["name"],
|
||||
"status": run["status"],
|
||||
"conclusion": run.get("conclusion"),
|
||||
"started_at": run.get("started_at"),
|
||||
"completed_at": run.get("completed_at"),
|
||||
"html_url": run["html_url"],
|
||||
"details_url": run.get("details_url"),
|
||||
"output_title": run.get("output", {}).get("title"),
|
||||
"output_summary": run.get("output", {}).get("summary"),
|
||||
"output_text": None,
|
||||
"annotations": [],
|
||||
}
|
||||
|
||||
# Get annotations if available
|
||||
if run.get("output", {}).get("annotations_count", 0) > 0:
|
||||
annotations_url = f"https://api.github.com/repos/{repo}/check-runs/{run['id']}/annotations"
|
||||
try:
|
||||
ann_response = await api.get(annotations_url)
|
||||
detailed_run["annotations"] = ann_response.json()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
detailed_runs.append(detailed_run)
|
||||
|
||||
return {
|
||||
"check_runs": detailed_runs,
|
||||
"total_count": len(detailed_runs),
|
||||
}
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
try:
|
||||
target = int(input_data.target)
|
||||
except ValueError:
|
||||
target = input_data.target
|
||||
|
||||
result = await self.get_ci_results(
|
||||
credentials,
|
||||
input_data.repo,
|
||||
target,
|
||||
input_data.search_pattern,
|
||||
input_data.check_name_filter,
|
||||
)
|
||||
|
||||
check_runs = result["check_runs"]
|
||||
|
||||
# Calculate overall status
|
||||
if not check_runs:
|
||||
yield "overall_status", "no_checks"
|
||||
yield "overall_conclusion", "no_checks"
|
||||
else:
|
||||
all_completed = all(run["status"] == "completed" for run in check_runs)
|
||||
if all_completed:
|
||||
yield "overall_status", "completed"
|
||||
# Determine overall conclusion
|
||||
has_failure = any(
|
||||
run["conclusion"] in ["failure", "timed_out", "action_required"]
|
||||
for run in check_runs
|
||||
)
|
||||
if has_failure:
|
||||
yield "overall_conclusion", "failure"
|
||||
else:
|
||||
yield "overall_conclusion", "success"
|
||||
else:
|
||||
yield "overall_status", "pending"
|
||||
yield "overall_conclusion", "pending"
|
||||
|
||||
# Count checks
|
||||
total = len(check_runs)
|
||||
passed = sum(1 for run in check_runs if run.get("conclusion") == "success")
|
||||
failed = sum(
|
||||
1 for run in check_runs if run.get("conclusion") in ["failure", "timed_out"]
|
||||
)
|
||||
|
||||
yield "total_checks", total
|
||||
yield "passed_checks", passed
|
||||
yield "failed_checks", failed
|
||||
|
||||
# Output check runs
|
||||
yield "check_runs", check_runs
|
||||
|
||||
# Search for patterns if specified
|
||||
if input_data.search_pattern:
|
||||
matched_lines = await self.search_in_logs(
|
||||
check_runs, input_data.search_pattern
|
||||
)
|
||||
if matched_lines:
|
||||
yield "matched_lines", matched_lines
|
||||
@@ -1,840 +0,0 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReviewEvent(Enum):
|
||||
COMMENT = "COMMENT"
|
||||
APPROVE = "APPROVE"
|
||||
REQUEST_CHANGES = "REQUEST_CHANGES"
|
||||
|
||||
|
||||
class GithubCreatePRReviewBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
class ReviewComment(TypedDict, total=False):
|
||||
path: str
|
||||
position: Optional[int]
|
||||
body: str
|
||||
line: Optional[int] # Will be used as position if position not provided
|
||||
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
pr_number: int = SchemaField(
|
||||
description="Pull request number",
|
||||
placeholder="123",
|
||||
)
|
||||
body: str = SchemaField(
|
||||
description="Body of the review comment",
|
||||
placeholder="Enter your review comment",
|
||||
)
|
||||
event: ReviewEvent = SchemaField(
|
||||
description="The review action to perform",
|
||||
default=ReviewEvent.COMMENT,
|
||||
)
|
||||
create_as_draft: bool = SchemaField(
|
||||
description="Create the review as a draft (pending) or post it immediately",
|
||||
default=False,
|
||||
advanced=False,
|
||||
)
|
||||
comments: Optional[List[ReviewComment]] = SchemaField(
|
||||
description="Optional inline comments to add to specific files/lines. Note: Only path, body, and position are supported. Position is line number in diff from first @@ hunk.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
review_id: int = SchemaField(description="ID of the created review")
|
||||
state: str = SchemaField(
|
||||
description="State of the review (e.g., PENDING, COMMENTED, APPROVED, CHANGES_REQUESTED)"
|
||||
)
|
||||
html_url: str = SchemaField(description="URL of the created review")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the review creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="84754b30-97d2-4c37-a3b8-eb39f268275b",
|
||||
description="This block creates a review on a GitHub pull request with optional inline comments. You can create it as a draft or post immediately. Note: For inline comments, 'position' should be the line number in the diff (starting from the first @@ hunk header).",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreatePRReviewBlock.Input,
|
||||
output_schema=GithubCreatePRReviewBlock.Output,
|
||||
test_input={
|
||||
"repo": "owner/repo",
|
||||
"pr_number": 1,
|
||||
"body": "This looks good to me!",
|
||||
"event": "APPROVE",
|
||||
"create_as_draft": False,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("review_id", 123456),
|
||||
("state", "APPROVED"),
|
||||
(
|
||||
"html_url",
|
||||
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"create_review": lambda *args, **kwargs: (
|
||||
123456,
|
||||
"APPROVED",
|
||||
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_review(
|
||||
credentials: GithubCredentials,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
body: str,
|
||||
event: ReviewEvent,
|
||||
create_as_draft: bool,
|
||||
comments: Optional[List[Input.ReviewComment]] = None,
|
||||
) -> tuple[int, str, str]:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
|
||||
# GitHub API endpoint for creating reviews
|
||||
reviews_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews"
|
||||
|
||||
# Get commit_id if we have comments
|
||||
commit_id = None
|
||||
if comments:
|
||||
# Get PR details to get the head commit for inline comments
|
||||
pr_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}"
|
||||
pr_response = await api.get(pr_url)
|
||||
pr_data = pr_response.json()
|
||||
commit_id = pr_data["head"]["sha"]
|
||||
|
||||
# Prepare the request data
|
||||
# If create_as_draft is True, omit the event field (creates a PENDING review)
|
||||
# Otherwise, use the actual event value which will auto-submit the review
|
||||
data: dict[str, Any] = {"body": body}
|
||||
|
||||
# Add commit_id if we have it
|
||||
if commit_id:
|
||||
data["commit_id"] = commit_id
|
||||
|
||||
# Add comments if provided
|
||||
if comments:
|
||||
# Process comments to ensure they have the required fields
|
||||
processed_comments = []
|
||||
for comment in comments:
|
||||
comment_data: dict = {
|
||||
"path": comment.get("path", ""),
|
||||
"body": comment.get("body", ""),
|
||||
}
|
||||
# Add position or line
|
||||
# Note: For review comments, only position is supported (not line/side)
|
||||
if "position" in comment and comment.get("position") is not None:
|
||||
comment_data["position"] = comment.get("position")
|
||||
elif "line" in comment and comment.get("line") is not None:
|
||||
# Note: Using line as position - may not work correctly
|
||||
# Position should be calculated from the diff
|
||||
comment_data["position"] = comment.get("line")
|
||||
|
||||
# Note: side, start_line, and start_side are NOT supported for review comments
|
||||
# They are only for standalone PR comments
|
||||
|
||||
processed_comments.append(comment_data)
|
||||
|
||||
data["comments"] = processed_comments
|
||||
|
||||
if not create_as_draft:
|
||||
# Only add event field if not creating a draft
|
||||
data["event"] = event.value
|
||||
|
||||
# Create the review
|
||||
response = await api.post(reviews_url, json=data)
|
||||
review_data = response.json()
|
||||
|
||||
return review_data["id"], review_data["state"], review_data["html_url"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
review_id, state, html_url = await self.create_review(
|
||||
credentials,
|
||||
input_data.repo,
|
||||
input_data.pr_number,
|
||||
input_data.body,
|
||||
input_data.event,
|
||||
input_data.create_as_draft,
|
||||
input_data.comments,
|
||||
)
|
||||
yield "review_id", review_id
|
||||
yield "state", state
|
||||
yield "html_url", html_url
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubListPRReviewsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
pr_number: int = SchemaField(
|
||||
description="Pull request number",
|
||||
placeholder="123",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class ReviewItem(TypedDict):
|
||||
id: int
|
||||
user: str
|
||||
state: str
|
||||
body: str
|
||||
html_url: str
|
||||
|
||||
review: ReviewItem = SchemaField(
|
||||
title="Review",
|
||||
description="Individual review with details",
|
||||
)
|
||||
reviews: list[ReviewItem] = SchemaField(
|
||||
description="List of all reviews on the pull request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing reviews failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f79bc6eb-33c0-4099-9c0f-d664ae1ba4d0",
|
||||
description="This block lists all reviews for a specified GitHub pull request.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListPRReviewsBlock.Input,
|
||||
output_schema=GithubListPRReviewsBlock.Output,
|
||||
test_input={
|
||||
"repo": "owner/repo",
|
||||
"pr_number": 1,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"reviews",
|
||||
[
|
||||
{
|
||||
"id": 123456,
|
||||
"user": "reviewer1",
|
||||
"state": "APPROVED",
|
||||
"body": "Looks good!",
|
||||
"html_url": "https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"review",
|
||||
{
|
||||
"id": 123456,
|
||||
"user": "reviewer1",
|
||||
"state": "APPROVED",
|
||||
"body": "Looks good!",
|
||||
"html_url": "https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_reviews": lambda *args, **kwargs: [
|
||||
{
|
||||
"id": 123456,
|
||||
"user": "reviewer1",
|
||||
"state": "APPROVED",
|
||||
"body": "Looks good!",
|
||||
"html_url": "https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def list_reviews(
|
||||
credentials: GithubCredentials, repo: str, pr_number: int
|
||||
) -> list[Output.ReviewItem]:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
|
||||
# GitHub API endpoint for listing reviews
|
||||
reviews_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews"
|
||||
|
||||
response = await api.get(reviews_url)
|
||||
data = response.json()
|
||||
|
||||
reviews: list[GithubListPRReviewsBlock.Output.ReviewItem] = [
|
||||
{
|
||||
"id": review["id"],
|
||||
"user": review["user"]["login"],
|
||||
"state": review["state"],
|
||||
"body": review.get("body", ""),
|
||||
"html_url": review["html_url"],
|
||||
}
|
||||
for review in data
|
||||
]
|
||||
return reviews
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
reviews = await self.list_reviews(
|
||||
credentials,
|
||||
input_data.repo,
|
||||
input_data.pr_number,
|
||||
)
|
||||
yield "reviews", reviews
|
||||
for review in reviews:
|
||||
yield "review", review
|
||||
|
||||
|
||||
class GithubSubmitPendingReviewBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
pr_number: int = SchemaField(
|
||||
description="Pull request number",
|
||||
placeholder="123",
|
||||
)
|
||||
review_id: int = SchemaField(
|
||||
description="ID of the pending review to submit",
|
||||
placeholder="123456",
|
||||
)
|
||||
event: ReviewEvent = SchemaField(
|
||||
description="The review action to perform when submitting",
|
||||
default=ReviewEvent.COMMENT,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
state: str = SchemaField(description="State of the submitted review")
|
||||
html_url: str = SchemaField(description="URL of the submitted review")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the review submission failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2e468217-7ca0-4201-9553-36e93eb9357a",
|
||||
description="This block submits a pending (draft) review on a GitHub pull request.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubSubmitPendingReviewBlock.Input,
|
||||
output_schema=GithubSubmitPendingReviewBlock.Output,
|
||||
test_input={
|
||||
"repo": "owner/repo",
|
||||
"pr_number": 1,
|
||||
"review_id": 123456,
|
||||
"event": "APPROVE",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("state", "APPROVED"),
|
||||
(
|
||||
"html_url",
|
||||
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"submit_review": lambda *args, **kwargs: (
|
||||
"APPROVED",
|
||||
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def submit_review(
|
||||
credentials: GithubCredentials,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
review_id: int,
|
||||
event: ReviewEvent,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
|
||||
# GitHub API endpoint for submitting a review
|
||||
submit_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews/{review_id}/events"
|
||||
|
||||
data = {"event": event.value}
|
||||
|
||||
response = await api.post(submit_url, json=data)
|
||||
review_data = response.json()
|
||||
|
||||
return review_data["state"], review_data["html_url"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
state, html_url = await self.submit_review(
|
||||
credentials,
|
||||
input_data.repo,
|
||||
input_data.pr_number,
|
||||
input_data.review_id,
|
||||
input_data.event,
|
||||
)
|
||||
yield "state", state
|
||||
yield "html_url", html_url
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubResolveReviewDiscussionBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
pr_number: int = SchemaField(
|
||||
description="Pull request number",
|
||||
placeholder="123",
|
||||
)
|
||||
comment_id: int = SchemaField(
|
||||
description="ID of the review comment to resolve/unresolve",
|
||||
placeholder="123456",
|
||||
)
|
||||
resolve: bool = SchemaField(
|
||||
description="Whether to resolve (true) or unresolve (false) the discussion",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the operation was successful")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b4b8a38c-95ae-4c91-9ef8-c2cffaf2b5d1",
|
||||
description="This block resolves or unresolves a review discussion thread on a GitHub pull request.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubResolveReviewDiscussionBlock.Input,
|
||||
output_schema=GithubResolveReviewDiscussionBlock.Output,
|
||||
test_input={
|
||||
"repo": "owner/repo",
|
||||
"pr_number": 1,
|
||||
"comment_id": 123456,
|
||||
"resolve": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"resolve_discussion": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def resolve_discussion(
|
||||
credentials: GithubCredentials,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
comment_id: int,
|
||||
resolve: bool,
|
||||
) -> bool:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
|
||||
# Extract owner and repo name
|
||||
parts = repo.split("/")
|
||||
owner = parts[0]
|
||||
repo_name = parts[1]
|
||||
|
||||
# GitHub GraphQL API is needed for resolving/unresolving discussions
|
||||
# First, we need to get the node ID of the comment
|
||||
graphql_url = "https://api.github.com/graphql"
|
||||
|
||||
# Query to get the review comment node ID
|
||||
query = """
|
||||
query($owner: String!, $repo: String!, $number: Int!) {
|
||||
repository(owner: $owner, name: $repo) {
|
||||
pullRequest(number: $number) {
|
||||
reviewThreads(first: 100) {
|
||||
nodes {
|
||||
comments(first: 100) {
|
||||
nodes {
|
||||
databaseId
|
||||
id
|
||||
}
|
||||
}
|
||||
id
|
||||
isResolved
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
variables = {"owner": owner, "repo": repo_name, "number": pr_number}
|
||||
|
||||
response = await api.post(
|
||||
graphql_url, json={"query": query, "variables": variables}
|
||||
)
|
||||
data = response.json()
|
||||
|
||||
# Find the thread containing our comment
|
||||
thread_id = None
|
||||
for thread in data["data"]["repository"]["pullRequest"]["reviewThreads"][
|
||||
"nodes"
|
||||
]:
|
||||
for comment in thread["comments"]["nodes"]:
|
||||
if comment["databaseId"] == comment_id:
|
||||
thread_id = thread["id"]
|
||||
break
|
||||
if thread_id:
|
||||
break
|
||||
|
||||
if not thread_id:
|
||||
raise ValueError(f"Comment {comment_id} not found in pull request")
|
||||
|
||||
# Now resolve or unresolve the thread
|
||||
# GitHub's GraphQL API has separate mutations for resolve and unresolve
|
||||
if resolve:
|
||||
mutation = """
|
||||
mutation($threadId: ID!) {
|
||||
resolveReviewThread(input: {threadId: $threadId}) {
|
||||
thread {
|
||||
isResolved
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
else:
|
||||
mutation = """
|
||||
mutation($threadId: ID!) {
|
||||
unresolveReviewThread(input: {threadId: $threadId}) {
|
||||
thread {
|
||||
isResolved
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
mutation_variables = {"threadId": thread_id}
|
||||
|
||||
response = await api.post(
|
||||
graphql_url, json={"query": mutation, "variables": mutation_variables}
|
||||
)
|
||||
result = response.json()
|
||||
|
||||
if "errors" in result:
|
||||
raise Exception(f"GraphQL error: {result['errors']}")
|
||||
|
||||
return True
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = await self.resolve_discussion(
|
||||
credentials,
|
||||
input_data.repo,
|
||||
input_data.pr_number,
|
||||
input_data.comment_id,
|
||||
input_data.resolve,
|
||||
)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "success", False
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubGetPRReviewCommentsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
placeholder="owner/repo",
|
||||
)
|
||||
pr_number: int = SchemaField(
|
||||
description="Pull request number",
|
||||
placeholder="123",
|
||||
)
|
||||
review_id: Optional[int] = SchemaField(
|
||||
description="ID of a specific review to get comments from (optional)",
|
||||
placeholder="123456",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class CommentItem(TypedDict):
|
||||
id: int
|
||||
user: str
|
||||
body: str
|
||||
path: str
|
||||
line: int
|
||||
side: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
in_reply_to_id: Optional[int]
|
||||
html_url: str
|
||||
|
||||
comment: CommentItem = SchemaField(
|
||||
title="Comment",
|
||||
description="Individual review comment with details",
|
||||
)
|
||||
comments: list[CommentItem] = SchemaField(
|
||||
description="List of all review comments on the pull request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if getting comments failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1d34db7f-10c1-45c1-9d43-749f743c8bd4",
|
||||
description="This block gets all review comments from a GitHub pull request or from a specific review.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubGetPRReviewCommentsBlock.Input,
|
||||
output_schema=GithubGetPRReviewCommentsBlock.Output,
|
||||
test_input={
|
||||
"repo": "owner/repo",
|
||||
"pr_number": 1,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"comments",
|
||||
[
|
||||
{
|
||||
"id": 123456,
|
||||
"user": "reviewer1",
|
||||
"body": "This needs improvement",
|
||||
"path": "src/main.py",
|
||||
"line": 42,
|
||||
"side": "RIGHT",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z",
|
||||
"in_reply_to_id": None,
|
||||
"html_url": "https://github.com/owner/repo/pull/1#discussion_r123456",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"comment",
|
||||
{
|
||||
"id": 123456,
|
||||
"user": "reviewer1",
|
||||
"body": "This needs improvement",
|
||||
"path": "src/main.py",
|
||||
"line": 42,
|
||||
"side": "RIGHT",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z",
|
||||
"in_reply_to_id": None,
|
||||
"html_url": "https://github.com/owner/repo/pull/1#discussion_r123456",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_comments": lambda *args, **kwargs: [
|
||||
{
|
||||
"id": 123456,
|
||||
"user": "reviewer1",
|
||||
"body": "This needs improvement",
|
||||
"path": "src/main.py",
|
||||
"line": 42,
|
||||
"side": "RIGHT",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z",
|
||||
"in_reply_to_id": None,
|
||||
"html_url": "https://github.com/owner/repo/pull/1#discussion_r123456",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_comments(
|
||||
credentials: GithubCredentials,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
review_id: Optional[int] = None,
|
||||
) -> list[Output.CommentItem]:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
|
||||
# Determine the endpoint based on whether we want comments from a specific review
|
||||
if review_id:
|
||||
# Get comments from a specific review
|
||||
comments_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews/{review_id}/comments"
|
||||
else:
|
||||
# Get all review comments on the PR
|
||||
comments_url = (
|
||||
f"https://api.github.com/repos/{repo}/pulls/{pr_number}/comments"
|
||||
)
|
||||
|
||||
response = await api.get(comments_url)
|
||||
data = response.json()
|
||||
|
||||
comments: list[GithubGetPRReviewCommentsBlock.Output.CommentItem] = [
|
||||
{
|
||||
"id": comment["id"],
|
||||
"user": comment["user"]["login"],
|
||||
"body": comment["body"],
|
||||
"path": comment.get("path", ""),
|
||||
"line": comment.get("line", 0),
|
||||
"side": comment.get("side", ""),
|
||||
"created_at": comment["created_at"],
|
||||
"updated_at": comment["updated_at"],
|
||||
"in_reply_to_id": comment.get("in_reply_to_id"),
|
||||
"html_url": comment["html_url"],
|
||||
}
|
||||
for comment in data
|
||||
]
|
||||
return comments
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
comments = await self.get_comments(
|
||||
credentials,
|
||||
input_data.repo,
|
||||
input_data.pr_number,
|
||||
input_data.review_id,
|
||||
)
|
||||
yield "comments", comments
|
||||
for comment in comments:
|
||||
yield "comment", comment
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCreateCommentObjectBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
path: str = SchemaField(
|
||||
description="The file path to comment on",
|
||||
placeholder="src/main.py",
|
||||
)
|
||||
body: str = SchemaField(
|
||||
description="The comment text",
|
||||
placeholder="Please fix this issue",
|
||||
)
|
||||
position: Optional[int] = SchemaField(
|
||||
description="Position in the diff (line number from first @@ hunk). Use this OR line.",
|
||||
placeholder="6",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
line: Optional[int] = SchemaField(
|
||||
description="Line number in the file (will be used as position if position not provided)",
|
||||
placeholder="42",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
side: Optional[str] = SchemaField(
|
||||
description="Side of the diff to comment on (NOTE: Only for standalone comments, not review comments)",
|
||||
default="RIGHT",
|
||||
advanced=True,
|
||||
)
|
||||
start_line: Optional[int] = SchemaField(
|
||||
description="Start line for multi-line comments (NOTE: Only for standalone comments, not review comments)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
start_side: Optional[str] = SchemaField(
|
||||
description="Side for the start of multi-line comments (NOTE: Only for standalone comments, not review comments)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
comment_object: dict = SchemaField(
|
||||
description="The comment object formatted for GitHub API"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b7d5e4f2-8c3a-4e6b-9f1d-7a8b9c5e4d3f",
|
||||
description="Creates a comment object for use with GitHub blocks. Note: For review comments, only path, body, and position are used. Side fields are only for standalone PR comments.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreateCommentObjectBlock.Input,
|
||||
output_schema=GithubCreateCommentObjectBlock.Output,
|
||||
test_input={
|
||||
"path": "src/main.py",
|
||||
"body": "Please fix this issue",
|
||||
"position": 6,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"comment_object",
|
||||
{
|
||||
"path": "src/main.py",
|
||||
"body": "Please fix this issue",
|
||||
"position": 6,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Build the comment object
|
||||
comment_obj: dict = {
|
||||
"path": input_data.path,
|
||||
"body": input_data.body,
|
||||
}
|
||||
|
||||
# Add position or line
|
||||
if input_data.position is not None:
|
||||
comment_obj["position"] = input_data.position
|
||||
elif input_data.line is not None:
|
||||
# Note: line will be used as position, which may not be accurate
|
||||
# Position should be calculated from the diff
|
||||
comment_obj["position"] = input_data.line
|
||||
|
||||
# Add optional fields only if they differ from defaults or are explicitly provided
|
||||
if input_data.side and input_data.side != "RIGHT":
|
||||
comment_obj["side"] = input_data.side
|
||||
if input_data.start_line is not None:
|
||||
comment_obj["start_line"] = input_data.start_line
|
||||
if input_data.start_side:
|
||||
comment_obj["start_side"] = input_data.start_side
|
||||
|
||||
yield "comment_object", comment_obj
|
||||
@@ -21,8 +21,6 @@ from ._auth import (
|
||||
GoogleCredentialsInput,
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class CalendarEvent(BaseModel):
|
||||
"""Structured representation of a Google Calendar event."""
|
||||
@@ -223,8 +221,8 @@ class GoogleCalendarReadEventsBlock(Block):
|
||||
else None
|
||||
),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=settings.secrets.google_client_id,
|
||||
client_secret=settings.secrets.google_client_secret,
|
||||
client_id=Settings().secrets.google_client_id,
|
||||
client_secret=Settings().secrets.google_client_secret,
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
return build("calendar", "v3", credentials=creds)
|
||||
@@ -571,8 +569,8 @@ class GoogleCalendarCreateEventBlock(Block):
|
||||
else None
|
||||
),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=settings.secrets.google_client_id,
|
||||
client_secret=settings.secrets.google_client_secret,
|
||||
client_id=Settings().secrets.google_client_id,
|
||||
client_secret=Settings().secrets.google_client_secret,
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
return build("calendar", "v3", credentials=creds)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -37,7 +37,6 @@ LLMProviderName = Literal[
|
||||
ProviderName.OPENAI,
|
||||
ProviderName.OPEN_ROUTER,
|
||||
ProviderName.LLAMA_API,
|
||||
ProviderName.V0,
|
||||
]
|
||||
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
|
||||
|
||||
@@ -82,11 +81,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
O3 = "o3-2025-04-16"
|
||||
O1 = "o1"
|
||||
O1_MINI = "o1-mini"
|
||||
# GPT-5 models
|
||||
GPT5 = "gpt-5-2025-08-07"
|
||||
GPT5_MINI = "gpt-5-mini-2025-08-07"
|
||||
GPT5_NANO = "gpt-5-nano-2025-08-07"
|
||||
GPT5_CHAT = "gpt-5-chat-latest"
|
||||
GPT41 = "gpt-4.1-2025-04-14"
|
||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||
GPT4O_MINI = "gpt-4o-mini"
|
||||
@@ -94,7 +88,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT3_5_TURBO = "gpt-3.5-turbo"
|
||||
# Anthropic models
|
||||
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
|
||||
CLAUDE_4_OPUS = "claude-opus-4-20250514"
|
||||
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
@@ -122,8 +115,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
OLLAMA_LLAMA3_405B = "llama3.1:405b"
|
||||
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
|
||||
# OpenRouter models
|
||||
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
||||
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
|
||||
GEMINI_FLASH_1_5 = "google/gemini-flash-1.5"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
|
||||
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
|
||||
@@ -156,10 +147,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
|
||||
LLAMA_API_LLAMA3_3_8B = "Llama-3.3-8B-Instruct"
|
||||
LLAMA_API_LLAMA3_3_70B = "Llama-3.3-70B-Instruct"
|
||||
# v0 by Vercel models
|
||||
V0_1_5_MD = "v0-1.5-md"
|
||||
V0_1_5_LG = "v0-1.5-lg"
|
||||
V0_1_0_MD = "v0-1.0-md"
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
@@ -184,11 +171,6 @@ MODEL_METADATA = {
|
||||
LlmModel.O3_MINI: ModelMetadata("openai", 200000, 100000), # o3-mini-2025-01-31
|
||||
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
|
||||
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_MINI: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_NANO: ModelMetadata("openai", 400000, 128000),
|
||||
LlmModel.GPT5_CHAT: ModelMetadata("openai", 400000, 16384),
|
||||
LlmModel.GPT41: ModelMetadata("openai", 1047576, 32768),
|
||||
LlmModel.GPT41_MINI: ModelMetadata("openai", 1047576, 32768),
|
||||
LlmModel.GPT4O_MINI: ModelMetadata(
|
||||
@@ -200,9 +182,6 @@ MODEL_METADATA = {
|
||||
), # gpt-4-turbo-2024-04-09
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, 4096), # gpt-3.5-turbo-0125
|
||||
# https://docs.anthropic.com/en/docs/about-claude/models
|
||||
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 32000
|
||||
), # claude-opus-4-1-20250805
|
||||
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 8192
|
||||
), # claude-4-opus-20250514
|
||||
@@ -267,8 +246,6 @@ MODEL_METADATA = {
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata(
|
||||
"open_router", 12288, 12288
|
||||
),
|
||||
LlmModel.OPENAI_GPT_OSS_120B: ModelMetadata("open_router", 131072, 131072),
|
||||
LlmModel.OPENAI_GPT_OSS_20B: ModelMetadata("open_router", 131072, 32768),
|
||||
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata("open_router", 300000, 5120),
|
||||
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata("open_router", 128000, 5120),
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata("open_router", 300000, 5120),
|
||||
@@ -285,10 +262,6 @@ MODEL_METADATA = {
|
||||
LlmModel.LLAMA_API_LLAMA4_MAVERICK: ModelMetadata("llama_api", 128000, 4028),
|
||||
LlmModel.LLAMA_API_LLAMA3_3_8B: ModelMetadata("llama_api", 128000, 4028),
|
||||
LlmModel.LLAMA_API_LLAMA3_3_70B: ModelMetadata("llama_api", 128000, 4028),
|
||||
# v0 by Vercel models
|
||||
LlmModel.V0_1_5_MD: ModelMetadata("v0", 128000, 64000),
|
||||
LlmModel.V0_1_5_LG: ModelMetadata("v0", 512000, 64000),
|
||||
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000),
|
||||
}
|
||||
|
||||
for model in LlmModel:
|
||||
@@ -502,7 +475,6 @@ async def llm_call(
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=an_tools,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
if not resp.content:
|
||||
@@ -685,11 +657,7 @@ async def llm_call(
|
||||
client = openai.OpenAI(
|
||||
base_url="https://api.aimlapi.com/v2",
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
default_headers={
|
||||
"X-Project": "AutoGPT",
|
||||
"X-Title": "AutoGPT",
|
||||
"HTTP-Referer": "https://github.com/Significant-Gravitas/AutoGPT",
|
||||
},
|
||||
default_headers={"X-Project": "AutoGPT"},
|
||||
)
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
@@ -709,42 +677,6 @@ async def llm_call(
|
||||
),
|
||||
reasoning=None,
|
||||
)
|
||||
elif provider == "v0":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
client = openai.AsyncOpenAI(
|
||||
base_url="https://api.v0.dev/v1",
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
response_format = None
|
||||
if json_format:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
prompt=prompt,
|
||||
response=response.choices[0].message.content or "",
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
|
||||
@@ -1,13 +1,22 @@
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client():
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
|
||||
|
||||
|
||||
StorageScope = Literal["within_agent", "across_agents"]
|
||||
|
||||
|
||||
@@ -79,7 +88,7 @@ class PersistInformationBlock(Block):
|
||||
async def _store_data(
|
||||
self, user_id: str, node_exec_id: str, key: str, data: Any
|
||||
) -> Any | None:
|
||||
return await get_database_manager_async_client().set_execution_kv_data(
|
||||
return await get_database_manager_client().set_execution_kv_data(
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
key=key,
|
||||
@@ -140,7 +149,7 @@ class RetrieveInformationBlock(Block):
|
||||
yield "value", input_data.default_value
|
||||
|
||||
async def _retrieve_data(self, user_id: str, key: str) -> Any | None:
|
||||
return await get_database_manager_async_client().get_execution_kv_data(
|
||||
return await get_database_manager_client().get_execution_kv_data(
|
||||
user_id=user_id,
|
||||
key=key,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import List
|
||||
|
||||
from backend.data.block import BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, SchemaField
|
||||
from backend.util.settings import BehaveAs, Settings
|
||||
from backend.util import settings
|
||||
from backend.util.settings import BehaveAs
|
||||
|
||||
from ._api import (
|
||||
TEST_CREDENTIALS,
|
||||
@@ -15,8 +16,6 @@ from ._api import (
|
||||
)
|
||||
from .base import Slant3DBlockBase
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class Slant3DCreateOrderBlock(Slant3DBlockBase):
|
||||
"""Block for creating new orders"""
|
||||
@@ -281,7 +280,7 @@ class Slant3DGetOrdersBlock(Slant3DBlockBase):
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
# This block is disabled for cloud hosted because it allows access to all orders for the account
|
||||
disabled=settings.config.behave_as == BehaveAs.CLOUD,
|
||||
disabled=settings.Settings().config.behave_as == BehaveAs.CLOUD,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
|
||||
@@ -9,7 +9,8 @@ from backend.data.block import (
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import AppEnvironment, BehaveAs, Settings
|
||||
from backend.util import settings
|
||||
from backend.util.settings import AppEnvironment, BehaveAs
|
||||
|
||||
from ._api import (
|
||||
TEST_CREDENTIALS,
|
||||
@@ -18,8 +19,6 @@ from ._api import (
|
||||
Slant3DCredentialsInput,
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class Slant3DTriggerBase:
|
||||
"""Base class for Slant3D webhook triggers"""
|
||||
@@ -77,8 +76,8 @@ class Slant3DOrderWebhookBlock(Slant3DTriggerBase, Block):
|
||||
),
|
||||
# All webhooks are currently subscribed to for all orders. This works for self hosted, but not for cloud hosted prod
|
||||
disabled=(
|
||||
settings.config.behave_as == BehaveAs.CLOUD
|
||||
and settings.config.app_env != AppEnvironment.LOCAL
|
||||
settings.Settings().config.behave_as == BehaveAs.CLOUD
|
||||
and settings.Settings().config.app_env != AppEnvironment.LOCAL
|
||||
),
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
|
||||
@@ -3,6 +3,8 @@ import re
|
||||
from collections import Counter
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data.block import (
|
||||
@@ -15,7 +17,6 @@ from backend.data.block import (
|
||||
)
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util import json
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.graph import Link, Node
|
||||
@@ -23,6 +24,14 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client():
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
|
||||
|
||||
|
||||
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Return a list of tool_call_ids if the entry is a tool request.
|
||||
@@ -291,32 +300,9 @@ class SmartDecisionMakerBlock(Block):
|
||||
|
||||
for link in links:
|
||||
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
|
||||
|
||||
# Handle dynamic fields (e.g., values_#_*, items_$_*, etc.)
|
||||
# These are fields that get merged by the executor into their base field
|
||||
if (
|
||||
"_#_" in link.sink_name
|
||||
or "_$_" in link.sink_name
|
||||
or "_@_" in link.sink_name
|
||||
):
|
||||
# For dynamic fields, provide a generic string schema
|
||||
# The executor will handle merging these into the appropriate structure
|
||||
properties[sink_name] = {
|
||||
"type": "string",
|
||||
"description": f"Dynamic value for {link.sink_name}",
|
||||
}
|
||||
else:
|
||||
# For regular fields, use the block's schema
|
||||
try:
|
||||
properties[sink_name] = sink_block_input_schema.get_field_schema(
|
||||
link.sink_name
|
||||
)
|
||||
except (KeyError, AttributeError):
|
||||
# If the field doesn't exist in the schema, provide a generic schema
|
||||
properties[sink_name] = {
|
||||
"type": "string",
|
||||
"description": f"Value for {link.sink_name}",
|
||||
}
|
||||
properties[sink_name] = sink_block_input_schema.get_field_schema(
|
||||
link.sink_name
|
||||
)
|
||||
|
||||
tool_function["parameters"] = {
|
||||
**block.input_schema.jsonschema(),
|
||||
@@ -347,7 +333,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
if not graph_id or not graph_version:
|
||||
raise ValueError("Graph ID or Graph Version not found in sink node.")
|
||||
|
||||
db_client = get_database_manager_async_client()
|
||||
db_client = get_database_manager_client()
|
||||
sink_graph_meta = await db_client.get_graph_metadata(graph_id, graph_version)
|
||||
if not sink_graph_meta:
|
||||
raise ValueError(
|
||||
@@ -407,7 +393,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
ValueError: If no tool links are found for the specified node_id, or if a sink node
|
||||
or its metadata cannot be found.
|
||||
"""
|
||||
db_client = get_database_manager_async_client()
|
||||
db_client = get_database_manager_client()
|
||||
tools = [
|
||||
(link, node)
|
||||
for link, node in await db_client.get_connected_output_nodes(node_id)
|
||||
@@ -501,6 +487,10 @@ class SmartDecisionMakerBlock(Block):
|
||||
}
|
||||
)
|
||||
prompt.extend(tool_output)
|
||||
if input_data.multiple_tool_calls:
|
||||
input_data.sys_prompt += "\nYou can call a tool (different tools) multiple times in a single response."
|
||||
else:
|
||||
input_data.sys_prompt += "\nOnly provide EXACTLY one function call, multiple tool calls is strictly prohibited."
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
@@ -539,6 +529,15 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
)
|
||||
|
||||
# Add reasoning to conversation history if available
|
||||
if response.reasoning:
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
|
||||
)
|
||||
|
||||
prompt.append(response.raw_response)
|
||||
yield "conversations", prompt
|
||||
|
||||
if not response.tool_calls:
|
||||
yield "finished", response.response
|
||||
return
|
||||
@@ -572,12 +571,3 @@ class SmartDecisionMakerBlock(Block):
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", tool_args[arg_name]
|
||||
else:
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", None
|
||||
|
||||
# Add reasoning to conversation history if available
|
||||
if response.reasoning:
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
|
||||
)
|
||||
|
||||
prompt.append(response.raw_response)
|
||||
yield "conversations", prompt
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,283 +0,0 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Duplicate pydantic models for store data so we don't accidently change the data shape in the blocks unintentionally when editing the backend
|
||||
class LibraryAgent(BaseModel):
|
||||
"""Model representing an agent in the user's library."""
|
||||
|
||||
library_agent_id: str = ""
|
||||
agent_id: str = ""
|
||||
agent_version: int = 0
|
||||
agent_name: str = ""
|
||||
description: str = ""
|
||||
creator: str = ""
|
||||
is_archived: bool = False
|
||||
categories: list[str] = []
|
||||
|
||||
|
||||
class AddToLibraryFromStoreBlock(Block):
|
||||
"""
|
||||
Block that adds an agent from the store to the user's library.
|
||||
This enables users to easily import agents from the marketplace into their personal collection.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
store_listing_version_id: str = SchemaField(
|
||||
description="The ID of the store listing version to add to library"
|
||||
)
|
||||
agent_name: str | None = SchemaField(
|
||||
description="Optional custom name for the agent in your library",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the agent was successfully added to library"
|
||||
)
|
||||
library_agent_id: str = SchemaField(
|
||||
description="The ID of the library agent entry"
|
||||
)
|
||||
agent_id: str = SchemaField(description="The ID of the agent graph")
|
||||
agent_version: int = SchemaField(
|
||||
description="The version number of the agent graph"
|
||||
)
|
||||
agent_name: str = SchemaField(description="The name of the agent")
|
||||
message: str = SchemaField(description="Success or error message")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2602a7b1-3f4d-4e5f-9c8b-1a2b3c4d5e6f",
|
||||
description="Add an agent from the store to your personal library",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=AddToLibraryFromStoreBlock.Input,
|
||||
output_schema=AddToLibraryFromStoreBlock.Output,
|
||||
test_input={
|
||||
"store_listing_version_id": "test-listing-id",
|
||||
"agent_name": "My Custom Agent",
|
||||
},
|
||||
test_output=[
|
||||
("success", True),
|
||||
("library_agent_id", "test-library-id"),
|
||||
("agent_id", "test-agent-id"),
|
||||
("agent_version", 1),
|
||||
("agent_name", "Test Agent"),
|
||||
("message", "Agent successfully added to library"),
|
||||
],
|
||||
test_mock={
|
||||
"_add_to_library": lambda *_, **__: LibraryAgent(
|
||||
library_agent_id="test-library-id",
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
agent_name="Test Agent",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
library_agent = await self._add_to_library(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=input_data.store_listing_version_id,
|
||||
custom_name=input_data.agent_name,
|
||||
)
|
||||
|
||||
yield "success", True
|
||||
yield "library_agent_id", library_agent.library_agent_id
|
||||
yield "agent_id", library_agent.agent_id
|
||||
yield "agent_version", library_agent.agent_version
|
||||
yield "agent_name", library_agent.agent_name
|
||||
yield "message", "Agent successfully added to library"
|
||||
|
||||
async def _add_to_library(
|
||||
self,
|
||||
user_id: str,
|
||||
store_listing_version_id: str,
|
||||
custom_name: str | None = None,
|
||||
) -> LibraryAgent:
|
||||
"""
|
||||
Add a store agent to the user's library using the existing library database function.
|
||||
"""
|
||||
library_agent = (
|
||||
await get_database_manager_async_client().add_store_agent_to_library(
|
||||
store_listing_version_id=store_listing_version_id, user_id=user_id
|
||||
)
|
||||
)
|
||||
|
||||
# If custom name is provided, we could update the library agent name here
|
||||
# For now, we'll just return the agent info
|
||||
agent_name = custom_name if custom_name else library_agent.name
|
||||
|
||||
return LibraryAgent(
|
||||
library_agent_id=library_agent.id,
|
||||
agent_id=library_agent.graph_id,
|
||||
agent_version=library_agent.graph_version,
|
||||
agent_name=agent_name,
|
||||
)
|
||||
|
||||
|
||||
class ListLibraryAgentsBlock(Block):
|
||||
"""
|
||||
Block that lists all agents in the user's library.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
search_query: str | None = SchemaField(
|
||||
description="Optional search query to filter agents", default=None
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of agents to return", default=50, ge=1, le=100
|
||||
)
|
||||
page: int = SchemaField(
|
||||
description="Page number for pagination", default=1, ge=1
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
agents: list[LibraryAgent] = SchemaField(
|
||||
description="List of agents in the library",
|
||||
default_factory=list,
|
||||
)
|
||||
agent: LibraryAgent = SchemaField(
|
||||
description="Individual library agent (yielded for each agent)"
|
||||
)
|
||||
total_count: int = SchemaField(
|
||||
description="Total number of agents in library", default=0
|
||||
)
|
||||
page: int = SchemaField(description="Current page number", default=1)
|
||||
total_pages: int = SchemaField(description="Total number of pages", default=1)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="082602d3-a74d-4600-9e9c-15b3af7eae98",
|
||||
description="List all agents in your personal library",
|
||||
categories={BlockCategory.BASIC, BlockCategory.DATA},
|
||||
input_schema=ListLibraryAgentsBlock.Input,
|
||||
output_schema=ListLibraryAgentsBlock.Output,
|
||||
test_input={
|
||||
"search_query": None,
|
||||
"limit": 10,
|
||||
"page": 1,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"agents",
|
||||
[
|
||||
LibraryAgent(
|
||||
library_agent_id="test-lib-id",
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
agent_name="Test Library Agent",
|
||||
description="A test agent in library",
|
||||
creator="Test User",
|
||||
),
|
||||
],
|
||||
),
|
||||
("total_count", 1),
|
||||
("page", 1),
|
||||
("total_pages", 1),
|
||||
(
|
||||
"agent",
|
||||
LibraryAgent(
|
||||
library_agent_id="test-lib-id",
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
agent_name="Test Library Agent",
|
||||
description="A test agent in library",
|
||||
creator="Test User",
|
||||
),
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"_list_library_agents": lambda *_, **__: {
|
||||
"agents": [
|
||||
LibraryAgent(
|
||||
library_agent_id="test-lib-id",
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
agent_name="Test Library Agent",
|
||||
description="A test agent in library",
|
||||
creator="Test User",
|
||||
)
|
||||
],
|
||||
"total": 1,
|
||||
"page": 1,
|
||||
"total_pages": 1,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
result = await self._list_library_agents(
|
||||
user_id=user_id,
|
||||
search_query=input_data.search_query,
|
||||
limit=input_data.limit,
|
||||
page=input_data.page,
|
||||
)
|
||||
|
||||
agents = result["agents"]
|
||||
|
||||
yield "agents", agents
|
||||
yield "total_count", result["total"]
|
||||
yield "page", result["page"]
|
||||
yield "total_pages", result["total_pages"]
|
||||
|
||||
# Yield each agent individually for better graph connectivity
|
||||
for agent in agents:
|
||||
yield "agent", agent
|
||||
|
||||
async def _list_library_agents(
|
||||
self,
|
||||
user_id: str,
|
||||
search_query: str | None = None,
|
||||
limit: int = 50,
|
||||
page: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
List agents in the user's library using the database client.
|
||||
"""
|
||||
result = await get_database_manager_async_client().list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_query,
|
||||
page=page,
|
||||
page_size=limit,
|
||||
)
|
||||
|
||||
agents = [
|
||||
LibraryAgent(
|
||||
library_agent_id=agent.id,
|
||||
agent_id=agent.graph_id,
|
||||
agent_version=agent.graph_version,
|
||||
agent_name=agent.name,
|
||||
description=getattr(agent, "description", ""),
|
||||
creator=getattr(agent, "creator", ""),
|
||||
is_archived=getattr(agent, "is_archived", False),
|
||||
categories=getattr(agent, "categories", []),
|
||||
)
|
||||
for agent in result.agents
|
||||
]
|
||||
|
||||
return {
|
||||
"agents": agents,
|
||||
"total": result.pagination.total_items,
|
||||
"page": result.pagination.current_page,
|
||||
"total_pages": result.pagination.total_pages,
|
||||
}
|
||||
@@ -1,311 +0,0 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Duplicate pydantic models for store data so we don't accidently change the data shape in the blocks unintentionally when editing the backend
|
||||
class StoreAgent(BaseModel):
|
||||
"""Model representing a store agent."""
|
||||
|
||||
slug: str = ""
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
creator: str = ""
|
||||
rating: float = 0.0
|
||||
runs: int = 0
|
||||
categories: list[str] = []
|
||||
|
||||
|
||||
class StoreAgentDict(BaseModel):
|
||||
"""Dictionary representation of a store agent."""
|
||||
|
||||
slug: str
|
||||
name: str
|
||||
description: str
|
||||
creator: str
|
||||
rating: float
|
||||
runs: int
|
||||
|
||||
|
||||
class SearchAgentsResponse(BaseModel):
|
||||
"""Response from searching store agents."""
|
||||
|
||||
agents: list[StoreAgentDict]
|
||||
total_count: int
|
||||
|
||||
|
||||
class StoreAgentDetails(BaseModel):
|
||||
"""Detailed information about a store agent."""
|
||||
|
||||
found: bool
|
||||
store_listing_version_id: str = ""
|
||||
agent_name: str = ""
|
||||
description: str = ""
|
||||
creator: str = ""
|
||||
categories: list[str] = []
|
||||
runs: int = 0
|
||||
rating: float = 0.0
|
||||
|
||||
|
||||
class GetStoreAgentDetailsBlock(Block):
|
||||
"""
|
||||
Block that retrieves detailed information about an agent from the store.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
creator: str = SchemaField(description="The username of the agent creator")
|
||||
slug: str = SchemaField(description="The name of the agent")
|
||||
|
||||
class Output(BlockSchema):
|
||||
found: bool = SchemaField(
|
||||
description="Whether the agent was found in the store"
|
||||
)
|
||||
store_listing_version_id: str = SchemaField(
|
||||
description="The store listing version ID"
|
||||
)
|
||||
agent_name: str = SchemaField(description="Name of the agent")
|
||||
description: str = SchemaField(description="Description of the agent")
|
||||
creator: str = SchemaField(description="Creator of the agent")
|
||||
categories: list[str] = SchemaField(
|
||||
description="Categories the agent belongs to", default_factory=list
|
||||
)
|
||||
runs: int = SchemaField(
|
||||
description="Number of times the agent has been run", default=0
|
||||
)
|
||||
rating: float = SchemaField(
|
||||
description="Average rating of the agent", default=0.0
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b604f0ec-6e0d-40a7-bf55-9fd09997cced",
|
||||
description="Get detailed information about an agent from the store",
|
||||
categories={BlockCategory.BASIC, BlockCategory.DATA},
|
||||
input_schema=GetStoreAgentDetailsBlock.Input,
|
||||
output_schema=GetStoreAgentDetailsBlock.Output,
|
||||
test_input={"creator": "test-creator", "slug": "test-agent-slug"},
|
||||
test_output=[
|
||||
("found", True),
|
||||
("store_listing_version_id", "test-listing-id"),
|
||||
("agent_name", "Test Agent"),
|
||||
("description", "A test agent"),
|
||||
("creator", "Test Creator"),
|
||||
("categories", ["productivity", "automation"]),
|
||||
("runs", 100),
|
||||
("rating", 4.5),
|
||||
],
|
||||
test_mock={
|
||||
"_get_agent_details": lambda *_, **__: StoreAgentDetails(
|
||||
found=True,
|
||||
store_listing_version_id="test-listing-id",
|
||||
agent_name="Test Agent",
|
||||
description="A test agent",
|
||||
creator="Test Creator",
|
||||
categories=["productivity", "automation"],
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
)
|
||||
},
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
details = await self._get_agent_details(
|
||||
creator=input_data.creator, slug=input_data.slug
|
||||
)
|
||||
yield "found", details.found
|
||||
yield "store_listing_version_id", details.store_listing_version_id
|
||||
yield "agent_name", details.agent_name
|
||||
yield "description", details.description
|
||||
yield "creator", details.creator
|
||||
yield "categories", details.categories
|
||||
yield "runs", details.runs
|
||||
yield "rating", details.rating
|
||||
|
||||
async def _get_agent_details(self, creator: str, slug: str) -> StoreAgentDetails:
|
||||
"""
|
||||
Retrieve detailed information about a store agent.
|
||||
"""
|
||||
# Get by specific version ID
|
||||
agent_details = (
|
||||
await get_database_manager_async_client().get_store_agent_details(
|
||||
username=creator, agent_name=slug
|
||||
)
|
||||
)
|
||||
|
||||
return StoreAgentDetails(
|
||||
found=True,
|
||||
store_listing_version_id=agent_details.store_listing_version_id,
|
||||
agent_name=agent_details.agent_name,
|
||||
description=agent_details.description,
|
||||
creator=agent_details.creator,
|
||||
categories=(
|
||||
agent_details.categories if hasattr(agent_details, "categories") else []
|
||||
),
|
||||
runs=agent_details.runs,
|
||||
rating=agent_details.rating,
|
||||
)
|
||||
|
||||
|
||||
class SearchStoreAgentsBlock(Block):
|
||||
"""
|
||||
Block that searches for agents in the store based on various criteria.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
query: str | None = SchemaField(
|
||||
description="Search query to find agents", default=None
|
||||
)
|
||||
category: str | None = SchemaField(
|
||||
description="Filter by category", default=None
|
||||
)
|
||||
sort_by: Literal["rating", "runs", "name", "recent"] = SchemaField(
|
||||
description="How to sort the results", default="rating"
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
description="Maximum number of results to return", default=10, ge=1, le=100
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
agents: list[StoreAgent] = SchemaField(
|
||||
description="List of agents matching the search criteria",
|
||||
default_factory=list,
|
||||
)
|
||||
agent: StoreAgent = SchemaField(description="Basic information of the agent")
|
||||
total_count: int = SchemaField(
|
||||
description="Total number of agents found", default=0
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="39524701-026c-4328-87cc-1b88c8e2cb4c",
|
||||
description="Search for agents in the store",
|
||||
categories={BlockCategory.BASIC, BlockCategory.DATA},
|
||||
input_schema=SearchStoreAgentsBlock.Input,
|
||||
output_schema=SearchStoreAgentsBlock.Output,
|
||||
test_input={
|
||||
"query": "productivity",
|
||||
"category": None,
|
||||
"sort_by": "rating",
|
||||
"limit": 10,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"agents",
|
||||
[
|
||||
{
|
||||
"slug": "test-agent",
|
||||
"name": "Test Agent",
|
||||
"description": "A test agent",
|
||||
"creator": "Test Creator",
|
||||
"rating": 4.5,
|
||||
"runs": 100,
|
||||
}
|
||||
],
|
||||
),
|
||||
("total_count", 1),
|
||||
(
|
||||
"agent",
|
||||
{
|
||||
"slug": "test-agent",
|
||||
"name": "Test Agent",
|
||||
"description": "A test agent",
|
||||
"creator": "Test Creator",
|
||||
"rating": 4.5,
|
||||
"runs": 100,
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"_search_agents": lambda *_, **__: SearchAgentsResponse(
|
||||
agents=[
|
||||
StoreAgentDict(
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
description="A test agent",
|
||||
creator="Test Creator",
|
||||
rating=4.5,
|
||||
runs=100,
|
||||
)
|
||||
],
|
||||
total_count=1,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
result = await self._search_agents(
|
||||
query=input_data.query,
|
||||
category=input_data.category,
|
||||
sort_by=input_data.sort_by,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
agents = result.agents
|
||||
total_count = result.total_count
|
||||
|
||||
# Convert to dict for output
|
||||
agents_as_dicts = [agent.model_dump() for agent in agents]
|
||||
|
||||
yield "agents", agents_as_dicts
|
||||
yield "total_count", total_count
|
||||
|
||||
for agent_dict in agents_as_dicts:
|
||||
yield "agent", agent_dict
|
||||
|
||||
async def _search_agents(
|
||||
self,
|
||||
query: str | None = None,
|
||||
category: str | None = None,
|
||||
sort_by: str = "rating",
|
||||
limit: int = 10,
|
||||
) -> SearchAgentsResponse:
|
||||
"""
|
||||
Search for agents in the store using the existing store database function.
|
||||
"""
|
||||
# Map our sort_by to the store's sorted_by parameter
|
||||
sorted_by_map = {
|
||||
"rating": "most_popular",
|
||||
"runs": "most_runs",
|
||||
"name": "alphabetical",
|
||||
"recent": "recently_updated",
|
||||
}
|
||||
|
||||
result = await get_database_manager_async_client().get_store_agents(
|
||||
featured=False,
|
||||
creators=None,
|
||||
sorted_by=sorted_by_map.get(sort_by, "most_popular"),
|
||||
search_query=query,
|
||||
category=category,
|
||||
page=1,
|
||||
page_size=limit,
|
||||
)
|
||||
|
||||
agents = [
|
||||
StoreAgentDict(
|
||||
slug=agent.slug,
|
||||
name=agent.agent_name,
|
||||
description=agent.description,
|
||||
creator=agent.creator,
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
)
|
||||
for agent in result.agents
|
||||
]
|
||||
|
||||
return SearchAgentsResponse(agents=agents, total_count=len(agents))
|
||||
@@ -1,8 +1,9 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from prisma.models import User
|
||||
|
||||
from backend.data.model import ProviderName, User
|
||||
from backend.data.model import ProviderName
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user