mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 16:48:06 -05:00
Compare commits
5 Commits
seer/featu
...
feat/execu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f652cb978 | ||
|
|
279552a2a3 | ||
|
|
fb6ac1d6ca | ||
|
|
9db15bff02 | ||
|
|
db4b94e0dc |
94
.github/copilot-instructions.md
vendored
94
.github/copilot-instructions.md
vendored
@@ -12,7 +12,6 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
|
||||
- **Infrastructure** - Docker configurations, CI/CD, and development tools
|
||||
|
||||
**Primary Languages & Frameworks:**
|
||||
|
||||
- **Backend**: Python 3.10-3.13, FastAPI, Prisma ORM, PostgreSQL, RabbitMQ
|
||||
- **Frontend**: TypeScript, Next.js 15, React, Tailwind CSS, Radix UI
|
||||
- **Development**: Docker, Poetry, pnpm, Playwright, Storybook
|
||||
@@ -24,17 +23,15 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
|
||||
**Always run these commands in the correct directory and in this order:**
|
||||
|
||||
1. **Initial Setup** (required once):
|
||||
|
||||
```bash
|
||||
# Clone and enter repository
|
||||
git clone <repo> && cd AutoGPT
|
||||
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
```
|
||||
|
||||
2. **Backend Setup** (always run before backend development):
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry install # Install dependencies
|
||||
@@ -51,7 +48,6 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
|
||||
### Runtime Requirements
|
||||
|
||||
**Critical:** Always ensure Docker services are running before starting development:
|
||||
|
||||
```bash
|
||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
```
|
||||
@@ -62,7 +58,6 @@ cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
### Development Commands
|
||||
|
||||
**Backend Development:**
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry run serve # Start development server (port 8000)
|
||||
@@ -73,7 +68,6 @@ poetry run lint # Lint code (ruff) - run after format
|
||||
```
|
||||
|
||||
**Frontend Development:**
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/frontend
|
||||
pnpm dev # Start development server (port 3000) - use for active development
|
||||
@@ -87,27 +81,23 @@ pnpm storybook # Start component development server
|
||||
### Testing Strategy
|
||||
|
||||
**Backend Tests:**
|
||||
|
||||
- **Block Tests**: `poetry run pytest backend/blocks/test/test_block.py -xvs` (validates all blocks)
|
||||
- **Specific Block**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[BlockName]' -xvs`
|
||||
- **Snapshot Tests**: Use `--snapshot-update` when output changes, always review with `git diff`
|
||||
|
||||
**Frontend Tests:**
|
||||
|
||||
- **E2E Tests**: Always run `pnpm dev` before `pnpm test` (Playwright requires running instance)
|
||||
- **Component Tests**: Use Storybook for isolated component development
|
||||
|
||||
### Critical Validation Steps
|
||||
|
||||
**Before committing changes:**
|
||||
|
||||
1. Run `poetry run format` (backend) and `pnpm format` (frontend)
|
||||
2. Ensure all tests pass in modified areas
|
||||
3. Verify Docker services are still running
|
||||
4. Check that database migrations apply cleanly
|
||||
|
||||
**Common Issues & Workarounds:**
|
||||
|
||||
- **Prisma issues**: Run `poetry run prisma generate` after schema changes
|
||||
- **Permission errors**: Ensure Docker has proper permissions
|
||||
- **Port conflicts**: Check the `docker-compose.yml` file for the current list of exposed ports. You can list all mapped ports with:
|
||||
@@ -118,7 +108,6 @@ pnpm storybook # Start component development server
|
||||
### Core Architecture
|
||||
|
||||
**AutoGPT Platform** (`autogpt_platform/`):
|
||||
|
||||
- `backend/` - FastAPI server with async support
|
||||
- `backend/backend/` - Core API logic
|
||||
- `backend/blocks/` - Agent execution blocks
|
||||
@@ -132,7 +121,6 @@ pnpm storybook # Start component development server
|
||||
- `docker-compose.yml` - Development stack orchestration
|
||||
|
||||
**Key Configuration Files:**
|
||||
|
||||
- `pyproject.toml` - Python dependencies and tooling
|
||||
- `package.json` - Node.js dependencies and scripts
|
||||
- `schema.prisma` - Database schema and migrations
|
||||
@@ -148,7 +136,6 @@ pnpm storybook # Start component development server
|
||||
### Development Workflow
|
||||
|
||||
**GitHub Actions**: Multiple CI/CD workflows in `.github/workflows/`
|
||||
|
||||
- `platform-backend-ci.yml` - Backend testing and validation
|
||||
- `platform-frontend-ci.yml` - Frontend testing and validation
|
||||
- `platform-fullstack-ci.yml` - End-to-end integration tests
|
||||
@@ -159,13 +146,11 @@ pnpm storybook # Start component development server
|
||||
### Key Source Files
|
||||
|
||||
**Backend Entry Points:**
|
||||
|
||||
- `backend/backend/server/server.py` - FastAPI application setup
|
||||
- `backend/backend/data/` - Database models and user management
|
||||
- `backend/blocks/` - Agent execution blocks and logic
|
||||
|
||||
**Frontend Entry Points:**
|
||||
|
||||
- `frontend/src/app/layout.tsx` - Root application layout
|
||||
- `frontend/src/app/page.tsx` - Home page
|
||||
- `frontend/src/lib/supabase/` - Authentication and database client
|
||||
@@ -175,7 +160,6 @@ pnpm storybook # Start component development server
|
||||
### Agent Block System
|
||||
|
||||
Agents are built using a visual block-based system where each block performs a single action. Blocks are defined in `backend/blocks/` and must include:
|
||||
|
||||
- Block definition with input/output schemas
|
||||
- Execution logic with proper error handling
|
||||
- Tests validating functionality
|
||||
@@ -183,7 +167,6 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
### Database & ORM
|
||||
|
||||
**Prisma ORM** with PostgreSQL backend including pgvector for embeddings:
|
||||
|
||||
- Schema in `schema.prisma`
|
||||
- Migrations in `backend/migrations/`
|
||||
- Always run `prisma migrate dev` and `prisma generate` after schema changes
|
||||
@@ -191,15 +174,13 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
## Environment Configuration
|
||||
|
||||
### Configuration Files Priority Order
|
||||
|
||||
1. **Backend**: `/backend/.env.default` → `/backend/.env` (user overrides)
|
||||
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
|
||||
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
|
||||
3. **Platform**: `/.env.default` (Supabase/shared) → `/.env` (user overrides)
|
||||
4. Docker Compose `environment:` sections override file-based config
|
||||
5. Shell environment variables have highest precedence
|
||||
|
||||
### Docker Environment Setup
|
||||
|
||||
- All services use hardcoded defaults (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
@@ -208,7 +189,6 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
## Advanced Development Patterns
|
||||
|
||||
### Adding New Blocks
|
||||
|
||||
1. Create file in `/backend/backend/blocks/`
|
||||
2. Inherit from `Block` base class with input/output schemas
|
||||
3. Implement `run` method with proper error handling
|
||||
@@ -218,7 +198,6 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
7. Consider how inputs/outputs connect with other blocks in graph editor
|
||||
|
||||
### API Development
|
||||
|
||||
1. Update routes in `/backend/backend/server/routers/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside route files
|
||||
@@ -226,76 +205,21 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
5. Run `poetry run test` to verify changes
|
||||
|
||||
### Frontend Development
|
||||
|
||||
**📖 Complete Frontend Guide**: See `autogpt_platform/frontend/CONTRIBUTING.md` and `autogpt_platform/frontend/.cursorrules` for comprehensive patterns and conventions.
|
||||
|
||||
**Quick Reference:**
|
||||
|
||||
**Component Structure:**
|
||||
|
||||
- Separate render logic from data/behavior
|
||||
- Structure: `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Exception: Small components (3-4 lines of logic) can be inline
|
||||
- Render-only components can be direct files without folders
|
||||
|
||||
**Data Fetching:**
|
||||
|
||||
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Generated via Orval from backend OpenAPI spec
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
- Example: `useGetV2ListLibraryAgents`
|
||||
- Regenerate with: `pnpm generate:api`
|
||||
- **Never** use deprecated `BackendAPI` or `src/lib/autogpt-server-api/*`
|
||||
|
||||
**Code Conventions:**
|
||||
|
||||
- Use function declarations for components and handlers (not arrow functions)
|
||||
- Only arrow functions for small inline lambdas (map, filter, etc.)
|
||||
- Components: `PascalCase`, Hooks: `camelCase` with `use` prefix
|
||||
- No barrel files or `index.ts` re-exports
|
||||
- Minimal comments (code should be self-documenting)
|
||||
|
||||
**Styling:**
|
||||
|
||||
- Use Tailwind CSS utilities only
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
- Only use Phosphor Icons (`@phosphor-icons/react`)
|
||||
- Prefer design tokens over hardcoded values
|
||||
|
||||
**Error Handling:**
|
||||
|
||||
- Render errors: Use `<ErrorCard />` component
|
||||
- Mutation errors: Display with toast notifications
|
||||
- Manual exceptions: Use `Sentry.captureException()`
|
||||
- Global error boundaries already configured
|
||||
|
||||
**Testing:**
|
||||
|
||||
- Add/update Storybook stories for UI components (`pnpm storybook`)
|
||||
- Run Playwright E2E tests with `pnpm test`
|
||||
- Verify in Chromatic after PR
|
||||
|
||||
**Architecture:**
|
||||
|
||||
- Default to client components ("use client")
|
||||
- Server components only for SEO or extreme TTFB needs
|
||||
- Use React Query for server state (via generated hooks)
|
||||
- Co-locate UI state in components/hooks
|
||||
1. Components in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for component development
|
||||
4. Test user-facing features with Playwright E2E tests
|
||||
5. Update protected routes in middleware when needed
|
||||
|
||||
### Security Guidelines
|
||||
|
||||
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
|
||||
|
||||
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
||||
- Prevents sensitive data caching in browsers/proxies
|
||||
- Add new cacheable endpoints to `CACHEABLE_PATHS`
|
||||
|
||||
### CI/CD Alignment
|
||||
|
||||
The repository has comprehensive CI workflows that test:
|
||||
|
||||
- **Backend**: Python 3.11-3.13, services (Redis/RabbitMQ/ClamAV), Prisma migrations, Poetry lock validation
|
||||
- **Frontend**: Node.js 21, pnpm, Playwright with Docker Compose stack, API schema validation
|
||||
- **Integration**: Full-stack type checking and E2E testing
|
||||
@@ -305,7 +229,6 @@ Match these patterns when developing locally - the copilot setup environment mir
|
||||
## Collaboration with Other AI Assistants
|
||||
|
||||
This repository is actively developed with assistance from Claude (via CLAUDE.md files). When working on this codebase:
|
||||
|
||||
- Check for existing CLAUDE.md files that provide additional context
|
||||
- Follow established patterns and conventions already in the codebase
|
||||
- Maintain consistency with existing code style and architecture
|
||||
@@ -314,9 +237,8 @@ This repository is actively developed with assistance from Claude (via CLAUDE.md
|
||||
## Trust These Instructions
|
||||
|
||||
These instructions are comprehensive and tested. Only perform additional searches if:
|
||||
|
||||
1. Information here is incomplete for your specific task
|
||||
2. You encounter errors not covered by the workarounds
|
||||
3. You need to understand implementation details not covered above
|
||||
|
||||
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.
|
||||
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.
|
||||
97
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
97
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
@@ -1,97 +0,0 @@
|
||||
name: Auto Fix CI Failures
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["CI"]
|
||||
types:
|
||||
- completed
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
actions: read
|
||||
issues: write
|
||||
id-token: write # Required for OIDC token exchange
|
||||
|
||||
jobs:
|
||||
auto-fix:
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'failure' &&
|
||||
github.event.workflow_run.pull_requests[0] &&
|
||||
!startsWith(github.event.workflow_run.head_branch, 'claude-auto-fix-ci-')
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.workflow_run.head_branch }}
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Setup git identity
|
||||
run: |
|
||||
git config --global user.email "claude[bot]@users.noreply.github.com"
|
||||
git config --global user.name "claude[bot]"
|
||||
|
||||
- name: Create fix branch
|
||||
id: branch
|
||||
run: |
|
||||
BRANCH_NAME="claude-auto-fix-ci-${{ github.event.workflow_run.head_branch }}-${{ github.run_id }}"
|
||||
git checkout -b "$BRANCH_NAME"
|
||||
echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Get CI failure details
|
||||
id: failure_details
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const run = await github.rest.actions.getWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: ${{ github.event.workflow_run.id }}
|
||||
});
|
||||
|
||||
const jobs = await github.rest.actions.listJobsForWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: ${{ github.event.workflow_run.id }}
|
||||
});
|
||||
|
||||
const failedJobs = jobs.data.jobs.filter(job => job.conclusion === 'failure');
|
||||
|
||||
let errorLogs = [];
|
||||
for (const job of failedJobs) {
|
||||
const logs = await github.rest.actions.downloadJobLogsForWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
job_id: job.id
|
||||
});
|
||||
errorLogs.push({
|
||||
jobName: job.name,
|
||||
logs: logs.data
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
runUrl: run.data.html_url,
|
||||
failedJobs: failedJobs.map(j => j.name),
|
||||
errorLogs: errorLogs
|
||||
};
|
||||
|
||||
- name: Fix CI failures with Claude
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
prompt: |
|
||||
/fix-ci
|
||||
Failed CI Run: ${{ fromJSON(steps.failure_details.outputs.result).runUrl }}
|
||||
Failed Jobs: ${{ join(fromJSON(steps.failure_details.outputs.result).failedJobs, ', ') }}
|
||||
PR Number: ${{ github.event.workflow_run.pull_requests[0].number }}
|
||||
Branch Name: ${{ steps.branch.outputs.branch_name }}
|
||||
Base Branch: ${{ github.event.workflow_run.head_branch }}
|
||||
Repository: ${{ github.repository }}
|
||||
|
||||
Error logs:
|
||||
${{ toJSON(fromJSON(steps.failure_details.outputs.result).errorLogs) }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: "--allowedTools 'Edit,MultiEdit,Write,Read,Glob,Grep,LS,Bash(git:*),Bash(bun:*),Bash(npm:*),Bash(npx:*),Bash(gh:*)'"
|
||||
379
.github/workflows/claude-dependabot.yml
vendored
379
.github/workflows/claude-dependabot.yml
vendored
@@ -1,379 +0,0 @@
|
||||
# Claude Dependabot PR Review Workflow
|
||||
#
|
||||
# This workflow automatically runs Claude analysis on Dependabot PRs to:
|
||||
# - Identify dependency changes and their versions
|
||||
# - Look up changelogs for updated packages
|
||||
# - Assess breaking changes and security impacts
|
||||
# - Provide actionable recommendations for the development team
|
||||
#
|
||||
# Triggered on: Dependabot PRs (opened, synchronize)
|
||||
# Requirements: ANTHROPIC_API_KEY secret must be configured
|
||||
|
||||
name: Claude Dependabot PR Review
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
|
||||
jobs:
|
||||
dependabot-review:
|
||||
# Only run on Dependabot PRs
|
||||
if: github.actor == 'dependabot[bot]'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: read
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read # Required for CI access
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock (matches CI)
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
# Add Poetry to PATH
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Check poetry.lock
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry lock
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
echo "Warning: poetry.lock not up to date, but continuing for setup"
|
||||
git checkout poetry.lock # Reset for clean setup
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Install Playwright browsers for frontend testing
|
||||
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
|
||||
# - name: Install Playwright browsers
|
||||
# working-directory: autogpt_platform/frontend
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Copy default environment files
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Copy default environment files for development
|
||||
cp .env.default .env
|
||||
cp backend/.env.default backend/.env
|
||||
cp frontend/.env.default frontend/.env
|
||||
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
|
||||
restore-keys: |
|
||||
docker-images-v2-${{ runner.os }}-
|
||||
docker-images-v1-${{ runner.os }}-
|
||||
|
||||
- name: Load or pull Docker images
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
mkdir -p ~/docker-cache
|
||||
|
||||
# Define image list for easy maintenance
|
||||
IMAGES=(
|
||||
"redis:latest"
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||
echo "Docker cache found, loading images in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||
echo "Loading $image..."
|
||||
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||
fi
|
||||
done
|
||||
wait
|
||||
echo "All cached images loaded"
|
||||
else
|
||||
echo "No Docker cache found, pulling images in parallel..."
|
||||
# Pull all images in parallel
|
||||
for image in "${IMAGES[@]}"; do
|
||||
docker pull "$image" &
|
||||
done
|
||||
wait
|
||||
|
||||
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||
echo "Saving Docker images to cache in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
echo "Saving $image..."
|
||||
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||
done
|
||||
wait
|
||||
echo "Docker image cache saved"
|
||||
else
|
||||
echo "Skipping cache save for PR/feature branch"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Docker images ready for use"
|
||||
|
||||
# Phase 2: Build migrate service with GitHub Actions cache
|
||||
- name: Build migrate Docker image with cache
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Build the migrate image with buildx for GHA caching
|
||||
docker buildx build \
|
||||
--cache-from type=gha \
|
||||
--cache-to type=gha,mode=max \
|
||||
--target migrate \
|
||||
--tag autogpt_platform-migrate:latest \
|
||||
--load \
|
||||
-f backend/Dockerfile \
|
||||
..
|
||||
|
||||
# Start services using pre-built images
|
||||
- name: Start Docker services for development
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Start essential services (migrate image already built with correct tag)
|
||||
docker compose --profile local up deps --no-build --detach
|
||||
echo "Waiting for services to be ready..."
|
||||
|
||||
# Wait for database to be ready
|
||||
echo "Checking database readiness..."
|
||||
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
|
||||
echo " Waiting for database..."
|
||||
sleep 2
|
||||
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
|
||||
|
||||
# Check migrate service status
|
||||
echo "Checking migration status..."
|
||||
docker compose ps migrate || echo " Migrate service not visible in ps output"
|
||||
|
||||
# Wait for migrate service to complete
|
||||
echo "Waiting for migrations to complete..."
|
||||
timeout 30 bash -c '
|
||||
ATTEMPTS=0
|
||||
while [ $ATTEMPTS -lt 15 ]; do
|
||||
ATTEMPTS=$((ATTEMPTS + 1))
|
||||
|
||||
# Check using docker directly (more reliable than docker compose ps)
|
||||
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
|
||||
|
||||
if [ -z "$CONTAINER_STATUS" ]; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
|
||||
echo "✅ Migrations completed successfully"
|
||||
docker compose logs migrate --tail=5 2>/dev/null || true
|
||||
exit 0
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
|
||||
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
|
||||
echo "❌ Migrations failed with exit code: $EXIT_CODE"
|
||||
echo "Migration logs:"
|
||||
docker compose logs migrate --tail=20 2>/dev/null || true
|
||||
exit 1
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
|
||||
else
|
||||
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
|
||||
echo "Final container check:"
|
||||
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
|
||||
echo "Migration logs (if available):"
|
||||
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
|
||||
' || echo "⚠️ Migration check completed with warnings, continuing..."
|
||||
|
||||
# Brief wait for other services to stabilize
|
||||
echo "Waiting 5 seconds for other services to stabilize..."
|
||||
sleep 5
|
||||
|
||||
# Verify installations and provide environment info
|
||||
- name: Verify setup and show environment info
|
||||
run: |
|
||||
echo "=== Python Setup ==="
|
||||
python --version
|
||||
poetry --version
|
||||
|
||||
echo "=== Node.js Setup ==="
|
||||
node --version
|
||||
pnpm --version
|
||||
|
||||
echo "=== Additional Tools ==="
|
||||
docker --version
|
||||
docker compose version
|
||||
gh --version || true
|
||||
|
||||
echo "=== Services Status ==="
|
||||
cd autogpt_platform
|
||||
docker compose ps || true
|
||||
|
||||
echo "=== Backend Dependencies ==="
|
||||
cd backend
|
||||
poetry show | head -10 || true
|
||||
|
||||
echo "=== Frontend Dependencies ==="
|
||||
cd ../frontend
|
||||
pnpm list --depth=0 | head -10 || true
|
||||
|
||||
echo "=== Environment Files ==="
|
||||
ls -la ../.env* || true
|
||||
ls -la .env* || true
|
||||
ls -la ../backend/.env* || true
|
||||
|
||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
||||
echo "🚀 Ready for development with Docker services running"
|
||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
|
||||
|
||||
- name: Run Claude Dependabot Analysis
|
||||
id: claude_review
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||
prompt: |
|
||||
You are Claude, an AI assistant specialized in reviewing Dependabot dependency update PRs.
|
||||
|
||||
Your primary tasks are:
|
||||
1. **Analyze the dependency changes** in this Dependabot PR
|
||||
2. **Look up changelogs** for all updated dependencies to understand what changed
|
||||
3. **Identify breaking changes** and assess potential impact on the AutoGPT codebase
|
||||
4. **Provide actionable recommendations** for the development team
|
||||
|
||||
## Analysis Process:
|
||||
|
||||
1. **Identify Changed Dependencies**:
|
||||
- Use git diff to see what dependencies were updated
|
||||
- Parse package.json, poetry.lock, requirements files, etc.
|
||||
- List all package versions: old → new
|
||||
|
||||
2. **Changelog Research**:
|
||||
- For each updated dependency, look up its changelog/release notes
|
||||
- Use WebFetch to access GitHub releases, NPM package pages, PyPI project pages. The pr should also have some details
|
||||
- Focus on versions between the old and new versions
|
||||
- Identify: breaking changes, deprecations, security fixes, new features
|
||||
|
||||
3. **Breaking Change Assessment**:
|
||||
- Categorize changes: BREAKING, MAJOR, MINOR, PATCH, SECURITY
|
||||
- Assess impact on AutoGPT's usage patterns
|
||||
- Check if AutoGPT uses affected APIs/features
|
||||
- Look for migration guides or upgrade instructions
|
||||
|
||||
4. **Codebase Impact Analysis**:
|
||||
- Search the AutoGPT codebase for usage of changed APIs
|
||||
- Identify files that might be affected by breaking changes
|
||||
- Check test files for deprecated usage patterns
|
||||
- Look for configuration changes needed
|
||||
|
||||
## Output Format:
|
||||
|
||||
Provide a comprehensive review comment with:
|
||||
|
||||
### 🔍 Dependency Analysis Summary
|
||||
- List of updated packages with version changes
|
||||
- Overall risk assessment (LOW/MEDIUM/HIGH)
|
||||
|
||||
### 📋 Detailed Changelog Review
|
||||
For each updated dependency:
|
||||
- **Package**: name (old_version → new_version)
|
||||
- **Changes**: Summary of key changes
|
||||
- **Breaking Changes**: List any breaking changes
|
||||
- **Security Fixes**: Note security improvements
|
||||
- **Migration Notes**: Any upgrade steps needed
|
||||
|
||||
### ⚠️ Impact Assessment
|
||||
- **Breaking Changes Found**: Yes/No with details
|
||||
- **Affected Files**: List AutoGPT files that may need updates
|
||||
- **Test Impact**: Any tests that may need updating
|
||||
- **Configuration Changes**: Required config updates
|
||||
|
||||
### 🛠️ Recommendations
|
||||
- **Action Required**: What the team should do
|
||||
- **Testing Focus**: Areas to test thoroughly
|
||||
- **Follow-up Tasks**: Any additional work needed
|
||||
- **Merge Recommendation**: APPROVE/REVIEW_NEEDED/HOLD
|
||||
|
||||
### 📚 Useful Links
|
||||
- Links to relevant changelogs, migration guides, documentation
|
||||
|
||||
Be thorough but concise. Focus on actionable insights that help the development team make informed decisions about the dependency updates.
|
||||
284
.github/workflows/claude.yml
vendored
284
.github/workflows/claude.yml
vendored
@@ -30,296 +30,18 @@ jobs:
|
||||
github.event.issue.author_association == 'COLLABORATOR'
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
contents: read
|
||||
pull-requests: read
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read # Required for CI access
|
||||
steps:
|
||||
- name: Checkout code
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock (matches CI)
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
# Add Poetry to PATH
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Check poetry.lock
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry lock
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
echo "Warning: poetry.lock not up to date, but continuing for setup"
|
||||
git checkout poetry.lock # Reset for clean setup
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Install Playwright browsers for frontend testing
|
||||
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
|
||||
# - name: Install Playwright browsers
|
||||
# working-directory: autogpt_platform/frontend
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Copy default environment files
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Copy default environment files for development
|
||||
cp .env.default .env
|
||||
cp backend/.env.default backend/.env
|
||||
cp frontend/.env.default frontend/.env
|
||||
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
|
||||
restore-keys: |
|
||||
docker-images-v2-${{ runner.os }}-
|
||||
docker-images-v1-${{ runner.os }}-
|
||||
|
||||
- name: Load or pull Docker images
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
mkdir -p ~/docker-cache
|
||||
|
||||
# Define image list for easy maintenance
|
||||
IMAGES=(
|
||||
"redis:latest"
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||
echo "Docker cache found, loading images in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||
echo "Loading $image..."
|
||||
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||
fi
|
||||
done
|
||||
wait
|
||||
echo "All cached images loaded"
|
||||
else
|
||||
echo "No Docker cache found, pulling images in parallel..."
|
||||
# Pull all images in parallel
|
||||
for image in "${IMAGES[@]}"; do
|
||||
docker pull "$image" &
|
||||
done
|
||||
wait
|
||||
|
||||
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||
echo "Saving Docker images to cache in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
echo "Saving $image..."
|
||||
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||
done
|
||||
wait
|
||||
echo "Docker image cache saved"
|
||||
else
|
||||
echo "Skipping cache save for PR/feature branch"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Docker images ready for use"
|
||||
|
||||
# Phase 2: Build migrate service with GitHub Actions cache
|
||||
- name: Build migrate Docker image with cache
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Build the migrate image with buildx for GHA caching
|
||||
docker buildx build \
|
||||
--cache-from type=gha \
|
||||
--cache-to type=gha,mode=max \
|
||||
--target migrate \
|
||||
--tag autogpt_platform-migrate:latest \
|
||||
--load \
|
||||
-f backend/Dockerfile \
|
||||
..
|
||||
|
||||
# Start services using pre-built images
|
||||
- name: Start Docker services for development
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Start essential services (migrate image already built with correct tag)
|
||||
docker compose --profile local up deps --no-build --detach
|
||||
echo "Waiting for services to be ready..."
|
||||
|
||||
# Wait for database to be ready
|
||||
echo "Checking database readiness..."
|
||||
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
|
||||
echo " Waiting for database..."
|
||||
sleep 2
|
||||
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
|
||||
|
||||
# Check migrate service status
|
||||
echo "Checking migration status..."
|
||||
docker compose ps migrate || echo " Migrate service not visible in ps output"
|
||||
|
||||
# Wait for migrate service to complete
|
||||
echo "Waiting for migrations to complete..."
|
||||
timeout 30 bash -c '
|
||||
ATTEMPTS=0
|
||||
while [ $ATTEMPTS -lt 15 ]; do
|
||||
ATTEMPTS=$((ATTEMPTS + 1))
|
||||
|
||||
# Check using docker directly (more reliable than docker compose ps)
|
||||
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
|
||||
|
||||
if [ -z "$CONTAINER_STATUS" ]; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
|
||||
echo "✅ Migrations completed successfully"
|
||||
docker compose logs migrate --tail=5 2>/dev/null || true
|
||||
exit 0
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
|
||||
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
|
||||
echo "❌ Migrations failed with exit code: $EXIT_CODE"
|
||||
echo "Migration logs:"
|
||||
docker compose logs migrate --tail=20 2>/dev/null || true
|
||||
exit 1
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
|
||||
else
|
||||
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
|
||||
echo "Final container check:"
|
||||
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
|
||||
echo "Migration logs (if available):"
|
||||
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
|
||||
' || echo "⚠️ Migration check completed with warnings, continuing..."
|
||||
|
||||
# Brief wait for other services to stabilize
|
||||
echo "Waiting 5 seconds for other services to stabilize..."
|
||||
sleep 5
|
||||
|
||||
# Verify installations and provide environment info
|
||||
- name: Verify setup and show environment info
|
||||
run: |
|
||||
echo "=== Python Setup ==="
|
||||
python --version
|
||||
poetry --version
|
||||
|
||||
echo "=== Node.js Setup ==="
|
||||
node --version
|
||||
pnpm --version
|
||||
|
||||
echo "=== Additional Tools ==="
|
||||
docker --version
|
||||
docker compose version
|
||||
gh --version || true
|
||||
|
||||
echo "=== Services Status ==="
|
||||
cd autogpt_platform
|
||||
docker compose ps || true
|
||||
|
||||
echo "=== Backend Dependencies ==="
|
||||
cd backend
|
||||
poetry show | head -10 || true
|
||||
|
||||
echo "=== Frontend Dependencies ==="
|
||||
cd ../frontend
|
||||
pnpm list --depth=0 | head -10 || true
|
||||
|
||||
echo "=== Environment Files ==="
|
||||
ls -la ../.env* || true
|
||||
ls -la .env* || true
|
||||
ls -la ../backend/.env* || true
|
||||
|
||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
||||
echo "🚀 Ready for development with Docker services running"
|
||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
|
||||
- name: Run Claude Code
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@v1
|
||||
uses: anthropics/claude-code-action@beta
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*), Bash(gh pr edit:*)"
|
||||
--model opus
|
||||
additional_permissions: |
|
||||
actions: read
|
||||
|
||||
@@ -3,7 +3,6 @@ name: AutoGPT Platform - Deploy Prod Environment
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: 'read'
|
||||
@@ -18,8 +17,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name || 'master' }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -39,7 +36,7 @@ jobs:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
|
||||
|
||||
trigger:
|
||||
needs: migrate
|
||||
runs-on: ubuntu-latest
|
||||
@@ -50,5 +47,4 @@ jobs:
|
||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
event-type: build_deploy_prod
|
||||
client-payload: |
|
||||
{"ref": "${{ github.ref_name || 'master' }}", "repository": "${{ github.repository }}"}
|
||||
client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "repository": "${{ github.repository }}"}'
|
||||
@@ -5,13 +5,6 @@ on:
|
||||
branches: [ dev ]
|
||||
paths:
|
||||
- 'autogpt_platform/**'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
git_ref:
|
||||
description: 'Git ref (branch/tag) of AutoGPT to deploy'
|
||||
required: true
|
||||
default: 'master'
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: 'read'
|
||||
@@ -26,8 +19,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -57,4 +48,4 @@ jobs:
|
||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
event-type: build_deploy_dev
|
||||
client-payload: '{"ref": "${{ github.event.inputs.git_ref || github.ref }}", "repository": "${{ github.repository }}"}'
|
||||
client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "repository": "${{ github.repository }}"}'
|
||||
|
||||
5
.github/workflows/platform-backend-ci.yml
vendored
5
.github/workflows/platform-backend-ci.yml
vendored
@@ -37,7 +37,9 @@ jobs:
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: redis:latest
|
||||
image: bitnami/redis:6.2
|
||||
env:
|
||||
REDIS_PASSWORD: testpassword
|
||||
ports:
|
||||
- 6379:6379
|
||||
rabbitmq:
|
||||
@@ -202,6 +204,7 @@ jobs:
|
||||
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
env:
|
||||
|
||||
2
.github/workflows/platform-frontend-ci.yml
vendored
2
.github/workflows/platform-frontend-ci.yml
vendored
@@ -160,7 +160,7 @@ jobs:
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d
|
||||
docker compose -f ../docker-compose.yml up -d
|
||||
env:
|
||||
DOCKER_BUILDKIT: 1
|
||||
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
[pr_reviewer]
|
||||
num_code_suggestions=0
|
||||
|
||||
[pr_code_suggestions]
|
||||
commitable_code_suggestions=false
|
||||
num_code_suggestions=0
|
||||
|
||||
@@ -61,41 +61,24 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd frontend && pnpm i
|
||||
|
||||
# Generate API client from OpenAPI spec
|
||||
pnpm generate:api
|
||||
cd frontend && npm install
|
||||
|
||||
# Start development server
|
||||
pnpm dev
|
||||
npm run dev
|
||||
|
||||
# Run E2E tests
|
||||
pnpm test
|
||||
npm run test
|
||||
|
||||
# Run Storybook for component development
|
||||
pnpm storybook
|
||||
npm run storybook
|
||||
|
||||
# Build production
|
||||
pnpm build
|
||||
|
||||
# Format and lint
|
||||
pnpm format
|
||||
npm run build
|
||||
|
||||
# Type checking
|
||||
pnpm types
|
||||
npm run types
|
||||
```
|
||||
|
||||
**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns.
|
||||
|
||||
**Key Frontend Conventions:**
|
||||
|
||||
- Separate render logic from data/behavior in components
|
||||
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Use function declarations (not arrow functions) for components/handlers
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Only use Phosphor Icons
|
||||
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Backend Architecture
|
||||
@@ -109,16 +92,11 @@ pnpm types
|
||||
|
||||
### Frontend Architecture
|
||||
|
||||
- **Framework**: Next.js 15 App Router (client-first approach)
|
||||
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
|
||||
- **State Management**: React Query for server state, co-located UI state in components/hooks
|
||||
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
|
||||
- **Framework**: Next.js App Router with React Server Components
|
||||
- **State Management**: React hooks + Supabase client for real-time updates
|
||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
||||
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
|
||||
- **Icons**: Phosphor Icons only
|
||||
- **UI Components**: Radix UI primitives with Tailwind CSS styling
|
||||
- **Feature Flags**: LaunchDarkly integration
|
||||
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
|
||||
- **Testing**: Playwright for E2E, Storybook for component development
|
||||
|
||||
### Key Concepts
|
||||
|
||||
@@ -172,7 +150,6 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
**Adding a new block:**
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
@@ -180,7 +157,6 @@ Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `/backend/backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
@@ -201,20 +177,10 @@ ex: do the inputs and outputs tie well together?
|
||||
|
||||
**Frontend feature development:**
|
||||
|
||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
|
||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
||||
- Add `usePageName.ts` hook for logic
|
||||
- Put sub-components in local `components/` folder
|
||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
1. Components go in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for new components
|
||||
4. Test with Playwright if user-facing
|
||||
|
||||
### Security Implementation
|
||||
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
.PHONY: start-core stop-core logs-core format lint migrate run-backend run-frontend
|
||||
|
||||
# Run just Supabase + Redis + RabbitMQ
|
||||
start-core:
|
||||
docker compose up -d deps
|
||||
|
||||
# Stop core services
|
||||
stop-core:
|
||||
docker compose stop deps
|
||||
|
||||
reset-db:
|
||||
rm -rf db/docker/volumes/db/data
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
|
||||
# View logs for core services
|
||||
logs-core:
|
||||
docker compose logs -f deps
|
||||
|
||||
# Run formatting and linting for backend and frontend
|
||||
format:
|
||||
cd backend && poetry run format
|
||||
cd frontend && pnpm format
|
||||
cd frontend && pnpm lint
|
||||
|
||||
init-env:
|
||||
cp -n .env.default .env || true
|
||||
cd backend && cp -n .env.default .env || true
|
||||
cd frontend && cp -n .env.default .env || true
|
||||
|
||||
|
||||
# Run migrations for backend
|
||||
migrate:
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
|
||||
run-backend:
|
||||
cd backend && poetry run app
|
||||
|
||||
run-frontend:
|
||||
cd frontend && pnpm dev
|
||||
|
||||
test-data:
|
||||
cd backend && poetry run python test/test_data_creator.py
|
||||
|
||||
help:
|
||||
@echo "Usage: make <target>"
|
||||
@echo "Targets:"
|
||||
@echo " start-core - Start just the core services (Supabase, Redis, RabbitMQ) in background"
|
||||
@echo " stop-core - Stop the core services"
|
||||
@echo " reset-db - Reset the database by deleting the volume"
|
||||
@echo " logs-core - Tail the logs for core services"
|
||||
@echo " format - Format & lint backend (Python) and frontend (TypeScript) code"
|
||||
@echo " migrate - Run backend database migrations"
|
||||
@echo " run-backend - Run the backend FastAPI server"
|
||||
@echo " run-frontend - Run the frontend Next.js development server"
|
||||
@echo " test-data - Run the test data creator"
|
||||
@@ -38,37 +38,6 @@ To run the AutoGPT Platform, follow these steps:
|
||||
|
||||
4. After all the services are in ready state, open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
|
||||
### Running Just Core services
|
||||
|
||||
You can now run the following to enable just the core services.
|
||||
|
||||
```
|
||||
# For help
|
||||
make help
|
||||
|
||||
# Run just Supabase + Redis + RabbitMQ
|
||||
make start-core
|
||||
|
||||
# Stop core services
|
||||
make stop-core
|
||||
|
||||
# View logs from core services
|
||||
make logs-core
|
||||
|
||||
# Run formatting and linting for backend and frontend
|
||||
make format
|
||||
|
||||
# Run migrations for backend database
|
||||
make migrate
|
||||
|
||||
# Run backend server
|
||||
make run-backend
|
||||
|
||||
# Run frontend development server
|
||||
make run-frontend
|
||||
|
||||
```
|
||||
|
||||
### Docker Compose Commands
|
||||
|
||||
Here are some useful Docker Compose commands for managing your AutoGPT Platform:
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
class APIKeyContainer(NamedTuple):
|
||||
"""Container for API key parts."""
|
||||
|
||||
raw: str
|
||||
prefix: str
|
||||
postfix: str
|
||||
hash: str
|
||||
|
||||
|
||||
class APIKeyManager:
|
||||
PREFIX: str = "agpt_"
|
||||
PREFIX_LENGTH: int = 8
|
||||
POSTFIX_LENGTH: int = 8
|
||||
|
||||
def generate_api_key(self) -> APIKeyContainer:
|
||||
"""Generate a new API key with all its parts."""
|
||||
raw_key = f"{self.PREFIX}{secrets.token_urlsafe(32)}"
|
||||
return APIKeyContainer(
|
||||
raw=raw_key,
|
||||
prefix=raw_key[: self.PREFIX_LENGTH],
|
||||
postfix=raw_key[-self.POSTFIX_LENGTH :],
|
||||
hash=hashlib.sha256(raw_key.encode()).hexdigest(),
|
||||
)
|
||||
|
||||
def verify_api_key(self, provided_key: str, stored_hash: str) -> bool:
|
||||
"""Verify if a provided API key matches the stored hash."""
|
||||
if not provided_key.startswith(self.PREFIX):
|
||||
return False
|
||||
provided_hash = hashlib.sha256(provided_key.encode()).hexdigest()
|
||||
return secrets.compare_digest(provided_hash, stored_hash)
|
||||
@@ -1,78 +0,0 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import NamedTuple
|
||||
|
||||
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
|
||||
|
||||
|
||||
class APIKeyContainer(NamedTuple):
|
||||
"""Container for API key parts."""
|
||||
|
||||
key: str
|
||||
head: str
|
||||
tail: str
|
||||
hash: str
|
||||
salt: str
|
||||
|
||||
|
||||
class APIKeySmith:
|
||||
PREFIX: str = "agpt_"
|
||||
HEAD_LENGTH: int = 8
|
||||
TAIL_LENGTH: int = 8
|
||||
|
||||
def generate_key(self) -> APIKeyContainer:
|
||||
"""Generate a new API key with secure hashing."""
|
||||
raw_key = f"{self.PREFIX}{secrets.token_urlsafe(32)}"
|
||||
hash, salt = self.hash_key(raw_key)
|
||||
|
||||
return APIKeyContainer(
|
||||
key=raw_key,
|
||||
head=raw_key[: self.HEAD_LENGTH],
|
||||
tail=raw_key[-self.TAIL_LENGTH :],
|
||||
hash=hash,
|
||||
salt=salt,
|
||||
)
|
||||
|
||||
def verify_key(
|
||||
self, provided_key: str, known_hash: str, known_salt: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Verify an API key against a known hash (+ salt).
|
||||
Supports verifying both legacy SHA256 and secure Scrypt hashes.
|
||||
"""
|
||||
if not provided_key.startswith(self.PREFIX):
|
||||
return False
|
||||
|
||||
# Handle legacy SHA256 hashes (migration support)
|
||||
if known_salt is None:
|
||||
legacy_hash = hashlib.sha256(provided_key.encode()).hexdigest()
|
||||
return secrets.compare_digest(legacy_hash, known_hash)
|
||||
|
||||
try:
|
||||
salt_bytes = bytes.fromhex(known_salt)
|
||||
provided_hash = self._hash_key_with_salt(provided_key, salt_bytes)
|
||||
return secrets.compare_digest(provided_hash, known_hash)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
def hash_key(self, raw_key: str) -> tuple[str, str]:
|
||||
"""Migrate a legacy hash to secure hash format."""
|
||||
salt = self._generate_salt()
|
||||
hash = self._hash_key_with_salt(raw_key, salt)
|
||||
return hash, salt.hex()
|
||||
|
||||
def _generate_salt(self) -> bytes:
|
||||
"""Generate a random salt for hashing."""
|
||||
return secrets.token_bytes(32)
|
||||
|
||||
def _hash_key_with_salt(self, raw_key: str, salt: bytes) -> str:
|
||||
"""Hash API key using Scrypt with salt."""
|
||||
kdf = Scrypt(
|
||||
length=32,
|
||||
salt=salt,
|
||||
n=2**14, # CPU/memory cost parameter
|
||||
r=8, # Block size parameter
|
||||
p=1, # Parallelization parameter
|
||||
)
|
||||
key_hash = kdf.derive(raw_key.encode())
|
||||
return key_hash.hex()
|
||||
@@ -1,79 +0,0 @@
|
||||
import hashlib
|
||||
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
|
||||
|
||||
def test_generate_api_key():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
assert key.key.startswith(keysmith.PREFIX)
|
||||
assert key.head == key.key[: keysmith.HEAD_LENGTH]
|
||||
assert key.tail == key.key[-keysmith.TAIL_LENGTH :]
|
||||
assert len(key.hash) == 64 # 32 bytes hex encoded
|
||||
assert len(key.salt) == 64 # 32 bytes hex encoded
|
||||
|
||||
|
||||
def test_verify_new_secure_key():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Test correct key validates
|
||||
assert keysmith.verify_key(key.key, key.hash, key.salt) is True
|
||||
|
||||
# Test wrong key fails
|
||||
wrong_key = f"{keysmith.PREFIX}wrongkey123"
|
||||
assert keysmith.verify_key(wrong_key, key.hash, key.salt) is False
|
||||
|
||||
|
||||
def test_verify_legacy_key():
|
||||
keysmith = APIKeySmith()
|
||||
legacy_key = f"{keysmith.PREFIX}legacykey123"
|
||||
legacy_hash = hashlib.sha256(legacy_key.encode()).hexdigest()
|
||||
|
||||
# Test legacy key validates without salt
|
||||
assert keysmith.verify_key(legacy_key, legacy_hash) is True
|
||||
|
||||
# Test wrong legacy key fails
|
||||
wrong_key = f"{keysmith.PREFIX}wronglegacy"
|
||||
assert keysmith.verify_key(wrong_key, legacy_hash) is False
|
||||
|
||||
|
||||
def test_rehash_existing_key():
|
||||
keysmith = APIKeySmith()
|
||||
legacy_key = f"{keysmith.PREFIX}migratekey123"
|
||||
|
||||
# Migrate the legacy key
|
||||
new_hash, new_salt = keysmith.hash_key(legacy_key)
|
||||
|
||||
# Verify migrated key works
|
||||
assert keysmith.verify_key(legacy_key, new_hash, new_salt) is True
|
||||
|
||||
# Verify different key fails with migrated hash
|
||||
wrong_key = f"{keysmith.PREFIX}wrongkey"
|
||||
assert keysmith.verify_key(wrong_key, new_hash, new_salt) is False
|
||||
|
||||
|
||||
def test_invalid_key_prefix():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Test key without proper prefix fails
|
||||
invalid_key = "invalid_prefix_key"
|
||||
assert keysmith.verify_key(invalid_key, key.hash, key.salt) is False
|
||||
|
||||
|
||||
def test_secure_hash_requires_salt():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Secure hash without salt should fail
|
||||
assert keysmith.verify_key(key.key, key.hash) is False
|
||||
|
||||
|
||||
def test_invalid_salt_format():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Invalid salt format should fail gracefully
|
||||
assert keysmith.verify_key(key.key, key.hash, "invalid_hex") is False
|
||||
@@ -10,7 +10,7 @@ from .jwt_utils import get_jwt_payload, verify_user
|
||||
from .models import User
|
||||
|
||||
|
||||
async def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid authenticated user.
|
||||
|
||||
@@ -20,9 +20,7 @@ async def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -
|
||||
return verify_user(jwt_payload, admin_only=False)
|
||||
|
||||
|
||||
async def requires_admin_user(
|
||||
jwt_payload: dict = fastapi.Security(get_jwt_payload),
|
||||
) -> User:
|
||||
def requires_admin_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid admin user.
|
||||
|
||||
@@ -32,7 +30,7 @@ async def requires_admin_user(
|
||||
return verify_user(jwt_payload, admin_only=True)
|
||||
|
||||
|
||||
async def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
|
||||
def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
|
||||
"""
|
||||
FastAPI dependency that returns the ID of the authenticated user.
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ class TestAuthDependencies:
|
||||
"""Create a test client."""
|
||||
return TestClient(app)
|
||||
|
||||
async def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
|
||||
def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user with valid JWT payload."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
@@ -53,12 +53,12 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = await requires_user(jwt_payload)
|
||||
user = requires_user(jwt_payload)
|
||||
assert isinstance(user, User)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.role == "user"
|
||||
|
||||
async def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
|
||||
def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user accepts admin users."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
@@ -69,28 +69,28 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = await requires_user(jwt_payload)
|
||||
user = requires_user(jwt_payload)
|
||||
assert user.user_id == "admin-456"
|
||||
assert user.role == "admin"
|
||||
|
||||
async def test_requires_user_missing_sub(self):
|
||||
def test_requires_user_missing_sub(self):
|
||||
"""Test requires_user with missing user ID."""
|
||||
jwt_payload = {"role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await requires_user(jwt_payload)
|
||||
requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
async def test_requires_user_empty_sub(self):
|
||||
def test_requires_user_empty_sub(self):
|
||||
"""Test requires_user with empty user ID."""
|
||||
jwt_payload = {"sub": "", "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await requires_user(jwt_payload)
|
||||
requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
async def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
|
||||
def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
|
||||
"""Test requires_admin_user with admin role."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-789",
|
||||
@@ -101,51 +101,51 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = await requires_admin_user(jwt_payload)
|
||||
user = requires_admin_user(jwt_payload)
|
||||
assert user.user_id == "admin-789"
|
||||
assert user.role == "admin"
|
||||
|
||||
async def test_requires_admin_user_with_regular_user(self):
|
||||
def test_requires_admin_user_with_regular_user(self):
|
||||
"""Test requires_admin_user rejects regular users."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await requires_admin_user(jwt_payload)
|
||||
requires_admin_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Admin access required" in exc_info.value.detail
|
||||
|
||||
async def test_requires_admin_user_missing_role(self):
|
||||
def test_requires_admin_user_missing_role(self):
|
||||
"""Test requires_admin_user with missing role."""
|
||||
jwt_payload = {"sub": "user-123", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
await requires_admin_user(jwt_payload)
|
||||
requires_admin_user(jwt_payload)
|
||||
|
||||
async def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
|
||||
def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
|
||||
"""Test get_user_id extracts user ID correctly."""
|
||||
jwt_payload = {"sub": "user-id-xyz", "role": "user"}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user_id = await get_user_id(jwt_payload)
|
||||
user_id = get_user_id(jwt_payload)
|
||||
assert user_id == "user-id-xyz"
|
||||
|
||||
async def test_get_user_id_missing_sub(self):
|
||||
def test_get_user_id_missing_sub(self):
|
||||
"""Test get_user_id with missing user ID."""
|
||||
jwt_payload = {"role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_user_id(jwt_payload)
|
||||
get_user_id(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
async def test_get_user_id_none_sub(self):
|
||||
def test_get_user_id_none_sub(self):
|
||||
"""Test get_user_id with None user ID."""
|
||||
jwt_payload = {"sub": None, "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_user_id(jwt_payload)
|
||||
get_user_id(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
@@ -170,7 +170,7 @@ class TestAuthDependenciesIntegration:
|
||||
|
||||
return _create_token
|
||||
|
||||
async def test_endpoint_auth_enabled_no_token(self):
|
||||
def test_endpoint_auth_enabled_no_token(self):
|
||||
"""Test endpoints require token when auth is enabled."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -184,7 +184,7 @@ class TestAuthDependenciesIntegration:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_endpoint_with_valid_token(self, create_token):
|
||||
def test_endpoint_with_valid_token(self, create_token):
|
||||
"""Test endpoint with valid JWT token."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -203,7 +203,7 @@ class TestAuthDependenciesIntegration:
|
||||
assert response.status_code == 200
|
||||
assert response.json()["user_id"] == "test-user"
|
||||
|
||||
async def test_admin_endpoint_requires_admin_role(self, create_token):
|
||||
def test_admin_endpoint_requires_admin_role(self, create_token):
|
||||
"""Test admin endpoint rejects non-admin users."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -240,7 +240,7 @@ class TestAuthDependenciesIntegration:
|
||||
class TestAuthDependenciesEdgeCases:
|
||||
"""Edge case tests for authentication dependencies."""
|
||||
|
||||
async def test_dependency_with_complex_payload(self):
|
||||
def test_dependency_with_complex_payload(self):
|
||||
"""Test dependencies handle complex JWT payloads."""
|
||||
complex_payload = {
|
||||
"sub": "user-123",
|
||||
@@ -256,14 +256,14 @@ class TestAuthDependenciesEdgeCases:
|
||||
"exp": 9999999999,
|
||||
}
|
||||
|
||||
user = await requires_user(complex_payload)
|
||||
user = requires_user(complex_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email == "test@example.com"
|
||||
|
||||
admin = await requires_admin_user(complex_payload)
|
||||
admin = requires_admin_user(complex_payload)
|
||||
assert admin.role == "admin"
|
||||
|
||||
async def test_dependency_with_unicode_in_payload(self):
|
||||
def test_dependency_with_unicode_in_payload(self):
|
||||
"""Test dependencies handle unicode in JWT payloads."""
|
||||
unicode_payload = {
|
||||
"sub": "user-😀-123",
|
||||
@@ -272,11 +272,11 @@ class TestAuthDependenciesEdgeCases:
|
||||
"name": "日本語",
|
||||
}
|
||||
|
||||
user = await requires_user(unicode_payload)
|
||||
user = requires_user(unicode_payload)
|
||||
assert "😀" in user.user_id
|
||||
assert user.email == "测试@example.com"
|
||||
|
||||
async def test_dependency_with_null_values(self):
|
||||
def test_dependency_with_null_values(self):
|
||||
"""Test dependencies handle null values in payload."""
|
||||
null_payload = {
|
||||
"sub": "user-123",
|
||||
@@ -286,18 +286,18 @@ class TestAuthDependenciesEdgeCases:
|
||||
"metadata": None,
|
||||
}
|
||||
|
||||
user = await requires_user(null_payload)
|
||||
user = requires_user(null_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email is None
|
||||
|
||||
async def test_concurrent_requests_isolation(self):
|
||||
def test_concurrent_requests_isolation(self):
|
||||
"""Test that concurrent requests don't interfere with each other."""
|
||||
payload1 = {"sub": "user-1", "role": "user"}
|
||||
payload2 = {"sub": "user-2", "role": "admin"}
|
||||
|
||||
# Simulate concurrent processing
|
||||
user1 = await requires_user(payload1)
|
||||
user2 = await requires_admin_user(payload2)
|
||||
user1 = requires_user(payload1)
|
||||
user2 = requires_admin_user(payload2)
|
||||
|
||||
assert user1.user_id == "user-1"
|
||||
assert user2.user_id == "user-2"
|
||||
@@ -314,7 +314,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
({"sub": "user", "role": "user"}, "Admin access required", True),
|
||||
],
|
||||
)
|
||||
async def test_dependency_error_cases(
|
||||
def test_dependency_error_cases(
|
||||
self, payload, expected_error: str, admin_only: bool
|
||||
):
|
||||
"""Test that errors propagate correctly through dependencies."""
|
||||
@@ -325,7 +325,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
verify_user(payload, admin_only=admin_only)
|
||||
assert expected_error in exc_info.value.detail
|
||||
|
||||
async def test_dependency_valid_user(self):
|
||||
def test_dependency_valid_user(self):
|
||||
"""Test valid user case for dependency."""
|
||||
# Import verify_user to test it directly since dependencies use FastAPI Security
|
||||
from autogpt_libs.auth.jwt_utils import verify_user
|
||||
|
||||
@@ -16,7 +16,7 @@ bearer_jwt_auth = HTTPBearer(
|
||||
)
|
||||
|
||||
|
||||
async def get_jwt_payload(
|
||||
def get_jwt_payload(
|
||||
credentials: HTTPAuthorizationCredentials | None = Security(bearer_jwt_auth),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -116,32 +116,32 @@ def test_parse_jwt_token_missing_audience():
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
async def test_get_jwt_payload_with_valid_token():
|
||||
def test_get_jwt_payload_with_valid_token():
|
||||
"""Test extracting JWT payload with valid bearer token."""
|
||||
token = create_token(TEST_USER_PAYLOAD)
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
|
||||
result = await jwt_utils.get_jwt_payload(credentials)
|
||||
result = jwt_utils.get_jwt_payload(credentials)
|
||||
assert result["sub"] == "test-user-id"
|
||||
assert result["role"] == "user"
|
||||
|
||||
|
||||
async def test_get_jwt_payload_no_credentials():
|
||||
def test_get_jwt_payload_no_credentials():
|
||||
"""Test JWT payload when no credentials provided."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await jwt_utils.get_jwt_payload(None)
|
||||
jwt_utils.get_jwt_payload(None)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Authorization header is missing" in exc_info.value.detail
|
||||
|
||||
|
||||
async def test_get_jwt_payload_invalid_token():
|
||||
def test_get_jwt_payload_invalid_token():
|
||||
"""Test JWT payload extraction with invalid token."""
|
||||
credentials = HTTPAuthorizationCredentials(
|
||||
scheme="Bearer", credentials="invalid.token.here"
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await jwt_utils.get_jwt_payload(credentials)
|
||||
jwt_utils.get_jwt_payload(credentials)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid token" in exc_info.value.detail
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import logging
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
@@ -94,36 +93,42 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
config = LoggingConfig()
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
structured_logging = config.enable_cloud_logging or force_cloud_logging
|
||||
|
||||
# Console output handlers
|
||||
if not structured_logging:
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
if config.level == logging.DEBUG:
|
||||
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
if config.level == logging.DEBUG:
|
||||
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
if config.level == logging.DEBUG:
|
||||
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
if config.level == logging.DEBUG:
|
||||
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
log_handlers += [stdout, stderr]
|
||||
log_handlers += [stdout, stderr]
|
||||
|
||||
# Cloud logging setup
|
||||
else:
|
||||
# Use Google Cloud Structured Log Handler. Log entries are printed to stdout
|
||||
# in a JSON format which is automatically picked up by Google Cloud Logging.
|
||||
from google.cloud.logging.handlers import StructuredLogHandler
|
||||
if config.enable_cloud_logging or force_cloud_logging:
|
||||
import google.cloud.logging
|
||||
from google.cloud.logging.handlers import CloudLoggingHandler
|
||||
from google.cloud.logging_v2.handlers.transports import (
|
||||
BackgroundThreadTransport,
|
||||
)
|
||||
|
||||
structured_log_handler = StructuredLogHandler(stream=sys.stdout)
|
||||
structured_log_handler.setLevel(config.level)
|
||||
log_handlers.append(structured_log_handler)
|
||||
client = google.cloud.logging.Client()
|
||||
# Use BackgroundThreadTransport to prevent blocking the main thread
|
||||
# and deadlocks when gRPC calls to Google Cloud Logging hang
|
||||
cloud_handler = CloudLoggingHandler(
|
||||
client,
|
||||
name="autogpt_logs",
|
||||
transport=BackgroundThreadTransport,
|
||||
)
|
||||
cloud_handler.setLevel(config.level)
|
||||
log_handlers.append(cloud_handler)
|
||||
|
||||
# File logging setup
|
||||
if config.enable_file_logging:
|
||||
@@ -134,13 +139,8 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
print(f"Log directory: {config.log_dir}")
|
||||
|
||||
# Activity log handler (INFO and above)
|
||||
# Security fix: Use RotatingFileHandler with size limits to prevent disk exhaustion
|
||||
activity_log_handler = RotatingFileHandler(
|
||||
config.log_dir / LOG_FILE,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB per file
|
||||
backupCount=3, # Keep 3 backup files (40MB total)
|
||||
activity_log_handler = logging.FileHandler(
|
||||
config.log_dir / LOG_FILE, "a", "utf-8"
|
||||
)
|
||||
activity_log_handler.setLevel(config.level)
|
||||
activity_log_handler.setFormatter(
|
||||
@@ -150,13 +150,8 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
|
||||
if config.level == logging.DEBUG:
|
||||
# Debug log handler (all levels)
|
||||
# Security fix: Use RotatingFileHandler with size limits
|
||||
debug_log_handler = RotatingFileHandler(
|
||||
config.log_dir / DEBUG_LOG_FILE,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB per file
|
||||
backupCount=3, # Keep 3 backup files (40MB total)
|
||||
debug_log_handler = logging.FileHandler(
|
||||
config.log_dir / DEBUG_LOG_FILE, "a", "utf-8"
|
||||
)
|
||||
debug_log_handler.setLevel(logging.DEBUG)
|
||||
debug_log_handler.setFormatter(
|
||||
@@ -165,13 +160,8 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
log_handlers.append(debug_log_handler)
|
||||
|
||||
# Error log handler (ERROR and above)
|
||||
# Security fix: Use RotatingFileHandler with size limits
|
||||
error_log_handler = RotatingFileHandler(
|
||||
config.log_dir / ERROR_LOG_FILE,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB per file
|
||||
backupCount=3, # Keep 3 backup files (40MB total)
|
||||
error_log_handler = logging.FileHandler(
|
||||
config.log_dir / ERROR_LOG_FILE, "a", "utf-8"
|
||||
)
|
||||
error_log_handler.setLevel(logging.ERROR)
|
||||
error_log_handler.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT, no_color=True))
|
||||
@@ -179,13 +169,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
|
||||
# Configure the root logger
|
||||
logging.basicConfig(
|
||||
format=(
|
||||
"%(levelname)s %(message)s"
|
||||
if structured_logging
|
||||
else (
|
||||
DEBUG_LOG_FORMAT if config.level == logging.DEBUG else SIMPLE_LOG_FORMAT
|
||||
)
|
||||
),
|
||||
format=DEBUG_LOG_FORMAT if config.level == logging.DEBUG else SIMPLE_LOG_FORMAT,
|
||||
level=config.level,
|
||||
handlers=log_handlers,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
@@ -15,8 +13,8 @@ class RateLimitSettings(BaseSettings):
|
||||
default="6379", description="Redis port", validation_alias="REDIS_PORT"
|
||||
)
|
||||
|
||||
redis_password: Optional[str] = Field(
|
||||
default=None,
|
||||
redis_password: str = Field(
|
||||
default="password",
|
||||
description="Redis password",
|
||||
validation_alias="REDIS_PASSWORD",
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ class RateLimiter:
|
||||
self,
|
||||
redis_host: str = RATE_LIMIT_SETTINGS.redis_host,
|
||||
redis_port: str = RATE_LIMIT_SETTINGS.redis_port,
|
||||
redis_password: str | None = RATE_LIMIT_SETTINGS.redis_password,
|
||||
redis_password: str = RATE_LIMIT_SETTINGS.redis_password,
|
||||
requests_per_minute: int = RATE_LIMIT_SETTINGS.requests_per_minute,
|
||||
):
|
||||
self.redis = Redis(
|
||||
|
||||
266
autogpt_platform/autogpt_libs/autogpt_libs/utils/cache.py
Normal file
266
autogpt_platform/autogpt_libs/autogpt_libs/utils/cache.py
Normal file
@@ -0,0 +1,266 @@
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
|
||||
pass
|
||||
|
||||
|
||||
def thread_cached(
|
||||
func: Callable[P, R] | Callable[P, Awaitable[R]],
|
||||
) -> Callable[P, R] | Callable[P, Awaitable[R]]:
|
||||
thread_local = threading.local()
|
||||
|
||||
def _clear():
|
||||
if hasattr(thread_local, "cache"):
|
||||
del thread_local.cache
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = await cast(Callable[P, Awaitable[R]], func)(
|
||||
*args, **kwargs
|
||||
)
|
||||
return cache[key]
|
||||
|
||||
setattr(async_wrapper, "clear_cache", _clear)
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(sync_wrapper, "clear_cache", _clear)
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
|
||||
|
||||
FuncT = TypeVar("FuncT")
|
||||
|
||||
|
||||
R_co = TypeVar("R_co", covariant=True)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AsyncCachedFunction(Protocol[P, R_co]):
|
||||
"""Protocol for async functions with cache management methods."""
|
||||
|
||||
def cache_clear(self) -> None:
|
||||
"""Clear all cached entries."""
|
||||
return None
|
||||
|
||||
def cache_info(self) -> dict[str, int | None]:
|
||||
"""Get cache statistics."""
|
||||
return {}
|
||||
|
||||
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
|
||||
"""Call the cached function."""
|
||||
return None # type: ignore
|
||||
|
||||
|
||||
def async_ttl_cache(
|
||||
maxsize: int = 128, ttl_seconds: int | None = None
|
||||
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
|
||||
"""
|
||||
TTL (Time To Live) cache decorator for async functions.
|
||||
|
||||
Similar to functools.lru_cache but works with async functions and includes optional TTL.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum number of cached entries
|
||||
ttl_seconds: Time to live in seconds. If None, entries never expire (like lru_cache)
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
|
||||
Example:
|
||||
# With TTL
|
||||
@async_ttl_cache(maxsize=1000, ttl_seconds=300)
|
||||
async def api_call(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
# Without TTL (permanent cache like lru_cache)
|
||||
@async_ttl_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
async_func: Callable[P, Awaitable[R]],
|
||||
) -> AsyncCachedFunction[P, R]:
|
||||
# Cache storage - use union type to handle both cases
|
||||
cache_storage: dict[tuple, R | Tuple[R, float]] = {}
|
||||
|
||||
@wraps(async_func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# Create cache key from arguments
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we have a valid cached entry
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
# No TTL - return cached result directly
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cast(R, cache_storage[key])
|
||||
else:
|
||||
# With TTL - check expiration
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cast(R, result)
|
||||
else:
|
||||
# Expired entry
|
||||
del cache_storage[key]
|
||||
logger.debug(
|
||||
f"Cache entry expired for {async_func.__name__}"
|
||||
)
|
||||
|
||||
# Cache miss or expired - fetch fresh data
|
||||
logger.debug(
|
||||
f"Cache miss for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
result = await async_func(*args, **kwargs)
|
||||
|
||||
# Store in cache
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Simple cleanup when cache gets too large
|
||||
if len(cache_storage) > maxsize:
|
||||
# Remove oldest entries (simple FIFO cleanup)
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
logger.debug(
|
||||
f"Cache cleanup: removed {len(oldest_keys)} entries for {async_func.__name__}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
# Add cache management methods (similar to functools.lru_cache)
|
||||
def cache_clear() -> None:
|
||||
cache_storage.clear()
|
||||
|
||||
def cache_info() -> dict[str, int | None]:
|
||||
return {
|
||||
"size": len(cache_storage),
|
||||
"maxsize": maxsize,
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
|
||||
# Attach methods to wrapper
|
||||
setattr(wrapper, "cache_clear", cache_clear)
|
||||
setattr(wrapper, "cache_info", cache_info)
|
||||
|
||||
return cast(AsyncCachedFunction[P, R], wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def async_cache(
|
||||
func: Callable[P, Awaitable[R]],
|
||||
) -> AsyncCachedFunction[P, R]:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def async_cache(
|
||||
func: None = None,
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
|
||||
pass
|
||||
|
||||
|
||||
def async_cache(
|
||||
func: Callable[P, Awaitable[R]] | None = None,
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
) -> (
|
||||
AsyncCachedFunction[P, R]
|
||||
| Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]
|
||||
):
|
||||
"""
|
||||
Process-level cache decorator for async functions (no TTL).
|
||||
|
||||
Similar to functools.lru_cache but works with async functions.
|
||||
This is a convenience wrapper around async_ttl_cache with ttl_seconds=None.
|
||||
|
||||
Args:
|
||||
func: The async function to cache (when used without parentheses)
|
||||
maxsize: Maximum number of cached entries
|
||||
|
||||
Returns:
|
||||
Decorated function or decorator
|
||||
|
||||
Example:
|
||||
# Without parentheses (uses default maxsize=128)
|
||||
@async_cache
|
||||
async def get_data(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
# With parentheses and custom maxsize
|
||||
@async_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
# Expensive computation here
|
||||
return {"result": param}
|
||||
"""
|
||||
if func is None:
|
||||
# Called with parentheses @async_cache() or @async_cache(maxsize=...)
|
||||
return async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
else:
|
||||
# Called without parentheses @async_cache
|
||||
decorator = async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
return decorator(func)
|
||||
705
autogpt_platform/autogpt_libs/autogpt_libs/utils/cache_test.py
Normal file
705
autogpt_platform/autogpt_libs/autogpt_libs/utils/cache_test.py
Normal file
@@ -0,0 +1,705 @@
|
||||
"""Tests for the @thread_cached decorator.
|
||||
|
||||
This module tests the thread-local caching functionality including:
|
||||
- Basic caching for sync and async functions
|
||||
- Thread isolation (each thread has its own cache)
|
||||
- Cache clearing functionality
|
||||
- Exception handling (exceptions are not cached)
|
||||
- Argument handling (positional vs keyword arguments)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt_libs.utils.cache import (
|
||||
async_cache,
|
||||
async_ttl_cache,
|
||||
clear_thread_cache,
|
||||
thread_cached,
|
||||
)
|
||||
|
||||
|
||||
class TestThreadCached:
|
||||
def test_sync_function_caching(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def expensive_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x + y
|
||||
|
||||
assert expensive_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert expensive_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert expensive_function(1, y=2) == 3
|
||||
assert call_count == 2
|
||||
|
||||
assert expensive_function(2, 3) == 5
|
||||
assert call_count == 3
|
||||
|
||||
assert expensive_function(1) == 1
|
||||
assert call_count == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function_caching(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
async def expensive_async_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x + y
|
||||
|
||||
assert await expensive_async_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert await expensive_async_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert await expensive_async_function(1, y=2) == 3
|
||||
assert call_count == 2
|
||||
|
||||
assert await expensive_async_function(2, 3) == 5
|
||||
assert call_count == 3
|
||||
|
||||
def test_thread_isolation(self):
|
||||
call_count = 0
|
||||
results = {}
|
||||
|
||||
@thread_cached
|
||||
def thread_specific_function(x: int) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{threading.current_thread().name}-{x}"
|
||||
|
||||
def worker(thread_id: int):
|
||||
result1 = thread_specific_function(1)
|
||||
result2 = thread_specific_function(1)
|
||||
result3 = thread_specific_function(2)
|
||||
results[thread_id] = (result1, result2, result3)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||
futures = [executor.submit(worker, i) for i in range(3)]
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
assert call_count >= 2
|
||||
|
||||
for thread_id, (r1, r2, r3) in results.items():
|
||||
assert r1 == r2
|
||||
assert r1 != r3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_thread_isolation(self):
|
||||
call_count = 0
|
||||
results = {}
|
||||
|
||||
@thread_cached
|
||||
async def async_thread_specific_function(x: int) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return f"{threading.current_thread().name}-{x}"
|
||||
|
||||
async def async_worker(worker_id: int):
|
||||
result1 = await async_thread_specific_function(1)
|
||||
result2 = await async_thread_specific_function(1)
|
||||
result3 = await async_thread_specific_function(2)
|
||||
results[worker_id] = (result1, result2, result3)
|
||||
|
||||
tasks = [async_worker(i) for i in range(3)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
for worker_id, (r1, r2, r3) in results.items():
|
||||
assert r1 == r2
|
||||
assert r1 != r3
|
||||
|
||||
def test_clear_cache_sync(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
|
||||
assert clearable_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
assert clearable_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
clear_thread_cache(clearable_function)
|
||||
|
||||
assert clearable_function(5) == 10
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_cache_async(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
async def clearable_async_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x * 2
|
||||
|
||||
assert await clearable_async_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
assert await clearable_async_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
clear_thread_cache(clearable_async_function)
|
||||
|
||||
assert await clearable_async_function(5) == 10
|
||||
assert call_count == 2
|
||||
|
||||
def test_simple_arguments(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def simple_function(a: str, b: int, c: str = "default") -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{a}-{b}-{c}"
|
||||
|
||||
# First call with all positional args
|
||||
result1 = simple_function("test", 42, "custom")
|
||||
assert call_count == 1
|
||||
|
||||
# Same args, all positional - should hit cache
|
||||
result2 = simple_function("test", 42, "custom")
|
||||
assert call_count == 1
|
||||
assert result1 == result2
|
||||
|
||||
# Same values but last arg as keyword - creates different cache key
|
||||
result3 = simple_function("test", 42, c="custom")
|
||||
assert call_count == 2
|
||||
assert result1 == result3 # Same result, different cache entry
|
||||
|
||||
# Different value - new cache entry
|
||||
result4 = simple_function("test", 43, "custom")
|
||||
assert call_count == 3
|
||||
assert result1 != result4
|
||||
|
||||
def test_positional_vs_keyword_args(self):
|
||||
"""Test that positional and keyword arguments create different cache entries."""
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def func(a: int, b: int = 10) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"result-{a}-{b}"
|
||||
|
||||
# All positional
|
||||
result1 = func(1, 2)
|
||||
assert call_count == 1
|
||||
assert result1 == "result-1-2"
|
||||
|
||||
# Same values, but second arg as keyword
|
||||
result2 = func(1, b=2)
|
||||
assert call_count == 2 # Different cache key!
|
||||
assert result2 == "result-1-2" # Same result
|
||||
|
||||
# Verify both are cached separately
|
||||
func(1, 2) # Uses first cache entry
|
||||
assert call_count == 2
|
||||
|
||||
func(1, b=2) # Uses second cache entry
|
||||
assert call_count == 2
|
||||
|
||||
def test_exception_handling(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def failing_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if x < 0:
|
||||
raise ValueError("Negative value")
|
||||
return x * 2
|
||||
|
||||
assert failing_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
failing_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
failing_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
assert failing_function(5) == 10
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_exception_handling(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
async def async_failing_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
if x < 0:
|
||||
raise ValueError("Negative value")
|
||||
return x * 2
|
||||
|
||||
assert await async_failing_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await async_failing_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await async_failing_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
def test_sync_caching_performance(self):
|
||||
@thread_cached
|
||||
def slow_function(x: int) -> int:
|
||||
print(f"slow_function called with x={x}")
|
||||
time.sleep(0.1)
|
||||
return x * 2
|
||||
|
||||
start = time.time()
|
||||
result1 = slow_function(5)
|
||||
first_call_time = time.time() - start
|
||||
print(f"First call took {first_call_time:.4f} seconds")
|
||||
|
||||
start = time.time()
|
||||
result2 = slow_function(5)
|
||||
second_call_time = time.time() - start
|
||||
print(f"Second call took {second_call_time:.4f} seconds")
|
||||
|
||||
assert result1 == result2 == 10
|
||||
assert first_call_time > 0.09
|
||||
assert second_call_time < 0.01
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_caching_performance(self):
|
||||
@thread_cached
|
||||
async def slow_async_function(x: int) -> int:
|
||||
print(f"slow_async_function called with x={x}")
|
||||
await asyncio.sleep(0.1)
|
||||
return x * 2
|
||||
|
||||
start = time.time()
|
||||
result1 = await slow_async_function(5)
|
||||
first_call_time = time.time() - start
|
||||
print(f"First async call took {first_call_time:.4f} seconds")
|
||||
|
||||
start = time.time()
|
||||
result2 = await slow_async_function(5)
|
||||
second_call_time = time.time() - start
|
||||
print(f"Second async call took {second_call_time:.4f} seconds")
|
||||
|
||||
assert result1 == result2 == 10
|
||||
assert first_call_time > 0.09
|
||||
assert second_call_time < 0.01
|
||||
|
||||
def test_with_mock_objects(self):
|
||||
mock = Mock(return_value=42)
|
||||
|
||||
@thread_cached
|
||||
def function_using_mock(x: int) -> int:
|
||||
return mock(x)
|
||||
|
||||
assert function_using_mock(1) == 42
|
||||
assert mock.call_count == 1
|
||||
|
||||
assert function_using_mock(1) == 42
|
||||
assert mock.call_count == 1
|
||||
|
||||
assert function_using_mock(2) == 42
|
||||
assert mock.call_count == 2
|
||||
|
||||
|
||||
class TestAsyncTTLCache:
|
||||
"""Tests for the @async_ttl_cache decorator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching(self):
|
||||
"""Test basic caching functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result3 = await cached_function(2, 3)
|
||||
assert result3 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_expiration(self):
|
||||
"""Test that cache entries expire after TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def short_lived_cache(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First call
|
||||
result1 = await short_lived_cache(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Second call immediately - should use cache
|
||||
result2 = await short_lived_cache(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Third call after expiration - should call function again
|
||||
result3 = await short_lived_cache(5)
|
||||
assert result3 == 10
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_info(self):
|
||||
"""Test cache info functionality."""
|
||||
|
||||
@async_ttl_cache(maxsize=5, ttl_seconds=300)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] == 300
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_clear(self):
|
||||
"""Test cache clearing functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 4
|
||||
|
||||
# First call
|
||||
result1 = await clearable_function(2)
|
||||
assert result1 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
result2 = await clearable_function(2)
|
||||
assert result2 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Clear cache
|
||||
clearable_function.cache_clear()
|
||||
|
||||
# Third call after clear - should call function again
|
||||
result3 = await clearable_function(2)
|
||||
assert result3 == 8
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maxsize_cleanup(self):
|
||||
"""Test that cache cleans up when maxsize is exceeded."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=3, ttl_seconds=60)
|
||||
async def size_limited_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# Fill cache to maxsize
|
||||
await size_limited_function(1) # call_count: 1
|
||||
await size_limited_function(2) # call_count: 2
|
||||
await size_limited_function(3) # call_count: 3
|
||||
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] == 3
|
||||
|
||||
# Add one more entry - should trigger cleanup
|
||||
await size_limited_function(4) # call_count: 4
|
||||
|
||||
# Cache size should be reduced (cleanup removes oldest entries)
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] is not None and info["size"] <= 3 # Should be cleaned up
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_argument_variations(self):
|
||||
"""Test caching with different argument patterns."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def arg_test_function(a: int, b: str = "default", *, c: int = 100) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{a}-{b}-{c}"
|
||||
|
||||
# Different ways to call with same logical arguments
|
||||
result1 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
|
||||
# Same arguments, same order - should use cache
|
||||
result2 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
assert result1 == result2
|
||||
|
||||
# Different arguments - should call function
|
||||
result3 = await arg_test_function(2, "test", c=200)
|
||||
assert call_count == 2
|
||||
assert result1 != result3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_handling(self):
|
||||
"""Test that exceptions are not cached."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def exception_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if x < 0:
|
||||
raise ValueError("Negative value not allowed")
|
||||
return x * 2
|
||||
|
||||
# Successful call - should be cached
|
||||
result1 = await exception_function(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Same successful call - should use cache
|
||||
result2 = await exception_function(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Exception call - should not be cached
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
# Same exception call - should call again (not cached)
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_calls(self):
|
||||
"""Test caching behavior with concurrent calls."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def concurrent_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.05) # Simulate work
|
||||
return x * x
|
||||
|
||||
# Launch concurrent calls with same arguments
|
||||
tasks = [concurrent_function(3) for _ in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All results should be the same
|
||||
assert all(result == 9 for result in results)
|
||||
|
||||
# Note: Due to race conditions, call_count might be up to 5 for concurrent calls
|
||||
# This tests that the cache doesn't break under concurrent access
|
||||
assert 1 <= call_count <= 5
|
||||
|
||||
|
||||
class TestAsyncCache:
|
||||
"""Tests for the @async_cache decorator (no TTL)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching_no_ttl(self):
|
||||
"""Test basic caching functionality without TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_cache(maxsize=10)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Third call after some time - should still use cache (no TTL)
|
||||
await asyncio.sleep(0.05)
|
||||
result3 = await cached_function(1, 2)
|
||||
assert result3 == 3
|
||||
assert call_count == 1 # Still no additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result4 = await cached_function(2, 3)
|
||||
assert result4 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_ttl_vs_ttl_behavior(self):
|
||||
"""Test the difference between TTL and no-TTL caching."""
|
||||
ttl_call_count = 0
|
||||
no_ttl_call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_call_count
|
||||
ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
@async_cache(maxsize=10) # No TTL
|
||||
async def no_ttl_function(x: int) -> int:
|
||||
nonlocal no_ttl_call_count
|
||||
no_ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First calls
|
||||
await ttl_function(5)
|
||||
await no_ttl_function(5)
|
||||
assert ttl_call_count == 1
|
||||
assert no_ttl_call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Second calls after TTL expiry
|
||||
await ttl_function(5) # Should call function again (TTL expired)
|
||||
await no_ttl_function(5) # Should use cache (no TTL)
|
||||
assert ttl_call_count == 2 # TTL function called again
|
||||
assert no_ttl_call_count == 1 # No-TTL function still cached
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_cache_info(self):
|
||||
"""Test cache info for no-TTL cache."""
|
||||
|
||||
@async_cache(maxsize=5)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] is None # No TTL
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
|
||||
class TestTTLOptional:
|
||||
"""Tests for optional TTL functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_none_behavior(self):
|
||||
"""Test that ttl_seconds=None works like no TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=None)
|
||||
async def no_ttl_via_none(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# First call
|
||||
result1 = await no_ttl_via_none(3)
|
||||
assert result1 == 9
|
||||
assert call_count == 1
|
||||
|
||||
# Wait (would expire if there was TTL)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Second call - should still use cache
|
||||
result2 = await no_ttl_via_none(3)
|
||||
assert result2 == 9
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Check cache info
|
||||
info = no_ttl_via_none.cache_info()
|
||||
assert info["ttl_seconds"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_options_comparison(self):
|
||||
"""Test different cache options work as expected."""
|
||||
ttl_calls = 0
|
||||
no_ttl_calls = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # With TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_calls
|
||||
ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
@async_cache(maxsize=10) # Process-level cache (no TTL)
|
||||
async def process_function(x: int) -> int:
|
||||
nonlocal no_ttl_calls
|
||||
no_ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
# Both should cache initially
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Immediate second calls - both should use cache
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# After TTL expiry
|
||||
await ttl_function(3) # Should call function again
|
||||
await process_function(3) # Should still use cache
|
||||
assert ttl_calls == 2 # TTL cache expired, called again
|
||||
assert no_ttl_calls == 1 # Process cache never expires
|
||||
76
autogpt_platform/autogpt_libs/poetry.lock
generated
76
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1002,18 +1002,6 @@ dynamodb = ["boto3 (>=1.9.71)"]
|
||||
redis = ["redis (>=2.10.5)"]
|
||||
test-filesource = ["pyyaml (>=5.3.1)", "watchdog (>=3.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "nodeenv"
|
||||
version = "1.9.1"
|
||||
description = "Node.js virtual environment builder"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"},
|
||||
{file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-api"
|
||||
version = "1.35.0"
|
||||
@@ -1359,27 +1347,6 @@ files = [
|
||||
{file = "pyrfc3339-2.0.1.tar.gz", hash = "sha256:e47843379ea35c1296c3b6c67a948a1a490ae0584edfcbdea0eaffb5dd29960b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyright"
|
||||
version = "1.1.404"
|
||||
description = "Command line wrapper for pyright"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pyright-1.1.404-py3-none-any.whl", hash = "sha256:c7b7ff1fdb7219c643079e4c3e7d4125f0dafcc19d253b47e898d130ea426419"},
|
||||
{file = "pyright-1.1.404.tar.gz", hash = "sha256:455e881a558ca6be9ecca0b30ce08aa78343ecc031d37a198ffa9a7a1abeb63e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
nodeenv = ">=1.6.0"
|
||||
typing-extensions = ">=4.1"
|
||||
|
||||
[package.extras]
|
||||
all = ["nodejs-wheel-binaries", "twine (>=3.4.1)"]
|
||||
dev = ["twine (>=3.4.1)"]
|
||||
nodejs = ["nodejs-wheel-binaries"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.4.1"
|
||||
@@ -1567,31 +1534,31 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.12.11"
|
||||
version = "0.12.9"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "ruff-0.12.11-py3-none-linux_armv6l.whl", hash = "sha256:93fce71e1cac3a8bf9200e63a38ac5c078f3b6baebffb74ba5274fb2ab276065"},
|
||||
{file = "ruff-0.12.11-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b8e33ac7b28c772440afa80cebb972ffd823621ded90404f29e5ab6d1e2d4b93"},
|
||||
{file = "ruff-0.12.11-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d69fb9d4937aa19adb2e9f058bc4fbfe986c2040acb1a4a9747734834eaa0bfd"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:411954eca8464595077a93e580e2918d0a01a19317af0a72132283e28ae21bee"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6a2c0a2e1a450f387bf2c6237c727dd22191ae8c00e448e0672d624b2bbd7fb0"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ca4c3a7f937725fd2413c0e884b5248a19369ab9bdd850b5781348ba283f644"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:4d1df0098124006f6a66ecf3581a7f7e754c4df7644b2e6704cd7ca80ff95211"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5a8dd5f230efc99a24ace3b77e3555d3fbc0343aeed3fc84c8d89e75ab2ff793"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4dc75533039d0ed04cd33fb8ca9ac9620b99672fe7ff1533b6402206901c34ee"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4fc58f9266d62c6eccc75261a665f26b4ef64840887fc6cbc552ce5b29f96cc8"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:5a0113bd6eafd545146440225fe60b4e9489f59eb5f5f107acd715ba5f0b3d2f"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0d737b4059d66295c3ea5720e6efc152623bb83fde5444209b69cd33a53e2000"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:916fc5defee32dbc1fc1650b576a8fed68f5e8256e2180d4d9855aea43d6aab2"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c984f07d7adb42d3ded5be894fb4007f30f82c87559438b4879fe7aa08c62b39"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e07fbb89f2e9249f219d88331c833860489b49cdf4b032b8e4432e9b13e8a4b9"},
|
||||
{file = "ruff-0.12.11-py3-none-win32.whl", hash = "sha256:c792e8f597c9c756e9bcd4d87cf407a00b60af77078c96f7b6366ea2ce9ba9d3"},
|
||||
{file = "ruff-0.12.11-py3-none-win_amd64.whl", hash = "sha256:a3283325960307915b6deb3576b96919ee89432ebd9c48771ca12ee8afe4a0fd"},
|
||||
{file = "ruff-0.12.11-py3-none-win_arm64.whl", hash = "sha256:bae4d6e6a2676f8fb0f98b74594a048bae1b944aab17e9f5d504062303c6dbea"},
|
||||
{file = "ruff-0.12.11.tar.gz", hash = "sha256:c6b09ae8426a65bbee5425b9d0b82796dbb07cb1af045743c79bfb163001165d"},
|
||||
{file = "ruff-0.12.9-py3-none-linux_armv6l.whl", hash = "sha256:fcebc6c79fcae3f220d05585229463621f5dbf24d79fdc4936d9302e177cfa3e"},
|
||||
{file = "ruff-0.12.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:aed9d15f8c5755c0e74467731a007fcad41f19bcce41cd75f768bbd687f8535f"},
|
||||
{file = "ruff-0.12.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5b15ea354c6ff0d7423814ba6d44be2807644d0c05e9ed60caca87e963e93f70"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d596c2d0393c2502eaabfef723bd74ca35348a8dac4267d18a94910087807c53"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1b15599931a1a7a03c388b9c5df1bfa62be7ede6eb7ef753b272381f39c3d0ff"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3d02faa2977fb6f3f32ddb7828e212b7dd499c59eb896ae6c03ea5c303575756"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:17d5b6b0b3a25259b69ebcba87908496e6830e03acfb929ef9fd4c58675fa2ea"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:72db7521860e246adbb43f6ef464dd2a532ef2ef1f5dd0d470455b8d9f1773e0"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a03242c1522b4e0885af63320ad754d53983c9599157ee33e77d748363c561ce"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fc83e4e9751e6c13b5046d7162f205d0a7bac5840183c5beebf824b08a27340"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:881465ed56ba4dd26a691954650de6ad389a2d1fdb130fe51ff18a25639fe4bb"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:43f07a3ccfc62cdb4d3a3348bf0588358a66da756aa113e071b8ca8c3b9826af"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:07adb221c54b6bba24387911e5734357f042e5669fa5718920ee728aba3cbadc"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f5cd34fabfdea3933ab85d72359f118035882a01bff15bd1d2b15261d85d5f66"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:f6be1d2ca0686c54564da8e7ee9e25f93bdd6868263805f8c0b8fc6a449db6d7"},
|
||||
{file = "ruff-0.12.9-py3-none-win32.whl", hash = "sha256:cc7a37bd2509974379d0115cc5608a1a4a6c4bff1b452ea69db83c8855d53f93"},
|
||||
{file = "ruff-0.12.9-py3-none-win_amd64.whl", hash = "sha256:6fb15b1977309741d7d098c8a3cb7a30bc112760a00fb6efb7abc85f00ba5908"},
|
||||
{file = "ruff-0.12.9-py3-none-win_arm64.whl", hash = "sha256:63c8c819739d86b96d500cce885956a1a48ab056bbcbc61b747ad494b2485089"},
|
||||
{file = "ruff-0.12.9.tar.gz", hash = "sha256:fbd94b2e3c623f659962934e52c2bea6fc6da11f667a427a368adaf3af2c866a"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1773,6 +1740,7 @@ files = [
|
||||
{file = "typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76"},
|
||||
{file = "typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36"},
|
||||
]
|
||||
markers = {dev = "python_version < \"3.11\""}
|
||||
|
||||
[[package]]
|
||||
name = "typing-inspection"
|
||||
@@ -1929,4 +1897,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "0c40b63c3c921846cf05ccfb4e685d4959854b29c2c302245f9832e20aac6954"
|
||||
content-hash = "ef7818fba061cea2841c6d7ca4852acde83e4f73b32fca1315e58660002bb0d0"
|
||||
|
||||
@@ -9,7 +9,6 @@ packages = [{ include = "autogpt_libs" }]
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<4.0"
|
||||
colorama = "^0.4.6"
|
||||
cryptography = "^45.0"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.116.1"
|
||||
google-cloud-logging = "^3.12.1"
|
||||
@@ -22,12 +21,11 @@ supabase = "^2.16.0"
|
||||
uvicorn = "^0.35.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pyright = "^1.1.404"
|
||||
ruff = "^0.12.9"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
pytest-mock = "^3.14.1"
|
||||
pytest-cov = "^6.2.1"
|
||||
ruff = "^0.12.11"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -21,7 +21,7 @@ PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
# Redis Configuration
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
# REDIS_PASSWORD=
|
||||
REDIS_PASSWORD=password
|
||||
|
||||
# RabbitMQ Credentials
|
||||
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
@@ -66,11 +66,6 @@ NVIDIA_API_KEY=
|
||||
GITHUB_CLIENT_ID=
|
||||
GITHUB_CLIENT_SECRET=
|
||||
|
||||
# Notion OAuth App server credentials - https://developers.notion.com/docs/authorization
|
||||
# Configure a public integration
|
||||
NOTION_CLIENT_ID=
|
||||
NOTION_CLIENT_SECRET=
|
||||
|
||||
# Google OAuth App server credentials - https://console.cloud.google.com/apis/credentials, and enable gmail api and set scopes
|
||||
# https://console.cloud.google.com/apis/credentials/consent ?project=<your_project_id>
|
||||
# You'll need to add/enable the following scopes (minimum):
|
||||
|
||||
10
autogpt_platform/backend/.gitignore
vendored
10
autogpt_platform/backend/.gitignore
vendored
@@ -9,12 +9,4 @@ secrets/*
|
||||
!secrets/.gitkeep
|
||||
|
||||
*.ignore.*
|
||||
*.ign.*
|
||||
|
||||
# Load test results and reports
|
||||
load-tests/*_RESULTS.md
|
||||
load-tests/*_REPORT.md
|
||||
load-tests/results/
|
||||
load-tests/*.json
|
||||
load-tests/*.log
|
||||
load-tests/node_modules/*
|
||||
*.ign.*
|
||||
@@ -9,15 +9,8 @@ WORKDIR /app
|
||||
|
||||
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
|
||||
|
||||
# Install Node.js repository key and setup
|
||||
# Update package list and install Python and build dependencies
|
||||
RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
||||
&& apt-get install -y curl ca-certificates gnupg \
|
||||
&& mkdir -p /etc/apt/keyrings \
|
||||
&& curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg \
|
||||
&& echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" | tee /etc/apt/sources.list.d/nodesource.list
|
||||
|
||||
# Update package list and install Python, Node.js, and build dependencies
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y \
|
||||
python3.13 \
|
||||
python3.13-dev \
|
||||
@@ -27,9 +20,7 @@ RUN apt-get update \
|
||||
libpq5 \
|
||||
libz-dev \
|
||||
libssl-dev \
|
||||
postgresql-client \
|
||||
nodejs \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
postgresql-client
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry
|
||||
ENV POETRY_NO_INTERACTION=1
|
||||
@@ -47,7 +38,6 @@ RUN poetry install --no-ansi --no-root
|
||||
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
RUN poetry run prisma generate
|
||||
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
@@ -64,18 +54,13 @@ ENV PATH=/opt/poetry/bin:$PATH
|
||||
# Install Python without upgrading system-managed packages
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.13 \
|
||||
python3-pip \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
python3-pip
|
||||
|
||||
# Copy only necessary files from builder
|
||||
COPY --from=builder /app /app
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
# Copy Node.js installation for Prisma
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||
# Copy Prisma binaries
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||
@@ -93,7 +78,6 @@ FROM server_dependencies AS migrate
|
||||
|
||||
# Migration stage only needs schema and migrations - much lighter than full backend
|
||||
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py /app/autogpt_platform/backend/backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
|
||||
|
||||
FROM server_dependencies AS server
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
@@ -5,8 +6,6 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from backend.util.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -16,7 +15,7 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
@functools.cache
|
||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
from backend.data.block import Block
|
||||
from backend.util.settings import Config
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus, NodesInputMasks
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.json import validate_with_jsonschema
|
||||
from backend.util.retry import func_retry
|
||||
@@ -20,7 +21,7 @@ _logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentExecutorBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
user_id: str = SchemaField(description="User ID")
|
||||
graph_id: str = SchemaField(description="Graph ID")
|
||||
graph_version: int = SchemaField(description="Graph Version")
|
||||
@@ -32,7 +33,7 @@ class AgentExecutorBlock(Block):
|
||||
input_schema: dict = SchemaField(description="Input schema for the graph")
|
||||
output_schema: dict = SchemaField(description="Output schema for the graph")
|
||||
|
||||
nodes_input_masks: Optional[NodesInputMasks] = SchemaField(
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = SchemaField(
|
||||
default=None, hidden=True
|
||||
)
|
||||
|
||||
@@ -54,7 +55,6 @@ class AgentExecutorBlock(Block):
|
||||
return validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Use BlockSchema to avoid automatic error field that could clash with graph outputs
|
||||
pass
|
||||
|
||||
def __init__(self):
|
||||
@@ -67,13 +67,7 @@ class AgentExecutorBlock(Block):
|
||||
categories={BlockCategory.AGENT},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
|
||||
from backend.executor import utils as execution_utils
|
||||
|
||||
@@ -83,7 +77,6 @@ class AgentExecutorBlock(Block):
|
||||
user_id=input_data.user_id,
|
||||
inputs=input_data.inputs,
|
||||
nodes_input_masks=input_data.nodes_input_masks,
|
||||
parent_graph_exec_id=graph_exec_id,
|
||||
)
|
||||
|
||||
logger = execution_utils.LogMetadata(
|
||||
|
||||
@@ -1,219 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks.llm import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AIBlockBase,
|
||||
AICredentials,
|
||||
AICredentialsField,
|
||||
LlmModel,
|
||||
LLMResponse,
|
||||
llm_call,
|
||||
)
|
||||
from backend.data.block import (
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class AIConditionBlock(AIBlockBase):
|
||||
"""
|
||||
An AI-powered condition block that uses natural language to evaluate conditions.
|
||||
|
||||
This block allows users to define conditions in plain English (e.g., "the input is an email address",
|
||||
"the input is a city in the USA") and uses AI to determine if the input satisfies the condition.
|
||||
It provides the same yes/no data pass-through functionality as the standard ConditionBlock.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
input_value: Any = SchemaField(
|
||||
description="The input value to evaluate with the AI condition",
|
||||
placeholder="Enter the value to be evaluated (text, number, or any data)",
|
||||
)
|
||||
condition: str = SchemaField(
|
||||
description="A plaintext English description of the condition to evaluate",
|
||||
placeholder="E.g., 'the input is the body of an email', 'the input is a City in the USA', 'the input is an error or a refusal'",
|
||||
)
|
||||
yes_value: Any = SchemaField(
|
||||
description="(Optional) Value to output if the condition is true. If not provided, input_value will be used.",
|
||||
placeholder="Leave empty to use input_value, or enter a specific value",
|
||||
default=None,
|
||||
)
|
||||
no_value: Any = SchemaField(
|
||||
description="(Optional) Value to output if the condition is false. If not provided, input_value will be used.",
|
||||
placeholder="Leave empty to use input_value, or enter a specific value",
|
||||
default=None,
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for evaluating the condition.",
|
||||
advanced=False,
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
result: bool = SchemaField(
|
||||
description="The result of the AI condition evaluation (True or False)"
|
||||
)
|
||||
yes_output: Any = SchemaField(
|
||||
description="The output value if the condition is true"
|
||||
)
|
||||
no_output: Any = SchemaField(
|
||||
description="The output value if the condition is false"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the AI evaluation is uncertain or fails"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="553ec5b8-6c45-4299-8d75-b394d05f72ff",
|
||||
input_schema=AIConditionBlock.Input,
|
||||
output_schema=AIConditionBlock.Output,
|
||||
description="Uses AI to evaluate natural language conditions and provide conditional outputs",
|
||||
categories={BlockCategory.AI, BlockCategory.LOGIC},
|
||||
test_input={
|
||||
"input_value": "john@example.com",
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "Valid email",
|
||||
"no_value": "Not an email",
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("result", True),
|
||||
("yes_output", "Valid email"),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda *args, **kwargs: LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response="true",
|
||||
tool_calls=None,
|
||||
prompt_tokens=50,
|
||||
completion_tokens=10,
|
||||
reasoning=None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
async def llm_call(
|
||||
self,
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
prompt: list,
|
||||
max_tokens: int,
|
||||
) -> LLMResponse:
|
||||
"""Wrapper method for llm_call to enable mocking in tests."""
|
||||
return await llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
force_json_output=False,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Evaluate the AI condition and return appropriate outputs.
|
||||
"""
|
||||
# Prepare the yes and no values, using input_value as default
|
||||
yes_value = (
|
||||
input_data.yes_value
|
||||
if input_data.yes_value is not None
|
||||
else input_data.input_value
|
||||
)
|
||||
no_value = (
|
||||
input_data.no_value
|
||||
if input_data.no_value is not None
|
||||
else input_data.input_value
|
||||
)
|
||||
|
||||
# Convert input_value to string for AI evaluation
|
||||
input_str = str(input_data.input_value)
|
||||
|
||||
# Create the prompt for AI evaluation
|
||||
prompt = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are an AI assistant that evaluates conditions based on input data. "
|
||||
"You must respond with only 'true' or 'false' (lowercase) to indicate whether "
|
||||
"the given condition is met by the input value. Be accurate and consider the "
|
||||
"context and meaning of both the input and the condition."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Input value: {input_str}\n"
|
||||
f"Condition to evaluate: {input_data.condition}\n\n"
|
||||
f"Does the input value satisfy the condition? Respond with only 'true' or 'false'."
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
# Call the LLM
|
||||
try:
|
||||
response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=10, # We only expect a true/false response
|
||||
)
|
||||
|
||||
# Extract the boolean result from the response
|
||||
response_text = response.response.strip().lower()
|
||||
if response_text == "true":
|
||||
result = True
|
||||
elif response_text == "false":
|
||||
result = False
|
||||
else:
|
||||
# If the response is not clear, try to interpret it using word boundaries
|
||||
import re
|
||||
|
||||
# Use word boundaries to avoid false positives like 'untrue' or '10'
|
||||
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", response_text))
|
||||
|
||||
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
|
||||
result = True
|
||||
elif tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
|
||||
result = False
|
||||
else:
|
||||
# Unclear or conflicting response - default to False and yield error
|
||||
result = False
|
||||
yield "error", f"Unclear AI response: '{response.response}'"
|
||||
|
||||
# Update internal stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
)
|
||||
)
|
||||
self.prompt = response.prompt
|
||||
|
||||
except Exception as e:
|
||||
# In case of any error, default to False to be safe
|
||||
result = False
|
||||
# Log the error but don't fail the block execution
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"AI condition evaluation failed: {str(e)}")
|
||||
yield "error", f"AI evaluation failed: {str(e)}"
|
||||
|
||||
# Yield results
|
||||
yield "result", result
|
||||
|
||||
if result:
|
||||
yield "yes_output", yes_value
|
||||
else:
|
||||
yield "no_output", no_value
|
||||
@@ -1,159 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.file import MediaFileType
|
||||
|
||||
|
||||
class GeminiImageModel(str, Enum):
|
||||
NANO_BANANA = "google/nano-banana"
|
||||
|
||||
|
||||
class OutputFormat(str, Enum):
|
||||
JPG = "jpg"
|
||||
PNG = "png"
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="replicate",
|
||||
api_key=SecretStr("mock-replicate-api-key"),
|
||||
title="Mock Replicate API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class AIImageCustomizerBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Replicate API key with permissions for Google Gemini image models",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="A text description of the image you want to generate",
|
||||
title="Prompt",
|
||||
)
|
||||
model: GeminiImageModel = SchemaField(
|
||||
description="The AI model to use for image generation and editing",
|
||||
default=GeminiImageModel.NANO_BANANA,
|
||||
title="Model",
|
||||
)
|
||||
images: list[MediaFileType] = SchemaField(
|
||||
description="Optional list of input images to reference or modify",
|
||||
default=[],
|
||||
title="Input Images",
|
||||
)
|
||||
output_format: OutputFormat = SchemaField(
|
||||
description="Format of the output image",
|
||||
default=OutputFormat.PNG,
|
||||
title="Output Format",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
image_url: MediaFileType = SchemaField(description="URL of the generated image")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d76bbe4c-930e-4894-8469-b66775511f71",
|
||||
description=(
|
||||
"Generate and edit custom images using Google's Nano-Banana model from Gemini 2.5. "
|
||||
"Provide a prompt and optional reference images to create or modify images."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
input_schema=AIImageCustomizerBlock.Input,
|
||||
output_schema=AIImageCustomizerBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Make the scene more vibrant and colorful",
|
||||
"model": GeminiImageModel.NANO_BANANA,
|
||||
"images": [],
|
||||
"output_format": OutputFormat.JPG,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
("image_url", "https://replicate.delivery/generated-image.jpg"),
|
||||
],
|
||||
test_mock={
|
||||
"run_model": lambda *args, **kwargs: MediaFileType(
|
||||
"https://replicate.delivery/generated-image.jpg"
|
||||
),
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = await self.run_model(
|
||||
api_key=credentials.api_key,
|
||||
model_name=input_data.model.value,
|
||||
prompt=input_data.prompt,
|
||||
images=input_data.images,
|
||||
output_format=input_data.output_format.value,
|
||||
)
|
||||
yield "image_url", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
async def run_model(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
images: list[MediaFileType],
|
||||
output_format: str,
|
||||
) -> MediaFileType:
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
|
||||
input_params: dict = {
|
||||
"prompt": prompt,
|
||||
"output_format": output_format,
|
||||
}
|
||||
|
||||
# Add images to input if provided (API expects "image_input" parameter)
|
||||
if images:
|
||||
input_params["image_input"] = [str(img) for img in images]
|
||||
|
||||
output: FileOutput | str = await client.async_run( # type: ignore
|
||||
model_name,
|
||||
input=input_params,
|
||||
wait=False,
|
||||
)
|
||||
|
||||
if isinstance(output, FileOutput):
|
||||
return MediaFileType(output.url)
|
||||
if isinstance(output, str):
|
||||
return MediaFileType(output)
|
||||
|
||||
raise ValueError("No output received from the model")
|
||||
@@ -5,7 +5,7 @@ from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
||||
from backend.data.block import Block, BlockCategory, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -101,7 +101,7 @@ class ImageGenModel(str, Enum):
|
||||
|
||||
|
||||
class AIImageGeneratorBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -135,8 +135,9 @@ class AIImageGeneratorBlock(Block):
|
||||
title="Image Style",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
image_url: str = SchemaField(description="URL of the generated image")
|
||||
error: str = SchemaField(description="Error message if generation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -6,13 +6,7 @@ from typing import Literal
|
||||
from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -60,7 +54,7 @@ class NormalizationStrategy(str, Enum):
|
||||
|
||||
|
||||
class AIMusicGeneratorBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -113,8 +107,9 @@ class AIMusicGeneratorBlock(Block):
|
||||
title="Normalization Strategy",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
result: str = SchemaField(description="URL of the generated audio file")
|
||||
error: str = SchemaField(description="Error message if the model run failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -6,13 +6,7 @@ from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -154,7 +148,7 @@ logger = logging.getLogger(__name__)
|
||||
class AIShortformVideoCreatorBlock(Block):
|
||||
"""Creates a short‑form text‑to‑video clip using stock or AI imagery."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -193,8 +187,9 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
placeholder=VisualMediaType.STOCK_VIDEOS,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="The URL of the created video")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
@@ -341,7 +336,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
class AIAdMakerVideoCreatorBlock(Block):
|
||||
"""Generates a 30‑second vertical AI advert using optional user‑supplied imagery."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -369,8 +364,9 @@ class AIAdMakerVideoCreatorBlock(Block):
|
||||
description="Restrict visuals to supplied images only.", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="URL of the finished advert")
|
||||
error: str = SchemaField(description="Error message on failure")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
@@ -528,7 +524,7 @@ class AIAdMakerVideoCreatorBlock(Block):
|
||||
class AIScreenshotToVideoAdBlock(Block):
|
||||
"""Creates an advert where the supplied screenshot is narrated by an AI avatar."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(description="Revid.ai API key")
|
||||
@@ -546,8 +542,9 @@ class AIScreenshotToVideoAdBlock(Block):
|
||||
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="Rendered video URL")
|
||||
error: str = SchemaField(description="Error, if encountered")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
|
||||
@@ -661,167 +661,6 @@ async def update_field(
|
||||
#################################################################
|
||||
|
||||
|
||||
async def get_table_schema(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
table_id_or_name: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Get the schema for a specific table, including all field definitions.
|
||||
|
||||
Args:
|
||||
credentials: Airtable API credentials
|
||||
base_id: The base ID
|
||||
table_id_or_name: The table ID or name
|
||||
|
||||
Returns:
|
||||
Dict containing table schema with fields information
|
||||
"""
|
||||
# First get all tables to find the right one
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
tables = data.get("tables", [])
|
||||
|
||||
# Find the matching table
|
||||
for table in tables:
|
||||
if table.get("id") == table_id_or_name or table.get("name") == table_id_or_name:
|
||||
return table
|
||||
|
||||
raise ValueError(f"Table '{table_id_or_name}' not found in base '{base_id}'")
|
||||
|
||||
|
||||
def get_empty_value_for_field(field_type: str) -> Any:
|
||||
"""
|
||||
Return the appropriate empty value for a given Airtable field type.
|
||||
|
||||
Args:
|
||||
field_type: The Airtable field type
|
||||
|
||||
Returns:
|
||||
The appropriate empty value for that field type
|
||||
"""
|
||||
# Fields that should be false when empty
|
||||
if field_type == "checkbox":
|
||||
return False
|
||||
|
||||
# Fields that should be empty arrays
|
||||
if field_type in [
|
||||
"multipleSelects",
|
||||
"multipleRecordLinks",
|
||||
"multipleAttachments",
|
||||
"multipleLookupValues",
|
||||
"multipleCollaborators",
|
||||
]:
|
||||
return []
|
||||
|
||||
# Fields that should be 0 when empty (numeric types)
|
||||
if field_type in [
|
||||
"number",
|
||||
"percent",
|
||||
"currency",
|
||||
"rating",
|
||||
"duration",
|
||||
"count",
|
||||
"autoNumber",
|
||||
]:
|
||||
return 0
|
||||
|
||||
# Fields that should be empty strings
|
||||
if field_type in [
|
||||
"singleLineText",
|
||||
"multilineText",
|
||||
"email",
|
||||
"url",
|
||||
"phoneNumber",
|
||||
"richText",
|
||||
"barcode",
|
||||
]:
|
||||
return ""
|
||||
|
||||
# Everything else gets null (dates, single selects, formulas, etc.)
|
||||
return None
|
||||
|
||||
|
||||
async def normalize_records(
|
||||
records: list[dict],
|
||||
table_schema: dict,
|
||||
include_field_metadata: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Normalize Airtable records to include all fields with proper empty values.
|
||||
|
||||
Args:
|
||||
records: List of record objects from Airtable API
|
||||
table_schema: Table schema containing field definitions
|
||||
include_field_metadata: Whether to include field metadata in response
|
||||
|
||||
Returns:
|
||||
Dict with normalized records and optionally field metadata
|
||||
"""
|
||||
fields = table_schema.get("fields", [])
|
||||
|
||||
# Normalize each record
|
||||
normalized_records = []
|
||||
for record in records:
|
||||
normalized = {
|
||||
"id": record.get("id"),
|
||||
"createdTime": record.get("createdTime"),
|
||||
"fields": {},
|
||||
}
|
||||
|
||||
# Add existing fields
|
||||
existing_fields = record.get("fields", {})
|
||||
|
||||
# Add all fields from schema, using empty values for missing ones
|
||||
for field in fields:
|
||||
field_name = field["name"]
|
||||
field_type = field["type"]
|
||||
|
||||
if field_name in existing_fields:
|
||||
# Field exists, use its value
|
||||
normalized["fields"][field_name] = existing_fields[field_name]
|
||||
else:
|
||||
# Field is missing, add appropriate empty value
|
||||
normalized["fields"][field_name] = get_empty_value_for_field(field_type)
|
||||
|
||||
normalized_records.append(normalized)
|
||||
|
||||
# Build result dictionary
|
||||
if include_field_metadata:
|
||||
field_metadata = {}
|
||||
for field in fields:
|
||||
metadata = {"type": field["type"], "id": field["id"]}
|
||||
|
||||
# Add type-specific metadata
|
||||
options = field.get("options", {})
|
||||
if field["type"] == "currency" and "symbol" in options:
|
||||
metadata["symbol"] = options["symbol"]
|
||||
metadata["precision"] = options.get("precision", 2)
|
||||
elif field["type"] == "duration" and "durationFormat" in options:
|
||||
metadata["format"] = options["durationFormat"]
|
||||
elif field["type"] == "percent" and "precision" in options:
|
||||
metadata["precision"] = options["precision"]
|
||||
elif (
|
||||
field["type"] in ["singleSelect", "multipleSelects"]
|
||||
and "choices" in options
|
||||
):
|
||||
metadata["choices"] = [choice["name"] for choice in options["choices"]]
|
||||
elif field["type"] == "rating" and "max" in options:
|
||||
metadata["max"] = options["max"]
|
||||
metadata["icon"] = options.get("icon", "star")
|
||||
metadata["color"] = options.get("color", "yellowBright")
|
||||
|
||||
field_metadata[field["name"]] = metadata
|
||||
|
||||
return {"records": normalized_records, "field_metadata": field_metadata}
|
||||
else:
|
||||
return {"records": normalized_records}
|
||||
|
||||
|
||||
async def list_records(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
@@ -1410,26 +1249,3 @@ async def list_bases(
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
|
||||
async def get_base_tables(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Get all tables for a specific base.
|
||||
|
||||
Args:
|
||||
credentials: Airtable API credentials
|
||||
base_id: The ID of the base
|
||||
|
||||
Returns:
|
||||
list[dict]: List of table objects with their schemas
|
||||
"""
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
return data.get("tables", [])
|
||||
|
||||
@@ -9,22 +9,21 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import create_base, get_base_tables, list_bases
|
||||
from ._api import create_base, list_bases
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableCreateBaseBlock(Block):
|
||||
"""
|
||||
Creates a new base in an Airtable workspace, or returns existing base if one with the same name exists.
|
||||
Creates a new base in an Airtable workspace.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -32,10 +31,6 @@ class AirtableCreateBaseBlock(Block):
|
||||
description="The workspace ID where the base will be created"
|
||||
)
|
||||
name: str = SchemaField(description="The name of the new base")
|
||||
find_existing: bool = SchemaField(
|
||||
description="If true, return existing base with same name instead of creating duplicate",
|
||||
default=True,
|
||||
)
|
||||
tables: list[dict] = SchemaField(
|
||||
description="At least one table and field must be specified. Array of table objects to create in the base. Each table should have 'name' and 'fields' properties",
|
||||
default=[
|
||||
@@ -54,19 +49,15 @@ class AirtableCreateBaseBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
base_id: str = SchemaField(description="The ID of the created or found base")
|
||||
class Output(BlockSchema):
|
||||
base_id: str = SchemaField(description="The ID of the created base")
|
||||
tables: list[dict] = SchemaField(description="Array of table objects")
|
||||
table: dict = SchemaField(description="A single table object")
|
||||
was_created: bool = SchemaField(
|
||||
description="True if a new base was created, False if existing was found",
|
||||
default=True,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f59b88a8-54ce-4676-a508-fd614b4e8dce",
|
||||
description="Create or find a base in Airtable",
|
||||
description="Create a new base in Airtable",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
@@ -75,31 +66,6 @@ class AirtableCreateBaseBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# If find_existing is true, check if a base with this name already exists
|
||||
if input_data.find_existing:
|
||||
# List all bases to check for existing one with same name
|
||||
# Note: Airtable API doesn't have a direct search, so we need to list and filter
|
||||
existing_bases = await list_bases(credentials)
|
||||
|
||||
for base in existing_bases.get("bases", []):
|
||||
if base.get("name") == input_data.name:
|
||||
# Base already exists, return it
|
||||
base_id = base.get("id")
|
||||
yield "base_id", base_id
|
||||
yield "was_created", False
|
||||
|
||||
# Get the tables for this base
|
||||
try:
|
||||
tables = await get_base_tables(credentials, base_id)
|
||||
yield "tables", tables
|
||||
for table in tables:
|
||||
yield "table", table
|
||||
except Exception:
|
||||
# If we can't get tables, return empty list
|
||||
yield "tables", []
|
||||
return
|
||||
|
||||
# No existing base found or find_existing is false, create new one
|
||||
data = await create_base(
|
||||
credentials,
|
||||
input_data.workspace_id,
|
||||
@@ -108,7 +74,6 @@ class AirtableCreateBaseBlock(Block):
|
||||
)
|
||||
|
||||
yield "base_id", data.get("id", None)
|
||||
yield "was_created", True
|
||||
yield "tables", data.get("tables", [])
|
||||
for table in data.get("tables", []):
|
||||
yield "table", table
|
||||
@@ -119,7 +84,7 @@ class AirtableListBasesBlock(Block):
|
||||
Lists all bases in an Airtable workspace that the user has access to.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -130,7 +95,7 @@ class AirtableListBasesBlock(Block):
|
||||
description="Pagination offset from previous request", default=""
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
bases: list[dict] = SchemaField(description="Array of base objects")
|
||||
offset: Optional[str] = SchemaField(
|
||||
description="Offset for next page (null if no more bases)", default=None
|
||||
|
||||
@@ -2,15 +2,14 @@
|
||||
Airtable record operation blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional, cast
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -19,9 +18,7 @@ from ._api import (
|
||||
create_record,
|
||||
delete_multiple_records,
|
||||
get_record,
|
||||
get_table_schema,
|
||||
list_records,
|
||||
normalize_records,
|
||||
update_multiple_records,
|
||||
)
|
||||
from ._config import airtable
|
||||
@@ -32,7 +29,7 @@ class AirtableListRecordsBlock(Block):
|
||||
Lists records from an Airtable table with optional filtering, sorting, and pagination.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -57,24 +54,12 @@ class AirtableListRecordsBlock(Block):
|
||||
return_fields: list[str] = SchemaField(
|
||||
description="Specific fields to return (comma-separated)", default=[]
|
||||
)
|
||||
normalize_output: bool = SchemaField(
|
||||
description="Normalize output to include all fields with proper empty values (disable to skip schema fetch and get raw Airtable response)",
|
||||
default=True,
|
||||
)
|
||||
include_field_metadata: bool = SchemaField(
|
||||
description="Include field type and configuration metadata (requires normalize_output=true)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of record objects")
|
||||
offset: Optional[str] = SchemaField(
|
||||
description="Offset for next page (null if no more records)", default=None
|
||||
)
|
||||
field_metadata: Optional[dict] = SchemaField(
|
||||
description="Field type and configuration metadata (only when include_field_metadata=true)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -88,7 +73,6 @@ class AirtableListRecordsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
data = await list_records(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -104,33 +88,8 @@ class AirtableListRecordsBlock(Block):
|
||||
fields=input_data.return_fields if input_data.return_fields else None,
|
||||
)
|
||||
|
||||
records = data.get("records", [])
|
||||
|
||||
# Normalize output if requested
|
||||
if input_data.normalize_output:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the records
|
||||
normalized_data = await normalize_records(
|
||||
records,
|
||||
table_schema,
|
||||
include_field_metadata=input_data.include_field_metadata,
|
||||
)
|
||||
|
||||
yield "records", normalized_data["records"]
|
||||
yield "offset", data.get("offset", None)
|
||||
|
||||
if (
|
||||
input_data.include_field_metadata
|
||||
and "field_metadata" in normalized_data
|
||||
):
|
||||
yield "field_metadata", normalized_data["field_metadata"]
|
||||
else:
|
||||
yield "records", records
|
||||
yield "offset", data.get("offset", None)
|
||||
yield "records", data.get("records", [])
|
||||
yield "offset", data.get("offset", None)
|
||||
|
||||
|
||||
class AirtableGetRecordBlock(Block):
|
||||
@@ -138,30 +97,18 @@ class AirtableGetRecordBlock(Block):
|
||||
Retrieves a single record from an Airtable table by its ID.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name")
|
||||
record_id: str = SchemaField(description="The record ID to retrieve")
|
||||
normalize_output: bool = SchemaField(
|
||||
description="Normalize output to include all fields with proper empty values (disable to skip schema fetch and get raw Airtable response)",
|
||||
default=True,
|
||||
)
|
||||
include_field_metadata: bool = SchemaField(
|
||||
description="Include field type and configuration metadata (requires normalize_output=true)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
id: str = SchemaField(description="The record ID")
|
||||
fields: dict = SchemaField(description="The record fields")
|
||||
created_time: str = SchemaField(description="The record created time")
|
||||
field_metadata: Optional[dict] = SchemaField(
|
||||
description="Field type and configuration metadata (only when include_field_metadata=true)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -175,7 +122,6 @@ class AirtableGetRecordBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
record = await get_record(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -183,34 +129,9 @@ class AirtableGetRecordBlock(Block):
|
||||
input_data.record_id,
|
||||
)
|
||||
|
||||
# Normalize output if requested
|
||||
if input_data.normalize_output:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the single record (wrap in list and unwrap result)
|
||||
normalized_data = await normalize_records(
|
||||
[record],
|
||||
table_schema,
|
||||
include_field_metadata=input_data.include_field_metadata,
|
||||
)
|
||||
|
||||
normalized_record = normalized_data["records"][0]
|
||||
yield "id", normalized_record.get("id", None)
|
||||
yield "fields", normalized_record.get("fields", None)
|
||||
yield "created_time", normalized_record.get("createdTime", None)
|
||||
|
||||
if (
|
||||
input_data.include_field_metadata
|
||||
and "field_metadata" in normalized_data
|
||||
):
|
||||
yield "field_metadata", normalized_data["field_metadata"]
|
||||
else:
|
||||
yield "id", record.get("id", None)
|
||||
yield "fields", record.get("fields", None)
|
||||
yield "created_time", record.get("createdTime", None)
|
||||
yield "id", record.get("id", None)
|
||||
yield "fields", record.get("fields", None)
|
||||
yield "created_time", record.get("createdTime", None)
|
||||
|
||||
|
||||
class AirtableCreateRecordsBlock(Block):
|
||||
@@ -218,7 +139,7 @@ class AirtableCreateRecordsBlock(Block):
|
||||
Creates one or more records in an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -227,10 +148,6 @@ class AirtableCreateRecordsBlock(Block):
|
||||
records: list[dict] = SchemaField(
|
||||
description="Array of records to create (each with 'fields' object)"
|
||||
)
|
||||
skip_normalization: bool = SchemaField(
|
||||
description="Skip output normalization to get raw Airtable response (faster but may have missing fields)",
|
||||
default=False,
|
||||
)
|
||||
typecast: bool = SchemaField(
|
||||
description="Automatically convert string values to appropriate types",
|
||||
default=False,
|
||||
@@ -240,7 +157,7 @@ class AirtableCreateRecordsBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of created record objects")
|
||||
details: dict = SchemaField(description="Details of the created records")
|
||||
|
||||
@@ -256,7 +173,7 @@ class AirtableCreateRecordsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
# The create_record API expects records in a specific format
|
||||
data = await create_record(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -265,22 +182,8 @@ class AirtableCreateRecordsBlock(Block):
|
||||
typecast=input_data.typecast if input_data.typecast else None,
|
||||
return_fields_by_field_id=input_data.return_fields_by_field_id,
|
||||
)
|
||||
result_records = cast(list[dict], data.get("records", []))
|
||||
|
||||
# Normalize output unless explicitly disabled
|
||||
if not input_data.skip_normalization and result_records:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the records
|
||||
normalized_data = await normalize_records(
|
||||
result_records, table_schema, include_field_metadata=False
|
||||
)
|
||||
result_records = normalized_data["records"]
|
||||
|
||||
yield "records", result_records
|
||||
yield "records", data.get("records", [])
|
||||
details = data.get("details", None)
|
||||
if details:
|
||||
yield "details", details
|
||||
@@ -291,7 +194,7 @@ class AirtableUpdateRecordsBlock(Block):
|
||||
Updates one or more existing records in an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -307,7 +210,7 @@ class AirtableUpdateRecordsBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of updated record objects")
|
||||
|
||||
def __init__(self):
|
||||
@@ -340,7 +243,7 @@ class AirtableDeleteRecordsBlock(Block):
|
||||
Deletes one or more records from an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -352,7 +255,7 @@ class AirtableDeleteRecordsBlock(Block):
|
||||
description="Array of upto 10 record IDs to delete"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of deletion results")
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -7,8 +7,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
@@ -24,13 +23,13 @@ class AirtableListSchemaBlock(Block):
|
||||
fields, and views.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
base_schema: dict = SchemaField(
|
||||
description="Complete base schema with tables, fields, and views"
|
||||
)
|
||||
@@ -67,7 +66,7 @@ class AirtableCreateTableBlock(Block):
|
||||
Creates a new table in an Airtable base with specified fields and views.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -78,7 +77,7 @@ class AirtableCreateTableBlock(Block):
|
||||
default=[{"name": "Name", "type": "singleLineText"}],
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
table: dict = SchemaField(description="Created table object")
|
||||
table_id: str = SchemaField(description="ID of the created table")
|
||||
|
||||
@@ -110,7 +109,7 @@ class AirtableUpdateTableBlock(Block):
|
||||
Updates an existing table's properties such as name or description.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -126,7 +125,7 @@ class AirtableUpdateTableBlock(Block):
|
||||
description="The date dependency of the table to update", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
table: dict = SchemaField(description="Updated table object")
|
||||
|
||||
def __init__(self):
|
||||
@@ -158,7 +157,7 @@ class AirtableCreateFieldBlock(Block):
|
||||
Adds a new field (column) to an existing Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -177,7 +176,7 @@ class AirtableCreateFieldBlock(Block):
|
||||
description="The options of the field to create", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
field: dict = SchemaField(description="Created field object")
|
||||
field_id: str = SchemaField(description="ID of the created field")
|
||||
|
||||
@@ -210,7 +209,7 @@ class AirtableUpdateFieldBlock(Block):
|
||||
Updates an existing field's properties in an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -226,7 +225,7 @@ class AirtableUpdateFieldBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
field: dict = SchemaField(description="Updated field object")
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -3,8 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
BlockWebhookConfig,
|
||||
CredentialsMetaInput,
|
||||
@@ -33,7 +32,7 @@ class AirtableWebhookTriggerBlock(Block):
|
||||
Thin wrapper just forwards the payloads one at a time to the next block.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -44,7 +43,7 @@ class AirtableWebhookTriggerBlock(Block):
|
||||
description="Airtable webhook event filter"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
payload: WebhookPayload = SchemaField(description="Airtable webhook payload")
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -10,20 +10,14 @@ from backend.blocks.apollo.models import (
|
||||
PrimaryPhone,
|
||||
SearchOrganizationsRequest,
|
||||
)
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class SearchOrganizationsBlock(Block):
|
||||
"""Search for organizations in Apollo"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
organization_num_employees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
@@ -75,7 +69,7 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
organizations: list[Organization] = SchemaField(
|
||||
description="List of organizations found",
|
||||
default_factory=list,
|
||||
|
||||
@@ -14,20 +14,14 @@ from backend.blocks.apollo.models import (
|
||||
SearchPeopleRequest,
|
||||
SenorityLevels,
|
||||
)
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class SearchPeopleBlock(Block):
|
||||
"""Search for people in Apollo"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
person_titles: list[str] = SchemaField(
|
||||
description="""Job titles held by the people you want to find. For a person to be included in search results, they only need to match 1 of the job titles you add. Adding more job titles expands your search results.
|
||||
|
||||
@@ -115,7 +109,7 @@ class SearchPeopleBlock(Block):
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
people: list[Contact] = SchemaField(
|
||||
description="List of people found",
|
||||
default_factory=list,
|
||||
|
||||
@@ -6,20 +6,14 @@ from backend.blocks.apollo._auth import (
|
||||
ApolloCredentialsInput,
|
||||
)
|
||||
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class GetPersonDetailBlock(Block):
|
||||
"""Get detailed person data with Apollo API, including email reveal"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
person_id: str = SchemaField(
|
||||
description="Apollo person ID to enrich (most accurate method)",
|
||||
default="",
|
||||
@@ -74,7 +68,7 @@ class GetPersonDetailBlock(Block):
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
contact: Contact = SchemaField(
|
||||
description="Enriched contact information",
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.block import BlockSchemaInput
|
||||
from backend.data.block import BlockSchema
|
||||
from backend.data.model import SchemaField, UserIntegrations
|
||||
from backend.integrations.ayrshare import AyrshareClient
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
@@ -17,7 +17,7 @@ async def get_profile_key(user_id: str):
|
||||
return user_integrations.managed_credentials.ayrshare_profile_key
|
||||
|
||||
|
||||
class BaseAyrshareInput(BlockSchemaInput):
|
||||
class BaseAyrshareInput(BlockSchema):
|
||||
"""Base input model for Ayrshare social media posts with common fields."""
|
||||
|
||||
post: str = SchemaField(
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -38,7 +38,7 @@ class PostToBlueskyBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -101,7 +101,7 @@ class PostToFacebookBlock(Block):
|
||||
description="URL for custom link preview", default="", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -94,7 +94,7 @@ class PostToGMBBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -94,7 +94,7 @@ class PostToInstagramBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -94,7 +94,7 @@ class PostToLinkedInBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -73,7 +73,7 @@ class PostToPinterestBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -19,7 +19,7 @@ class PostToRedditBlock(Block):
|
||||
|
||||
pass # Uses all base fields
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -43,7 +43,7 @@ class PostToSnapchatBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -38,7 +38,7 @@ class PostToTelegramBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -31,7 +31,7 @@ class PostToThreadsBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -98,7 +98,7 @@ class PostToTikTokBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -97,7 +97,7 @@ class PostToXBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -119,7 +119,7 @@ class PostToYouTubeBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -9,8 +9,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -24,7 +23,7 @@ class BaasBotJoinMeetingBlock(Block):
|
||||
Deploy a bot immediately or at a scheduled start_time to join and record a meeting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
@@ -58,7 +57,7 @@ class BaasBotJoinMeetingBlock(Block):
|
||||
description="Custom metadata to attach to the bot", default={}
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
bot_id: str = SchemaField(description="UUID of the deployed bot")
|
||||
join_response: dict = SchemaField(
|
||||
description="Full response from join operation"
|
||||
@@ -104,13 +103,13 @@ class BaasBotLeaveMeetingBlock(Block):
|
||||
Force the bot to exit the call.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot to remove from meeting")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
left: bool = SchemaField(description="Whether the bot successfully left")
|
||||
|
||||
def __init__(self):
|
||||
@@ -139,7 +138,7 @@ class BaasBotFetchMeetingDataBlock(Block):
|
||||
Pull MP4 URL, transcript & metadata for a completed meeting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
@@ -148,7 +147,7 @@ class BaasBotFetchMeetingDataBlock(Block):
|
||||
description="Include transcript data in response", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
mp4_url: str = SchemaField(
|
||||
description="URL to download the meeting recording (time-limited)"
|
||||
)
|
||||
@@ -186,13 +185,13 @@ class BaasBotDeleteRecordingBlock(Block):
|
||||
Purge MP4 + transcript data for privacy or storage management.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot whose data to delete")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
deleted: bool = SchemaField(
|
||||
description="Whether the data was successfully deleted"
|
||||
)
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .text_overlay import BannerbearTextOverlayBlock
|
||||
|
||||
__all__ = ["BannerbearTextOverlayBlock"]
|
||||
@@ -1,8 +0,0 @@
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
bannerbear = (
|
||||
ProviderBuilder("bannerbear")
|
||||
.with_api_key("BANNERBEAR_API_KEY", "Bannerbear API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
@@ -1,239 +0,0 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import bannerbear
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="bannerbear",
|
||||
api_key=SecretStr("mock-bannerbear-api-key"),
|
||||
title="Mock Bannerbear API Key",
|
||||
)
|
||||
|
||||
|
||||
class TextModification(BlockSchemaInput):
|
||||
name: str = SchemaField(
|
||||
description="The name of the layer to modify in the template"
|
||||
)
|
||||
text: str = SchemaField(description="The text content to add to this layer")
|
||||
color: str = SchemaField(
|
||||
description="Hex color code for the text (e.g., '#FF0000')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
font_family: str = SchemaField(
|
||||
description="Font family to use for the text",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
font_size: int = SchemaField(
|
||||
description="Font size in pixels",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
font_weight: str = SchemaField(
|
||||
description="Font weight (e.g., 'bold', 'normal')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
text_align: str = SchemaField(
|
||||
description="Text alignment (left, center, right)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class BannerbearTextOverlayBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = bannerbear.credentials_field(
|
||||
description="API credentials for Bannerbear"
|
||||
)
|
||||
template_id: str = SchemaField(
|
||||
description="The unique ID of your Bannerbear template"
|
||||
)
|
||||
project_id: str = SchemaField(
|
||||
description="Optional: Project ID (required when using Master API Key)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
text_modifications: List[TextModification] = SchemaField(
|
||||
description="List of text layers to modify in the template"
|
||||
)
|
||||
image_url: str = SchemaField(
|
||||
description="Optional: URL of an image to use in the template",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
image_layer_name: str = SchemaField(
|
||||
description="Optional: Name of the image layer in the template",
|
||||
default="photo",
|
||||
advanced=True,
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="Optional: URL to receive webhook notification when image is ready",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
metadata: str = SchemaField(
|
||||
description="Optional: Custom metadata to attach to the image",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the image generation was successfully initiated"
|
||||
)
|
||||
image_url: str = SchemaField(
|
||||
description="URL of the generated image (if synchronous) or placeholder"
|
||||
)
|
||||
uid: str = SchemaField(description="Unique identifier for the generated image")
|
||||
status: str = SchemaField(description="Status of the image generation")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c7d3a5c2-05fc-450e-8dce-3b0e04626009",
|
||||
description="Add text overlay to images using Bannerbear templates. Perfect for creating social media graphics, marketing materials, and dynamic image content.",
|
||||
categories={BlockCategory.PRODUCTIVITY, BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"template_id": "jJWBKNELpQPvbX5R93Gk",
|
||||
"text_modifications": [
|
||||
{
|
||||
"name": "headline",
|
||||
"text": "Amazing Product Launch!",
|
||||
"color": "#FF0000",
|
||||
},
|
||||
{
|
||||
"name": "subtitle",
|
||||
"text": "50% OFF Today Only",
|
||||
},
|
||||
],
|
||||
"credentials": {
|
||||
"provider": "bannerbear",
|
||||
"id": str(uuid.uuid4()),
|
||||
"type": "api_key",
|
||||
},
|
||||
},
|
||||
test_output=[
|
||||
("success", True),
|
||||
("image_url", "https://cdn.bannerbear.com/test-image.jpg"),
|
||||
("uid", "test-uid-123"),
|
||||
("status", "completed"),
|
||||
],
|
||||
test_mock={
|
||||
"_make_api_request": lambda *args, **kwargs: {
|
||||
"uid": "test-uid-123",
|
||||
"status": "completed",
|
||||
"image_url": "https://cdn.bannerbear.com/test-image.jpg",
|
||||
}
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def _make_api_request(self, payload: dict, api_key: str) -> dict:
|
||||
"""Make the actual API request to Bannerbear. This is separated for easy mocking in tests."""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = await Requests().post(
|
||||
"https://sync.api.bannerbear.com/v2/images",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status in [200, 201, 202]:
|
||||
return response.json()
|
||||
else:
|
||||
error_msg = f"API request failed with status {response.status}"
|
||||
if response.text:
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_msg = (
|
||||
f"{error_msg}: {error_data.get('message', response.text)}"
|
||||
)
|
||||
except Exception:
|
||||
error_msg = f"{error_msg}: {response.text}"
|
||||
raise Exception(error_msg)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Build the modifications array
|
||||
modifications = []
|
||||
|
||||
# Add text modifications
|
||||
for text_mod in input_data.text_modifications:
|
||||
mod_data: Dict[str, Any] = {
|
||||
"name": text_mod.name,
|
||||
"text": text_mod.text,
|
||||
}
|
||||
|
||||
# Add optional text styling parameters only if they have values
|
||||
if text_mod.color and text_mod.color.strip():
|
||||
mod_data["color"] = text_mod.color
|
||||
if text_mod.font_family and text_mod.font_family.strip():
|
||||
mod_data["font_family"] = text_mod.font_family
|
||||
if text_mod.font_size and text_mod.font_size > 0:
|
||||
mod_data["font_size"] = text_mod.font_size
|
||||
if text_mod.font_weight and text_mod.font_weight.strip():
|
||||
mod_data["font_weight"] = text_mod.font_weight
|
||||
if text_mod.text_align and text_mod.text_align.strip():
|
||||
mod_data["text_align"] = text_mod.text_align
|
||||
|
||||
modifications.append(mod_data)
|
||||
|
||||
# Add image modification if provided and not empty
|
||||
if input_data.image_url and input_data.image_url.strip():
|
||||
modifications.append(
|
||||
{
|
||||
"name": input_data.image_layer_name,
|
||||
"image_url": input_data.image_url,
|
||||
}
|
||||
)
|
||||
|
||||
# Build the request payload - only include non-empty optional fields
|
||||
payload = {
|
||||
"template": input_data.template_id,
|
||||
"modifications": modifications,
|
||||
}
|
||||
|
||||
# Add project_id if provided (required for Master API keys)
|
||||
if input_data.project_id and input_data.project_id.strip():
|
||||
payload["project_id"] = input_data.project_id
|
||||
|
||||
if input_data.webhook_url and input_data.webhook_url.strip():
|
||||
payload["webhook_url"] = input_data.webhook_url
|
||||
if input_data.metadata and input_data.metadata.strip():
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
# Make the API request using the private method
|
||||
data = await self._make_api_request(
|
||||
payload, credentials.api_key.get_secret_value()
|
||||
)
|
||||
|
||||
# Synchronous request - image should be ready
|
||||
yield "success", True
|
||||
yield "image_url", data.get("image_url", "")
|
||||
yield "uid", data.get("uid", "")
|
||||
yield "status", data.get("status", "completed")
|
||||
@@ -1,21 +1,14 @@
|
||||
import enum
|
||||
from typing import Any
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType, convert
|
||||
|
||||
|
||||
class FileStoreBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
file_in: MediaFileType = SchemaField(
|
||||
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
|
||||
)
|
||||
@@ -26,7 +19,7 @@ class FileStoreBlock(Block):
|
||||
title="Produce Base64 Output",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
file_out: MediaFileType = SchemaField(
|
||||
description="The relative path to the stored file in the temporary directory."
|
||||
)
|
||||
@@ -64,7 +57,7 @@ class StoreValueBlock(Block):
|
||||
The block output will be static, the output can be consumed multiple times.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
input: Any = SchemaField(
|
||||
description="Trigger the block to produce the output. "
|
||||
"The value is only used when `data` is None."
|
||||
@@ -75,7 +68,7 @@ class StoreValueBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="The stored data retained in the block.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -101,10 +94,10 @@ class StoreValueBlock(Block):
|
||||
|
||||
|
||||
class PrintToConsoleBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
text: Any = SchemaField(description="The data to print to the console.")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="The data printed to the console.")
|
||||
status: str = SchemaField(description="The status of the print operation.")
|
||||
|
||||
@@ -128,10 +121,10 @@ class PrintToConsoleBlock(Block):
|
||||
|
||||
|
||||
class NoteBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(description="The text to display in the sticky note.")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
output: str = SchemaField(description="The text to display in the sticky note.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -161,14 +154,15 @@ class TypeOptions(enum.Enum):
|
||||
|
||||
|
||||
class UniversalTypeConverterBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
value: Any = SchemaField(
|
||||
description="The value to convert to a universal type."
|
||||
)
|
||||
type: TypeOptions = SchemaField(description="The type to convert the value to.")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
value: Any = SchemaField(description="The converted value.")
|
||||
error: str = SchemaField(description="Error message if conversion failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -201,10 +195,10 @@ class ReverseListOrderBlock(Block):
|
||||
A block which takes in a list and returns it in the opposite order.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
input_list: list[Any] = SchemaField(description="The list to reverse")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
reversed_list: list[Any] = SchemaField(description="The list in reversed order")
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -2,13 +2,7 @@ import os
|
||||
import re
|
||||
from typing import Type
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
@@ -21,12 +15,12 @@ class BlockInstallationBlock(Block):
|
||||
for development purposes only.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
code: str = SchemaField(
|
||||
description="Python code of the block to be installed",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
success: str = SchemaField(
|
||||
description="Success message if the block is installed successfully",
|
||||
)
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.type import convert
|
||||
|
||||
@@ -22,7 +16,7 @@ class ComparisonOperator(Enum):
|
||||
|
||||
|
||||
class ConditionBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
value1: Any = SchemaField(
|
||||
description="Enter the first value for comparison",
|
||||
placeholder="For example: 10 or 'hello' or True",
|
||||
@@ -46,7 +40,7 @@ class ConditionBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
result: bool = SchemaField(
|
||||
description="The result of the condition evaluation (True or False)"
|
||||
)
|
||||
@@ -117,7 +111,7 @@ class ConditionBlock(Block):
|
||||
|
||||
|
||||
class IfInputMatchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
input: Any = SchemaField(
|
||||
description="The input to match against",
|
||||
placeholder="For example: 10 or 'hello' or True",
|
||||
@@ -137,7 +131,7 @@ class IfInputMatchesBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
result: bool = SchemaField(
|
||||
description="The result of the condition evaluation (True or False)"
|
||||
)
|
||||
|
||||
@@ -1,18 +1,10 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from e2b_code_interpreter import AsyncSandbox
|
||||
from e2b_code_interpreter import Result as E2BExecutionResult
|
||||
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
|
||||
from pydantic import BaseModel, Field, JsonValue, SecretStr
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -44,135 +36,14 @@ class ProgrammingLanguage(Enum):
|
||||
JAVA = "java"
|
||||
|
||||
|
||||
class MainCodeExecutionResult(BaseModel):
|
||||
"""
|
||||
*Pydantic model mirroring `e2b_code_interpreter.Result`*
|
||||
|
||||
Represents the data to be displayed as a result of executing a cell in a Jupyter notebook.
|
||||
The result is similar to the structure returned by ipython kernel: https://ipython.readthedocs.io/en/stable/development/execution.html#execution-semantics
|
||||
|
||||
The result can contain multiple types of data, such as text, images, plots, etc. Each type of data is represented
|
||||
as a string, and the result can contain multiple types of data. The display calls don't have to have text representation,
|
||||
for the actual result the representation is always present for the result, the other representations are always optional.
|
||||
""" # noqa
|
||||
|
||||
class Chart(BaseModel, E2BExecutionResultChart):
|
||||
pass
|
||||
|
||||
text: Optional[str] = None
|
||||
html: Optional[str] = None
|
||||
markdown: Optional[str] = None
|
||||
svg: Optional[str] = None
|
||||
png: Optional[str] = None
|
||||
jpeg: Optional[str] = None
|
||||
pdf: Optional[str] = None
|
||||
latex: Optional[str] = None
|
||||
json_data: Optional[JsonValue] = Field(None, alias="json")
|
||||
javascript: Optional[str] = None
|
||||
data: Optional[dict] = None
|
||||
chart: Optional[Chart] = None
|
||||
extra: Optional[dict] = None
|
||||
"""Extra data that can be included. Not part of the standard types."""
|
||||
|
||||
|
||||
class CodeExecutionResult(MainCodeExecutionResult):
|
||||
__doc__ = MainCodeExecutionResult.__doc__
|
||||
|
||||
is_main_result: bool = False
|
||||
"""Whether this data is the main result of the cell. Data can be produced by display calls of which can be multiple in a cell.""" # noqa
|
||||
|
||||
|
||||
class BaseE2BExecutorMixin:
|
||||
"""Shared implementation methods for E2B executor blocks."""
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
api_key: str,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
template_id: str = "",
|
||||
setup_commands: Optional[list[str]] = None,
|
||||
timeout: Optional[int] = None,
|
||||
sandbox_id: Optional[str] = None,
|
||||
dispose_sandbox: bool = False,
|
||||
):
|
||||
"""
|
||||
Unified code execution method that handles all three use cases:
|
||||
1. Create new sandbox and execute (ExecuteCodeBlock)
|
||||
2. Create new sandbox, execute, and return sandbox_id (InstantiateCodeSandboxBlock)
|
||||
3. Connect to existing sandbox and execute (ExecuteCodeStepBlock)
|
||||
""" # noqa
|
||||
sandbox = None
|
||||
try:
|
||||
if sandbox_id:
|
||||
# Connect to existing sandbox (ExecuteCodeStepBlock case)
|
||||
sandbox = await AsyncSandbox.connect(
|
||||
sandbox_id=sandbox_id, api_key=api_key
|
||||
)
|
||||
else:
|
||||
# Create new sandbox (ExecuteCodeBlock/InstantiateCodeSandboxBlock case)
|
||||
sandbox = await AsyncSandbox.create(
|
||||
api_key=api_key, template=template_id, timeout=timeout
|
||||
)
|
||||
if setup_commands:
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
|
||||
# Execute the code
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
results = execution.results
|
||||
text_output = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return results, text_output, stdout_logs, stderr_logs, sandbox.sandbox_id
|
||||
finally:
|
||||
# Dispose of sandbox if requested to reduce usage costs
|
||||
if dispose_sandbox and sandbox:
|
||||
await sandbox.kill()
|
||||
|
||||
def process_execution_results(
|
||||
self, results: list[E2BExecutionResult]
|
||||
) -> tuple[dict[str, Any] | None, list[dict[str, Any]]]:
|
||||
"""Process and filter execution results."""
|
||||
# Filter out empty formats and convert to dicts
|
||||
processed_results = [
|
||||
{
|
||||
f: value
|
||||
for f in [*r.formats(), "extra", "is_main_result"]
|
||||
if (value := getattr(r, f, None)) is not None
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
if main_result := next(
|
||||
(r for r in processed_results if r.get("is_main_result")), None
|
||||
):
|
||||
# Make main_result a copy we can modify & remove is_main_result
|
||||
(main_result := {**main_result}).pop("is_main_result")
|
||||
|
||||
return main_result, processed_results
|
||||
|
||||
|
||||
class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
class CodeExecutionBlock(Block):
|
||||
# TODO : Add support to upload and download files
|
||||
# NOTE: Currently, you can only customize the CPU and Memory
|
||||
# by creating a pre customized sandbox template
|
||||
class Input(BlockSchemaInput):
|
||||
# Currently, You can customized the CPU and Memory, only by creating a pre customized sandbox template
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description=(
|
||||
"Enter your API key for the E2B platform. "
|
||||
"You can get it in here - https://e2b.dev/docs"
|
||||
),
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
)
|
||||
|
||||
# Todo : Option to run commond in background
|
||||
@@ -205,14 +76,6 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
description="Execution timeout in seconds", default=300
|
||||
)
|
||||
|
||||
dispose_sandbox: bool = SchemaField(
|
||||
description=(
|
||||
"Whether to dispose of the sandbox immediately after execution. "
|
||||
"If disabled, the sandbox will run until its timeout expires."
|
||||
),
|
||||
default=True,
|
||||
)
|
||||
|
||||
template_id: str = SchemaField(
|
||||
description=(
|
||||
"You can use an E2B sandbox template by entering its ID here. "
|
||||
@@ -223,29 +86,21 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
main_result: MainCodeExecutionResult = SchemaField(
|
||||
title="Main Result", description="The main result from the code execution"
|
||||
)
|
||||
results: list[CodeExecutionResult] = SchemaField(
|
||||
description="List of results from the code execution"
|
||||
)
|
||||
response: str = SchemaField(
|
||||
title="Main Text Output",
|
||||
description="Text output (if any) of the main execution result",
|
||||
)
|
||||
class Output(BlockSchema):
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
stderr_logs: str = SchemaField(description="Standard error logs from execution")
|
||||
error: str = SchemaField(description="Error message if execution failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0b02b072-abe7-11ef-8372-fb5d162dd712",
|
||||
description="Executes code in a sandbox environment with internet access.",
|
||||
description="Executes code in an isolated sandbox environment with internet access.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=ExecuteCodeBlock.Input,
|
||||
output_schema=ExecuteCodeBlock.Output,
|
||||
input_schema=CodeExecutionBlock.Input,
|
||||
output_schema=CodeExecutionBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
@@ -256,59 +111,91 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
"template_id": "",
|
||||
},
|
||||
test_output=[
|
||||
("results", []),
|
||||
("response", "Hello World"),
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout, dispose_sandbox: ( # noqa
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
"sandbox_id", # sandbox_id
|
||||
"execute_code": lambda code, language, setup_commands, timeout, api_key, template_id: (
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
setup_commands: list[str],
|
||||
timeout: int,
|
||||
api_key: str,
|
||||
template_id: str,
|
||||
):
|
||||
try:
|
||||
sandbox = None
|
||||
if template_id:
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template_id, api_key=api_key, timeout=timeout
|
||||
)
|
||||
else:
|
||||
sandbox = await AsyncSandbox.create(api_key=api_key, timeout=timeout)
|
||||
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not created")
|
||||
|
||||
# Running setup commands
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox if there is an error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return response, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
results, text_output, stdout, stderr, _ = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.code,
|
||||
language=input_data.language,
|
||||
template_id=input_data.template_id,
|
||||
setup_commands=input_data.setup_commands,
|
||||
timeout=input_data.timeout,
|
||||
dispose_sandbox=input_data.dispose_sandbox,
|
||||
response, stdout_logs, stderr_logs = await self.execute_code(
|
||||
input_data.code,
|
||||
input_data.language,
|
||||
input_data.setup_commands,
|
||||
input_data.timeout,
|
||||
credentials.api_key.get_secret_value(),
|
||||
input_data.template_id,
|
||||
)
|
||||
|
||||
# Determine result object shape & filter out empty formats
|
||||
main_result, results = self.process_execution_results(results)
|
||||
if main_result:
|
||||
yield "main_result", main_result
|
||||
yield "results", results
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout:
|
||||
yield "stdout_logs", stdout
|
||||
if stderr:
|
||||
yield "stderr_logs", stderr
|
||||
if response:
|
||||
yield "response", response
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
class Input(BlockSchemaInput):
|
||||
class InstantiationBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description=(
|
||||
"Enter your API key for the E2B platform. "
|
||||
"You can get it in here - https://e2b.dev/docs"
|
||||
)
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
)
|
||||
|
||||
# Todo : Option to run commond in background
|
||||
@@ -351,27 +238,22 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
sandbox_id: str = SchemaField(description="ID of the sandbox instance")
|
||||
response: str = SchemaField(
|
||||
title="Text Result",
|
||||
description="Text result (if any) of the setup code execution",
|
||||
)
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
stderr_logs: str = SchemaField(description="Standard error logs from execution")
|
||||
error: str = SchemaField(description="Error message if execution failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ff0861c9-1726-4aec-9e5b-bf53f3622112",
|
||||
description=(
|
||||
"Instantiate a sandbox environment with internet access "
|
||||
"in which you can execute code with the Execute Code Step block."
|
||||
),
|
||||
description="Instantiate an isolated sandbox environment with internet access where to execute code in.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=InstantiateCodeSandboxBlock.Input,
|
||||
output_schema=InstantiateCodeSandboxBlock.Output,
|
||||
input_schema=InstantiationBlock.Input,
|
||||
output_schema=InstantiationBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
@@ -387,12 +269,11 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout: ( # noqa
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
"sandbox_id", # sandbox_id
|
||||
"execute_code": lambda setup_code, language, setup_commands, timeout, api_key, template_id: (
|
||||
"sandbox_id",
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -401,38 +282,78 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
_, text_output, stdout, stderr, sandbox_id = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.setup_code,
|
||||
language=input_data.language,
|
||||
template_id=input_data.template_id,
|
||||
setup_commands=input_data.setup_commands,
|
||||
timeout=input_data.timeout,
|
||||
sandbox_id, response, stdout_logs, stderr_logs = await self.execute_code(
|
||||
input_data.setup_code,
|
||||
input_data.language,
|
||||
input_data.setup_commands,
|
||||
input_data.timeout,
|
||||
credentials.api_key.get_secret_value(),
|
||||
input_data.template_id,
|
||||
)
|
||||
if sandbox_id:
|
||||
yield "sandbox_id", sandbox_id
|
||||
else:
|
||||
yield "error", "Sandbox ID not found"
|
||||
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout:
|
||||
yield "stdout_logs", stdout
|
||||
if stderr:
|
||||
yield "stderr_logs", stderr
|
||||
if response:
|
||||
yield "response", response
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
setup_commands: list[str],
|
||||
timeout: int,
|
||||
api_key: str,
|
||||
template_id: str,
|
||||
):
|
||||
try:
|
||||
sandbox = None
|
||||
if template_id:
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template_id, api_key=api_key, timeout=timeout
|
||||
)
|
||||
else:
|
||||
sandbox = await AsyncSandbox.create(api_key=api_key, timeout=timeout)
|
||||
|
||||
class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
||||
class Input(BlockSchemaInput):
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not created")
|
||||
|
||||
# Running setup commands
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox if there is an error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return sandbox.sandbox_id, response, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
class StepExecutionBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description=(
|
||||
"Enter your API key for the E2B platform. "
|
||||
"You can get it in here - https://e2b.dev/docs"
|
||||
),
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
)
|
||||
|
||||
sandbox_id: str = SchemaField(
|
||||
@@ -453,34 +374,21 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
dispose_sandbox: bool = SchemaField(
|
||||
description="Whether to dispose of the sandbox after executing this code.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
main_result: MainCodeExecutionResult = SchemaField(
|
||||
title="Main Result", description="The main result from the code execution"
|
||||
)
|
||||
results: list[CodeExecutionResult] = SchemaField(
|
||||
description="List of results from the code execution"
|
||||
)
|
||||
response: str = SchemaField(
|
||||
title="Main Text Output",
|
||||
description="Text output (if any) of the main execution result",
|
||||
)
|
||||
class Output(BlockSchema):
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
stderr_logs: str = SchemaField(description="Standard error logs from execution")
|
||||
error: str = SchemaField(description="Error message if execution failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="82b59b8e-ea10-4d57-9161-8b169b0adba6",
|
||||
description="Execute code in a previously instantiated sandbox.",
|
||||
description="Execute code in a previously instantiated sandbox environment.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=ExecuteCodeStepBlock.Input,
|
||||
output_schema=ExecuteCodeStepBlock.Output,
|
||||
input_schema=StepExecutionBlock.Input,
|
||||
output_schema=StepExecutionBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
@@ -489,43 +397,61 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
||||
"language": ProgrammingLanguage.PYTHON.value,
|
||||
},
|
||||
test_output=[
|
||||
("results", []),
|
||||
("response", "Hello World"),
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda api_key, code, language, sandbox_id, dispose_sandbox: ( # noqa
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
sandbox_id, # sandbox_id
|
||||
"execute_step_code": lambda sandbox_id, step_code, language, api_key: (
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def execute_step_code(
|
||||
self,
|
||||
sandbox_id: str,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
api_key: str,
|
||||
):
|
||||
try:
|
||||
sandbox = await AsyncSandbox.connect(sandbox_id=sandbox_id, api_key=api_key)
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not found")
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(code, language=language.value)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return response, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
results, text_output, stdout, stderr, _ = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.step_code,
|
||||
language=input_data.language,
|
||||
sandbox_id=input_data.sandbox_id,
|
||||
dispose_sandbox=input_data.dispose_sandbox,
|
||||
response, stdout_logs, stderr_logs = await self.execute_step_code(
|
||||
input_data.sandbox_id,
|
||||
input_data.step_code,
|
||||
input_data.language,
|
||||
credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
# Determine result object shape & filter out empty formats
|
||||
main_result, results = self.process_execution_results(results)
|
||||
if main_result:
|
||||
yield "main_result", main_result
|
||||
yield "results", results
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout:
|
||||
yield "stdout_logs", stdout
|
||||
if stderr:
|
||||
yield "stderr_logs", stderr
|
||||
if response:
|
||||
yield "response", response
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -1,23 +1,17 @@
|
||||
import re
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class CodeExtractionBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(
|
||||
description="Text containing code blocks to extract (e.g., AI response)",
|
||||
placeholder="Enter text containing code blocks",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
html: str = SchemaField(description="Extracted HTML code")
|
||||
css: str = SchemaField(description="Extracted CSS code")
|
||||
javascript: str = SchemaField(description="Extracted JavaScript code")
|
||||
@@ -96,7 +90,7 @@ class CodeExtractionBlock(Block):
|
||||
for aliases in language_aliases.values()
|
||||
for alias in aliases
|
||||
)
|
||||
+ r")[ \t]*\n[\s\S]*?```"
|
||||
+ r")\s+[\s\S]*?```"
|
||||
)
|
||||
|
||||
remaining_text = re.sub(pattern, "", input_data.text).strip()
|
||||
@@ -109,9 +103,7 @@ class CodeExtractionBlock(Block):
|
||||
# Escape special regex characters in the language string
|
||||
language = re.escape(language)
|
||||
# Extract all code blocks enclosed in ```language``` blocks
|
||||
pattern = re.compile(
|
||||
rf"```{language}[ \t]*\n(.*?)\n```", re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
pattern = re.compile(rf"```{language}\s+(.*?)```", re.DOTALL | re.IGNORECASE)
|
||||
matches = pattern.finditer(text)
|
||||
# Combine all code blocks for this language with newlines between them
|
||||
code_blocks = [match.group(1).strip() for match in matches]
|
||||
|
||||
@@ -5,8 +5,7 @@ from backend.data.block import (
|
||||
BlockCategory,
|
||||
BlockManualWebhookConfig,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -28,10 +27,10 @@ class TranscriptionDataModel(BaseModel):
|
||||
|
||||
|
||||
class CompassAITriggerBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
payload: TranscriptionDataModel = SchemaField(hidden=True)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
transcription: str = SchemaField(
|
||||
description="The contents of the compass transcription."
|
||||
)
|
||||
|
||||
@@ -1,22 +1,16 @@
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class WordCharacterCountBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(
|
||||
description="Input text to count words and characters",
|
||||
placeholder="Enter your text here",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
word_count: int = SchemaField(description="Number of words in the input text")
|
||||
character_count: int = SchemaField(
|
||||
description="Number of characters in the input text"
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
from typing import Any, List
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.json import loads
|
||||
from backend.util.mock import MockObject
|
||||
@@ -18,13 +12,13 @@ from backend.util.prompt import estimate_token_count_str
|
||||
|
||||
|
||||
class CreateDictionaryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
values: dict[str, Any] = SchemaField(
|
||||
description="Key-value pairs to create the dictionary with",
|
||||
placeholder="e.g., {'name': 'Alice', 'age': 25}",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
dictionary: dict[str, Any] = SchemaField(
|
||||
description="The created dictionary containing the specified key-value pairs"
|
||||
)
|
||||
@@ -68,11 +62,10 @@ class CreateDictionaryBlock(Block):
|
||||
|
||||
|
||||
class AddToDictionaryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
|
||||
advanced=False,
|
||||
)
|
||||
key: str = SchemaField(
|
||||
default="",
|
||||
@@ -92,10 +85,11 @@ class AddToDictionaryBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
updated_dictionary: dict = SchemaField(
|
||||
description="The dictionary with the new entry added."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -146,11 +140,11 @@ class AddToDictionaryBlock(Block):
|
||||
|
||||
|
||||
class FindInDictionaryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
input: Any = SchemaField(description="Dictionary to lookup from")
|
||||
key: str | int = SchemaField(description="Key to lookup in the dictionary")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="Value found for the given key")
|
||||
missing: Any = SchemaField(
|
||||
description="Value of the input that missing the key"
|
||||
@@ -206,7 +200,7 @@ class FindInDictionaryBlock(Block):
|
||||
|
||||
|
||||
class RemoveFromDictionaryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary to modify."
|
||||
)
|
||||
@@ -215,11 +209,12 @@ class RemoveFromDictionaryBlock(Block):
|
||||
default=False, description="Whether to return the removed value."
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
updated_dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary after removal."
|
||||
)
|
||||
removed_value: Any = SchemaField(description="The removed value if requested.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -255,18 +250,19 @@ class RemoveFromDictionaryBlock(Block):
|
||||
|
||||
|
||||
class ReplaceDictionaryValueBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary to modify."
|
||||
)
|
||||
key: str | int = SchemaField(description="Key to replace the value for.")
|
||||
value: Any = SchemaField(description="The new value for the given key.")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
updated_dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary after replacement."
|
||||
)
|
||||
old_value: Any = SchemaField(description="The value that was replaced.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -303,10 +299,10 @@ class ReplaceDictionaryValueBlock(Block):
|
||||
|
||||
|
||||
class DictionaryIsEmptyBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(description="The dictionary to check.")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
is_empty: bool = SchemaField(description="True if the dictionary is empty.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -330,7 +326,7 @@ class DictionaryIsEmptyBlock(Block):
|
||||
|
||||
|
||||
class CreateListBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
values: List[Any] = SchemaField(
|
||||
description="A list of values to be combined into a new list.",
|
||||
placeholder="e.g., ['Alice', 25, True]",
|
||||
@@ -346,10 +342,11 @@ class CreateListBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
description="The created list containing the specified values."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if list creation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -406,7 +403,7 @@ class CreateListBlock(Block):
|
||||
|
||||
|
||||
class AddToListBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
@@ -427,10 +424,11 @@ class AddToListBlock(Block):
|
||||
description="The position to insert the new entry. If not provided, the entry will be appended to the end of the list.",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
updated_list: List[Any] = SchemaField(
|
||||
description="The list with the new entry added."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -485,11 +483,11 @@ class AddToListBlock(Block):
|
||||
|
||||
|
||||
class FindInListBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to search in.")
|
||||
value: Any = SchemaField(description="The value to search for.")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
index: int = SchemaField(description="The index of the value in the list.")
|
||||
found: bool = SchemaField(
|
||||
description="Whether the value was found in the list."
|
||||
@@ -527,14 +525,15 @@ class FindInListBlock(Block):
|
||||
|
||||
|
||||
class GetListItemBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to get the item from.")
|
||||
index: int = SchemaField(
|
||||
description="The 0-based index of the item (supports negative indices)."
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
item: Any = SchemaField(description="The item at the specified index.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -561,7 +560,7 @@ class GetListItemBlock(Block):
|
||||
|
||||
|
||||
class RemoveFromListBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to modify.")
|
||||
value: Any = SchemaField(
|
||||
default=None, description="Value to remove from the list."
|
||||
@@ -574,9 +573,10 @@ class RemoveFromListBlock(Block):
|
||||
default=False, description="Whether to return the removed item."
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
updated_list: List[Any] = SchemaField(description="The list after removal.")
|
||||
removed_item: Any = SchemaField(description="The removed item if requested.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -617,16 +617,17 @@ class RemoveFromListBlock(Block):
|
||||
|
||||
|
||||
class ReplaceListItemBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to modify.")
|
||||
index: int = SchemaField(
|
||||
description="Index of the item to replace (supports negative indices)."
|
||||
)
|
||||
value: Any = SchemaField(description="The new value for the given index.")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
updated_list: List[Any] = SchemaField(description="The list after replacement.")
|
||||
old_item: Any = SchemaField(description="The item that was replaced.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -661,10 +662,10 @@ class ReplaceListItemBlock(Block):
|
||||
|
||||
|
||||
class ListIsEmptyBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to check.")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
is_empty: bool = SchemaField(description="True if the list is empty.")
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -113,7 +113,6 @@ class DataForSeoClient:
|
||||
include_serp_info: bool = False,
|
||||
include_clickstream_data: bool = False,
|
||||
limit: int = 100,
|
||||
depth: Optional[int] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get related keywords from DataForSEO Labs.
|
||||
@@ -126,7 +125,6 @@ class DataForSeoClient:
|
||||
include_serp_info: Include SERP data
|
||||
include_clickstream_data: Include clickstream metrics
|
||||
limit: Maximum number of results (up to 3000)
|
||||
depth: Keyword search depth (0-4), controls number of returned keywords
|
||||
|
||||
Returns:
|
||||
API response with related keywords
|
||||
@@ -150,8 +148,6 @@ class DataForSeoClient:
|
||||
task_data["include_clickstream_data"] = include_clickstream_data
|
||||
if limit is not None:
|
||||
task_data["limit"] = limit
|
||||
if depth is not None:
|
||||
task_data["depth"] = depth
|
||||
|
||||
payload = [task_data]
|
||||
|
||||
|
||||
@@ -8,8 +8,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
@@ -19,7 +18,7 @@ from ._api import DataForSeoClient
|
||||
from ._config import dataforseo
|
||||
|
||||
|
||||
class KeywordSuggestion(BlockSchemaInput):
|
||||
class KeywordSuggestion(BlockSchema):
|
||||
"""Schema for a keyword suggestion result."""
|
||||
|
||||
keyword: str = SchemaField(description="The keyword suggestion")
|
||||
@@ -46,7 +45,7 @@ class KeywordSuggestion(BlockSchemaInput):
|
||||
class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
"""Block for getting keyword suggestions from DataForSEO Labs."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = dataforseo.credentials_field(
|
||||
description="DataForSEO credentials (username and password)"
|
||||
)
|
||||
@@ -78,7 +77,7 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
le=3000,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
suggestions: List[KeywordSuggestion] = SchemaField(
|
||||
description="List of keyword suggestions with metrics"
|
||||
)
|
||||
@@ -162,63 +161,54 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the keyword suggestions query."""
|
||||
try:
|
||||
client = DataForSeoClient(credentials)
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
suggestions = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", [])
|
||||
if isinstance(first_result, dict)
|
||||
else []
|
||||
# Process and format the results
|
||||
suggestions = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Create the KeywordSuggestion object
|
||||
suggestion = KeywordSuggestion(
|
||||
keyword=item.get("keyword", ""),
|
||||
search_volume=item.get("keyword_info", {}).get("search_volume"),
|
||||
competition=item.get("keyword_info", {}).get("competition"),
|
||||
cpc=item.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=item.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
item.get("serp_info") if input_data.include_serp_info else None
|
||||
),
|
||||
clickstream_data=(
|
||||
item.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
if items is None:
|
||||
items = []
|
||||
for item in items:
|
||||
# Create the KeywordSuggestion object
|
||||
suggestion = KeywordSuggestion(
|
||||
keyword=item.get("keyword", ""),
|
||||
search_volume=item.get("keyword_info", {}).get("search_volume"),
|
||||
competition=item.get("keyword_info", {}).get("competition"),
|
||||
cpc=item.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=item.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
item.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
item.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "suggestion", suggestion
|
||||
suggestions.append(suggestion)
|
||||
yield "suggestion", suggestion
|
||||
suggestions.append(suggestion)
|
||||
|
||||
yield "suggestions", suggestions
|
||||
yield "total_count", len(suggestions)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to fetch keyword suggestions: {str(e)}"
|
||||
yield "suggestions", suggestions
|
||||
yield "total_count", len(suggestions)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
|
||||
|
||||
class KeywordSuggestionExtractorBlock(Block):
|
||||
"""Extracts individual fields from a KeywordSuggestion object."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
suggestion: KeywordSuggestion = SchemaField(
|
||||
description="The keyword suggestion object to extract fields from"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
keyword: str = SchemaField(description="The keyword suggestion")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
|
||||
@@ -8,8 +8,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
@@ -19,7 +18,7 @@ from ._api import DataForSeoClient
|
||||
from ._config import dataforseo
|
||||
|
||||
|
||||
class RelatedKeyword(BlockSchemaInput):
|
||||
class RelatedKeyword(BlockSchema):
|
||||
"""Schema for a related keyword result."""
|
||||
|
||||
keyword: str = SchemaField(description="The related keyword")
|
||||
@@ -46,7 +45,7 @@ class RelatedKeyword(BlockSchemaInput):
|
||||
class DataForSeoRelatedKeywordsBlock(Block):
|
||||
"""Block for getting related keywords from DataForSEO Labs."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = dataforseo.credentials_field(
|
||||
description="DataForSEO credentials (username and password)"
|
||||
)
|
||||
@@ -79,14 +78,8 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
ge=1,
|
||||
le=3000,
|
||||
)
|
||||
depth: int = SchemaField(
|
||||
description="Keyword search depth (0-4). Controls the number of returned keywords: 0=1 keyword, 1=~8 keywords, 2=~72 keywords, 3=~584 keywords, 4=~4680 keywords",
|
||||
default=1,
|
||||
ge=0,
|
||||
le=4,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
related_keywords: List[RelatedKeyword] = SchemaField(
|
||||
description="List of related keywords with metrics"
|
||||
)
|
||||
@@ -161,7 +154,6 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
include_serp_info=input_data.include_serp_info,
|
||||
include_clickstream_data=input_data.include_clickstream_data,
|
||||
limit=input_data.limit,
|
||||
depth=input_data.depth,
|
||||
)
|
||||
|
||||
async def run(
|
||||
@@ -172,71 +164,61 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the related keywords query."""
|
||||
try:
|
||||
client = DataForSeoClient(credentials)
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_related_keywords(client, input_data)
|
||||
results = await self._fetch_related_keywords(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
related_keywords = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", [])
|
||||
if isinstance(first_result, dict)
|
||||
else []
|
||||
# Process and format the results
|
||||
related_keywords = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
keyword_data = item.get("keyword_data", {})
|
||||
|
||||
# Create the RelatedKeyword object
|
||||
keyword = RelatedKeyword(
|
||||
keyword=keyword_data.get("keyword", ""),
|
||||
search_volume=keyword_data.get("keyword_info", {}).get(
|
||||
"search_volume"
|
||||
),
|
||||
competition=keyword_data.get("keyword_info", {}).get("competition"),
|
||||
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=keyword_data.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
keyword_data.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
keyword_data.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
# Ensure items is never None
|
||||
if items is None:
|
||||
items = []
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
keyword_data = item.get("keyword_data", {})
|
||||
yield "related_keyword", keyword
|
||||
related_keywords.append(keyword)
|
||||
|
||||
# Create the RelatedKeyword object
|
||||
keyword = RelatedKeyword(
|
||||
keyword=keyword_data.get("keyword", ""),
|
||||
search_volume=keyword_data.get("keyword_info", {}).get(
|
||||
"search_volume"
|
||||
),
|
||||
competition=keyword_data.get("keyword_info", {}).get(
|
||||
"competition"
|
||||
),
|
||||
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=keyword_data.get(
|
||||
"keyword_properties", {}
|
||||
).get("keyword_difficulty"),
|
||||
serp_info=(
|
||||
keyword_data.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
keyword_data.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "related_keyword", keyword
|
||||
related_keywords.append(keyword)
|
||||
|
||||
yield "related_keywords", related_keywords
|
||||
yield "total_count", len(related_keywords)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to fetch related keywords: {str(e)}"
|
||||
yield "related_keywords", related_keywords
|
||||
yield "total_count", len(related_keywords)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
|
||||
|
||||
class RelatedKeywordExtractorBlock(Block):
|
||||
"""Extracts individual fields from a RelatedKeyword object."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
related_keyword: RelatedKeyword = SchemaField(
|
||||
description="The related keyword object to extract fields from"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
keyword: str = SchemaField(description="The related keyword")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
|
||||
@@ -1,23 +1,17 @@
|
||||
import codecs
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TextDecoderBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(
|
||||
description="A string containing escaped characters to be decoded",
|
||||
placeholder='Your entire text block with \\n and \\" escaped characters',
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
decoded_text: str = SchemaField(
|
||||
description="The decoded text with escape sequences processed"
|
||||
)
|
||||
|
||||
@@ -4,19 +4,13 @@ import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.request import Requests
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
from ._auth import (
|
||||
@@ -34,10 +28,10 @@ TEST_CREDENTIALS_INPUT = TEST_BOT_CREDENTIALS_INPUT
|
||||
|
||||
|
||||
class ReadDiscordMessagesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
message_content: str = SchemaField(
|
||||
description="The content of the message received"
|
||||
)
|
||||
@@ -120,9 +114,10 @@ class ReadDiscordMessagesBlock(Block):
|
||||
if message.attachments:
|
||||
attachment = message.attachments[0] # Process the first attachment
|
||||
if attachment.filename.endswith((".txt", ".py")):
|
||||
response = await Requests().get(attachment.url)
|
||||
file_content = response.text()
|
||||
self.output_data += f"\n\nFile from user: {attachment.filename}\nContent: {file_content}"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(attachment.url) as response:
|
||||
file_content = response.text()
|
||||
self.output_data += f"\n\nFile from user: {attachment.filename}\nContent: {file_content}"
|
||||
|
||||
await client.close()
|
||||
|
||||
@@ -170,21 +165,21 @@ class ReadDiscordMessagesBlock(Block):
|
||||
|
||||
|
||||
class SendDiscordMessageBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
message_content: str = SchemaField(
|
||||
description="The content of the message to send"
|
||||
)
|
||||
channel_name: str = SchemaField(
|
||||
description="Channel ID or channel name to send the message to"
|
||||
description="The name of the channel the message will be sent to"
|
||||
)
|
||||
server_name: str = SchemaField(
|
||||
description="Server name (only needed if using channel name)",
|
||||
advanced=True,
|
||||
description="The name of the server where the channel is located",
|
||||
advanced=True, # Optional field for server name
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(
|
||||
description="The status of the operation (e.g., 'Message sent', 'Error')"
|
||||
)
|
||||
@@ -236,49 +231,25 @@ class SendDiscordMessageBlock(Block):
|
||||
@client.event
|
||||
async def on_ready():
|
||||
print(f"Logged in as {client.user}")
|
||||
channel = None
|
||||
for guild in client.guilds:
|
||||
if server_name and guild.name != server_name:
|
||||
continue
|
||||
for channel in guild.text_channels:
|
||||
if channel.name == channel_name:
|
||||
# Split message into chunks if it exceeds 2000 characters
|
||||
chunks = self.chunk_message(message_content)
|
||||
last_message = None
|
||||
for chunk in chunks:
|
||||
last_message = await channel.send(chunk)
|
||||
result["status"] = "Message sent"
|
||||
result["message_id"] = (
|
||||
str(last_message.id) if last_message else ""
|
||||
)
|
||||
result["channel_id"] = str(channel.id)
|
||||
await client.close()
|
||||
return
|
||||
|
||||
# Try to parse as channel ID first
|
||||
try:
|
||||
channel_id = int(channel_name)
|
||||
channel = client.get_channel(channel_id)
|
||||
except ValueError:
|
||||
# Not a valid ID, will try name lookup
|
||||
pass
|
||||
|
||||
# If not found by ID (or not an ID), try name lookup
|
||||
if not channel:
|
||||
for guild in client.guilds:
|
||||
if server_name and guild.name != server_name:
|
||||
continue
|
||||
for ch in guild.text_channels:
|
||||
if ch.name == channel_name:
|
||||
channel = ch
|
||||
break
|
||||
if channel:
|
||||
break
|
||||
|
||||
if not channel:
|
||||
result["status"] = f"Channel not found: {channel_name}"
|
||||
await client.close()
|
||||
return
|
||||
|
||||
# Type check - ensure it's a text channel that can send messages
|
||||
if not hasattr(channel, "send"):
|
||||
result["status"] = (
|
||||
f"Channel {channel_name} cannot receive messages (not a text channel)"
|
||||
)
|
||||
await client.close()
|
||||
return
|
||||
|
||||
# Split message into chunks if it exceeds 2000 characters
|
||||
chunks = self.chunk_message(message_content)
|
||||
last_message = None
|
||||
for chunk in chunks:
|
||||
last_message = await channel.send(chunk) # type: ignore
|
||||
result["status"] = "Message sent"
|
||||
result["message_id"] = str(last_message.id) if last_message else ""
|
||||
result["channel_id"] = str(channel.id)
|
||||
result["status"] = "Channel not found"
|
||||
await client.close()
|
||||
|
||||
await client.start(token)
|
||||
@@ -316,7 +287,7 @@ class SendDiscordMessageBlock(Block):
|
||||
|
||||
|
||||
class SendDiscordDMBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
user_id: str = SchemaField(
|
||||
description="The Discord user ID to send the DM to (e.g., '123456789012345678')"
|
||||
@@ -325,7 +296,7 @@ class SendDiscordDMBlock(Block):
|
||||
description="The content of the direct message to send"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="The status of the operation")
|
||||
message_id: str = SchemaField(description="The ID of the sent message")
|
||||
|
||||
@@ -405,7 +376,7 @@ class SendDiscordDMBlock(Block):
|
||||
|
||||
|
||||
class SendDiscordEmbedBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
channel_identifier: str = SchemaField(
|
||||
description="Channel ID or channel name to send the embed to"
|
||||
@@ -442,7 +413,7 @@ class SendDiscordEmbedBlock(Block):
|
||||
default=[],
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Operation status")
|
||||
message_id: str = SchemaField(description="ID of the sent embed message")
|
||||
|
||||
@@ -592,7 +563,7 @@ class SendDiscordEmbedBlock(Block):
|
||||
|
||||
|
||||
class SendDiscordFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
channel_identifier: str = SchemaField(
|
||||
description="Channel ID or channel name to send the file to"
|
||||
@@ -613,7 +584,7 @@ class SendDiscordFileBlock(Block):
|
||||
description="Optional message to send with the file", default=""
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Operation status")
|
||||
message_id: str = SchemaField(description="ID of the sent message")
|
||||
|
||||
@@ -704,15 +675,16 @@ class SendDiscordFileBlock(Block):
|
||||
|
||||
elif file.startswith(("http://", "https://")):
|
||||
# URL - download the file
|
||||
response = await Requests().get(file)
|
||||
file_bytes = response.content
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(file) as response:
|
||||
file_bytes = await response.read()
|
||||
|
||||
# Try to get filename from URL if not provided
|
||||
if not filename:
|
||||
from urllib.parse import urlparse
|
||||
# Try to get filename from URL if not provided
|
||||
if not filename:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
path = urlparse(file).path
|
||||
detected_filename = Path(path).name or "download"
|
||||
path = urlparse(file).path
|
||||
detected_filename = Path(path).name or "download"
|
||||
else:
|
||||
# Local file path - read from stored media file
|
||||
# This would be a path from a previous block's output
|
||||
@@ -794,7 +766,7 @@ class SendDiscordFileBlock(Block):
|
||||
|
||||
|
||||
class ReplyToDiscordMessageBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
channel_id: str = SchemaField(
|
||||
description="The channel ID where the message to reply to is located"
|
||||
@@ -805,7 +777,7 @@ class ReplyToDiscordMessageBlock(Block):
|
||||
description="Whether to mention the original message author", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Operation status")
|
||||
reply_id: str = SchemaField(description="ID of the reply message")
|
||||
|
||||
@@ -919,13 +891,13 @@ class ReplyToDiscordMessageBlock(Block):
|
||||
|
||||
|
||||
class DiscordUserInfoBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
user_id: str = SchemaField(
|
||||
description="The Discord user ID to get information about"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
user_id: str = SchemaField(
|
||||
description="The user's ID (passed through for chaining)"
|
||||
)
|
||||
@@ -1036,7 +1008,7 @@ class DiscordUserInfoBlock(Block):
|
||||
|
||||
|
||||
class DiscordChannelInfoBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
channel_identifier: str = SchemaField(
|
||||
description="Channel name or channel ID to look up"
|
||||
@@ -1047,7 +1019,7 @@ class DiscordChannelInfoBlock(Block):
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
channel_id: str = SchemaField(description="The channel's ID")
|
||||
channel_name: str = SchemaField(description="The channel's name")
|
||||
server_id: str = SchemaField(description="The server's ID")
|
||||
|
||||
@@ -2,13 +2,7 @@
|
||||
Discord OAuth-based blocks.
|
||||
"""
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import DiscordOAuthUser, get_current_user
|
||||
@@ -27,12 +21,12 @@ class DiscordGetCurrentUserBlock(Block):
|
||||
This block requires Discord OAuth2 credentials (not bot tokens).
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordOAuthCredentialsInput = DiscordOAuthCredentialsField(
|
||||
["identify"]
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
user_id: str = SchemaField(description="The authenticated user's Discord ID")
|
||||
username: str = SchemaField(description="The user's username")
|
||||
avatar_url: str = SchemaField(description="URL to the user's avatar image")
|
||||
|
||||
@@ -5,13 +5,7 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, SecretStr
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
@@ -57,7 +51,7 @@ class SMTPConfig(BaseModel):
|
||||
|
||||
|
||||
class SendEmailBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
to_email: str = SchemaField(
|
||||
description="Recipient email address", placeholder="recipient@example.com"
|
||||
)
|
||||
@@ -73,7 +67,7 @@ class SendEmailBlock(Block):
|
||||
)
|
||||
credentials: SMTPCredentialsInput = SMTPCredentialsField()
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Status of the email sending operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the email sending failed"
|
||||
|
||||
@@ -8,13 +8,7 @@ which provides access to LinkedIn profile data and related information.
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
@@ -35,7 +29,7 @@ logger = logging.getLogger(__name__)
|
||||
class GetLinkedinProfileBlock(Block):
|
||||
"""Block to fetch LinkedIn profile data using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
"""Input schema for GetLinkedinProfileBlock."""
|
||||
|
||||
linkedin_url: str = SchemaField(
|
||||
@@ -86,12 +80,13 @@ class GetLinkedinProfileBlock(Block):
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
"""Output schema for GetLinkedinProfileBlock."""
|
||||
|
||||
profile: PersonProfileResponse = SchemaField(
|
||||
description="LinkedIn profile data"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GetLinkedinProfileBlock."""
|
||||
@@ -204,7 +199,7 @@ class GetLinkedinProfileBlock(Block):
|
||||
class LinkedinPersonLookupBlock(Block):
|
||||
"""Block to look up LinkedIn profiles by person's information using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
"""Input schema for LinkedinPersonLookupBlock."""
|
||||
|
||||
first_name: str = SchemaField(
|
||||
@@ -247,12 +242,13 @@ class LinkedinPersonLookupBlock(Block):
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
"""Output schema for LinkedinPersonLookupBlock."""
|
||||
|
||||
lookup_result: PersonLookupResponse = SchemaField(
|
||||
description="LinkedIn profile lookup result"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize LinkedinPersonLookupBlock."""
|
||||
@@ -350,7 +346,7 @@ class LinkedinPersonLookupBlock(Block):
|
||||
class LinkedinRoleLookupBlock(Block):
|
||||
"""Block to look up LinkedIn profiles by role in a company using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
"""Input schema for LinkedinRoleLookupBlock."""
|
||||
|
||||
role: str = SchemaField(
|
||||
@@ -370,12 +366,13 @@ class LinkedinRoleLookupBlock(Block):
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
"""Output schema for LinkedinRoleLookupBlock."""
|
||||
|
||||
role_lookup_result: RoleLookupResponse = SchemaField(
|
||||
description="LinkedIn role lookup result"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize LinkedinRoleLookupBlock."""
|
||||
@@ -452,7 +449,7 @@ class LinkedinRoleLookupBlock(Block):
|
||||
class GetLinkedinProfilePictureBlock(Block):
|
||||
"""Block to get LinkedIn profile pictures using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
"""Input schema for GetLinkedinProfilePictureBlock."""
|
||||
|
||||
linkedin_profile_url: str = SchemaField(
|
||||
@@ -463,12 +460,13 @@ class GetLinkedinProfilePictureBlock(Block):
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
"""Output schema for GetLinkedinProfilePictureBlock."""
|
||||
|
||||
profile_picture_url: MediaFileType = SchemaField(
|
||||
description="LinkedIn profile picture URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GetLinkedinProfilePictureBlock."""
|
||||
|
||||
@@ -4,8 +4,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
@@ -50,7 +49,7 @@ class CostDollars(BaseModel):
|
||||
|
||||
|
||||
class ExaAnswerBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
@@ -70,7 +69,7 @@ class ExaAnswerBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
answer: str = SchemaField(
|
||||
description="The generated answer based on search results"
|
||||
)
|
||||
|
||||
@@ -3,8 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
@@ -15,7 +14,7 @@ from .helpers import ContentSettings
|
||||
|
||||
|
||||
class ExaContentsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
@@ -28,7 +27,7 @@ class ExaContentsBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
results: list = SchemaField(
|
||||
description="List of document contents", default_factory=list
|
||||
)
|
||||
|
||||
@@ -5,8 +5,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
@@ -17,7 +16,7 @@ from .helpers import ContentSettings
|
||||
|
||||
|
||||
class ExaSearchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
@@ -64,7 +63,7 @@ class ExaSearchBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
results: list = SchemaField(
|
||||
description="List of search results", default_factory=list
|
||||
)
|
||||
|
||||
@@ -6,8 +6,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
@@ -18,7 +17,7 @@ from .helpers import ContentSettings
|
||||
|
||||
|
||||
class ExaFindSimilarBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
@@ -66,7 +65,7 @@ class ExaFindSimilarBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
results: list[Any] = SchemaField(
|
||||
description="List of similar documents with title, URL, published date, author, and score",
|
||||
default_factory=list,
|
||||
|
||||
@@ -9,8 +9,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
BlockWebhookConfig,
|
||||
CredentialsMetaInput,
|
||||
@@ -85,7 +84,7 @@ class ExaWebsetWebhookBlock(Block):
|
||||
including creation, updates, searches, and exports.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="Exa API credentials for webhook management"
|
||||
)
|
||||
@@ -105,7 +104,7 @@ class ExaWebsetWebhookBlock(Block):
|
||||
description="Webhook payload data", default={}, hidden=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
event_type: str = SchemaField(description="Type of event that occurred")
|
||||
event_id: str = SchemaField(description="Unique identifier for this event")
|
||||
webset_id: str = SchemaField(description="ID of the affected webset")
|
||||
|
||||
@@ -31,8 +31,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
@@ -105,7 +104,7 @@ class Webset(BaseModel):
|
||||
|
||||
|
||||
class ExaCreateWebsetBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
@@ -220,7 +219,7 @@ class ExaCreateWebsetBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
webset: Webset = SchemaField(
|
||||
description="The unique identifier for the created webset"
|
||||
)
|
||||
@@ -405,7 +404,7 @@ class ExaCreateWebsetBlock(Block):
|
||||
|
||||
|
||||
class ExaUpdateWebsetBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
@@ -418,7 +417,7 @@ class ExaUpdateWebsetBlock(Block):
|
||||
description="Key-value pairs to associate with this webset (set to null to clear)",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
webset_id: str = SchemaField(description="The unique identifier for the webset")
|
||||
status: str = SchemaField(description="The status of the webset")
|
||||
external_id: Optional[str] = SchemaField(
|
||||
@@ -476,7 +475,7 @@ class ExaUpdateWebsetBlock(Block):
|
||||
|
||||
|
||||
class ExaListWebsetsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
@@ -498,7 +497,7 @@ class ExaListWebsetsBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
websets: list[Webset] = SchemaField(
|
||||
description="List of websets", default_factory=list
|
||||
)
|
||||
@@ -551,7 +550,7 @@ class ExaListWebsetsBlock(Block):
|
||||
|
||||
|
||||
class ExaGetWebsetBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
@@ -560,7 +559,7 @@ class ExaGetWebsetBlock(Block):
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
webset_id: str = SchemaField(description="The unique identifier for the webset")
|
||||
status: str = SchemaField(description="The status of the webset")
|
||||
external_id: Optional[str] = SchemaField(
|
||||
@@ -638,7 +637,7 @@ class ExaGetWebsetBlock(Block):
|
||||
|
||||
|
||||
class ExaDeleteWebsetBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
@@ -647,7 +646,7 @@ class ExaDeleteWebsetBlock(Block):
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
webset_id: str = SchemaField(
|
||||
description="The unique identifier for the deleted webset"
|
||||
)
|
||||
@@ -696,7 +695,7 @@ class ExaDeleteWebsetBlock(Block):
|
||||
|
||||
|
||||
class ExaCancelWebsetBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
@@ -705,7 +704,7 @@ class ExaCancelWebsetBlock(Block):
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
webset_id: str = SchemaField(description="The unique identifier for the webset")
|
||||
status: str = SchemaField(
|
||||
description="The status of the webset after cancellation"
|
||||
|
||||
@@ -10,13 +10,7 @@ from backend.blocks.fal._auth import (
|
||||
FalCredentialsField,
|
||||
FalCredentialsInput,
|
||||
)
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import ClientResponseError, Requests
|
||||
|
||||
@@ -30,7 +24,7 @@ class FalModel(str, Enum):
|
||||
|
||||
|
||||
class AIVideoGeneratorBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
prompt: str = SchemaField(
|
||||
description="Description of the video to generate.",
|
||||
placeholder="A dog running in a field.",
|
||||
@@ -42,7 +36,7 @@ class AIVideoGeneratorBlock(Block):
|
||||
)
|
||||
credentials: FalCredentialsInput = FalCredentialsField()
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="The URL of the generated video.")
|
||||
error: str = SchemaField(
|
||||
description="Error message if video generation failed."
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Utility functions for converting between our ScrapeFormat enum and firecrawl FormatOption types."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from firecrawl.v2.types import FormatOption, ScreenshotFormat
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
|
||||
|
||||
def convert_to_format_options(
|
||||
formats: List[ScrapeFormat],
|
||||
) -> List[FormatOption]:
|
||||
"""Convert our ScrapeFormat enum values to firecrawl FormatOption types.
|
||||
|
||||
Handles special cases like screenshot@fullPage which needs to be converted
|
||||
to a ScreenshotFormat object.
|
||||
"""
|
||||
result: List[FormatOption] = []
|
||||
|
||||
for format_enum in formats:
|
||||
if format_enum.value == "screenshot@fullPage":
|
||||
# Special case: convert to ScreenshotFormat with full_page=True
|
||||
result.append(ScreenshotFormat(type="screenshot", full_page=True))
|
||||
else:
|
||||
# Regular string literals
|
||||
result.append(format_enum.value)
|
||||
|
||||
return result
|
||||
@@ -1,26 +1,35 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
from firecrawl.v2.types import ScrapeOptions
|
||||
from firecrawl import FirecrawlApp, ScrapeOptions
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import firecrawl
|
||||
from ._format_utils import convert_to_format_options
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
|
||||
|
||||
class FirecrawlCrawlBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
url: str = SchemaField(description="The URL to crawl")
|
||||
limit: int = SchemaField(description="The number of pages to crawl", default=10)
|
||||
@@ -40,7 +49,7 @@ class FirecrawlCrawlBlock(Block):
|
||||
description="The format of the crawl", default=[ScrapeFormat.MARKDOWN]
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
data: list[dict[str, Any]] = SchemaField(description="The result of the crawl")
|
||||
markdown: str = SchemaField(description="The markdown of the crawl")
|
||||
html: str = SchemaField(description="The html of the crawl")
|
||||
@@ -56,10 +65,6 @@ class FirecrawlCrawlBlock(Block):
|
||||
change_tracking: dict[str, Any] = SchemaField(
|
||||
description="The change tracking of the crawl"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the crawl failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -73,17 +78,18 @@ class FirecrawlCrawlBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Sync call
|
||||
crawl_result = app.crawl(
|
||||
crawl_result = app.crawl_url(
|
||||
input_data.url,
|
||||
limit=input_data.limit,
|
||||
scrape_options=ScrapeOptions(
|
||||
formats=convert_to_format_options(input_data.formats),
|
||||
only_main_content=input_data.only_main_content,
|
||||
max_age=input_data.max_age,
|
||||
wait_for=input_data.wait_for,
|
||||
formats=[format.value for format in input_data.formats],
|
||||
onlyMainContent=input_data.only_main_content,
|
||||
maxAge=input_data.max_age,
|
||||
waitFor=input_data.wait_for,
|
||||
),
|
||||
)
|
||||
yield "data", crawl_result.data
|
||||
@@ -95,7 +101,7 @@ class FirecrawlCrawlBlock(Block):
|
||||
elif f == ScrapeFormat.HTML:
|
||||
yield "html", data.html
|
||||
elif f == ScrapeFormat.RAW_HTML:
|
||||
yield "raw_html", data.raw_html
|
||||
yield "raw_html", data.rawHtml
|
||||
elif f == ScrapeFormat.LINKS:
|
||||
yield "links", data.links
|
||||
elif f == ScrapeFormat.SCREENSHOT:
|
||||
@@ -103,6 +109,6 @@ class FirecrawlCrawlBlock(Block):
|
||||
elif f == ScrapeFormat.SCREENSHOT_FULL_PAGE:
|
||||
yield "screenshot_full_page", data.screenshot
|
||||
elif f == ScrapeFormat.CHANGE_TRACKING:
|
||||
yield "change_tracking", data.change_tracking
|
||||
yield "change_tracking", data.changeTracking
|
||||
elif f == ScrapeFormat.JSON:
|
||||
yield "json", data.json
|
||||
|
||||
@@ -9,8 +9,7 @@ from backend.sdk import (
|
||||
BlockCost,
|
||||
BlockCostType,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
cost,
|
||||
@@ -21,7 +20,8 @@ from ._config import firecrawl
|
||||
|
||||
@cost(BlockCost(2, BlockCostType.RUN))
|
||||
class FirecrawlExtractBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
urls: list[str] = SchemaField(
|
||||
description="The URLs to crawl - at least one is required. Wildcards are supported. (/*)"
|
||||
@@ -38,12 +38,8 @@ class FirecrawlExtractBlock(Block):
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
data: dict[str, Any] = SchemaField(description="The result of the crawl")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the extraction failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -57,6 +53,7 @@ class FirecrawlExtractBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
extract_result = app.extract(
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
from backend.sdk import (
|
||||
@@ -7,8 +5,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -17,20 +14,14 @@ from ._config import firecrawl
|
||||
|
||||
|
||||
class FirecrawlMapWebsiteBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
|
||||
url: str = SchemaField(description="The website url to map")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
links: list[str] = SchemaField(description="List of URLs found on the website")
|
||||
results: list[dict[str, Any]] = SchemaField(
|
||||
description="List of search results with url, title, and description"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the map failed",
|
||||
default="",
|
||||
)
|
||||
class Output(BlockSchema):
|
||||
links: list[str] = SchemaField(description="The links of the website")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -44,22 +35,12 @@ class FirecrawlMapWebsiteBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Sync call
|
||||
map_result = app.map(
|
||||
map_result = app.map_url(
|
||||
url=input_data.url,
|
||||
)
|
||||
|
||||
# Convert SearchResult objects to dicts
|
||||
results_data = [
|
||||
{
|
||||
"url": link.url,
|
||||
"title": link.title,
|
||||
"description": link.description,
|
||||
}
|
||||
for link in map_result.links
|
||||
]
|
||||
|
||||
yield "links", [link.url for link in map_result.links]
|
||||
yield "results", results_data
|
||||
yield "links", map_result.links
|
||||
|
||||
@@ -1,25 +1,35 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import firecrawl
|
||||
from ._format_utils import convert_to_format_options
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
|
||||
|
||||
class FirecrawlScrapeBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
url: str = SchemaField(description="The URL to crawl")
|
||||
limit: int = SchemaField(description="The number of pages to crawl", default=10)
|
||||
@@ -39,7 +49,7 @@ class FirecrawlScrapeBlock(Block):
|
||||
description="The format of the crawl", default=[ScrapeFormat.MARKDOWN]
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
data: dict[str, Any] = SchemaField(description="The result of the crawl")
|
||||
markdown: str = SchemaField(description="The markdown of the crawl")
|
||||
html: str = SchemaField(description="The html of the crawl")
|
||||
@@ -55,10 +65,6 @@ class FirecrawlScrapeBlock(Block):
|
||||
change_tracking: dict[str, Any] = SchemaField(
|
||||
description="The change tracking of the crawl"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the scrape failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -72,11 +78,12 @@ class FirecrawlScrapeBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
scrape_result = app.scrape(
|
||||
scrape_result = app.scrape_url(
|
||||
input_data.url,
|
||||
formats=convert_to_format_options(input_data.formats),
|
||||
formats=[format.value for format in input_data.formats],
|
||||
only_main_content=input_data.only_main_content,
|
||||
max_age=input_data.max_age,
|
||||
wait_for=input_data.wait_for,
|
||||
@@ -89,7 +96,7 @@ class FirecrawlScrapeBlock(Block):
|
||||
elif f == ScrapeFormat.HTML:
|
||||
yield "html", scrape_result.html
|
||||
elif f == ScrapeFormat.RAW_HTML:
|
||||
yield "raw_html", scrape_result.raw_html
|
||||
yield "raw_html", scrape_result.rawHtml
|
||||
elif f == ScrapeFormat.LINKS:
|
||||
yield "links", scrape_result.links
|
||||
elif f == ScrapeFormat.SCREENSHOT:
|
||||
@@ -97,6 +104,6 @@ class FirecrawlScrapeBlock(Block):
|
||||
elif f == ScrapeFormat.SCREENSHOT_FULL_PAGE:
|
||||
yield "screenshot_full_page", scrape_result.screenshot
|
||||
elif f == ScrapeFormat.CHANGE_TRACKING:
|
||||
yield "change_tracking", scrape_result.change_tracking
|
||||
yield "change_tracking", scrape_result.changeTracking
|
||||
elif f == ScrapeFormat.JSON:
|
||||
yield "json", scrape_result.json
|
||||
|
||||
@@ -1,26 +1,35 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
from firecrawl.v2.types import ScrapeOptions
|
||||
from firecrawl import FirecrawlApp, ScrapeOptions
|
||||
|
||||
from backend.blocks.firecrawl._api import ScrapeFormat
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import firecrawl
|
||||
from ._format_utils import convert_to_format_options
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
|
||||
|
||||
class FirecrawlSearchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
query: str = SchemaField(description="The query to search for")
|
||||
limit: int = SchemaField(description="The number of pages to crawl", default=10)
|
||||
@@ -36,13 +45,9 @@ class FirecrawlSearchBlock(Block):
|
||||
description="Returns the content of the search if specified", default=[]
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
data: dict[str, Any] = SchemaField(description="The result of the search")
|
||||
site: dict[str, Any] = SchemaField(description="The site of the search")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the search failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -56,6 +61,7 @@ class FirecrawlSearchBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Sync call
|
||||
@@ -63,12 +69,11 @@ class FirecrawlSearchBlock(Block):
|
||||
input_data.query,
|
||||
limit=input_data.limit,
|
||||
scrape_options=ScrapeOptions(
|
||||
formats=convert_to_format_options(input_data.formats) or None,
|
||||
max_age=input_data.max_age,
|
||||
wait_for=input_data.wait_for,
|
||||
formats=[format.value for format in input_data.formats],
|
||||
maxAge=input_data.max_age,
|
||||
waitFor=input_data.wait_for,
|
||||
),
|
||||
)
|
||||
yield "data", scrape_result
|
||||
if hasattr(scrape_result, "web") and scrape_result.web:
|
||||
for site in scrape_result.web:
|
||||
yield "site", site
|
||||
for site in scrape_result.data:
|
||||
yield "site", site
|
||||
|
||||
@@ -5,13 +5,7 @@ from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -63,7 +57,7 @@ class AspectRatio(str, Enum):
|
||||
|
||||
|
||||
class AIImageEditorBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -96,10 +90,11 @@ class AIImageEditorBlock(Block):
|
||||
title="Model",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
output_image: MediaFileType = SchemaField(
|
||||
description="URL of the transformed image"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if generation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -3,8 +3,7 @@ from backend.sdk import (
|
||||
BlockCategory,
|
||||
BlockManualWebhookConfig,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
ProviderBuilder,
|
||||
ProviderName,
|
||||
SchemaField,
|
||||
@@ -20,14 +19,14 @@ generic_webhook = (
|
||||
|
||||
|
||||
class GenericWebhookTriggerBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
payload: dict = SchemaField(hidden=True, default_factory=dict)
|
||||
constants: dict = SchemaField(
|
||||
description="The constants to be set when the block is put on the graph",
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
payload: dict = SchemaField(
|
||||
description="The complete webhook payload that was received from the generic webhook."
|
||||
)
|
||||
|
||||
@@ -3,13 +3,7 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
@@ -45,7 +39,7 @@ class ChecksConclusion(Enum):
|
||||
class GithubCreateCheckRunBlock(Block):
|
||||
"""Block for creating a new check run on a GitHub repository."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo:status")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -82,7 +76,7 @@ class GithubCreateCheckRunBlock(Block):
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
class CheckRunResult(BaseModel):
|
||||
id: int
|
||||
html_url: str
|
||||
@@ -217,7 +211,7 @@ class GithubCreateCheckRunBlock(Block):
|
||||
class GithubUpdateCheckRunBlock(Block):
|
||||
"""Block for updating an existing check run on a GitHub repository."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo:status")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -245,7 +239,7 @@ class GithubUpdateCheckRunBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
class CheckRunResult(BaseModel):
|
||||
id: int
|
||||
html_url: str
|
||||
@@ -255,6 +249,7 @@ class GithubUpdateCheckRunBlock(Block):
|
||||
check_run: CheckRunResult = SchemaField(
|
||||
description="Details of the updated check run"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if check run update failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -5,13 +5,7 @@ from typing import Optional
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
@@ -43,7 +37,7 @@ class CheckRunConclusion(Enum):
|
||||
|
||||
|
||||
class GithubGetCIResultsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
@@ -66,7 +60,7 @@ class GithubGetCIResultsBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
class CheckRunItem(TypedDict, total=False):
|
||||
id: int
|
||||
name: str
|
||||
@@ -110,6 +104,7 @@ class GithubGetCIResultsBlock(Block):
|
||||
total_checks: int = SchemaField(description="Total number of CI checks")
|
||||
passed_checks: int = SchemaField(description="Number of passed checks")
|
||||
failed_checks: int = SchemaField(description="Number of failed checks")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -3,13 +3,7 @@ from urllib.parse import urlparse
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import convert_comment_url_to_api_endpoint, get_api
|
||||
@@ -30,7 +24,7 @@ def is_github_url(url: str) -> bool:
|
||||
|
||||
# --8<-- [start:GithubCommentBlockExample]
|
||||
class GithubCommentBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue or pull request",
|
||||
@@ -41,7 +35,7 @@ class GithubCommentBlock(Block):
|
||||
placeholder="Enter your comment",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
id: int = SchemaField(description="ID of the created comment")
|
||||
url: str = SchemaField(description="URL to the comment on GitHub")
|
||||
error: str = SchemaField(
|
||||
@@ -118,7 +112,7 @@ class GithubCommentBlock(Block):
|
||||
|
||||
|
||||
class GithubUpdateCommentBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
comment_url: str = SchemaField(
|
||||
description="URL of the GitHub comment",
|
||||
@@ -141,7 +135,7 @@ class GithubUpdateCommentBlock(Block):
|
||||
placeholder="Enter your comment",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
id: int = SchemaField(description="ID of the updated comment")
|
||||
url: str = SchemaField(description="URL to the comment on GitHub")
|
||||
error: str = SchemaField(
|
||||
@@ -225,14 +219,14 @@ class GithubUpdateCommentBlock(Block):
|
||||
|
||||
|
||||
class GithubListCommentsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue or pull request",
|
||||
placeholder="https://github.com/owner/repo/issues/1",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
class CommentItem(TypedDict):
|
||||
id: int
|
||||
body: str
|
||||
@@ -245,6 +239,7 @@ class GithubListCommentsBlock(Block):
|
||||
comments: list[CommentItem] = SchemaField(
|
||||
description="List of comments with their ID, body, user, and URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing comments failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -340,7 +335,7 @@ class GithubListCommentsBlock(Block):
|
||||
|
||||
|
||||
class GithubMakeIssueBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -353,7 +348,7 @@ class GithubMakeIssueBlock(Block):
|
||||
description="Body of the issue", placeholder="Enter the issue body"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
number: int = SchemaField(description="Number of the created issue")
|
||||
url: str = SchemaField(description="URL of the created issue")
|
||||
error: str = SchemaField(
|
||||
@@ -415,14 +410,14 @@ class GithubMakeIssueBlock(Block):
|
||||
|
||||
|
||||
class GithubReadIssueBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue",
|
||||
placeholder="https://github.com/owner/repo/issues/1",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
title: str = SchemaField(description="Title of the issue")
|
||||
body: str = SchemaField(description="Body of the issue")
|
||||
user: str = SchemaField(description="User who created the issue")
|
||||
@@ -488,14 +483,14 @@ class GithubReadIssueBlock(Block):
|
||||
|
||||
|
||||
class GithubListIssuesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
class IssueItem(TypedDict):
|
||||
title: str
|
||||
url: str
|
||||
@@ -506,6 +501,7 @@ class GithubListIssuesBlock(Block):
|
||||
issues: list[IssueItem] = SchemaField(
|
||||
description="List of issues with their title and URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing issues failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -577,7 +573,7 @@ class GithubListIssuesBlock(Block):
|
||||
|
||||
|
||||
class GithubAddLabelBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue or pull request",
|
||||
@@ -588,7 +584,7 @@ class GithubAddLabelBlock(Block):
|
||||
placeholder="Enter the label",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Status of the label addition operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the label addition failed"
|
||||
@@ -637,7 +633,7 @@ class GithubAddLabelBlock(Block):
|
||||
|
||||
|
||||
class GithubRemoveLabelBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue or pull request",
|
||||
@@ -648,7 +644,7 @@ class GithubRemoveLabelBlock(Block):
|
||||
placeholder="Enter the label",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Status of the label removal operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the label removal failed"
|
||||
@@ -698,7 +694,7 @@ class GithubRemoveLabelBlock(Block):
|
||||
|
||||
|
||||
class GithubAssignIssueBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue",
|
||||
@@ -709,7 +705,7 @@ class GithubAssignIssueBlock(Block):
|
||||
placeholder="Enter the username",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(
|
||||
description="Status of the issue assignment operation"
|
||||
)
|
||||
@@ -764,7 +760,7 @@ class GithubAssignIssueBlock(Block):
|
||||
|
||||
|
||||
class GithubUnassignIssueBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue",
|
||||
@@ -775,7 +771,7 @@ class GithubUnassignIssueBlock(Block):
|
||||
placeholder="Enter the username",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(
|
||||
description="Status of the issue unassignment operation"
|
||||
)
|
||||
|
||||
@@ -2,13 +2,7 @@ import re
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
@@ -22,14 +16,14 @@ from ._auth import (
|
||||
|
||||
|
||||
class GithubListPullRequestsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
class PRItem(TypedDict):
|
||||
title: str
|
||||
url: str
|
||||
@@ -114,7 +108,7 @@ class GithubListPullRequestsBlock(Block):
|
||||
|
||||
|
||||
class GithubMakePullRequestBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -141,7 +135,7 @@ class GithubMakePullRequestBlock(Block):
|
||||
placeholder="Enter the base branch",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
number: int = SchemaField(description="Number of the created pull request")
|
||||
url: str = SchemaField(description="URL of the created pull request")
|
||||
error: str = SchemaField(
|
||||
@@ -215,7 +209,7 @@ class GithubMakePullRequestBlock(Block):
|
||||
|
||||
|
||||
class GithubReadPullRequestBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
pr_url: str = SchemaField(
|
||||
description="URL of the GitHub pull request",
|
||||
@@ -227,7 +221,7 @@ class GithubReadPullRequestBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
title: str = SchemaField(description="Title of the pull request")
|
||||
body: str = SchemaField(description="Body of the pull request")
|
||||
author: str = SchemaField(description="User who created the pull request")
|
||||
@@ -331,7 +325,7 @@ class GithubReadPullRequestBlock(Block):
|
||||
|
||||
|
||||
class GithubAssignPRReviewerBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
pr_url: str = SchemaField(
|
||||
description="URL of the GitHub pull request",
|
||||
@@ -342,7 +336,7 @@ class GithubAssignPRReviewerBlock(Block):
|
||||
placeholder="Enter the reviewer's username",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(
|
||||
description="Status of the reviewer assignment operation"
|
||||
)
|
||||
@@ -398,7 +392,7 @@ class GithubAssignPRReviewerBlock(Block):
|
||||
|
||||
|
||||
class GithubUnassignPRReviewerBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
pr_url: str = SchemaField(
|
||||
description="URL of the GitHub pull request",
|
||||
@@ -409,7 +403,7 @@ class GithubUnassignPRReviewerBlock(Block):
|
||||
placeholder="Enter the reviewer's username",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(
|
||||
description="Status of the reviewer unassignment operation"
|
||||
)
|
||||
@@ -465,14 +459,14 @@ class GithubUnassignPRReviewerBlock(Block):
|
||||
|
||||
|
||||
class GithubListPRReviewersBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
pr_url: str = SchemaField(
|
||||
description="URL of the GitHub pull request",
|
||||
placeholder="https://github.com/owner/repo/pull/1",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
class ReviewerItem(TypedDict):
|
||||
username: str
|
||||
url: str
|
||||
|
||||
@@ -2,13 +2,7 @@ import base64
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
@@ -22,14 +16,14 @@ from ._auth import (
|
||||
|
||||
|
||||
class GithubListTagsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
class TagItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
@@ -40,6 +34,7 @@ class GithubListTagsBlock(Block):
|
||||
tags: list[TagItem] = SchemaField(
|
||||
description="List of tags with their name and file tree browser URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing tags failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -116,14 +111,14 @@ class GithubListTagsBlock(Block):
|
||||
|
||||
|
||||
class GithubListBranchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
class BranchItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
@@ -135,6 +130,7 @@ class GithubListBranchesBlock(Block):
|
||||
branches: list[BranchItem] = SchemaField(
|
||||
description="List of branches with their name and file tree browser URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing branches failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -211,7 +207,7 @@ class GithubListBranchesBlock(Block):
|
||||
|
||||
|
||||
class GithubListDiscussionsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -221,7 +217,7 @@ class GithubListDiscussionsBlock(Block):
|
||||
description="Number of discussions to fetch", default=5
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
class DiscussionItem(TypedDict):
|
||||
title: str
|
||||
url: str
|
||||
@@ -327,14 +323,14 @@ class GithubListDiscussionsBlock(Block):
|
||||
|
||||
|
||||
class GithubListReleasesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
class ReleaseItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
@@ -346,6 +342,7 @@ class GithubListReleasesBlock(Block):
|
||||
releases: list[ReleaseItem] = SchemaField(
|
||||
description="List of releases with their name and file tree browser URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing releases failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -417,7 +414,7 @@ class GithubListReleasesBlock(Block):
|
||||
|
||||
|
||||
class GithubReadFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -433,7 +430,7 @@ class GithubReadFileBlock(Block):
|
||||
default="master",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
text_content: str = SchemaField(
|
||||
description="Content of the file (decoded as UTF-8 text)"
|
||||
)
|
||||
@@ -441,6 +438,7 @@ class GithubReadFileBlock(Block):
|
||||
description="Raw base64-encoded content of the file"
|
||||
)
|
||||
size: int = SchemaField(description="The size of the file (in bytes)")
|
||||
error: str = SchemaField(description="Error message if the file reading failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -503,7 +501,7 @@ class GithubReadFileBlock(Block):
|
||||
|
||||
|
||||
class GithubReadFolderBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -519,7 +517,7 @@ class GithubReadFolderBlock(Block):
|
||||
default="master",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
class DirEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
@@ -627,7 +625,7 @@ class GithubReadFolderBlock(Block):
|
||||
|
||||
|
||||
class GithubMakeBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -642,7 +640,7 @@ class GithubMakeBranchBlock(Block):
|
||||
placeholder="source_branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Status of the branch creation operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch creation failed"
|
||||
@@ -707,7 +705,7 @@ class GithubMakeBranchBlock(Block):
|
||||
|
||||
|
||||
class GithubDeleteBranchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -718,7 +716,7 @@ class GithubDeleteBranchBlock(Block):
|
||||
placeholder="branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Status of the branch deletion operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch deletion failed"
|
||||
@@ -768,7 +766,7 @@ class GithubDeleteBranchBlock(Block):
|
||||
|
||||
|
||||
class GithubCreateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -791,7 +789,7 @@ class GithubCreateFileBlock(Block):
|
||||
default="Create new file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
url: str = SchemaField(description="URL of the created file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
error: str = SchemaField(
|
||||
@@ -870,7 +868,7 @@ class GithubCreateFileBlock(Block):
|
||||
|
||||
|
||||
class GithubUpdateFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
@@ -893,9 +891,10 @@ class GithubUpdateFileBlock(Block):
|
||||
default="Update file",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
url: str = SchemaField(description="URL of the updated file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
error: str = SchemaField(description="Error message if the file update failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -975,7 +974,7 @@ class GithubUpdateFileBlock(Block):
|
||||
|
||||
|
||||
class GithubCreateRepositoryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
name: str = SchemaField(
|
||||
description="Name of the repository to create",
|
||||
@@ -999,7 +998,7 @@ class GithubCreateRepositoryBlock(Block):
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
url: str = SchemaField(description="URL of the created repository")
|
||||
clone_url: str = SchemaField(description="Git clone URL of the repository")
|
||||
error: str = SchemaField(
|
||||
@@ -1078,14 +1077,14 @@ class GithubCreateRepositoryBlock(Block):
|
||||
|
||||
|
||||
class GithubListStargazersBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
class StargazerItem(TypedDict):
|
||||
username: str
|
||||
url: str
|
||||
|
||||
@@ -4,13 +4,7 @@ from typing import Any, List, Optional
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
@@ -32,7 +26,7 @@ class ReviewEvent(Enum):
|
||||
|
||||
|
||||
class GithubCreatePRReviewBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
class ReviewComment(TypedDict, total=False):
|
||||
path: str
|
||||
position: Optional[int]
|
||||
@@ -67,7 +61,7 @@ class GithubCreatePRReviewBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
review_id: int = SchemaField(description="ID of the created review")
|
||||
state: str = SchemaField(
|
||||
description="State of the review (e.g., PENDING, COMMENTED, APPROVED, CHANGES_REQUESTED)"
|
||||
@@ -203,7 +197,7 @@ class GithubCreatePRReviewBlock(Block):
|
||||
|
||||
|
||||
class GithubListPRReviewsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
@@ -214,7 +208,7 @@ class GithubListPRReviewsBlock(Block):
|
||||
placeholder="123",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
class ReviewItem(TypedDict):
|
||||
id: int
|
||||
user: str
|
||||
@@ -229,6 +223,7 @@ class GithubListPRReviewsBlock(Block):
|
||||
reviews: list[ReviewItem] = SchemaField(
|
||||
description="List of all reviews on the pull request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing reviews failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -322,7 +317,7 @@ class GithubListPRReviewsBlock(Block):
|
||||
|
||||
|
||||
class GithubSubmitPendingReviewBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
@@ -341,7 +336,7 @@ class GithubSubmitPendingReviewBlock(Block):
|
||||
default=ReviewEvent.COMMENT,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
state: str = SchemaField(description="State of the submitted review")
|
||||
html_url: str = SchemaField(description="URL of the submitted review")
|
||||
error: str = SchemaField(
|
||||
@@ -420,7 +415,7 @@ class GithubSubmitPendingReviewBlock(Block):
|
||||
|
||||
|
||||
class GithubResolveReviewDiscussionBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
@@ -439,8 +434,9 @@ class GithubResolveReviewDiscussionBlock(Block):
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the operation was successful")
|
||||
error: str = SchemaField(description="Error message if the operation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -583,7 +579,7 @@ class GithubResolveReviewDiscussionBlock(Block):
|
||||
|
||||
|
||||
class GithubGetPRReviewCommentsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo: str = SchemaField(
|
||||
description="GitHub repository",
|
||||
@@ -600,7 +596,7 @@ class GithubGetPRReviewCommentsBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
class CommentItem(TypedDict):
|
||||
id: int
|
||||
user: str
|
||||
@@ -620,6 +616,7 @@ class GithubGetPRReviewCommentsBlock(Block):
|
||||
comments: list[CommentItem] = SchemaField(
|
||||
description="List of all review comments on the pull request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if getting comments failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -747,7 +744,7 @@ class GithubGetPRReviewCommentsBlock(Block):
|
||||
|
||||
|
||||
class GithubCreateCommentObjectBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
path: str = SchemaField(
|
||||
description="The file path to comment on",
|
||||
placeholder="src/main.py",
|
||||
@@ -784,7 +781,7 @@ class GithubCreateCommentObjectBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
comment_object: dict = SchemaField(
|
||||
description="The comment object formatted for GitHub API"
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user