mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 16:48:06 -05:00
Compare commits
5 Commits
worktree/d
...
feat/execu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f652cb978 | ||
|
|
279552a2a3 | ||
|
|
fb6ac1d6ca | ||
|
|
9db15bff02 | ||
|
|
db4b94e0dc |
94
.github/copilot-instructions.md
vendored
94
.github/copilot-instructions.md
vendored
@@ -12,7 +12,6 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
|
||||
- **Infrastructure** - Docker configurations, CI/CD, and development tools
|
||||
|
||||
**Primary Languages & Frameworks:**
|
||||
|
||||
- **Backend**: Python 3.10-3.13, FastAPI, Prisma ORM, PostgreSQL, RabbitMQ
|
||||
- **Frontend**: TypeScript, Next.js 15, React, Tailwind CSS, Radix UI
|
||||
- **Development**: Docker, Poetry, pnpm, Playwright, Storybook
|
||||
@@ -24,17 +23,15 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
|
||||
**Always run these commands in the correct directory and in this order:**
|
||||
|
||||
1. **Initial Setup** (required once):
|
||||
|
||||
```bash
|
||||
# Clone and enter repository
|
||||
git clone <repo> && cd AutoGPT
|
||||
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
```
|
||||
|
||||
2. **Backend Setup** (always run before backend development):
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry install # Install dependencies
|
||||
@@ -51,7 +48,6 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
|
||||
### Runtime Requirements
|
||||
|
||||
**Critical:** Always ensure Docker services are running before starting development:
|
||||
|
||||
```bash
|
||||
cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
```
|
||||
@@ -62,7 +58,6 @@ cd autogpt_platform && docker compose --profile local up deps --build --detach
|
||||
### Development Commands
|
||||
|
||||
**Backend Development:**
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry run serve # Start development server (port 8000)
|
||||
@@ -73,7 +68,6 @@ poetry run lint # Lint code (ruff) - run after format
|
||||
```
|
||||
|
||||
**Frontend Development:**
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/frontend
|
||||
pnpm dev # Start development server (port 3000) - use for active development
|
||||
@@ -87,27 +81,23 @@ pnpm storybook # Start component development server
|
||||
### Testing Strategy
|
||||
|
||||
**Backend Tests:**
|
||||
|
||||
- **Block Tests**: `poetry run pytest backend/blocks/test/test_block.py -xvs` (validates all blocks)
|
||||
- **Specific Block**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[BlockName]' -xvs`
|
||||
- **Snapshot Tests**: Use `--snapshot-update` when output changes, always review with `git diff`
|
||||
|
||||
**Frontend Tests:**
|
||||
|
||||
- **E2E Tests**: Always run `pnpm dev` before `pnpm test` (Playwright requires running instance)
|
||||
- **Component Tests**: Use Storybook for isolated component development
|
||||
|
||||
### Critical Validation Steps
|
||||
|
||||
**Before committing changes:**
|
||||
|
||||
1. Run `poetry run format` (backend) and `pnpm format` (frontend)
|
||||
2. Ensure all tests pass in modified areas
|
||||
3. Verify Docker services are still running
|
||||
4. Check that database migrations apply cleanly
|
||||
|
||||
**Common Issues & Workarounds:**
|
||||
|
||||
- **Prisma issues**: Run `poetry run prisma generate` after schema changes
|
||||
- **Permission errors**: Ensure Docker has proper permissions
|
||||
- **Port conflicts**: Check the `docker-compose.yml` file for the current list of exposed ports. You can list all mapped ports with:
|
||||
@@ -118,7 +108,6 @@ pnpm storybook # Start component development server
|
||||
### Core Architecture
|
||||
|
||||
**AutoGPT Platform** (`autogpt_platform/`):
|
||||
|
||||
- `backend/` - FastAPI server with async support
|
||||
- `backend/backend/` - Core API logic
|
||||
- `backend/blocks/` - Agent execution blocks
|
||||
@@ -132,7 +121,6 @@ pnpm storybook # Start component development server
|
||||
- `docker-compose.yml` - Development stack orchestration
|
||||
|
||||
**Key Configuration Files:**
|
||||
|
||||
- `pyproject.toml` - Python dependencies and tooling
|
||||
- `package.json` - Node.js dependencies and scripts
|
||||
- `schema.prisma` - Database schema and migrations
|
||||
@@ -148,7 +136,6 @@ pnpm storybook # Start component development server
|
||||
### Development Workflow
|
||||
|
||||
**GitHub Actions**: Multiple CI/CD workflows in `.github/workflows/`
|
||||
|
||||
- `platform-backend-ci.yml` - Backend testing and validation
|
||||
- `platform-frontend-ci.yml` - Frontend testing and validation
|
||||
- `platform-fullstack-ci.yml` - End-to-end integration tests
|
||||
@@ -159,13 +146,11 @@ pnpm storybook # Start component development server
|
||||
### Key Source Files
|
||||
|
||||
**Backend Entry Points:**
|
||||
|
||||
- `backend/backend/server/server.py` - FastAPI application setup
|
||||
- `backend/backend/data/` - Database models and user management
|
||||
- `backend/blocks/` - Agent execution blocks and logic
|
||||
|
||||
**Frontend Entry Points:**
|
||||
|
||||
- `frontend/src/app/layout.tsx` - Root application layout
|
||||
- `frontend/src/app/page.tsx` - Home page
|
||||
- `frontend/src/lib/supabase/` - Authentication and database client
|
||||
@@ -175,7 +160,6 @@ pnpm storybook # Start component development server
|
||||
### Agent Block System
|
||||
|
||||
Agents are built using a visual block-based system where each block performs a single action. Blocks are defined in `backend/blocks/` and must include:
|
||||
|
||||
- Block definition with input/output schemas
|
||||
- Execution logic with proper error handling
|
||||
- Tests validating functionality
|
||||
@@ -183,7 +167,6 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
### Database & ORM
|
||||
|
||||
**Prisma ORM** with PostgreSQL backend including pgvector for embeddings:
|
||||
|
||||
- Schema in `schema.prisma`
|
||||
- Migrations in `backend/migrations/`
|
||||
- Always run `prisma migrate dev` and `prisma generate` after schema changes
|
||||
@@ -191,15 +174,13 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
## Environment Configuration
|
||||
|
||||
### Configuration Files Priority Order
|
||||
|
||||
1. **Backend**: `/backend/.env.default` → `/backend/.env` (user overrides)
|
||||
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
|
||||
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
|
||||
3. **Platform**: `/.env.default` (Supabase/shared) → `/.env` (user overrides)
|
||||
4. Docker Compose `environment:` sections override file-based config
|
||||
5. Shell environment variables have highest precedence
|
||||
|
||||
### Docker Environment Setup
|
||||
|
||||
- All services use hardcoded defaults (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
@@ -208,7 +189,6 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
## Advanced Development Patterns
|
||||
|
||||
### Adding New Blocks
|
||||
|
||||
1. Create file in `/backend/backend/blocks/`
|
||||
2. Inherit from `Block` base class with input/output schemas
|
||||
3. Implement `run` method with proper error handling
|
||||
@@ -218,7 +198,6 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
7. Consider how inputs/outputs connect with other blocks in graph editor
|
||||
|
||||
### API Development
|
||||
|
||||
1. Update routes in `/backend/backend/server/routers/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside route files
|
||||
@@ -226,76 +205,21 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
5. Run `poetry run test` to verify changes
|
||||
|
||||
### Frontend Development
|
||||
|
||||
**📖 Complete Frontend Guide**: See `autogpt_platform/frontend/CONTRIBUTING.md` and `autogpt_platform/frontend/.cursorrules` for comprehensive patterns and conventions.
|
||||
|
||||
**Quick Reference:**
|
||||
|
||||
**Component Structure:**
|
||||
|
||||
- Separate render logic from data/behavior
|
||||
- Structure: `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Exception: Small components (3-4 lines of logic) can be inline
|
||||
- Render-only components can be direct files without folders
|
||||
|
||||
**Data Fetching:**
|
||||
|
||||
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Generated via Orval from backend OpenAPI spec
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
- Example: `useGetV2ListLibraryAgents`
|
||||
- Regenerate with: `pnpm generate:api`
|
||||
- **Never** use deprecated `BackendAPI` or `src/lib/autogpt-server-api/*`
|
||||
|
||||
**Code Conventions:**
|
||||
|
||||
- Use function declarations for components and handlers (not arrow functions)
|
||||
- Only arrow functions for small inline lambdas (map, filter, etc.)
|
||||
- Components: `PascalCase`, Hooks: `camelCase` with `use` prefix
|
||||
- No barrel files or `index.ts` re-exports
|
||||
- Minimal comments (code should be self-documenting)
|
||||
|
||||
**Styling:**
|
||||
|
||||
- Use Tailwind CSS utilities only
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
- Only use Phosphor Icons (`@phosphor-icons/react`)
|
||||
- Prefer design tokens over hardcoded values
|
||||
|
||||
**Error Handling:**
|
||||
|
||||
- Render errors: Use `<ErrorCard />` component
|
||||
- Mutation errors: Display with toast notifications
|
||||
- Manual exceptions: Use `Sentry.captureException()`
|
||||
- Global error boundaries already configured
|
||||
|
||||
**Testing:**
|
||||
|
||||
- Add/update Storybook stories for UI components (`pnpm storybook`)
|
||||
- Run Playwright E2E tests with `pnpm test`
|
||||
- Verify in Chromatic after PR
|
||||
|
||||
**Architecture:**
|
||||
|
||||
- Default to client components ("use client")
|
||||
- Server components only for SEO or extreme TTFB needs
|
||||
- Use React Query for server state (via generated hooks)
|
||||
- Co-locate UI state in components/hooks
|
||||
1. Components in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for component development
|
||||
4. Test user-facing features with Playwright E2E tests
|
||||
5. Update protected routes in middleware when needed
|
||||
|
||||
### Security Guidelines
|
||||
|
||||
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
|
||||
|
||||
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
||||
- Prevents sensitive data caching in browsers/proxies
|
||||
- Add new cacheable endpoints to `CACHEABLE_PATHS`
|
||||
|
||||
### CI/CD Alignment
|
||||
|
||||
The repository has comprehensive CI workflows that test:
|
||||
|
||||
- **Backend**: Python 3.11-3.13, services (Redis/RabbitMQ/ClamAV), Prisma migrations, Poetry lock validation
|
||||
- **Frontend**: Node.js 21, pnpm, Playwright with Docker Compose stack, API schema validation
|
||||
- **Integration**: Full-stack type checking and E2E testing
|
||||
@@ -305,7 +229,6 @@ Match these patterns when developing locally - the copilot setup environment mir
|
||||
## Collaboration with Other AI Assistants
|
||||
|
||||
This repository is actively developed with assistance from Claude (via CLAUDE.md files). When working on this codebase:
|
||||
|
||||
- Check for existing CLAUDE.md files that provide additional context
|
||||
- Follow established patterns and conventions already in the codebase
|
||||
- Maintain consistency with existing code style and architecture
|
||||
@@ -314,9 +237,8 @@ This repository is actively developed with assistance from Claude (via CLAUDE.md
|
||||
## Trust These Instructions
|
||||
|
||||
These instructions are comprehensive and tested. Only perform additional searches if:
|
||||
|
||||
1. Information here is incomplete for your specific task
|
||||
2. You encounter errors not covered by the workarounds
|
||||
3. You need to understand implementation details not covered above
|
||||
|
||||
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.
|
||||
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.
|
||||
97
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
97
.github/workflows/claude-ci-failure-auto-fix.yml
vendored
@@ -1,97 +0,0 @@
|
||||
name: Auto Fix CI Failures
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["CI"]
|
||||
types:
|
||||
- completed
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
actions: read
|
||||
issues: write
|
||||
id-token: write # Required for OIDC token exchange
|
||||
|
||||
jobs:
|
||||
auto-fix:
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'failure' &&
|
||||
github.event.workflow_run.pull_requests[0] &&
|
||||
!startsWith(github.event.workflow_run.head_branch, 'claude-auto-fix-ci-')
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.workflow_run.head_branch }}
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Setup git identity
|
||||
run: |
|
||||
git config --global user.email "claude[bot]@users.noreply.github.com"
|
||||
git config --global user.name "claude[bot]"
|
||||
|
||||
- name: Create fix branch
|
||||
id: branch
|
||||
run: |
|
||||
BRANCH_NAME="claude-auto-fix-ci-${{ github.event.workflow_run.head_branch }}-${{ github.run_id }}"
|
||||
git checkout -b "$BRANCH_NAME"
|
||||
echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Get CI failure details
|
||||
id: failure_details
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const run = await github.rest.actions.getWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: ${{ github.event.workflow_run.id }}
|
||||
});
|
||||
|
||||
const jobs = await github.rest.actions.listJobsForWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: ${{ github.event.workflow_run.id }}
|
||||
});
|
||||
|
||||
const failedJobs = jobs.data.jobs.filter(job => job.conclusion === 'failure');
|
||||
|
||||
let errorLogs = [];
|
||||
for (const job of failedJobs) {
|
||||
const logs = await github.rest.actions.downloadJobLogsForWorkflowRun({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
job_id: job.id
|
||||
});
|
||||
errorLogs.push({
|
||||
jobName: job.name,
|
||||
logs: logs.data
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
runUrl: run.data.html_url,
|
||||
failedJobs: failedJobs.map(j => j.name),
|
||||
errorLogs: errorLogs
|
||||
};
|
||||
|
||||
- name: Fix CI failures with Claude
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
prompt: |
|
||||
/fix-ci
|
||||
Failed CI Run: ${{ fromJSON(steps.failure_details.outputs.result).runUrl }}
|
||||
Failed Jobs: ${{ join(fromJSON(steps.failure_details.outputs.result).failedJobs, ', ') }}
|
||||
PR Number: ${{ github.event.workflow_run.pull_requests[0].number }}
|
||||
Branch Name: ${{ steps.branch.outputs.branch_name }}
|
||||
Base Branch: ${{ github.event.workflow_run.head_branch }}
|
||||
Repository: ${{ github.repository }}
|
||||
|
||||
Error logs:
|
||||
${{ toJSON(fromJSON(steps.failure_details.outputs.result).errorLogs) }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: "--allowedTools 'Edit,MultiEdit,Write,Read,Glob,Grep,LS,Bash(git:*),Bash(bun:*),Bash(npm:*),Bash(npx:*),Bash(gh:*)'"
|
||||
379
.github/workflows/claude-dependabot.yml
vendored
379
.github/workflows/claude-dependabot.yml
vendored
@@ -1,379 +0,0 @@
|
||||
# Claude Dependabot PR Review Workflow
|
||||
#
|
||||
# This workflow automatically runs Claude analysis on Dependabot PRs to:
|
||||
# - Identify dependency changes and their versions
|
||||
# - Look up changelogs for updated packages
|
||||
# - Assess breaking changes and security impacts
|
||||
# - Provide actionable recommendations for the development team
|
||||
#
|
||||
# Triggered on: Dependabot PRs (opened, synchronize)
|
||||
# Requirements: ANTHROPIC_API_KEY secret must be configured
|
||||
|
||||
name: Claude Dependabot PR Review
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
|
||||
jobs:
|
||||
dependabot-review:
|
||||
# Only run on Dependabot PRs
|
||||
if: github.actor == 'dependabot[bot]'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: read
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read # Required for CI access
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock (matches CI)
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
# Add Poetry to PATH
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Check poetry.lock
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry lock
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
echo "Warning: poetry.lock not up to date, but continuing for setup"
|
||||
git checkout poetry.lock # Reset for clean setup
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Install Playwright browsers for frontend testing
|
||||
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
|
||||
# - name: Install Playwright browsers
|
||||
# working-directory: autogpt_platform/frontend
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Copy default environment files
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Copy default environment files for development
|
||||
cp .env.default .env
|
||||
cp backend/.env.default backend/.env
|
||||
cp frontend/.env.default frontend/.env
|
||||
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
|
||||
restore-keys: |
|
||||
docker-images-v2-${{ runner.os }}-
|
||||
docker-images-v1-${{ runner.os }}-
|
||||
|
||||
- name: Load or pull Docker images
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
mkdir -p ~/docker-cache
|
||||
|
||||
# Define image list for easy maintenance
|
||||
IMAGES=(
|
||||
"redis:latest"
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||
echo "Docker cache found, loading images in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||
echo "Loading $image..."
|
||||
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||
fi
|
||||
done
|
||||
wait
|
||||
echo "All cached images loaded"
|
||||
else
|
||||
echo "No Docker cache found, pulling images in parallel..."
|
||||
# Pull all images in parallel
|
||||
for image in "${IMAGES[@]}"; do
|
||||
docker pull "$image" &
|
||||
done
|
||||
wait
|
||||
|
||||
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||
echo "Saving Docker images to cache in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
echo "Saving $image..."
|
||||
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||
done
|
||||
wait
|
||||
echo "Docker image cache saved"
|
||||
else
|
||||
echo "Skipping cache save for PR/feature branch"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Docker images ready for use"
|
||||
|
||||
# Phase 2: Build migrate service with GitHub Actions cache
|
||||
- name: Build migrate Docker image with cache
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Build the migrate image with buildx for GHA caching
|
||||
docker buildx build \
|
||||
--cache-from type=gha \
|
||||
--cache-to type=gha,mode=max \
|
||||
--target migrate \
|
||||
--tag autogpt_platform-migrate:latest \
|
||||
--load \
|
||||
-f backend/Dockerfile \
|
||||
..
|
||||
|
||||
# Start services using pre-built images
|
||||
- name: Start Docker services for development
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Start essential services (migrate image already built with correct tag)
|
||||
docker compose --profile local up deps --no-build --detach
|
||||
echo "Waiting for services to be ready..."
|
||||
|
||||
# Wait for database to be ready
|
||||
echo "Checking database readiness..."
|
||||
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
|
||||
echo " Waiting for database..."
|
||||
sleep 2
|
||||
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
|
||||
|
||||
# Check migrate service status
|
||||
echo "Checking migration status..."
|
||||
docker compose ps migrate || echo " Migrate service not visible in ps output"
|
||||
|
||||
# Wait for migrate service to complete
|
||||
echo "Waiting for migrations to complete..."
|
||||
timeout 30 bash -c '
|
||||
ATTEMPTS=0
|
||||
while [ $ATTEMPTS -lt 15 ]; do
|
||||
ATTEMPTS=$((ATTEMPTS + 1))
|
||||
|
||||
# Check using docker directly (more reliable than docker compose ps)
|
||||
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
|
||||
|
||||
if [ -z "$CONTAINER_STATUS" ]; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
|
||||
echo "✅ Migrations completed successfully"
|
||||
docker compose logs migrate --tail=5 2>/dev/null || true
|
||||
exit 0
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
|
||||
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
|
||||
echo "❌ Migrations failed with exit code: $EXIT_CODE"
|
||||
echo "Migration logs:"
|
||||
docker compose logs migrate --tail=20 2>/dev/null || true
|
||||
exit 1
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
|
||||
else
|
||||
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
|
||||
echo "Final container check:"
|
||||
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
|
||||
echo "Migration logs (if available):"
|
||||
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
|
||||
' || echo "⚠️ Migration check completed with warnings, continuing..."
|
||||
|
||||
# Brief wait for other services to stabilize
|
||||
echo "Waiting 5 seconds for other services to stabilize..."
|
||||
sleep 5
|
||||
|
||||
# Verify installations and provide environment info
|
||||
- name: Verify setup and show environment info
|
||||
run: |
|
||||
echo "=== Python Setup ==="
|
||||
python --version
|
||||
poetry --version
|
||||
|
||||
echo "=== Node.js Setup ==="
|
||||
node --version
|
||||
pnpm --version
|
||||
|
||||
echo "=== Additional Tools ==="
|
||||
docker --version
|
||||
docker compose version
|
||||
gh --version || true
|
||||
|
||||
echo "=== Services Status ==="
|
||||
cd autogpt_platform
|
||||
docker compose ps || true
|
||||
|
||||
echo "=== Backend Dependencies ==="
|
||||
cd backend
|
||||
poetry show | head -10 || true
|
||||
|
||||
echo "=== Frontend Dependencies ==="
|
||||
cd ../frontend
|
||||
pnpm list --depth=0 | head -10 || true
|
||||
|
||||
echo "=== Environment Files ==="
|
||||
ls -la ../.env* || true
|
||||
ls -la .env* || true
|
||||
ls -la ../backend/.env* || true
|
||||
|
||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
||||
echo "🚀 Ready for development with Docker services running"
|
||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
|
||||
|
||||
- name: Run Claude Dependabot Analysis
|
||||
id: claude_review
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||
prompt: |
|
||||
You are Claude, an AI assistant specialized in reviewing Dependabot dependency update PRs.
|
||||
|
||||
Your primary tasks are:
|
||||
1. **Analyze the dependency changes** in this Dependabot PR
|
||||
2. **Look up changelogs** for all updated dependencies to understand what changed
|
||||
3. **Identify breaking changes** and assess potential impact on the AutoGPT codebase
|
||||
4. **Provide actionable recommendations** for the development team
|
||||
|
||||
## Analysis Process:
|
||||
|
||||
1. **Identify Changed Dependencies**:
|
||||
- Use git diff to see what dependencies were updated
|
||||
- Parse package.json, poetry.lock, requirements files, etc.
|
||||
- List all package versions: old → new
|
||||
|
||||
2. **Changelog Research**:
|
||||
- For each updated dependency, look up its changelog/release notes
|
||||
- Use WebFetch to access GitHub releases, NPM package pages, PyPI project pages. The pr should also have some details
|
||||
- Focus on versions between the old and new versions
|
||||
- Identify: breaking changes, deprecations, security fixes, new features
|
||||
|
||||
3. **Breaking Change Assessment**:
|
||||
- Categorize changes: BREAKING, MAJOR, MINOR, PATCH, SECURITY
|
||||
- Assess impact on AutoGPT's usage patterns
|
||||
- Check if AutoGPT uses affected APIs/features
|
||||
- Look for migration guides or upgrade instructions
|
||||
|
||||
4. **Codebase Impact Analysis**:
|
||||
- Search the AutoGPT codebase for usage of changed APIs
|
||||
- Identify files that might be affected by breaking changes
|
||||
- Check test files for deprecated usage patterns
|
||||
- Look for configuration changes needed
|
||||
|
||||
## Output Format:
|
||||
|
||||
Provide a comprehensive review comment with:
|
||||
|
||||
### 🔍 Dependency Analysis Summary
|
||||
- List of updated packages with version changes
|
||||
- Overall risk assessment (LOW/MEDIUM/HIGH)
|
||||
|
||||
### 📋 Detailed Changelog Review
|
||||
For each updated dependency:
|
||||
- **Package**: name (old_version → new_version)
|
||||
- **Changes**: Summary of key changes
|
||||
- **Breaking Changes**: List any breaking changes
|
||||
- **Security Fixes**: Note security improvements
|
||||
- **Migration Notes**: Any upgrade steps needed
|
||||
|
||||
### ⚠️ Impact Assessment
|
||||
- **Breaking Changes Found**: Yes/No with details
|
||||
- **Affected Files**: List AutoGPT files that may need updates
|
||||
- **Test Impact**: Any tests that may need updating
|
||||
- **Configuration Changes**: Required config updates
|
||||
|
||||
### 🛠️ Recommendations
|
||||
- **Action Required**: What the team should do
|
||||
- **Testing Focus**: Areas to test thoroughly
|
||||
- **Follow-up Tasks**: Any additional work needed
|
||||
- **Merge Recommendation**: APPROVE/REVIEW_NEEDED/HOLD
|
||||
|
||||
### 📚 Useful Links
|
||||
- Links to relevant changelogs, migration guides, documentation
|
||||
|
||||
Be thorough but concise. Focus on actionable insights that help the development team make informed decisions about the dependency updates.
|
||||
284
.github/workflows/claude.yml
vendored
284
.github/workflows/claude.yml
vendored
@@ -30,296 +30,18 @@ jobs:
|
||||
github.event.issue.author_association == 'COLLABORATOR'
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
contents: read
|
||||
pull-requests: read
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read # Required for CI access
|
||||
steps:
|
||||
- name: Checkout code
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11" # Use standard version matching CI
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock (matches CI)
|
||||
cd autogpt_platform/backend
|
||||
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||
|
||||
# Add Poetry to PATH
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Check poetry.lock
|
||||
working-directory: autogpt_platform/backend
|
||||
run: |
|
||||
poetry lock
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
echo "Warning: poetry.lock not up to date, but continuing for setup"
|
||||
git checkout poetry.lock # Reset for clean setup
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
working-directory: autogpt_platform/backend
|
||||
run: poetry run prisma generate
|
||||
|
||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set pnpm store directory
|
||||
run: |
|
||||
pnpm config set store-dir ~/.pnpm-store
|
||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache frontend dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install JavaScript dependencies
|
||||
working-directory: autogpt_platform/frontend
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
# Install Playwright browsers for frontend testing
|
||||
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
|
||||
# - name: Install Playwright browsers
|
||||
# working-directory: autogpt_platform/frontend
|
||||
# run: pnpm playwright install --with-deps chromium
|
||||
|
||||
# Docker setup for development environment
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Copy default environment files
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Copy default environment files for development
|
||||
cp .env.default .env
|
||||
cp backend/.env.default backend/.env
|
||||
cp frontend/.env.default frontend/.env
|
||||
|
||||
# Phase 1: Cache and load Docker images for faster setup
|
||||
- name: Set up Docker image cache
|
||||
id: docker-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/docker-cache
|
||||
# Use a versioned key for cache invalidation when image list changes
|
||||
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
|
||||
restore-keys: |
|
||||
docker-images-v2-${{ runner.os }}-
|
||||
docker-images-v1-${{ runner.os }}-
|
||||
|
||||
- name: Load or pull Docker images
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
mkdir -p ~/docker-cache
|
||||
|
||||
# Define image list for easy maintenance
|
||||
IMAGES=(
|
||||
"redis:latest"
|
||||
"rabbitmq:management"
|
||||
"clamav/clamav-debian:latest"
|
||||
"busybox:latest"
|
||||
"kong:2.8.1"
|
||||
"supabase/gotrue:v2.170.0"
|
||||
"supabase/postgres:15.8.1.049"
|
||||
"supabase/postgres-meta:v0.86.1"
|
||||
"supabase/studio:20250224-d10db0f"
|
||||
)
|
||||
|
||||
# Check if any cached tar files exist (more reliable than cache-hit)
|
||||
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||
echo "Docker cache found, loading images in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||
echo "Loading $image..."
|
||||
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||
fi
|
||||
done
|
||||
wait
|
||||
echo "All cached images loaded"
|
||||
else
|
||||
echo "No Docker cache found, pulling images in parallel..."
|
||||
# Pull all images in parallel
|
||||
for image in "${IMAGES[@]}"; do
|
||||
docker pull "$image" &
|
||||
done
|
||||
wait
|
||||
|
||||
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||
echo "Saving Docker images to cache in parallel..."
|
||||
for image in "${IMAGES[@]}"; do
|
||||
# Convert image name to filename (replace : and / with -)
|
||||
filename=$(echo "$image" | tr ':/' '--')
|
||||
echo "Saving $image..."
|
||||
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||
done
|
||||
wait
|
||||
echo "Docker image cache saved"
|
||||
else
|
||||
echo "Skipping cache save for PR/feature branch"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Docker images ready for use"
|
||||
|
||||
# Phase 2: Build migrate service with GitHub Actions cache
|
||||
- name: Build migrate Docker image with cache
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Build the migrate image with buildx for GHA caching
|
||||
docker buildx build \
|
||||
--cache-from type=gha \
|
||||
--cache-to type=gha,mode=max \
|
||||
--target migrate \
|
||||
--tag autogpt_platform-migrate:latest \
|
||||
--load \
|
||||
-f backend/Dockerfile \
|
||||
..
|
||||
|
||||
# Start services using pre-built images
|
||||
- name: Start Docker services for development
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
# Start essential services (migrate image already built with correct tag)
|
||||
docker compose --profile local up deps --no-build --detach
|
||||
echo "Waiting for services to be ready..."
|
||||
|
||||
# Wait for database to be ready
|
||||
echo "Checking database readiness..."
|
||||
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
|
||||
echo " Waiting for database..."
|
||||
sleep 2
|
||||
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
|
||||
|
||||
# Check migrate service status
|
||||
echo "Checking migration status..."
|
||||
docker compose ps migrate || echo " Migrate service not visible in ps output"
|
||||
|
||||
# Wait for migrate service to complete
|
||||
echo "Waiting for migrations to complete..."
|
||||
timeout 30 bash -c '
|
||||
ATTEMPTS=0
|
||||
while [ $ATTEMPTS -lt 15 ]; do
|
||||
ATTEMPTS=$((ATTEMPTS + 1))
|
||||
|
||||
# Check using docker directly (more reliable than docker compose ps)
|
||||
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
|
||||
|
||||
if [ -z "$CONTAINER_STATUS" ]; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
|
||||
echo "✅ Migrations completed successfully"
|
||||
docker compose logs migrate --tail=5 2>/dev/null || true
|
||||
exit 0
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
|
||||
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
|
||||
echo "❌ Migrations failed with exit code: $EXIT_CODE"
|
||||
echo "Migration logs:"
|
||||
docker compose logs migrate --tail=20 2>/dev/null || true
|
||||
exit 1
|
||||
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
|
||||
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
|
||||
else
|
||||
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
|
||||
echo "Final container check:"
|
||||
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
|
||||
echo "Migration logs (if available):"
|
||||
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
|
||||
' || echo "⚠️ Migration check completed with warnings, continuing..."
|
||||
|
||||
# Brief wait for other services to stabilize
|
||||
echo "Waiting 5 seconds for other services to stabilize..."
|
||||
sleep 5
|
||||
|
||||
# Verify installations and provide environment info
|
||||
- name: Verify setup and show environment info
|
||||
run: |
|
||||
echo "=== Python Setup ==="
|
||||
python --version
|
||||
poetry --version
|
||||
|
||||
echo "=== Node.js Setup ==="
|
||||
node --version
|
||||
pnpm --version
|
||||
|
||||
echo "=== Additional Tools ==="
|
||||
docker --version
|
||||
docker compose version
|
||||
gh --version || true
|
||||
|
||||
echo "=== Services Status ==="
|
||||
cd autogpt_platform
|
||||
docker compose ps || true
|
||||
|
||||
echo "=== Backend Dependencies ==="
|
||||
cd backend
|
||||
poetry show | head -10 || true
|
||||
|
||||
echo "=== Frontend Dependencies ==="
|
||||
cd ../frontend
|
||||
pnpm list --depth=0 | head -10 || true
|
||||
|
||||
echo "=== Environment Files ==="
|
||||
ls -la ../.env* || true
|
||||
ls -la .env* || true
|
||||
ls -la ../backend/.env* || true
|
||||
|
||||
echo "✅ AutoGPT Platform development environment setup complete!"
|
||||
echo "🚀 Ready for development with Docker services running"
|
||||
echo "📝 Backend server: poetry run serve (port 8000)"
|
||||
echo "🌐 Frontend server: pnpm dev (port 3000)"
|
||||
|
||||
- name: Run Claude Code
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@v1
|
||||
uses: anthropics/claude-code-action@beta
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*), Bash(gh pr edit:*)"
|
||||
--model opus
|
||||
additional_permissions: |
|
||||
actions: read
|
||||
|
||||
@@ -3,7 +3,6 @@ name: AutoGPT Platform - Deploy Prod Environment
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: 'read'
|
||||
@@ -18,8 +17,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name || 'master' }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -39,7 +36,7 @@ jobs:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
|
||||
|
||||
trigger:
|
||||
needs: migrate
|
||||
runs-on: ubuntu-latest
|
||||
@@ -50,5 +47,4 @@ jobs:
|
||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
event-type: build_deploy_prod
|
||||
client-payload: |
|
||||
{"ref": "${{ github.ref_name || 'master' }}", "repository": "${{ github.repository }}"}
|
||||
client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "repository": "${{ github.repository }}"}'
|
||||
@@ -5,13 +5,6 @@ on:
|
||||
branches: [ dev ]
|
||||
paths:
|
||||
- 'autogpt_platform/**'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
git_ref:
|
||||
description: 'Git ref (branch/tag) of AutoGPT to deploy'
|
||||
required: true
|
||||
default: 'master'
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: 'read'
|
||||
@@ -26,8 +19,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -57,4 +48,4 @@ jobs:
|
||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
event-type: build_deploy_dev
|
||||
client-payload: '{"ref": "${{ github.event.inputs.git_ref || github.ref }}", "repository": "${{ github.repository }}"}'
|
||||
client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "repository": "${{ github.repository }}"}'
|
||||
|
||||
5
.github/workflows/platform-backend-ci.yml
vendored
5
.github/workflows/platform-backend-ci.yml
vendored
@@ -37,7 +37,9 @@ jobs:
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: redis:latest
|
||||
image: bitnami/redis:6.2
|
||||
env:
|
||||
REDIS_PASSWORD: testpassword
|
||||
ports:
|
||||
- 6379:6379
|
||||
rabbitmq:
|
||||
@@ -202,6 +204,7 @@ jobs:
|
||||
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
env:
|
||||
|
||||
10
.github/workflows/platform-frontend-ci.yml
vendored
10
.github/workflows/platform-frontend-ci.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
@@ -62,7 +62,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
@@ -97,7 +97,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
@@ -138,7 +138,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
@@ -160,7 +160,7 @@ jobs:
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d
|
||||
docker compose -f ../docker-compose.yml up -d
|
||||
env:
|
||||
DOCKER_BUILDKIT: 1
|
||||
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
|
||||
|
||||
4
.github/workflows/platform-fullstack-ci.yml
vendored
4
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
@@ -66,7 +66,7 @@ jobs:
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -178,4 +178,3 @@ autogpt_platform/backend/settings.py
|
||||
*.ign.*
|
||||
.test-contents
|
||||
.claude/settings.local.json
|
||||
/autogpt_platform/backend/logs
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
[pr_reviewer]
|
||||
num_code_suggestions=0
|
||||
|
||||
[pr_code_suggestions]
|
||||
commitable_code_suggestions=false
|
||||
num_code_suggestions=0
|
||||
|
||||
@@ -61,41 +61,24 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd frontend && pnpm i
|
||||
|
||||
# Generate API client from OpenAPI spec
|
||||
pnpm generate:api
|
||||
cd frontend && npm install
|
||||
|
||||
# Start development server
|
||||
pnpm dev
|
||||
npm run dev
|
||||
|
||||
# Run E2E tests
|
||||
pnpm test
|
||||
npm run test
|
||||
|
||||
# Run Storybook for component development
|
||||
pnpm storybook
|
||||
npm run storybook
|
||||
|
||||
# Build production
|
||||
pnpm build
|
||||
|
||||
# Format and lint
|
||||
pnpm format
|
||||
npm run build
|
||||
|
||||
# Type checking
|
||||
pnpm types
|
||||
npm run types
|
||||
```
|
||||
|
||||
**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns.
|
||||
|
||||
**Key Frontend Conventions:**
|
||||
|
||||
- Separate render logic from data/behavior in components
|
||||
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Use function declarations (not arrow functions) for components/handlers
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Only use Phosphor Icons
|
||||
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Backend Architecture
|
||||
@@ -109,16 +92,11 @@ pnpm types
|
||||
|
||||
### Frontend Architecture
|
||||
|
||||
- **Framework**: Next.js 15 App Router (client-first approach)
|
||||
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
|
||||
- **State Management**: React Query for server state, co-located UI state in components/hooks
|
||||
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
|
||||
- **Framework**: Next.js App Router with React Server Components
|
||||
- **State Management**: React hooks + Supabase client for real-time updates
|
||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
||||
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
|
||||
- **Icons**: Phosphor Icons only
|
||||
- **UI Components**: Radix UI primitives with Tailwind CSS styling
|
||||
- **Feature Flags**: LaunchDarkly integration
|
||||
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
|
||||
- **Testing**: Playwright for E2E, Storybook for component development
|
||||
|
||||
### Key Concepts
|
||||
|
||||
@@ -172,7 +150,6 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
**Adding a new block:**
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
@@ -180,7 +157,6 @@ Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `/backend/backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
@@ -192,8 +168,6 @@ Quick steps:
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
**Modifying the API:**
|
||||
|
||||
1. Update route in `/backend/backend/server/routers/`
|
||||
@@ -203,20 +177,10 @@ If you get any pushback or hit complex block conditions check the new_blocks gui
|
||||
|
||||
**Frontend feature development:**
|
||||
|
||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
|
||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
||||
- Add `usePageName.ts` hook for logic
|
||||
- Put sub-components in local `components/` folder
|
||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
1. Components go in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for new components
|
||||
4. Test with Playwright if user-facing
|
||||
|
||||
### Security Implementation
|
||||
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
.PHONY: start-core stop-core logs-core format lint migrate run-backend run-frontend
|
||||
|
||||
# Run just Supabase + Redis + RabbitMQ
|
||||
start-core:
|
||||
docker compose up -d deps
|
||||
|
||||
# Stop core services
|
||||
stop-core:
|
||||
docker compose stop deps
|
||||
|
||||
reset-db:
|
||||
rm -rf db/docker/volumes/db/data
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
|
||||
# View logs for core services
|
||||
logs-core:
|
||||
docker compose logs -f deps
|
||||
|
||||
# Run formatting and linting for backend and frontend
|
||||
format:
|
||||
cd backend && poetry run format
|
||||
cd frontend && pnpm format
|
||||
cd frontend && pnpm lint
|
||||
|
||||
init-env:
|
||||
cp -n .env.default .env || true
|
||||
cd backend && cp -n .env.default .env || true
|
||||
cd frontend && cp -n .env.default .env || true
|
||||
|
||||
|
||||
# Run migrations for backend
|
||||
migrate:
|
||||
cd backend && poetry run prisma migrate deploy
|
||||
cd backend && poetry run prisma generate
|
||||
|
||||
run-backend:
|
||||
cd backend && poetry run app
|
||||
|
||||
run-frontend:
|
||||
cd frontend && pnpm dev
|
||||
|
||||
test-data:
|
||||
cd backend && poetry run python test/test_data_creator.py
|
||||
|
||||
help:
|
||||
@echo "Usage: make <target>"
|
||||
@echo "Targets:"
|
||||
@echo " start-core - Start just the core services (Supabase, Redis, RabbitMQ) in background"
|
||||
@echo " stop-core - Stop the core services"
|
||||
@echo " reset-db - Reset the database by deleting the volume"
|
||||
@echo " logs-core - Tail the logs for core services"
|
||||
@echo " format - Format & lint backend (Python) and frontend (TypeScript) code"
|
||||
@echo " migrate - Run backend database migrations"
|
||||
@echo " run-backend - Run the backend FastAPI server"
|
||||
@echo " run-frontend - Run the frontend Next.js development server"
|
||||
@echo " test-data - Run the test data creator"
|
||||
@@ -38,37 +38,6 @@ To run the AutoGPT Platform, follow these steps:
|
||||
|
||||
4. After all the services are in ready state, open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
|
||||
### Running Just Core services
|
||||
|
||||
You can now run the following to enable just the core services.
|
||||
|
||||
```
|
||||
# For help
|
||||
make help
|
||||
|
||||
# Run just Supabase + Redis + RabbitMQ
|
||||
make start-core
|
||||
|
||||
# Stop core services
|
||||
make stop-core
|
||||
|
||||
# View logs from core services
|
||||
make logs-core
|
||||
|
||||
# Run formatting and linting for backend and frontend
|
||||
make format
|
||||
|
||||
# Run migrations for backend database
|
||||
make migrate
|
||||
|
||||
# Run backend server
|
||||
make run-backend
|
||||
|
||||
# Run frontend development server
|
||||
make run-frontend
|
||||
|
||||
```
|
||||
|
||||
### Docker Compose Commands
|
||||
|
||||
Here are some useful Docker Compose commands for managing your AutoGPT Platform:
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
class APIKeyContainer(NamedTuple):
|
||||
"""Container for API key parts."""
|
||||
|
||||
raw: str
|
||||
prefix: str
|
||||
postfix: str
|
||||
hash: str
|
||||
|
||||
|
||||
class APIKeyManager:
|
||||
PREFIX: str = "agpt_"
|
||||
PREFIX_LENGTH: int = 8
|
||||
POSTFIX_LENGTH: int = 8
|
||||
|
||||
def generate_api_key(self) -> APIKeyContainer:
|
||||
"""Generate a new API key with all its parts."""
|
||||
raw_key = f"{self.PREFIX}{secrets.token_urlsafe(32)}"
|
||||
return APIKeyContainer(
|
||||
raw=raw_key,
|
||||
prefix=raw_key[: self.PREFIX_LENGTH],
|
||||
postfix=raw_key[-self.POSTFIX_LENGTH :],
|
||||
hash=hashlib.sha256(raw_key.encode()).hexdigest(),
|
||||
)
|
||||
|
||||
def verify_api_key(self, provided_key: str, stored_hash: str) -> bool:
|
||||
"""Verify if a provided API key matches the stored hash."""
|
||||
if not provided_key.startswith(self.PREFIX):
|
||||
return False
|
||||
provided_hash = hashlib.sha256(provided_key.encode()).hexdigest()
|
||||
return secrets.compare_digest(provided_hash, stored_hash)
|
||||
@@ -1,78 +0,0 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import NamedTuple
|
||||
|
||||
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
|
||||
|
||||
|
||||
class APIKeyContainer(NamedTuple):
|
||||
"""Container for API key parts."""
|
||||
|
||||
key: str
|
||||
head: str
|
||||
tail: str
|
||||
hash: str
|
||||
salt: str
|
||||
|
||||
|
||||
class APIKeySmith:
|
||||
PREFIX: str = "agpt_"
|
||||
HEAD_LENGTH: int = 8
|
||||
TAIL_LENGTH: int = 8
|
||||
|
||||
def generate_key(self) -> APIKeyContainer:
|
||||
"""Generate a new API key with secure hashing."""
|
||||
raw_key = f"{self.PREFIX}{secrets.token_urlsafe(32)}"
|
||||
hash, salt = self.hash_key(raw_key)
|
||||
|
||||
return APIKeyContainer(
|
||||
key=raw_key,
|
||||
head=raw_key[: self.HEAD_LENGTH],
|
||||
tail=raw_key[-self.TAIL_LENGTH :],
|
||||
hash=hash,
|
||||
salt=salt,
|
||||
)
|
||||
|
||||
def verify_key(
|
||||
self, provided_key: str, known_hash: str, known_salt: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Verify an API key against a known hash (+ salt).
|
||||
Supports verifying both legacy SHA256 and secure Scrypt hashes.
|
||||
"""
|
||||
if not provided_key.startswith(self.PREFIX):
|
||||
return False
|
||||
|
||||
# Handle legacy SHA256 hashes (migration support)
|
||||
if known_salt is None:
|
||||
legacy_hash = hashlib.sha256(provided_key.encode()).hexdigest()
|
||||
return secrets.compare_digest(legacy_hash, known_hash)
|
||||
|
||||
try:
|
||||
salt_bytes = bytes.fromhex(known_salt)
|
||||
provided_hash = self._hash_key_with_salt(provided_key, salt_bytes)
|
||||
return secrets.compare_digest(provided_hash, known_hash)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
def hash_key(self, raw_key: str) -> tuple[str, str]:
|
||||
"""Migrate a legacy hash to secure hash format."""
|
||||
salt = self._generate_salt()
|
||||
hash = self._hash_key_with_salt(raw_key, salt)
|
||||
return hash, salt.hex()
|
||||
|
||||
def _generate_salt(self) -> bytes:
|
||||
"""Generate a random salt for hashing."""
|
||||
return secrets.token_bytes(32)
|
||||
|
||||
def _hash_key_with_salt(self, raw_key: str, salt: bytes) -> str:
|
||||
"""Hash API key using Scrypt with salt."""
|
||||
kdf = Scrypt(
|
||||
length=32,
|
||||
salt=salt,
|
||||
n=2**14, # CPU/memory cost parameter
|
||||
r=8, # Block size parameter
|
||||
p=1, # Parallelization parameter
|
||||
)
|
||||
key_hash = kdf.derive(raw_key.encode())
|
||||
return key_hash.hex()
|
||||
@@ -1,79 +0,0 @@
|
||||
import hashlib
|
||||
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
|
||||
|
||||
def test_generate_api_key():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
assert key.key.startswith(keysmith.PREFIX)
|
||||
assert key.head == key.key[: keysmith.HEAD_LENGTH]
|
||||
assert key.tail == key.key[-keysmith.TAIL_LENGTH :]
|
||||
assert len(key.hash) == 64 # 32 bytes hex encoded
|
||||
assert len(key.salt) == 64 # 32 bytes hex encoded
|
||||
|
||||
|
||||
def test_verify_new_secure_key():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Test correct key validates
|
||||
assert keysmith.verify_key(key.key, key.hash, key.salt) is True
|
||||
|
||||
# Test wrong key fails
|
||||
wrong_key = f"{keysmith.PREFIX}wrongkey123"
|
||||
assert keysmith.verify_key(wrong_key, key.hash, key.salt) is False
|
||||
|
||||
|
||||
def test_verify_legacy_key():
|
||||
keysmith = APIKeySmith()
|
||||
legacy_key = f"{keysmith.PREFIX}legacykey123"
|
||||
legacy_hash = hashlib.sha256(legacy_key.encode()).hexdigest()
|
||||
|
||||
# Test legacy key validates without salt
|
||||
assert keysmith.verify_key(legacy_key, legacy_hash) is True
|
||||
|
||||
# Test wrong legacy key fails
|
||||
wrong_key = f"{keysmith.PREFIX}wronglegacy"
|
||||
assert keysmith.verify_key(wrong_key, legacy_hash) is False
|
||||
|
||||
|
||||
def test_rehash_existing_key():
|
||||
keysmith = APIKeySmith()
|
||||
legacy_key = f"{keysmith.PREFIX}migratekey123"
|
||||
|
||||
# Migrate the legacy key
|
||||
new_hash, new_salt = keysmith.hash_key(legacy_key)
|
||||
|
||||
# Verify migrated key works
|
||||
assert keysmith.verify_key(legacy_key, new_hash, new_salt) is True
|
||||
|
||||
# Verify different key fails with migrated hash
|
||||
wrong_key = f"{keysmith.PREFIX}wrongkey"
|
||||
assert keysmith.verify_key(wrong_key, new_hash, new_salt) is False
|
||||
|
||||
|
||||
def test_invalid_key_prefix():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Test key without proper prefix fails
|
||||
invalid_key = "invalid_prefix_key"
|
||||
assert keysmith.verify_key(invalid_key, key.hash, key.salt) is False
|
||||
|
||||
|
||||
def test_secure_hash_requires_salt():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Secure hash without salt should fail
|
||||
assert keysmith.verify_key(key.key, key.hash) is False
|
||||
|
||||
|
||||
def test_invalid_salt_format():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Invalid salt format should fail gracefully
|
||||
assert keysmith.verify_key(key.key, key.hash, "invalid_hex") is False
|
||||
@@ -1,10 +1,5 @@
|
||||
from .config import verify_settings
|
||||
from .dependencies import (
|
||||
get_optional_user_id,
|
||||
get_user_id,
|
||||
requires_admin_user,
|
||||
requires_user,
|
||||
)
|
||||
from .dependencies import get_user_id, requires_admin_user, requires_user
|
||||
from .helpers import add_auth_responses_to_openapi
|
||||
from .models import User
|
||||
|
||||
@@ -13,7 +8,6 @@ __all__ = [
|
||||
"get_user_id",
|
||||
"requires_admin_user",
|
||||
"requires_user",
|
||||
"get_optional_user_id",
|
||||
"add_auth_responses_to_openapi",
|
||||
"User",
|
||||
]
|
||||
|
||||
@@ -4,55 +4,13 @@ FastAPI dependency functions for JWT-based authentication and authorization.
|
||||
These are the high-level dependency functions used in route definitions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import fastapi
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from .jwt_utils import get_jwt_payload, verify_user
|
||||
from .models import User
|
||||
|
||||
optional_bearer = HTTPBearer(auto_error=False)
|
||||
|
||||
# Header name for admin impersonation
|
||||
IMPERSONATION_HEADER_NAME = "X-Act-As-User-Id"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_optional_user_id(
|
||||
credentials: HTTPAuthorizationCredentials | None = fastapi.Security(
|
||||
optional_bearer
|
||||
),
|
||||
) -> str | None:
|
||||
"""
|
||||
Attempts to extract the user ID ("sub" claim) from a Bearer JWT if provided.
|
||||
|
||||
This dependency allows for both authenticated and anonymous access. If a valid bearer token is
|
||||
supplied, it parses the JWT and extracts the user ID. If the token is missing or invalid, it returns None,
|
||||
treating the request as anonymous.
|
||||
|
||||
Args:
|
||||
credentials: Optional HTTPAuthorizationCredentials object from FastAPI Security dependency.
|
||||
|
||||
Returns:
|
||||
The user ID (str) extracted from the JWT "sub" claim, or None if no valid token is present.
|
||||
"""
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Parse JWT token to get user ID
|
||||
from autogpt_libs.auth.jwt_utils import parse_jwt_token
|
||||
|
||||
payload = parse_jwt_token(credentials.credentials)
|
||||
return payload.get("sub")
|
||||
except Exception as e:
|
||||
logger.debug(f"Auth token validation failed (anonymous access): {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid authenticated user.
|
||||
|
||||
@@ -62,9 +20,7 @@ async def requires_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -
|
||||
return verify_user(jwt_payload, admin_only=False)
|
||||
|
||||
|
||||
async def requires_admin_user(
|
||||
jwt_payload: dict = fastapi.Security(get_jwt_payload),
|
||||
) -> User:
|
||||
def requires_admin_user(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> User:
|
||||
"""
|
||||
FastAPI dependency that requires a valid admin user.
|
||||
|
||||
@@ -74,44 +30,16 @@ async def requires_admin_user(
|
||||
return verify_user(jwt_payload, admin_only=True)
|
||||
|
||||
|
||||
async def get_user_id(
|
||||
request: fastapi.Request, jwt_payload: dict = fastapi.Security(get_jwt_payload)
|
||||
) -> str:
|
||||
def get_user_id(jwt_payload: dict = fastapi.Security(get_jwt_payload)) -> str:
|
||||
"""
|
||||
FastAPI dependency that returns the ID of the authenticated user.
|
||||
|
||||
Supports admin impersonation via X-Act-As-User-Id header:
|
||||
- If the header is present and user is admin, returns the impersonated user ID
|
||||
- Otherwise returns the authenticated user's own ID
|
||||
- Logs all impersonation actions for audit trail
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 for authentication failures or missing user ID
|
||||
HTTPException: 403 if non-admin tries to use impersonation
|
||||
"""
|
||||
# Get the authenticated user's ID from JWT
|
||||
user_id = jwt_payload.get("sub")
|
||||
if not user_id:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="User ID not found in token"
|
||||
)
|
||||
|
||||
# Check for admin impersonation header
|
||||
impersonate_header = request.headers.get(IMPERSONATION_HEADER_NAME, "").strip()
|
||||
if impersonate_header:
|
||||
# Verify the authenticated user is an admin
|
||||
authenticated_user = verify_user(jwt_payload, admin_only=False)
|
||||
if authenticated_user.role != "admin":
|
||||
raise fastapi.HTTPException(
|
||||
status_code=403, detail="Only admin users can impersonate other users"
|
||||
)
|
||||
|
||||
# Log the impersonation for audit trail
|
||||
logger.info(
|
||||
f"Admin impersonation: {authenticated_user.user_id} ({authenticated_user.email}) "
|
||||
f"acting as user {impersonate_header} for requesting {request.method} {request.url}"
|
||||
)
|
||||
|
||||
return impersonate_header
|
||||
|
||||
return user_id
|
||||
|
||||
@@ -4,10 +4,9 @@ Tests the full authentication flow from HTTP requests to user validation.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException, Request, Security
|
||||
from fastapi import FastAPI, HTTPException, Security
|
||||
from fastapi.testclient import TestClient
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
@@ -46,8 +45,7 @@ class TestAuthDependencies:
|
||||
"""Create a test client."""
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
|
||||
def test_requires_user_with_valid_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user with valid JWT payload."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
@@ -55,13 +53,12 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = await requires_user(jwt_payload)
|
||||
user = requires_user(jwt_payload)
|
||||
assert isinstance(user, User)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.role == "user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
|
||||
def test_requires_user_with_admin_jwt_payload(self, mocker: MockerFixture):
|
||||
"""Test requires_user accepts admin users."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
@@ -72,31 +69,28 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = await requires_user(jwt_payload)
|
||||
user = requires_user(jwt_payload)
|
||||
assert user.user_id == "admin-456"
|
||||
assert user.role == "admin"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_user_missing_sub(self):
|
||||
def test_requires_user_missing_sub(self):
|
||||
"""Test requires_user with missing user ID."""
|
||||
jwt_payload = {"role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await requires_user(jwt_payload)
|
||||
requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_user_empty_sub(self):
|
||||
def test_requires_user_empty_sub(self):
|
||||
"""Test requires_user with empty user ID."""
|
||||
jwt_payload = {"sub": "", "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await requires_user(jwt_payload)
|
||||
requires_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
|
||||
def test_requires_admin_user_with_admin(self, mocker: MockerFixture):
|
||||
"""Test requires_admin_user with admin role."""
|
||||
jwt_payload = {
|
||||
"sub": "admin-789",
|
||||
@@ -107,62 +101,51 @@ class TestAuthDependencies:
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user = await requires_admin_user(jwt_payload)
|
||||
user = requires_admin_user(jwt_payload)
|
||||
assert user.user_id == "admin-789"
|
||||
assert user.role == "admin"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_admin_user_with_regular_user(self):
|
||||
def test_requires_admin_user_with_regular_user(self):
|
||||
"""Test requires_admin_user rejects regular users."""
|
||||
jwt_payload = {"sub": "user-123", "role": "user", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await requires_admin_user(jwt_payload)
|
||||
requires_admin_user(jwt_payload)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Admin access required" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_admin_user_missing_role(self):
|
||||
def test_requires_admin_user_missing_role(self):
|
||||
"""Test requires_admin_user with missing role."""
|
||||
jwt_payload = {"sub": "user-123", "email": "user@example.com"}
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
await requires_admin_user(jwt_payload)
|
||||
requires_admin_user(jwt_payload)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
|
||||
def test_get_user_id_with_valid_payload(self, mocker: MockerFixture):
|
||||
"""Test get_user_id extracts user ID correctly."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {}
|
||||
jwt_payload = {"sub": "user-id-xyz", "role": "user"}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
user_id = get_user_id(jwt_payload)
|
||||
assert user_id == "user-id-xyz"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_id_missing_sub(self):
|
||||
def test_get_user_id_missing_sub(self):
|
||||
"""Test get_user_id with missing user ID."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {}
|
||||
jwt_payload = {"role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_user_id(request, jwt_payload)
|
||||
get_user_id(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "User ID not found" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_id_none_sub(self):
|
||||
def test_get_user_id_none_sub(self):
|
||||
"""Test get_user_id with None user ID."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {}
|
||||
jwt_payload = {"sub": None, "role": "user"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_user_id(request, jwt_payload)
|
||||
get_user_id(jwt_payload)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
@@ -187,8 +170,7 @@ class TestAuthDependenciesIntegration:
|
||||
|
||||
return _create_token
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_endpoint_auth_enabled_no_token(self):
|
||||
def test_endpoint_auth_enabled_no_token(self):
|
||||
"""Test endpoints require token when auth is enabled."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -202,8 +184,7 @@ class TestAuthDependenciesIntegration:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_endpoint_with_valid_token(self, create_token):
|
||||
def test_endpoint_with_valid_token(self, create_token):
|
||||
"""Test endpoint with valid JWT token."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -222,8 +203,7 @@ class TestAuthDependenciesIntegration:
|
||||
assert response.status_code == 200
|
||||
assert response.json()["user_id"] == "test-user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_endpoint_requires_admin_role(self, create_token):
|
||||
def test_admin_endpoint_requires_admin_role(self, create_token):
|
||||
"""Test admin endpoint rejects non-admin users."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -260,8 +240,7 @@ class TestAuthDependenciesIntegration:
|
||||
class TestAuthDependenciesEdgeCases:
|
||||
"""Edge case tests for authentication dependencies."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_with_complex_payload(self):
|
||||
def test_dependency_with_complex_payload(self):
|
||||
"""Test dependencies handle complex JWT payloads."""
|
||||
complex_payload = {
|
||||
"sub": "user-123",
|
||||
@@ -277,15 +256,14 @@ class TestAuthDependenciesEdgeCases:
|
||||
"exp": 9999999999,
|
||||
}
|
||||
|
||||
user = await requires_user(complex_payload)
|
||||
user = requires_user(complex_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email == "test@example.com"
|
||||
|
||||
admin = await requires_admin_user(complex_payload)
|
||||
admin = requires_admin_user(complex_payload)
|
||||
assert admin.role == "admin"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_with_unicode_in_payload(self):
|
||||
def test_dependency_with_unicode_in_payload(self):
|
||||
"""Test dependencies handle unicode in JWT payloads."""
|
||||
unicode_payload = {
|
||||
"sub": "user-😀-123",
|
||||
@@ -294,12 +272,11 @@ class TestAuthDependenciesEdgeCases:
|
||||
"name": "日本語",
|
||||
}
|
||||
|
||||
user = await requires_user(unicode_payload)
|
||||
user = requires_user(unicode_payload)
|
||||
assert "😀" in user.user_id
|
||||
assert user.email == "测试@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_with_null_values(self):
|
||||
def test_dependency_with_null_values(self):
|
||||
"""Test dependencies handle null values in payload."""
|
||||
null_payload = {
|
||||
"sub": "user-123",
|
||||
@@ -309,19 +286,18 @@ class TestAuthDependenciesEdgeCases:
|
||||
"metadata": None,
|
||||
}
|
||||
|
||||
user = await requires_user(null_payload)
|
||||
user = requires_user(null_payload)
|
||||
assert user.user_id == "user-123"
|
||||
assert user.email is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_requests_isolation(self):
|
||||
def test_concurrent_requests_isolation(self):
|
||||
"""Test that concurrent requests don't interfere with each other."""
|
||||
payload1 = {"sub": "user-1", "role": "user"}
|
||||
payload2 = {"sub": "user-2", "role": "admin"}
|
||||
|
||||
# Simulate concurrent processing
|
||||
user1 = await requires_user(payload1)
|
||||
user2 = await requires_admin_user(payload2)
|
||||
user1 = requires_user(payload1)
|
||||
user2 = requires_admin_user(payload2)
|
||||
|
||||
assert user1.user_id == "user-1"
|
||||
assert user2.user_id == "user-2"
|
||||
@@ -338,8 +314,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
({"sub": "user", "role": "user"}, "Admin access required", True),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_error_cases(
|
||||
def test_dependency_error_cases(
|
||||
self, payload, expected_error: str, admin_only: bool
|
||||
):
|
||||
"""Test that errors propagate correctly through dependencies."""
|
||||
@@ -350,8 +325,7 @@ class TestAuthDependenciesEdgeCases:
|
||||
verify_user(payload, admin_only=admin_only)
|
||||
assert expected_error in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_valid_user(self):
|
||||
def test_dependency_valid_user(self):
|
||||
"""Test valid user case for dependency."""
|
||||
# Import verify_user to test it directly since dependencies use FastAPI Security
|
||||
from autogpt_libs.auth.jwt_utils import verify_user
|
||||
@@ -359,196 +333,3 @@ class TestAuthDependenciesEdgeCases:
|
||||
# Valid case
|
||||
user = verify_user({"sub": "user", "role": "user"}, admin_only=False)
|
||||
assert user.user_id == "user"
|
||||
|
||||
|
||||
class TestAdminImpersonation:
|
||||
"""Test suite for admin user impersonation functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_impersonation_success(self, mocker: MockerFixture):
|
||||
"""Test admin successfully impersonating another user."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Act-As-User-Id": "target-user-123"}
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
# Mock verify_user to return admin user data
|
||||
mock_verify_user = mocker.patch("autogpt_libs.auth.dependencies.verify_user")
|
||||
mock_verify_user.return_value = Mock(
|
||||
user_id="admin-456", email="admin@example.com", role="admin"
|
||||
)
|
||||
|
||||
# Mock logger to verify audit logging
|
||||
mock_logger = mocker.patch("autogpt_libs.auth.dependencies.logger")
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Should return the impersonated user ID
|
||||
assert user_id == "target-user-123"
|
||||
|
||||
# Should log the impersonation attempt
|
||||
mock_logger.info.assert_called_once()
|
||||
log_call = mock_logger.info.call_args[0][0]
|
||||
assert "Admin impersonation:" in log_call
|
||||
assert "admin@example.com" in log_call
|
||||
assert "target-user-123" in log_call
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_admin_impersonation_attempt(self, mocker: MockerFixture):
|
||||
"""Test non-admin user attempting impersonation returns 403."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Act-As-User-Id": "target-user-123"}
|
||||
jwt_payload = {
|
||||
"sub": "regular-user",
|
||||
"role": "user",
|
||||
"email": "user@example.com",
|
||||
}
|
||||
|
||||
# Mock verify_user to return regular user data
|
||||
mock_verify_user = mocker.patch("autogpt_libs.auth.dependencies.verify_user")
|
||||
mock_verify_user.return_value = Mock(
|
||||
user_id="regular-user", email="user@example.com", role="user"
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_user_id(request, jwt_payload)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Only admin users can impersonate other users" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impersonation_empty_header(self, mocker: MockerFixture):
|
||||
"""Test impersonation with empty header falls back to regular user ID."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Act-As-User-Id": ""}
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Should fall back to the admin's own user ID
|
||||
assert user_id == "admin-456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impersonation_missing_header(self, mocker: MockerFixture):
|
||||
"""Test normal behavior when impersonation header is missing."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {} # No impersonation header
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Should return the admin's own user ID
|
||||
assert user_id == "admin-456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impersonation_audit_logging_details(self, mocker: MockerFixture):
|
||||
"""Test that impersonation audit logging includes all required details."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Act-As-User-Id": "victim-user-789"}
|
||||
jwt_payload = {
|
||||
"sub": "admin-999",
|
||||
"role": "admin",
|
||||
"email": "superadmin@company.com",
|
||||
}
|
||||
|
||||
# Mock verify_user to return admin user data
|
||||
mock_verify_user = mocker.patch("autogpt_libs.auth.dependencies.verify_user")
|
||||
mock_verify_user.return_value = Mock(
|
||||
user_id="admin-999", email="superadmin@company.com", role="admin"
|
||||
)
|
||||
|
||||
# Mock logger to capture audit trail
|
||||
mock_logger = mocker.patch("autogpt_libs.auth.dependencies.logger")
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Verify all audit details are logged
|
||||
assert user_id == "victim-user-789"
|
||||
mock_logger.info.assert_called_once()
|
||||
|
||||
log_message = mock_logger.info.call_args[0][0]
|
||||
assert "Admin impersonation:" in log_message
|
||||
assert "superadmin@company.com" in log_message
|
||||
assert "victim-user-789" in log_message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impersonation_header_case_sensitivity(self, mocker: MockerFixture):
|
||||
"""Test that impersonation header is case-sensitive."""
|
||||
request = Mock(spec=Request)
|
||||
# Use wrong case - should not trigger impersonation
|
||||
request.headers = {"x-act-as-user-id": "target-user-123"}
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Should fall back to admin's own ID (header case mismatch)
|
||||
assert user_id == "admin-456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_impersonation_with_whitespace_header(self, mocker: MockerFixture):
|
||||
"""Test impersonation with whitespace in header value."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"X-Act-As-User-Id": " target-user-123 "}
|
||||
jwt_payload = {
|
||||
"sub": "admin-456",
|
||||
"role": "admin",
|
||||
"email": "admin@example.com",
|
||||
}
|
||||
|
||||
# Mock verify_user to return admin user data
|
||||
mock_verify_user = mocker.patch("autogpt_libs.auth.dependencies.verify_user")
|
||||
mock_verify_user.return_value = Mock(
|
||||
user_id="admin-456", email="admin@example.com", role="admin"
|
||||
)
|
||||
|
||||
# Mock logger
|
||||
mock_logger = mocker.patch("autogpt_libs.auth.dependencies.logger")
|
||||
|
||||
mocker.patch(
|
||||
"autogpt_libs.auth.dependencies.get_jwt_payload", return_value=jwt_payload
|
||||
)
|
||||
|
||||
user_id = await get_user_id(request, jwt_payload)
|
||||
|
||||
# Should strip whitespace and impersonate successfully
|
||||
assert user_id == "target-user-123"
|
||||
mock_logger.info.assert_called_once()
|
||||
|
||||
@@ -16,7 +16,7 @@ bearer_jwt_auth = HTTPBearer(
|
||||
)
|
||||
|
||||
|
||||
async def get_jwt_payload(
|
||||
def get_jwt_payload(
|
||||
credentials: HTTPAuthorizationCredentials | None = Security(bearer_jwt_auth),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -116,32 +116,32 @@ def test_parse_jwt_token_missing_audience():
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
|
||||
async def test_get_jwt_payload_with_valid_token():
|
||||
def test_get_jwt_payload_with_valid_token():
|
||||
"""Test extracting JWT payload with valid bearer token."""
|
||||
token = create_token(TEST_USER_PAYLOAD)
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
||||
|
||||
result = await jwt_utils.get_jwt_payload(credentials)
|
||||
result = jwt_utils.get_jwt_payload(credentials)
|
||||
assert result["sub"] == "test-user-id"
|
||||
assert result["role"] == "user"
|
||||
|
||||
|
||||
async def test_get_jwt_payload_no_credentials():
|
||||
def test_get_jwt_payload_no_credentials():
|
||||
"""Test JWT payload when no credentials provided."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await jwt_utils.get_jwt_payload(None)
|
||||
jwt_utils.get_jwt_payload(None)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Authorization header is missing" in exc_info.value.detail
|
||||
|
||||
|
||||
async def test_get_jwt_payload_invalid_token():
|
||||
def test_get_jwt_payload_invalid_token():
|
||||
"""Test JWT payload extraction with invalid token."""
|
||||
credentials = HTTPAuthorizationCredentials(
|
||||
scheme="Bearer", credentials="invalid.token.here"
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await jwt_utils.get_jwt_payload(credentials)
|
||||
jwt_utils.get_jwt_payload(credentials)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid token" in exc_info.value.detail
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import logging
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
@@ -94,36 +93,42 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
config = LoggingConfig()
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
structured_logging = config.enable_cloud_logging or force_cloud_logging
|
||||
|
||||
# Console output handlers
|
||||
if not structured_logging:
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
if config.level == logging.DEBUG:
|
||||
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
if config.level == logging.DEBUG:
|
||||
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
if config.level == logging.DEBUG:
|
||||
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
if config.level == logging.DEBUG:
|
||||
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
log_handlers += [stdout, stderr]
|
||||
log_handlers += [stdout, stderr]
|
||||
|
||||
# Cloud logging setup
|
||||
else:
|
||||
# Use Google Cloud Structured Log Handler. Log entries are printed to stdout
|
||||
# in a JSON format which is automatically picked up by Google Cloud Logging.
|
||||
from google.cloud.logging.handlers import StructuredLogHandler
|
||||
if config.enable_cloud_logging or force_cloud_logging:
|
||||
import google.cloud.logging
|
||||
from google.cloud.logging.handlers import CloudLoggingHandler
|
||||
from google.cloud.logging_v2.handlers.transports import (
|
||||
BackgroundThreadTransport,
|
||||
)
|
||||
|
||||
structured_log_handler = StructuredLogHandler(stream=sys.stdout)
|
||||
structured_log_handler.setLevel(config.level)
|
||||
log_handlers.append(structured_log_handler)
|
||||
client = google.cloud.logging.Client()
|
||||
# Use BackgroundThreadTransport to prevent blocking the main thread
|
||||
# and deadlocks when gRPC calls to Google Cloud Logging hang
|
||||
cloud_handler = CloudLoggingHandler(
|
||||
client,
|
||||
name="autogpt_logs",
|
||||
transport=BackgroundThreadTransport,
|
||||
)
|
||||
cloud_handler.setLevel(config.level)
|
||||
log_handlers.append(cloud_handler)
|
||||
|
||||
# File logging setup
|
||||
if config.enable_file_logging:
|
||||
@@ -134,13 +139,8 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
print(f"Log directory: {config.log_dir}")
|
||||
|
||||
# Activity log handler (INFO and above)
|
||||
# Security fix: Use RotatingFileHandler with size limits to prevent disk exhaustion
|
||||
activity_log_handler = RotatingFileHandler(
|
||||
config.log_dir / LOG_FILE,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB per file
|
||||
backupCount=3, # Keep 3 backup files (40MB total)
|
||||
activity_log_handler = logging.FileHandler(
|
||||
config.log_dir / LOG_FILE, "a", "utf-8"
|
||||
)
|
||||
activity_log_handler.setLevel(config.level)
|
||||
activity_log_handler.setFormatter(
|
||||
@@ -150,13 +150,8 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
|
||||
if config.level == logging.DEBUG:
|
||||
# Debug log handler (all levels)
|
||||
# Security fix: Use RotatingFileHandler with size limits
|
||||
debug_log_handler = RotatingFileHandler(
|
||||
config.log_dir / DEBUG_LOG_FILE,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB per file
|
||||
backupCount=3, # Keep 3 backup files (40MB total)
|
||||
debug_log_handler = logging.FileHandler(
|
||||
config.log_dir / DEBUG_LOG_FILE, "a", "utf-8"
|
||||
)
|
||||
debug_log_handler.setLevel(logging.DEBUG)
|
||||
debug_log_handler.setFormatter(
|
||||
@@ -165,13 +160,8 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
log_handlers.append(debug_log_handler)
|
||||
|
||||
# Error log handler (ERROR and above)
|
||||
# Security fix: Use RotatingFileHandler with size limits
|
||||
error_log_handler = RotatingFileHandler(
|
||||
config.log_dir / ERROR_LOG_FILE,
|
||||
mode="a",
|
||||
encoding="utf-8",
|
||||
maxBytes=10 * 1024 * 1024, # 10MB per file
|
||||
backupCount=3, # Keep 3 backup files (40MB total)
|
||||
error_log_handler = logging.FileHandler(
|
||||
config.log_dir / ERROR_LOG_FILE, "a", "utf-8"
|
||||
)
|
||||
error_log_handler.setLevel(logging.ERROR)
|
||||
error_log_handler.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT, no_color=True))
|
||||
@@ -179,13 +169,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
|
||||
# Configure the root logger
|
||||
logging.basicConfig(
|
||||
format=(
|
||||
"%(levelname)s %(message)s"
|
||||
if structured_logging
|
||||
else (
|
||||
DEBUG_LOG_FORMAT if config.level == logging.DEBUG else SIMPLE_LOG_FORMAT
|
||||
)
|
||||
),
|
||||
format=DEBUG_LOG_FORMAT if config.level == logging.DEBUG else SIMPLE_LOG_FORMAT,
|
||||
level=config.level,
|
||||
handlers=log_handlers,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
@@ -15,8 +13,8 @@ class RateLimitSettings(BaseSettings):
|
||||
default="6379", description="Redis port", validation_alias="REDIS_PORT"
|
||||
)
|
||||
|
||||
redis_password: Optional[str] = Field(
|
||||
default=None,
|
||||
redis_password: str = Field(
|
||||
default="password",
|
||||
description="Redis password",
|
||||
validation_alias="REDIS_PASSWORD",
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ class RateLimiter:
|
||||
self,
|
||||
redis_host: str = RATE_LIMIT_SETTINGS.redis_host,
|
||||
redis_port: str = RATE_LIMIT_SETTINGS.redis_port,
|
||||
redis_password: str | None = RATE_LIMIT_SETTINGS.redis_password,
|
||||
redis_password: str = RATE_LIMIT_SETTINGS.redis_password,
|
||||
requests_per_minute: int = RATE_LIMIT_SETTINGS.requests_per_minute,
|
||||
):
|
||||
self.redis = Redis(
|
||||
|
||||
266
autogpt_platform/autogpt_libs/autogpt_libs/utils/cache.py
Normal file
266
autogpt_platform/autogpt_libs/autogpt_libs/utils/cache.py
Normal file
@@ -0,0 +1,266 @@
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
|
||||
pass
|
||||
|
||||
|
||||
def thread_cached(
|
||||
func: Callable[P, R] | Callable[P, Awaitable[R]],
|
||||
) -> Callable[P, R] | Callable[P, Awaitable[R]]:
|
||||
thread_local = threading.local()
|
||||
|
||||
def _clear():
|
||||
if hasattr(thread_local, "cache"):
|
||||
del thread_local.cache
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = await cast(Callable[P, Awaitable[R]], func)(
|
||||
*args, **kwargs
|
||||
)
|
||||
return cache[key]
|
||||
|
||||
setattr(async_wrapper, "clear_cache", _clear)
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(sync_wrapper, "clear_cache", _clear)
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
|
||||
|
||||
FuncT = TypeVar("FuncT")
|
||||
|
||||
|
||||
R_co = TypeVar("R_co", covariant=True)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AsyncCachedFunction(Protocol[P, R_co]):
|
||||
"""Protocol for async functions with cache management methods."""
|
||||
|
||||
def cache_clear(self) -> None:
|
||||
"""Clear all cached entries."""
|
||||
return None
|
||||
|
||||
def cache_info(self) -> dict[str, int | None]:
|
||||
"""Get cache statistics."""
|
||||
return {}
|
||||
|
||||
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
|
||||
"""Call the cached function."""
|
||||
return None # type: ignore
|
||||
|
||||
|
||||
def async_ttl_cache(
|
||||
maxsize: int = 128, ttl_seconds: int | None = None
|
||||
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
|
||||
"""
|
||||
TTL (Time To Live) cache decorator for async functions.
|
||||
|
||||
Similar to functools.lru_cache but works with async functions and includes optional TTL.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum number of cached entries
|
||||
ttl_seconds: Time to live in seconds. If None, entries never expire (like lru_cache)
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
|
||||
Example:
|
||||
# With TTL
|
||||
@async_ttl_cache(maxsize=1000, ttl_seconds=300)
|
||||
async def api_call(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
# Without TTL (permanent cache like lru_cache)
|
||||
@async_ttl_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
async_func: Callable[P, Awaitable[R]],
|
||||
) -> AsyncCachedFunction[P, R]:
|
||||
# Cache storage - use union type to handle both cases
|
||||
cache_storage: dict[tuple, R | Tuple[R, float]] = {}
|
||||
|
||||
@wraps(async_func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# Create cache key from arguments
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we have a valid cached entry
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
# No TTL - return cached result directly
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cast(R, cache_storage[key])
|
||||
else:
|
||||
# With TTL - check expiration
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cast(R, result)
|
||||
else:
|
||||
# Expired entry
|
||||
del cache_storage[key]
|
||||
logger.debug(
|
||||
f"Cache entry expired for {async_func.__name__}"
|
||||
)
|
||||
|
||||
# Cache miss or expired - fetch fresh data
|
||||
logger.debug(
|
||||
f"Cache miss for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
result = await async_func(*args, **kwargs)
|
||||
|
||||
# Store in cache
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Simple cleanup when cache gets too large
|
||||
if len(cache_storage) > maxsize:
|
||||
# Remove oldest entries (simple FIFO cleanup)
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
logger.debug(
|
||||
f"Cache cleanup: removed {len(oldest_keys)} entries for {async_func.__name__}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
# Add cache management methods (similar to functools.lru_cache)
|
||||
def cache_clear() -> None:
|
||||
cache_storage.clear()
|
||||
|
||||
def cache_info() -> dict[str, int | None]:
|
||||
return {
|
||||
"size": len(cache_storage),
|
||||
"maxsize": maxsize,
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
|
||||
# Attach methods to wrapper
|
||||
setattr(wrapper, "cache_clear", cache_clear)
|
||||
setattr(wrapper, "cache_info", cache_info)
|
||||
|
||||
return cast(AsyncCachedFunction[P, R], wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def async_cache(
|
||||
func: Callable[P, Awaitable[R]],
|
||||
) -> AsyncCachedFunction[P, R]:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def async_cache(
|
||||
func: None = None,
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
|
||||
pass
|
||||
|
||||
|
||||
def async_cache(
|
||||
func: Callable[P, Awaitable[R]] | None = None,
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
) -> (
|
||||
AsyncCachedFunction[P, R]
|
||||
| Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]
|
||||
):
|
||||
"""
|
||||
Process-level cache decorator for async functions (no TTL).
|
||||
|
||||
Similar to functools.lru_cache but works with async functions.
|
||||
This is a convenience wrapper around async_ttl_cache with ttl_seconds=None.
|
||||
|
||||
Args:
|
||||
func: The async function to cache (when used without parentheses)
|
||||
maxsize: Maximum number of cached entries
|
||||
|
||||
Returns:
|
||||
Decorated function or decorator
|
||||
|
||||
Example:
|
||||
# Without parentheses (uses default maxsize=128)
|
||||
@async_cache
|
||||
async def get_data(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
# With parentheses and custom maxsize
|
||||
@async_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
# Expensive computation here
|
||||
return {"result": param}
|
||||
"""
|
||||
if func is None:
|
||||
# Called with parentheses @async_cache() or @async_cache(maxsize=...)
|
||||
return async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
else:
|
||||
# Called without parentheses @async_cache
|
||||
decorator = async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
return decorator(func)
|
||||
705
autogpt_platform/autogpt_libs/autogpt_libs/utils/cache_test.py
Normal file
705
autogpt_platform/autogpt_libs/autogpt_libs/utils/cache_test.py
Normal file
@@ -0,0 +1,705 @@
|
||||
"""Tests for the @thread_cached decorator.
|
||||
|
||||
This module tests the thread-local caching functionality including:
|
||||
- Basic caching for sync and async functions
|
||||
- Thread isolation (each thread has its own cache)
|
||||
- Cache clearing functionality
|
||||
- Exception handling (exceptions are not cached)
|
||||
- Argument handling (positional vs keyword arguments)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt_libs.utils.cache import (
|
||||
async_cache,
|
||||
async_ttl_cache,
|
||||
clear_thread_cache,
|
||||
thread_cached,
|
||||
)
|
||||
|
||||
|
||||
class TestThreadCached:
|
||||
def test_sync_function_caching(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def expensive_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x + y
|
||||
|
||||
assert expensive_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert expensive_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert expensive_function(1, y=2) == 3
|
||||
assert call_count == 2
|
||||
|
||||
assert expensive_function(2, 3) == 5
|
||||
assert call_count == 3
|
||||
|
||||
assert expensive_function(1) == 1
|
||||
assert call_count == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function_caching(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
async def expensive_async_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x + y
|
||||
|
||||
assert await expensive_async_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert await expensive_async_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert await expensive_async_function(1, y=2) == 3
|
||||
assert call_count == 2
|
||||
|
||||
assert await expensive_async_function(2, 3) == 5
|
||||
assert call_count == 3
|
||||
|
||||
def test_thread_isolation(self):
|
||||
call_count = 0
|
||||
results = {}
|
||||
|
||||
@thread_cached
|
||||
def thread_specific_function(x: int) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{threading.current_thread().name}-{x}"
|
||||
|
||||
def worker(thread_id: int):
|
||||
result1 = thread_specific_function(1)
|
||||
result2 = thread_specific_function(1)
|
||||
result3 = thread_specific_function(2)
|
||||
results[thread_id] = (result1, result2, result3)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||
futures = [executor.submit(worker, i) for i in range(3)]
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
assert call_count >= 2
|
||||
|
||||
for thread_id, (r1, r2, r3) in results.items():
|
||||
assert r1 == r2
|
||||
assert r1 != r3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_thread_isolation(self):
|
||||
call_count = 0
|
||||
results = {}
|
||||
|
||||
@thread_cached
|
||||
async def async_thread_specific_function(x: int) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return f"{threading.current_thread().name}-{x}"
|
||||
|
||||
async def async_worker(worker_id: int):
|
||||
result1 = await async_thread_specific_function(1)
|
||||
result2 = await async_thread_specific_function(1)
|
||||
result3 = await async_thread_specific_function(2)
|
||||
results[worker_id] = (result1, result2, result3)
|
||||
|
||||
tasks = [async_worker(i) for i in range(3)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
for worker_id, (r1, r2, r3) in results.items():
|
||||
assert r1 == r2
|
||||
assert r1 != r3
|
||||
|
||||
def test_clear_cache_sync(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
|
||||
assert clearable_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
assert clearable_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
clear_thread_cache(clearable_function)
|
||||
|
||||
assert clearable_function(5) == 10
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_cache_async(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
async def clearable_async_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x * 2
|
||||
|
||||
assert await clearable_async_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
assert await clearable_async_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
clear_thread_cache(clearable_async_function)
|
||||
|
||||
assert await clearable_async_function(5) == 10
|
||||
assert call_count == 2
|
||||
|
||||
def test_simple_arguments(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def simple_function(a: str, b: int, c: str = "default") -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{a}-{b}-{c}"
|
||||
|
||||
# First call with all positional args
|
||||
result1 = simple_function("test", 42, "custom")
|
||||
assert call_count == 1
|
||||
|
||||
# Same args, all positional - should hit cache
|
||||
result2 = simple_function("test", 42, "custom")
|
||||
assert call_count == 1
|
||||
assert result1 == result2
|
||||
|
||||
# Same values but last arg as keyword - creates different cache key
|
||||
result3 = simple_function("test", 42, c="custom")
|
||||
assert call_count == 2
|
||||
assert result1 == result3 # Same result, different cache entry
|
||||
|
||||
# Different value - new cache entry
|
||||
result4 = simple_function("test", 43, "custom")
|
||||
assert call_count == 3
|
||||
assert result1 != result4
|
||||
|
||||
def test_positional_vs_keyword_args(self):
|
||||
"""Test that positional and keyword arguments create different cache entries."""
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def func(a: int, b: int = 10) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"result-{a}-{b}"
|
||||
|
||||
# All positional
|
||||
result1 = func(1, 2)
|
||||
assert call_count == 1
|
||||
assert result1 == "result-1-2"
|
||||
|
||||
# Same values, but second arg as keyword
|
||||
result2 = func(1, b=2)
|
||||
assert call_count == 2 # Different cache key!
|
||||
assert result2 == "result-1-2" # Same result
|
||||
|
||||
# Verify both are cached separately
|
||||
func(1, 2) # Uses first cache entry
|
||||
assert call_count == 2
|
||||
|
||||
func(1, b=2) # Uses second cache entry
|
||||
assert call_count == 2
|
||||
|
||||
def test_exception_handling(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def failing_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if x < 0:
|
||||
raise ValueError("Negative value")
|
||||
return x * 2
|
||||
|
||||
assert failing_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
failing_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
failing_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
assert failing_function(5) == 10
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_exception_handling(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
async def async_failing_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
if x < 0:
|
||||
raise ValueError("Negative value")
|
||||
return x * 2
|
||||
|
||||
assert await async_failing_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await async_failing_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await async_failing_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
def test_sync_caching_performance(self):
|
||||
@thread_cached
|
||||
def slow_function(x: int) -> int:
|
||||
print(f"slow_function called with x={x}")
|
||||
time.sleep(0.1)
|
||||
return x * 2
|
||||
|
||||
start = time.time()
|
||||
result1 = slow_function(5)
|
||||
first_call_time = time.time() - start
|
||||
print(f"First call took {first_call_time:.4f} seconds")
|
||||
|
||||
start = time.time()
|
||||
result2 = slow_function(5)
|
||||
second_call_time = time.time() - start
|
||||
print(f"Second call took {second_call_time:.4f} seconds")
|
||||
|
||||
assert result1 == result2 == 10
|
||||
assert first_call_time > 0.09
|
||||
assert second_call_time < 0.01
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_caching_performance(self):
|
||||
@thread_cached
|
||||
async def slow_async_function(x: int) -> int:
|
||||
print(f"slow_async_function called with x={x}")
|
||||
await asyncio.sleep(0.1)
|
||||
return x * 2
|
||||
|
||||
start = time.time()
|
||||
result1 = await slow_async_function(5)
|
||||
first_call_time = time.time() - start
|
||||
print(f"First async call took {first_call_time:.4f} seconds")
|
||||
|
||||
start = time.time()
|
||||
result2 = await slow_async_function(5)
|
||||
second_call_time = time.time() - start
|
||||
print(f"Second async call took {second_call_time:.4f} seconds")
|
||||
|
||||
assert result1 == result2 == 10
|
||||
assert first_call_time > 0.09
|
||||
assert second_call_time < 0.01
|
||||
|
||||
def test_with_mock_objects(self):
|
||||
mock = Mock(return_value=42)
|
||||
|
||||
@thread_cached
|
||||
def function_using_mock(x: int) -> int:
|
||||
return mock(x)
|
||||
|
||||
assert function_using_mock(1) == 42
|
||||
assert mock.call_count == 1
|
||||
|
||||
assert function_using_mock(1) == 42
|
||||
assert mock.call_count == 1
|
||||
|
||||
assert function_using_mock(2) == 42
|
||||
assert mock.call_count == 2
|
||||
|
||||
|
||||
class TestAsyncTTLCache:
|
||||
"""Tests for the @async_ttl_cache decorator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching(self):
|
||||
"""Test basic caching functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result3 = await cached_function(2, 3)
|
||||
assert result3 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_expiration(self):
|
||||
"""Test that cache entries expire after TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def short_lived_cache(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First call
|
||||
result1 = await short_lived_cache(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Second call immediately - should use cache
|
||||
result2 = await short_lived_cache(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Third call after expiration - should call function again
|
||||
result3 = await short_lived_cache(5)
|
||||
assert result3 == 10
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_info(self):
|
||||
"""Test cache info functionality."""
|
||||
|
||||
@async_ttl_cache(maxsize=5, ttl_seconds=300)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] == 300
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_clear(self):
|
||||
"""Test cache clearing functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 4
|
||||
|
||||
# First call
|
||||
result1 = await clearable_function(2)
|
||||
assert result1 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
result2 = await clearable_function(2)
|
||||
assert result2 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Clear cache
|
||||
clearable_function.cache_clear()
|
||||
|
||||
# Third call after clear - should call function again
|
||||
result3 = await clearable_function(2)
|
||||
assert result3 == 8
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maxsize_cleanup(self):
|
||||
"""Test that cache cleans up when maxsize is exceeded."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=3, ttl_seconds=60)
|
||||
async def size_limited_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# Fill cache to maxsize
|
||||
await size_limited_function(1) # call_count: 1
|
||||
await size_limited_function(2) # call_count: 2
|
||||
await size_limited_function(3) # call_count: 3
|
||||
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] == 3
|
||||
|
||||
# Add one more entry - should trigger cleanup
|
||||
await size_limited_function(4) # call_count: 4
|
||||
|
||||
# Cache size should be reduced (cleanup removes oldest entries)
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] is not None and info["size"] <= 3 # Should be cleaned up
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_argument_variations(self):
|
||||
"""Test caching with different argument patterns."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def arg_test_function(a: int, b: str = "default", *, c: int = 100) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{a}-{b}-{c}"
|
||||
|
||||
# Different ways to call with same logical arguments
|
||||
result1 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
|
||||
# Same arguments, same order - should use cache
|
||||
result2 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
assert result1 == result2
|
||||
|
||||
# Different arguments - should call function
|
||||
result3 = await arg_test_function(2, "test", c=200)
|
||||
assert call_count == 2
|
||||
assert result1 != result3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_handling(self):
|
||||
"""Test that exceptions are not cached."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def exception_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if x < 0:
|
||||
raise ValueError("Negative value not allowed")
|
||||
return x * 2
|
||||
|
||||
# Successful call - should be cached
|
||||
result1 = await exception_function(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Same successful call - should use cache
|
||||
result2 = await exception_function(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Exception call - should not be cached
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
# Same exception call - should call again (not cached)
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_calls(self):
|
||||
"""Test caching behavior with concurrent calls."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def concurrent_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.05) # Simulate work
|
||||
return x * x
|
||||
|
||||
# Launch concurrent calls with same arguments
|
||||
tasks = [concurrent_function(3) for _ in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All results should be the same
|
||||
assert all(result == 9 for result in results)
|
||||
|
||||
# Note: Due to race conditions, call_count might be up to 5 for concurrent calls
|
||||
# This tests that the cache doesn't break under concurrent access
|
||||
assert 1 <= call_count <= 5
|
||||
|
||||
|
||||
class TestAsyncCache:
|
||||
"""Tests for the @async_cache decorator (no TTL)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching_no_ttl(self):
|
||||
"""Test basic caching functionality without TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_cache(maxsize=10)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Third call after some time - should still use cache (no TTL)
|
||||
await asyncio.sleep(0.05)
|
||||
result3 = await cached_function(1, 2)
|
||||
assert result3 == 3
|
||||
assert call_count == 1 # Still no additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result4 = await cached_function(2, 3)
|
||||
assert result4 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_ttl_vs_ttl_behavior(self):
|
||||
"""Test the difference between TTL and no-TTL caching."""
|
||||
ttl_call_count = 0
|
||||
no_ttl_call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_call_count
|
||||
ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
@async_cache(maxsize=10) # No TTL
|
||||
async def no_ttl_function(x: int) -> int:
|
||||
nonlocal no_ttl_call_count
|
||||
no_ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First calls
|
||||
await ttl_function(5)
|
||||
await no_ttl_function(5)
|
||||
assert ttl_call_count == 1
|
||||
assert no_ttl_call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Second calls after TTL expiry
|
||||
await ttl_function(5) # Should call function again (TTL expired)
|
||||
await no_ttl_function(5) # Should use cache (no TTL)
|
||||
assert ttl_call_count == 2 # TTL function called again
|
||||
assert no_ttl_call_count == 1 # No-TTL function still cached
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_cache_info(self):
|
||||
"""Test cache info for no-TTL cache."""
|
||||
|
||||
@async_cache(maxsize=5)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] is None # No TTL
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
|
||||
class TestTTLOptional:
|
||||
"""Tests for optional TTL functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_none_behavior(self):
|
||||
"""Test that ttl_seconds=None works like no TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=None)
|
||||
async def no_ttl_via_none(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# First call
|
||||
result1 = await no_ttl_via_none(3)
|
||||
assert result1 == 9
|
||||
assert call_count == 1
|
||||
|
||||
# Wait (would expire if there was TTL)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Second call - should still use cache
|
||||
result2 = await no_ttl_via_none(3)
|
||||
assert result2 == 9
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Check cache info
|
||||
info = no_ttl_via_none.cache_info()
|
||||
assert info["ttl_seconds"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_options_comparison(self):
|
||||
"""Test different cache options work as expected."""
|
||||
ttl_calls = 0
|
||||
no_ttl_calls = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # With TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_calls
|
||||
ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
@async_cache(maxsize=10) # Process-level cache (no TTL)
|
||||
async def process_function(x: int) -> int:
|
||||
nonlocal no_ttl_calls
|
||||
no_ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
# Both should cache initially
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Immediate second calls - both should use cache
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# After TTL expiry
|
||||
await ttl_function(3) # Should call function again
|
||||
await process_function(3) # Should still use cache
|
||||
assert ttl_calls == 2 # TTL cache expired, called again
|
||||
assert no_ttl_calls == 1 # Process cache never expires
|
||||
76
autogpt_platform/autogpt_libs/poetry.lock
generated
76
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1002,18 +1002,6 @@ dynamodb = ["boto3 (>=1.9.71)"]
|
||||
redis = ["redis (>=2.10.5)"]
|
||||
test-filesource = ["pyyaml (>=5.3.1)", "watchdog (>=3.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "nodeenv"
|
||||
version = "1.9.1"
|
||||
description = "Node.js virtual environment builder"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"},
|
||||
{file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-api"
|
||||
version = "1.35.0"
|
||||
@@ -1359,27 +1347,6 @@ files = [
|
||||
{file = "pyrfc3339-2.0.1.tar.gz", hash = "sha256:e47843379ea35c1296c3b6c67a948a1a490ae0584edfcbdea0eaffb5dd29960b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyright"
|
||||
version = "1.1.404"
|
||||
description = "Command line wrapper for pyright"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pyright-1.1.404-py3-none-any.whl", hash = "sha256:c7b7ff1fdb7219c643079e4c3e7d4125f0dafcc19d253b47e898d130ea426419"},
|
||||
{file = "pyright-1.1.404.tar.gz", hash = "sha256:455e881a558ca6be9ecca0b30ce08aa78343ecc031d37a198ffa9a7a1abeb63e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
nodeenv = ">=1.6.0"
|
||||
typing-extensions = ">=4.1"
|
||||
|
||||
[package.extras]
|
||||
all = ["nodejs-wheel-binaries", "twine (>=3.4.1)"]
|
||||
dev = ["twine (>=3.4.1)"]
|
||||
nodejs = ["nodejs-wheel-binaries"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.4.1"
|
||||
@@ -1567,31 +1534,31 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.12.11"
|
||||
version = "0.12.9"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "ruff-0.12.11-py3-none-linux_armv6l.whl", hash = "sha256:93fce71e1cac3a8bf9200e63a38ac5c078f3b6baebffb74ba5274fb2ab276065"},
|
||||
{file = "ruff-0.12.11-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b8e33ac7b28c772440afa80cebb972ffd823621ded90404f29e5ab6d1e2d4b93"},
|
||||
{file = "ruff-0.12.11-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d69fb9d4937aa19adb2e9f058bc4fbfe986c2040acb1a4a9747734834eaa0bfd"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:411954eca8464595077a93e580e2918d0a01a19317af0a72132283e28ae21bee"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6a2c0a2e1a450f387bf2c6237c727dd22191ae8c00e448e0672d624b2bbd7fb0"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ca4c3a7f937725fd2413c0e884b5248a19369ab9bdd850b5781348ba283f644"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:4d1df0098124006f6a66ecf3581a7f7e754c4df7644b2e6704cd7ca80ff95211"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5a8dd5f230efc99a24ace3b77e3555d3fbc0343aeed3fc84c8d89e75ab2ff793"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4dc75533039d0ed04cd33fb8ca9ac9620b99672fe7ff1533b6402206901c34ee"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4fc58f9266d62c6eccc75261a665f26b4ef64840887fc6cbc552ce5b29f96cc8"},
|
||||
{file = "ruff-0.12.11-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:5a0113bd6eafd545146440225fe60b4e9489f59eb5f5f107acd715ba5f0b3d2f"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0d737b4059d66295c3ea5720e6efc152623bb83fde5444209b69cd33a53e2000"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:916fc5defee32dbc1fc1650b576a8fed68f5e8256e2180d4d9855aea43d6aab2"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c984f07d7adb42d3ded5be894fb4007f30f82c87559438b4879fe7aa08c62b39"},
|
||||
{file = "ruff-0.12.11-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e07fbb89f2e9249f219d88331c833860489b49cdf4b032b8e4432e9b13e8a4b9"},
|
||||
{file = "ruff-0.12.11-py3-none-win32.whl", hash = "sha256:c792e8f597c9c756e9bcd4d87cf407a00b60af77078c96f7b6366ea2ce9ba9d3"},
|
||||
{file = "ruff-0.12.11-py3-none-win_amd64.whl", hash = "sha256:a3283325960307915b6deb3576b96919ee89432ebd9c48771ca12ee8afe4a0fd"},
|
||||
{file = "ruff-0.12.11-py3-none-win_arm64.whl", hash = "sha256:bae4d6e6a2676f8fb0f98b74594a048bae1b944aab17e9f5d504062303c6dbea"},
|
||||
{file = "ruff-0.12.11.tar.gz", hash = "sha256:c6b09ae8426a65bbee5425b9d0b82796dbb07cb1af045743c79bfb163001165d"},
|
||||
{file = "ruff-0.12.9-py3-none-linux_armv6l.whl", hash = "sha256:fcebc6c79fcae3f220d05585229463621f5dbf24d79fdc4936d9302e177cfa3e"},
|
||||
{file = "ruff-0.12.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:aed9d15f8c5755c0e74467731a007fcad41f19bcce41cd75f768bbd687f8535f"},
|
||||
{file = "ruff-0.12.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5b15ea354c6ff0d7423814ba6d44be2807644d0c05e9ed60caca87e963e93f70"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d596c2d0393c2502eaabfef723bd74ca35348a8dac4267d18a94910087807c53"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1b15599931a1a7a03c388b9c5df1bfa62be7ede6eb7ef753b272381f39c3d0ff"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3d02faa2977fb6f3f32ddb7828e212b7dd499c59eb896ae6c03ea5c303575756"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:17d5b6b0b3a25259b69ebcba87908496e6830e03acfb929ef9fd4c58675fa2ea"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:72db7521860e246adbb43f6ef464dd2a532ef2ef1f5dd0d470455b8d9f1773e0"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a03242c1522b4e0885af63320ad754d53983c9599157ee33e77d748363c561ce"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fc83e4e9751e6c13b5046d7162f205d0a7bac5840183c5beebf824b08a27340"},
|
||||
{file = "ruff-0.12.9-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:881465ed56ba4dd26a691954650de6ad389a2d1fdb130fe51ff18a25639fe4bb"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:43f07a3ccfc62cdb4d3a3348bf0588358a66da756aa113e071b8ca8c3b9826af"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:07adb221c54b6bba24387911e5734357f042e5669fa5718920ee728aba3cbadc"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f5cd34fabfdea3933ab85d72359f118035882a01bff15bd1d2b15261d85d5f66"},
|
||||
{file = "ruff-0.12.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:f6be1d2ca0686c54564da8e7ee9e25f93bdd6868263805f8c0b8fc6a449db6d7"},
|
||||
{file = "ruff-0.12.9-py3-none-win32.whl", hash = "sha256:cc7a37bd2509974379d0115cc5608a1a4a6c4bff1b452ea69db83c8855d53f93"},
|
||||
{file = "ruff-0.12.9-py3-none-win_amd64.whl", hash = "sha256:6fb15b1977309741d7d098c8a3cb7a30bc112760a00fb6efb7abc85f00ba5908"},
|
||||
{file = "ruff-0.12.9-py3-none-win_arm64.whl", hash = "sha256:63c8c819739d86b96d500cce885956a1a48ab056bbcbc61b747ad494b2485089"},
|
||||
{file = "ruff-0.12.9.tar.gz", hash = "sha256:fbd94b2e3c623f659962934e52c2bea6fc6da11f667a427a368adaf3af2c866a"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1773,6 +1740,7 @@ files = [
|
||||
{file = "typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76"},
|
||||
{file = "typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36"},
|
||||
]
|
||||
markers = {dev = "python_version < \"3.11\""}
|
||||
|
||||
[[package]]
|
||||
name = "typing-inspection"
|
||||
@@ -1929,4 +1897,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "0c40b63c3c921846cf05ccfb4e685d4959854b29c2c302245f9832e20aac6954"
|
||||
content-hash = "ef7818fba061cea2841c6d7ca4852acde83e4f73b32fca1315e58660002bb0d0"
|
||||
|
||||
@@ -9,7 +9,6 @@ packages = [{ include = "autogpt_libs" }]
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<4.0"
|
||||
colorama = "^0.4.6"
|
||||
cryptography = "^45.0"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.116.1"
|
||||
google-cloud-logging = "^3.12.1"
|
||||
@@ -22,12 +21,11 @@ supabase = "^2.16.0"
|
||||
uvicorn = "^0.35.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pyright = "^1.1.404"
|
||||
ruff = "^0.12.9"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
pytest-mock = "^3.14.1"
|
||||
pytest-cov = "^6.2.1"
|
||||
ruff = "^0.12.11"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -21,7 +21,7 @@ PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
# Redis Configuration
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
# REDIS_PASSWORD=
|
||||
REDIS_PASSWORD=password
|
||||
|
||||
# RabbitMQ Credentials
|
||||
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
@@ -66,11 +66,6 @@ NVIDIA_API_KEY=
|
||||
GITHUB_CLIENT_ID=
|
||||
GITHUB_CLIENT_SECRET=
|
||||
|
||||
# Notion OAuth App server credentials - https://developers.notion.com/docs/authorization
|
||||
# Configure a public integration
|
||||
NOTION_CLIENT_ID=
|
||||
NOTION_CLIENT_SECRET=
|
||||
|
||||
# Google OAuth App server credentials - https://console.cloud.google.com/apis/credentials, and enable gmail api and set scopes
|
||||
# https://console.cloud.google.com/apis/credentials/consent ?project=<your_project_id>
|
||||
# You'll need to add/enable the following scopes (minimum):
|
||||
|
||||
10
autogpt_platform/backend/.gitignore
vendored
10
autogpt_platform/backend/.gitignore
vendored
@@ -9,12 +9,4 @@ secrets/*
|
||||
!secrets/.gitkeep
|
||||
|
||||
*.ignore.*
|
||||
*.ign.*
|
||||
|
||||
# Load test results and reports
|
||||
load-tests/*_RESULTS.md
|
||||
load-tests/*_REPORT.md
|
||||
load-tests/results/
|
||||
load-tests/*.json
|
||||
load-tests/*.log
|
||||
load-tests/node_modules/*
|
||||
*.ign.*
|
||||
@@ -9,15 +9,8 @@ WORKDIR /app
|
||||
|
||||
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
|
||||
|
||||
# Install Node.js repository key and setup
|
||||
# Update package list and install Python and build dependencies
|
||||
RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
||||
&& apt-get install -y curl ca-certificates gnupg \
|
||||
&& mkdir -p /etc/apt/keyrings \
|
||||
&& curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg \
|
||||
&& echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" | tee /etc/apt/sources.list.d/nodesource.list
|
||||
|
||||
# Update package list and install Python, Node.js, and build dependencies
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y \
|
||||
python3.13 \
|
||||
python3.13-dev \
|
||||
@@ -27,9 +20,7 @@ RUN apt-get update \
|
||||
libpq5 \
|
||||
libz-dev \
|
||||
libssl-dev \
|
||||
postgresql-client \
|
||||
nodejs \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
postgresql-client
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry
|
||||
ENV POETRY_NO_INTERACTION=1
|
||||
@@ -47,7 +38,6 @@ RUN poetry install --no-ansi --no-root
|
||||
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
RUN poetry run prisma generate
|
||||
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
@@ -64,18 +54,13 @@ ENV PATH=/opt/poetry/bin:$PATH
|
||||
# Install Python without upgrading system-managed packages
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.13 \
|
||||
python3-pip \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
python3-pip
|
||||
|
||||
# Copy only necessary files from builder
|
||||
COPY --from=builder /app /app
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
# Copy Node.js installation for Prisma
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||
# Copy Prisma binaries
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||
@@ -93,7 +78,6 @@ FROM server_dependencies AS migrate
|
||||
|
||||
# Migration stage only needs schema and migrations - much lighter than full backend
|
||||
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py /app/autogpt_platform/backend/backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
|
||||
|
||||
FROM server_dependencies AS server
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
@@ -5,8 +6,6 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from backend.util.cache import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -16,7 +15,7 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
@functools.cache
|
||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
from backend.data.block import Block
|
||||
from backend.util.settings import Config
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockSchemaInput,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus, NodesInputMasks
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.json import validate_with_jsonschema
|
||||
from backend.util.retry import func_retry
|
||||
@@ -20,7 +21,7 @@ _logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentExecutorBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
user_id: str = SchemaField(description="User ID")
|
||||
graph_id: str = SchemaField(description="Graph ID")
|
||||
graph_version: int = SchemaField(description="Graph Version")
|
||||
@@ -32,7 +33,7 @@ class AgentExecutorBlock(Block):
|
||||
input_schema: dict = SchemaField(description="Input schema for the graph")
|
||||
output_schema: dict = SchemaField(description="Output schema for the graph")
|
||||
|
||||
nodes_input_masks: Optional[NodesInputMasks] = SchemaField(
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = SchemaField(
|
||||
default=None, hidden=True
|
||||
)
|
||||
|
||||
@@ -54,7 +55,6 @@ class AgentExecutorBlock(Block):
|
||||
return validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Use BlockSchema to avoid automatic error field that could clash with graph outputs
|
||||
pass
|
||||
|
||||
def __init__(self):
|
||||
@@ -67,13 +67,7 @@ class AgentExecutorBlock(Block):
|
||||
categories={BlockCategory.AGENT},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
|
||||
from backend.executor import utils as execution_utils
|
||||
|
||||
@@ -83,8 +77,6 @@ class AgentExecutorBlock(Block):
|
||||
user_id=input_data.user_id,
|
||||
inputs=input_data.inputs,
|
||||
nodes_input_masks=input_data.nodes_input_masks,
|
||||
parent_graph_exec_id=graph_exec_id,
|
||||
is_sub_graph=True, # AgentExecutorBlock executions are always sub-graphs
|
||||
)
|
||||
|
||||
logger = execution_utils.LogMetadata(
|
||||
|
||||
@@ -1,219 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks.llm import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AIBlockBase,
|
||||
AICredentials,
|
||||
AICredentialsField,
|
||||
LlmModel,
|
||||
LLMResponse,
|
||||
llm_call,
|
||||
)
|
||||
from backend.data.block import (
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||
|
||||
|
||||
class AIConditionBlock(AIBlockBase):
|
||||
"""
|
||||
An AI-powered condition block that uses natural language to evaluate conditions.
|
||||
|
||||
This block allows users to define conditions in plain English (e.g., "the input is an email address",
|
||||
"the input is a city in the USA") and uses AI to determine if the input satisfies the condition.
|
||||
It provides the same yes/no data pass-through functionality as the standard ConditionBlock.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
input_value: Any = SchemaField(
|
||||
description="The input value to evaluate with the AI condition",
|
||||
placeholder="Enter the value to be evaluated (text, number, or any data)",
|
||||
)
|
||||
condition: str = SchemaField(
|
||||
description="A plaintext English description of the condition to evaluate",
|
||||
placeholder="E.g., 'the input is the body of an email', 'the input is a City in the USA', 'the input is an error or a refusal'",
|
||||
)
|
||||
yes_value: Any = SchemaField(
|
||||
description="(Optional) Value to output if the condition is true. If not provided, input_value will be used.",
|
||||
placeholder="Leave empty to use input_value, or enter a specific value",
|
||||
default=None,
|
||||
)
|
||||
no_value: Any = SchemaField(
|
||||
description="(Optional) Value to output if the condition is false. If not provided, input_value will be used.",
|
||||
placeholder="Leave empty to use input_value, or enter a specific value",
|
||||
default=None,
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for evaluating the condition.",
|
||||
advanced=False,
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
result: bool = SchemaField(
|
||||
description="The result of the AI condition evaluation (True or False)"
|
||||
)
|
||||
yes_output: Any = SchemaField(
|
||||
description="The output value if the condition is true"
|
||||
)
|
||||
no_output: Any = SchemaField(
|
||||
description="The output value if the condition is false"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the AI evaluation is uncertain or fails"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="553ec5b8-6c45-4299-8d75-b394d05f72ff",
|
||||
input_schema=AIConditionBlock.Input,
|
||||
output_schema=AIConditionBlock.Output,
|
||||
description="Uses AI to evaluate natural language conditions and provide conditional outputs",
|
||||
categories={BlockCategory.AI, BlockCategory.LOGIC},
|
||||
test_input={
|
||||
"input_value": "john@example.com",
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "Valid email",
|
||||
"no_value": "Not an email",
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("result", True),
|
||||
("yes_output", "Valid email"),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda *args, **kwargs: LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response="true",
|
||||
tool_calls=None,
|
||||
prompt_tokens=50,
|
||||
completion_tokens=10,
|
||||
reasoning=None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
async def llm_call(
|
||||
self,
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
prompt: list,
|
||||
max_tokens: int,
|
||||
) -> LLMResponse:
|
||||
"""Wrapper method for llm_call to enable mocking in tests."""
|
||||
return await llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
force_json_output=False,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Evaluate the AI condition and return appropriate outputs.
|
||||
"""
|
||||
# Prepare the yes and no values, using input_value as default
|
||||
yes_value = (
|
||||
input_data.yes_value
|
||||
if input_data.yes_value is not None
|
||||
else input_data.input_value
|
||||
)
|
||||
no_value = (
|
||||
input_data.no_value
|
||||
if input_data.no_value is not None
|
||||
else input_data.input_value
|
||||
)
|
||||
|
||||
# Convert input_value to string for AI evaluation
|
||||
input_str = str(input_data.input_value)
|
||||
|
||||
# Create the prompt for AI evaluation
|
||||
prompt = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are an AI assistant that evaluates conditions based on input data. "
|
||||
"You must respond with only 'true' or 'false' (lowercase) to indicate whether "
|
||||
"the given condition is met by the input value. Be accurate and consider the "
|
||||
"context and meaning of both the input and the condition."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Input value: {input_str}\n"
|
||||
f"Condition to evaluate: {input_data.condition}\n\n"
|
||||
f"Does the input value satisfy the condition? Respond with only 'true' or 'false'."
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
# Call the LLM
|
||||
try:
|
||||
response = await self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
max_tokens=10, # We only expect a true/false response
|
||||
)
|
||||
|
||||
# Extract the boolean result from the response
|
||||
response_text = response.response.strip().lower()
|
||||
if response_text == "true":
|
||||
result = True
|
||||
elif response_text == "false":
|
||||
result = False
|
||||
else:
|
||||
# If the response is not clear, try to interpret it using word boundaries
|
||||
import re
|
||||
|
||||
# Use word boundaries to avoid false positives like 'untrue' or '10'
|
||||
tokens = set(re.findall(r"\b(true|false|yes|no|1|0)\b", response_text))
|
||||
|
||||
if tokens == {"true"} or tokens == {"yes"} or tokens == {"1"}:
|
||||
result = True
|
||||
elif tokens == {"false"} or tokens == {"no"} or tokens == {"0"}:
|
||||
result = False
|
||||
else:
|
||||
# Unclear or conflicting response - default to False and yield error
|
||||
result = False
|
||||
yield "error", f"Unclear AI response: '{response.response}'"
|
||||
|
||||
# Update internal stats
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
input_token_count=response.prompt_tokens,
|
||||
output_token_count=response.completion_tokens,
|
||||
)
|
||||
)
|
||||
self.prompt = response.prompt
|
||||
|
||||
except Exception as e:
|
||||
# In case of any error, default to False to be safe
|
||||
result = False
|
||||
# Log the error but don't fail the block execution
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"AI condition evaluation failed: {str(e)}")
|
||||
yield "error", f"AI evaluation failed: {str(e)}"
|
||||
|
||||
# Yield results
|
||||
yield "result", result
|
||||
|
||||
if result:
|
||||
yield "yes_output", yes_value
|
||||
else:
|
||||
yield "no_output", no_value
|
||||
@@ -1,159 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.file import MediaFileType
|
||||
|
||||
|
||||
class GeminiImageModel(str, Enum):
|
||||
NANO_BANANA = "google/nano-banana"
|
||||
|
||||
|
||||
class OutputFormat(str, Enum):
|
||||
JPG = "jpg"
|
||||
PNG = "png"
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="replicate",
|
||||
api_key=SecretStr("mock-replicate-api-key"),
|
||||
title="Mock Replicate API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class AIImageCustomizerBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Replicate API key with permissions for Google Gemini image models",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="A text description of the image you want to generate",
|
||||
title="Prompt",
|
||||
)
|
||||
model: GeminiImageModel = SchemaField(
|
||||
description="The AI model to use for image generation and editing",
|
||||
default=GeminiImageModel.NANO_BANANA,
|
||||
title="Model",
|
||||
)
|
||||
images: list[MediaFileType] = SchemaField(
|
||||
description="Optional list of input images to reference or modify",
|
||||
default=[],
|
||||
title="Input Images",
|
||||
)
|
||||
output_format: OutputFormat = SchemaField(
|
||||
description="Format of the output image",
|
||||
default=OutputFormat.PNG,
|
||||
title="Output Format",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
image_url: MediaFileType = SchemaField(description="URL of the generated image")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d76bbe4c-930e-4894-8469-b66775511f71",
|
||||
description=(
|
||||
"Generate and edit custom images using Google's Nano-Banana model from Gemini 2.5. "
|
||||
"Provide a prompt and optional reference images to create or modify images."
|
||||
),
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
input_schema=AIImageCustomizerBlock.Input,
|
||||
output_schema=AIImageCustomizerBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Make the scene more vibrant and colorful",
|
||||
"model": GeminiImageModel.NANO_BANANA,
|
||||
"images": [],
|
||||
"output_format": OutputFormat.JPG,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
("image_url", "https://replicate.delivery/generated-image.jpg"),
|
||||
],
|
||||
test_mock={
|
||||
"run_model": lambda *args, **kwargs: MediaFileType(
|
||||
"https://replicate.delivery/generated-image.jpg"
|
||||
),
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = await self.run_model(
|
||||
api_key=credentials.api_key,
|
||||
model_name=input_data.model.value,
|
||||
prompt=input_data.prompt,
|
||||
images=input_data.images,
|
||||
output_format=input_data.output_format.value,
|
||||
)
|
||||
yield "image_url", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
async def run_model(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
images: list[MediaFileType],
|
||||
output_format: str,
|
||||
) -> MediaFileType:
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
|
||||
input_params: dict = {
|
||||
"prompt": prompt,
|
||||
"output_format": output_format,
|
||||
}
|
||||
|
||||
# Add images to input if provided (API expects "image_input" parameter)
|
||||
if images:
|
||||
input_params["image_input"] = [str(img) for img in images]
|
||||
|
||||
output: FileOutput | str = await client.async_run( # type: ignore
|
||||
model_name,
|
||||
input=input_params,
|
||||
wait=False,
|
||||
)
|
||||
|
||||
if isinstance(output, FileOutput):
|
||||
return MediaFileType(output.url)
|
||||
if isinstance(output, str):
|
||||
return MediaFileType(output)
|
||||
|
||||
raise ValueError("No output received from the model")
|
||||
@@ -5,7 +5,7 @@ from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
||||
from backend.data.block import Block, BlockCategory, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -101,7 +101,7 @@ class ImageGenModel(str, Enum):
|
||||
|
||||
|
||||
class AIImageGeneratorBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -135,8 +135,9 @@ class AIImageGeneratorBlock(Block):
|
||||
title="Image Style",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
image_url: str = SchemaField(description="URL of the generated image")
|
||||
error: str = SchemaField(description="Error message if generation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -6,13 +6,7 @@ from typing import Literal
|
||||
from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -60,7 +54,7 @@ class NormalizationStrategy(str, Enum):
|
||||
|
||||
|
||||
class AIMusicGeneratorBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -113,8 +107,9 @@ class AIMusicGeneratorBlock(Block):
|
||||
title="Normalization Strategy",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
result: str = SchemaField(description="URL of the generated audio file")
|
||||
error: str = SchemaField(description="Error message if the model run failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -6,13 +6,7 @@ from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -154,7 +148,7 @@ logger = logging.getLogger(__name__)
|
||||
class AIShortformVideoCreatorBlock(Block):
|
||||
"""Creates a short‑form text‑to‑video clip using stock or AI imagery."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -193,8 +187,9 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
placeholder=VisualMediaType.STOCK_VIDEOS,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="The URL of the created video")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
@@ -341,7 +336,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
class AIAdMakerVideoCreatorBlock(Block):
|
||||
"""Generates a 30‑second vertical AI advert using optional user‑supplied imagery."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
@@ -369,8 +364,9 @@ class AIAdMakerVideoCreatorBlock(Block):
|
||||
description="Restrict visuals to supplied images only.", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="URL of the finished advert")
|
||||
error: str = SchemaField(description="Error message on failure")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
@@ -528,7 +524,7 @@ class AIAdMakerVideoCreatorBlock(Block):
|
||||
class AIScreenshotToVideoAdBlock(Block):
|
||||
"""Creates an advert where the supplied screenshot is narrated by an AI avatar."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(description="Revid.ai API key")
|
||||
@@ -546,8 +542,9 @@ class AIScreenshotToVideoAdBlock(Block):
|
||||
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="Rendered video URL")
|
||||
error: str = SchemaField(description="Error, if encountered")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
|
||||
@@ -661,167 +661,6 @@ async def update_field(
|
||||
#################################################################
|
||||
|
||||
|
||||
async def get_table_schema(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
table_id_or_name: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Get the schema for a specific table, including all field definitions.
|
||||
|
||||
Args:
|
||||
credentials: Airtable API credentials
|
||||
base_id: The base ID
|
||||
table_id_or_name: The table ID or name
|
||||
|
||||
Returns:
|
||||
Dict containing table schema with fields information
|
||||
"""
|
||||
# First get all tables to find the right one
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
tables = data.get("tables", [])
|
||||
|
||||
# Find the matching table
|
||||
for table in tables:
|
||||
if table.get("id") == table_id_or_name or table.get("name") == table_id_or_name:
|
||||
return table
|
||||
|
||||
raise ValueError(f"Table '{table_id_or_name}' not found in base '{base_id}'")
|
||||
|
||||
|
||||
def get_empty_value_for_field(field_type: str) -> Any:
|
||||
"""
|
||||
Return the appropriate empty value for a given Airtable field type.
|
||||
|
||||
Args:
|
||||
field_type: The Airtable field type
|
||||
|
||||
Returns:
|
||||
The appropriate empty value for that field type
|
||||
"""
|
||||
# Fields that should be false when empty
|
||||
if field_type == "checkbox":
|
||||
return False
|
||||
|
||||
# Fields that should be empty arrays
|
||||
if field_type in [
|
||||
"multipleSelects",
|
||||
"multipleRecordLinks",
|
||||
"multipleAttachments",
|
||||
"multipleLookupValues",
|
||||
"multipleCollaborators",
|
||||
]:
|
||||
return []
|
||||
|
||||
# Fields that should be 0 when empty (numeric types)
|
||||
if field_type in [
|
||||
"number",
|
||||
"percent",
|
||||
"currency",
|
||||
"rating",
|
||||
"duration",
|
||||
"count",
|
||||
"autoNumber",
|
||||
]:
|
||||
return 0
|
||||
|
||||
# Fields that should be empty strings
|
||||
if field_type in [
|
||||
"singleLineText",
|
||||
"multilineText",
|
||||
"email",
|
||||
"url",
|
||||
"phoneNumber",
|
||||
"richText",
|
||||
"barcode",
|
||||
]:
|
||||
return ""
|
||||
|
||||
# Everything else gets null (dates, single selects, formulas, etc.)
|
||||
return None
|
||||
|
||||
|
||||
async def normalize_records(
|
||||
records: list[dict],
|
||||
table_schema: dict,
|
||||
include_field_metadata: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Normalize Airtable records to include all fields with proper empty values.
|
||||
|
||||
Args:
|
||||
records: List of record objects from Airtable API
|
||||
table_schema: Table schema containing field definitions
|
||||
include_field_metadata: Whether to include field metadata in response
|
||||
|
||||
Returns:
|
||||
Dict with normalized records and optionally field metadata
|
||||
"""
|
||||
fields = table_schema.get("fields", [])
|
||||
|
||||
# Normalize each record
|
||||
normalized_records = []
|
||||
for record in records:
|
||||
normalized = {
|
||||
"id": record.get("id"),
|
||||
"createdTime": record.get("createdTime"),
|
||||
"fields": {},
|
||||
}
|
||||
|
||||
# Add existing fields
|
||||
existing_fields = record.get("fields", {})
|
||||
|
||||
# Add all fields from schema, using empty values for missing ones
|
||||
for field in fields:
|
||||
field_name = field["name"]
|
||||
field_type = field["type"]
|
||||
|
||||
if field_name in existing_fields:
|
||||
# Field exists, use its value
|
||||
normalized["fields"][field_name] = existing_fields[field_name]
|
||||
else:
|
||||
# Field is missing, add appropriate empty value
|
||||
normalized["fields"][field_name] = get_empty_value_for_field(field_type)
|
||||
|
||||
normalized_records.append(normalized)
|
||||
|
||||
# Build result dictionary
|
||||
if include_field_metadata:
|
||||
field_metadata = {}
|
||||
for field in fields:
|
||||
metadata = {"type": field["type"], "id": field["id"]}
|
||||
|
||||
# Add type-specific metadata
|
||||
options = field.get("options", {})
|
||||
if field["type"] == "currency" and "symbol" in options:
|
||||
metadata["symbol"] = options["symbol"]
|
||||
metadata["precision"] = options.get("precision", 2)
|
||||
elif field["type"] == "duration" and "durationFormat" in options:
|
||||
metadata["format"] = options["durationFormat"]
|
||||
elif field["type"] == "percent" and "precision" in options:
|
||||
metadata["precision"] = options["precision"]
|
||||
elif (
|
||||
field["type"] in ["singleSelect", "multipleSelects"]
|
||||
and "choices" in options
|
||||
):
|
||||
metadata["choices"] = [choice["name"] for choice in options["choices"]]
|
||||
elif field["type"] == "rating" and "max" in options:
|
||||
metadata["max"] = options["max"]
|
||||
metadata["icon"] = options.get("icon", "star")
|
||||
metadata["color"] = options.get("color", "yellowBright")
|
||||
|
||||
field_metadata[field["name"]] = metadata
|
||||
|
||||
return {"records": normalized_records, "field_metadata": field_metadata}
|
||||
else:
|
||||
return {"records": normalized_records}
|
||||
|
||||
|
||||
async def list_records(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
@@ -1410,26 +1249,3 @@ async def list_bases(
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
|
||||
async def get_base_tables(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Get all tables for a specific base.
|
||||
|
||||
Args:
|
||||
credentials: Airtable API credentials
|
||||
base_id: The ID of the base
|
||||
|
||||
Returns:
|
||||
list[dict]: List of table objects with their schemas
|
||||
"""
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
return data.get("tables", [])
|
||||
|
||||
@@ -9,22 +9,21 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import create_base, get_base_tables, list_bases
|
||||
from ._api import create_base, list_bases
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableCreateBaseBlock(Block):
|
||||
"""
|
||||
Creates a new base in an Airtable workspace, or returns existing base if one with the same name exists.
|
||||
Creates a new base in an Airtable workspace.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -32,10 +31,6 @@ class AirtableCreateBaseBlock(Block):
|
||||
description="The workspace ID where the base will be created"
|
||||
)
|
||||
name: str = SchemaField(description="The name of the new base")
|
||||
find_existing: bool = SchemaField(
|
||||
description="If true, return existing base with same name instead of creating duplicate",
|
||||
default=True,
|
||||
)
|
||||
tables: list[dict] = SchemaField(
|
||||
description="At least one table and field must be specified. Array of table objects to create in the base. Each table should have 'name' and 'fields' properties",
|
||||
default=[
|
||||
@@ -54,19 +49,15 @@ class AirtableCreateBaseBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
base_id: str = SchemaField(description="The ID of the created or found base")
|
||||
class Output(BlockSchema):
|
||||
base_id: str = SchemaField(description="The ID of the created base")
|
||||
tables: list[dict] = SchemaField(description="Array of table objects")
|
||||
table: dict = SchemaField(description="A single table object")
|
||||
was_created: bool = SchemaField(
|
||||
description="True if a new base was created, False if existing was found",
|
||||
default=True,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f59b88a8-54ce-4676-a508-fd614b4e8dce",
|
||||
description="Create or find a base in Airtable",
|
||||
description="Create a new base in Airtable",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
@@ -75,31 +66,6 @@ class AirtableCreateBaseBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# If find_existing is true, check if a base with this name already exists
|
||||
if input_data.find_existing:
|
||||
# List all bases to check for existing one with same name
|
||||
# Note: Airtable API doesn't have a direct search, so we need to list and filter
|
||||
existing_bases = await list_bases(credentials)
|
||||
|
||||
for base in existing_bases.get("bases", []):
|
||||
if base.get("name") == input_data.name:
|
||||
# Base already exists, return it
|
||||
base_id = base.get("id")
|
||||
yield "base_id", base_id
|
||||
yield "was_created", False
|
||||
|
||||
# Get the tables for this base
|
||||
try:
|
||||
tables = await get_base_tables(credentials, base_id)
|
||||
yield "tables", tables
|
||||
for table in tables:
|
||||
yield "table", table
|
||||
except Exception:
|
||||
# If we can't get tables, return empty list
|
||||
yield "tables", []
|
||||
return
|
||||
|
||||
# No existing base found or find_existing is false, create new one
|
||||
data = await create_base(
|
||||
credentials,
|
||||
input_data.workspace_id,
|
||||
@@ -108,7 +74,6 @@ class AirtableCreateBaseBlock(Block):
|
||||
)
|
||||
|
||||
yield "base_id", data.get("id", None)
|
||||
yield "was_created", True
|
||||
yield "tables", data.get("tables", [])
|
||||
for table in data.get("tables", []):
|
||||
yield "table", table
|
||||
@@ -119,7 +84,7 @@ class AirtableListBasesBlock(Block):
|
||||
Lists all bases in an Airtable workspace that the user has access to.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -130,7 +95,7 @@ class AirtableListBasesBlock(Block):
|
||||
description="Pagination offset from previous request", default=""
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
bases: list[dict] = SchemaField(description="Array of base objects")
|
||||
offset: Optional[str] = SchemaField(
|
||||
description="Offset for next page (null if no more bases)", default=None
|
||||
|
||||
@@ -2,15 +2,14 @@
|
||||
Airtable record operation blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional, cast
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -19,9 +18,7 @@ from ._api import (
|
||||
create_record,
|
||||
delete_multiple_records,
|
||||
get_record,
|
||||
get_table_schema,
|
||||
list_records,
|
||||
normalize_records,
|
||||
update_multiple_records,
|
||||
)
|
||||
from ._config import airtable
|
||||
@@ -32,7 +29,7 @@ class AirtableListRecordsBlock(Block):
|
||||
Lists records from an Airtable table with optional filtering, sorting, and pagination.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -57,24 +54,12 @@ class AirtableListRecordsBlock(Block):
|
||||
return_fields: list[str] = SchemaField(
|
||||
description="Specific fields to return (comma-separated)", default=[]
|
||||
)
|
||||
normalize_output: bool = SchemaField(
|
||||
description="Normalize output to include all fields with proper empty values (disable to skip schema fetch and get raw Airtable response)",
|
||||
default=True,
|
||||
)
|
||||
include_field_metadata: bool = SchemaField(
|
||||
description="Include field type and configuration metadata (requires normalize_output=true)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of record objects")
|
||||
offset: Optional[str] = SchemaField(
|
||||
description="Offset for next page (null if no more records)", default=None
|
||||
)
|
||||
field_metadata: Optional[dict] = SchemaField(
|
||||
description="Field type and configuration metadata (only when include_field_metadata=true)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -88,7 +73,6 @@ class AirtableListRecordsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
data = await list_records(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -104,33 +88,8 @@ class AirtableListRecordsBlock(Block):
|
||||
fields=input_data.return_fields if input_data.return_fields else None,
|
||||
)
|
||||
|
||||
records = data.get("records", [])
|
||||
|
||||
# Normalize output if requested
|
||||
if input_data.normalize_output:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the records
|
||||
normalized_data = await normalize_records(
|
||||
records,
|
||||
table_schema,
|
||||
include_field_metadata=input_data.include_field_metadata,
|
||||
)
|
||||
|
||||
yield "records", normalized_data["records"]
|
||||
yield "offset", data.get("offset", None)
|
||||
|
||||
if (
|
||||
input_data.include_field_metadata
|
||||
and "field_metadata" in normalized_data
|
||||
):
|
||||
yield "field_metadata", normalized_data["field_metadata"]
|
||||
else:
|
||||
yield "records", records
|
||||
yield "offset", data.get("offset", None)
|
||||
yield "records", data.get("records", [])
|
||||
yield "offset", data.get("offset", None)
|
||||
|
||||
|
||||
class AirtableGetRecordBlock(Block):
|
||||
@@ -138,30 +97,18 @@ class AirtableGetRecordBlock(Block):
|
||||
Retrieves a single record from an Airtable table by its ID.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name")
|
||||
record_id: str = SchemaField(description="The record ID to retrieve")
|
||||
normalize_output: bool = SchemaField(
|
||||
description="Normalize output to include all fields with proper empty values (disable to skip schema fetch and get raw Airtable response)",
|
||||
default=True,
|
||||
)
|
||||
include_field_metadata: bool = SchemaField(
|
||||
description="Include field type and configuration metadata (requires normalize_output=true)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
id: str = SchemaField(description="The record ID")
|
||||
fields: dict = SchemaField(description="The record fields")
|
||||
created_time: str = SchemaField(description="The record created time")
|
||||
field_metadata: Optional[dict] = SchemaField(
|
||||
description="Field type and configuration metadata (only when include_field_metadata=true)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -175,7 +122,6 @@ class AirtableGetRecordBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
record = await get_record(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -183,34 +129,9 @@ class AirtableGetRecordBlock(Block):
|
||||
input_data.record_id,
|
||||
)
|
||||
|
||||
# Normalize output if requested
|
||||
if input_data.normalize_output:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the single record (wrap in list and unwrap result)
|
||||
normalized_data = await normalize_records(
|
||||
[record],
|
||||
table_schema,
|
||||
include_field_metadata=input_data.include_field_metadata,
|
||||
)
|
||||
|
||||
normalized_record = normalized_data["records"][0]
|
||||
yield "id", normalized_record.get("id", None)
|
||||
yield "fields", normalized_record.get("fields", None)
|
||||
yield "created_time", normalized_record.get("createdTime", None)
|
||||
|
||||
if (
|
||||
input_data.include_field_metadata
|
||||
and "field_metadata" in normalized_data
|
||||
):
|
||||
yield "field_metadata", normalized_data["field_metadata"]
|
||||
else:
|
||||
yield "id", record.get("id", None)
|
||||
yield "fields", record.get("fields", None)
|
||||
yield "created_time", record.get("createdTime", None)
|
||||
yield "id", record.get("id", None)
|
||||
yield "fields", record.get("fields", None)
|
||||
yield "created_time", record.get("createdTime", None)
|
||||
|
||||
|
||||
class AirtableCreateRecordsBlock(Block):
|
||||
@@ -218,7 +139,7 @@ class AirtableCreateRecordsBlock(Block):
|
||||
Creates one or more records in an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -227,10 +148,6 @@ class AirtableCreateRecordsBlock(Block):
|
||||
records: list[dict] = SchemaField(
|
||||
description="Array of records to create (each with 'fields' object)"
|
||||
)
|
||||
skip_normalization: bool = SchemaField(
|
||||
description="Skip output normalization to get raw Airtable response (faster but may have missing fields)",
|
||||
default=False,
|
||||
)
|
||||
typecast: bool = SchemaField(
|
||||
description="Automatically convert string values to appropriate types",
|
||||
default=False,
|
||||
@@ -240,7 +157,7 @@ class AirtableCreateRecordsBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of created record objects")
|
||||
details: dict = SchemaField(description="Details of the created records")
|
||||
|
||||
@@ -256,7 +173,7 @@ class AirtableCreateRecordsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
# The create_record API expects records in a specific format
|
||||
data = await create_record(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -265,22 +182,8 @@ class AirtableCreateRecordsBlock(Block):
|
||||
typecast=input_data.typecast if input_data.typecast else None,
|
||||
return_fields_by_field_id=input_data.return_fields_by_field_id,
|
||||
)
|
||||
result_records = cast(list[dict], data.get("records", []))
|
||||
|
||||
# Normalize output unless explicitly disabled
|
||||
if not input_data.skip_normalization and result_records:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the records
|
||||
normalized_data = await normalize_records(
|
||||
result_records, table_schema, include_field_metadata=False
|
||||
)
|
||||
result_records = normalized_data["records"]
|
||||
|
||||
yield "records", result_records
|
||||
yield "records", data.get("records", [])
|
||||
details = data.get("details", None)
|
||||
if details:
|
||||
yield "details", details
|
||||
@@ -291,7 +194,7 @@ class AirtableUpdateRecordsBlock(Block):
|
||||
Updates one or more existing records in an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -307,7 +210,7 @@ class AirtableUpdateRecordsBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of updated record objects")
|
||||
|
||||
def __init__(self):
|
||||
@@ -340,7 +243,7 @@ class AirtableDeleteRecordsBlock(Block):
|
||||
Deletes one or more records from an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -352,7 +255,7 @@ class AirtableDeleteRecordsBlock(Block):
|
||||
description="Array of upto 10 record IDs to delete"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of deletion results")
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -7,8 +7,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
@@ -24,13 +23,13 @@ class AirtableListSchemaBlock(Block):
|
||||
fields, and views.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
base_schema: dict = SchemaField(
|
||||
description="Complete base schema with tables, fields, and views"
|
||||
)
|
||||
@@ -67,7 +66,7 @@ class AirtableCreateTableBlock(Block):
|
||||
Creates a new table in an Airtable base with specified fields and views.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -78,7 +77,7 @@ class AirtableCreateTableBlock(Block):
|
||||
default=[{"name": "Name", "type": "singleLineText"}],
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
table: dict = SchemaField(description="Created table object")
|
||||
table_id: str = SchemaField(description="ID of the created table")
|
||||
|
||||
@@ -110,7 +109,7 @@ class AirtableUpdateTableBlock(Block):
|
||||
Updates an existing table's properties such as name or description.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -126,7 +125,7 @@ class AirtableUpdateTableBlock(Block):
|
||||
description="The date dependency of the table to update", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
table: dict = SchemaField(description="Updated table object")
|
||||
|
||||
def __init__(self):
|
||||
@@ -158,7 +157,7 @@ class AirtableCreateFieldBlock(Block):
|
||||
Adds a new field (column) to an existing Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -177,7 +176,7 @@ class AirtableCreateFieldBlock(Block):
|
||||
description="The options of the field to create", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
field: dict = SchemaField(description="Created field object")
|
||||
field_id: str = SchemaField(description="ID of the created field")
|
||||
|
||||
@@ -210,7 +209,7 @@ class AirtableUpdateFieldBlock(Block):
|
||||
Updates an existing field's properties in an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -226,7 +225,7 @@ class AirtableUpdateFieldBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
field: dict = SchemaField(description="Updated field object")
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -3,8 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
BlockWebhookConfig,
|
||||
CredentialsMetaInput,
|
||||
@@ -33,7 +32,7 @@ class AirtableWebhookTriggerBlock(Block):
|
||||
Thin wrapper just forwards the payloads one at a time to the next block.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
@@ -44,7 +43,7 @@ class AirtableWebhookTriggerBlock(Block):
|
||||
description="Airtable webhook event filter"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
payload: WebhookPayload = SchemaField(description="Airtable webhook payload")
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -10,20 +10,14 @@ from backend.blocks.apollo.models import (
|
||||
PrimaryPhone,
|
||||
SearchOrganizationsRequest,
|
||||
)
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class SearchOrganizationsBlock(Block):
|
||||
"""Search for organizations in Apollo"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
organization_num_employees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
@@ -75,7 +69,7 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
organizations: list[Organization] = SchemaField(
|
||||
description="List of organizations found",
|
||||
default_factory=list,
|
||||
|
||||
@@ -14,20 +14,14 @@ from backend.blocks.apollo.models import (
|
||||
SearchPeopleRequest,
|
||||
SenorityLevels,
|
||||
)
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class SearchPeopleBlock(Block):
|
||||
"""Search for people in Apollo"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
person_titles: list[str] = SchemaField(
|
||||
description="""Job titles held by the people you want to find. For a person to be included in search results, they only need to match 1 of the job titles you add. Adding more job titles expands your search results.
|
||||
|
||||
@@ -115,7 +109,7 @@ class SearchPeopleBlock(Block):
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
people: list[Contact] = SchemaField(
|
||||
description="List of people found",
|
||||
default_factory=list,
|
||||
|
||||
@@ -6,20 +6,14 @@ from backend.blocks.apollo._auth import (
|
||||
ApolloCredentialsInput,
|
||||
)
|
||||
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class GetPersonDetailBlock(Block):
|
||||
"""Get detailed person data with Apollo API, including email reveal"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
person_id: str = SchemaField(
|
||||
description="Apollo person ID to enrich (most accurate method)",
|
||||
default="",
|
||||
@@ -74,7 +68,7 @@ class GetPersonDetailBlock(Block):
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
contact: Contact = SchemaField(
|
||||
description="Enriched contact information",
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.block import BlockSchemaInput
|
||||
from backend.data.block import BlockSchema
|
||||
from backend.data.model import SchemaField, UserIntegrations
|
||||
from backend.integrations.ayrshare import AyrshareClient
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
@@ -17,7 +17,7 @@ async def get_profile_key(user_id: str):
|
||||
return user_integrations.managed_credentials.ayrshare_profile_key
|
||||
|
||||
|
||||
class BaseAyrshareInput(BlockSchemaInput):
|
||||
class BaseAyrshareInput(BlockSchema):
|
||||
"""Base input model for Ayrshare social media posts with common fields."""
|
||||
|
||||
post: str = SchemaField(
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -38,7 +38,7 @@ class PostToBlueskyBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -101,7 +101,7 @@ class PostToFacebookBlock(Block):
|
||||
description="URL for custom link preview", default="", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -94,7 +94,7 @@ class PostToGMBBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -94,7 +94,7 @@ class PostToInstagramBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -94,7 +94,7 @@ class PostToLinkedInBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -73,7 +73,7 @@ class PostToPinterestBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -19,7 +19,7 @@ class PostToRedditBlock(Block):
|
||||
|
||||
pass # Uses all base fields
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -43,7 +43,7 @@ class PostToSnapchatBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -38,7 +38,7 @@ class PostToTelegramBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -31,7 +31,7 @@ class PostToThreadsBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -98,7 +98,7 @@ class PostToTikTokBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -97,7 +97,7 @@ class PostToXBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -119,7 +119,7 @@ class PostToYouTubeBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
|
||||
@@ -9,8 +9,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
@@ -24,7 +23,7 @@ class BaasBotJoinMeetingBlock(Block):
|
||||
Deploy a bot immediately or at a scheduled start_time to join and record a meeting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
@@ -58,7 +57,7 @@ class BaasBotJoinMeetingBlock(Block):
|
||||
description="Custom metadata to attach to the bot", default={}
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
bot_id: str = SchemaField(description="UUID of the deployed bot")
|
||||
join_response: dict = SchemaField(
|
||||
description="Full response from join operation"
|
||||
@@ -104,13 +103,13 @@ class BaasBotLeaveMeetingBlock(Block):
|
||||
Force the bot to exit the call.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot to remove from meeting")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
left: bool = SchemaField(description="Whether the bot successfully left")
|
||||
|
||||
def __init__(self):
|
||||
@@ -139,7 +138,7 @@ class BaasBotFetchMeetingDataBlock(Block):
|
||||
Pull MP4 URL, transcript & metadata for a completed meeting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
@@ -148,7 +147,7 @@ class BaasBotFetchMeetingDataBlock(Block):
|
||||
description="Include transcript data in response", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
mp4_url: str = SchemaField(
|
||||
description="URL to download the meeting recording (time-limited)"
|
||||
)
|
||||
@@ -186,13 +185,13 @@ class BaasBotDeleteRecordingBlock(Block):
|
||||
Purge MP4 + transcript data for privacy or storage management.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot whose data to delete")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
deleted: bool = SchemaField(
|
||||
description="Whether the data was successfully deleted"
|
||||
)
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .text_overlay import BannerbearTextOverlayBlock
|
||||
|
||||
__all__ = ["BannerbearTextOverlayBlock"]
|
||||
@@ -1,8 +0,0 @@
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
bannerbear = (
|
||||
ProviderBuilder("bannerbear")
|
||||
.with_api_key("BANNERBEAR_API_KEY", "Bannerbear API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
@@ -1,239 +0,0 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import bannerbear
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="bannerbear",
|
||||
api_key=SecretStr("mock-bannerbear-api-key"),
|
||||
title="Mock Bannerbear API Key",
|
||||
)
|
||||
|
||||
|
||||
class TextModification(BlockSchemaInput):
|
||||
name: str = SchemaField(
|
||||
description="The name of the layer to modify in the template"
|
||||
)
|
||||
text: str = SchemaField(description="The text content to add to this layer")
|
||||
color: str = SchemaField(
|
||||
description="Hex color code for the text (e.g., '#FF0000')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
font_family: str = SchemaField(
|
||||
description="Font family to use for the text",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
font_size: int = SchemaField(
|
||||
description="Font size in pixels",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
font_weight: str = SchemaField(
|
||||
description="Font weight (e.g., 'bold', 'normal')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
text_align: str = SchemaField(
|
||||
description="Text alignment (left, center, right)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class BannerbearTextOverlayBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = bannerbear.credentials_field(
|
||||
description="API credentials for Bannerbear"
|
||||
)
|
||||
template_id: str = SchemaField(
|
||||
description="The unique ID of your Bannerbear template"
|
||||
)
|
||||
project_id: str = SchemaField(
|
||||
description="Optional: Project ID (required when using Master API Key)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
text_modifications: List[TextModification] = SchemaField(
|
||||
description="List of text layers to modify in the template"
|
||||
)
|
||||
image_url: str = SchemaField(
|
||||
description="Optional: URL of an image to use in the template",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
image_layer_name: str = SchemaField(
|
||||
description="Optional: Name of the image layer in the template",
|
||||
default="photo",
|
||||
advanced=True,
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="Optional: URL to receive webhook notification when image is ready",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
metadata: str = SchemaField(
|
||||
description="Optional: Custom metadata to attach to the image",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the image generation was successfully initiated"
|
||||
)
|
||||
image_url: str = SchemaField(
|
||||
description="URL of the generated image (if synchronous) or placeholder"
|
||||
)
|
||||
uid: str = SchemaField(description="Unique identifier for the generated image")
|
||||
status: str = SchemaField(description="Status of the image generation")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c7d3a5c2-05fc-450e-8dce-3b0e04626009",
|
||||
description="Add text overlay to images using Bannerbear templates. Perfect for creating social media graphics, marketing materials, and dynamic image content.",
|
||||
categories={BlockCategory.PRODUCTIVITY, BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"template_id": "jJWBKNELpQPvbX5R93Gk",
|
||||
"text_modifications": [
|
||||
{
|
||||
"name": "headline",
|
||||
"text": "Amazing Product Launch!",
|
||||
"color": "#FF0000",
|
||||
},
|
||||
{
|
||||
"name": "subtitle",
|
||||
"text": "50% OFF Today Only",
|
||||
},
|
||||
],
|
||||
"credentials": {
|
||||
"provider": "bannerbear",
|
||||
"id": str(uuid.uuid4()),
|
||||
"type": "api_key",
|
||||
},
|
||||
},
|
||||
test_output=[
|
||||
("success", True),
|
||||
("image_url", "https://cdn.bannerbear.com/test-image.jpg"),
|
||||
("uid", "test-uid-123"),
|
||||
("status", "completed"),
|
||||
],
|
||||
test_mock={
|
||||
"_make_api_request": lambda *args, **kwargs: {
|
||||
"uid": "test-uid-123",
|
||||
"status": "completed",
|
||||
"image_url": "https://cdn.bannerbear.com/test-image.jpg",
|
||||
}
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def _make_api_request(self, payload: dict, api_key: str) -> dict:
|
||||
"""Make the actual API request to Bannerbear. This is separated for easy mocking in tests."""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = await Requests().post(
|
||||
"https://sync.api.bannerbear.com/v2/images",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status in [200, 201, 202]:
|
||||
return response.json()
|
||||
else:
|
||||
error_msg = f"API request failed with status {response.status}"
|
||||
if response.text:
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_msg = (
|
||||
f"{error_msg}: {error_data.get('message', response.text)}"
|
||||
)
|
||||
except Exception:
|
||||
error_msg = f"{error_msg}: {response.text}"
|
||||
raise Exception(error_msg)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Build the modifications array
|
||||
modifications = []
|
||||
|
||||
# Add text modifications
|
||||
for text_mod in input_data.text_modifications:
|
||||
mod_data: Dict[str, Any] = {
|
||||
"name": text_mod.name,
|
||||
"text": text_mod.text,
|
||||
}
|
||||
|
||||
# Add optional text styling parameters only if they have values
|
||||
if text_mod.color and text_mod.color.strip():
|
||||
mod_data["color"] = text_mod.color
|
||||
if text_mod.font_family and text_mod.font_family.strip():
|
||||
mod_data["font_family"] = text_mod.font_family
|
||||
if text_mod.font_size and text_mod.font_size > 0:
|
||||
mod_data["font_size"] = text_mod.font_size
|
||||
if text_mod.font_weight and text_mod.font_weight.strip():
|
||||
mod_data["font_weight"] = text_mod.font_weight
|
||||
if text_mod.text_align and text_mod.text_align.strip():
|
||||
mod_data["text_align"] = text_mod.text_align
|
||||
|
||||
modifications.append(mod_data)
|
||||
|
||||
# Add image modification if provided and not empty
|
||||
if input_data.image_url and input_data.image_url.strip():
|
||||
modifications.append(
|
||||
{
|
||||
"name": input_data.image_layer_name,
|
||||
"image_url": input_data.image_url,
|
||||
}
|
||||
)
|
||||
|
||||
# Build the request payload - only include non-empty optional fields
|
||||
payload = {
|
||||
"template": input_data.template_id,
|
||||
"modifications": modifications,
|
||||
}
|
||||
|
||||
# Add project_id if provided (required for Master API keys)
|
||||
if input_data.project_id and input_data.project_id.strip():
|
||||
payload["project_id"] = input_data.project_id
|
||||
|
||||
if input_data.webhook_url and input_data.webhook_url.strip():
|
||||
payload["webhook_url"] = input_data.webhook_url
|
||||
if input_data.metadata and input_data.metadata.strip():
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
# Make the API request using the private method
|
||||
data = await self._make_api_request(
|
||||
payload, credentials.api_key.get_secret_value()
|
||||
)
|
||||
|
||||
# Synchronous request - image should be ready
|
||||
yield "success", True
|
||||
yield "image_url", data.get("image_url", "")
|
||||
yield "uid", data.get("uid", "")
|
||||
yield "status", data.get("status", "completed")
|
||||
@@ -1,21 +1,14 @@
|
||||
import enum
|
||||
from typing import Any
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType, convert
|
||||
|
||||
|
||||
class FileStoreBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
file_in: MediaFileType = SchemaField(
|
||||
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
|
||||
)
|
||||
@@ -26,7 +19,7 @@ class FileStoreBlock(Block):
|
||||
title="Produce Base64 Output",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
file_out: MediaFileType = SchemaField(
|
||||
description="The relative path to the stored file in the temporary directory."
|
||||
)
|
||||
@@ -64,7 +57,7 @@ class StoreValueBlock(Block):
|
||||
The block output will be static, the output can be consumed multiple times.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
input: Any = SchemaField(
|
||||
description="Trigger the block to produce the output. "
|
||||
"The value is only used when `data` is None."
|
||||
@@ -75,7 +68,7 @@ class StoreValueBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="The stored data retained in the block.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -101,10 +94,10 @@ class StoreValueBlock(Block):
|
||||
|
||||
|
||||
class PrintToConsoleBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
text: Any = SchemaField(description="The data to print to the console.")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="The data printed to the console.")
|
||||
status: str = SchemaField(description="The status of the print operation.")
|
||||
|
||||
@@ -128,10 +121,10 @@ class PrintToConsoleBlock(Block):
|
||||
|
||||
|
||||
class NoteBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(description="The text to display in the sticky note.")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
output: str = SchemaField(description="The text to display in the sticky note.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -161,14 +154,15 @@ class TypeOptions(enum.Enum):
|
||||
|
||||
|
||||
class UniversalTypeConverterBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
value: Any = SchemaField(
|
||||
description="The value to convert to a universal type."
|
||||
)
|
||||
type: TypeOptions = SchemaField(description="The type to convert the value to.")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
value: Any = SchemaField(description="The converted value.")
|
||||
error: str = SchemaField(description="Error message if conversion failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -201,10 +195,10 @@ class ReverseListOrderBlock(Block):
|
||||
A block which takes in a list and returns it in the opposite order.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
input_list: list[Any] = SchemaField(description="The list to reverse")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
reversed_list: list[Any] = SchemaField(description="The list in reversed order")
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -2,13 +2,7 @@ import os
|
||||
import re
|
||||
from typing import Type
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
@@ -21,12 +15,12 @@ class BlockInstallationBlock(Block):
|
||||
for development purposes only.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
code: str = SchemaField(
|
||||
description="Python code of the block to be installed",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
success: str = SchemaField(
|
||||
description="Success message if the block is installed successfully",
|
||||
)
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.type import convert
|
||||
|
||||
@@ -22,7 +16,7 @@ class ComparisonOperator(Enum):
|
||||
|
||||
|
||||
class ConditionBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
value1: Any = SchemaField(
|
||||
description="Enter the first value for comparison",
|
||||
placeholder="For example: 10 or 'hello' or True",
|
||||
@@ -46,7 +40,7 @@ class ConditionBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
result: bool = SchemaField(
|
||||
description="The result of the condition evaluation (True or False)"
|
||||
)
|
||||
@@ -117,7 +111,7 @@ class ConditionBlock(Block):
|
||||
|
||||
|
||||
class IfInputMatchesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
input: Any = SchemaField(
|
||||
description="The input to match against",
|
||||
placeholder="For example: 10 or 'hello' or True",
|
||||
@@ -137,7 +131,7 @@ class IfInputMatchesBlock(Block):
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
result: bool = SchemaField(
|
||||
description="The result of the condition evaluation (True or False)"
|
||||
)
|
||||
|
||||
@@ -1,18 +1,10 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from e2b_code_interpreter import AsyncSandbox
|
||||
from e2b_code_interpreter import Result as E2BExecutionResult
|
||||
from e2b_code_interpreter.charts import Chart as E2BExecutionResultChart
|
||||
from pydantic import BaseModel, Field, JsonValue, SecretStr
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -44,135 +36,14 @@ class ProgrammingLanguage(Enum):
|
||||
JAVA = "java"
|
||||
|
||||
|
||||
class MainCodeExecutionResult(BaseModel):
|
||||
"""
|
||||
*Pydantic model mirroring `e2b_code_interpreter.Result`*
|
||||
|
||||
Represents the data to be displayed as a result of executing a cell in a Jupyter notebook.
|
||||
The result is similar to the structure returned by ipython kernel: https://ipython.readthedocs.io/en/stable/development/execution.html#execution-semantics
|
||||
|
||||
The result can contain multiple types of data, such as text, images, plots, etc. Each type of data is represented
|
||||
as a string, and the result can contain multiple types of data. The display calls don't have to have text representation,
|
||||
for the actual result the representation is always present for the result, the other representations are always optional.
|
||||
""" # noqa
|
||||
|
||||
class Chart(BaseModel, E2BExecutionResultChart):
|
||||
pass
|
||||
|
||||
text: Optional[str] = None
|
||||
html: Optional[str] = None
|
||||
markdown: Optional[str] = None
|
||||
svg: Optional[str] = None
|
||||
png: Optional[str] = None
|
||||
jpeg: Optional[str] = None
|
||||
pdf: Optional[str] = None
|
||||
latex: Optional[str] = None
|
||||
json_data: Optional[JsonValue] = Field(None, alias="json")
|
||||
javascript: Optional[str] = None
|
||||
data: Optional[dict] = None
|
||||
chart: Optional[Chart] = None
|
||||
extra: Optional[dict] = None
|
||||
"""Extra data that can be included. Not part of the standard types."""
|
||||
|
||||
|
||||
class CodeExecutionResult(MainCodeExecutionResult):
|
||||
__doc__ = MainCodeExecutionResult.__doc__
|
||||
|
||||
is_main_result: bool = False
|
||||
"""Whether this data is the main result of the cell. Data can be produced by display calls of which can be multiple in a cell.""" # noqa
|
||||
|
||||
|
||||
class BaseE2BExecutorMixin:
|
||||
"""Shared implementation methods for E2B executor blocks."""
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
api_key: str,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
template_id: str = "",
|
||||
setup_commands: Optional[list[str]] = None,
|
||||
timeout: Optional[int] = None,
|
||||
sandbox_id: Optional[str] = None,
|
||||
dispose_sandbox: bool = False,
|
||||
):
|
||||
"""
|
||||
Unified code execution method that handles all three use cases:
|
||||
1. Create new sandbox and execute (ExecuteCodeBlock)
|
||||
2. Create new sandbox, execute, and return sandbox_id (InstantiateCodeSandboxBlock)
|
||||
3. Connect to existing sandbox and execute (ExecuteCodeStepBlock)
|
||||
""" # noqa
|
||||
sandbox = None
|
||||
try:
|
||||
if sandbox_id:
|
||||
# Connect to existing sandbox (ExecuteCodeStepBlock case)
|
||||
sandbox = await AsyncSandbox.connect(
|
||||
sandbox_id=sandbox_id, api_key=api_key
|
||||
)
|
||||
else:
|
||||
# Create new sandbox (ExecuteCodeBlock/InstantiateCodeSandboxBlock case)
|
||||
sandbox = await AsyncSandbox.create(
|
||||
api_key=api_key, template=template_id, timeout=timeout
|
||||
)
|
||||
if setup_commands:
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
|
||||
# Execute the code
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox on error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
results = execution.results
|
||||
text_output = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return results, text_output, stdout_logs, stderr_logs, sandbox.sandbox_id
|
||||
finally:
|
||||
# Dispose of sandbox if requested to reduce usage costs
|
||||
if dispose_sandbox and sandbox:
|
||||
await sandbox.kill()
|
||||
|
||||
def process_execution_results(
|
||||
self, results: list[E2BExecutionResult]
|
||||
) -> tuple[dict[str, Any] | None, list[dict[str, Any]]]:
|
||||
"""Process and filter execution results."""
|
||||
# Filter out empty formats and convert to dicts
|
||||
processed_results = [
|
||||
{
|
||||
f: value
|
||||
for f in [*r.formats(), "extra", "is_main_result"]
|
||||
if (value := getattr(r, f, None)) is not None
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
if main_result := next(
|
||||
(r for r in processed_results if r.get("is_main_result")), None
|
||||
):
|
||||
# Make main_result a copy we can modify & remove is_main_result
|
||||
(main_result := {**main_result}).pop("is_main_result")
|
||||
|
||||
return main_result, processed_results
|
||||
|
||||
|
||||
class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
class CodeExecutionBlock(Block):
|
||||
# TODO : Add support to upload and download files
|
||||
# NOTE: Currently, you can only customize the CPU and Memory
|
||||
# by creating a pre customized sandbox template
|
||||
class Input(BlockSchemaInput):
|
||||
# Currently, You can customized the CPU and Memory, only by creating a pre customized sandbox template
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description=(
|
||||
"Enter your API key for the E2B platform. "
|
||||
"You can get it in here - https://e2b.dev/docs"
|
||||
),
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
)
|
||||
|
||||
# Todo : Option to run commond in background
|
||||
@@ -205,14 +76,6 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
description="Execution timeout in seconds", default=300
|
||||
)
|
||||
|
||||
dispose_sandbox: bool = SchemaField(
|
||||
description=(
|
||||
"Whether to dispose of the sandbox immediately after execution. "
|
||||
"If disabled, the sandbox will run until its timeout expires."
|
||||
),
|
||||
default=True,
|
||||
)
|
||||
|
||||
template_id: str = SchemaField(
|
||||
description=(
|
||||
"You can use an E2B sandbox template by entering its ID here. "
|
||||
@@ -223,29 +86,21 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
main_result: MainCodeExecutionResult = SchemaField(
|
||||
title="Main Result", description="The main result from the code execution"
|
||||
)
|
||||
results: list[CodeExecutionResult] = SchemaField(
|
||||
description="List of results from the code execution"
|
||||
)
|
||||
response: str = SchemaField(
|
||||
title="Main Text Output",
|
||||
description="Text output (if any) of the main execution result",
|
||||
)
|
||||
class Output(BlockSchema):
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
stderr_logs: str = SchemaField(description="Standard error logs from execution")
|
||||
error: str = SchemaField(description="Error message if execution failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0b02b072-abe7-11ef-8372-fb5d162dd712",
|
||||
description="Executes code in a sandbox environment with internet access.",
|
||||
description="Executes code in an isolated sandbox environment with internet access.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=ExecuteCodeBlock.Input,
|
||||
output_schema=ExecuteCodeBlock.Output,
|
||||
input_schema=CodeExecutionBlock.Input,
|
||||
output_schema=CodeExecutionBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
@@ -256,59 +111,91 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
|
||||
"template_id": "",
|
||||
},
|
||||
test_output=[
|
||||
("results", []),
|
||||
("response", "Hello World"),
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout, dispose_sandbox: ( # noqa
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
"sandbox_id", # sandbox_id
|
||||
"execute_code": lambda code, language, setup_commands, timeout, api_key, template_id: (
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
setup_commands: list[str],
|
||||
timeout: int,
|
||||
api_key: str,
|
||||
template_id: str,
|
||||
):
|
||||
try:
|
||||
sandbox = None
|
||||
if template_id:
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template_id, api_key=api_key, timeout=timeout
|
||||
)
|
||||
else:
|
||||
sandbox = await AsyncSandbox.create(api_key=api_key, timeout=timeout)
|
||||
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not created")
|
||||
|
||||
# Running setup commands
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox if there is an error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return response, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
results, text_output, stdout, stderr, _ = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.code,
|
||||
language=input_data.language,
|
||||
template_id=input_data.template_id,
|
||||
setup_commands=input_data.setup_commands,
|
||||
timeout=input_data.timeout,
|
||||
dispose_sandbox=input_data.dispose_sandbox,
|
||||
response, stdout_logs, stderr_logs = await self.execute_code(
|
||||
input_data.code,
|
||||
input_data.language,
|
||||
input_data.setup_commands,
|
||||
input_data.timeout,
|
||||
credentials.api_key.get_secret_value(),
|
||||
input_data.template_id,
|
||||
)
|
||||
|
||||
# Determine result object shape & filter out empty formats
|
||||
main_result, results = self.process_execution_results(results)
|
||||
if main_result:
|
||||
yield "main_result", main_result
|
||||
yield "results", results
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout:
|
||||
yield "stdout_logs", stdout
|
||||
if stderr:
|
||||
yield "stderr_logs", stderr
|
||||
if response:
|
||||
yield "response", response
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
class Input(BlockSchemaInput):
|
||||
class InstantiationBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description=(
|
||||
"Enter your API key for the E2B platform. "
|
||||
"You can get it in here - https://e2b.dev/docs"
|
||||
)
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
)
|
||||
|
||||
# Todo : Option to run commond in background
|
||||
@@ -351,27 +238,22 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
sandbox_id: str = SchemaField(description="ID of the sandbox instance")
|
||||
response: str = SchemaField(
|
||||
title="Text Result",
|
||||
description="Text result (if any) of the setup code execution",
|
||||
)
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
stderr_logs: str = SchemaField(description="Standard error logs from execution")
|
||||
error: str = SchemaField(description="Error message if execution failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ff0861c9-1726-4aec-9e5b-bf53f3622112",
|
||||
description=(
|
||||
"Instantiate a sandbox environment with internet access "
|
||||
"in which you can execute code with the Execute Code Step block."
|
||||
),
|
||||
description="Instantiate an isolated sandbox environment with internet access where to execute code in.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=InstantiateCodeSandboxBlock.Input,
|
||||
output_schema=InstantiateCodeSandboxBlock.Output,
|
||||
input_schema=InstantiationBlock.Input,
|
||||
output_schema=InstantiationBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
@@ -387,12 +269,11 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda api_key, code, language, template_id, setup_commands, timeout: ( # noqa
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
"sandbox_id", # sandbox_id
|
||||
"execute_code": lambda setup_code, language, setup_commands, timeout, api_key, template_id: (
|
||||
"sandbox_id",
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -401,38 +282,78 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
_, text_output, stdout, stderr, sandbox_id = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.setup_code,
|
||||
language=input_data.language,
|
||||
template_id=input_data.template_id,
|
||||
setup_commands=input_data.setup_commands,
|
||||
timeout=input_data.timeout,
|
||||
sandbox_id, response, stdout_logs, stderr_logs = await self.execute_code(
|
||||
input_data.setup_code,
|
||||
input_data.language,
|
||||
input_data.setup_commands,
|
||||
input_data.timeout,
|
||||
credentials.api_key.get_secret_value(),
|
||||
input_data.template_id,
|
||||
)
|
||||
if sandbox_id:
|
||||
yield "sandbox_id", sandbox_id
|
||||
else:
|
||||
yield "error", "Sandbox ID not found"
|
||||
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout:
|
||||
yield "stdout_logs", stdout
|
||||
if stderr:
|
||||
yield "stderr_logs", stderr
|
||||
if response:
|
||||
yield "response", response
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
async def execute_code(
|
||||
self,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
setup_commands: list[str],
|
||||
timeout: int,
|
||||
api_key: str,
|
||||
template_id: str,
|
||||
):
|
||||
try:
|
||||
sandbox = None
|
||||
if template_id:
|
||||
sandbox = await AsyncSandbox.create(
|
||||
template=template_id, api_key=api_key, timeout=timeout
|
||||
)
|
||||
else:
|
||||
sandbox = await AsyncSandbox.create(api_key=api_key, timeout=timeout)
|
||||
|
||||
class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
||||
class Input(BlockSchemaInput):
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not created")
|
||||
|
||||
# Running setup commands
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox if there is an error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return sandbox.sandbox_id, response, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
class StepExecutionBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description=(
|
||||
"Enter your API key for the E2B platform. "
|
||||
"You can get it in here - https://e2b.dev/docs"
|
||||
),
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
)
|
||||
|
||||
sandbox_id: str = SchemaField(
|
||||
@@ -453,34 +374,21 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
dispose_sandbox: bool = SchemaField(
|
||||
description="Whether to dispose of the sandbox after executing this code.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
main_result: MainCodeExecutionResult = SchemaField(
|
||||
title="Main Result", description="The main result from the code execution"
|
||||
)
|
||||
results: list[CodeExecutionResult] = SchemaField(
|
||||
description="List of results from the code execution"
|
||||
)
|
||||
response: str = SchemaField(
|
||||
title="Main Text Output",
|
||||
description="Text output (if any) of the main execution result",
|
||||
)
|
||||
class Output(BlockSchema):
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
stderr_logs: str = SchemaField(description="Standard error logs from execution")
|
||||
error: str = SchemaField(description="Error message if execution failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="82b59b8e-ea10-4d57-9161-8b169b0adba6",
|
||||
description="Execute code in a previously instantiated sandbox.",
|
||||
description="Execute code in a previously instantiated sandbox environment.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=ExecuteCodeStepBlock.Input,
|
||||
output_schema=ExecuteCodeStepBlock.Output,
|
||||
input_schema=StepExecutionBlock.Input,
|
||||
output_schema=StepExecutionBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
@@ -489,43 +397,61 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
|
||||
"language": ProgrammingLanguage.PYTHON.value,
|
||||
},
|
||||
test_output=[
|
||||
("results", []),
|
||||
("response", "Hello World"),
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda api_key, code, language, sandbox_id, dispose_sandbox: ( # noqa
|
||||
[], # results
|
||||
"Hello World", # text_output
|
||||
"Hello World\n", # stdout_logs
|
||||
"", # stderr_logs
|
||||
sandbox_id, # sandbox_id
|
||||
"execute_step_code": lambda sandbox_id, step_code, language, api_key: (
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def execute_step_code(
|
||||
self,
|
||||
sandbox_id: str,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
api_key: str,
|
||||
):
|
||||
try:
|
||||
sandbox = await AsyncSandbox.connect(sandbox_id=sandbox_id, api_key=api_key)
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not found")
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(code, language=language.value)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return response, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
results, text_output, stdout, stderr, _ = await self.execute_code(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
code=input_data.step_code,
|
||||
language=input_data.language,
|
||||
sandbox_id=input_data.sandbox_id,
|
||||
dispose_sandbox=input_data.dispose_sandbox,
|
||||
response, stdout_logs, stderr_logs = await self.execute_step_code(
|
||||
input_data.sandbox_id,
|
||||
input_data.step_code,
|
||||
input_data.language,
|
||||
credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
# Determine result object shape & filter out empty formats
|
||||
main_result, results = self.process_execution_results(results)
|
||||
if main_result:
|
||||
yield "main_result", main_result
|
||||
yield "results", results
|
||||
if text_output:
|
||||
yield "response", text_output
|
||||
if stdout:
|
||||
yield "stdout_logs", stdout
|
||||
if stderr:
|
||||
yield "stderr_logs", stderr
|
||||
if response:
|
||||
yield "response", response
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -1,23 +1,17 @@
|
||||
import re
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class CodeExtractionBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(
|
||||
description="Text containing code blocks to extract (e.g., AI response)",
|
||||
placeholder="Enter text containing code blocks",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
html: str = SchemaField(description="Extracted HTML code")
|
||||
css: str = SchemaField(description="Extracted CSS code")
|
||||
javascript: str = SchemaField(description="Extracted JavaScript code")
|
||||
@@ -96,7 +90,7 @@ class CodeExtractionBlock(Block):
|
||||
for aliases in language_aliases.values()
|
||||
for alias in aliases
|
||||
)
|
||||
+ r")[ \t]*\n[\s\S]*?```"
|
||||
+ r")\s+[\s\S]*?```"
|
||||
)
|
||||
|
||||
remaining_text = re.sub(pattern, "", input_data.text).strip()
|
||||
@@ -109,9 +103,7 @@ class CodeExtractionBlock(Block):
|
||||
# Escape special regex characters in the language string
|
||||
language = re.escape(language)
|
||||
# Extract all code blocks enclosed in ```language``` blocks
|
||||
pattern = re.compile(
|
||||
rf"```{language}[ \t]*\n(.*?)\n```", re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
pattern = re.compile(rf"```{language}\s+(.*?)```", re.DOTALL | re.IGNORECASE)
|
||||
matches = pattern.finditer(text)
|
||||
# Combine all code blocks for this language with newlines between them
|
||||
code_blocks = [match.group(1).strip() for match in matches]
|
||||
|
||||
@@ -5,8 +5,7 @@ from backend.data.block import (
|
||||
BlockCategory,
|
||||
BlockManualWebhookConfig,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -28,10 +27,10 @@ class TranscriptionDataModel(BaseModel):
|
||||
|
||||
|
||||
class CompassAITriggerBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
payload: TranscriptionDataModel = SchemaField(hidden=True)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
transcription: str = SchemaField(
|
||||
description="The contents of the compass transcription."
|
||||
)
|
||||
|
||||
@@ -1,22 +1,16 @@
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class WordCharacterCountBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(
|
||||
description="Input text to count words and characters",
|
||||
placeholder="Enter your text here",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
word_count: int = SchemaField(description="Number of words in the input text")
|
||||
character_count: int = SchemaField(
|
||||
description="Number of characters in the input text"
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
from typing import Any, List
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.json import loads
|
||||
from backend.util.mock import MockObject
|
||||
@@ -18,13 +12,13 @@ from backend.util.prompt import estimate_token_count_str
|
||||
|
||||
|
||||
class CreateDictionaryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
values: dict[str, Any] = SchemaField(
|
||||
description="Key-value pairs to create the dictionary with",
|
||||
placeholder="e.g., {'name': 'Alice', 'age': 25}",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
dictionary: dict[str, Any] = SchemaField(
|
||||
description="The created dictionary containing the specified key-value pairs"
|
||||
)
|
||||
@@ -68,11 +62,10 @@ class CreateDictionaryBlock(Block):
|
||||
|
||||
|
||||
class AddToDictionaryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
|
||||
advanced=False,
|
||||
)
|
||||
key: str = SchemaField(
|
||||
default="",
|
||||
@@ -92,10 +85,11 @@ class AddToDictionaryBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
updated_dictionary: dict = SchemaField(
|
||||
description="The dictionary with the new entry added."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -146,11 +140,11 @@ class AddToDictionaryBlock(Block):
|
||||
|
||||
|
||||
class FindInDictionaryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
input: Any = SchemaField(description="Dictionary to lookup from")
|
||||
key: str | int = SchemaField(description="Key to lookup in the dictionary")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="Value found for the given key")
|
||||
missing: Any = SchemaField(
|
||||
description="Value of the input that missing the key"
|
||||
@@ -206,7 +200,7 @@ class FindInDictionaryBlock(Block):
|
||||
|
||||
|
||||
class RemoveFromDictionaryBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary to modify."
|
||||
)
|
||||
@@ -215,11 +209,12 @@ class RemoveFromDictionaryBlock(Block):
|
||||
default=False, description="Whether to return the removed value."
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
updated_dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary after removal."
|
||||
)
|
||||
removed_value: Any = SchemaField(description="The removed value if requested.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -255,18 +250,19 @@ class RemoveFromDictionaryBlock(Block):
|
||||
|
||||
|
||||
class ReplaceDictionaryValueBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary to modify."
|
||||
)
|
||||
key: str | int = SchemaField(description="Key to replace the value for.")
|
||||
value: Any = SchemaField(description="The new value for the given key.")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
updated_dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary after replacement."
|
||||
)
|
||||
old_value: Any = SchemaField(description="The value that was replaced.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -303,10 +299,10 @@ class ReplaceDictionaryValueBlock(Block):
|
||||
|
||||
|
||||
class DictionaryIsEmptyBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(description="The dictionary to check.")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
is_empty: bool = SchemaField(description="True if the dictionary is empty.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -330,7 +326,7 @@ class DictionaryIsEmptyBlock(Block):
|
||||
|
||||
|
||||
class CreateListBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
values: List[Any] = SchemaField(
|
||||
description="A list of values to be combined into a new list.",
|
||||
placeholder="e.g., ['Alice', 25, True]",
|
||||
@@ -346,10 +342,11 @@ class CreateListBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
description="The created list containing the specified values."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if list creation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -406,7 +403,7 @@ class CreateListBlock(Block):
|
||||
|
||||
|
||||
class AddToListBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
@@ -427,10 +424,11 @@ class AddToListBlock(Block):
|
||||
description="The position to insert the new entry. If not provided, the entry will be appended to the end of the list.",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
updated_list: List[Any] = SchemaField(
|
||||
description="The list with the new entry added."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -485,11 +483,11 @@ class AddToListBlock(Block):
|
||||
|
||||
|
||||
class FindInListBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to search in.")
|
||||
value: Any = SchemaField(description="The value to search for.")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
index: int = SchemaField(description="The index of the value in the list.")
|
||||
found: bool = SchemaField(
|
||||
description="Whether the value was found in the list."
|
||||
@@ -527,14 +525,15 @@ class FindInListBlock(Block):
|
||||
|
||||
|
||||
class GetListItemBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to get the item from.")
|
||||
index: int = SchemaField(
|
||||
description="The 0-based index of the item (supports negative indices)."
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
item: Any = SchemaField(description="The item at the specified index.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -561,7 +560,7 @@ class GetListItemBlock(Block):
|
||||
|
||||
|
||||
class RemoveFromListBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to modify.")
|
||||
value: Any = SchemaField(
|
||||
default=None, description="Value to remove from the list."
|
||||
@@ -574,9 +573,10 @@ class RemoveFromListBlock(Block):
|
||||
default=False, description="Whether to return the removed item."
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
updated_list: List[Any] = SchemaField(description="The list after removal.")
|
||||
removed_item: Any = SchemaField(description="The removed item if requested.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -617,16 +617,17 @@ class RemoveFromListBlock(Block):
|
||||
|
||||
|
||||
class ReplaceListItemBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to modify.")
|
||||
index: int = SchemaField(
|
||||
description="Index of the item to replace (supports negative indices)."
|
||||
)
|
||||
value: Any = SchemaField(description="The new value for the given index.")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
updated_list: List[Any] = SchemaField(description="The list after replacement.")
|
||||
old_item: Any = SchemaField(description="The item that was replaced.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -661,10 +662,10 @@ class ReplaceListItemBlock(Block):
|
||||
|
||||
|
||||
class ListIsEmptyBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to check.")
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
is_empty: bool = SchemaField(description="True if the list is empty.")
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -113,7 +113,6 @@ class DataForSeoClient:
|
||||
include_serp_info: bool = False,
|
||||
include_clickstream_data: bool = False,
|
||||
limit: int = 100,
|
||||
depth: Optional[int] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get related keywords from DataForSEO Labs.
|
||||
@@ -126,7 +125,6 @@ class DataForSeoClient:
|
||||
include_serp_info: Include SERP data
|
||||
include_clickstream_data: Include clickstream metrics
|
||||
limit: Maximum number of results (up to 3000)
|
||||
depth: Keyword search depth (0-4), controls number of returned keywords
|
||||
|
||||
Returns:
|
||||
API response with related keywords
|
||||
@@ -150,8 +148,6 @@ class DataForSeoClient:
|
||||
task_data["include_clickstream_data"] = include_clickstream_data
|
||||
if limit is not None:
|
||||
task_data["limit"] = limit
|
||||
if depth is not None:
|
||||
task_data["depth"] = depth
|
||||
|
||||
payload = [task_data]
|
||||
|
||||
|
||||
@@ -8,8 +8,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
@@ -19,7 +18,7 @@ from ._api import DataForSeoClient
|
||||
from ._config import dataforseo
|
||||
|
||||
|
||||
class KeywordSuggestion(BlockSchemaInput):
|
||||
class KeywordSuggestion(BlockSchema):
|
||||
"""Schema for a keyword suggestion result."""
|
||||
|
||||
keyword: str = SchemaField(description="The keyword suggestion")
|
||||
@@ -46,7 +45,7 @@ class KeywordSuggestion(BlockSchemaInput):
|
||||
class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
"""Block for getting keyword suggestions from DataForSEO Labs."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = dataforseo.credentials_field(
|
||||
description="DataForSEO credentials (username and password)"
|
||||
)
|
||||
@@ -78,7 +77,7 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
le=3000,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
suggestions: List[KeywordSuggestion] = SchemaField(
|
||||
description="List of keyword suggestions with metrics"
|
||||
)
|
||||
@@ -162,63 +161,54 @@ class DataForSeoKeywordSuggestionsBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the keyword suggestions query."""
|
||||
try:
|
||||
client = DataForSeoClient(credentials)
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||
results = await self._fetch_keyword_suggestions(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
suggestions = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", [])
|
||||
if isinstance(first_result, dict)
|
||||
else []
|
||||
# Process and format the results
|
||||
suggestions = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Create the KeywordSuggestion object
|
||||
suggestion = KeywordSuggestion(
|
||||
keyword=item.get("keyword", ""),
|
||||
search_volume=item.get("keyword_info", {}).get("search_volume"),
|
||||
competition=item.get("keyword_info", {}).get("competition"),
|
||||
cpc=item.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=item.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
item.get("serp_info") if input_data.include_serp_info else None
|
||||
),
|
||||
clickstream_data=(
|
||||
item.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
if items is None:
|
||||
items = []
|
||||
for item in items:
|
||||
# Create the KeywordSuggestion object
|
||||
suggestion = KeywordSuggestion(
|
||||
keyword=item.get("keyword", ""),
|
||||
search_volume=item.get("keyword_info", {}).get("search_volume"),
|
||||
competition=item.get("keyword_info", {}).get("competition"),
|
||||
cpc=item.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=item.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
item.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
item.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "suggestion", suggestion
|
||||
suggestions.append(suggestion)
|
||||
yield "suggestion", suggestion
|
||||
suggestions.append(suggestion)
|
||||
|
||||
yield "suggestions", suggestions
|
||||
yield "total_count", len(suggestions)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to fetch keyword suggestions: {str(e)}"
|
||||
yield "suggestions", suggestions
|
||||
yield "total_count", len(suggestions)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
|
||||
|
||||
class KeywordSuggestionExtractorBlock(Block):
|
||||
"""Extracts individual fields from a KeywordSuggestion object."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
suggestion: KeywordSuggestion = SchemaField(
|
||||
description="The keyword suggestion object to extract fields from"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
keyword: str = SchemaField(description="The keyword suggestion")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
|
||||
@@ -8,8 +8,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
@@ -19,7 +18,7 @@ from ._api import DataForSeoClient
|
||||
from ._config import dataforseo
|
||||
|
||||
|
||||
class RelatedKeyword(BlockSchemaInput):
|
||||
class RelatedKeyword(BlockSchema):
|
||||
"""Schema for a related keyword result."""
|
||||
|
||||
keyword: str = SchemaField(description="The related keyword")
|
||||
@@ -46,7 +45,7 @@ class RelatedKeyword(BlockSchemaInput):
|
||||
class DataForSeoRelatedKeywordsBlock(Block):
|
||||
"""Block for getting related keywords from DataForSEO Labs."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = dataforseo.credentials_field(
|
||||
description="DataForSEO credentials (username and password)"
|
||||
)
|
||||
@@ -79,14 +78,8 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
ge=1,
|
||||
le=3000,
|
||||
)
|
||||
depth: int = SchemaField(
|
||||
description="Keyword search depth (0-4). Controls the number of returned keywords: 0=1 keyword, 1=~8 keywords, 2=~72 keywords, 3=~584 keywords, 4=~4680 keywords",
|
||||
default=1,
|
||||
ge=0,
|
||||
le=4,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
related_keywords: List[RelatedKeyword] = SchemaField(
|
||||
description="List of related keywords with metrics"
|
||||
)
|
||||
@@ -161,7 +154,6 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
include_serp_info=input_data.include_serp_info,
|
||||
include_clickstream_data=input_data.include_clickstream_data,
|
||||
limit=input_data.limit,
|
||||
depth=input_data.depth,
|
||||
)
|
||||
|
||||
async def run(
|
||||
@@ -172,71 +164,61 @@ class DataForSeoRelatedKeywordsBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the related keywords query."""
|
||||
try:
|
||||
client = DataForSeoClient(credentials)
|
||||
client = DataForSeoClient(credentials)
|
||||
|
||||
results = await self._fetch_related_keywords(client, input_data)
|
||||
results = await self._fetch_related_keywords(client, input_data)
|
||||
|
||||
# Process and format the results
|
||||
related_keywords = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", [])
|
||||
if isinstance(first_result, dict)
|
||||
else []
|
||||
# Process and format the results
|
||||
related_keywords = []
|
||||
if results and len(results) > 0:
|
||||
# results is a list, get the first element
|
||||
first_result = results[0] if isinstance(results, list) else results
|
||||
items = (
|
||||
first_result.get("items", []) if isinstance(first_result, dict) else []
|
||||
)
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
keyword_data = item.get("keyword_data", {})
|
||||
|
||||
# Create the RelatedKeyword object
|
||||
keyword = RelatedKeyword(
|
||||
keyword=keyword_data.get("keyword", ""),
|
||||
search_volume=keyword_data.get("keyword_info", {}).get(
|
||||
"search_volume"
|
||||
),
|
||||
competition=keyword_data.get("keyword_info", {}).get("competition"),
|
||||
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=keyword_data.get("keyword_properties", {}).get(
|
||||
"keyword_difficulty"
|
||||
),
|
||||
serp_info=(
|
||||
keyword_data.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
keyword_data.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
# Ensure items is never None
|
||||
if items is None:
|
||||
items = []
|
||||
for item in items:
|
||||
# Extract keyword_data from the item
|
||||
keyword_data = item.get("keyword_data", {})
|
||||
yield "related_keyword", keyword
|
||||
related_keywords.append(keyword)
|
||||
|
||||
# Create the RelatedKeyword object
|
||||
keyword = RelatedKeyword(
|
||||
keyword=keyword_data.get("keyword", ""),
|
||||
search_volume=keyword_data.get("keyword_info", {}).get(
|
||||
"search_volume"
|
||||
),
|
||||
competition=keyword_data.get("keyword_info", {}).get(
|
||||
"competition"
|
||||
),
|
||||
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
|
||||
keyword_difficulty=keyword_data.get(
|
||||
"keyword_properties", {}
|
||||
).get("keyword_difficulty"),
|
||||
serp_info=(
|
||||
keyword_data.get("serp_info")
|
||||
if input_data.include_serp_info
|
||||
else None
|
||||
),
|
||||
clickstream_data=(
|
||||
keyword_data.get("clickstream_keyword_info")
|
||||
if input_data.include_clickstream_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
yield "related_keyword", keyword
|
||||
related_keywords.append(keyword)
|
||||
|
||||
yield "related_keywords", related_keywords
|
||||
yield "total_count", len(related_keywords)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to fetch related keywords: {str(e)}"
|
||||
yield "related_keywords", related_keywords
|
||||
yield "total_count", len(related_keywords)
|
||||
yield "seed_keyword", input_data.keyword
|
||||
|
||||
|
||||
class RelatedKeywordExtractorBlock(Block):
|
||||
"""Extracts individual fields from a RelatedKeyword object."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
related_keyword: RelatedKeyword = SchemaField(
|
||||
description="The related keyword object to extract fields from"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
keyword: str = SchemaField(description="The related keyword")
|
||||
search_volume: Optional[int] = SchemaField(
|
||||
description="Monthly search volume", default=None
|
||||
|
||||
@@ -1,23 +1,17 @@
|
||||
import codecs
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TextDecoderBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(
|
||||
description="A string containing escaped characters to be decoded",
|
||||
placeholder='Your entire text block with \\n and \\" escaped characters',
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
decoded_text: str = SchemaField(
|
||||
description="The decoded text with escape sequences processed"
|
||||
)
|
||||
|
||||
@@ -4,19 +4,13 @@ import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.request import Requests
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
from ._auth import (
|
||||
@@ -34,10 +28,10 @@ TEST_CREDENTIALS_INPUT = TEST_BOT_CREDENTIALS_INPUT
|
||||
|
||||
|
||||
class ReadDiscordMessagesBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
message_content: str = SchemaField(
|
||||
description="The content of the message received"
|
||||
)
|
||||
@@ -120,9 +114,10 @@ class ReadDiscordMessagesBlock(Block):
|
||||
if message.attachments:
|
||||
attachment = message.attachments[0] # Process the first attachment
|
||||
if attachment.filename.endswith((".txt", ".py")):
|
||||
response = await Requests().get(attachment.url)
|
||||
file_content = response.text()
|
||||
self.output_data += f"\n\nFile from user: {attachment.filename}\nContent: {file_content}"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(attachment.url) as response:
|
||||
file_content = response.text()
|
||||
self.output_data += f"\n\nFile from user: {attachment.filename}\nContent: {file_content}"
|
||||
|
||||
await client.close()
|
||||
|
||||
@@ -170,21 +165,21 @@ class ReadDiscordMessagesBlock(Block):
|
||||
|
||||
|
||||
class SendDiscordMessageBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
message_content: str = SchemaField(
|
||||
description="The content of the message to send"
|
||||
)
|
||||
channel_name: str = SchemaField(
|
||||
description="Channel ID or channel name to send the message to"
|
||||
description="The name of the channel the message will be sent to"
|
||||
)
|
||||
server_name: str = SchemaField(
|
||||
description="Server name (only needed if using channel name)",
|
||||
advanced=True,
|
||||
description="The name of the server where the channel is located",
|
||||
advanced=True, # Optional field for server name
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(
|
||||
description="The status of the operation (e.g., 'Message sent', 'Error')"
|
||||
)
|
||||
@@ -236,49 +231,25 @@ class SendDiscordMessageBlock(Block):
|
||||
@client.event
|
||||
async def on_ready():
|
||||
print(f"Logged in as {client.user}")
|
||||
channel = None
|
||||
for guild in client.guilds:
|
||||
if server_name and guild.name != server_name:
|
||||
continue
|
||||
for channel in guild.text_channels:
|
||||
if channel.name == channel_name:
|
||||
# Split message into chunks if it exceeds 2000 characters
|
||||
chunks = self.chunk_message(message_content)
|
||||
last_message = None
|
||||
for chunk in chunks:
|
||||
last_message = await channel.send(chunk)
|
||||
result["status"] = "Message sent"
|
||||
result["message_id"] = (
|
||||
str(last_message.id) if last_message else ""
|
||||
)
|
||||
result["channel_id"] = str(channel.id)
|
||||
await client.close()
|
||||
return
|
||||
|
||||
# Try to parse as channel ID first
|
||||
try:
|
||||
channel_id = int(channel_name)
|
||||
channel = client.get_channel(channel_id)
|
||||
except ValueError:
|
||||
# Not a valid ID, will try name lookup
|
||||
pass
|
||||
|
||||
# If not found by ID (or not an ID), try name lookup
|
||||
if not channel:
|
||||
for guild in client.guilds:
|
||||
if server_name and guild.name != server_name:
|
||||
continue
|
||||
for ch in guild.text_channels:
|
||||
if ch.name == channel_name:
|
||||
channel = ch
|
||||
break
|
||||
if channel:
|
||||
break
|
||||
|
||||
if not channel:
|
||||
result["status"] = f"Channel not found: {channel_name}"
|
||||
await client.close()
|
||||
return
|
||||
|
||||
# Type check - ensure it's a text channel that can send messages
|
||||
if not hasattr(channel, "send"):
|
||||
result["status"] = (
|
||||
f"Channel {channel_name} cannot receive messages (not a text channel)"
|
||||
)
|
||||
await client.close()
|
||||
return
|
||||
|
||||
# Split message into chunks if it exceeds 2000 characters
|
||||
chunks = self.chunk_message(message_content)
|
||||
last_message = None
|
||||
for chunk in chunks:
|
||||
last_message = await channel.send(chunk) # type: ignore
|
||||
result["status"] = "Message sent"
|
||||
result["message_id"] = str(last_message.id) if last_message else ""
|
||||
result["channel_id"] = str(channel.id)
|
||||
result["status"] = "Channel not found"
|
||||
await client.close()
|
||||
|
||||
await client.start(token)
|
||||
@@ -316,7 +287,7 @@ class SendDiscordMessageBlock(Block):
|
||||
|
||||
|
||||
class SendDiscordDMBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
user_id: str = SchemaField(
|
||||
description="The Discord user ID to send the DM to (e.g., '123456789012345678')"
|
||||
@@ -325,7 +296,7 @@ class SendDiscordDMBlock(Block):
|
||||
description="The content of the direct message to send"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="The status of the operation")
|
||||
message_id: str = SchemaField(description="The ID of the sent message")
|
||||
|
||||
@@ -405,7 +376,7 @@ class SendDiscordDMBlock(Block):
|
||||
|
||||
|
||||
class SendDiscordEmbedBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
channel_identifier: str = SchemaField(
|
||||
description="Channel ID or channel name to send the embed to"
|
||||
@@ -442,7 +413,7 @@ class SendDiscordEmbedBlock(Block):
|
||||
default=[],
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Operation status")
|
||||
message_id: str = SchemaField(description="ID of the sent embed message")
|
||||
|
||||
@@ -592,7 +563,7 @@ class SendDiscordEmbedBlock(Block):
|
||||
|
||||
|
||||
class SendDiscordFileBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
channel_identifier: str = SchemaField(
|
||||
description="Channel ID or channel name to send the file to"
|
||||
@@ -613,7 +584,7 @@ class SendDiscordFileBlock(Block):
|
||||
description="Optional message to send with the file", default=""
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Operation status")
|
||||
message_id: str = SchemaField(description="ID of the sent message")
|
||||
|
||||
@@ -704,15 +675,16 @@ class SendDiscordFileBlock(Block):
|
||||
|
||||
elif file.startswith(("http://", "https://")):
|
||||
# URL - download the file
|
||||
response = await Requests().get(file)
|
||||
file_bytes = response.content
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(file) as response:
|
||||
file_bytes = await response.read()
|
||||
|
||||
# Try to get filename from URL if not provided
|
||||
if not filename:
|
||||
from urllib.parse import urlparse
|
||||
# Try to get filename from URL if not provided
|
||||
if not filename:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
path = urlparse(file).path
|
||||
detected_filename = Path(path).name or "download"
|
||||
path = urlparse(file).path
|
||||
detected_filename = Path(path).name or "download"
|
||||
else:
|
||||
# Local file path - read from stored media file
|
||||
# This would be a path from a previous block's output
|
||||
@@ -794,7 +766,7 @@ class SendDiscordFileBlock(Block):
|
||||
|
||||
|
||||
class ReplyToDiscordMessageBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
channel_id: str = SchemaField(
|
||||
description="The channel ID where the message to reply to is located"
|
||||
@@ -805,7 +777,7 @@ class ReplyToDiscordMessageBlock(Block):
|
||||
description="Whether to mention the original message author", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Operation status")
|
||||
reply_id: str = SchemaField(description="ID of the reply message")
|
||||
|
||||
@@ -919,13 +891,13 @@ class ReplyToDiscordMessageBlock(Block):
|
||||
|
||||
|
||||
class DiscordUserInfoBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
user_id: str = SchemaField(
|
||||
description="The Discord user ID to get information about"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
user_id: str = SchemaField(
|
||||
description="The user's ID (passed through for chaining)"
|
||||
)
|
||||
@@ -1036,7 +1008,7 @@ class DiscordUserInfoBlock(Block):
|
||||
|
||||
|
||||
class DiscordChannelInfoBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordCredentials = DiscordCredentialsField()
|
||||
channel_identifier: str = SchemaField(
|
||||
description="Channel name or channel ID to look up"
|
||||
@@ -1047,7 +1019,7 @@ class DiscordChannelInfoBlock(Block):
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
channel_id: str = SchemaField(description="The channel's ID")
|
||||
channel_name: str = SchemaField(description="The channel's name")
|
||||
server_id: str = SchemaField(description="The server's ID")
|
||||
|
||||
@@ -2,13 +2,7 @@
|
||||
Discord OAuth-based blocks.
|
||||
"""
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
|
||||
from ._api import DiscordOAuthUser, get_current_user
|
||||
@@ -27,12 +21,12 @@ class DiscordGetCurrentUserBlock(Block):
|
||||
This block requires Discord OAuth2 credentials (not bot tokens).
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: DiscordOAuthCredentialsInput = DiscordOAuthCredentialsField(
|
||||
["identify"]
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
user_id: str = SchemaField(description="The authenticated user's Discord ID")
|
||||
username: str = SchemaField(description="The user's username")
|
||||
avatar_url: str = SchemaField(description="URL to the user's avatar image")
|
||||
|
||||
@@ -5,13 +5,7 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, SecretStr
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
@@ -57,7 +51,7 @@ class SMTPConfig(BaseModel):
|
||||
|
||||
|
||||
class SendEmailBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
to_email: str = SchemaField(
|
||||
description="Recipient email address", placeholder="recipient@example.com"
|
||||
)
|
||||
@@ -73,7 +67,7 @@ class SendEmailBlock(Block):
|
||||
)
|
||||
credentials: SMTPCredentialsInput = SMTPCredentialsField()
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Status of the email sending operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the email sending failed"
|
||||
|
||||
@@ -8,13 +8,7 @@ which provides access to LinkedIn profile data and related information.
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
@@ -35,7 +29,7 @@ logger = logging.getLogger(__name__)
|
||||
class GetLinkedinProfileBlock(Block):
|
||||
"""Block to fetch LinkedIn profile data using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
"""Input schema for GetLinkedinProfileBlock."""
|
||||
|
||||
linkedin_url: str = SchemaField(
|
||||
@@ -86,12 +80,13 @@ class GetLinkedinProfileBlock(Block):
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
"""Output schema for GetLinkedinProfileBlock."""
|
||||
|
||||
profile: PersonProfileResponse = SchemaField(
|
||||
description="LinkedIn profile data"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GetLinkedinProfileBlock."""
|
||||
@@ -204,7 +199,7 @@ class GetLinkedinProfileBlock(Block):
|
||||
class LinkedinPersonLookupBlock(Block):
|
||||
"""Block to look up LinkedIn profiles by person's information using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
"""Input schema for LinkedinPersonLookupBlock."""
|
||||
|
||||
first_name: str = SchemaField(
|
||||
@@ -247,12 +242,13 @@ class LinkedinPersonLookupBlock(Block):
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
"""Output schema for LinkedinPersonLookupBlock."""
|
||||
|
||||
lookup_result: PersonLookupResponse = SchemaField(
|
||||
description="LinkedIn profile lookup result"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize LinkedinPersonLookupBlock."""
|
||||
@@ -350,7 +346,7 @@ class LinkedinPersonLookupBlock(Block):
|
||||
class LinkedinRoleLookupBlock(Block):
|
||||
"""Block to look up LinkedIn profiles by role in a company using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
"""Input schema for LinkedinRoleLookupBlock."""
|
||||
|
||||
role: str = SchemaField(
|
||||
@@ -370,12 +366,13 @@ class LinkedinRoleLookupBlock(Block):
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
"""Output schema for LinkedinRoleLookupBlock."""
|
||||
|
||||
role_lookup_result: RoleLookupResponse = SchemaField(
|
||||
description="LinkedIn role lookup result"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize LinkedinRoleLookupBlock."""
|
||||
@@ -452,7 +449,7 @@ class LinkedinRoleLookupBlock(Block):
|
||||
class GetLinkedinProfilePictureBlock(Block):
|
||||
"""Block to get LinkedIn profile pictures using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
"""Input schema for GetLinkedinProfilePictureBlock."""
|
||||
|
||||
linkedin_profile_url: str = SchemaField(
|
||||
@@ -463,12 +460,13 @@ class GetLinkedinProfilePictureBlock(Block):
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
"""Output schema for GetLinkedinProfilePictureBlock."""
|
||||
|
||||
profile_picture_url: MediaFileType = SchemaField(
|
||||
description="LinkedIn profile picture URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GetLinkedinProfilePictureBlock."""
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
"""
|
||||
Test credentials and helpers for Exa blocks.
|
||||
"""
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="exa",
|
||||
api_key=SecretStr("mock-exa-api-key"),
|
||||
title="Mock Exa API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
@@ -1,59 +1,55 @@
|
||||
from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.api import AnswerResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseModel,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
MediaFileType,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
class AnswerCitation(BaseModel):
|
||||
"""Citation model for answer endpoint."""
|
||||
class CostBreakdown(BaseModel):
|
||||
keywordSearch: float
|
||||
neuralSearch: float
|
||||
contentText: float
|
||||
contentHighlight: float
|
||||
contentSummary: float
|
||||
|
||||
id: str = SchemaField(description="The temporary ID for the document")
|
||||
url: str = SchemaField(description="The URL of the search result")
|
||||
title: Optional[str] = SchemaField(description="The title of the search result")
|
||||
author: Optional[str] = SchemaField(description="The author of the content")
|
||||
publishedDate: Optional[str] = SchemaField(
|
||||
description="An estimate of the creation date"
|
||||
)
|
||||
text: Optional[str] = SchemaField(description="The full text content of the source")
|
||||
image: Optional[MediaFileType] = SchemaField(
|
||||
description="The URL of the image associated with the result"
|
||||
)
|
||||
favicon: Optional[MediaFileType] = SchemaField(
|
||||
description="The URL of the favicon for the domain"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, sdk_citation) -> "AnswerCitation":
|
||||
"""Convert SDK AnswerResult (dataclass) to our Pydantic model."""
|
||||
return cls(
|
||||
id=getattr(sdk_citation, "id", ""),
|
||||
url=getattr(sdk_citation, "url", ""),
|
||||
title=getattr(sdk_citation, "title", None),
|
||||
author=getattr(sdk_citation, "author", None),
|
||||
publishedDate=getattr(sdk_citation, "published_date", None),
|
||||
text=getattr(sdk_citation, "text", None),
|
||||
image=getattr(sdk_citation, "image", None),
|
||||
favicon=getattr(sdk_citation, "favicon", None),
|
||||
)
|
||||
class SearchBreakdown(BaseModel):
|
||||
search: float
|
||||
contents: float
|
||||
breakdown: CostBreakdown
|
||||
|
||||
|
||||
class PerRequestPrices(BaseModel):
|
||||
neuralSearch_1_25_results: float
|
||||
neuralSearch_26_100_results: float
|
||||
neuralSearch_100_plus_results: float
|
||||
keywordSearch_1_100_results: float
|
||||
keywordSearch_100_plus_results: float
|
||||
|
||||
|
||||
class PerPagePrices(BaseModel):
|
||||
contentText: float
|
||||
contentHighlight: float
|
||||
contentSummary: float
|
||||
|
||||
|
||||
class CostDollars(BaseModel):
|
||||
total: float
|
||||
breakDown: list[SearchBreakdown]
|
||||
perRequestPrices: PerRequestPrices
|
||||
perPagePrices: PerPagePrices
|
||||
|
||||
|
||||
class ExaAnswerBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
@@ -62,21 +58,31 @@ class ExaAnswerBlock(Block):
|
||||
placeholder="What is the latest valuation of SpaceX?",
|
||||
)
|
||||
text: bool = SchemaField(
|
||||
description="Include full text content in the search results used for the answer",
|
||||
default=True,
|
||||
default=False,
|
||||
description="If true, the response includes full text content in the search results",
|
||||
advanced=True,
|
||||
)
|
||||
model: str = SchemaField(
|
||||
default="exa",
|
||||
description="The search model to use (exa or exa-pro)",
|
||||
placeholder="exa",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
answer: str = SchemaField(
|
||||
description="The generated answer based on search results"
|
||||
)
|
||||
citations: list[AnswerCitation] = SchemaField(
|
||||
description="Search results used to generate the answer"
|
||||
citations: list[dict] = SchemaField(
|
||||
description="Search results used to generate the answer",
|
||||
default_factory=list,
|
||||
)
|
||||
citation: AnswerCitation = SchemaField(
|
||||
description="Individual citation from the answer"
|
||||
cost_dollars: CostDollars = SchemaField(
|
||||
description="Cost breakdown of the request"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -90,24 +96,26 @@ class ExaAnswerBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
url = "https://api.exa.ai/answer"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
# Get answer using SDK (stream=False for blocks) - this IS async, needs await
|
||||
response = await aexa.answer(
|
||||
query=input_data.query, text=input_data.text, stream=False
|
||||
)
|
||||
# Build the payload
|
||||
payload = {
|
||||
"query": input_data.query,
|
||||
"text": input_data.text,
|
||||
"model": input_data.model,
|
||||
}
|
||||
|
||||
# this should remain true as long as they don't start defaulting to streaming only.
|
||||
# provides a bit of safety for sdk updates.
|
||||
assert type(response) is AnswerResponse
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
|
||||
yield "answer", response.answer
|
||||
yield "answer", data.get("answer", "")
|
||||
yield "citations", data.get("citations", [])
|
||||
yield "cost_dollars", data.get("costDollars", {})
|
||||
|
||||
citations = [
|
||||
AnswerCitation.from_sdk(sdk_citation)
|
||||
for sdk_citation in response.citations or []
|
||||
]
|
||||
|
||||
yield "citations", citations
|
||||
for citation in citations:
|
||||
yield "citation", citation
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
"""
|
||||
Exa Code Context Block
|
||||
|
||||
Provides code search capabilities to find relevant code snippets and examples
|
||||
from open source repositories, documentation, and Stack Overflow.
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
class CodeContextResponse(BaseModel):
|
||||
"""Stable output model for code context responses."""
|
||||
|
||||
request_id: str
|
||||
query: str
|
||||
response: str
|
||||
results_count: int
|
||||
cost_dollars: str
|
||||
search_time: float
|
||||
output_tokens: int
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, data: dict) -> "CodeContextResponse":
|
||||
"""Convert API response to our stable model."""
|
||||
return cls(
|
||||
request_id=data.get("requestId", ""),
|
||||
query=data.get("query", ""),
|
||||
response=data.get("response", ""),
|
||||
results_count=data.get("resultsCount", 0),
|
||||
cost_dollars=data.get("costDollars", ""),
|
||||
search_time=data.get("searchTime", 0.0),
|
||||
output_tokens=data.get("outputTokens", 0),
|
||||
)
|
||||
|
||||
|
||||
class ExaCodeContextBlock(Block):
|
||||
"""Get relevant code snippets and examples from open source repositories."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
query: str = SchemaField(
|
||||
description="Search query to find relevant code snippets. Describe what you're trying to do or what code you're looking for.",
|
||||
placeholder="how to use React hooks for state management",
|
||||
)
|
||||
tokens_num: Union[str, int] = SchemaField(
|
||||
default="dynamic",
|
||||
description="Token limit for response. Use 'dynamic' for automatic sizing, 5000 for standard queries, or 10000 for comprehensive examples.",
|
||||
placeholder="dynamic",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
request_id: str = SchemaField(description="Unique identifier for this request")
|
||||
query: str = SchemaField(description="The search query used")
|
||||
response: str = SchemaField(
|
||||
description="Formatted code snippets and contextual examples with sources"
|
||||
)
|
||||
results_count: int = SchemaField(
|
||||
description="Number of code sources found and included"
|
||||
)
|
||||
cost_dollars: str = SchemaField(description="Cost of this request in dollars")
|
||||
search_time: float = SchemaField(
|
||||
description="Time taken to search in milliseconds"
|
||||
)
|
||||
output_tokens: int = SchemaField(description="Number of tokens in the response")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8f9e0d1c-2b3a-4567-8901-23456789abcd",
|
||||
description="Search billions of GitHub repos, docs, and Stack Overflow for relevant code examples",
|
||||
categories={BlockCategory.SEARCH, BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=ExaCodeContextBlock.Input,
|
||||
output_schema=ExaCodeContextBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/context"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
payload = {
|
||||
"query": input_data.query,
|
||||
"tokensNum": input_data.tokens_num,
|
||||
}
|
||||
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
|
||||
context = CodeContextResponse.from_api(data)
|
||||
|
||||
yield "request_id", context.request_id
|
||||
yield "query", context.query
|
||||
yield "response", context.response
|
||||
yield "results_count", context.results_count
|
||||
yield "cost_dollars", context.cost_dollars
|
||||
yield "search_time", context.search_time
|
||||
yield "output_tokens", context.output_tokens
|
||||
@@ -1,127 +1,39 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import (
|
||||
CostDollars,
|
||||
ExaSearchResults,
|
||||
ExtrasSettings,
|
||||
HighlightSettings,
|
||||
LivecrawlTypes,
|
||||
SummarySettings,
|
||||
)
|
||||
|
||||
|
||||
class ContentStatusTag(str, Enum):
|
||||
CRAWL_NOT_FOUND = "CRAWL_NOT_FOUND"
|
||||
CRAWL_TIMEOUT = "CRAWL_TIMEOUT"
|
||||
CRAWL_LIVECRAWL_TIMEOUT = "CRAWL_LIVECRAWL_TIMEOUT"
|
||||
SOURCE_NOT_AVAILABLE = "SOURCE_NOT_AVAILABLE"
|
||||
CRAWL_UNKNOWN_ERROR = "CRAWL_UNKNOWN_ERROR"
|
||||
|
||||
|
||||
class ContentError(BaseModel):
|
||||
tag: Optional[ContentStatusTag] = SchemaField(
|
||||
default=None, description="Specific error type"
|
||||
)
|
||||
httpStatusCode: Optional[int] = SchemaField(
|
||||
default=None, description="The corresponding HTTP status code"
|
||||
)
|
||||
|
||||
|
||||
class ContentStatus(BaseModel):
|
||||
id: str = SchemaField(description="The URL that was requested")
|
||||
status: str = SchemaField(
|
||||
description="Status of the content fetch operation (success or error)"
|
||||
)
|
||||
error: Optional[ContentError] = SchemaField(
|
||||
default=None, description="Error details, only present when status is 'error'"
|
||||
)
|
||||
from .helpers import ContentSettings
|
||||
|
||||
|
||||
class ExaContentsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
urls: list[str] = SchemaField(
|
||||
description="Array of URLs to crawl (preferred over 'ids')",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
ids: list[str] = SchemaField(
|
||||
description="[DEPRECATED - use 'urls' instead] Array of document IDs obtained from searches",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
description="Array of document IDs obtained from searches"
|
||||
)
|
||||
text: bool = SchemaField(
|
||||
description="Retrieve text content from pages",
|
||||
default=True,
|
||||
)
|
||||
highlights: HighlightSettings = SchemaField(
|
||||
description="Text snippets most relevant from each page",
|
||||
default=HighlightSettings(),
|
||||
)
|
||||
summary: SummarySettings = SchemaField(
|
||||
description="LLM-generated summary of the webpage",
|
||||
default=SummarySettings(),
|
||||
)
|
||||
livecrawl: Optional[LivecrawlTypes] = SchemaField(
|
||||
description="Livecrawling options: never, fallback (default), always, preferred",
|
||||
default=LivecrawlTypes.FALLBACK,
|
||||
advanced=True,
|
||||
)
|
||||
livecrawl_timeout: Optional[int] = SchemaField(
|
||||
description="Timeout for livecrawling in milliseconds",
|
||||
default=10000,
|
||||
advanced=True,
|
||||
)
|
||||
subpages: Optional[int] = SchemaField(
|
||||
description="Number of subpages to crawl", default=0, ge=0, advanced=True
|
||||
)
|
||||
subpage_target: Optional[str | list[str]] = SchemaField(
|
||||
description="Keyword(s) to find specific subpages of search results",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
extras: ExtrasSettings = SchemaField(
|
||||
description="Extra parameters for additional content",
|
||||
default=ExtrasSettings(),
|
||||
contents: ContentSettings = SchemaField(
|
||||
description="Content retrieval settings",
|
||||
default=ContentSettings(),
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
results: list[ExaSearchResults] = SchemaField(
|
||||
description="List of document contents with metadata"
|
||||
class Output(BlockSchema):
|
||||
results: list = SchemaField(
|
||||
description="List of document contents", default_factory=list
|
||||
)
|
||||
result: ExaSearchResults = SchemaField(
|
||||
description="Single document content result"
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
context: str = SchemaField(
|
||||
description="A formatted string of the results ready for LLMs"
|
||||
)
|
||||
request_id: str = SchemaField(description="Unique identifier for the request")
|
||||
statuses: list[ContentStatus] = SchemaField(
|
||||
description="Status information for each requested URL"
|
||||
)
|
||||
cost_dollars: Optional[CostDollars] = SchemaField(
|
||||
description="Cost breakdown for the request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -135,91 +47,23 @@ class ExaContentsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
if not input_data.urls and not input_data.ids:
|
||||
raise ValueError("Either 'urls' or 'ids' must be provided")
|
||||
url = "https://api.exa.ai/contents"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
sdk_kwargs = {}
|
||||
# Convert ContentSettings to API format
|
||||
payload = {
|
||||
"ids": input_data.ids,
|
||||
"text": input_data.contents.text,
|
||||
"highlights": input_data.contents.highlights,
|
||||
"summary": input_data.contents.summary,
|
||||
}
|
||||
|
||||
# Prefer urls over ids
|
||||
if input_data.urls:
|
||||
sdk_kwargs["urls"] = input_data.urls
|
||||
elif input_data.ids:
|
||||
sdk_kwargs["ids"] = input_data.ids
|
||||
|
||||
if input_data.text:
|
||||
sdk_kwargs["text"] = {"includeHtmlTags": True}
|
||||
|
||||
# Handle highlights - only include if modified from defaults
|
||||
if input_data.highlights and (
|
||||
input_data.highlights.num_sentences != 1
|
||||
or input_data.highlights.highlights_per_url != 1
|
||||
or input_data.highlights.query is not None
|
||||
):
|
||||
highlights_dict = {}
|
||||
highlights_dict["numSentences"] = input_data.highlights.num_sentences
|
||||
highlights_dict["highlightsPerUrl"] = (
|
||||
input_data.highlights.highlights_per_url
|
||||
)
|
||||
if input_data.highlights.query:
|
||||
highlights_dict["query"] = input_data.highlights.query
|
||||
sdk_kwargs["highlights"] = highlights_dict
|
||||
|
||||
# Handle summary - only include if modified from defaults
|
||||
if input_data.summary and (
|
||||
input_data.summary.query is not None
|
||||
or input_data.summary.schema is not None
|
||||
):
|
||||
summary_dict = {}
|
||||
if input_data.summary.query:
|
||||
summary_dict["query"] = input_data.summary.query
|
||||
if input_data.summary.schema:
|
||||
summary_dict["schema"] = input_data.summary.schema
|
||||
sdk_kwargs["summary"] = summary_dict
|
||||
|
||||
if input_data.livecrawl:
|
||||
sdk_kwargs["livecrawl"] = input_data.livecrawl.value
|
||||
|
||||
if input_data.livecrawl_timeout is not None:
|
||||
sdk_kwargs["livecrawl_timeout"] = input_data.livecrawl_timeout
|
||||
|
||||
if input_data.subpages is not None:
|
||||
sdk_kwargs["subpages"] = input_data.subpages
|
||||
|
||||
if input_data.subpage_target:
|
||||
sdk_kwargs["subpage_target"] = input_data.subpage_target
|
||||
|
||||
# Handle extras - only include if modified from defaults
|
||||
if input_data.extras and (
|
||||
input_data.extras.links > 0 or input_data.extras.image_links > 0
|
||||
):
|
||||
extras_dict = {}
|
||||
if input_data.extras.links:
|
||||
extras_dict["links"] = input_data.extras.links
|
||||
if input_data.extras.image_links:
|
||||
extras_dict["image_links"] = input_data.extras.image_links
|
||||
sdk_kwargs["extras"] = extras_dict
|
||||
|
||||
# Always enable context for LLM-ready output
|
||||
sdk_kwargs["context"] = True
|
||||
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
response = await aexa.get_contents(**sdk_kwargs)
|
||||
|
||||
converted_results = [
|
||||
ExaSearchResults.from_sdk(sdk_result)
|
||||
for sdk_result in response.results or []
|
||||
]
|
||||
|
||||
yield "results", converted_results
|
||||
|
||||
for result in converted_results:
|
||||
yield "result", result
|
||||
|
||||
if response.context:
|
||||
yield "context", response.context
|
||||
|
||||
if response.statuses:
|
||||
yield "statuses", response.statuses
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
yield "results", data.get("results", [])
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -1,150 +1,51 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import BaseModel, MediaFileType, SchemaField
|
||||
from backend.sdk import BaseModel, SchemaField
|
||||
|
||||
|
||||
class LivecrawlTypes(str, Enum):
|
||||
NEVER = "never"
|
||||
FALLBACK = "fallback"
|
||||
ALWAYS = "always"
|
||||
PREFERRED = "preferred"
|
||||
|
||||
|
||||
class TextEnabled(BaseModel):
|
||||
discriminator: Literal["enabled"] = "enabled"
|
||||
|
||||
|
||||
class TextDisabled(BaseModel):
|
||||
discriminator: Literal["disabled"] = "disabled"
|
||||
|
||||
|
||||
class TextAdvanced(BaseModel):
|
||||
discriminator: Literal["advanced"] = "advanced"
|
||||
max_characters: Optional[int] = SchemaField(
|
||||
default=None,
|
||||
class TextSettings(BaseModel):
|
||||
max_characters: int = SchemaField(
|
||||
default=1000,
|
||||
description="Maximum number of characters to return",
|
||||
placeholder="1000",
|
||||
)
|
||||
include_html_tags: bool = SchemaField(
|
||||
default=False,
|
||||
description="Include HTML tags in the response, helps LLMs understand text structure",
|
||||
description="Whether to include HTML tags in the text",
|
||||
placeholder="False",
|
||||
)
|
||||
|
||||
|
||||
class HighlightSettings(BaseModel):
|
||||
num_sentences: int = SchemaField(
|
||||
default=1,
|
||||
default=3,
|
||||
description="Number of sentences per highlight",
|
||||
placeholder="1",
|
||||
ge=1,
|
||||
placeholder="3",
|
||||
)
|
||||
highlights_per_url: int = SchemaField(
|
||||
default=1,
|
||||
default=3,
|
||||
description="Number of highlights per URL",
|
||||
placeholder="1",
|
||||
ge=1,
|
||||
)
|
||||
query: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Custom query to direct the LLM's selection of highlights",
|
||||
placeholder="Key advancements",
|
||||
placeholder="3",
|
||||
)
|
||||
|
||||
|
||||
class SummarySettings(BaseModel):
|
||||
query: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Custom query for the LLM-generated summary",
|
||||
placeholder="Main developments",
|
||||
)
|
||||
schema: Optional[dict] = SchemaField( # type: ignore
|
||||
default=None,
|
||||
description="JSON schema for structured output from summary",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class ExtrasSettings(BaseModel):
|
||||
links: int = SchemaField(
|
||||
default=0,
|
||||
description="Number of URLs to return from each webpage",
|
||||
placeholder="1",
|
||||
ge=0,
|
||||
)
|
||||
image_links: int = SchemaField(
|
||||
default=0,
|
||||
description="Number of images to return for each result",
|
||||
placeholder="1",
|
||||
ge=0,
|
||||
)
|
||||
|
||||
|
||||
class ContextEnabled(BaseModel):
|
||||
discriminator: Literal["enabled"] = "enabled"
|
||||
|
||||
|
||||
class ContextDisabled(BaseModel):
|
||||
discriminator: Literal["disabled"] = "disabled"
|
||||
|
||||
|
||||
class ContextAdvanced(BaseModel):
|
||||
discriminator: Literal["advanced"] = "advanced"
|
||||
max_characters: Optional[int] = SchemaField(
|
||||
default=None,
|
||||
description="Maximum character limit for context string",
|
||||
placeholder="10000",
|
||||
default="",
|
||||
description="Query string for summarization",
|
||||
placeholder="Enter query",
|
||||
)
|
||||
|
||||
|
||||
class ContentSettings(BaseModel):
|
||||
text: Optional[Union[bool, TextEnabled, TextDisabled, TextAdvanced]] = SchemaField(
|
||||
default=None,
|
||||
description="Text content retrieval. Boolean for simple enable/disable or object for advanced settings",
|
||||
text: TextSettings = SchemaField(
|
||||
default=TextSettings(),
|
||||
)
|
||||
highlights: Optional[HighlightSettings] = SchemaField(
|
||||
default=None,
|
||||
description="Text snippets most relevant from each page",
|
||||
highlights: HighlightSettings = SchemaField(
|
||||
default=HighlightSettings(),
|
||||
)
|
||||
summary: Optional[SummarySettings] = SchemaField(
|
||||
default=None,
|
||||
description="LLM-generated summary of the webpage",
|
||||
)
|
||||
livecrawl: Optional[LivecrawlTypes] = SchemaField(
|
||||
default=None,
|
||||
description="Livecrawling options: never, fallback, always, preferred",
|
||||
advanced=True,
|
||||
)
|
||||
livecrawl_timeout: Optional[int] = SchemaField(
|
||||
default=None,
|
||||
description="Timeout for livecrawling in milliseconds",
|
||||
placeholder="10000",
|
||||
advanced=True,
|
||||
)
|
||||
subpages: Optional[int] = SchemaField(
|
||||
default=None,
|
||||
description="Number of subpages to crawl",
|
||||
placeholder="0",
|
||||
ge=0,
|
||||
advanced=True,
|
||||
)
|
||||
subpage_target: Optional[Union[str, list[str]]] = SchemaField(
|
||||
default=None,
|
||||
description="Keyword(s) to find specific subpages of search results",
|
||||
advanced=True,
|
||||
)
|
||||
extras: Optional[ExtrasSettings] = SchemaField(
|
||||
default=None,
|
||||
description="Extra parameters for additional content",
|
||||
advanced=True,
|
||||
)
|
||||
context: Optional[Union[bool, ContextEnabled, ContextDisabled, ContextAdvanced]] = (
|
||||
SchemaField(
|
||||
default=None,
|
||||
description="Format search results into a context string for LLMs",
|
||||
advanced=True,
|
||||
)
|
||||
summary: SummarySettings = SchemaField(
|
||||
default=SummarySettings(),
|
||||
)
|
||||
|
||||
|
||||
@@ -226,225 +127,3 @@ class WebsetEnrichmentConfig(BaseModel):
|
||||
default=None,
|
||||
description="Options for the enrichment",
|
||||
)
|
||||
|
||||
|
||||
# Shared result models
|
||||
class ExaSearchExtras(BaseModel):
|
||||
links: list[str] = SchemaField(
|
||||
default_factory=list, description="Array of links from the search result"
|
||||
)
|
||||
imageLinks: list[str] = SchemaField(
|
||||
default_factory=list, description="Array of image links from the search result"
|
||||
)
|
||||
|
||||
|
||||
class ExaSearchResults(BaseModel):
|
||||
title: str | None = None
|
||||
url: str | None = None
|
||||
publishedDate: str | None = None
|
||||
author: str | None = None
|
||||
id: str
|
||||
image: MediaFileType | None = None
|
||||
favicon: MediaFileType | None = None
|
||||
text: str | None = None
|
||||
highlights: list[str] = SchemaField(default_factory=list)
|
||||
highlightScores: list[float] = SchemaField(default_factory=list)
|
||||
summary: str | None = None
|
||||
subpages: list[dict] = SchemaField(default_factory=list)
|
||||
extras: ExaSearchExtras | None = None
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, sdk_result) -> "ExaSearchResults":
|
||||
"""Convert SDK Result (dataclass) to our Pydantic model."""
|
||||
return cls(
|
||||
id=getattr(sdk_result, "id", ""),
|
||||
url=getattr(sdk_result, "url", None),
|
||||
title=getattr(sdk_result, "title", None),
|
||||
author=getattr(sdk_result, "author", None),
|
||||
publishedDate=getattr(sdk_result, "published_date", None),
|
||||
text=getattr(sdk_result, "text", None),
|
||||
highlights=getattr(sdk_result, "highlights", None) or [],
|
||||
highlightScores=getattr(sdk_result, "highlight_scores", None) or [],
|
||||
summary=getattr(sdk_result, "summary", None),
|
||||
subpages=getattr(sdk_result, "subpages", None) or [],
|
||||
image=getattr(sdk_result, "image", None),
|
||||
favicon=getattr(sdk_result, "favicon", None),
|
||||
extras=getattr(sdk_result, "extras", None),
|
||||
)
|
||||
|
||||
|
||||
# Cost tracking models
|
||||
class CostBreakdown(BaseModel):
|
||||
keywordSearch: float = SchemaField(default=0.0)
|
||||
neuralSearch: float = SchemaField(default=0.0)
|
||||
contentText: float = SchemaField(default=0.0)
|
||||
contentHighlight: float = SchemaField(default=0.0)
|
||||
contentSummary: float = SchemaField(default=0.0)
|
||||
|
||||
|
||||
class CostBreakdownItem(BaseModel):
|
||||
search: float = SchemaField(default=0.0)
|
||||
contents: float = SchemaField(default=0.0)
|
||||
breakdown: CostBreakdown = SchemaField(default_factory=CostBreakdown)
|
||||
|
||||
|
||||
class PerRequestPrices(BaseModel):
|
||||
neuralSearch_1_25_results: float = SchemaField(default=0.005)
|
||||
neuralSearch_26_100_results: float = SchemaField(default=0.025)
|
||||
neuralSearch_100_plus_results: float = SchemaField(default=1.0)
|
||||
keywordSearch_1_100_results: float = SchemaField(default=0.0025)
|
||||
keywordSearch_100_plus_results: float = SchemaField(default=3.0)
|
||||
|
||||
|
||||
class PerPagePrices(BaseModel):
|
||||
contentText: float = SchemaField(default=0.001)
|
||||
contentHighlight: float = SchemaField(default=0.001)
|
||||
contentSummary: float = SchemaField(default=0.001)
|
||||
|
||||
|
||||
class CostDollars(BaseModel):
|
||||
total: float = SchemaField(description="Total dollar cost for your request")
|
||||
breakDown: list[CostBreakdownItem] = SchemaField(
|
||||
default_factory=list, description="Breakdown of costs by operation type"
|
||||
)
|
||||
perRequestPrices: PerRequestPrices = SchemaField(
|
||||
default_factory=PerRequestPrices,
|
||||
description="Standard price per request for different operations",
|
||||
)
|
||||
perPagePrices: PerPagePrices = SchemaField(
|
||||
default_factory=PerPagePrices,
|
||||
description="Standard price per page for different content operations",
|
||||
)
|
||||
|
||||
|
||||
# Helper functions for payload processing
|
||||
def process_text_field(
|
||||
text: Union[bool, TextEnabled, TextDisabled, TextAdvanced, None]
|
||||
) -> Optional[Union[bool, Dict[str, Any]]]:
|
||||
"""Process text field for API payload."""
|
||||
if text is None:
|
||||
return None
|
||||
|
||||
# Handle backward compatibility with boolean
|
||||
if isinstance(text, bool):
|
||||
return text
|
||||
elif isinstance(text, TextDisabled):
|
||||
return False
|
||||
elif isinstance(text, TextEnabled):
|
||||
return True
|
||||
elif isinstance(text, TextAdvanced):
|
||||
text_dict = {}
|
||||
if text.max_characters:
|
||||
text_dict["maxCharacters"] = text.max_characters
|
||||
if text.include_html_tags:
|
||||
text_dict["includeHtmlTags"] = text.include_html_tags
|
||||
return text_dict if text_dict else True
|
||||
return None
|
||||
|
||||
|
||||
def process_contents_settings(contents: Optional[ContentSettings]) -> Dict[str, Any]:
|
||||
"""Process ContentSettings into API payload format."""
|
||||
if not contents:
|
||||
return {}
|
||||
|
||||
content_settings = {}
|
||||
|
||||
# Handle text field (can be boolean or object)
|
||||
text_value = process_text_field(contents.text)
|
||||
if text_value is not None:
|
||||
content_settings["text"] = text_value
|
||||
|
||||
# Handle highlights
|
||||
if contents.highlights:
|
||||
highlights_dict: Dict[str, Any] = {
|
||||
"numSentences": contents.highlights.num_sentences,
|
||||
"highlightsPerUrl": contents.highlights.highlights_per_url,
|
||||
}
|
||||
if contents.highlights.query:
|
||||
highlights_dict["query"] = contents.highlights.query
|
||||
content_settings["highlights"] = highlights_dict
|
||||
|
||||
if contents.summary:
|
||||
summary_dict = {}
|
||||
if contents.summary.query:
|
||||
summary_dict["query"] = contents.summary.query
|
||||
if contents.summary.schema:
|
||||
summary_dict["schema"] = contents.summary.schema
|
||||
content_settings["summary"] = summary_dict
|
||||
|
||||
if contents.livecrawl:
|
||||
content_settings["livecrawl"] = contents.livecrawl.value
|
||||
|
||||
if contents.livecrawl_timeout is not None:
|
||||
content_settings["livecrawlTimeout"] = contents.livecrawl_timeout
|
||||
|
||||
if contents.subpages is not None:
|
||||
content_settings["subpages"] = contents.subpages
|
||||
|
||||
if contents.subpage_target:
|
||||
content_settings["subpageTarget"] = contents.subpage_target
|
||||
|
||||
if contents.extras:
|
||||
extras_dict = {}
|
||||
if contents.extras.links:
|
||||
extras_dict["links"] = contents.extras.links
|
||||
if contents.extras.image_links:
|
||||
extras_dict["imageLinks"] = contents.extras.image_links
|
||||
content_settings["extras"] = extras_dict
|
||||
|
||||
context_value = process_context_field(contents.context)
|
||||
if context_value is not None:
|
||||
content_settings["context"] = context_value
|
||||
|
||||
return content_settings
|
||||
|
||||
|
||||
def process_context_field(
|
||||
context: Union[bool, dict, ContextEnabled, ContextDisabled, ContextAdvanced, None]
|
||||
) -> Optional[Union[bool, Dict[str, int]]]:
|
||||
"""Process context field for API payload."""
|
||||
if context is None:
|
||||
return None
|
||||
|
||||
# Handle backward compatibility with boolean
|
||||
if isinstance(context, bool):
|
||||
return context if context else None
|
||||
elif isinstance(context, dict) and "maxCharacters" in context:
|
||||
return {"maxCharacters": context["maxCharacters"]}
|
||||
elif isinstance(context, ContextDisabled):
|
||||
return None # Don't send context field at all when disabled
|
||||
elif isinstance(context, ContextEnabled):
|
||||
return True
|
||||
elif isinstance(context, ContextAdvanced):
|
||||
if context.max_characters:
|
||||
return {"maxCharacters": context.max_characters}
|
||||
return True
|
||||
return None
|
||||
|
||||
|
||||
def format_date_fields(
|
||||
input_data: Any, date_field_mapping: Dict[str, str]
|
||||
) -> Dict[str, str]:
|
||||
"""Format datetime fields for API payload."""
|
||||
formatted_dates = {}
|
||||
for input_field, api_field in date_field_mapping.items():
|
||||
value = getattr(input_data, input_field, None)
|
||||
if value:
|
||||
formatted_dates[api_field] = value.strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
return formatted_dates
|
||||
|
||||
|
||||
def add_optional_fields(
|
||||
input_data: Any,
|
||||
field_mapping: Dict[str, str],
|
||||
payload: Dict[str, Any],
|
||||
process_enums: bool = False,
|
||||
) -> None:
|
||||
"""Add optional fields to payload if they have values."""
|
||||
for input_field, api_field in field_mapping.items():
|
||||
value = getattr(input_data, input_field, None)
|
||||
if value: # Only add non-empty values
|
||||
if process_enums and hasattr(value, "value"):
|
||||
payload[api_field] = value.value
|
||||
else:
|
||||
payload[api_field] = value
|
||||
|
||||
247
autogpt_platform/backend/backend/blocks/exa/model.py
Normal file
247
autogpt_platform/backend/backend/blocks/exa/model.py
Normal file
@@ -0,0 +1,247 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# Enum definitions based on available options
|
||||
class WebsetStatus(str, Enum):
|
||||
IDLE = "idle"
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
|
||||
|
||||
class WebsetSearchStatus(str, Enum):
|
||||
CREATED = "created"
|
||||
# Add more if known, based on example it's "created"
|
||||
|
||||
|
||||
class ImportStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class ImportFormat(str, Enum):
|
||||
CSV = "csv"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class EnrichmentStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class EnrichmentFormat(str, Enum):
|
||||
TEXT = "text"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorStatus(str, Enum):
|
||||
ENABLED = "enabled"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorBehaviorType(str, Enum):
|
||||
SEARCH = "search"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorRunStatus(str, Enum):
|
||||
CREATED = "created"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class CanceledReason(str, Enum):
|
||||
WEBSET_DELETED = "webset_deleted"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class FailedReason(str, Enum):
|
||||
INVALID_FORMAT = "invalid_format"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class Confidence(str, Enum):
|
||||
HIGH = "high"
|
||||
# Add more if known
|
||||
|
||||
|
||||
# Nested models
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
class Criterion(BaseModel):
|
||||
description: str
|
||||
successRate: Optional[int] = None
|
||||
|
||||
|
||||
class ExcludeItem(BaseModel):
|
||||
source: str = Field(default="import")
|
||||
id: str
|
||||
|
||||
|
||||
class Relationship(BaseModel):
|
||||
definition: str
|
||||
limit: Optional[float] = None
|
||||
|
||||
|
||||
class ScopeItem(BaseModel):
|
||||
source: str = Field(default="import")
|
||||
id: str
|
||||
relationship: Optional[Relationship] = None
|
||||
|
||||
|
||||
class Progress(BaseModel):
|
||||
found: int
|
||||
analyzed: int
|
||||
completion: int
|
||||
timeLeft: int
|
||||
|
||||
|
||||
class Bounds(BaseModel):
|
||||
min: int
|
||||
max: int
|
||||
|
||||
|
||||
class Expected(BaseModel):
|
||||
total: int
|
||||
confidence: str = Field(default="high") # Use str or Confidence enum
|
||||
bounds: Bounds
|
||||
|
||||
|
||||
class Recall(BaseModel):
|
||||
expected: Expected
|
||||
reasoning: str
|
||||
|
||||
|
||||
class WebsetSearch(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset_search")
|
||||
status: str = Field(default="created") # Or use WebsetSearchStatus
|
||||
websetId: str
|
||||
query: str
|
||||
entity: Entity
|
||||
criteria: List[Criterion]
|
||||
count: int
|
||||
behavior: str = Field(default="override")
|
||||
exclude: List[ExcludeItem]
|
||||
scope: List[ScopeItem]
|
||||
progress: Progress
|
||||
recall: Recall
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
canceledAt: Optional[datetime] = None
|
||||
canceledReason: Optional[str] = Field(default=None) # Or use CanceledReason
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class ImportEntity(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
class Import(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="import")
|
||||
status: str = Field(default="pending") # Or use ImportStatus
|
||||
format: str = Field(default="csv") # Or use ImportFormat
|
||||
entity: ImportEntity
|
||||
title: str
|
||||
count: int
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
failedReason: Optional[str] = Field(default=None) # Or use FailedReason
|
||||
failedAt: Optional[datetime] = None
|
||||
failedMessage: Optional[str] = None
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Option(BaseModel):
|
||||
label: str
|
||||
|
||||
|
||||
class WebsetEnrichment(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset_enrichment")
|
||||
status: str = Field(default="pending") # Or use EnrichmentStatus
|
||||
websetId: str
|
||||
title: str
|
||||
description: str
|
||||
format: str = Field(default="text") # Or use EnrichmentFormat
|
||||
options: List[Option]
|
||||
instructions: str
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Cadence(BaseModel):
|
||||
cron: str
|
||||
timezone: str = Field(default="Etc/UTC")
|
||||
|
||||
|
||||
class BehaviorConfig(BaseModel):
|
||||
query: Optional[str] = None
|
||||
criteria: Optional[List[Criterion]] = None
|
||||
entity: Optional[Entity] = None
|
||||
count: Optional[int] = None
|
||||
behavior: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class Behavior(BaseModel):
|
||||
type: str = Field(default="search") # Or use MonitorBehaviorType
|
||||
config: BehaviorConfig
|
||||
|
||||
|
||||
class MonitorRun(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="monitor_run")
|
||||
status: str = Field(default="created") # Or use MonitorRunStatus
|
||||
monitorId: str
|
||||
type: str = Field(default="search")
|
||||
completedAt: Optional[datetime] = None
|
||||
failedAt: Optional[datetime] = None
|
||||
failedReason: Optional[str] = None
|
||||
canceledAt: Optional[datetime] = None
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Monitor(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="monitor")
|
||||
status: str = Field(default="enabled") # Or use MonitorStatus
|
||||
websetId: str
|
||||
cadence: Cadence
|
||||
behavior: Behavior
|
||||
lastRun: Optional[MonitorRun] = None
|
||||
nextRunAt: Optional[datetime] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Webset(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset")
|
||||
status: WebsetStatus
|
||||
externalId: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
searches: List[WebsetSearch]
|
||||
imports: List[Import]
|
||||
enrichments: List[WebsetEnrichment]
|
||||
monitors: List[Monitor]
|
||||
streams: List[Any]
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ListWebsets(BaseModel):
|
||||
data: List[Webset]
|
||||
hasMore: bool
|
||||
nextCursor: Optional[str] = None
|
||||
@@ -1,518 +0,0 @@
|
||||
"""
|
||||
Exa Research Task Blocks
|
||||
|
||||
Provides asynchronous research capabilities that explore the web, gather sources,
|
||||
synthesize findings, and return structured results with citations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
class ResearchModel(str, Enum):
|
||||
"""Available research models."""
|
||||
|
||||
FAST = "exa-research-fast"
|
||||
STANDARD = "exa-research"
|
||||
PRO = "exa-research-pro"
|
||||
|
||||
|
||||
class ResearchStatus(str, Enum):
|
||||
"""Research task status."""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
CANCELED = "canceled"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class ResearchCostModel(BaseModel):
|
||||
"""Cost breakdown for a research request."""
|
||||
|
||||
total: float
|
||||
num_searches: int
|
||||
num_pages: int
|
||||
reasoning_tokens: int
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, data: dict) -> "ResearchCostModel":
|
||||
"""Convert API response, rounding fractional counts to integers."""
|
||||
return cls(
|
||||
total=data.get("total", 0.0),
|
||||
num_searches=int(round(data.get("numSearches", 0))),
|
||||
num_pages=int(round(data.get("numPages", 0))),
|
||||
reasoning_tokens=int(round(data.get("reasoningTokens", 0))),
|
||||
)
|
||||
|
||||
|
||||
class ResearchOutputModel(BaseModel):
|
||||
"""Research output with content and optional structured data."""
|
||||
|
||||
content: str
|
||||
parsed: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ResearchTaskModel(BaseModel):
|
||||
"""Stable output model for research tasks."""
|
||||
|
||||
research_id: str
|
||||
created_at: int
|
||||
model: str
|
||||
instructions: str
|
||||
status: str
|
||||
output_schema: Optional[Dict[str, Any]] = None
|
||||
output: Optional[ResearchOutputModel] = None
|
||||
cost_dollars: Optional[ResearchCostModel] = None
|
||||
finished_at: Optional[int] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, data: dict) -> "ResearchTaskModel":
|
||||
"""Convert API response to our stable model."""
|
||||
output_data = data.get("output")
|
||||
output = None
|
||||
if output_data:
|
||||
output = ResearchOutputModel(
|
||||
content=output_data.get("content", ""),
|
||||
parsed=output_data.get("parsed"),
|
||||
)
|
||||
|
||||
cost_data = data.get("costDollars")
|
||||
cost = None
|
||||
if cost_data:
|
||||
cost = ResearchCostModel.from_api(cost_data)
|
||||
|
||||
return cls(
|
||||
research_id=data.get("researchId", ""),
|
||||
created_at=data.get("createdAt", 0),
|
||||
model=data.get("model", "exa-research"),
|
||||
instructions=data.get("instructions", ""),
|
||||
status=data.get("status", "pending"),
|
||||
output_schema=data.get("outputSchema"),
|
||||
output=output,
|
||||
cost_dollars=cost,
|
||||
finished_at=data.get("finishedAt"),
|
||||
error=data.get("error"),
|
||||
)
|
||||
|
||||
|
||||
class ExaCreateResearchBlock(Block):
|
||||
"""Create an asynchronous research task that explores the web and synthesizes findings."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
instructions: str = SchemaField(
|
||||
description="Research instructions - clearly define what information to find, how to conduct research, and desired output format.",
|
||||
placeholder="Research the top 5 AI coding assistants, their features, pricing, and user reviews",
|
||||
)
|
||||
model: ResearchModel = SchemaField(
|
||||
default=ResearchModel.STANDARD,
|
||||
description="Research model: 'fast' for quick results, 'standard' for balanced quality, 'pro' for thorough analysis",
|
||||
)
|
||||
output_schema: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="JSON Schema to enforce structured output. When provided, results are validated and returned as parsed JSON.",
|
||||
advanced=True,
|
||||
)
|
||||
wait_for_completion: bool = SchemaField(
|
||||
default=True,
|
||||
description="Wait for research to complete before returning. Ensures you get results immediately.",
|
||||
)
|
||||
polling_timeout: int = SchemaField(
|
||||
default=600,
|
||||
description="Maximum time to wait for completion in seconds (only if wait_for_completion is True)",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=3600,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
research_id: str = SchemaField(
|
||||
description="Unique identifier for tracking this research request"
|
||||
)
|
||||
status: str = SchemaField(description="Final status of the research")
|
||||
model: str = SchemaField(description="The research model used")
|
||||
instructions: str = SchemaField(
|
||||
description="The research instructions provided"
|
||||
)
|
||||
created_at: int = SchemaField(
|
||||
description="When the research was created (Unix timestamp in ms)"
|
||||
)
|
||||
output_content: Optional[str] = SchemaField(
|
||||
description="Research output as text (only if wait_for_completion was True and completed)"
|
||||
)
|
||||
output_parsed: Optional[dict] = SchemaField(
|
||||
description="Structured JSON output (only if wait_for_completion and outputSchema were provided)"
|
||||
)
|
||||
cost_total: Optional[float] = SchemaField(
|
||||
description="Total cost in USD (only if wait_for_completion was True and completed)"
|
||||
)
|
||||
elapsed_time: Optional[float] = SchemaField(
|
||||
description="Time taken to complete in seconds (only if wait_for_completion was True)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a1f2e3d4-c5b6-4a78-9012-3456789abcde",
|
||||
description="Create research task with optional waiting - explores web and synthesizes findings with citations",
|
||||
categories={BlockCategory.SEARCH, BlockCategory.AI},
|
||||
input_schema=ExaCreateResearchBlock.Input,
|
||||
output_schema=ExaCreateResearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/research/v1"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"model": input_data.model.value,
|
||||
"instructions": input_data.instructions,
|
||||
}
|
||||
|
||||
if input_data.output_schema:
|
||||
payload["outputSchema"] = input_data.output_schema
|
||||
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
|
||||
research_id = data.get("researchId", "")
|
||||
|
||||
if input_data.wait_for_completion:
|
||||
start_time = time.time()
|
||||
get_url = f"https://api.exa.ai/research/v1/{research_id}"
|
||||
get_headers = {"x-api-key": credentials.api_key.get_secret_value()}
|
||||
check_interval = 10
|
||||
|
||||
while time.time() - start_time < input_data.polling_timeout:
|
||||
poll_response = await Requests().get(url=get_url, headers=get_headers)
|
||||
poll_data = poll_response.json()
|
||||
|
||||
status = poll_data.get("status", "")
|
||||
|
||||
if status in ["completed", "failed", "canceled"]:
|
||||
elapsed = time.time() - start_time
|
||||
research = ResearchTaskModel.from_api(poll_data)
|
||||
|
||||
yield "research_id", research.research_id
|
||||
yield "status", research.status
|
||||
yield "model", research.model
|
||||
yield "instructions", research.instructions
|
||||
yield "created_at", research.created_at
|
||||
yield "elapsed_time", elapsed
|
||||
|
||||
if research.output:
|
||||
yield "output_content", research.output.content
|
||||
yield "output_parsed", research.output.parsed
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
return
|
||||
|
||||
await asyncio.sleep(check_interval)
|
||||
|
||||
raise ValueError(
|
||||
f"Research did not complete within {input_data.polling_timeout} seconds"
|
||||
)
|
||||
else:
|
||||
yield "research_id", research_id
|
||||
yield "status", data.get("status", "pending")
|
||||
yield "model", data.get("model", input_data.model.value)
|
||||
yield "instructions", data.get("instructions", input_data.instructions)
|
||||
yield "created_at", data.get("createdAt", 0)
|
||||
|
||||
|
||||
class ExaGetResearchBlock(Block):
|
||||
"""Get the status and results of a research task."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
research_id: str = SchemaField(
|
||||
description="The ID of the research task to retrieve",
|
||||
placeholder="01jszdfs0052sg4jc552sg4jc5",
|
||||
)
|
||||
include_events: bool = SchemaField(
|
||||
default=False,
|
||||
description="Include detailed event log of research operations",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
research_id: str = SchemaField(description="The research task identifier")
|
||||
status: str = SchemaField(
|
||||
description="Current status: pending, running, completed, canceled, or failed"
|
||||
)
|
||||
instructions: str = SchemaField(
|
||||
description="The original research instructions"
|
||||
)
|
||||
model: str = SchemaField(description="The research model used")
|
||||
created_at: int = SchemaField(
|
||||
description="When research was created (Unix timestamp in ms)"
|
||||
)
|
||||
finished_at: Optional[int] = SchemaField(
|
||||
description="When research finished (Unix timestamp in ms, if completed/canceled/failed)"
|
||||
)
|
||||
output_content: Optional[str] = SchemaField(
|
||||
description="Research output as text (if completed)"
|
||||
)
|
||||
output_parsed: Optional[dict] = SchemaField(
|
||||
description="Structured JSON output matching outputSchema (if provided and completed)"
|
||||
)
|
||||
cost_total: Optional[float] = SchemaField(
|
||||
description="Total cost in USD (if completed)"
|
||||
)
|
||||
cost_searches: Optional[int] = SchemaField(
|
||||
description="Number of searches performed (if completed)"
|
||||
)
|
||||
cost_pages: Optional[int] = SchemaField(
|
||||
description="Number of pages crawled (if completed)"
|
||||
)
|
||||
cost_reasoning_tokens: Optional[int] = SchemaField(
|
||||
description="AI tokens used for reasoning (if completed)"
|
||||
)
|
||||
error_message: Optional[str] = SchemaField(
|
||||
description="Error message if research failed"
|
||||
)
|
||||
events: Optional[List[dict]] = SchemaField(
|
||||
description="Detailed event log (if include_events was True)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b2e3f4a5-6789-4bcd-9012-3456789abcde",
|
||||
description="Get status and results of a research task",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaGetResearchBlock.Input,
|
||||
output_schema=ExaGetResearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = f"https://api.exa.ai/research/v1/{input_data.research_id}"
|
||||
headers = {
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
params = {}
|
||||
if input_data.include_events:
|
||||
params["events"] = "true"
|
||||
|
||||
response = await Requests().get(url, headers=headers, params=params)
|
||||
data = response.json()
|
||||
|
||||
research = ResearchTaskModel.from_api(data)
|
||||
|
||||
yield "research_id", research.research_id
|
||||
yield "status", research.status
|
||||
yield "instructions", research.instructions
|
||||
yield "model", research.model
|
||||
yield "created_at", research.created_at
|
||||
yield "finished_at", research.finished_at
|
||||
|
||||
if research.output:
|
||||
yield "output_content", research.output.content
|
||||
yield "output_parsed", research.output.parsed
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
yield "cost_searches", research.cost_dollars.num_searches
|
||||
yield "cost_pages", research.cost_dollars.num_pages
|
||||
yield "cost_reasoning_tokens", research.cost_dollars.reasoning_tokens
|
||||
|
||||
yield "error_message", research.error
|
||||
|
||||
if input_data.include_events:
|
||||
yield "events", data.get("events", [])
|
||||
|
||||
|
||||
class ExaWaitForResearchBlock(Block):
|
||||
"""Wait for a research task to complete with progress tracking."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
research_id: str = SchemaField(
|
||||
description="The ID of the research task to wait for",
|
||||
placeholder="01jszdfs0052sg4jc552sg4jc5",
|
||||
)
|
||||
timeout: int = SchemaField(
|
||||
default=600,
|
||||
description="Maximum time to wait in seconds",
|
||||
ge=1,
|
||||
le=3600,
|
||||
)
|
||||
check_interval: int = SchemaField(
|
||||
default=10,
|
||||
description="Seconds between status checks",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=60,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
research_id: str = SchemaField(description="The research task identifier")
|
||||
final_status: str = SchemaField(description="Final status when polling stopped")
|
||||
output_content: Optional[str] = SchemaField(
|
||||
description="Research output as text (if completed)"
|
||||
)
|
||||
output_parsed: Optional[dict] = SchemaField(
|
||||
description="Structured JSON output (if outputSchema was provided and completed)"
|
||||
)
|
||||
cost_total: Optional[float] = SchemaField(description="Total cost in USD")
|
||||
elapsed_time: float = SchemaField(description="Total time waited in seconds")
|
||||
timed_out: bool = SchemaField(
|
||||
description="Whether polling timed out before completion"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c3d4e5f6-7890-4abc-9012-3456789abcde",
|
||||
description="Wait for a research task to complete with configurable timeout",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaWaitForResearchBlock.Input,
|
||||
output_schema=ExaWaitForResearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
start_time = time.time()
|
||||
url = f"https://api.exa.ai/research/v1/{input_data.research_id}"
|
||||
headers = {
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
response = await Requests().get(url, headers=headers)
|
||||
data = response.json()
|
||||
|
||||
status = data.get("status", "")
|
||||
|
||||
if status in ["completed", "failed", "canceled"]:
|
||||
elapsed = time.time() - start_time
|
||||
research = ResearchTaskModel.from_api(data)
|
||||
|
||||
yield "research_id", research.research_id
|
||||
yield "final_status", research.status
|
||||
yield "elapsed_time", elapsed
|
||||
yield "timed_out", False
|
||||
|
||||
if research.output:
|
||||
yield "output_content", research.output.content
|
||||
yield "output_parsed", research.output.parsed
|
||||
|
||||
if research.cost_dollars:
|
||||
yield "cost_total", research.cost_dollars.total
|
||||
|
||||
return
|
||||
|
||||
await asyncio.sleep(input_data.check_interval)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
response = await Requests().get(url, headers=headers)
|
||||
data = response.json()
|
||||
|
||||
yield "research_id", input_data.research_id
|
||||
yield "final_status", data.get("status", "unknown")
|
||||
yield "elapsed_time", elapsed
|
||||
yield "timed_out", True
|
||||
|
||||
|
||||
class ExaListResearchBlock(Block):
|
||||
"""List all research tasks with pagination support."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor for pagination through results",
|
||||
advanced=True,
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
default=10,
|
||||
description="Number of research tasks to return (1-50)",
|
||||
ge=1,
|
||||
le=50,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
research_tasks: List[ResearchTaskModel] = SchemaField(
|
||||
description="List of research tasks ordered by creation time (newest first)"
|
||||
)
|
||||
research_task: ResearchTaskModel = SchemaField(
|
||||
description="Individual research task (yielded for each task)"
|
||||
)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more tasks to paginate through"
|
||||
)
|
||||
next_cursor: Optional[str] = SchemaField(
|
||||
description="Cursor for the next page of results"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d4e5f6a7-8901-4bcd-9012-3456789abcde",
|
||||
description="List all research tasks with pagination support",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaListResearchBlock.Input,
|
||||
output_schema=ExaListResearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/research/v1"
|
||||
headers = {
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"limit": input_data.limit,
|
||||
}
|
||||
if input_data.cursor:
|
||||
params["cursor"] = input_data.cursor
|
||||
|
||||
response = await Requests().get(url, headers=headers, params=params)
|
||||
data = response.json()
|
||||
|
||||
tasks = [ResearchTaskModel.from_api(task) for task in data.get("data", [])]
|
||||
|
||||
yield "research_tasks", tasks
|
||||
|
||||
for task in tasks:
|
||||
yield "research_task", task
|
||||
|
||||
yield "has_more", data.get("hasMore", False)
|
||||
yield "next_cursor", data.get("nextCursor")
|
||||
@@ -1,66 +1,32 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import (
|
||||
ContentSettings,
|
||||
CostDollars,
|
||||
ExaSearchResults,
|
||||
process_contents_settings,
|
||||
)
|
||||
|
||||
|
||||
class ExaSearchTypes(Enum):
|
||||
KEYWORD = "keyword"
|
||||
NEURAL = "neural"
|
||||
FAST = "fast"
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class ExaSearchCategories(Enum):
|
||||
COMPANY = "company"
|
||||
RESEARCH_PAPER = "research paper"
|
||||
NEWS = "news"
|
||||
PDF = "pdf"
|
||||
GITHUB = "github"
|
||||
TWEET = "tweet"
|
||||
PERSONAL_SITE = "personal site"
|
||||
LINKEDIN_PROFILE = "linkedin profile"
|
||||
FINANCIAL_REPORT = "financial report"
|
||||
from .helpers import ContentSettings
|
||||
|
||||
|
||||
class ExaSearchBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
query: str = SchemaField(description="The search query")
|
||||
type: ExaSearchTypes = SchemaField(
|
||||
description="Type of search", default=ExaSearchTypes.AUTO, advanced=True
|
||||
use_auto_prompt: bool = SchemaField(
|
||||
description="Whether to use autoprompt", default=True, advanced=True
|
||||
)
|
||||
category: ExaSearchCategories | None = SchemaField(
|
||||
description="Category to search within: company, research paper, news, pdf, github, tweet, personal site, linkedin profile, financial report",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
user_location: str | None = SchemaField(
|
||||
description="The two-letter ISO country code of the user (e.g., 'US')",
|
||||
default=None,
|
||||
advanced=True,
|
||||
type: str = SchemaField(description="Type of search", default="", advanced=True)
|
||||
category: str = SchemaField(
|
||||
description="Category to search within", default="", advanced=True
|
||||
)
|
||||
number_of_results: int = SchemaField(
|
||||
description="Number of results to return", default=10, advanced=True
|
||||
@@ -73,17 +39,17 @@ class ExaSearchBlock(Block):
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
start_crawl_date: datetime | None = SchemaField(
|
||||
description="Start date for crawled content", advanced=True, default=None
|
||||
start_crawl_date: datetime = SchemaField(
|
||||
description="Start date for crawled content"
|
||||
)
|
||||
end_crawl_date: datetime | None = SchemaField(
|
||||
description="End date for crawled content", advanced=True, default=None
|
||||
end_crawl_date: datetime = SchemaField(
|
||||
description="End date for crawled content"
|
||||
)
|
||||
start_published_date: datetime | None = SchemaField(
|
||||
description="Start date for published content", advanced=True, default=None
|
||||
start_published_date: datetime = SchemaField(
|
||||
description="Start date for published content"
|
||||
)
|
||||
end_published_date: datetime | None = SchemaField(
|
||||
description="End date for published content", advanced=True, default=None
|
||||
end_published_date: datetime = SchemaField(
|
||||
description="End date for published content"
|
||||
)
|
||||
include_text: list[str] = SchemaField(
|
||||
description="Text patterns to include", default_factory=list, advanced=True
|
||||
@@ -96,30 +62,14 @@ class ExaSearchBlock(Block):
|
||||
default=ContentSettings(),
|
||||
advanced=True,
|
||||
)
|
||||
moderation: bool = SchemaField(
|
||||
description="Enable content moderation to filter unsafe content from search results",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
results: list[ExaSearchResults] = SchemaField(
|
||||
description="List of search results"
|
||||
class Output(BlockSchema):
|
||||
results: list = SchemaField(
|
||||
description="List of search results", default_factory=list
|
||||
)
|
||||
result: ExaSearchResults = SchemaField(description="Single search result")
|
||||
context: str = SchemaField(
|
||||
description="A formatted string of the search results ready for LLMs."
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed",
|
||||
)
|
||||
search_type: str = SchemaField(
|
||||
description="For auto searches, indicates which search type was selected."
|
||||
)
|
||||
resolved_search_type: str = SchemaField(
|
||||
description="The search type that was actually used for this request (neural or keyword)"
|
||||
)
|
||||
cost_dollars: Optional[CostDollars] = SchemaField(
|
||||
description="Cost breakdown for the request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -133,76 +83,51 @@ class ExaSearchBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
sdk_kwargs = {
|
||||
"query": input_data.query,
|
||||
"num_results": input_data.number_of_results,
|
||||
url = "https://api.exa.ai/search"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
if input_data.type:
|
||||
sdk_kwargs["type"] = input_data.type.value
|
||||
payload = {
|
||||
"query": input_data.query,
|
||||
"useAutoprompt": input_data.use_auto_prompt,
|
||||
"numResults": input_data.number_of_results,
|
||||
"contents": input_data.contents.model_dump(),
|
||||
}
|
||||
|
||||
if input_data.category:
|
||||
sdk_kwargs["category"] = input_data.category.value
|
||||
date_field_mapping = {
|
||||
"start_crawl_date": "startCrawlDate",
|
||||
"end_crawl_date": "endCrawlDate",
|
||||
"start_published_date": "startPublishedDate",
|
||||
"end_published_date": "endPublishedDate",
|
||||
}
|
||||
|
||||
if input_data.user_location:
|
||||
sdk_kwargs["user_location"] = input_data.user_location
|
||||
# Add dates if they exist
|
||||
for input_field, api_field in date_field_mapping.items():
|
||||
value = getattr(input_data, input_field, None)
|
||||
if value:
|
||||
payload[api_field] = value.strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
|
||||
# Handle domains
|
||||
if input_data.include_domains:
|
||||
sdk_kwargs["include_domains"] = input_data.include_domains
|
||||
if input_data.exclude_domains:
|
||||
sdk_kwargs["exclude_domains"] = input_data.exclude_domains
|
||||
optional_field_mapping = {
|
||||
"type": "type",
|
||||
"category": "category",
|
||||
"include_domains": "includeDomains",
|
||||
"exclude_domains": "excludeDomains",
|
||||
"include_text": "includeText",
|
||||
"exclude_text": "excludeText",
|
||||
}
|
||||
|
||||
# Handle dates
|
||||
if input_data.start_crawl_date:
|
||||
sdk_kwargs["start_crawl_date"] = input_data.start_crawl_date.isoformat()
|
||||
if input_data.end_crawl_date:
|
||||
sdk_kwargs["end_crawl_date"] = input_data.end_crawl_date.isoformat()
|
||||
if input_data.start_published_date:
|
||||
sdk_kwargs["start_published_date"] = (
|
||||
input_data.start_published_date.isoformat()
|
||||
)
|
||||
if input_data.end_published_date:
|
||||
sdk_kwargs["end_published_date"] = input_data.end_published_date.isoformat()
|
||||
# Add other fields
|
||||
for input_field, api_field in optional_field_mapping.items():
|
||||
value = getattr(input_data, input_field)
|
||||
if value: # Only add non-empty values
|
||||
payload[api_field] = value
|
||||
|
||||
# Handle text filters
|
||||
if input_data.include_text:
|
||||
sdk_kwargs["include_text"] = input_data.include_text
|
||||
if input_data.exclude_text:
|
||||
sdk_kwargs["exclude_text"] = input_data.exclude_text
|
||||
|
||||
if input_data.moderation:
|
||||
sdk_kwargs["moderation"] = input_data.moderation
|
||||
|
||||
# heck if we need to use search_and_contents
|
||||
content_settings = process_contents_settings(input_data.contents)
|
||||
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
if content_settings:
|
||||
sdk_kwargs["text"] = content_settings.get("text", False)
|
||||
if "highlights" in content_settings:
|
||||
sdk_kwargs["highlights"] = content_settings["highlights"]
|
||||
if "summary" in content_settings:
|
||||
sdk_kwargs["summary"] = content_settings["summary"]
|
||||
response = await aexa.search_and_contents(**sdk_kwargs)
|
||||
else:
|
||||
response = await aexa.search(**sdk_kwargs)
|
||||
|
||||
converted_results = [
|
||||
ExaSearchResults.from_sdk(sdk_result)
|
||||
for sdk_result in response.results or []
|
||||
]
|
||||
|
||||
yield "results", converted_results
|
||||
for result in converted_results:
|
||||
yield "result", result
|
||||
|
||||
if response.context:
|
||||
yield "context", response.context
|
||||
|
||||
if response.resolved_search_type:
|
||||
yield "resolved_search_type", response.resolved_search_type
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
# Extract just the results array from the response
|
||||
yield "results", data.get("results", [])
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -1,30 +1,23 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from typing import Any
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import (
|
||||
ContentSettings,
|
||||
CostDollars,
|
||||
ExaSearchResults,
|
||||
process_contents_settings,
|
||||
)
|
||||
from .helpers import ContentSettings
|
||||
|
||||
|
||||
class ExaFindSimilarBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
@@ -35,7 +28,7 @@ class ExaFindSimilarBlock(Block):
|
||||
description="Number of results to return", default=10, advanced=True
|
||||
)
|
||||
include_domains: list[str] = SchemaField(
|
||||
description="List of domains to include in the search. If specified, results will only come from these domains.",
|
||||
description="Domains to include in search",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
@@ -44,17 +37,17 @@ class ExaFindSimilarBlock(Block):
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
start_crawl_date: Optional[datetime] = SchemaField(
|
||||
description="Start date for crawled content", advanced=True, default=None
|
||||
start_crawl_date: datetime = SchemaField(
|
||||
description="Start date for crawled content"
|
||||
)
|
||||
end_crawl_date: Optional[datetime] = SchemaField(
|
||||
description="End date for crawled content", advanced=True, default=None
|
||||
end_crawl_date: datetime = SchemaField(
|
||||
description="End date for crawled content"
|
||||
)
|
||||
start_published_date: Optional[datetime] = SchemaField(
|
||||
description="Start date for published content", advanced=True, default=None
|
||||
start_published_date: datetime = SchemaField(
|
||||
description="Start date for published content"
|
||||
)
|
||||
end_published_date: Optional[datetime] = SchemaField(
|
||||
description="End date for published content", advanced=True, default=None
|
||||
end_published_date: datetime = SchemaField(
|
||||
description="End date for published content"
|
||||
)
|
||||
include_text: list[str] = SchemaField(
|
||||
description="Text patterns to include (max 1 string, up to 5 words)",
|
||||
@@ -71,27 +64,15 @@ class ExaFindSimilarBlock(Block):
|
||||
default=ContentSettings(),
|
||||
advanced=True,
|
||||
)
|
||||
moderation: bool = SchemaField(
|
||||
description="Enable content moderation to filter unsafe content from search results",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
results: list[ExaSearchResults] = SchemaField(
|
||||
description="List of similar documents with metadata and content"
|
||||
class Output(BlockSchema):
|
||||
results: list[Any] = SchemaField(
|
||||
description="List of similar documents with title, URL, published date, author, and score",
|
||||
default_factory=list,
|
||||
)
|
||||
result: ExaSearchResults = SchemaField(
|
||||
description="Single similar document result"
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
context: str = SchemaField(
|
||||
description="A formatted string of the results ready for LLMs."
|
||||
)
|
||||
request_id: str = SchemaField(description="Unique identifier for the request")
|
||||
cost_dollars: Optional[CostDollars] = SchemaField(
|
||||
description="Cost breakdown for the request"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -105,65 +86,47 @@ class ExaFindSimilarBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
sdk_kwargs = {
|
||||
"url": input_data.url,
|
||||
"num_results": input_data.number_of_results,
|
||||
url = "https://api.exa.ai/findSimilar"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
# Handle domains
|
||||
if input_data.include_domains:
|
||||
sdk_kwargs["include_domains"] = input_data.include_domains
|
||||
if input_data.exclude_domains:
|
||||
sdk_kwargs["exclude_domains"] = input_data.exclude_domains
|
||||
payload = {
|
||||
"url": input_data.url,
|
||||
"numResults": input_data.number_of_results,
|
||||
"contents": input_data.contents.model_dump(),
|
||||
}
|
||||
|
||||
# Handle dates
|
||||
if input_data.start_crawl_date:
|
||||
sdk_kwargs["start_crawl_date"] = input_data.start_crawl_date.isoformat()
|
||||
if input_data.end_crawl_date:
|
||||
sdk_kwargs["end_crawl_date"] = input_data.end_crawl_date.isoformat()
|
||||
if input_data.start_published_date:
|
||||
sdk_kwargs["start_published_date"] = (
|
||||
input_data.start_published_date.isoformat()
|
||||
)
|
||||
if input_data.end_published_date:
|
||||
sdk_kwargs["end_published_date"] = input_data.end_published_date.isoformat()
|
||||
optional_field_mapping = {
|
||||
"include_domains": "includeDomains",
|
||||
"exclude_domains": "excludeDomains",
|
||||
"include_text": "includeText",
|
||||
"exclude_text": "excludeText",
|
||||
}
|
||||
|
||||
# Handle text filters
|
||||
if input_data.include_text:
|
||||
sdk_kwargs["include_text"] = input_data.include_text
|
||||
if input_data.exclude_text:
|
||||
sdk_kwargs["exclude_text"] = input_data.exclude_text
|
||||
# Add optional fields if they have values
|
||||
for input_field, api_field in optional_field_mapping.items():
|
||||
value = getattr(input_data, input_field)
|
||||
if value: # Only add non-empty values
|
||||
payload[api_field] = value
|
||||
|
||||
if input_data.moderation:
|
||||
sdk_kwargs["moderation"] = input_data.moderation
|
||||
date_field_mapping = {
|
||||
"start_crawl_date": "startCrawlDate",
|
||||
"end_crawl_date": "endCrawlDate",
|
||||
"start_published_date": "startPublishedDate",
|
||||
"end_published_date": "endPublishedDate",
|
||||
}
|
||||
|
||||
# check if we need to use find_similar_and_contents
|
||||
content_settings = process_contents_settings(input_data.contents)
|
||||
# Add dates if they exist
|
||||
for input_field, api_field in date_field_mapping.items():
|
||||
value = getattr(input_data, input_field, None)
|
||||
if value:
|
||||
payload[api_field] = value.strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
if content_settings:
|
||||
# Use find_similar_and_contents when contents are requested
|
||||
sdk_kwargs["text"] = content_settings.get("text", False)
|
||||
if "highlights" in content_settings:
|
||||
sdk_kwargs["highlights"] = content_settings["highlights"]
|
||||
if "summary" in content_settings:
|
||||
sdk_kwargs["summary"] = content_settings["summary"]
|
||||
response = await aexa.find_similar_and_contents(**sdk_kwargs)
|
||||
else:
|
||||
response = await aexa.find_similar(**sdk_kwargs)
|
||||
|
||||
converted_results = [
|
||||
ExaSearchResults.from_sdk(sdk_result)
|
||||
for sdk_result in response.results or []
|
||||
]
|
||||
|
||||
yield "results", converted_results
|
||||
for result in converted_results:
|
||||
yield "result", result
|
||||
|
||||
if response.context:
|
||||
yield "context", response.context
|
||||
|
||||
if response.cost_dollars:
|
||||
yield "cost_dollars", response.cost_dollars
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
yield "results", data.get("results", [])
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -9,8 +9,7 @@ from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
BlockWebhookConfig,
|
||||
CredentialsMetaInput,
|
||||
@@ -85,7 +84,7 @@ class ExaWebsetWebhookBlock(Block):
|
||||
including creation, updates, searches, and exports.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="Exa API credentials for webhook management"
|
||||
)
|
||||
@@ -105,7 +104,7 @@ class ExaWebsetWebhookBlock(Block):
|
||||
description="Webhook payload data", default={}, hidden=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
event_type: str = SchemaField(description="Type of event that occurred")
|
||||
event_id: str = SchemaField(description="Unique identifier for this event")
|
||||
webset_id: str = SchemaField(description="ID of the affected webset")
|
||||
@@ -132,33 +131,45 @@ class ExaWebsetWebhookBlock(Block):
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
"""Process incoming Exa webhook payload."""
|
||||
payload = input_data.payload
|
||||
try:
|
||||
payload = input_data.payload
|
||||
|
||||
# Extract event details
|
||||
event_type = payload.get("eventType", "unknown")
|
||||
event_id = payload.get("eventId", "")
|
||||
# Extract event details
|
||||
event_type = payload.get("eventType", "unknown")
|
||||
event_id = payload.get("eventId", "")
|
||||
|
||||
# Get webset ID from payload or input
|
||||
webset_id = payload.get("websetId", input_data.webset_id)
|
||||
# Get webset ID from payload or input
|
||||
webset_id = payload.get("websetId", input_data.webset_id)
|
||||
|
||||
# Check if we should process this event based on filter
|
||||
should_process = self._should_process_event(event_type, input_data.event_filter)
|
||||
# Check if we should process this event based on filter
|
||||
should_process = self._should_process_event(
|
||||
event_type, input_data.event_filter
|
||||
)
|
||||
|
||||
if not should_process:
|
||||
# Skip events that don't match our filter
|
||||
return
|
||||
if not should_process:
|
||||
# Skip events that don't match our filter
|
||||
return
|
||||
|
||||
# Extract event data
|
||||
event_data = payload.get("data", {})
|
||||
timestamp = payload.get("occurredAt", payload.get("createdAt", ""))
|
||||
metadata = payload.get("metadata", {})
|
||||
# Extract event data
|
||||
event_data = payload.get("data", {})
|
||||
timestamp = payload.get("occurredAt", payload.get("createdAt", ""))
|
||||
metadata = payload.get("metadata", {})
|
||||
|
||||
yield "event_type", event_type
|
||||
yield "event_id", event_id
|
||||
yield "webset_id", webset_id
|
||||
yield "data", event_data
|
||||
yield "timestamp", timestamp
|
||||
yield "metadata", metadata
|
||||
yield "event_type", event_type
|
||||
yield "event_id", event_id
|
||||
yield "webset_id", webset_id
|
||||
yield "data", event_data
|
||||
yield "timestamp", timestamp
|
||||
yield "metadata", metadata
|
||||
|
||||
except Exception as e:
|
||||
# Handle errors gracefully
|
||||
yield "event_type", "error"
|
||||
yield "event_id", ""
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "data", {"error": str(e)}
|
||||
yield "timestamp", ""
|
||||
yield "metadata", {}
|
||||
|
||||
def _should_process_event(
|
||||
self, event_type: str, event_filter: WebsetEventFilter
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,554 +0,0 @@
|
||||
"""
|
||||
Exa Websets Enrichment Management Blocks
|
||||
|
||||
This module provides blocks for creating and managing enrichments on webset items,
|
||||
allowing extraction of additional structured data from existing items.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.websets.types import WebsetEnrichment as SdkWebsetEnrichment
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
# Mirrored model for stability
|
||||
class WebsetEnrichmentModel(BaseModel):
|
||||
"""Stable output model mirroring SDK WebsetEnrichment."""
|
||||
|
||||
id: str
|
||||
webset_id: str
|
||||
status: str
|
||||
title: Optional[str]
|
||||
description: str
|
||||
format: str
|
||||
options: List[str]
|
||||
instructions: Optional[str]
|
||||
metadata: Dict[str, Any]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, enrichment: SdkWebsetEnrichment) -> "WebsetEnrichmentModel":
|
||||
"""Convert SDK WebsetEnrichment to our stable model."""
|
||||
# Extract options
|
||||
options_list = []
|
||||
if enrichment.options:
|
||||
for option in enrichment.options:
|
||||
option_dict = option.model_dump(by_alias=True)
|
||||
options_list.append(option_dict.get("label", ""))
|
||||
|
||||
return cls(
|
||||
id=enrichment.id,
|
||||
webset_id=enrichment.webset_id,
|
||||
status=(
|
||||
enrichment.status.value
|
||||
if hasattr(enrichment.status, "value")
|
||||
else str(enrichment.status)
|
||||
),
|
||||
title=enrichment.title,
|
||||
description=enrichment.description,
|
||||
format=(
|
||||
enrichment.format.value
|
||||
if enrichment.format and hasattr(enrichment.format, "value")
|
||||
else "text"
|
||||
),
|
||||
options=options_list,
|
||||
instructions=enrichment.instructions,
|
||||
metadata=enrichment.metadata if enrichment.metadata else {},
|
||||
created_at=(
|
||||
enrichment.created_at.isoformat() if enrichment.created_at else ""
|
||||
),
|
||||
updated_at=(
|
||||
enrichment.updated_at.isoformat() if enrichment.updated_at else ""
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class EnrichmentFormat(str, Enum):
|
||||
"""Format types for enrichment responses."""
|
||||
|
||||
TEXT = "text" # Free text response
|
||||
DATE = "date" # Date/datetime format
|
||||
NUMBER = "number" # Numeric value
|
||||
OPTIONS = "options" # Multiple choice from provided options
|
||||
EMAIL = "email" # Email address format
|
||||
PHONE = "phone" # Phone number format
|
||||
|
||||
|
||||
class ExaCreateEnrichmentBlock(Block):
|
||||
"""Create a new enrichment to extract additional data from webset items."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
description: str = SchemaField(
|
||||
description="What data to extract from each item",
|
||||
placeholder="Extract the company's main product or service offering",
|
||||
)
|
||||
title: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Short title for this enrichment (auto-generated if not provided)",
|
||||
placeholder="Main Product",
|
||||
)
|
||||
format: EnrichmentFormat = SchemaField(
|
||||
default=EnrichmentFormat.TEXT,
|
||||
description="Expected format of the extracted data",
|
||||
)
|
||||
options: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="Available options when format is 'options'",
|
||||
placeholder='["B2B", "B2C", "Both", "Unknown"]',
|
||||
advanced=True,
|
||||
)
|
||||
apply_to_existing: bool = SchemaField(
|
||||
default=True,
|
||||
description="Apply this enrichment to existing items in the webset",
|
||||
)
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="Metadata to attach to the enrichment",
|
||||
advanced=True,
|
||||
)
|
||||
wait_for_completion: bool = SchemaField(
|
||||
default=False,
|
||||
description="Wait for the enrichment to complete on existing items",
|
||||
)
|
||||
polling_timeout: int = SchemaField(
|
||||
default=300,
|
||||
description="Maximum time to wait for completion in seconds",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=600,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The unique identifier for the created enrichment"
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The webset this enrichment belongs to"
|
||||
)
|
||||
status: str = SchemaField(description="Current status of the enrichment")
|
||||
title: str = SchemaField(description="Title of the enrichment")
|
||||
description: str = SchemaField(
|
||||
description="Description of what data is extracted"
|
||||
)
|
||||
format: str = SchemaField(description="Format of the extracted data")
|
||||
instructions: str = SchemaField(
|
||||
description="Generated instructions for the enrichment"
|
||||
)
|
||||
items_enriched: Optional[int] = SchemaField(
|
||||
description="Number of items enriched (if wait_for_completion was True)"
|
||||
)
|
||||
completion_time: Optional[float] = SchemaField(
|
||||
description="Time taken to complete in seconds (if wait_for_completion was True)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="71146ae8-0cb1-4a15-8cde-eae30de71cb6",
|
||||
description="Create enrichments to extract additional structured data from webset items",
|
||||
categories={BlockCategory.AI, BlockCategory.SEARCH},
|
||||
input_schema=ExaCreateEnrichmentBlock.Input,
|
||||
output_schema=ExaCreateEnrichmentBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
import time
|
||||
|
||||
# Build the payload
|
||||
payload: dict[str, Any] = {
|
||||
"description": input_data.description,
|
||||
"format": input_data.format.value,
|
||||
}
|
||||
|
||||
# Add title if provided
|
||||
if input_data.title:
|
||||
payload["title"] = input_data.title
|
||||
|
||||
# Add options for 'options' format
|
||||
if input_data.format == EnrichmentFormat.OPTIONS and input_data.options:
|
||||
payload["options"] = [{"label": opt} for opt in input_data.options]
|
||||
|
||||
# Add metadata if provided
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_enrichment = aexa.websets.enrichments.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
|
||||
enrichment_id = sdk_enrichment.id
|
||||
status = (
|
||||
sdk_enrichment.status.value
|
||||
if hasattr(sdk_enrichment.status, "value")
|
||||
else str(sdk_enrichment.status)
|
||||
)
|
||||
|
||||
# If wait_for_completion is True and apply_to_existing is True, poll for completion
|
||||
if input_data.wait_for_completion and input_data.apply_to_existing:
|
||||
import asyncio
|
||||
|
||||
poll_interval = 5
|
||||
max_interval = 30
|
||||
poll_start = time.time()
|
||||
items_enriched = 0
|
||||
|
||||
while time.time() - poll_start < input_data.polling_timeout:
|
||||
current_enrich = aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=enrichment_id
|
||||
)
|
||||
current_status = (
|
||||
current_enrich.status.value
|
||||
if hasattr(current_enrich.status, "value")
|
||||
else str(current_enrich.status)
|
||||
)
|
||||
|
||||
if current_status in ["completed", "failed", "cancelled"]:
|
||||
# Estimate items from webset searches
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.progress:
|
||||
items_enriched += search.progress.found
|
||||
completion_time = time.time() - start_time
|
||||
|
||||
yield "enrichment_id", enrichment_id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", current_status
|
||||
yield "title", sdk_enrichment.title
|
||||
yield "description", input_data.description
|
||||
yield "format", input_data.format.value
|
||||
yield "instructions", sdk_enrichment.instructions
|
||||
yield "items_enriched", items_enriched
|
||||
yield "completion_time", completion_time
|
||||
return
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
poll_interval = min(poll_interval * 1.5, max_interval)
|
||||
|
||||
# Timeout
|
||||
completion_time = time.time() - start_time
|
||||
yield "enrichment_id", enrichment_id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", status
|
||||
yield "title", sdk_enrichment.title
|
||||
yield "description", input_data.description
|
||||
yield "format", input_data.format.value
|
||||
yield "instructions", sdk_enrichment.instructions
|
||||
yield "items_enriched", 0
|
||||
yield "completion_time", completion_time
|
||||
else:
|
||||
yield "enrichment_id", enrichment_id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", status
|
||||
yield "title", sdk_enrichment.title
|
||||
yield "description", input_data.description
|
||||
yield "format", input_data.format.value
|
||||
yield "instructions", sdk_enrichment.instructions
|
||||
|
||||
|
||||
class ExaGetEnrichmentBlock(Block):
|
||||
"""Get the status and details of a webset enrichment."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The ID of the enrichment to retrieve",
|
||||
placeholder="enrichment-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The unique identifier for the enrichment"
|
||||
)
|
||||
status: str = SchemaField(description="Current status of the enrichment")
|
||||
title: str = SchemaField(description="Title of the enrichment")
|
||||
description: str = SchemaField(
|
||||
description="Description of what data is extracted"
|
||||
)
|
||||
format: str = SchemaField(description="Format of the extracted data")
|
||||
options: list[str] = SchemaField(
|
||||
description="Available options (for 'options' format)"
|
||||
)
|
||||
instructions: str = SchemaField(
|
||||
description="Generated instructions for the enrichment"
|
||||
)
|
||||
created_at: str = SchemaField(description="When the enrichment was created")
|
||||
updated_at: str = SchemaField(
|
||||
description="When the enrichment was last updated"
|
||||
)
|
||||
metadata: dict = SchemaField(description="Metadata attached to the enrichment")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b8c9d0e1-f2a3-4567-89ab-cdef01234567",
|
||||
description="Get the status and details of a webset enrichment",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaGetEnrichmentBlock.Input,
|
||||
output_schema=ExaGetEnrichmentBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_enrichment = aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
enrichment = WebsetEnrichmentModel.from_sdk(sdk_enrichment)
|
||||
|
||||
yield "enrichment_id", enrichment.id
|
||||
yield "status", enrichment.status
|
||||
yield "title", enrichment.title
|
||||
yield "description", enrichment.description
|
||||
yield "format", enrichment.format
|
||||
yield "options", enrichment.options
|
||||
yield "instructions", enrichment.instructions
|
||||
yield "created_at", enrichment.created_at
|
||||
yield "updated_at", enrichment.updated_at
|
||||
yield "metadata", enrichment.metadata
|
||||
|
||||
|
||||
class ExaUpdateEnrichmentBlock(Block):
|
||||
"""Update an existing enrichment configuration."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The ID of the enrichment to update",
|
||||
placeholder="enrichment-id",
|
||||
)
|
||||
description: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="New description for what data to extract",
|
||||
)
|
||||
format: Optional[EnrichmentFormat] = SchemaField(
|
||||
default=None,
|
||||
description="New format for the extracted data",
|
||||
)
|
||||
options: Optional[list[str]] = SchemaField(
|
||||
default=None,
|
||||
description="New options when format is 'options'",
|
||||
)
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="New metadata to attach to the enrichment",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The unique identifier for the enrichment"
|
||||
)
|
||||
status: str = SchemaField(description="Current status of the enrichment")
|
||||
title: str = SchemaField(description="Title of the enrichment")
|
||||
description: str = SchemaField(description="Updated description")
|
||||
format: str = SchemaField(description="Updated format")
|
||||
success: str = SchemaField(description="Whether the update was successful")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c8d5c5fb-9684-4a29-bd2a-5b38d71776c9",
|
||||
description="Update an existing enrichment configuration",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaUpdateEnrichmentBlock.Input,
|
||||
output_schema=ExaUpdateEnrichmentBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}/enrichments/{input_data.enrichment_id}"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
# Build the update payload
|
||||
payload = {}
|
||||
|
||||
if input_data.description is not None:
|
||||
payload["description"] = input_data.description
|
||||
|
||||
if input_data.format is not None:
|
||||
payload["format"] = input_data.format.value
|
||||
|
||||
if input_data.options is not None:
|
||||
payload["options"] = [{"label": opt} for opt in input_data.options]
|
||||
|
||||
if input_data.metadata is not None:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
try:
|
||||
response = await Requests().patch(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
|
||||
yield "enrichment_id", data.get("id", "")
|
||||
yield "status", data.get("status", "")
|
||||
yield "title", data.get("title", "")
|
||||
yield "description", data.get("description", "")
|
||||
yield "format", data.get("format", "")
|
||||
yield "success", "true"
|
||||
|
||||
except ValueError as e:
|
||||
# Re-raise user input validation errors
|
||||
raise ValueError(f"Failed to update enrichment: {e}") from e
|
||||
# Let all other exceptions propagate naturally
|
||||
|
||||
|
||||
class ExaDeleteEnrichmentBlock(Block):
|
||||
"""Delete an enrichment from a webset."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The ID of the enrichment to delete",
|
||||
placeholder="enrichment-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
enrichment_id: str = SchemaField(description="The ID of the deleted enrichment")
|
||||
success: str = SchemaField(description="Whether the deletion was successful")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b250de56-2ca6-4237-a7b8-b5684892189f",
|
||||
description="Delete an enrichment from a webset",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaDeleteEnrichmentBlock.Input,
|
||||
output_schema=ExaDeleteEnrichmentBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_enrichment = aexa.websets.enrichments.delete(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
yield "enrichment_id", deleted_enrichment.id
|
||||
yield "success", "true"
|
||||
|
||||
|
||||
class ExaCancelEnrichmentBlock(Block):
|
||||
"""Cancel a running enrichment operation."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The ID of the enrichment to cancel",
|
||||
placeholder="enrichment-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The ID of the canceled enrichment"
|
||||
)
|
||||
status: str = SchemaField(description="Status after cancellation")
|
||||
items_enriched_before_cancel: int = SchemaField(
|
||||
description="Approximate number of items enriched before cancellation"
|
||||
)
|
||||
success: str = SchemaField(
|
||||
description="Whether the cancellation was successful"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7e1f8f0f-b6ab-43b3-bd1d-0c534a649295",
|
||||
description="Cancel a running enrichment operation",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaCancelEnrichmentBlock.Input,
|
||||
output_schema=ExaCancelEnrichmentBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
canceled_enrichment = aexa.websets.enrichments.cancel(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
# Try to estimate how many items were enriched before cancellation
|
||||
items_enriched = 0
|
||||
items_response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=100
|
||||
)
|
||||
|
||||
for sdk_item in items_response.data:
|
||||
# Check if this enrichment is present
|
||||
for enrich_result in sdk_item.enrichments:
|
||||
if enrich_result.enrichment_id == input_data.enrichment_id:
|
||||
items_enriched += 1
|
||||
break
|
||||
|
||||
status = (
|
||||
canceled_enrichment.status.value
|
||||
if hasattr(canceled_enrichment.status, "value")
|
||||
else str(canceled_enrichment.status)
|
||||
)
|
||||
|
||||
yield "enrichment_id", canceled_enrichment.id
|
||||
yield "status", status
|
||||
yield "items_enriched_before_cancel", items_enriched
|
||||
yield "success", "true"
|
||||
@@ -1,676 +0,0 @@
|
||||
"""
|
||||
Exa Websets Import/Export Management Blocks
|
||||
|
||||
This module provides blocks for importing data into websets from CSV files
|
||||
and exporting webset data in various formats.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import json
|
||||
from enum import Enum
|
||||
from io import StringIO
|
||||
from typing import Optional, Union
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.websets.types import CreateImportResponse
|
||||
from exa_py.websets.types import Import as SdkImport
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from ._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
|
||||
|
||||
# Mirrored model for stability - don't use SDK types directly in block outputs
|
||||
class ImportModel(BaseModel):
|
||||
"""Stable output model mirroring SDK Import."""
|
||||
|
||||
id: str
|
||||
status: str
|
||||
title: str
|
||||
format: str
|
||||
entity_type: str
|
||||
count: int
|
||||
upload_url: Optional[str] # Only in CreateImportResponse
|
||||
upload_valid_until: Optional[str] # Only in CreateImportResponse
|
||||
failed_reason: str
|
||||
failed_message: str
|
||||
metadata: dict
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
@classmethod
|
||||
def from_sdk(
|
||||
cls, import_obj: Union[SdkImport, CreateImportResponse]
|
||||
) -> "ImportModel":
|
||||
"""Convert SDK Import or CreateImportResponse to our stable model."""
|
||||
# Extract entity type from union (may be None)
|
||||
entity_type = "unknown"
|
||||
if import_obj.entity:
|
||||
entity_dict = import_obj.entity.model_dump(by_alias=True, exclude_none=True)
|
||||
entity_type = entity_dict.get("type", "unknown")
|
||||
|
||||
# Handle status enum
|
||||
status_str = (
|
||||
import_obj.status.value
|
||||
if hasattr(import_obj.status, "value")
|
||||
else str(import_obj.status)
|
||||
)
|
||||
|
||||
# Handle format enum
|
||||
format_str = (
|
||||
import_obj.format.value
|
||||
if hasattr(import_obj.format, "value")
|
||||
else str(import_obj.format)
|
||||
)
|
||||
|
||||
# Handle failed_reason enum (may be None or enum)
|
||||
failed_reason_str = ""
|
||||
if import_obj.failed_reason:
|
||||
failed_reason_str = (
|
||||
import_obj.failed_reason.value
|
||||
if hasattr(import_obj.failed_reason, "value")
|
||||
else str(import_obj.failed_reason)
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=import_obj.id,
|
||||
status=status_str,
|
||||
title=import_obj.title or "",
|
||||
format=format_str,
|
||||
entity_type=entity_type,
|
||||
count=int(import_obj.count or 0),
|
||||
upload_url=getattr(
|
||||
import_obj, "upload_url", None
|
||||
), # Only in CreateImportResponse
|
||||
upload_valid_until=getattr(
|
||||
import_obj, "upload_valid_until", None
|
||||
), # Only in CreateImportResponse
|
||||
failed_reason=failed_reason_str,
|
||||
failed_message=import_obj.failed_message or "",
|
||||
metadata=import_obj.metadata or {},
|
||||
created_at=(
|
||||
import_obj.created_at.isoformat() if import_obj.created_at else ""
|
||||
),
|
||||
updated_at=(
|
||||
import_obj.updated_at.isoformat() if import_obj.updated_at else ""
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ImportFormat(str, Enum):
|
||||
"""Supported import formats."""
|
||||
|
||||
CSV = "csv"
|
||||
# JSON = "json" # Future support
|
||||
|
||||
|
||||
class ImportEntityType(str, Enum):
|
||||
"""Entity types for imports."""
|
||||
|
||||
COMPANY = "company"
|
||||
PERSON = "person"
|
||||
ARTICLE = "article"
|
||||
RESEARCH_PAPER = "research_paper"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class ExportFormat(str, Enum):
|
||||
"""Supported export formats."""
|
||||
|
||||
JSON = "json"
|
||||
CSV = "csv"
|
||||
JSON_LINES = "jsonl"
|
||||
|
||||
|
||||
class ExaCreateImportBlock(Block):
|
||||
"""Create an import to load external data that can be used with websets."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
title: str = SchemaField(
|
||||
description="Title for this import",
|
||||
placeholder="Customer List Import",
|
||||
)
|
||||
csv_data: str = SchemaField(
|
||||
description="CSV data to import (as a string)",
|
||||
placeholder="name,url\nAcme Corp,https://acme.com\nExample Inc,https://example.com",
|
||||
)
|
||||
entity_type: ImportEntityType = SchemaField(
|
||||
default=ImportEntityType.COMPANY,
|
||||
description="Type of entities being imported",
|
||||
)
|
||||
entity_description: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Description for custom entity type",
|
||||
advanced=True,
|
||||
)
|
||||
identifier_column: int = SchemaField(
|
||||
default=0,
|
||||
description="Column index containing the identifier (0-based)",
|
||||
ge=0,
|
||||
)
|
||||
url_column: Optional[int] = SchemaField(
|
||||
default=None,
|
||||
description="Column index containing URLs (optional)",
|
||||
ge=0,
|
||||
advanced=True,
|
||||
)
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="Metadata to attach to the import",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
import_id: str = SchemaField(
|
||||
description="The unique identifier for the created import"
|
||||
)
|
||||
status: str = SchemaField(description="Current status of the import")
|
||||
title: str = SchemaField(description="Title of the import")
|
||||
count: int = SchemaField(description="Number of items in the import")
|
||||
entity_type: str = SchemaField(description="Type of entities imported")
|
||||
upload_url: Optional[str] = SchemaField(
|
||||
description="Upload URL for CSV data (only if csv_data not provided in request)"
|
||||
)
|
||||
upload_valid_until: Optional[str] = SchemaField(
|
||||
description="Expiration time for upload URL (only if upload_url is provided)"
|
||||
)
|
||||
created_at: str = SchemaField(description="When the import was created")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="020a35d8-8a53-4e60-8b60-1de5cbab1df3",
|
||||
description="Import CSV data to use with websets for targeted searches",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=ExaCreateImportBlock.Input,
|
||||
output_schema=ExaCreateImportBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"title": "Test Import",
|
||||
"csv_data": "name,url\nAcme,https://acme.com",
|
||||
"entity_type": ImportEntityType.COMPANY,
|
||||
"identifier_column": 0,
|
||||
},
|
||||
test_output=[
|
||||
("import_id", "import-123"),
|
||||
("status", "pending"),
|
||||
("title", "Test Import"),
|
||||
("count", 1),
|
||||
("entity_type", "company"),
|
||||
("upload_url", None),
|
||||
("upload_valid_until", None),
|
||||
("created_at", "2024-01-01T00:00:00"),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock=self._create_test_mock(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_test_mock():
|
||||
"""Create test mocks for the AsyncExa SDK."""
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create mock SDK import object
|
||||
mock_import = MagicMock()
|
||||
mock_import.id = "import-123"
|
||||
mock_import.status = MagicMock(value="pending")
|
||||
mock_import.title = "Test Import"
|
||||
mock_import.format = MagicMock(value="csv")
|
||||
mock_import.count = 1
|
||||
mock_import.upload_url = None
|
||||
mock_import.upload_valid_until = None
|
||||
mock_import.failed_reason = None
|
||||
mock_import.failed_message = ""
|
||||
mock_import.metadata = {}
|
||||
mock_import.created_at = datetime.fromisoformat("2024-01-01T00:00:00")
|
||||
mock_import.updated_at = datetime.fromisoformat("2024-01-01T00:00:00")
|
||||
|
||||
# Mock entity
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.model_dump = MagicMock(return_value={"type": "company"})
|
||||
mock_import.entity = mock_entity
|
||||
|
||||
return {
|
||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||
websets=MagicMock(
|
||||
imports=MagicMock(create=lambda *args, **kwargs: mock_import)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
def _get_client(self, api_key: str) -> AsyncExa:
|
||||
"""Get Exa client (separated for testing)."""
|
||||
return AsyncExa(api_key=api_key)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
aexa = self._get_client(credentials.api_key.get_secret_value())
|
||||
|
||||
csv_reader = csv.reader(StringIO(input_data.csv_data))
|
||||
rows = list(csv_reader)
|
||||
count = len(rows) - 1 if len(rows) > 1 else 0
|
||||
|
||||
size = len(input_data.csv_data.encode("utf-8"))
|
||||
|
||||
payload = {
|
||||
"title": input_data.title,
|
||||
"format": ImportFormat.CSV.value,
|
||||
"count": count,
|
||||
"size": size,
|
||||
"csv": {
|
||||
"identifier": input_data.identifier_column,
|
||||
},
|
||||
}
|
||||
|
||||
# Add URL column if specified
|
||||
if input_data.url_column is not None:
|
||||
payload["csv"]["url"] = input_data.url_column
|
||||
|
||||
# Add entity configuration
|
||||
entity = {"type": input_data.entity_type.value}
|
||||
if (
|
||||
input_data.entity_type == ImportEntityType.CUSTOM
|
||||
and input_data.entity_description
|
||||
):
|
||||
entity["description"] = input_data.entity_description
|
||||
payload["entity"] = entity
|
||||
|
||||
# Add metadata if provided
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_import = aexa.websets.imports.create(
|
||||
params=payload, csv_data=input_data.csv_data
|
||||
)
|
||||
|
||||
import_obj = ImportModel.from_sdk(sdk_import)
|
||||
|
||||
yield "import_id", import_obj.id
|
||||
yield "status", import_obj.status
|
||||
yield "title", import_obj.title
|
||||
yield "count", import_obj.count
|
||||
yield "entity_type", import_obj.entity_type
|
||||
yield "upload_url", import_obj.upload_url
|
||||
yield "upload_valid_until", import_obj.upload_valid_until
|
||||
yield "created_at", import_obj.created_at
|
||||
|
||||
|
||||
class ExaGetImportBlock(Block):
|
||||
"""Get the status and details of an import."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
import_id: str = SchemaField(
|
||||
description="The ID of the import to retrieve",
|
||||
placeholder="import-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
import_id: str = SchemaField(description="The unique identifier for the import")
|
||||
status: str = SchemaField(description="Current status of the import")
|
||||
title: str = SchemaField(description="Title of the import")
|
||||
format: str = SchemaField(description="Format of the imported data")
|
||||
entity_type: str = SchemaField(description="Type of entities imported")
|
||||
count: int = SchemaField(description="Number of items imported")
|
||||
upload_url: Optional[str] = SchemaField(
|
||||
description="Upload URL for CSV data (if import not yet uploaded)"
|
||||
)
|
||||
upload_valid_until: Optional[str] = SchemaField(
|
||||
description="Expiration time for upload URL (if applicable)"
|
||||
)
|
||||
failed_reason: Optional[str] = SchemaField(
|
||||
description="Reason for failure (if applicable)"
|
||||
)
|
||||
failed_message: Optional[str] = SchemaField(
|
||||
description="Detailed failure message (if applicable)"
|
||||
)
|
||||
created_at: str = SchemaField(description="When the import was created")
|
||||
updated_at: str = SchemaField(description="When the import was last updated")
|
||||
metadata: dict = SchemaField(description="Metadata attached to the import")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="236663c8-a8dc-45f7-a050-2676bb0a3dd2",
|
||||
description="Get the status and details of an import",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=ExaGetImportBlock.Input,
|
||||
output_schema=ExaGetImportBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_import = aexa.websets.imports.get(import_id=input_data.import_id)
|
||||
|
||||
import_obj = ImportModel.from_sdk(sdk_import)
|
||||
|
||||
# Yield all fields
|
||||
yield "import_id", import_obj.id
|
||||
yield "status", import_obj.status
|
||||
yield "title", import_obj.title
|
||||
yield "format", import_obj.format
|
||||
yield "entity_type", import_obj.entity_type
|
||||
yield "count", import_obj.count
|
||||
yield "upload_url", import_obj.upload_url
|
||||
yield "upload_valid_until", import_obj.upload_valid_until
|
||||
yield "failed_reason", import_obj.failed_reason
|
||||
yield "failed_message", import_obj.failed_message
|
||||
yield "created_at", import_obj.created_at
|
||||
yield "updated_at", import_obj.updated_at
|
||||
yield "metadata", import_obj.metadata
|
||||
|
||||
|
||||
class ExaListImportsBlock(Block):
|
||||
"""List all imports with pagination."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
default=25,
|
||||
description="Number of imports to return",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor for pagination",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
imports: list[dict] = SchemaField(description="List of imports")
|
||||
import_item: dict = SchemaField(
|
||||
description="Individual import (yielded for each import)"
|
||||
)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more imports to paginate through"
|
||||
)
|
||||
next_cursor: Optional[str] = SchemaField(
|
||||
description="Cursor for the next page of results"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="65323630-f7e9-4692-a624-184ba14c0686",
|
||||
description="List all imports with pagination support",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=ExaListImportsBlock.Input,
|
||||
output_schema=ExaListImportsBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
response = aexa.websets.imports.list(
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
# Convert SDK imports to our stable models
|
||||
imports = [ImportModel.from_sdk(i) for i in response.data]
|
||||
|
||||
yield "imports", [i.model_dump() for i in imports]
|
||||
|
||||
for import_obj in imports:
|
||||
yield "import_item", import_obj.model_dump()
|
||||
|
||||
yield "has_more", response.has_more
|
||||
yield "next_cursor", response.next_cursor
|
||||
|
||||
|
||||
class ExaDeleteImportBlock(Block):
|
||||
"""Delete an import."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
import_id: str = SchemaField(
|
||||
description="The ID of the import to delete",
|
||||
placeholder="import-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
import_id: str = SchemaField(description="The ID of the deleted import")
|
||||
success: str = SchemaField(description="Whether the deletion was successful")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="81ae30ed-c7ba-4b5d-8483-b726846e570c",
|
||||
description="Delete an import",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=ExaDeleteImportBlock.Input,
|
||||
output_schema=ExaDeleteImportBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_import = aexa.websets.imports.delete(import_id=input_data.import_id)
|
||||
|
||||
yield "import_id", deleted_import.id
|
||||
yield "success", "true"
|
||||
|
||||
|
||||
class ExaExportWebsetBlock(Block):
|
||||
"""Export all data from a webset in various formats."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset to export",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
format: ExportFormat = SchemaField(
|
||||
default=ExportFormat.JSON,
|
||||
description="Export format",
|
||||
)
|
||||
include_content: bool = SchemaField(
|
||||
default=True,
|
||||
description="Include full content in export",
|
||||
)
|
||||
include_enrichments: bool = SchemaField(
|
||||
default=True,
|
||||
description="Include enrichment data in export",
|
||||
)
|
||||
max_items: int = SchemaField(
|
||||
default=100,
|
||||
description="Maximum number of items to export",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
export_data: str = SchemaField(
|
||||
description="Exported data in the requested format"
|
||||
)
|
||||
item_count: int = SchemaField(description="Number of items exported")
|
||||
total_items: int = SchemaField(
|
||||
description="Total number of items in the webset"
|
||||
)
|
||||
truncated: bool = SchemaField(
|
||||
description="Whether the export was truncated due to max_items limit"
|
||||
)
|
||||
format: str = SchemaField(description="Format of the exported data")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5da9d0fd-4b5b-4318-8302-8f71d0ccce9d",
|
||||
description="Export webset data in JSON, CSV, or JSON Lines format",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=ExaExportWebsetBlock.Input,
|
||||
output_schema=ExaExportWebsetBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"webset_id": "test-webset",
|
||||
"format": ExportFormat.JSON,
|
||||
"include_content": True,
|
||||
"include_enrichments": True,
|
||||
"max_items": 10,
|
||||
},
|
||||
test_output=[
|
||||
("export_data", str),
|
||||
("item_count", 2),
|
||||
("total_items", 2),
|
||||
("truncated", False),
|
||||
("format", "json"),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock=self._create_test_mock(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_test_mock():
|
||||
"""Create test mocks for the AsyncExa SDK."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create mock webset items
|
||||
mock_item1 = MagicMock()
|
||||
mock_item1.model_dump = MagicMock(
|
||||
return_value={
|
||||
"id": "item-1",
|
||||
"url": "https://example.com",
|
||||
"title": "Test Item 1",
|
||||
}
|
||||
)
|
||||
|
||||
mock_item2 = MagicMock()
|
||||
mock_item2.model_dump = MagicMock(
|
||||
return_value={
|
||||
"id": "item-2",
|
||||
"url": "https://example.org",
|
||||
"title": "Test Item 2",
|
||||
}
|
||||
)
|
||||
|
||||
# Create mock iterator
|
||||
mock_items = [mock_item1, mock_item2]
|
||||
|
||||
return {
|
||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||
websets=MagicMock(
|
||||
items=MagicMock(list_all=lambda *args, **kwargs: iter(mock_items))
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
def _get_client(self, api_key: str) -> AsyncExa:
|
||||
"""Get Exa client (separated for testing)."""
|
||||
return AsyncExa(api_key=api_key)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = self._get_client(credentials.api_key.get_secret_value())
|
||||
|
||||
try:
|
||||
all_items = []
|
||||
|
||||
# Use SDK's list_all iterator to fetch items
|
||||
item_iterator = aexa.websets.items.list_all(
|
||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||
)
|
||||
|
||||
for sdk_item in item_iterator:
|
||||
if len(all_items) >= input_data.max_items:
|
||||
break
|
||||
|
||||
# Convert to dict for export
|
||||
item_dict = sdk_item.model_dump(by_alias=True, exclude_none=True)
|
||||
all_items.append(item_dict)
|
||||
|
||||
# Calculate total and truncated
|
||||
total_items = len(all_items) # SDK doesn't provide total count
|
||||
truncated = len(all_items) >= input_data.max_items
|
||||
|
||||
# Process items based on include flags
|
||||
if not input_data.include_content:
|
||||
for item in all_items:
|
||||
item.pop("content", None)
|
||||
|
||||
if not input_data.include_enrichments:
|
||||
for item in all_items:
|
||||
item.pop("enrichments", None)
|
||||
|
||||
# Format the export data
|
||||
export_data = ""
|
||||
|
||||
if input_data.format == ExportFormat.JSON:
|
||||
export_data = json.dumps(all_items, indent=2, default=str)
|
||||
|
||||
elif input_data.format == ExportFormat.JSON_LINES:
|
||||
lines = [json.dumps(item, default=str) for item in all_items]
|
||||
export_data = "\n".join(lines)
|
||||
|
||||
elif input_data.format == ExportFormat.CSV:
|
||||
# Extract all unique keys for CSV headers
|
||||
all_keys = set()
|
||||
for item in all_items:
|
||||
all_keys.update(self._flatten_dict(item).keys())
|
||||
|
||||
# Create CSV
|
||||
output = StringIO()
|
||||
writer = csv.DictWriter(output, fieldnames=sorted(all_keys))
|
||||
writer.writeheader()
|
||||
|
||||
for item in all_items:
|
||||
flat_item = self._flatten_dict(item)
|
||||
writer.writerow(flat_item)
|
||||
|
||||
export_data = output.getvalue()
|
||||
|
||||
yield "export_data", export_data
|
||||
yield "item_count", len(all_items)
|
||||
yield "total_items", total_items
|
||||
yield "truncated", truncated
|
||||
yield "format", input_data.format.value
|
||||
|
||||
except ValueError as e:
|
||||
# Re-raise user input validation errors
|
||||
raise ValueError(f"Failed to export webset: {e}") from e
|
||||
# Let all other exceptions propagate naturally
|
||||
|
||||
def _flatten_dict(self, d: dict, parent_key: str = "", sep: str = "_") -> dict:
|
||||
"""Flatten nested dictionaries for CSV export."""
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
items.extend(self._flatten_dict(v, new_key, sep=sep).items())
|
||||
elif isinstance(v, list):
|
||||
# Convert lists to JSON strings for CSV
|
||||
items.append((new_key, json.dumps(v, default=str)))
|
||||
else:
|
||||
items.append((new_key, v))
|
||||
return dict(items)
|
||||
@@ -1,591 +0,0 @@
|
||||
"""
|
||||
Exa Websets Item Management Blocks
|
||||
|
||||
This module provides blocks for managing items within Exa websets, including
|
||||
retrieving, listing, deleting, and bulk operations on webset items.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.websets.types import WebsetItem as SdkWebsetItem
|
||||
from exa_py.websets.types import (
|
||||
WebsetItemArticleProperties,
|
||||
WebsetItemCompanyProperties,
|
||||
WebsetItemCustomProperties,
|
||||
WebsetItemPersonProperties,
|
||||
WebsetItemResearchPaperProperties,
|
||||
)
|
||||
from pydantic import AnyUrl, BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
# Mirrored model for enrichment results
|
||||
class EnrichmentResultModel(BaseModel):
|
||||
"""Stable output model mirroring SDK EnrichmentResult."""
|
||||
|
||||
enrichment_id: str
|
||||
format: str
|
||||
result: Optional[List[str]]
|
||||
reasoning: Optional[str]
|
||||
references: List[Dict[str, Any]]
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, sdk_enrich) -> "EnrichmentResultModel":
|
||||
"""Convert SDK EnrichmentResult to our model."""
|
||||
format_str = (
|
||||
sdk_enrich.format.value
|
||||
if hasattr(sdk_enrich.format, "value")
|
||||
else str(sdk_enrich.format)
|
||||
)
|
||||
|
||||
# Convert references to dicts
|
||||
references_list = []
|
||||
if sdk_enrich.references:
|
||||
for ref in sdk_enrich.references:
|
||||
references_list.append(ref.model_dump(by_alias=True, exclude_none=True))
|
||||
|
||||
return cls(
|
||||
enrichment_id=sdk_enrich.enrichment_id,
|
||||
format=format_str,
|
||||
result=sdk_enrich.result,
|
||||
reasoning=sdk_enrich.reasoning,
|
||||
references=references_list,
|
||||
)
|
||||
|
||||
|
||||
# Mirrored model for stability - don't use SDK types directly in block outputs
|
||||
class WebsetItemModel(BaseModel):
|
||||
"""Stable output model mirroring SDK WebsetItem."""
|
||||
|
||||
id: str
|
||||
url: Optional[AnyUrl]
|
||||
title: str
|
||||
content: str
|
||||
entity_data: Dict[str, Any]
|
||||
enrichments: Dict[str, EnrichmentResultModel]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, item: SdkWebsetItem) -> "WebsetItemModel":
|
||||
"""Convert SDK WebsetItem to our stable model."""
|
||||
# Extract properties from the union type
|
||||
properties_dict = {}
|
||||
url_value = None
|
||||
title = ""
|
||||
content = ""
|
||||
|
||||
if hasattr(item, "properties") and item.properties:
|
||||
properties_dict = item.properties.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
|
||||
# URL is always available on all property types
|
||||
url_value = item.properties.url
|
||||
|
||||
# Extract title using isinstance checks on the union type
|
||||
if isinstance(item.properties, WebsetItemPersonProperties):
|
||||
title = item.properties.person.name
|
||||
content = "" # Person type has no content
|
||||
elif isinstance(item.properties, WebsetItemCompanyProperties):
|
||||
title = item.properties.company.name
|
||||
content = item.properties.content or ""
|
||||
elif isinstance(item.properties, WebsetItemArticleProperties):
|
||||
title = item.properties.description
|
||||
content = item.properties.content or ""
|
||||
elif isinstance(item.properties, WebsetItemResearchPaperProperties):
|
||||
title = item.properties.description
|
||||
content = item.properties.content or ""
|
||||
elif isinstance(item.properties, WebsetItemCustomProperties):
|
||||
title = item.properties.description
|
||||
content = item.properties.content or ""
|
||||
else:
|
||||
# Fallback
|
||||
title = item.properties.description
|
||||
content = getattr(item.properties, "content", "")
|
||||
|
||||
# Convert enrichments from list to dict keyed by enrichment_id using Pydantic models
|
||||
enrichments_dict: Dict[str, EnrichmentResultModel] = {}
|
||||
if hasattr(item, "enrichments") and item.enrichments:
|
||||
for sdk_enrich in item.enrichments:
|
||||
enrich_model = EnrichmentResultModel.from_sdk(sdk_enrich)
|
||||
enrichments_dict[enrich_model.enrichment_id] = enrich_model
|
||||
|
||||
return cls(
|
||||
id=item.id,
|
||||
url=url_value,
|
||||
title=title,
|
||||
content=content or "",
|
||||
entity_data=properties_dict,
|
||||
enrichments=enrichments_dict,
|
||||
created_at=item.created_at.isoformat() if item.created_at else "",
|
||||
updated_at=item.updated_at.isoformat() if item.updated_at else "",
|
||||
)
|
||||
|
||||
|
||||
class ExaGetWebsetItemBlock(Block):
|
||||
"""Get a specific item from a webset by its ID."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
item_id: str = SchemaField(
|
||||
description="The ID of the specific item to retrieve",
|
||||
placeholder="item-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
item_id: str = SchemaField(description="The unique identifier for the item")
|
||||
url: str = SchemaField(description="The URL of the original source")
|
||||
title: str = SchemaField(description="The title of the item")
|
||||
content: str = SchemaField(description="The main content of the item")
|
||||
entity_data: dict = SchemaField(description="Entity-specific structured data")
|
||||
enrichments: dict = SchemaField(description="Enrichment data added to the item")
|
||||
created_at: str = SchemaField(
|
||||
description="When the item was added to the webset"
|
||||
)
|
||||
updated_at: str = SchemaField(description="When the item was last updated")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c4a7d9e2-8f3b-4a6c-9d8e-a5b6c7d8e9f0",
|
||||
description="Get a specific item from a webset by its ID",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaGetWebsetItemBlock.Input,
|
||||
output_schema=ExaGetWebsetItemBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_item = aexa.websets.items.get(
|
||||
webset_id=input_data.webset_id, id=input_data.item_id
|
||||
)
|
||||
|
||||
item = WebsetItemModel.from_sdk(sdk_item)
|
||||
|
||||
yield "item_id", item.id
|
||||
yield "url", item.url
|
||||
yield "title", item.title
|
||||
yield "content", item.content
|
||||
yield "entity_data", item.entity_data
|
||||
yield "enrichments", item.enrichments
|
||||
yield "created_at", item.created_at
|
||||
yield "updated_at", item.updated_at
|
||||
|
||||
|
||||
class ExaListWebsetItemsBlock(Block):
|
||||
"""List items in a webset with pagination and optional filtering."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
default=25,
|
||||
description="Number of items to return (1-100)",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor for pagination through results",
|
||||
advanced=True,
|
||||
)
|
||||
wait_for_items: bool = SchemaField(
|
||||
default=False,
|
||||
description="Wait for items to be available if webset is still processing",
|
||||
advanced=True,
|
||||
)
|
||||
wait_timeout: int = SchemaField(
|
||||
default=60,
|
||||
description="Maximum time to wait for items in seconds",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=300,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
items: list[WebsetItemModel] = SchemaField(
|
||||
description="List of webset items",
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID of the webset",
|
||||
)
|
||||
item: WebsetItemModel = SchemaField(
|
||||
description="Individual item (yielded for each item in the list)",
|
||||
)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more items to paginate through",
|
||||
)
|
||||
next_cursor: Optional[str] = SchemaField(
|
||||
description="Cursor for the next page of results",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7b5e8c9f-01a2-43c4-95e6-f7a8b9c0d1e2",
|
||||
description="List items in a webset with pagination support",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaListWebsetItemsBlock.Input,
|
||||
output_schema=ExaListWebsetItemsBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
if input_data.wait_for_items:
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
interval = 2
|
||||
response = None
|
||||
|
||||
while time.time() - start_time < input_data.wait_timeout:
|
||||
response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
if response.data:
|
||||
break
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
interval = min(interval * 1.2, 10)
|
||||
|
||||
if not response:
|
||||
response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
else:
|
||||
response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
)
|
||||
|
||||
items = [WebsetItemModel.from_sdk(item) for item in response.data]
|
||||
|
||||
yield "items", items
|
||||
|
||||
for item in items:
|
||||
yield "item", item
|
||||
|
||||
yield "has_more", response.has_more
|
||||
yield "next_cursor", response.next_cursor
|
||||
yield "webset_id", input_data.webset_id
|
||||
|
||||
|
||||
class ExaDeleteWebsetItemBlock(Block):
|
||||
"""Delete a specific item from a webset."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
item_id: str = SchemaField(
|
||||
description="The ID of the item to delete",
|
||||
placeholder="item-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
item_id: str = SchemaField(description="The ID of the deleted item")
|
||||
success: str = SchemaField(description="Whether the deletion was successful")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="12c57fbe-c270-4877-a2b6-d2d05529ba79",
|
||||
description="Delete a specific item from a webset",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaDeleteWebsetItemBlock.Input,
|
||||
output_schema=ExaDeleteWebsetItemBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_item = aexa.websets.items.delete(
|
||||
webset_id=input_data.webset_id, id=input_data.item_id
|
||||
)
|
||||
|
||||
yield "item_id", deleted_item.id
|
||||
yield "success", "true"
|
||||
|
||||
|
||||
class ExaBulkWebsetItemsBlock(Block):
|
||||
"""Get all items from a webset in a single operation (with size limits)."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
max_items: int = SchemaField(
|
||||
default=100,
|
||||
description="Maximum number of items to retrieve (1-1000). Note: Large values may take longer.",
|
||||
ge=1,
|
||||
le=1000,
|
||||
)
|
||||
include_enrichments: bool = SchemaField(
|
||||
default=True,
|
||||
description="Include enrichment data for each item",
|
||||
)
|
||||
include_content: bool = SchemaField(
|
||||
default=True,
|
||||
description="Include full content for each item",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
items: list[WebsetItemModel] = SchemaField(
|
||||
description="All items from the webset"
|
||||
)
|
||||
item: WebsetItemModel = SchemaField(
|
||||
description="Individual item (yielded for each item)"
|
||||
)
|
||||
total_retrieved: int = SchemaField(
|
||||
description="Total number of items retrieved"
|
||||
)
|
||||
truncated: bool = SchemaField(
|
||||
description="Whether results were truncated due to max_items limit"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="dbd619f5-476e-4395-af9a-a7a7c0fb8c4e",
|
||||
description="Get all items from a webset in bulk (with configurable limits)",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaBulkWebsetItemsBlock.Input,
|
||||
output_schema=ExaBulkWebsetItemsBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
all_items: List[WebsetItemModel] = []
|
||||
item_iterator = aexa.websets.items.list_all(
|
||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||
)
|
||||
|
||||
for sdk_item in item_iterator:
|
||||
if len(all_items) >= input_data.max_items:
|
||||
break
|
||||
|
||||
item = WebsetItemModel.from_sdk(sdk_item)
|
||||
|
||||
if not input_data.include_enrichments:
|
||||
item.enrichments = {}
|
||||
if not input_data.include_content:
|
||||
item.content = ""
|
||||
|
||||
all_items.append(item)
|
||||
|
||||
yield "items", all_items
|
||||
|
||||
for item in all_items:
|
||||
yield "item", item
|
||||
|
||||
yield "total_retrieved", len(all_items)
|
||||
yield "truncated", len(all_items) >= input_data.max_items
|
||||
|
||||
|
||||
class ExaWebsetItemsSummaryBlock(Block):
|
||||
"""Get a summary of items in a webset without retrieving all data."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
sample_size: int = SchemaField(
|
||||
default=5,
|
||||
description="Number of sample items to include",
|
||||
ge=0,
|
||||
le=10,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
total_items: int = SchemaField(
|
||||
description="Total number of items in the webset"
|
||||
)
|
||||
entity_type: str = SchemaField(description="Type of entities in the webset")
|
||||
sample_items: list[WebsetItemModel] = SchemaField(
|
||||
description="Sample of items from the webset"
|
||||
)
|
||||
enrichment_columns: list[str] = SchemaField(
|
||||
description="List of enrichment columns available"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="db7813ad-10bd-4652-8623-5667d6fecdd5",
|
||||
description="Get a summary of webset items without retrieving all data",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaWebsetItemsSummaryBlock.Input,
|
||||
output_schema=ExaWebsetItemsSummaryBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
entity_type = "unknown"
|
||||
if webset.searches:
|
||||
first_search = webset.searches[0]
|
||||
if first_search.entity:
|
||||
# The entity is a union type, extract type field
|
||||
entity_dict = first_search.entity.model_dump(by_alias=True)
|
||||
entity_type = entity_dict.get("type", "unknown")
|
||||
|
||||
# Get enrichment columns
|
||||
enrichment_columns = []
|
||||
if webset.enrichments:
|
||||
enrichment_columns = [
|
||||
e.title if e.title else e.description for e in webset.enrichments
|
||||
]
|
||||
|
||||
# Get sample items if requested
|
||||
sample_items: List[WebsetItemModel] = []
|
||||
if input_data.sample_size > 0:
|
||||
items_response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||
)
|
||||
# Convert to our stable models
|
||||
sample_items = [
|
||||
WebsetItemModel.from_sdk(item) for item in items_response.data
|
||||
]
|
||||
|
||||
total_items = 0
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.progress:
|
||||
total_items += search.progress.found
|
||||
|
||||
yield "total_items", total_items
|
||||
yield "entity_type", entity_type
|
||||
yield "sample_items", sample_items
|
||||
yield "enrichment_columns", enrichment_columns
|
||||
|
||||
|
||||
class ExaGetNewItemsBlock(Block):
|
||||
"""Get items added to a webset since a specific cursor (incremental processing helper)."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
since_cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor from previous run - only items after this will be returned. Leave empty on first run.",
|
||||
placeholder="cursor-from-previous-run",
|
||||
)
|
||||
max_items: int = SchemaField(
|
||||
default=100,
|
||||
description="Maximum number of new items to retrieve",
|
||||
ge=1,
|
||||
le=1000,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
new_items: list[WebsetItemModel] = SchemaField(
|
||||
description="Items added since the cursor"
|
||||
)
|
||||
item: WebsetItemModel = SchemaField(
|
||||
description="Individual item (yielded for each new item)"
|
||||
)
|
||||
count: int = SchemaField(description="Number of new items found")
|
||||
next_cursor: Optional[str] = SchemaField(
|
||||
description="Save this cursor for the next run to get only newer items"
|
||||
)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more new items beyond max_items"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3ff9bdf5-9613-4d21-8a60-90eb8b69c414",
|
||||
description="Get items added since a cursor - enables incremental processing without reprocessing",
|
||||
categories={BlockCategory.SEARCH, BlockCategory.DATA},
|
||||
input_schema=ExaGetNewItemsBlock.Input,
|
||||
output_schema=ExaGetNewItemsBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Get items starting from cursor
|
||||
response = aexa.websets.items.list(
|
||||
webset_id=input_data.webset_id,
|
||||
cursor=input_data.since_cursor,
|
||||
limit=input_data.max_items,
|
||||
)
|
||||
|
||||
# Convert SDK items to our stable models
|
||||
new_items = [WebsetItemModel.from_sdk(item) for item in response.data]
|
||||
|
||||
# Yield the full list
|
||||
yield "new_items", new_items
|
||||
|
||||
# Yield individual items for processing
|
||||
for item in new_items:
|
||||
yield "item", item
|
||||
|
||||
# Yield metadata for next run
|
||||
yield "count", len(new_items)
|
||||
yield "next_cursor", response.next_cursor
|
||||
yield "has_more", response.has_more
|
||||
@@ -1,600 +0,0 @@
|
||||
"""
|
||||
Exa Websets Monitor Management Blocks
|
||||
|
||||
This module provides blocks for creating and managing monitors that automatically
|
||||
keep websets updated with fresh data on a schedule.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.websets.types import Monitor as SdkMonitor
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from ._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
|
||||
|
||||
# Mirrored model for stability - don't use SDK types directly in block outputs
|
||||
class MonitorModel(BaseModel):
|
||||
"""Stable output model mirroring SDK Monitor."""
|
||||
|
||||
id: str
|
||||
status: str
|
||||
webset_id: str
|
||||
behavior_type: str
|
||||
behavior_config: dict
|
||||
cron_expression: str
|
||||
timezone: str
|
||||
next_run_at: str
|
||||
last_run: dict
|
||||
metadata: dict
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, monitor: SdkMonitor) -> "MonitorModel":
|
||||
"""Convert SDK Monitor to our stable model."""
|
||||
# Extract behavior information
|
||||
behavior_dict = monitor.behavior.model_dump(by_alias=True, exclude_none=True)
|
||||
behavior_type = behavior_dict.get("type", "unknown")
|
||||
behavior_config = behavior_dict.get("config", {})
|
||||
|
||||
# Extract cadence information
|
||||
cadence_dict = monitor.cadence.model_dump(by_alias=True, exclude_none=True)
|
||||
cron_expr = cadence_dict.get("cron", "")
|
||||
timezone = cadence_dict.get("timezone", "Etc/UTC")
|
||||
|
||||
# Extract last run information
|
||||
last_run_dict = {}
|
||||
if monitor.last_run:
|
||||
last_run_dict = monitor.last_run.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
|
||||
# Handle status enum
|
||||
status_str = (
|
||||
monitor.status.value
|
||||
if hasattr(monitor.status, "value")
|
||||
else str(monitor.status)
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=monitor.id,
|
||||
status=status_str,
|
||||
webset_id=monitor.webset_id,
|
||||
behavior_type=behavior_type,
|
||||
behavior_config=behavior_config,
|
||||
cron_expression=cron_expr,
|
||||
timezone=timezone,
|
||||
next_run_at=monitor.next_run_at.isoformat() if monitor.next_run_at else "",
|
||||
last_run=last_run_dict,
|
||||
metadata=monitor.metadata or {},
|
||||
created_at=monitor.created_at.isoformat() if monitor.created_at else "",
|
||||
updated_at=monitor.updated_at.isoformat() if monitor.updated_at else "",
|
||||
)
|
||||
|
||||
|
||||
class MonitorStatus(str, Enum):
|
||||
"""Status of a monitor."""
|
||||
|
||||
ENABLED = "enabled"
|
||||
DISABLED = "disabled"
|
||||
PAUSED = "paused"
|
||||
|
||||
|
||||
class MonitorBehaviorType(str, Enum):
|
||||
"""Type of behavior for a monitor."""
|
||||
|
||||
SEARCH = "search" # Run new searches
|
||||
REFRESH = "refresh" # Refresh existing items
|
||||
|
||||
|
||||
class SearchBehavior(str, Enum):
|
||||
"""How search results interact with existing items."""
|
||||
|
||||
APPEND = "append"
|
||||
OVERRIDE = "override"
|
||||
|
||||
|
||||
class ExaCreateMonitorBlock(Block):
|
||||
"""Create a monitor to automatically keep a webset updated on a schedule."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset to monitor",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
|
||||
# Schedule configuration
|
||||
cron_expression: str = SchemaField(
|
||||
description="Cron expression for scheduling (5 fields, max once per day)",
|
||||
placeholder="0 9 * * 1", # Every Monday at 9 AM
|
||||
)
|
||||
timezone: str = SchemaField(
|
||||
default="Etc/UTC",
|
||||
description="IANA timezone for the schedule",
|
||||
placeholder="America/New_York",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Behavior configuration
|
||||
behavior_type: MonitorBehaviorType = SchemaField(
|
||||
default=MonitorBehaviorType.SEARCH,
|
||||
description="Type of monitor behavior (search for new items or refresh existing)",
|
||||
)
|
||||
|
||||
# Search configuration (for SEARCH behavior)
|
||||
search_query: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Search query for finding new items (required for search behavior)",
|
||||
placeholder="AI startups that raised funding in the last week",
|
||||
)
|
||||
search_count: int = SchemaField(
|
||||
default=10,
|
||||
description="Number of items to find in each search",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
search_criteria: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="Criteria that items must meet",
|
||||
advanced=True,
|
||||
)
|
||||
search_behavior: SearchBehavior = SchemaField(
|
||||
default=SearchBehavior.APPEND,
|
||||
description="How new results interact with existing items",
|
||||
advanced=True,
|
||||
)
|
||||
entity_type: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Type of entity to search for (company, person, etc.)",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Refresh configuration (for REFRESH behavior)
|
||||
refresh_content: bool = SchemaField(
|
||||
default=True,
|
||||
description="Refresh content from source URLs (for refresh behavior)",
|
||||
advanced=True,
|
||||
)
|
||||
refresh_enrichments: bool = SchemaField(
|
||||
default=True,
|
||||
description="Re-run enrichments on items (for refresh behavior)",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Metadata
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="Metadata to attach to the monitor",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
monitor_id: str = SchemaField(
|
||||
description="The unique identifier for the created monitor"
|
||||
)
|
||||
webset_id: str = SchemaField(description="The webset this monitor belongs to")
|
||||
status: str = SchemaField(description="Status of the monitor")
|
||||
behavior_type: str = SchemaField(description="Type of monitor behavior")
|
||||
next_run_at: Optional[str] = SchemaField(
|
||||
description="When the monitor will next run"
|
||||
)
|
||||
cron_expression: str = SchemaField(description="The schedule cron expression")
|
||||
timezone: str = SchemaField(description="The timezone for scheduling")
|
||||
created_at: str = SchemaField(description="When the monitor was created")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f8a9b0c1-d2e3-4567-890a-bcdef1234567",
|
||||
description="Create automated monitors to keep websets updated with fresh data on a schedule",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaCreateMonitorBlock.Input,
|
||||
output_schema=ExaCreateMonitorBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"webset_id": "test-webset",
|
||||
"cron_expression": "0 9 * * 1",
|
||||
"behavior_type": MonitorBehaviorType.SEARCH,
|
||||
"search_query": "AI startups",
|
||||
"search_count": 10,
|
||||
},
|
||||
test_output=[
|
||||
("monitor_id", "monitor-123"),
|
||||
("webset_id", "test-webset"),
|
||||
("status", "enabled"),
|
||||
("behavior_type", "search"),
|
||||
("next_run_at", "2024-01-01T00:00:00"),
|
||||
("cron_expression", "0 9 * * 1"),
|
||||
("timezone", "Etc/UTC"),
|
||||
("created_at", "2024-01-01T00:00:00"),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock=self._create_test_mock(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_test_mock():
|
||||
"""Create test mocks for the AsyncExa SDK."""
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create mock SDK monitor object
|
||||
mock_monitor = MagicMock()
|
||||
mock_monitor.id = "monitor-123"
|
||||
mock_monitor.status = MagicMock(value="enabled")
|
||||
mock_monitor.webset_id = "test-webset"
|
||||
mock_monitor.next_run_at = datetime.fromisoformat("2024-01-01T00:00:00")
|
||||
mock_monitor.created_at = datetime.fromisoformat("2024-01-01T00:00:00")
|
||||
mock_monitor.updated_at = datetime.fromisoformat("2024-01-01T00:00:00")
|
||||
mock_monitor.metadata = {}
|
||||
mock_monitor.last_run = None
|
||||
|
||||
# Mock behavior
|
||||
mock_behavior = MagicMock()
|
||||
mock_behavior.model_dump = MagicMock(
|
||||
return_value={"type": "search", "config": {}}
|
||||
)
|
||||
mock_monitor.behavior = mock_behavior
|
||||
|
||||
# Mock cadence
|
||||
mock_cadence = MagicMock()
|
||||
mock_cadence.model_dump = MagicMock(
|
||||
return_value={"cron": "0 9 * * 1", "timezone": "Etc/UTC"}
|
||||
)
|
||||
mock_monitor.cadence = mock_cadence
|
||||
|
||||
return {
|
||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||
websets=MagicMock(
|
||||
monitors=MagicMock(create=lambda *args, **kwargs: mock_monitor)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
def _get_client(self, api_key: str) -> AsyncExa:
|
||||
"""Get Exa client (separated for testing)."""
|
||||
return AsyncExa(api_key=api_key)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
aexa = self._get_client(credentials.api_key.get_secret_value())
|
||||
|
||||
# Build the payload
|
||||
payload = {
|
||||
"websetId": input_data.webset_id,
|
||||
"cadence": {
|
||||
"cron": input_data.cron_expression,
|
||||
"timezone": input_data.timezone,
|
||||
},
|
||||
}
|
||||
|
||||
# Build behavior configuration based on type
|
||||
if input_data.behavior_type == MonitorBehaviorType.SEARCH:
|
||||
behavior_config = {
|
||||
"query": input_data.search_query or "",
|
||||
"count": input_data.search_count,
|
||||
"behavior": input_data.search_behavior.value,
|
||||
}
|
||||
|
||||
if input_data.search_criteria:
|
||||
behavior_config["criteria"] = [
|
||||
{"description": c} for c in input_data.search_criteria
|
||||
]
|
||||
|
||||
if input_data.entity_type:
|
||||
behavior_config["entity"] = {"type": input_data.entity_type}
|
||||
|
||||
payload["behavior"] = {
|
||||
"type": "search",
|
||||
"config": behavior_config,
|
||||
}
|
||||
else:
|
||||
# REFRESH behavior
|
||||
payload["behavior"] = {
|
||||
"type": "refresh",
|
||||
"config": {
|
||||
"content": input_data.refresh_content,
|
||||
"enrichments": input_data.refresh_enrichments,
|
||||
},
|
||||
}
|
||||
|
||||
# Add metadata if provided
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_monitor = aexa.websets.monitors.create(params=payload)
|
||||
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
|
||||
# Yield all fields
|
||||
yield "monitor_id", monitor.id
|
||||
yield "webset_id", monitor.webset_id
|
||||
yield "status", monitor.status
|
||||
yield "behavior_type", monitor.behavior_type
|
||||
yield "next_run_at", monitor.next_run_at
|
||||
yield "cron_expression", monitor.cron_expression
|
||||
yield "timezone", monitor.timezone
|
||||
yield "created_at", monitor.created_at
|
||||
|
||||
|
||||
class ExaGetMonitorBlock(Block):
|
||||
"""Get the details and status of a monitor."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
monitor_id: str = SchemaField(
|
||||
description="The ID of the monitor to retrieve",
|
||||
placeholder="monitor-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
monitor_id: str = SchemaField(
|
||||
description="The unique identifier for the monitor"
|
||||
)
|
||||
webset_id: str = SchemaField(description="The webset this monitor belongs to")
|
||||
status: str = SchemaField(description="Current status of the monitor")
|
||||
behavior_type: str = SchemaField(description="Type of monitor behavior")
|
||||
behavior_config: dict = SchemaField(
|
||||
description="Configuration for the monitor behavior"
|
||||
)
|
||||
cron_expression: str = SchemaField(description="The schedule cron expression")
|
||||
timezone: str = SchemaField(description="The timezone for scheduling")
|
||||
next_run_at: Optional[str] = SchemaField(
|
||||
description="When the monitor will next run"
|
||||
)
|
||||
last_run: Optional[dict] = SchemaField(
|
||||
description="Information about the last run"
|
||||
)
|
||||
created_at: str = SchemaField(description="When the monitor was created")
|
||||
updated_at: str = SchemaField(description="When the monitor was last updated")
|
||||
metadata: dict = SchemaField(description="Metadata attached to the monitor")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5c852a2d-d505-4a56-b711-7def8dd14e72",
|
||||
description="Get the details and status of a webset monitor",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaGetMonitorBlock.Input,
|
||||
output_schema=ExaGetMonitorBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_monitor = aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
||||
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
|
||||
# Yield all fields
|
||||
yield "monitor_id", monitor.id
|
||||
yield "webset_id", monitor.webset_id
|
||||
yield "status", monitor.status
|
||||
yield "behavior_type", monitor.behavior_type
|
||||
yield "behavior_config", monitor.behavior_config
|
||||
yield "cron_expression", monitor.cron_expression
|
||||
yield "timezone", monitor.timezone
|
||||
yield "next_run_at", monitor.next_run_at
|
||||
yield "last_run", monitor.last_run
|
||||
yield "created_at", monitor.created_at
|
||||
yield "updated_at", monitor.updated_at
|
||||
yield "metadata", monitor.metadata
|
||||
|
||||
|
||||
class ExaUpdateMonitorBlock(Block):
|
||||
"""Update a monitor's configuration."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
monitor_id: str = SchemaField(
|
||||
description="The ID of the monitor to update",
|
||||
placeholder="monitor-id",
|
||||
)
|
||||
status: Optional[MonitorStatus] = SchemaField(
|
||||
default=None,
|
||||
description="New status for the monitor",
|
||||
)
|
||||
cron_expression: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="New cron expression for scheduling",
|
||||
)
|
||||
timezone: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="New timezone for the schedule",
|
||||
advanced=True,
|
||||
)
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="New metadata for the monitor",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
monitor_id: str = SchemaField(
|
||||
description="The unique identifier for the monitor"
|
||||
)
|
||||
status: str = SchemaField(description="Updated status of the monitor")
|
||||
next_run_at: Optional[str] = SchemaField(
|
||||
description="When the monitor will next run"
|
||||
)
|
||||
updated_at: str = SchemaField(description="When the monitor was updated")
|
||||
success: str = SchemaField(description="Whether the update was successful")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="245102c3-6af3-4515-a308-c2210b7939d2",
|
||||
description="Update a monitor's status, schedule, or metadata",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaUpdateMonitorBlock.Input,
|
||||
output_schema=ExaUpdateMonitorBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Build update payload
|
||||
payload = {}
|
||||
|
||||
if input_data.status is not None:
|
||||
payload["status"] = input_data.status.value
|
||||
|
||||
if input_data.cron_expression is not None or input_data.timezone is not None:
|
||||
cadence = {}
|
||||
if input_data.cron_expression:
|
||||
cadence["cron"] = input_data.cron_expression
|
||||
if input_data.timezone:
|
||||
cadence["timezone"] = input_data.timezone
|
||||
payload["cadence"] = cadence
|
||||
|
||||
if input_data.metadata is not None:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
sdk_monitor = aexa.websets.monitors.update(
|
||||
monitor_id=input_data.monitor_id, params=payload
|
||||
)
|
||||
|
||||
# Convert to our stable model
|
||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||
|
||||
# Yield fields
|
||||
yield "monitor_id", monitor.id
|
||||
yield "status", monitor.status
|
||||
yield "next_run_at", monitor.next_run_at
|
||||
yield "updated_at", monitor.updated_at
|
||||
yield "success", "true"
|
||||
|
||||
|
||||
class ExaDeleteMonitorBlock(Block):
|
||||
"""Delete a monitor from a webset."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
monitor_id: str = SchemaField(
|
||||
description="The ID of the monitor to delete",
|
||||
placeholder="monitor-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
monitor_id: str = SchemaField(description="The ID of the deleted monitor")
|
||||
success: str = SchemaField(description="Whether the deletion was successful")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f16f9b10-0c4d-4db8-997d-7b96b6026094",
|
||||
description="Delete a monitor from a webset",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaDeleteMonitorBlock.Input,
|
||||
output_schema=ExaDeleteMonitorBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
deleted_monitor = aexa.websets.monitors.delete(monitor_id=input_data.monitor_id)
|
||||
|
||||
yield "monitor_id", deleted_monitor.id
|
||||
yield "success", "true"
|
||||
|
||||
|
||||
class ExaListMonitorsBlock(Block):
|
||||
"""List all monitors with pagination."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Filter monitors by webset ID",
|
||||
placeholder="webset-id",
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
default=25,
|
||||
description="Number of monitors to return",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor for pagination",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
monitors: list[dict] = SchemaField(description="List of monitors")
|
||||
monitor: dict = SchemaField(
|
||||
description="Individual monitor (yielded for each monitor)"
|
||||
)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more monitors to paginate through"
|
||||
)
|
||||
next_cursor: Optional[str] = SchemaField(
|
||||
description="Cursor for the next page of results"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f06e2b38-5397-4e8f-aa85-491149dd98df",
|
||||
description="List all monitors with optional webset filtering",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaListMonitorsBlock.Input,
|
||||
output_schema=ExaListMonitorsBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
response = aexa.websets.monitors.list(
|
||||
cursor=input_data.cursor,
|
||||
limit=input_data.limit,
|
||||
webset_id=input_data.webset_id,
|
||||
)
|
||||
|
||||
# Convert SDK monitors to our stable models
|
||||
monitors = [MonitorModel.from_sdk(m) for m in response.data]
|
||||
|
||||
# Yield the full list
|
||||
yield "monitors", [m.model_dump() for m in monitors]
|
||||
|
||||
# Yield individual monitors for graph chaining
|
||||
for monitor in monitors:
|
||||
yield "monitor", monitor.model_dump()
|
||||
|
||||
# Yield pagination metadata
|
||||
yield "has_more", response.has_more
|
||||
yield "next_cursor", response.next_cursor
|
||||
@@ -1,600 +0,0 @@
|
||||
"""
|
||||
Exa Websets Polling Blocks
|
||||
|
||||
This module provides dedicated polling blocks for waiting on webset operations
|
||||
to complete, with progress tracking and timeout management.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
# Import WebsetItemModel for use in enrichment samples
|
||||
# This is safe as websets_items doesn't import from websets_polling
|
||||
from .websets_items import WebsetItemModel
|
||||
|
||||
|
||||
# Model for sample enrichment data
|
||||
class SampleEnrichmentModel(BaseModel):
|
||||
"""Sample enrichment result for display."""
|
||||
|
||||
item_id: str
|
||||
item_title: str
|
||||
enrichment_data: Dict[str, Any]
|
||||
|
||||
|
||||
class WebsetTargetStatus(str, Enum):
|
||||
IDLE = "idle"
|
||||
COMPLETED = "completed"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
ANY_COMPLETE = "any_complete" # Either idle or completed
|
||||
|
||||
|
||||
class ExaWaitForWebsetBlock(Block):
|
||||
"""Wait for a webset to reach a specific status with progress tracking."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset to monitor",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
target_status: WebsetTargetStatus = SchemaField(
|
||||
default=WebsetTargetStatus.IDLE,
|
||||
description="Status to wait for (idle=all operations complete, completed=search done, running=actively processing)",
|
||||
)
|
||||
timeout: int = SchemaField(
|
||||
default=300,
|
||||
description="Maximum time to wait in seconds",
|
||||
ge=1,
|
||||
le=1800, # 30 minutes max
|
||||
)
|
||||
check_interval: int = SchemaField(
|
||||
default=5,
|
||||
description="Initial interval between status checks in seconds",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=60,
|
||||
)
|
||||
max_interval: int = SchemaField(
|
||||
default=30,
|
||||
description="Maximum interval between checks (for exponential backoff)",
|
||||
advanced=True,
|
||||
ge=5,
|
||||
le=120,
|
||||
)
|
||||
include_progress: bool = SchemaField(
|
||||
default=True,
|
||||
description="Include detailed progress information in output",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
webset_id: str = SchemaField(description="The webset ID that was monitored")
|
||||
final_status: str = SchemaField(description="The final status of the webset")
|
||||
elapsed_time: float = SchemaField(description="Total time elapsed in seconds")
|
||||
item_count: int = SchemaField(description="Number of items found")
|
||||
search_progress: dict = SchemaField(
|
||||
description="Detailed search progress information"
|
||||
)
|
||||
enrichment_progress: dict = SchemaField(
|
||||
description="Detailed enrichment progress information"
|
||||
)
|
||||
timed_out: bool = SchemaField(description="Whether the operation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="619d71e8-b72a-434d-8bd4-23376dd0342c",
|
||||
description="Wait for a webset to reach a specific status with progress tracking",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaWaitForWebsetBlock.Input,
|
||||
output_schema=ExaWaitForWebsetBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
start_time = time.time()
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
try:
|
||||
if input_data.target_status in [
|
||||
WebsetTargetStatus.IDLE,
|
||||
WebsetTargetStatus.ANY_COMPLETE,
|
||||
]:
|
||||
final_webset = aexa.websets.wait_until_idle(
|
||||
id=input_data.webset_id,
|
||||
timeout=input_data.timeout,
|
||||
poll_interval=input_data.check_interval,
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
status_str = (
|
||||
final_webset.status.value
|
||||
if hasattr(final_webset.status, "value")
|
||||
else str(final_webset.status)
|
||||
)
|
||||
|
||||
item_count = 0
|
||||
if final_webset.searches:
|
||||
for search in final_webset.searches:
|
||||
if search.progress:
|
||||
item_count += search.progress.found
|
||||
|
||||
# Extract progress if requested
|
||||
search_progress = {}
|
||||
enrichment_progress = {}
|
||||
if input_data.include_progress:
|
||||
webset_dict = final_webset.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
search_progress = self._extract_search_progress(webset_dict)
|
||||
enrichment_progress = self._extract_enrichment_progress(webset_dict)
|
||||
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "final_status", status_str
|
||||
yield "elapsed_time", elapsed
|
||||
yield "item_count", item_count
|
||||
if input_data.include_progress:
|
||||
yield "search_progress", search_progress
|
||||
yield "enrichment_progress", enrichment_progress
|
||||
yield "timed_out", False
|
||||
else:
|
||||
# For other status targets, manually poll
|
||||
interval = input_data.check_interval
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
# Get current webset status
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
current_status = (
|
||||
webset.status.value
|
||||
if hasattr(webset.status, "value")
|
||||
else str(webset.status)
|
||||
)
|
||||
|
||||
# Check if target status reached
|
||||
if current_status == input_data.target_status.value:
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Estimate item count from search progress
|
||||
item_count = 0
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.progress:
|
||||
item_count += search.progress.found
|
||||
|
||||
search_progress = {}
|
||||
enrichment_progress = {}
|
||||
if input_data.include_progress:
|
||||
webset_dict = webset.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
search_progress = self._extract_search_progress(webset_dict)
|
||||
enrichment_progress = self._extract_enrichment_progress(
|
||||
webset_dict
|
||||
)
|
||||
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "final_status", current_status
|
||||
yield "elapsed_time", elapsed
|
||||
yield "item_count", item_count
|
||||
if input_data.include_progress:
|
||||
yield "search_progress", search_progress
|
||||
yield "enrichment_progress", enrichment_progress
|
||||
yield "timed_out", False
|
||||
return
|
||||
|
||||
# Wait before next check with exponential backoff
|
||||
await asyncio.sleep(interval)
|
||||
interval = min(interval * 1.5, input_data.max_interval)
|
||||
|
||||
# Timeout reached
|
||||
elapsed = time.time() - start_time
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
final_status = (
|
||||
webset.status.value
|
||||
if hasattr(webset.status, "value")
|
||||
else str(webset.status)
|
||||
)
|
||||
|
||||
item_count = 0
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.progress:
|
||||
item_count += search.progress.found
|
||||
|
||||
search_progress = {}
|
||||
enrichment_progress = {}
|
||||
if input_data.include_progress:
|
||||
webset_dict = webset.model_dump(by_alias=True, exclude_none=True)
|
||||
search_progress = self._extract_search_progress(webset_dict)
|
||||
enrichment_progress = self._extract_enrichment_progress(webset_dict)
|
||||
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "final_status", final_status
|
||||
yield "elapsed_time", elapsed
|
||||
yield "item_count", item_count
|
||||
if input_data.include_progress:
|
||||
yield "search_progress", search_progress
|
||||
yield "enrichment_progress", enrichment_progress
|
||||
yield "timed_out", True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise ValueError(
|
||||
f"Polling timed out after {input_data.timeout} seconds"
|
||||
) from None
|
||||
|
||||
def _extract_search_progress(self, webset_data: dict) -> dict:
|
||||
"""Extract search progress information from webset data."""
|
||||
progress = {}
|
||||
searches = webset_data.get("searches", [])
|
||||
|
||||
for idx, search in enumerate(searches):
|
||||
search_id = search.get("id", f"search_{idx}")
|
||||
search_progress = search.get("progress", {})
|
||||
|
||||
progress[search_id] = {
|
||||
"status": search.get("status", "unknown"),
|
||||
"found": search_progress.get("found", 0),
|
||||
"analyzed": search_progress.get("analyzed", 0),
|
||||
"completion": search_progress.get("completion", 0),
|
||||
"time_left": search_progress.get("timeLeft", 0),
|
||||
}
|
||||
|
||||
return progress
|
||||
|
||||
def _extract_enrichment_progress(self, webset_data: dict) -> dict:
|
||||
"""Extract enrichment progress information from webset data."""
|
||||
progress = {}
|
||||
enrichments = webset_data.get("enrichments", [])
|
||||
|
||||
for idx, enrichment in enumerate(enrichments):
|
||||
enrich_id = enrichment.get("id", f"enrichment_{idx}")
|
||||
|
||||
progress[enrich_id] = {
|
||||
"status": enrichment.get("status", "unknown"),
|
||||
"title": enrichment.get("title", ""),
|
||||
"description": enrichment.get("description", ""),
|
||||
}
|
||||
|
||||
return progress
|
||||
|
||||
|
||||
class ExaWaitForSearchBlock(Block):
|
||||
"""Wait for a specific webset search to complete with progress tracking."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
search_id: str = SchemaField(
|
||||
description="The ID of the search to monitor",
|
||||
placeholder="search-id",
|
||||
)
|
||||
timeout: int = SchemaField(
|
||||
default=300,
|
||||
description="Maximum time to wait in seconds",
|
||||
ge=1,
|
||||
le=1800,
|
||||
)
|
||||
check_interval: int = SchemaField(
|
||||
default=5,
|
||||
description="Initial interval between status checks in seconds",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=60,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
search_id: str = SchemaField(description="The search ID that was monitored")
|
||||
final_status: str = SchemaField(description="The final status of the search")
|
||||
items_found: int = SchemaField(
|
||||
description="Number of items found by the search"
|
||||
)
|
||||
items_analyzed: int = SchemaField(description="Number of items analyzed")
|
||||
completion_percentage: int = SchemaField(
|
||||
description="Completion percentage (0-100)"
|
||||
)
|
||||
elapsed_time: float = SchemaField(description="Total time elapsed in seconds")
|
||||
recall_info: dict = SchemaField(
|
||||
description="Information about expected results and confidence"
|
||||
)
|
||||
timed_out: bool = SchemaField(description="Whether the operation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="14da21ae-40a1-41bc-a111-c8e5c9ef012b",
|
||||
description="Wait for a specific webset search to complete with progress tracking",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaWaitForSearchBlock.Input,
|
||||
output_schema=ExaWaitForSearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
start_time = time.time()
|
||||
interval = input_data.check_interval
|
||||
max_interval = 30
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
try:
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
# Get current search status using SDK
|
||||
search = aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
|
||||
# Extract status
|
||||
status = (
|
||||
search.status.value
|
||||
if hasattr(search.status, "value")
|
||||
else str(search.status)
|
||||
)
|
||||
|
||||
# Check if search is complete
|
||||
if status in ["completed", "failed", "canceled"]:
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Extract progress information
|
||||
progress_dict = {}
|
||||
if search.progress:
|
||||
progress_dict = search.progress.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
|
||||
# Extract recall information
|
||||
recall_info = {}
|
||||
if search.recall:
|
||||
recall_dict = search.recall.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
expected = recall_dict.get("expected", {})
|
||||
recall_info = {
|
||||
"expected_total": expected.get("total", 0),
|
||||
"confidence": expected.get("confidence", ""),
|
||||
"min_expected": expected.get("bounds", {}).get("min", 0),
|
||||
"max_expected": expected.get("bounds", {}).get("max", 0),
|
||||
"reasoning": recall_dict.get("reasoning", ""),
|
||||
}
|
||||
|
||||
yield "search_id", input_data.search_id
|
||||
yield "final_status", status
|
||||
yield "items_found", progress_dict.get("found", 0)
|
||||
yield "items_analyzed", progress_dict.get("analyzed", 0)
|
||||
yield "completion_percentage", progress_dict.get("completion", 0)
|
||||
yield "elapsed_time", elapsed
|
||||
yield "recall_info", recall_info
|
||||
yield "timed_out", False
|
||||
|
||||
return
|
||||
|
||||
# Wait before next check with exponential backoff
|
||||
await asyncio.sleep(interval)
|
||||
interval = min(interval * 1.5, max_interval)
|
||||
|
||||
# Timeout reached
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Get last known status
|
||||
search = aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
final_status = (
|
||||
search.status.value
|
||||
if hasattr(search.status, "value")
|
||||
else str(search.status)
|
||||
)
|
||||
|
||||
progress_dict = {}
|
||||
if search.progress:
|
||||
progress_dict = search.progress.model_dump(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
|
||||
yield "search_id", input_data.search_id
|
||||
yield "final_status", final_status
|
||||
yield "items_found", progress_dict.get("found", 0)
|
||||
yield "items_analyzed", progress_dict.get("analyzed", 0)
|
||||
yield "completion_percentage", progress_dict.get("completion", 0)
|
||||
yield "elapsed_time", elapsed
|
||||
yield "timed_out", True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise ValueError(
|
||||
f"Search polling timed out after {input_data.timeout} seconds"
|
||||
) from None
|
||||
|
||||
|
||||
class ExaWaitForEnrichmentBlock(Block):
|
||||
"""Wait for a webset enrichment to complete with progress tracking."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The ID of the enrichment to monitor",
|
||||
placeholder="enrichment-id",
|
||||
)
|
||||
timeout: int = SchemaField(
|
||||
default=300,
|
||||
description="Maximum time to wait in seconds",
|
||||
ge=1,
|
||||
le=1800,
|
||||
)
|
||||
check_interval: int = SchemaField(
|
||||
default=5,
|
||||
description="Initial interval between status checks in seconds",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=60,
|
||||
)
|
||||
sample_results: bool = SchemaField(
|
||||
default=True,
|
||||
description="Include sample enrichment results in output",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
enrichment_id: str = SchemaField(
|
||||
description="The enrichment ID that was monitored"
|
||||
)
|
||||
final_status: str = SchemaField(
|
||||
description="The final status of the enrichment"
|
||||
)
|
||||
items_enriched: int = SchemaField(
|
||||
description="Number of items successfully enriched"
|
||||
)
|
||||
enrichment_title: str = SchemaField(
|
||||
description="Title/description of the enrichment"
|
||||
)
|
||||
elapsed_time: float = SchemaField(description="Total time elapsed in seconds")
|
||||
sample_data: list[SampleEnrichmentModel] = SchemaField(
|
||||
description="Sample of enriched data (if requested)"
|
||||
)
|
||||
timed_out: bool = SchemaField(description="Whether the operation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a11865c3-ac80-4721-8a40-ac4e3b71a558",
|
||||
description="Wait for a webset enrichment to complete with progress tracking",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaWaitForEnrichmentBlock.Input,
|
||||
output_schema=ExaWaitForEnrichmentBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
start_time = time.time()
|
||||
interval = input_data.check_interval
|
||||
max_interval = 30
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
try:
|
||||
while time.time() - start_time < input_data.timeout:
|
||||
# Get current enrichment status using SDK
|
||||
enrichment = aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
|
||||
# Extract status
|
||||
status = (
|
||||
enrichment.status.value
|
||||
if hasattr(enrichment.status, "value")
|
||||
else str(enrichment.status)
|
||||
)
|
||||
|
||||
# Check if enrichment is complete
|
||||
if status in ["completed", "failed", "canceled"]:
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Get sample enriched items if requested
|
||||
sample_data = []
|
||||
items_enriched = 0
|
||||
|
||||
if input_data.sample_results and status == "completed":
|
||||
sample_data, items_enriched = (
|
||||
await self._get_sample_enrichments(
|
||||
input_data.webset_id, input_data.enrichment_id, aexa
|
||||
)
|
||||
)
|
||||
|
||||
yield "enrichment_id", input_data.enrichment_id
|
||||
yield "final_status", status
|
||||
yield "items_enriched", items_enriched
|
||||
yield "enrichment_title", enrichment.title or enrichment.description or ""
|
||||
yield "elapsed_time", elapsed
|
||||
if input_data.sample_results:
|
||||
yield "sample_data", sample_data
|
||||
yield "timed_out", False
|
||||
|
||||
return
|
||||
|
||||
# Wait before next check with exponential backoff
|
||||
await asyncio.sleep(interval)
|
||||
interval = min(interval * 1.5, max_interval)
|
||||
|
||||
# Timeout reached
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Get last known status
|
||||
enrichment = aexa.websets.enrichments.get(
|
||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||
)
|
||||
final_status = (
|
||||
enrichment.status.value
|
||||
if hasattr(enrichment.status, "value")
|
||||
else str(enrichment.status)
|
||||
)
|
||||
title = enrichment.title or enrichment.description or ""
|
||||
|
||||
yield "enrichment_id", input_data.enrichment_id
|
||||
yield "final_status", final_status
|
||||
yield "items_enriched", 0
|
||||
yield "enrichment_title", title
|
||||
yield "elapsed_time", elapsed
|
||||
yield "timed_out", True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise ValueError(
|
||||
f"Enrichment polling timed out after {input_data.timeout} seconds"
|
||||
) from None
|
||||
|
||||
async def _get_sample_enrichments(
|
||||
self, webset_id: str, enrichment_id: str, aexa: AsyncExa
|
||||
) -> tuple[list[SampleEnrichmentModel], int]:
|
||||
"""Get sample enriched data and count."""
|
||||
# Get a few items to see enrichment results using SDK
|
||||
response = aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||
|
||||
sample_data: list[SampleEnrichmentModel] = []
|
||||
enriched_count = 0
|
||||
|
||||
for sdk_item in response.data:
|
||||
# Convert to our WebsetItemModel first
|
||||
item = WebsetItemModel.from_sdk(sdk_item)
|
||||
|
||||
# Check if this item has the enrichment we're looking for
|
||||
if enrichment_id in item.enrichments:
|
||||
enriched_count += 1
|
||||
enrich_model = item.enrichments[enrichment_id]
|
||||
|
||||
# Create sample using our typed model
|
||||
sample = SampleEnrichmentModel(
|
||||
item_id=item.id,
|
||||
item_title=item.title,
|
||||
enrichment_data=enrich_model.model_dump(exclude_none=True),
|
||||
)
|
||||
sample_data.append(sample)
|
||||
|
||||
return sample_data, enriched_count
|
||||
@@ -1,650 +0,0 @@
|
||||
"""
|
||||
Exa Websets Search Management Blocks
|
||||
|
||||
This module provides blocks for creating and managing searches within websets,
|
||||
including adding new searches, checking status, and canceling operations.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from exa_py import AsyncExa
|
||||
from exa_py.websets.types import WebsetSearch as SdkWebsetSearch
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
# Mirrored model for stability
|
||||
class WebsetSearchModel(BaseModel):
|
||||
"""Stable output model mirroring SDK WebsetSearch."""
|
||||
|
||||
id: str
|
||||
webset_id: str
|
||||
status: str
|
||||
query: str
|
||||
entity_type: str
|
||||
criteria: List[Dict[str, Any]]
|
||||
count: int
|
||||
behavior: str
|
||||
progress: Dict[str, Any]
|
||||
recall: Optional[Dict[str, Any]]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
canceled_at: Optional[str]
|
||||
canceled_reason: Optional[str]
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_sdk(cls, search: SdkWebsetSearch) -> "WebsetSearchModel":
|
||||
"""Convert SDK WebsetSearch to our stable model."""
|
||||
# Extract entity type
|
||||
entity_type = "auto"
|
||||
if search.entity:
|
||||
entity_dict = search.entity.model_dump(by_alias=True)
|
||||
entity_type = entity_dict.get("type", "auto")
|
||||
|
||||
# Convert criteria
|
||||
criteria = [c.model_dump(by_alias=True) for c in search.criteria]
|
||||
|
||||
# Convert progress
|
||||
progress_dict = {}
|
||||
if search.progress:
|
||||
progress_dict = search.progress.model_dump(by_alias=True)
|
||||
|
||||
# Convert recall
|
||||
recall_dict = None
|
||||
if search.recall:
|
||||
recall_dict = search.recall.model_dump(by_alias=True)
|
||||
|
||||
return cls(
|
||||
id=search.id,
|
||||
webset_id=search.webset_id,
|
||||
status=(
|
||||
search.status.value
|
||||
if hasattr(search.status, "value")
|
||||
else str(search.status)
|
||||
),
|
||||
query=search.query,
|
||||
entity_type=entity_type,
|
||||
criteria=criteria,
|
||||
count=search.count,
|
||||
behavior=search.behavior.value if search.behavior else "override",
|
||||
progress=progress_dict,
|
||||
recall=recall_dict,
|
||||
created_at=search.created_at.isoformat() if search.created_at else "",
|
||||
updated_at=search.updated_at.isoformat() if search.updated_at else "",
|
||||
canceled_at=search.canceled_at.isoformat() if search.canceled_at else None,
|
||||
canceled_reason=(
|
||||
search.canceled_reason.value if search.canceled_reason else None
|
||||
),
|
||||
metadata=search.metadata if search.metadata else {},
|
||||
)
|
||||
|
||||
|
||||
class SearchBehavior(str, Enum):
|
||||
"""Behavior for how new search results interact with existing items."""
|
||||
|
||||
OVERRIDE = "override" # Replace existing items
|
||||
APPEND = "append" # Add to existing items
|
||||
MERGE = "merge" # Merge with existing items
|
||||
|
||||
|
||||
class SearchEntityType(str, Enum):
|
||||
COMPANY = "company"
|
||||
PERSON = "person"
|
||||
ARTICLE = "article"
|
||||
RESEARCH_PAPER = "research_paper"
|
||||
CUSTOM = "custom"
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class ExaCreateWebsetSearchBlock(Block):
|
||||
"""Add a new search to an existing webset."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
query: str = SchemaField(
|
||||
description="Search query describing what to find",
|
||||
placeholder="Engineering managers at Fortune 500 companies",
|
||||
)
|
||||
count: int = SchemaField(
|
||||
default=10,
|
||||
description="Number of items to find",
|
||||
ge=1,
|
||||
le=1000,
|
||||
)
|
||||
|
||||
# Entity configuration
|
||||
entity_type: SearchEntityType = SchemaField(
|
||||
default=SearchEntityType.AUTO,
|
||||
description="Type of entity to search for",
|
||||
)
|
||||
entity_description: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Description for custom entity type",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Criteria for verification
|
||||
criteria: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of criteria that items must meet. If not provided, auto-detected from query.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Advanced search options
|
||||
behavior: SearchBehavior = SchemaField(
|
||||
default=SearchBehavior.APPEND,
|
||||
description="How new results interact with existing items",
|
||||
advanced=True,
|
||||
)
|
||||
recall: bool = SchemaField(
|
||||
default=True,
|
||||
description="Enable recall estimation for expected results",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Exclude sources
|
||||
exclude_source_ids: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="IDs of imports/websets to exclude from results",
|
||||
advanced=True,
|
||||
)
|
||||
exclude_source_types: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="Types of sources to exclude ('import' or 'webset')",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Scope sources
|
||||
scope_source_ids: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="IDs of imports/websets to limit search scope to",
|
||||
advanced=True,
|
||||
)
|
||||
scope_source_types: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="Types of scope sources ('import' or 'webset')",
|
||||
advanced=True,
|
||||
)
|
||||
scope_relationships: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="Relationship definitions for hop searches",
|
||||
advanced=True,
|
||||
)
|
||||
scope_relationship_limits: list[int] = SchemaField(
|
||||
default_factory=list,
|
||||
description="Limits on related entities to find",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="Metadata to attach to the search",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Polling options
|
||||
wait_for_completion: bool = SchemaField(
|
||||
default=False,
|
||||
description="Wait for the search to complete before returning",
|
||||
)
|
||||
polling_timeout: int = SchemaField(
|
||||
default=300,
|
||||
description="Maximum time to wait for completion in seconds",
|
||||
advanced=True,
|
||||
ge=1,
|
||||
le=600,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
search_id: str = SchemaField(
|
||||
description="The unique identifier for the created search"
|
||||
)
|
||||
webset_id: str = SchemaField(description="The webset this search belongs to")
|
||||
status: str = SchemaField(description="Current status of the search")
|
||||
query: str = SchemaField(description="The search query")
|
||||
expected_results: dict = SchemaField(
|
||||
description="Recall estimation of expected results"
|
||||
)
|
||||
items_found: Optional[int] = SchemaField(
|
||||
description="Number of items found (if wait_for_completion was True)"
|
||||
)
|
||||
completion_time: Optional[float] = SchemaField(
|
||||
description="Time taken to complete in seconds (if wait_for_completion was True)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="342ff776-2e2c-4cdb-b392-4eeb34b21d5f",
|
||||
description="Add a new search to an existing webset to find more items",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaCreateWebsetSearchBlock.Input,
|
||||
output_schema=ExaCreateWebsetSearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
import time
|
||||
|
||||
# Build the payload
|
||||
payload = {
|
||||
"query": input_data.query,
|
||||
"count": input_data.count,
|
||||
"behavior": input_data.behavior.value,
|
||||
"recall": input_data.recall,
|
||||
}
|
||||
|
||||
# Add entity configuration
|
||||
if input_data.entity_type != SearchEntityType.AUTO:
|
||||
entity = {"type": input_data.entity_type.value}
|
||||
if (
|
||||
input_data.entity_type == SearchEntityType.CUSTOM
|
||||
and input_data.entity_description
|
||||
):
|
||||
entity["description"] = input_data.entity_description
|
||||
payload["entity"] = entity
|
||||
|
||||
# Add criteria if provided
|
||||
if input_data.criteria:
|
||||
payload["criteria"] = [{"description": c} for c in input_data.criteria]
|
||||
|
||||
# Add exclude sources
|
||||
if input_data.exclude_source_ids:
|
||||
exclude_list = []
|
||||
for idx, src_id in enumerate(input_data.exclude_source_ids):
|
||||
src_type = "import"
|
||||
if input_data.exclude_source_types and idx < len(
|
||||
input_data.exclude_source_types
|
||||
):
|
||||
src_type = input_data.exclude_source_types[idx]
|
||||
exclude_list.append({"source": src_type, "id": src_id})
|
||||
payload["exclude"] = exclude_list
|
||||
|
||||
# Add scope sources
|
||||
if input_data.scope_source_ids:
|
||||
scope_list: list[dict[str, Any]] = []
|
||||
for idx, src_id in enumerate(input_data.scope_source_ids):
|
||||
scope_item: dict[str, Any] = {"source": "import", "id": src_id}
|
||||
|
||||
if input_data.scope_source_types and idx < len(
|
||||
input_data.scope_source_types
|
||||
):
|
||||
scope_item["source"] = input_data.scope_source_types[idx]
|
||||
|
||||
# Add relationship if provided
|
||||
if input_data.scope_relationships and idx < len(
|
||||
input_data.scope_relationships
|
||||
):
|
||||
relationship: dict[str, Any] = {
|
||||
"definition": input_data.scope_relationships[idx]
|
||||
}
|
||||
if input_data.scope_relationship_limits and idx < len(
|
||||
input_data.scope_relationship_limits
|
||||
):
|
||||
relationship["limit"] = input_data.scope_relationship_limits[
|
||||
idx
|
||||
]
|
||||
scope_item["relationship"] = relationship
|
||||
|
||||
scope_list.append(scope_item)
|
||||
payload["scope"] = scope_list
|
||||
|
||||
# Add metadata if provided
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_search = aexa.websets.searches.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
|
||||
search_id = sdk_search.id
|
||||
status = (
|
||||
sdk_search.status.value
|
||||
if hasattr(sdk_search.status, "value")
|
||||
else str(sdk_search.status)
|
||||
)
|
||||
|
||||
# Extract expected results from recall
|
||||
expected_results = {}
|
||||
if sdk_search.recall:
|
||||
recall_dict = sdk_search.recall.model_dump(by_alias=True)
|
||||
expected = recall_dict.get("expected", {})
|
||||
expected_results = {
|
||||
"total": expected.get("total", 0),
|
||||
"confidence": expected.get("confidence", ""),
|
||||
"min": expected.get("bounds", {}).get("min", 0),
|
||||
"max": expected.get("bounds", {}).get("max", 0),
|
||||
"reasoning": recall_dict.get("reasoning", ""),
|
||||
}
|
||||
|
||||
# If wait_for_completion is True, poll for completion
|
||||
if input_data.wait_for_completion:
|
||||
import asyncio
|
||||
|
||||
poll_interval = 5
|
||||
max_interval = 30
|
||||
poll_start = time.time()
|
||||
|
||||
while time.time() - poll_start < input_data.polling_timeout:
|
||||
current_search = aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=search_id
|
||||
)
|
||||
current_status = (
|
||||
current_search.status.value
|
||||
if hasattr(current_search.status, "value")
|
||||
else str(current_search.status)
|
||||
)
|
||||
|
||||
if current_status in ["completed", "failed", "cancelled"]:
|
||||
items_found = 0
|
||||
if current_search.progress:
|
||||
items_found = current_search.progress.found
|
||||
completion_time = time.time() - start_time
|
||||
|
||||
yield "search_id", search_id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", current_status
|
||||
yield "query", input_data.query
|
||||
yield "expected_results", expected_results
|
||||
yield "items_found", items_found
|
||||
yield "completion_time", completion_time
|
||||
return
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
poll_interval = min(poll_interval * 1.5, max_interval)
|
||||
|
||||
# Timeout - yield what we have
|
||||
yield "search_id", search_id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", status
|
||||
yield "query", input_data.query
|
||||
yield "expected_results", expected_results
|
||||
yield "items_found", 0
|
||||
yield "completion_time", time.time() - start_time
|
||||
else:
|
||||
yield "search_id", search_id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", status
|
||||
yield "query", input_data.query
|
||||
yield "expected_results", expected_results
|
||||
|
||||
|
||||
class ExaGetWebsetSearchBlock(Block):
|
||||
"""Get the status and details of a webset search."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
search_id: str = SchemaField(
|
||||
description="The ID of the search to retrieve",
|
||||
placeholder="search-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
search_id: str = SchemaField(description="The unique identifier for the search")
|
||||
status: str = SchemaField(description="Current status of the search")
|
||||
query: str = SchemaField(description="The search query")
|
||||
entity_type: str = SchemaField(description="Type of entity being searched")
|
||||
criteria: list[dict] = SchemaField(description="Criteria used for verification")
|
||||
progress: dict = SchemaField(description="Search progress information")
|
||||
recall: dict = SchemaField(description="Recall estimation information")
|
||||
created_at: str = SchemaField(description="When the search was created")
|
||||
updated_at: str = SchemaField(description="When the search was last updated")
|
||||
canceled_at: Optional[str] = SchemaField(
|
||||
description="When the search was canceled (if applicable)"
|
||||
)
|
||||
canceled_reason: Optional[str] = SchemaField(
|
||||
description="Reason for cancellation (if applicable)"
|
||||
)
|
||||
metadata: dict = SchemaField(description="Metadata attached to the search")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4fa3e627-a0ff-485f-8732-52148051646c",
|
||||
description="Get the status and details of a webset search",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaGetWebsetSearchBlock.Input,
|
||||
output_schema=ExaGetWebsetSearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
sdk_search = aexa.websets.searches.get(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
|
||||
search = WebsetSearchModel.from_sdk(sdk_search)
|
||||
|
||||
# Extract progress information
|
||||
progress_info = {
|
||||
"found": search.progress.get("found", 0),
|
||||
"analyzed": search.progress.get("analyzed", 0),
|
||||
"completion": search.progress.get("completion", 0),
|
||||
"time_left": search.progress.get("timeLeft", 0),
|
||||
}
|
||||
|
||||
# Extract recall information
|
||||
recall_data = {}
|
||||
if search.recall:
|
||||
expected = search.recall.get("expected", {})
|
||||
recall_data = {
|
||||
"expected_total": expected.get("total", 0),
|
||||
"confidence": expected.get("confidence", ""),
|
||||
"min_expected": expected.get("bounds", {}).get("min", 0),
|
||||
"max_expected": expected.get("bounds", {}).get("max", 0),
|
||||
"reasoning": search.recall.get("reasoning", ""),
|
||||
}
|
||||
|
||||
yield "search_id", search.id
|
||||
yield "status", search.status
|
||||
yield "query", search.query
|
||||
yield "entity_type", search.entity_type
|
||||
yield "criteria", search.criteria
|
||||
yield "progress", progress_info
|
||||
yield "recall", recall_data
|
||||
yield "created_at", search.created_at
|
||||
yield "updated_at", search.updated_at
|
||||
yield "canceled_at", search.canceled_at
|
||||
yield "canceled_reason", search.canceled_reason
|
||||
yield "metadata", search.metadata
|
||||
|
||||
|
||||
class ExaCancelWebsetSearchBlock(Block):
|
||||
"""Cancel a running webset search."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
search_id: str = SchemaField(
|
||||
description="The ID of the search to cancel",
|
||||
placeholder="search-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
search_id: str = SchemaField(description="The ID of the canceled search")
|
||||
status: str = SchemaField(description="Status after cancellation")
|
||||
items_found_before_cancel: int = SchemaField(
|
||||
description="Number of items found before cancellation"
|
||||
)
|
||||
success: str = SchemaField(
|
||||
description="Whether the cancellation was successful"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="74ef9f1e-ae89-4c7f-9d7d-d217214815b4",
|
||||
description="Cancel a running webset search",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaCancelWebsetSearchBlock.Input,
|
||||
output_schema=ExaCancelWebsetSearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
canceled_search = aexa.websets.searches.cancel(
|
||||
webset_id=input_data.webset_id, id=input_data.search_id
|
||||
)
|
||||
|
||||
# Extract items found before cancellation
|
||||
items_found = 0
|
||||
if canceled_search.progress:
|
||||
items_found = canceled_search.progress.found
|
||||
|
||||
status = (
|
||||
canceled_search.status.value
|
||||
if hasattr(canceled_search.status, "value")
|
||||
else str(canceled_search.status)
|
||||
)
|
||||
|
||||
yield "search_id", canceled_search.id
|
||||
yield "status", status
|
||||
yield "items_found_before_cancel", items_found
|
||||
yield "success", "true"
|
||||
|
||||
|
||||
class ExaFindOrCreateSearchBlock(Block):
|
||||
"""Find existing search by query or create new one (prevents duplicate searches)."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
query: str = SchemaField(
|
||||
description="Search query to find or create",
|
||||
placeholder="AI companies in San Francisco",
|
||||
)
|
||||
count: int = SchemaField(
|
||||
default=10,
|
||||
description="Number of items to find (only used if creating new search)",
|
||||
ge=1,
|
||||
le=1000,
|
||||
)
|
||||
entity_type: SearchEntityType = SchemaField(
|
||||
default=SearchEntityType.AUTO,
|
||||
description="Entity type (only used if creating)",
|
||||
advanced=True,
|
||||
)
|
||||
behavior: SearchBehavior = SchemaField(
|
||||
default=SearchBehavior.OVERRIDE,
|
||||
description="Search behavior (only used if creating)",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
search_id: str = SchemaField(description="The search ID (existing or new)")
|
||||
webset_id: str = SchemaField(description="The webset ID")
|
||||
status: str = SchemaField(description="Current search status")
|
||||
query: str = SchemaField(description="The search query")
|
||||
was_created: bool = SchemaField(
|
||||
description="True if search was newly created, False if already existed"
|
||||
)
|
||||
items_found: int = SchemaField(
|
||||
description="Number of items found (0 if still running)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="cbdb05ac-cb73-4b03-a493-6d34e9a011da",
|
||||
description="Find existing search by query or create new - prevents duplicate searches in workflows",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaFindOrCreateSearchBlock.Input,
|
||||
output_schema=ExaFindOrCreateSearchBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Use AsyncExa SDK
|
||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Get webset to check existing searches
|
||||
webset = aexa.websets.get(id=input_data.webset_id)
|
||||
|
||||
# Look for existing search with same query
|
||||
existing_search = None
|
||||
if webset.searches:
|
||||
for search in webset.searches:
|
||||
if search.query.strip().lower() == input_data.query.strip().lower():
|
||||
existing_search = search
|
||||
break
|
||||
|
||||
if existing_search:
|
||||
# Found existing search
|
||||
search = WebsetSearchModel.from_sdk(existing_search)
|
||||
|
||||
yield "search_id", search.id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", search.status
|
||||
yield "query", search.query
|
||||
yield "was_created", False
|
||||
yield "items_found", search.progress.get("found", 0)
|
||||
else:
|
||||
# Create new search
|
||||
payload: Dict[str, Any] = {
|
||||
"query": input_data.query,
|
||||
"count": input_data.count,
|
||||
"behavior": input_data.behavior.value,
|
||||
}
|
||||
|
||||
# Add entity if not auto
|
||||
if input_data.entity_type != SearchEntityType.AUTO:
|
||||
payload["entity"] = {"type": input_data.entity_type.value}
|
||||
|
||||
sdk_search = aexa.websets.searches.create(
|
||||
webset_id=input_data.webset_id, params=payload
|
||||
)
|
||||
|
||||
search = WebsetSearchModel.from_sdk(sdk_search)
|
||||
|
||||
yield "search_id", search.id
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "status", search.status
|
||||
yield "query", search.query
|
||||
yield "was_created", True
|
||||
yield "items_found", 0 # Newly created, no items yet
|
||||
@@ -10,13 +10,7 @@ from backend.blocks.fal._auth import (
|
||||
FalCredentialsField,
|
||||
FalCredentialsInput,
|
||||
)
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import ClientResponseError, Requests
|
||||
|
||||
@@ -30,7 +24,7 @@ class FalModel(str, Enum):
|
||||
|
||||
|
||||
class AIVideoGeneratorBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
class Input(BlockSchema):
|
||||
prompt: str = SchemaField(
|
||||
description="Description of the video to generate.",
|
||||
placeholder="A dog running in a field.",
|
||||
@@ -42,7 +36,7 @@ class AIVideoGeneratorBlock(Block):
|
||||
)
|
||||
credentials: FalCredentialsInput = FalCredentialsField()
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="The URL of the generated video.")
|
||||
error: str = SchemaField(
|
||||
description="Error message if video generation failed."
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user