Compare commits

..

127 Commits

Author SHA1 Message Date
Abhimanyu Yadav
3f917327b2 Merge branch 'dev' into redesigning-block-menu 2025-06-18 15:12:11 +05:30
Abhimanyu Yadav
53c92be68f Remove unused imports 2025-06-18 15:07:04 +05:30
Abhimanyu Yadav
7d0c4b9f7f Update pnpm-lock.yaml 2025-06-17 22:31:52 +05:30
abhi1992002
8902ecb3ae Merge branch 'dev' into redesigning-block-menu 2025-06-17 22:30:53 +05:30
Abhimanyu Yadav
845a6ab38b refactor(block-menu): remove commented-out code for recent searches in SuggestionContent to improve code clarity 2025-06-17 22:18:57 +05:30
Abhimanyu Yadav
79afd83919 refactor(block-menu): replace hardcoded scrollbar styles with centralized scrollbarStyles import for consistency across components 2025-06-17 22:15:50 +05:30
Abhimanyu Yadav
be8144b305 refactor(block-menu): optimize key assignment in IntegrationBlocks to use block ID for improved rendering 2025-06-17 21:59:27 +05:30
Abhimanyu Yadav
e63575877d refactor(block-menu): consolidate block content components into PaginatedBlocksContent for improved maintainability 2025-06-17 21:57:18 +05:30
Abhimanyu Yadav
9206d24017 refactor(block-menu): remove React.FC type annotation from block menu components for consistency 2025-06-17 21:54:45 +05:30
Abhimanyu Yadav
b2ab2602fe refactor(block-menu): optimize block key assignment in AllBlocksContent for improved rendering performance 2025-06-17 21:50:58 +05:30
Abhimanyu Yadav
aa4de454b2 refactor(block-menu): simplify useEffect dependencies in AllBlocksContent, IntegrationBlocks, and SuggestionContent components 2025-06-17 21:48:57 +05:30
Abhimanyu Yadav
9ea44b6267 fix format 2025-06-12 07:58:35 +05:30
Abhimanyu Yadav
3cd214d0d4 remove unused state and simplify hover behavior in FilterChip component 2025-06-12 07:55:11 +05:30
Abhimanyu Yadav
04d30efc5d refactor(block-menu): update button styles across components to improve disabled state visibility 2025-06-12 07:50:44 +05:30
Abhimanyu Yadav
9157388723 refactor(block-menu): enhance text highlighting functionality in IntegrationBlock by escaping special characters 2025-06-12 07:41:08 +05:30
Abhimanyu Yadav
455f273ccf refactor(block-menu): replace hardcoded filter defaults with getDefaultFilters utility for consistency 2025-06-12 07:34:38 +05:30
Abhimanyu Yadav
382598f2be refactor(block-menu): improve error message in useBlockMenuContext for clarity on provider usage 2025-06-12 07:28:57 +05:30
Abhimanyu Yadav
79b6a56b56 refactor(block-menu): adjust conditional rendering for image display in UGCAgentBlock component 2025-06-12 07:28:05 +05:30
Abhimanyu Yadav
68cec8b2e7 refactor(block-menu): enhance error handling in ErrorState component by introducing parseErrorMessage utility 2025-06-12 07:14:41 +05:30
Abhimanyu Yadav
b921edb062 refactor(block-menu): update ControlPanelButton styles for improved clarity and consistency 2025-06-12 07:09:13 +05:30
Abhimanyu Yadav
b7408415df refactor(block-menu): implement search debounce and update ControlPanelButton imports for consistency 2025-06-12 06:39:50 +05:30
Abhimanyu Yadav
59752054fa refactor(block-menu): export components in various files for improved modularity and consistency 2025-06-12 06:33:01 +05:30
Abhimanyu Yadav
478f31141d refactor(block-menu): export components in Block, BlockMenu, BlockMenuContent, and related files for improved modularity 2025-06-12 06:32:44 +05:30
Abhimanyu Yadav
5c264c253c refactor(block-menu): simplify callback functions in BlockMenu, BlockMenuContent, and FilterSheet components 2025-06-12 05:44:08 +05:30
Abhimanyu Yadav
d6d4703bbc refactor(block-menu): simplify conditional rendering for title and description in Block, Integration, IntegrationBlock, IntegrationChip, MarketplaceAgentBlock, and UGCAgentBlock components 2025-06-11 19:23:59 +05:30
Abhimanyu Yadav
0b602600cb refactor(styles): remove unused scroll-container styles from globals.css 2025-06-11 19:01:27 +05:30
Abhimanyu Yadav
19382072b1 chore: update integration images with compressed ones 2025-06-11 18:59:05 +05:30
Abhimanyu Yadav
3e2b388df0 Merge branch 'dev' into redesigning-block-menu 2025-06-09 21:29:30 +05:30
Abhimanyu Yadav
a50532a975 Merge branch 'dev' into redesigning-block-menu 2025-06-09 16:40:30 +05:30
Abhimanyu Yadav
27e53aa3dd Comment out monitor test suite 2025-06-09 10:45:25 +05:30
Abhimanyu Yadav
a24673d15f Comment out build.spec.ts test file 2025-06-09 10:34:58 +05:30
Abhimanyu Yadav
9d2d9606e8 Merge branch 'dev' into redesigning-block-menu 2025-06-09 10:17:38 +05:30
Abhimanyu Yadav
91407dfc33 Add expandable creator list in filter sheet menu 2025-06-06 18:34:00 +05:30
abhi1992002
851919d2d5 Merge branch 'dev' into redesigning-block-menu 2025-06-06 18:23:17 +05:30
Abhimanyu Yadav
d6acb02cb6 Merge branch 'dev' into redesigning-block-menu 2025-06-06 18:20:05 +05:30
Krzysztof Czerwinski
9c07206725 Merge branch 'redesigning-block-menu' into kpczerwinski/secrt-1320-backend-update 2025-06-06 14:43:17 +02:00
Krzysztof Czerwinski
4bd3447301 Cleanup and comments 2025-06-06 14:39:13 +02:00
Abhimanyu Yadav
8adc9f967d Remove commented TODO and clean up code formatting 2025-06-06 17:11:40 +05:30
Abhimanyu Yadav
349b70c4bc Remove unused imports and cleanup effects 2025-06-06 17:07:59 +05:30
Abhimanyu Yadav
9ecfa1e1f1 Add scrolling and fixed footer to filter sheet panel 2025-06-06 17:03:24 +05:30
Krzysztof Czerwinski
4e17f9c49e Include block costs in get_blocks 2025-06-06 13:05:44 +02:00
Krzysztof Czerwinski
31fdeeb706 Make agent_name optional 2025-06-06 13:03:32 +02:00
Abhimanyu Yadav
e42b24c029 Add hasLocalActiveFilters for applying filter state 2025-06-06 11:10:47 +05:30
Abhimanyu Yadav
2d52a57a21 Fix "Agent page" link propagation in menu block 2025-06-06 10:53:55 +05:30
Krzysztof Czerwinski
f45123f6b6 Fix Agent Executor block name 2025-06-05 17:00:40 +02:00
Abhimanyu Yadav
d524518f41 Update pnpm-lock.yaml 2025-06-05 18:32:07 +05:30
abhi1992002
81d1b28d92 Merge remote-tracking branch 'upstream/dev' into redesigning-block-menu 2025-06-05 18:30:19 +05:30
Abhimanyu Yadav
4e4e754ac1 Merge branch 'backend-temp' into redesigning-block-menu 2025-06-05 17:00:08 +05:30
Krzysztof Czerwinski
e409d7aa34 Fixes 2025-06-04 14:58:49 +02:00
Abhimanyu Yadav
8312a339c2 cleaning up frontend code 2025-06-04 11:22:46 +05:30
Abhimanyu Yadav
5b45d246ef fix blockType utils 2025-06-04 10:46:45 +05:30
Krzysztof Czerwinski
5c7c7ca874 Suggested blocks 2025-06-03 19:27:02 +02:00
Krzysztof Czerwinski
c93c5e35ba Merge branch 'redesigning-block-menu' into kpczerwinski/secrt-1320-backend-update 2025-06-03 11:45:34 +02:00
Abhimanyu Yadav
ce989b1bf7 remove providers from filter list and add support of ai blocks in search
list]
2025-06-03 10:46:25 +05:30
Abhimanyu Yadav
c1c919b88b Merge branch 'backend-temp' into redesigning-block-menu 2025-06-03 10:22:01 +05:30
Krzysztof Czerwinski
21a91fe9fd Merge branch 'redesigning-block-menu' into kpczerwinski/secrt-1320-backend-update 2025-06-02 15:07:16 +02:00
Krzysztof Czerwinski
b2f3d8c1f2 Search model names 2025-06-02 15:06:54 +02:00
Krzysztof Czerwinski
46ab2e3b20 Remove providers filter from search 2025-06-02 10:13:05 +02:00
Abhimanyu Yadav
5b40700299 fetching creator list from searchList
Moves the `getBlockType` function from the SearchList component to the
`utils.ts` file to make it more reusable. Also removes the unused
`creators` state and `setCreators` function from the
BlockMenuContext and instead calculates the creators list dynamically
within the FilterSheet component based on the available search data.
2025-06-02 13:07:35 +05:30
Abhimanyu Yadav
1a97020eeb fix marketplace agent block and libray agent block in searchList 2025-06-02 12:50:07 +05:30
Abhimanyu Yadav
39d03f2090 Add Marketplace Agents to builder
Adds functionality to add Marketplace agents to the user's library and then to builder.
Includes a loading indicator while the agent is being added.
Refactors agent-to-block conversion into a utility function.
2025-06-02 12:32:13 +05:30
Abhimanyu Yadav
8088d294f4 Add Agent Blocks to Flow
This commit adds the ability to add Agent blocks to the
flow.  Clicking on an agent in the My Agents menu will add
it to the flow.  The block includes the necessary
information such as input/output schemas.
2025-06-02 12:03:51 +05:30
Abhimanyu Yadav
31266949ed Clears all filters when the search input is cleared and redesign filter based on new design. 2025-06-02 11:34:01 +05:30
abhi1992002
f4eb00a6ad Fetch Block Counts in Block Menu
Adds API calls to fetch block counts for each category
in the block menu and displays them next to the category
name.  This replaces the hardcoded numbers previously
displayed.
2025-06-02 10:50:26 +05:30
Abhimanyu Yadav
f75cc0dd11 Merge branch 'dev' into redesigning-block-menu 2025-06-02 10:34:16 +05:30
Krzysztof Czerwinski
21b612625f Format frontend 2025-05-31 13:42:32 +02:00
Krzysztof Czerwinski
eec0d276d5 Add output_schema to LibraryAgent 2025-05-31 13:42:07 +02:00
Krzysztof Czerwinski
c6941e7f6e Merge branch 'redesigning-block-menu' into kpczerwinski/secrt-1320-backend-update 2025-05-31 12:49:36 +02:00
Abhimanyu Yadav
325684a10f remove recent searches from suggestionContent and done some cleanup as
well
2025-05-30 17:33:44 +05:30
Abhimanyu Yadav
cf057cbbda fixed pagination problem in default menus in block menu 2025-05-30 17:26:59 +05:30
Abhimanyu Yadav
f3a7be1fd3 add highlighted description while searching 2025-05-30 12:06:36 +05:30
Abhimanyu Yadav
97bcb0f95e fix searchlist pagination 2025-05-30 11:11:40 +05:30
Abhimanyu Yadav
dd71d65706 adding beautify String in integration chips 2025-05-30 10:57:08 +05:30
Abhimanyu Yadav
2b2d26bcde remove items expanding when selecting menus 2025-05-30 10:54:56 +05:30
Abhimanyu Yadav
67f6f43e1b fix error state layout in input/output/action blocks list 2025-05-30 10:43:03 +05:30
Abhimanyu Yadav
a3409c9578 fix MarketplaceAgentBlock layout 2025-05-30 10:29:53 +05:30
Abhimanyu Yadav
7f82457ea4 add external agent link to marketplace agent block 2025-05-30 10:25:40 +05:30
Abhimanyu Yadav
a5c0fabc00 fix design of clear button in searchMenuBar 2025-05-30 10:12:09 +05:30
Krzysztof Czerwinski
09dba93a4a Add counts endpoint 2025-05-29 16:17:33 +02:00
Krzysztof Czerwinski
ea2cd3e7bf Merge branch 'redesigning-block-menu' into kpczerwinski/secrt-1320-backend-update 2025-05-29 13:16:18 +02:00
Abhimanyu Yadav
d3d0ccf732 fix menu item hover state and add a clear button at the end of searchbar 2025-05-29 11:09:55 +05:30
Abhimanyu Yadav
d8d5d6ec0c make hover state correct on all reusable compoents in block menu 2025-05-29 11:01:09 +05:30
Abhimanyu Yadav
f45b09c0b5 fix hover state and heading text in suggestion content page 2025-05-29 10:57:08 +05:30
Abhimanyu Yadav
1e89b6d3a4 add beautifyString in block, integration and integration block 2025-05-29 10:31:10 +05:30
Abhimanyu Yadav
950a85e179 fix image sizes warning with fill 2025-05-28 17:40:35 +05:30
Abhimanyu Yadav
c5e3148145 add better error handling in all components 2025-05-28 17:27:07 +05:30
Abhimanyu Yadav
a135ba3f0b refactor addBlock implementation in flow.tsx 2025-05-28 15:58:31 +05:30
Abhimanyu Yadav
fe95e27226 only show scroller when hovering 2025-05-28 15:48:28 +05:30
Abhimanyu Yadav
711ca10cc9 add relative time in my_agent block using react-timeago library 2025-05-28 15:29:34 +05:30
Abhimanyu Yadav
1346d8230c add 500ms debouncer on searchbar 2025-05-28 15:19:56 +05:30
Abhimanyu Yadav
07c84a4757 add categories filter in search 2025-05-28 13:57:39 +05:30
Abhimanyu Yadav
596824c1e7 add pagination on search list 2025-05-28 13:28:32 +05:30
Abhimanyu Yadav
79afa6db99 add search functioanlity in block menu 2025-05-28 12:15:38 +05:30
Abhimanyu Yadav
e034c16f31 add pagination in all components in default state 2025-05-26 21:13:51 +05:30
Abhimanyu Yadav
9012eff1ac add basic data fetching in all default state components 2025-05-26 10:27:15 +05:30
Abhimanyu Yadav
0361ea4aa4 connection integration list and blocks 2025-05-26 00:25:30 +05:30
Abhimanyu Yadav
6f1c522ea3 add some images and connect suggestion content frontend with backend 2025-05-25 23:09:22 +05:30
Krzysztof Czerwinski
2d654bf64b Update frontend types and api client 2025-05-25 15:12:01 +02:00
Krzysztof Czerwinski
bb69e32fee Update backend 2025-05-25 15:11:29 +02:00
Krzysztof Czerwinski
1be830835b Update signatures, disable providers 2025-05-23 17:22:35 +02:00
Krzysztof Czerwinski
a2a4d546f7 Merge branch 'redesigning-block-menu' into kpczerwinski/secrt-1320-backend-update 2025-05-23 16:51:53 +02:00
Krzysztof Czerwinski
3053a7bd06 Add types and function on the frontend 2025-05-23 16:50:52 +02:00
Krzysztof Czerwinski
bbf4108136 Add builder router and get_blocks endpoint 2025-05-23 16:50:05 +02:00
Krzysztof Czerwinski
95387bcf78 Add model and functions 2025-05-23 16:48:34 +02:00
Abhimanyu Yadav
e1fc56e6f3 fix small optimisation and DX issue 2025-05-21 18:10:29 +05:30
Abhimanyu Yadav
2a06956802 fix max width in sidebar 2025-05-20 17:01:03 +05:30
Abhimanyu Yadav
32231ff80f Implement search text highlighting in Block components, add transitions
to FilterChip, and create NoSearchResult component for empty searches. Move
SearchItem types to provider context for better access.
2025-05-20 15:31:02 +05:30
Abhimanyu Yadav
d0b23c085f add context api for block menu 2025-05-20 11:58:45 +05:30
Abhimanyu Yadav
e718d3d3d8 fix filter sheets 2025-05-20 11:25:46 +05:30
Abhimanyu Yadav
1971a62684 fix checkbox tick design 2025-05-20 10:38:51 +05:30
Abhimanyu Yadav
e125b5923c fix width of left sidebar 2025-05-20 10:30:18 +05:30
Abhimanyu Yadav
c6942e4e6f prevent layout shift when clicking result elements with border 2025-05-20 10:19:02 +05:30
Abhimanyu Yadav
c9e421a219 Merge branch 'dev' into redesigning-block-menu 2025-05-19 22:27:23 +05:30
Abhimanyu Yadav
7868373897 fix comments 2025-05-19 17:06:17 +05:30
Abhimanyu Yadav
f1c8399e0e fix recent searches onClick 2025-05-19 16:55:59 +05:30
Abhimanyu Yadav
97ba69ef1c fix lint 2025-05-19 16:35:55 +05:30
Abhimanyu Yadav
773e1488bf add filter sheet 2025-05-19 16:34:26 +05:30
Abhimanyu Yadav
4273be59ba fix format 2025-05-19 15:37:13 +05:30
Abhimanyu Yadav
06e524788a fix format 2025-05-18 21:10:26 +05:30
Abhimanyu Yadav
bc08012771 add search list in block menu 2025-05-18 21:10:19 +05:30
Abhimanyu Yadav
4af0aedebd fix format 2025-05-18 17:16:45 +05:30
Abhimanyu Yadav
d22464a75e Add skeleton components and loading states 2025-05-18 17:16:08 +05:30
Abhimanyu Yadav
82e3a485f0 complete frontend design for default state 2025-05-18 10:19:25 +05:30
Abhimanyu Yadav
8165ad5879 fix scrollbar in default content 2025-05-18 08:52:20 +05:30
Abhimanyu Yadav
451284de76 Add tailwind-scrollbar-hide and implement block menu UI
The commit adds a new block menu UI component with sidebar navigation,
integration chips, and scrollable content areas. It includes tailwind-
scrollbar-hide for better UI experience and custom CSS for scroll
containers. The implementation features different content sections
for blocks categorized by type (input, action, output) and supports
search functionality.
2025-05-17 21:18:08 +05:30
Abhimanyu Yadav
1d8c7c5e1a Merge branch 'dev' into redesigning-block-menu 2025-05-17 00:13:52 +05:30
Abhimanyu Yadav
34be6a3379 creating small ui reusable component 2025-05-17 00:01:40 +05:30
933 changed files with 31834 additions and 88918 deletions

View File

@@ -9,13 +9,11 @@
# Platform - Backend
!autogpt_platform/backend/backend/
!autogpt_platform/backend/test/e2e_test_data.py
!autogpt_platform/backend/migrations/
!autogpt_platform/backend/schema.prisma
!autogpt_platform/backend/pyproject.toml
!autogpt_platform/backend/poetry.lock
!autogpt_platform/backend/README.md
!autogpt_platform/backend/.env
# Platform - Market
!autogpt_platform/market/market/
@@ -28,7 +26,6 @@
# Platform - Frontend
!autogpt_platform/frontend/src/
!autogpt_platform/frontend/public/
!autogpt_platform/frontend/scripts/
!autogpt_platform/frontend/package.json
!autogpt_platform/frontend/pnpm-lock.yaml
!autogpt_platform/frontend/tsconfig.json
@@ -36,7 +33,6 @@
## config
!autogpt_platform/frontend/*.config.*
!autogpt_platform/frontend/.env.*
!autogpt_platform/frontend/.env
# Classic - AutoGPT
!classic/original_autogpt/autogpt/

View File

@@ -24,8 +24,7 @@
</details>
#### For configuration changes:
- [ ] `.env.default` is updated or already compatible with my changes
- [ ] `.env.example` is updated or already compatible with my changes
- [ ] `docker-compose.yml` is updated or already compatible with my changes
- [ ] I have included a list of my configuration changes in the PR description (under **Changes**)

View File

@@ -1,244 +0,0 @@
# GitHub Copilot Instructions for AutoGPT
This file provides comprehensive onboarding information for GitHub Copilot coding agent to work efficiently with the AutoGPT repository.
## Repository Overview
**AutoGPT** is a powerful platform for creating, deploying, and managing continuous AI agents that automate complex workflows. This is a large monorepo (~150MB) containing multiple components:
- **AutoGPT Platform** (`autogpt_platform/`) - Main focus: Modern AI agent platform (Polyform Shield License)
- **Classic AutoGPT** (`classic/`) - Legacy agent system (MIT License)
- **Documentation** (`docs/`) - MkDocs-based documentation site
- **Infrastructure** - Docker configurations, CI/CD, and development tools
**Primary Languages & Frameworks:**
- **Backend**: Python 3.10-3.13, FastAPI, Prisma ORM, PostgreSQL, RabbitMQ
- **Frontend**: TypeScript, Next.js 15, React, Tailwind CSS, Radix UI
- **Development**: Docker, Poetry, pnpm, Playwright, Storybook
## Build and Validation Instructions
### Essential Setup Commands
**Always run these commands in the correct directory and in this order:**
1. **Initial Setup** (required once):
```bash
# Clone and enter repository
git clone <repo> && cd AutoGPT
# Start all services (database, redis, rabbitmq, clamav)
cd autogpt_platform && docker compose --profile local up deps --build --detach
```
2. **Backend Setup** (always run before backend development):
```bash
cd autogpt_platform/backend
poetry install # Install dependencies
poetry run prisma migrate dev # Run database migrations
poetry run prisma generate # Generate Prisma client
```
3. **Frontend Setup** (always run before frontend development):
```bash
cd autogpt_platform/frontend
pnpm install # Install dependencies
```
### Runtime Requirements
**Critical:** Always ensure Docker services are running before starting development:
```bash
cd autogpt_platform && docker compose --profile local up deps --build --detach
```
**Python Version:** Use Python 3.11 (required; managed by Poetry via pyproject.toml)
**Node.js Version:** Use Node.js 21+ with pnpm package manager
### Development Commands
**Backend Development:**
```bash
cd autogpt_platform/backend
poetry run serve # Start development server (port 8000)
poetry run test # Run all tests (requires ~5 minutes)
poetry run pytest path/to/test.py # Run specific test
poetry run format # Format code (Black + isort) - always run first
poetry run lint # Lint code (ruff) - run after format
```
**Frontend Development:**
```bash
cd autogpt_platform/frontend
pnpm dev # Start development server (port 3000) - use for active development
pnpm build # Build for production (only needed for E2E tests or deployment)
pnpm test # Run Playwright E2E tests (requires build first)
pnpm test-ui # Run tests with UI
pnpm format # Format and lint code
pnpm storybook # Start component development server
```
### Testing Strategy
**Backend Tests:**
- **Block Tests**: `poetry run pytest backend/blocks/test/test_block.py -xvs` (validates all blocks)
- **Specific Block**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[BlockName]' -xvs`
- **Snapshot Tests**: Use `--snapshot-update` when output changes, always review with `git diff`
**Frontend Tests:**
- **E2E Tests**: Always run `pnpm dev` before `pnpm test` (Playwright requires running instance)
- **Component Tests**: Use Storybook for isolated component development
### Critical Validation Steps
**Before committing changes:**
1. Run `poetry run format` (backend) and `pnpm format` (frontend)
2. Ensure all tests pass in modified areas
3. Verify Docker services are still running
4. Check that database migrations apply cleanly
**Common Issues & Workarounds:**
- **Prisma issues**: Run `poetry run prisma generate` after schema changes
- **Permission errors**: Ensure Docker has proper permissions
- **Port conflicts**: Check the `docker-compose.yml` file for the current list of exposed ports. You can list all mapped ports with:
- **Test timeouts**: Backend tests can take 5+ minutes, use `-x` flag to stop on first failure
## Project Layout & Architecture
### Core Architecture
**AutoGPT Platform** (`autogpt_platform/`):
- `backend/` - FastAPI server with async support
- `backend/backend/` - Core API logic
- `backend/blocks/` - Agent execution blocks
- `backend/data/` - Database models and schemas
- `schema.prisma` - Database schema definition
- `frontend/` - Next.js application
- `src/app/` - App Router pages and layouts
- `src/components/` - Reusable React components
- `src/lib/` - Utilities and configurations
- `autogpt_libs/` - Shared Python utilities
- `docker-compose.yml` - Development stack orchestration
**Key Configuration Files:**
- `pyproject.toml` - Python dependencies and tooling
- `package.json` - Node.js dependencies and scripts
- `schema.prisma` - Database schema and migrations
- `next.config.mjs` - Next.js configuration
- `tailwind.config.ts` - Styling configuration
### Security & Middleware
**Cache Protection**: Backend includes middleware preventing sensitive data caching in browsers/proxies
**Authentication**: JWT-based with Supabase integration
**User ID Validation**: All data access requires user ID checks - verify this for any `data/*.py` changes
### Development Workflow
**GitHub Actions**: Multiple CI/CD workflows in `.github/workflows/`
- `platform-backend-ci.yml` - Backend testing and validation
- `platform-frontend-ci.yml` - Frontend testing and validation
- `platform-fullstack-ci.yml` - End-to-end integration tests
**Pre-commit Hooks**: Run linting and formatting checks
**Conventional Commits**: Use format `type(scope): description` (e.g., `feat(backend): add API`)
### Key Source Files
**Backend Entry Points:**
- `backend/backend/server/server.py` - FastAPI application setup
- `backend/backend/data/` - Database models and user management
- `backend/blocks/` - Agent execution blocks and logic
**Frontend Entry Points:**
- `frontend/src/app/layout.tsx` - Root application layout
- `frontend/src/app/page.tsx` - Home page
- `frontend/src/lib/supabase/` - Authentication and database client
**Protected Routes**: Update `frontend/lib/supabase/middleware.ts` when adding protected routes
### Agent Block System
Agents are built using a visual block-based system where each block performs a single action. Blocks are defined in `backend/blocks/` and must include:
- Block definition with input/output schemas
- Execution logic with proper error handling
- Tests validating functionality
### Database & ORM
**Prisma ORM** with PostgreSQL backend including pgvector for embeddings:
- Schema in `schema.prisma`
- Migrations in `backend/migrations/`
- Always run `prisma migrate dev` and `prisma generate` after schema changes
## Environment Configuration
### Configuration Files Priority Order
1. **Backend**: `/backend/.env.default` → `/backend/.env` (user overrides)
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
3. **Platform**: `/.env.default` (Supabase/shared) → `/.env` (user overrides)
4. Docker Compose `environment:` sections override file-based config
5. Shell environment variables have highest precedence
### Docker Environment Setup
- All services use hardcoded defaults (no `${VARIABLE}` substitutions)
- The `env_file` directive loads variables INTO containers at runtime
- Backend/Frontend services use YAML anchors for consistent configuration
- Copy `.env.default` files to `.env` for local development customization
## Advanced Development Patterns
### Adding New Blocks
1. Create file in `/backend/backend/blocks/`
2. Inherit from `Block` base class with input/output schemas
3. Implement `run` method with proper error handling
4. Generate block UUID using `uuid.uuid4()`
5. Register in block registry
6. Write tests alongside block implementation
7. Consider how inputs/outputs connect with other blocks in graph editor
### API Development
1. Update routes in `/backend/backend/server/routers/`
2. Add/update Pydantic models in same directory
3. Write tests alongside route files
4. For `data/*.py` changes, validate user ID checks
5. Run `poetry run test` to verify changes
### Frontend Development
1. Components in `/frontend/src/components/`
2. Use existing UI components from `/frontend/src/components/ui/`
3. Add Storybook stories for component development
4. Test user-facing features with Playwright E2E tests
5. Update protected routes in middleware when needed
### Security Guidelines
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
- Prevents sensitive data caching in browsers/proxies
- Add new cacheable endpoints to `CACHEABLE_PATHS`
### CI/CD Alignment
The repository has comprehensive CI workflows that test:
- **Backend**: Python 3.11-3.13, services (Redis/RabbitMQ/ClamAV), Prisma migrations, Poetry lock validation
- **Frontend**: Node.js 21, pnpm, Playwright with Docker Compose stack, API schema validation
- **Integration**: Full-stack type checking and E2E testing
Match these patterns when developing locally - the copilot setup environment mirrors these CI configurations.
## Collaboration with Other AI Assistants
This repository is actively developed with assistance from Claude (via CLAUDE.md files). When working on this codebase:
- Check for existing CLAUDE.md files that provide additional context
- Follow established patterns and conventions already in the codebase
- Maintain consistency with existing code style and architecture
- Consider that changes may be reviewed and extended by both human developers and AI assistants
## Trust These Instructions
These instructions are comprehensive and tested. Only perform additional searches if:
1. Information here is incomplete for your specific task
2. You encounter errors not covered by the workarounds
3. You need to understand implementation details not covered above
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.

View File

@@ -1,302 +0,0 @@
name: "Copilot Setup Steps"
# Automatically run the setup steps when they are changed to allow for easy validation, and
# allow manual testing through the repository's "Actions" tab
on:
workflow_dispatch:
push:
paths:
- .github/workflows/copilot-setup-steps.yml
pull_request:
paths:
- .github/workflows/copilot-setup-steps.yml
jobs:
# The job MUST be called `copilot-setup-steps` or it will not be picked up by Copilot.
copilot-setup-steps:
runs-on: ubuntu-latest
timeout-minutes: 45
# Set the permissions to the lowest permissions possible needed for your steps.
# Copilot will be given its own token for its operations.
permissions:
# If you want to clone the repository as part of your setup steps, for example to install dependencies, you'll need the `contents: read` permission. If you don't clone the repository in your setup steps, Copilot will do this for you automatically after the steps complete.
contents: read
# You can define any steps you want, and they will run before the agent starts.
# If you do not check out your code, Copilot will do this for you.
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
# Backend Python/Poetry setup (mirrors platform-backend-ci.yml)
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11" # Use standard version matching CI
- name: Set up Python dependency cache
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
- name: Install Poetry
run: |
# Extract Poetry version from backend/poetry.lock (matches CI)
cd autogpt_platform/backend
HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
# Install Poetry
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
# Add Poetry to PATH
echo "$HOME/.local/bin" >> $GITHUB_PATH
- name: Check poetry.lock
working-directory: autogpt_platform/backend
run: |
poetry lock
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
echo "Warning: poetry.lock not up to date, but continuing for setup"
git checkout poetry.lock # Reset for clean setup
fi
- name: Install Python dependencies
working-directory: autogpt_platform/backend
run: poetry install
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Enable corepack
run: corepack enable
- name: Set pnpm store directory
run: |
pnpm config set store-dir ~/.pnpm-store
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
- name: Cache frontend dependencies
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install JavaScript dependencies
working-directory: autogpt_platform/frontend
run: pnpm install --frozen-lockfile
# Install Playwright browsers for frontend testing
# NOTE: Disabled to save ~1 minute of setup time. Re-enable if Copilot needs browser automation (e.g., for MCP)
# - name: Install Playwright browsers
# working-directory: autogpt_platform/frontend
# run: pnpm playwright install --with-deps chromium
# Docker setup for development environment
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Copy default environment files
working-directory: autogpt_platform
run: |
# Copy default environment files for development
cp .env.default .env
cp backend/.env.default backend/.env
cp frontend/.env.default frontend/.env
# Phase 1: Cache and load Docker images for faster setup
- name: Set up Docker image cache
id: docker-cache
uses: actions/cache@v4
with:
path: ~/docker-cache
# Use a versioned key for cache invalidation when image list changes
key: docker-images-v2-${{ runner.os }}-${{ hashFiles('.github/workflows/copilot-setup-steps.yml') }}
restore-keys: |
docker-images-v2-${{ runner.os }}-
docker-images-v1-${{ runner.os }}-
- name: Load or pull Docker images
working-directory: autogpt_platform
run: |
mkdir -p ~/docker-cache
# Define image list for easy maintenance
IMAGES=(
"redis:latest"
"rabbitmq:management"
"clamav/clamav-debian:latest"
"busybox:latest"
"kong:2.8.1"
"supabase/gotrue:v2.170.0"
"supabase/postgres:15.8.1.049"
"supabase/postgres-meta:v0.86.1"
"supabase/studio:20250224-d10db0f"
)
# Check if any cached tar files exist (more reliable than cache-hit)
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
echo "Docker cache found, loading images in parallel..."
for image in "${IMAGES[@]}"; do
# Convert image name to filename (replace : and / with -)
filename=$(echo "$image" | tr ':/' '--')
if [ -f ~/docker-cache/${filename}.tar ]; then
echo "Loading $image..."
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
fi
done
wait
echo "All cached images loaded"
else
echo "No Docker cache found, pulling images in parallel..."
# Pull all images in parallel
for image in "${IMAGES[@]}"; do
docker pull "$image" &
done
wait
# Only save cache on main branches (not PRs) to avoid cache pollution
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
echo "Saving Docker images to cache in parallel..."
for image in "${IMAGES[@]}"; do
# Convert image name to filename (replace : and / with -)
filename=$(echo "$image" | tr ':/' '--')
echo "Saving $image..."
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
done
wait
echo "Docker image cache saved"
else
echo "Skipping cache save for PR/feature branch"
fi
fi
echo "Docker images ready for use"
# Phase 2: Build migrate service with GitHub Actions cache
- name: Build migrate Docker image with cache
working-directory: autogpt_platform
run: |
# Build the migrate image with buildx for GHA caching
docker buildx build \
--cache-from type=gha \
--cache-to type=gha,mode=max \
--target migrate \
--tag autogpt_platform-migrate:latest \
--load \
-f backend/Dockerfile \
..
# Start services using pre-built images
- name: Start Docker services for development
working-directory: autogpt_platform
run: |
# Start essential services (migrate image already built with correct tag)
docker compose --profile local up deps --no-build --detach
echo "Waiting for services to be ready..."
# Wait for database to be ready
echo "Checking database readiness..."
timeout 30 sh -c 'until docker compose exec -T db pg_isready -U postgres 2>/dev/null; do
echo " Waiting for database..."
sleep 2
done' && echo "✅ Database is ready" || echo "⚠️ Database ready check timeout after 30s, continuing..."
# Check migrate service status
echo "Checking migration status..."
docker compose ps migrate || echo " Migrate service not visible in ps output"
# Wait for migrate service to complete
echo "Waiting for migrations to complete..."
timeout 30 bash -c '
ATTEMPTS=0
while [ $ATTEMPTS -lt 15 ]; do
ATTEMPTS=$((ATTEMPTS + 1))
# Check using docker directly (more reliable than docker compose ps)
CONTAINER_STATUS=$(docker ps -a --filter "label=com.docker.compose.service=migrate" --format "{{.Status}}" | head -1)
if [ -z "$CONTAINER_STATUS" ]; then
echo " Attempt $ATTEMPTS: Migrate container not found yet..."
elif echo "$CONTAINER_STATUS" | grep -q "Exited (0)"; then
echo "✅ Migrations completed successfully"
docker compose logs migrate --tail=5 2>/dev/null || true
exit 0
elif echo "$CONTAINER_STATUS" | grep -q "Exited ([1-9]"; then
EXIT_CODE=$(echo "$CONTAINER_STATUS" | grep -oE "Exited \([0-9]+\)" | grep -oE "[0-9]+")
echo "❌ Migrations failed with exit code: $EXIT_CODE"
echo "Migration logs:"
docker compose logs migrate --tail=20 2>/dev/null || true
exit 1
elif echo "$CONTAINER_STATUS" | grep -q "Up"; then
echo " Attempt $ATTEMPTS: Migrate container is running... ($CONTAINER_STATUS)"
else
echo " Attempt $ATTEMPTS: Migrate container status: $CONTAINER_STATUS"
fi
sleep 2
done
echo "⚠️ Timeout: Could not determine migration status after 30 seconds"
echo "Final container check:"
docker ps -a --filter "label=com.docker.compose.service=migrate" || true
echo "Migration logs (if available):"
docker compose logs migrate --tail=10 2>/dev/null || echo " No logs available"
' || echo "⚠️ Migration check completed with warnings, continuing..."
# Brief wait for other services to stabilize
echo "Waiting 5 seconds for other services to stabilize..."
sleep 5
# Verify installations and provide environment info
- name: Verify setup and show environment info
run: |
echo "=== Python Setup ==="
python --version
poetry --version
echo "=== Node.js Setup ==="
node --version
pnpm --version
echo "=== Additional Tools ==="
docker --version
docker compose version
gh --version || true
echo "=== Services Status ==="
cd autogpt_platform
docker compose ps || true
echo "=== Backend Dependencies ==="
cd backend
poetry show | head -10 || true
echo "=== Frontend Dependencies ==="
cd ../frontend
pnpm list --depth=0 | head -10 || true
echo "=== Environment Files ==="
ls -la ../.env* || true
ls -la .env* || true
ls -la ../backend/.env* || true
echo "✅ AutoGPT Platform development environment setup complete!"
echo "🚀 Ready for development with Docker services running"
echo "📝 Backend server: poetry run serve (port 8000)"
echo "🌐 Frontend server: pnpm dev (port 3000)"

View File

@@ -32,7 +32,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.11", "3.12", "3.13"]
python-version: ["3.11"]
runs-on: ubuntu-latest
services:
@@ -50,23 +50,6 @@ jobs:
env:
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
clamav:
image: clamav/clamav-debian:latest
ports:
- 3310:3310
env:
CLAMAV_NO_FRESHCLAMD: false
CLAMD_CONF_StreamMaxLength: 50M
CLAMD_CONF_MaxFileSize: 100M
CLAMD_CONF_MaxScanSize: 100M
CLAMD_CONF_MaxThreads: 4
CLAMD_CONF_ReadTimeout: 300
options: >-
--health-cmd "clamdscan --version || exit 1"
--health-interval 30s
--health-timeout 10s
--health-retries 5
--health-start-period 180s
steps:
- name: Checkout repository
@@ -148,35 +131,6 @@ jobs:
# outputs:
# DB_URL, API_URL, GRAPHQL_URL, ANON_KEY, SERVICE_ROLE_KEY, JWT_SECRET
- name: Wait for ClamAV to be ready
run: |
echo "Waiting for ClamAV daemon to start..."
max_attempts=60
attempt=0
until nc -z localhost 3310 || [ $attempt -eq $max_attempts ]; do
echo "ClamAV is unavailable - sleeping (attempt $((attempt+1))/$max_attempts)"
sleep 5
attempt=$((attempt+1))
done
if [ $attempt -eq $max_attempts ]; then
echo "ClamAV failed to start after $((max_attempts*5)) seconds"
echo "Checking ClamAV service logs..."
docker logs $(docker ps -q --filter "ancestor=clamav/clamav-debian:latest") 2>&1 | tail -50 || echo "No ClamAV container found"
exit 1
fi
echo "ClamAV is ready!"
# Verify ClamAV is responsive
echo "Testing ClamAV connection..."
timeout 10 bash -c 'echo "PING" | nc localhost 3310' || {
echo "ClamAV is not responding to PING"
docker logs $(docker ps -q --filter "ancestor=clamav/clamav-debian:latest") 2>&1 | tail -50 || echo "No ClamAV container found"
exit 1
}
- name: Run Database Migrations
run: poetry run prisma migrate dev --name updates
env:
@@ -190,9 +144,9 @@ jobs:
- name: Run pytest with coverage
run: |
if [[ "${{ runner.debug }}" == "1" ]]; then
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG test
else
poetry run pytest -s -vv
poetry run pytest -s -vv test
fi
if: success() || (failure() && steps.lint.outcome == 'failure')
env:
@@ -205,7 +159,6 @@ jobs:
REDIS_HOST: "localhost"
REDIS_PORT: "6379"
REDIS_PASSWORD: "testpassword"
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
env:
CI: true

View File

@@ -18,46 +18,11 @@ defaults:
working-directory: autogpt_platform/frontend
jobs:
setup:
runs-on: ubuntu-latest
outputs:
cache-key: ${{ steps.cache-key.outputs.key }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Enable corepack
run: corepack enable
- name: Generate cache key
id: cache-key
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
- name: Cache dependencies
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
lint:
runs-on: ubuntu-latest
needs: setup
steps:
- name: Checkout repository
uses: actions/checkout@v4
- uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
@@ -67,32 +32,17 @@ jobs:
- name: Enable corepack
run: corepack enable
- name: Restore dependencies cache
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Run lint
run: pnpm lint
chromatic:
type-check:
runs-on: ubuntu-latest
needs: setup
# Only run on dev branch pushes or PRs targeting dev
if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
@@ -102,32 +52,18 @@ jobs:
- name: Enable corepack
run: corepack enable
- name: Restore dependencies cache
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Run Chromatic
uses: chromaui/action@latest
with:
projectToken: chpt_9e7c1a76478c9c8
onlyChanged: true
workingDir: autogpt_platform/frontend
token: ${{ secrets.GITHUB_TOKEN }}
exitOnceUploaded: true
- name: Run tsc check
run: pnpm type-check
test:
runs-on: big-boi
needs: setup
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
browser: [chromium, webkit]
steps:
- name: Checkout repository
@@ -143,93 +79,47 @@ jobs:
- name: Enable corepack
run: corepack enable
- name: Free Disk Space (Ubuntu)
uses: jlumbroso/free-disk-space@main
with:
large-packages: false # slow
docker-images: false # limited benefit
- name: Copy default supabase .env
run: |
cp ../.env.default ../.env
cp ../.env.example ../.env
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Cache Docker layers
uses: actions/cache@v4
with:
path: /tmp/.buildx-cache
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
restore-keys: |
${{ runner.os }}-buildx-frontend-test-
- name: Copy backend .env
run: |
cp ../backend/.env.example ../backend/.env
- name: Run docker compose
run: |
docker compose -f ../docker-compose.yml up -d
env:
DOCKER_BUILDKIT: 1
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
BUILDX_CACHE_TO: type=local,dest=/tmp/.buildx-cache-new,mode=max
- name: Move cache
run: |
rm -rf /tmp/.buildx-cache
if [ -d "/tmp/.buildx-cache-new" ]; then
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
fi
- name: Wait for services to be ready
run: |
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
- name: Create E2E test data
run: |
echo "Creating E2E test data..."
# First try to run the script from inside the container
if docker compose -f ../docker-compose.yml exec -T rest_server test -f /app/autogpt_platform/backend/test/e2e_test_data.py; then
echo "✅ Found e2e_test_data.py in container, running it..."
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python backend/test/e2e_test_data.py" || {
echo "❌ E2E test data creation failed!"
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
exit 1
}
else
echo "⚠️ e2e_test_data.py not found in container, copying and running..."
# Copy the script into the container and run it
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.yml ps -q rest_server):/tmp/e2e_test_data.py || {
echo "❌ Failed to copy script to container"
exit 1
}
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
echo "❌ E2E test data creation failed!"
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
exit 1
}
fi
- name: Restore dependencies cache
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Install Browser 'chromium'
run: pnpm playwright install --with-deps chromium
- name: Setup .env
run: cp .env.example .env
- name: Build frontend
run: pnpm build --turbo
# uses Turbopack, much faster and safe enough for a test pipeline
- name: Install Browser '${{ matrix.browser }}'
run: pnpm playwright install --with-deps ${{ matrix.browser }}
- name: Run Playwright tests
run: pnpm test:no-build
- name: Upload Playwright artifacts
if: failure()
uses: actions/upload-artifact@v4
with:
name: playwright-report
path: playwright-report
run: pnpm test:no-build --project=${{ matrix.browser }}
- name: Print Final Docker Compose logs
if: always()
run: docker compose -f ../docker-compose.yml logs
- uses: actions/upload-artifact@v4
if: ${{ !cancelled() }}
with:
name: playwright-report-${{ matrix.browser }}
path: playwright-report/
retention-days: 30

View File

@@ -1,132 +0,0 @@
name: AutoGPT Platform - Frontend CI
on:
push:
branches: [master, dev]
paths:
- ".github/workflows/platform-fullstack-ci.yml"
- "autogpt_platform/**"
pull_request:
paths:
- ".github/workflows/platform-fullstack-ci.yml"
- "autogpt_platform/**"
merge_group:
defaults:
run:
shell: bash
working-directory: autogpt_platform/frontend
jobs:
setup:
runs-on: ubuntu-latest
outputs:
cache-key: ${{ steps.cache-key.outputs.key }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Enable corepack
run: corepack enable
- name: Generate cache key
id: cache-key
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
- name: Cache dependencies
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
types:
runs-on: ubuntu-latest
needs: setup
strategy:
fail-fast: false
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Enable corepack
run: corepack enable
- name: Copy default supabase .env
run: |
cp ../.env.default ../.env
- name: Copy backend .env
run: |
cp ../backend/.env.default ../backend/.env
- name: Run docker compose
run: |
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
- name: Restore dependencies cache
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Setup .env
run: cp .env.default .env
- name: Wait for services to be ready
run: |
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
- name: Generate API queries
run: pnpm generate:api:force
- name: Check for API schema changes
run: |
if ! git diff --exit-code src/app/api/openapi.json; then
echo "❌ API schema changes detected in src/app/api/openapi.json"
echo ""
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
echo "The API schema is now out of sync with the Front-end queries."
echo ""
echo "To fix this:"
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
echo "2. Run 'pnpm generate:api' locally"
echo "3. Run 'pnpm types' locally"
echo "4. Fix any TypeScript errors that may have been introduced"
echo "5. Commit and push your changes"
echo ""
exit 1
else
echo "✅ No API schema changes detected"
fi
- name: Run Typescript checks
run: pnpm types

5
.gitignore vendored
View File

@@ -5,8 +5,6 @@ classic/original_autogpt/*.json
auto_gpt_workspace/*
*.mpeg
.env
# Root .env files
/.env
azure.yaml
.vscode
.idea/*
@@ -123,6 +121,7 @@ celerybeat.pid
# Environments
.direnv/
.env
.venv
env/
venv*/
@@ -166,7 +165,7 @@ package-lock.json
# Allow for locally private items
# private
pri*
pri*
# ignore
ig*
.github_access_token

View File

@@ -235,7 +235,7 @@ repos:
hooks:
- id: tsc
name: Typecheck - AutoGPT Platform - Frontend
entry: bash -c 'cd autogpt_platform/frontend && pnpm types'
entry: bash -c 'cd autogpt_platform/frontend && pnpm type-check'
files: ^autogpt_platform/frontend/
types: [file]
language: system

6
.vscode/launch.json vendored
View File

@@ -6,7 +6,7 @@
"type": "node-terminal",
"request": "launch",
"cwd": "${workspaceFolder}/autogpt_platform/frontend",
"command": "pnpm dev"
"command": "yarn dev"
},
{
"name": "Frontend: Client Side",
@@ -19,12 +19,12 @@
"type": "node-terminal",
"request": "launch",
"command": "pnpm dev",
"command": "yarn dev",
"cwd": "${workspaceFolder}/autogpt_platform/frontend",
"serverReadyAction": {
"pattern": "- Local:.+(https?://.+)",
"uriFormat": "%s",
"action": "debugWithChrome"
"action": "debugWithEdge"
}
},
{

195
LICENSE
View File

@@ -1,197 +1,6 @@
All portions of this repository are under one of two licenses.
All portions of this repository are under one of two licenses. The majority of the AutoGPT repository is under the MIT License below. The autogpt_platform folder is under the
Polyform Shield License.
- Everything inside the autogpt_platform folder is under the Polyform Shield License.
- Everything outside the autogpt_platform folder is under the MIT License.
More info:
**Polyform Shield License:**
Code and content within the `autogpt_platform` folder is licensed under the Polyform Shield License. This new project is our in-developlemt platform for building, deploying and managing agents.
Read more about this effort here: https://agpt.co/blog/introducing-the-autogpt-platform
**MIT License:**
All other portions of the AutoGPT repository (i.e., everything outside the `autogpt_platform` folder) are licensed under the MIT License. This includes:
- The Original, stand-alone AutoGPT Agent
- Forge: https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/forge
- AG Benchmark: https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/benchmark
- AutoGPT Classic GUI: https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/frontend.
We also publish additional work under the MIT Licence in other repositories, such as GravitasML (https://github.com/Significant-Gravitas/gravitasml) which is developed for and used in the AutoGPT Platform, and our [Code Ability](https://github.com/Significant-Gravitas/AutoGPT-Code-Ability) project.
Both licences are available to read below:
=====================================================
-----------------------------------------------------
=====================================================
# PolyForm Shield License 1.0.0
<https://polyformproject.org/licenses/shield/1.0.0>
## Acceptance
In order to get any license under these terms, you must agree
to them as both strict obligations and conditions to all
your licenses.
## Copyright License
The licensor grants you a copyright license for the
software to do everything you might do with the software
that would otherwise infringe the licensor's copyright
in it for any permitted purpose. However, you may
only distribute the software according to [Distribution
License](#distribution-license) and make changes or new works
based on the software according to [Changes and New Works
License](#changes-and-new-works-license).
## Distribution License
The licensor grants you an additional copyright license
to distribute copies of the software. Your license
to distribute covers distributing the software with
changes and new works permitted by [Changes and New Works
License](#changes-and-new-works-license).
## Notices
You must ensure that anyone who gets a copy of any part of
the software from you also gets a copy of these terms or the
URL for them above, as well as copies of any plain-text lines
beginning with `Required Notice:` that the licensor provided
with the software. For example:
> Required Notice: Copyright Yoyodyne, Inc. (http://example.com)
## Changes and New Works License
The licensor grants you an additional copyright license to
make changes and new works based on the software for any
permitted purpose.
## Patent License
The licensor grants you a patent license for the software that
covers patent claims the licensor can license, or becomes able
to license, that you would infringe by using the software.
## Noncompete
Any purpose is a permitted purpose, except for providing any
product that competes with the software or any product the
licensor or any of its affiliates provides using the software.
## Competition
Goods and services compete even when they provide functionality
through different kinds of interfaces or for different technical
platforms. Applications can compete with services, libraries
with plugins, frameworks with development tools, and so on,
even if they're written in different programming languages
or for different computer architectures. Goods and services
compete even when provided free of charge. If you market a
product as a practical substitute for the software or another
product, it definitely competes.
## New Products
If you are using the software to provide a product that does
not compete, but the licensor or any of its affiliates brings
your product into competition by providing a new version of
the software or another product using the software, you may
continue using versions of the software available under these
terms beforehand to provide your competing product, but not
any later versions.
## Discontinued Products
You may begin using the software to compete with a product
or service that the licensor or any of its affiliates has
stopped providing, unless the licensor includes a plain-text
line beginning with `Licensor Line of Business:` with the
software that mentions that line of business. For example:
> Licensor Line of Business: YoyodyneCMS Content Management
System (http://example.com/cms)
## Sales of Business
If the licensor or any of its affiliates sells a line of
business developing the software or using the software
to provide a product, the buyer can also enforce
[Noncompete](#noncompete) for that product.
## Fair Use
You may have "fair use" rights for the software under the
law. These terms do not limit them.
## No Other Rights
These terms do not allow you to sublicense or transfer any of
your licenses to anyone else, or prevent the licensor from
granting licenses to anyone else. These terms do not imply
any other licenses.
## Patent Defense
If you make any written claim that the software infringes or
contributes to infringement of any patent, your patent license
for the software granted under these terms ends immediately. If
your company makes such a claim, your patent license ends
immediately for work on behalf of your company.
## Violations
The first time you are notified in writing that you have
violated any of these terms, or done anything with the software
not covered by your licenses, your licenses can nonetheless
continue if you come into full compliance with these terms,
and take practical steps to correct past violations, within
32 days of receiving notice. Otherwise, all your licenses
end immediately.
## No Liability
***As far as the law allows, the software comes as is, without
any warranty or condition, and the licensor will not be liable
to you for any damages arising out of these terms or the use
or nature of the software, under any kind of legal claim.***
## Definitions
The **licensor** is the individual or entity offering these
terms, and the **software** is the software the licensor makes
available under these terms.
A **product** can be a good or service, or a combination
of them.
**You** refers to the individual or entity agreeing to these
terms.
**Your company** is any legal entity, sole proprietorship,
or other kind of organization that you work for, plus all
its affiliates.
**Affiliates** means the other organizations than an
organization has control over, is under the control of, or is
under common control with.
**Control** means ownership of substantially all the assets of
an entity, or the power to direct its management and policies
by vote, contract, or otherwise. Control can be direct or
indirect.
**Your licenses** are all the licenses granted to you for the
software under these terms.
**Use** means anything you do with the software requiring one
of your licenses.
=====================================================
-----------------------------------------------------
=====================================================
MIT License

View File

@@ -1,25 +1,16 @@
# AutoGPT: Build, Deploy, and Run AI Agents
[![Discord Follow](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fautogpt%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&label=total%20members&logo=discord&logoColor=white&color=7289da)](https://discord.gg/autogpt) &ensp;
[![Discord Follow](https://dcbadge.vercel.app/api/server/autogpt?style=flat)](https://discord.gg/autogpt) &ensp;
[![Twitter Follow](https://img.shields.io/twitter/follow/Auto_GPT?style=social)](https://twitter.com/Auto_GPT) &ensp;
<!-- Keep these links. Translations will automatically update with the README. -->
[Deutsch](https://zdoc.app/de/Significant-Gravitas/AutoGPT) |
[Español](https://zdoc.app/es/Significant-Gravitas/AutoGPT) |
[français](https://zdoc.app/fr/Significant-Gravitas/AutoGPT) |
[日本語](https://zdoc.app/ja/Significant-Gravitas/AutoGPT) |
[한국어](https://zdoc.app/ko/Significant-Gravitas/AutoGPT) |
[Português](https://zdoc.app/pt/Significant-Gravitas/AutoGPT) |
[Русский](https://zdoc.app/ru/Significant-Gravitas/AutoGPT) |
[中文](https://zdoc.app/zh/Significant-Gravitas/AutoGPT)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
**AutoGPT** is a powerful platform that allows you to create, deploy, and manage continuous AI agents that automate complex workflows.
## Hosting Options
- Download to self-host (Free!)
- [Join the Waitlist](https://bit.ly/3ZDijAI) for the cloud-hosted beta (Closed Beta - Public release Coming Soon!)
- Download to self-host
- [Join the Waitlist](https://bit.ly/3ZDijAI) for the cloud-hosted beta
## How to Self-Host the AutoGPT Platform
## How to Setup for Self-Hosting
> [!NOTE]
> Setting up and hosting the AutoGPT Platform yourself is a technical process.
> If you'd rather something that just works, we recommend [joining the waitlist](https://bit.ly/3ZDijAI) for the cloud-hosted beta.
@@ -59,24 +50,6 @@ We've moved to a fully maintained and regularly updated documentation site.
This tutorial assumes you have Docker, VSCode, git and npm installed.
---
#### ⚡ Quick Setup with One-Line Script (Recommended for Local Hosting)
Skip the manual steps and get started in minutes using our automatic setup script.
For macOS/Linux:
```
curl -fsSL https://setup.agpt.co/install.sh -o install.sh && bash install.sh
```
For Windows (PowerShell):
```
powershell -c "iwr https://setup.agpt.co/install.bat -o install.bat; ./install.bat"
```
This will install dependencies, configure Docker, and launch your local instance — all in one go.
### 🧱 AutoGPT Frontend
The AutoGPT frontend is where users interact with our powerful AI automation platform. It offers multiple ways to engage with and leverage our AI agents. This is the interface where you'll bring your AI automation ideas to life:
@@ -123,17 +96,7 @@ Here are two examples of what you can do with AutoGPT:
These examples show just a glimpse of what you can achieve with AutoGPT! You can create customized workflows to build agents for any use case.
---
### **License Overview:**
🛡️ **Polyform Shield License:**
All code and content within the `autogpt_platform` folder is licensed under the Polyform Shield License. This new project is our in-developlemt platform for building, deploying and managing agents.</br>_[Read more about this effort](https://agpt.co/blog/introducing-the-autogpt-platform)_
🦉 **MIT License:**
All other portions of the AutoGPT repository (i.e., everything outside the `autogpt_platform` folder) are licensed under the MIT License. This includes the original stand-alone AutoGPT Agent, along with projects such as [Forge](https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/forge), [agbenchmark](https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/benchmark) and the [AutoGPT Classic GUI](https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/frontend).</br>We also publish additional work under the MIT Licence in other repositories, such as [GravitasML](https://github.com/Significant-Gravitas/gravitasml) which is developed for and used in the AutoGPT Platform. See also our MIT Licenced [Code Ability](https://github.com/Significant-Gravitas/AutoGPT-Code-Ability) project.
---
### Mission
### Mission and Licencing
Our mission is to provide the tools, so that you can focus on what matters:
- 🏗️ **Building** - Lay the foundation for something amazing.
@@ -146,6 +109,14 @@ Be part of the revolution! **AutoGPT** is here to stay, at the forefront of AI i
&ensp;|&ensp;
**🚀 [Contributing](CONTRIBUTING.md)**
**Licensing:**
MIT License: The majority of the AutoGPT repository is under the MIT License.
Polyform Shield License: This license applies to the autogpt_platform folder.
For more information, see https://agpt.co/blog/introducing-the-autogpt-platform
---
## 🤖 AutoGPT Classic
> Below is information about the classic version of AutoGPT.

View File

@@ -5,7 +5,6 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
## Repository Overview
AutoGPT Platform is a monorepo containing:
- **Backend** (`/backend`): Python FastAPI server with async support
- **Frontend** (`/frontend`): Next.js React application
- **Shared Libraries** (`/autogpt_libs`): Common Python utilities
@@ -13,7 +12,6 @@ AutoGPT Platform is a monorepo containing:
## Essential Commands
### Backend Development
```bash
# Install dependencies
cd backend && poetry install
@@ -21,7 +19,7 @@ cd backend && poetry install
# Run database migrations
poetry run prisma migrate dev
# Start all services (database, redis, rabbitmq, clamav)
# Start all services (database, redis, rabbitmq)
docker compose up -d
# Run the backend server
@@ -33,18 +31,10 @@ poetry run test
# Run specific test
poetry run pytest path/to/test_file.py::test_function_name
# Run block tests (tests that validate all blocks work correctly)
poetry run pytest backend/blocks/test/test_block.py -xvs
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
More details can be found in TESTING.md
#### Creating/Updating Snapshots
@@ -57,8 +47,8 @@ poetry run pytest path/to/test.py --snapshot-update
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
### Frontend Development
### Frontend Development
```bash
# Install dependencies
cd frontend && npm install
@@ -76,22 +66,19 @@ npm run storybook
npm run build
# Type checking
npm run types
npm run type-check
```
## Architecture Overview
### Backend Architecture
- **API Layer**: FastAPI with REST and WebSocket endpoints
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
- **Queue System**: RabbitMQ for async task processing
- **Execution Engine**: Separate executor service processes agent workflows
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
### Frontend Architecture
- **Framework**: Next.js App Router with React Server Components
- **State Management**: React hooks + Supabase client for real-time updates
- **Workflow Builder**: Visual graph editor using @xyflow/react
@@ -99,24 +86,19 @@ npm run types
- **Feature Flags**: LaunchDarkly integration
### Key Concepts
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
3. **Integrations**: OAuth and API connections stored per user
4. **Store**: Marketplace for sharing agent templates
5. **Virus Scanning**: ClamAV integration for file upload security
### Testing Approach
- Backend uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
- Frontend uses Playwright for E2E tests
- Component testing via Storybook
### Database Schema
Key models (defined in `/backend/schema.prisma`):
- `User`: Authentication and profile data
- `AgentGraph`: Workflow definitions with version control
- `AgentGraphExecution`: Execution history and results
@@ -124,107 +106,27 @@ Key models (defined in `/backend/schema.prisma`):
- `StoreListing`: Marketplace listings for sharing agents
### Environment Configuration
#### Configuration Files
- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides)
- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides)
- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides)
#### Docker Environment Loading Order
1. `.env.default` files provide base configuration (tracked in git)
2. `.env` files provide user-specific overrides (gitignored)
3. Docker Compose `environment:` sections provide service-specific overrides
4. Shell environment variables have highest precedence
#### Key Points
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
- The `env_file` directive loads variables INTO containers at runtime
- Backend/Frontend services use YAML anchors for consistent configuration
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
- Backend: `.env` file in `/backend`
- Frontend: `.env.local` file in `/frontend`
- Both require Supabase credentials and API keys for various services
### Common Development Tasks
**Adding a new block:**
1. Create new file in `/backend/backend/blocks/`
2. Inherit from `Block` base class
3. Define input/output schemas
4. Implement `run` method
5. Register in block registry
6. Generate the block uuid using `uuid.uuid4()`
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
**Modifying the API:**
1. Update route in `/backend/backend/server/routers/`
2. Add/update Pydantic models in same directory
3. Write tests alongside the route file
4. Run `poetry run test` to verify
**Frontend feature development:**
1. Components go in `/frontend/src/components/`
2. Use existing UI components from `/frontend/src/components/ui/`
3. Add Storybook stories for new components
4. Test with Playwright if user-facing
### Security Implementation
**Cache Protection Middleware:**
- Located in `/backend/backend/server/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications
### Creating Pull Requests
- Create the PR aginst the `dev` branch of the repository.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/
- Use conventional commit messages (see below)/
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/
- Run the github pre-commit hooks to ensure code quality.
### Reviewing/Revising Pull Requests
- When the user runs /pr-comments or tries to fetch them, also run gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews to get the reviews
- Use gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews/[review_id]/comments to get the review contents
- Use gh api /repos/Significant-Gravitas/AutoGPT/issues/9924/comments to get the pr specific comments
### Conventional Commits
Use this format for commit messages and Pull Request titles:
**Conventional Commit Types:**
- `feat`: Introduces a new feature to the codebase
- `fix`: Patches a bug in the codebase
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
- `ci`: Changes to CI configuration
- `docs`: Documentation-only changes
- `dx`: Improvements to the developer experience
**Recommended Base Scopes:**
- `platform`: Changes affecting both frontend and backend
- `frontend`
- `backend`
- `infra`
- `blocks`: Modifications/additions of individual blocks
**Subscope Examples:**
- `backend/executor`
- `backend/db`
- `frontend/builder` (includes changes to the block UI component)
- `infra/prod`
Use these scopes and subscopes for clarity and consistency in commit messages.
4. Test with Playwright if user-facing

View File

@@ -8,6 +8,7 @@ Welcome to the AutoGPT Platform - a powerful system for creating and running AI
- Docker
- Docker Compose V2 (comes with Docker Desktop, or can be installed separately)
- Node.js & NPM (for running the frontend application)
### Running the System
@@ -23,10 +24,10 @@ To run the AutoGPT Platform, follow these steps:
2. Run the following command:
```
cp .env.default .env
cp .env.example .env
```
This command will copy the `.env.default` file to `.env`. You can modify the `.env` file to add your own environment variables.
This command will copy the `.env.example` file to `.env`. You can modify the `.env` file to add your own environment variables.
3. Run the following command:
@@ -36,7 +37,38 @@ To run the AutoGPT Platform, follow these steps:
This command will start all the necessary backend services defined in the `docker-compose.yml` file in detached mode.
4. After all the services are in ready state, open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
4. Navigate to `frontend` within the `autogpt_platform` directory:
```
cd frontend
```
You will need to run your frontend application separately on your local machine.
5. Run the following command:
```
cp .env.example .env.local
```
This command will copy the `.env.example` file to `.env.local` in the `frontend` directory. You can modify the `.env.local` within this folder to add your own environment variables for the frontend application.
6. Run the following command:
Enable corepack and install dependencies by running:
```
corepack enable
pnpm i
```
Then start the frontend application in development mode:
```
pnpm dev
```
7. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
### Docker Compose Commands
@@ -132,28 +164,3 @@ To persist data for PostgreSQL and Redis, you can modify the `docker-compose.yml
3. Save the file and run `docker compose up -d` to apply the changes.
This configuration will create named volumes for PostgreSQL and Redis, ensuring that your data persists across container restarts.
### API Client Generation
The platform includes scripts for generating and managing the API client:
- `pnpm fetch:openapi`: Fetches the OpenAPI specification from the backend service (requires backend to be running on port 8006)
- `pnpm generate:api-client`: Generates the TypeScript API client from the OpenAPI specification using Orval
- `pnpm generate:api`: Runs both fetch and generate commands in sequence
#### Manual API Client Updates
If you need to update the API client after making changes to the backend API:
1. Ensure the backend services are running:
```
docker compose up -d
```
2. Generate the updated API client:
```
pnpm generate:api
```
This will fetch the latest OpenAPI specification and regenerate the TypeScript client code.

View File

@@ -7,5 +7,9 @@ class Settings:
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
self.JWT_ALGORITHM: str = "HS256"
@property
def is_configured(self) -> bool:
return bool(self.JWT_SECRET_KEY)
settings = Settings()

View File

@@ -10,8 +10,8 @@ from starlette.status import HTTP_401_UNAUTHORIZED
from .config import settings
from .jwt_utils import parse_jwt_token
security = HTTPBearer()
logger = logging.getLogger(__name__)
bearer_auth = HTTPBearer(auto_error=False)
async def auth_middleware(request: Request):
@@ -20,10 +20,11 @@ async def auth_middleware(request: Request):
logger.warning("Auth disabled")
return {}
credentials = await bearer_auth(request)
security = HTTPBearer()
credentials = await security(request)
if not credentials:
raise HTTPException(status_code=401, detail="Not authenticated")
raise HTTPException(status_code=401, detail="Authorization header is missing")
try:
payload = parse_jwt_token(credentials.credentials)

View File

@@ -0,0 +1,166 @@
import asyncio
import contextlib
import logging
from functools import wraps
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, cast
import ldclient
from fastapi import HTTPException
from ldclient import Context, LDClient
from ldclient.config import Config
from typing_extensions import ParamSpec
from .config import SETTINGS
logger = logging.getLogger(__name__)
P = ParamSpec("P")
T = TypeVar("T")
def get_client() -> LDClient:
"""Get the LaunchDarkly client singleton."""
return ldclient.get()
def initialize_launchdarkly() -> None:
sdk_key = SETTINGS.launch_darkly_sdk_key
logger.debug(
f"Initializing LaunchDarkly with SDK key: {'present' if sdk_key else 'missing'}"
)
if not sdk_key:
logger.warning("LaunchDarkly SDK key not configured")
return
config = Config(sdk_key)
ldclient.set_config(config)
if ldclient.get().is_initialized():
logger.info("LaunchDarkly client initialized successfully")
else:
logger.error("LaunchDarkly client failed to initialize")
def shutdown_launchdarkly() -> None:
"""Shutdown the LaunchDarkly client."""
if ldclient.get().is_initialized():
ldclient.get().close()
logger.info("LaunchDarkly client closed successfully")
def create_context(
user_id: str, additional_attributes: Optional[Dict[str, Any]] = None
) -> Context:
"""Create LaunchDarkly context with optional additional attributes."""
builder = Context.builder(str(user_id)).kind("user")
if additional_attributes:
for key, value in additional_attributes.items():
builder.set(key, value)
return builder.build()
def feature_flag(
flag_key: str,
default: bool = False,
) -> Callable[
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
]:
"""
Decorator for feature flag protected endpoints.
"""
def decorator(
func: Callable[P, Union[T, Awaitable[T]]],
) -> Callable[P, Union[T, Awaitable[T]]]:
@wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
try:
user_id = kwargs.get("user_id")
if not user_id:
raise ValueError("user_id is required")
if not get_client().is_initialized():
logger.warning(
f"LaunchDarkly not initialized, using default={default}"
)
is_enabled = default
else:
context = create_context(str(user_id))
is_enabled = get_client().variation(flag_key, context, default)
if not is_enabled:
raise HTTPException(status_code=404, detail="Feature not available")
result = func(*args, **kwargs)
if asyncio.iscoroutine(result):
return await result
return cast(T, result)
except Exception as e:
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
raise
@wraps(func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
try:
user_id = kwargs.get("user_id")
if not user_id:
raise ValueError("user_id is required")
if not get_client().is_initialized():
logger.warning(
f"LaunchDarkly not initialized, using default={default}"
)
is_enabled = default
else:
context = create_context(str(user_id))
is_enabled = get_client().variation(flag_key, context, default)
if not is_enabled:
raise HTTPException(status_code=404, detail="Feature not available")
return cast(T, func(*args, **kwargs))
except Exception as e:
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
raise
return cast(
Callable[P, Union[T, Awaitable[T]]],
async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper,
)
return decorator
def percentage_rollout(
flag_key: str,
default: bool = False,
) -> Callable[
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
]:
"""Decorator for percentage-based rollouts."""
return feature_flag(flag_key, default)
def beta_feature(
flag_key: Optional[str] = None,
unauthorized_response: Any = {"message": "Not available in beta"},
) -> Callable[
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
]:
"""Decorator for beta features."""
actual_key = f"beta-{flag_key}" if flag_key else "beta"
return feature_flag(actual_key, False)
@contextlib.contextmanager
def mock_flag_variation(flag_key: str, return_value: Any):
"""Context manager for testing feature flags."""
original_variation = get_client().variation
get_client().variation = lambda key, context, default: (
return_value if key == flag_key else original_variation(key, context, default)
)
try:
yield
finally:
get_client().variation = original_variation

View File

@@ -0,0 +1,45 @@
import pytest
from ldclient import LDClient
from autogpt_libs.feature_flag.client import feature_flag, mock_flag_variation
@pytest.fixture
def ld_client(mocker):
client = mocker.Mock(spec=LDClient)
mocker.patch("ldclient.get", return_value=client)
client.is_initialized.return_value = True
return client
@pytest.mark.asyncio
async def test_feature_flag_enabled(ld_client):
ld_client.variation.return_value = True
@feature_flag("test-flag")
async def test_function(user_id: str):
return "success"
result = test_function(user_id="test-user")
assert result == "success"
ld_client.variation.assert_called_once()
@pytest.mark.asyncio
async def test_feature_flag_unauthorized_response(ld_client):
ld_client.variation.return_value = False
@feature_flag("test-flag")
async def test_function(user_id: str):
return "success"
result = test_function(user_id="test-user")
assert result == {"error": "disabled"}
def test_mock_flag_variation(ld_client):
with mock_flag_variation("test-flag", True):
assert ld_client.variation("test-flag", None, False)
with mock_flag_variation("test-flag", False):
assert ld_client.variation("test-flag", None, False)

View File

@@ -0,0 +1,15 @@
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
launch_darkly_sdk_key: str = Field(
default="",
description="The Launch Darkly SDK key",
validation_alias="LAUNCH_DARKLY_SDK_KEY",
)
model_config = SettingsConfigDict(case_sensitive=True, extra="ignore")
SETTINGS = Settings()

View File

@@ -1,8 +1,6 @@
"""Logging module for Auto-GPT."""
import logging
import os
import socket
import sys
from pathlib import Path
@@ -12,15 +10,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
from .filters import BelowLevelFilter
from .formatters import AGPTFormatter
# Configure global socket timeout and gRPC keepalive to prevent deadlocks
# This must be done at import time before any gRPC connections are established
socket.setdefaulttimeout(30) # 30-second socket timeout
# Enable gRPC keepalive to detect dead connections faster
os.environ.setdefault("GRPC_KEEPALIVE_TIME_MS", "30000") # 30 seconds
os.environ.setdefault("GRPC_KEEPALIVE_TIMEOUT_MS", "5000") # 5 seconds
os.environ.setdefault("GRPC_KEEPALIVE_PERMIT_WITHOUT_CALLS", "true")
LOG_DIR = Path(__file__).parent.parent.parent.parent / "logs"
LOG_FILE = "activity.log"
DEBUG_LOG_FILE = "debug.log"
@@ -90,6 +79,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
Note: This function is typically called at the start of the application
to set up the logging infrastructure.
"""
config = LoggingConfig()
log_handlers: list[logging.Handler] = []
@@ -115,17 +105,13 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
if config.enable_cloud_logging or force_cloud_logging:
import google.cloud.logging
from google.cloud.logging.handlers import CloudLoggingHandler
from google.cloud.logging_v2.handlers.transports import (
BackgroundThreadTransport,
)
from google.cloud.logging_v2.handlers.transports.sync import SyncTransport
client = google.cloud.logging.Client()
# Use BackgroundThreadTransport to prevent blocking the main thread
# and deadlocks when gRPC calls to Google Cloud Logging hang
cloud_handler = CloudLoggingHandler(
client,
name="autogpt_logs",
transport=BackgroundThreadTransport,
transport=SyncTransport,
)
cloud_handler.setLevel(config.level)
log_handlers.append(cloud_handler)

View File

@@ -1,5 +1,39 @@
import logging
import re
from typing import Any
import uvicorn.config
from colorama import Fore
def remove_color_codes(s: str) -> str:
return re.sub(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])", "", s)
def fmt_kwargs(kwargs: dict) -> str:
return ", ".join(f"{n}={repr(v)}" for n, v in kwargs.items())
def print_attribute(
title: str, value: Any, title_color: str = Fore.GREEN, value_color: str = ""
) -> None:
logger = logging.getLogger()
logger.info(
str(value),
extra={
"title": f"{title.rstrip(':')}:",
"title_color": title_color,
"color": value_color,
},
)
def generate_uvicorn_config():
"""
Generates a uvicorn logging config that silences uvicorn's default logging and tells it to use the native logging module.
"""
log_config = dict(uvicorn.config.LOGGING_CONFIG)
log_config["loggers"]["uvicorn"] = {"handlers": []}
log_config["loggers"]["uvicorn.error"] = {"handlers": []}
log_config["loggers"]["uvicorn.access"] = {"handlers": []}
return log_config

View File

@@ -1,34 +1,17 @@
import inspect
import logging
import threading
import time
from functools import wraps
from typing import (
Awaitable,
Callable,
ParamSpec,
Protocol,
Tuple,
TypeVar,
cast,
overload,
runtime_checkable,
)
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
P = ParamSpec("P")
R = TypeVar("R")
logger = logging.getLogger(__name__)
@overload
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
@overload
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
pass
@overload
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
pass
def thread_cached(func: Callable[P, R]) -> Callable[P, R]: ...
def thread_cached(
@@ -74,193 +57,3 @@ def thread_cached(
def clear_thread_cache(func: Callable) -> None:
if clear := getattr(func, "clear_cache", None):
clear()
FuncT = TypeVar("FuncT")
R_co = TypeVar("R_co", covariant=True)
@runtime_checkable
class AsyncCachedFunction(Protocol[P, R_co]):
"""Protocol for async functions with cache management methods."""
def cache_clear(self) -> None:
"""Clear all cached entries."""
return None
def cache_info(self) -> dict[str, int | None]:
"""Get cache statistics."""
return {}
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
"""Call the cached function."""
return None # type: ignore
def async_ttl_cache(
maxsize: int = 128, ttl_seconds: int | None = None
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
"""
TTL (Time To Live) cache decorator for async functions.
Similar to functools.lru_cache but works with async functions and includes optional TTL.
Args:
maxsize: Maximum number of cached entries
ttl_seconds: Time to live in seconds. If None, entries never expire (like lru_cache)
Returns:
Decorator function
Example:
# With TTL
@async_ttl_cache(maxsize=1000, ttl_seconds=300)
async def api_call(param: str) -> dict:
return {"result": param}
# Without TTL (permanent cache like lru_cache)
@async_ttl_cache(maxsize=1000)
async def expensive_computation(param: str) -> dict:
return {"result": param}
"""
def decorator(
async_func: Callable[P, Awaitable[R]],
) -> AsyncCachedFunction[P, R]:
# Cache storage - use union type to handle both cases
cache_storage: dict[tuple, R | Tuple[R, float]] = {}
@wraps(async_func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# Create cache key from arguments
key = (args, tuple(sorted(kwargs.items())))
current_time = time.time()
# Check if we have a valid cached entry
if key in cache_storage:
if ttl_seconds is None:
# No TTL - return cached result directly
logger.debug(
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
)
return cast(R, cache_storage[key])
else:
# With TTL - check expiration
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.debug(
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
)
return cast(R, result)
else:
# Expired entry
del cache_storage[key]
logger.debug(
f"Cache entry expired for {async_func.__name__}"
)
# Cache miss or expired - fetch fresh data
logger.debug(
f"Cache miss for {async_func.__name__} with key: {str(key)[:50]}"
)
result = await async_func(*args, **kwargs)
# Store in cache
if ttl_seconds is None:
cache_storage[key] = result
else:
cache_storage[key] = (result, current_time)
# Simple cleanup when cache gets too large
if len(cache_storage) > maxsize:
# Remove oldest entries (simple FIFO cleanup)
cutoff = maxsize // 2
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
logger.debug(
f"Cache cleanup: removed {len(oldest_keys)} entries for {async_func.__name__}"
)
return result
# Add cache management methods (similar to functools.lru_cache)
def cache_clear() -> None:
cache_storage.clear()
def cache_info() -> dict[str, int | None]:
return {
"size": len(cache_storage),
"maxsize": maxsize,
"ttl_seconds": ttl_seconds,
}
# Attach methods to wrapper
setattr(wrapper, "cache_clear", cache_clear)
setattr(wrapper, "cache_info", cache_info)
return cast(AsyncCachedFunction[P, R], wrapper)
return decorator
@overload
def async_cache(
func: Callable[P, Awaitable[R]],
) -> AsyncCachedFunction[P, R]:
pass
@overload
def async_cache(
func: None = None,
*,
maxsize: int = 128,
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
pass
def async_cache(
func: Callable[P, Awaitable[R]] | None = None,
*,
maxsize: int = 128,
) -> (
AsyncCachedFunction[P, R]
| Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]
):
"""
Process-level cache decorator for async functions (no TTL).
Similar to functools.lru_cache but works with async functions.
This is a convenience wrapper around async_ttl_cache with ttl_seconds=None.
Args:
func: The async function to cache (when used without parentheses)
maxsize: Maximum number of cached entries
Returns:
Decorated function or decorator
Example:
# Without parentheses (uses default maxsize=128)
@async_cache
async def get_data(param: str) -> dict:
return {"result": param}
# With parentheses and custom maxsize
@async_cache(maxsize=1000)
async def expensive_computation(param: str) -> dict:
# Expensive computation here
return {"result": param}
"""
if func is None:
# Called with parentheses @async_cache() or @async_cache(maxsize=...)
return async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
else:
# Called without parentheses @async_cache
decorator = async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
return decorator(func)

View File

@@ -1,705 +0,0 @@
"""Tests for the @thread_cached decorator.
This module tests the thread-local caching functionality including:
- Basic caching for sync and async functions
- Thread isolation (each thread has its own cache)
- Cache clearing functionality
- Exception handling (exceptions are not cached)
- Argument handling (positional vs keyword arguments)
"""
import asyncio
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock
import pytest
from autogpt_libs.utils.cache import (
async_cache,
async_ttl_cache,
clear_thread_cache,
thread_cached,
)
class TestThreadCached:
def test_sync_function_caching(self):
call_count = 0
@thread_cached
def expensive_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
return x + y
assert expensive_function(1, 2) == 3
assert call_count == 1
assert expensive_function(1, 2) == 3
assert call_count == 1
assert expensive_function(1, y=2) == 3
assert call_count == 2
assert expensive_function(2, 3) == 5
assert call_count == 3
assert expensive_function(1) == 1
assert call_count == 4
@pytest.mark.asyncio
async def test_async_function_caching(self):
call_count = 0
@thread_cached
async def expensive_async_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return x + y
assert await expensive_async_function(1, 2) == 3
assert call_count == 1
assert await expensive_async_function(1, 2) == 3
assert call_count == 1
assert await expensive_async_function(1, y=2) == 3
assert call_count == 2
assert await expensive_async_function(2, 3) == 5
assert call_count == 3
def test_thread_isolation(self):
call_count = 0
results = {}
@thread_cached
def thread_specific_function(x: int) -> str:
nonlocal call_count
call_count += 1
return f"{threading.current_thread().name}-{x}"
def worker(thread_id: int):
result1 = thread_specific_function(1)
result2 = thread_specific_function(1)
result3 = thread_specific_function(2)
results[thread_id] = (result1, result2, result3)
with ThreadPoolExecutor(max_workers=3) as executor:
futures = [executor.submit(worker, i) for i in range(3)]
for future in futures:
future.result()
assert call_count >= 2
for thread_id, (r1, r2, r3) in results.items():
assert r1 == r2
assert r1 != r3
@pytest.mark.asyncio
async def test_async_thread_isolation(self):
call_count = 0
results = {}
@thread_cached
async def async_thread_specific_function(x: int) -> str:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return f"{threading.current_thread().name}-{x}"
async def async_worker(worker_id: int):
result1 = await async_thread_specific_function(1)
result2 = await async_thread_specific_function(1)
result3 = await async_thread_specific_function(2)
results[worker_id] = (result1, result2, result3)
tasks = [async_worker(i) for i in range(3)]
await asyncio.gather(*tasks)
for worker_id, (r1, r2, r3) in results.items():
assert r1 == r2
assert r1 != r3
def test_clear_cache_sync(self):
call_count = 0
@thread_cached
def clearable_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 2
assert clearable_function(5) == 10
assert call_count == 1
assert clearable_function(5) == 10
assert call_count == 1
clear_thread_cache(clearable_function)
assert clearable_function(5) == 10
assert call_count == 2
@pytest.mark.asyncio
async def test_clear_cache_async(self):
call_count = 0
@thread_cached
async def clearable_async_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return x * 2
assert await clearable_async_function(5) == 10
assert call_count == 1
assert await clearable_async_function(5) == 10
assert call_count == 1
clear_thread_cache(clearable_async_function)
assert await clearable_async_function(5) == 10
assert call_count == 2
def test_simple_arguments(self):
call_count = 0
@thread_cached
def simple_function(a: str, b: int, c: str = "default") -> str:
nonlocal call_count
call_count += 1
return f"{a}-{b}-{c}"
# First call with all positional args
result1 = simple_function("test", 42, "custom")
assert call_count == 1
# Same args, all positional - should hit cache
result2 = simple_function("test", 42, "custom")
assert call_count == 1
assert result1 == result2
# Same values but last arg as keyword - creates different cache key
result3 = simple_function("test", 42, c="custom")
assert call_count == 2
assert result1 == result3 # Same result, different cache entry
# Different value - new cache entry
result4 = simple_function("test", 43, "custom")
assert call_count == 3
assert result1 != result4
def test_positional_vs_keyword_args(self):
"""Test that positional and keyword arguments create different cache entries."""
call_count = 0
@thread_cached
def func(a: int, b: int = 10) -> str:
nonlocal call_count
call_count += 1
return f"result-{a}-{b}"
# All positional
result1 = func(1, 2)
assert call_count == 1
assert result1 == "result-1-2"
# Same values, but second arg as keyword
result2 = func(1, b=2)
assert call_count == 2 # Different cache key!
assert result2 == "result-1-2" # Same result
# Verify both are cached separately
func(1, 2) # Uses first cache entry
assert call_count == 2
func(1, b=2) # Uses second cache entry
assert call_count == 2
def test_exception_handling(self):
call_count = 0
@thread_cached
def failing_function(x: int) -> int:
nonlocal call_count
call_count += 1
if x < 0:
raise ValueError("Negative value")
return x * 2
assert failing_function(5) == 10
assert call_count == 1
with pytest.raises(ValueError):
failing_function(-1)
assert call_count == 2
with pytest.raises(ValueError):
failing_function(-1)
assert call_count == 3
assert failing_function(5) == 10
assert call_count == 3
@pytest.mark.asyncio
async def test_async_exception_handling(self):
call_count = 0
@thread_cached
async def async_failing_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
if x < 0:
raise ValueError("Negative value")
return x * 2
assert await async_failing_function(5) == 10
assert call_count == 1
with pytest.raises(ValueError):
await async_failing_function(-1)
assert call_count == 2
with pytest.raises(ValueError):
await async_failing_function(-1)
assert call_count == 3
def test_sync_caching_performance(self):
@thread_cached
def slow_function(x: int) -> int:
print(f"slow_function called with x={x}")
time.sleep(0.1)
return x * 2
start = time.time()
result1 = slow_function(5)
first_call_time = time.time() - start
print(f"First call took {first_call_time:.4f} seconds")
start = time.time()
result2 = slow_function(5)
second_call_time = time.time() - start
print(f"Second call took {second_call_time:.4f} seconds")
assert result1 == result2 == 10
assert first_call_time > 0.09
assert second_call_time < 0.01
@pytest.mark.asyncio
async def test_async_caching_performance(self):
@thread_cached
async def slow_async_function(x: int) -> int:
print(f"slow_async_function called with x={x}")
await asyncio.sleep(0.1)
return x * 2
start = time.time()
result1 = await slow_async_function(5)
first_call_time = time.time() - start
print(f"First async call took {first_call_time:.4f} seconds")
start = time.time()
result2 = await slow_async_function(5)
second_call_time = time.time() - start
print(f"Second async call took {second_call_time:.4f} seconds")
assert result1 == result2 == 10
assert first_call_time > 0.09
assert second_call_time < 0.01
def test_with_mock_objects(self):
mock = Mock(return_value=42)
@thread_cached
def function_using_mock(x: int) -> int:
return mock(x)
assert function_using_mock(1) == 42
assert mock.call_count == 1
assert function_using_mock(1) == 42
assert mock.call_count == 1
assert function_using_mock(2) == 42
assert mock.call_count == 2
class TestAsyncTTLCache:
"""Tests for the @async_ttl_cache decorator."""
@pytest.mark.asyncio
async def test_basic_caching(self):
"""Test basic caching functionality."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def cached_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01) # Simulate async work
return x + y
# First call
result1 = await cached_function(1, 2)
assert result1 == 3
assert call_count == 1
# Second call with same args - should use cache
result2 = await cached_function(1, 2)
assert result2 == 3
assert call_count == 1 # No additional call
# Different args - should call function again
result3 = await cached_function(2, 3)
assert result3 == 5
assert call_count == 2
@pytest.mark.asyncio
async def test_ttl_expiration(self):
"""Test that cache entries expire after TTL."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
async def short_lived_cache(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 2
# First call
result1 = await short_lived_cache(5)
assert result1 == 10
assert call_count == 1
# Second call immediately - should use cache
result2 = await short_lived_cache(5)
assert result2 == 10
assert call_count == 1
# Wait for TTL to expire
await asyncio.sleep(1.1)
# Third call after expiration - should call function again
result3 = await short_lived_cache(5)
assert result3 == 10
assert call_count == 2
@pytest.mark.asyncio
async def test_cache_info(self):
"""Test cache info functionality."""
@async_ttl_cache(maxsize=5, ttl_seconds=300)
async def info_test_function(x: int) -> int:
return x * 3
# Check initial cache info
info = info_test_function.cache_info()
assert info["size"] == 0
assert info["maxsize"] == 5
assert info["ttl_seconds"] == 300
# Add an entry
await info_test_function(1)
info = info_test_function.cache_info()
assert info["size"] == 1
@pytest.mark.asyncio
async def test_cache_clear(self):
"""Test cache clearing functionality."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def clearable_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 4
# First call
result1 = await clearable_function(2)
assert result1 == 8
assert call_count == 1
# Second call - should use cache
result2 = await clearable_function(2)
assert result2 == 8
assert call_count == 1
# Clear cache
clearable_function.cache_clear()
# Third call after clear - should call function again
result3 = await clearable_function(2)
assert result3 == 8
assert call_count == 2
@pytest.mark.asyncio
async def test_maxsize_cleanup(self):
"""Test that cache cleans up when maxsize is exceeded."""
call_count = 0
@async_ttl_cache(maxsize=3, ttl_seconds=60)
async def size_limited_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x**2
# Fill cache to maxsize
await size_limited_function(1) # call_count: 1
await size_limited_function(2) # call_count: 2
await size_limited_function(3) # call_count: 3
info = size_limited_function.cache_info()
assert info["size"] == 3
# Add one more entry - should trigger cleanup
await size_limited_function(4) # call_count: 4
# Cache size should be reduced (cleanup removes oldest entries)
info = size_limited_function.cache_info()
assert info["size"] is not None and info["size"] <= 3 # Should be cleaned up
@pytest.mark.asyncio
async def test_argument_variations(self):
"""Test caching with different argument patterns."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def arg_test_function(a: int, b: str = "default", *, c: int = 100) -> str:
nonlocal call_count
call_count += 1
return f"{a}-{b}-{c}"
# Different ways to call with same logical arguments
result1 = await arg_test_function(1, "test", c=200)
assert call_count == 1
# Same arguments, same order - should use cache
result2 = await arg_test_function(1, "test", c=200)
assert call_count == 1
assert result1 == result2
# Different arguments - should call function
result3 = await arg_test_function(2, "test", c=200)
assert call_count == 2
assert result1 != result3
@pytest.mark.asyncio
async def test_exception_handling(self):
"""Test that exceptions are not cached."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def exception_function(x: int) -> int:
nonlocal call_count
call_count += 1
if x < 0:
raise ValueError("Negative value not allowed")
return x * 2
# Successful call - should be cached
result1 = await exception_function(5)
assert result1 == 10
assert call_count == 1
# Same successful call - should use cache
result2 = await exception_function(5)
assert result2 == 10
assert call_count == 1
# Exception call - should not be cached
with pytest.raises(ValueError):
await exception_function(-1)
assert call_count == 2
# Same exception call - should call again (not cached)
with pytest.raises(ValueError):
await exception_function(-1)
assert call_count == 3
@pytest.mark.asyncio
async def test_concurrent_calls(self):
"""Test caching behavior with concurrent calls."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def concurrent_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.05) # Simulate work
return x * x
# Launch concurrent calls with same arguments
tasks = [concurrent_function(3) for _ in range(5)]
results = await asyncio.gather(*tasks)
# All results should be the same
assert all(result == 9 for result in results)
# Note: Due to race conditions, call_count might be up to 5 for concurrent calls
# This tests that the cache doesn't break under concurrent access
assert 1 <= call_count <= 5
class TestAsyncCache:
"""Tests for the @async_cache decorator (no TTL)."""
@pytest.mark.asyncio
async def test_basic_caching_no_ttl(self):
"""Test basic caching functionality without TTL."""
call_count = 0
@async_cache(maxsize=10)
async def cached_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01) # Simulate async work
return x + y
# First call
result1 = await cached_function(1, 2)
assert result1 == 3
assert call_count == 1
# Second call with same args - should use cache
result2 = await cached_function(1, 2)
assert result2 == 3
assert call_count == 1 # No additional call
# Third call after some time - should still use cache (no TTL)
await asyncio.sleep(0.05)
result3 = await cached_function(1, 2)
assert result3 == 3
assert call_count == 1 # Still no additional call
# Different args - should call function again
result4 = await cached_function(2, 3)
assert result4 == 5
assert call_count == 2
@pytest.mark.asyncio
async def test_no_ttl_vs_ttl_behavior(self):
"""Test the difference between TTL and no-TTL caching."""
ttl_call_count = 0
no_ttl_call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
async def ttl_function(x: int) -> int:
nonlocal ttl_call_count
ttl_call_count += 1
return x * 2
@async_cache(maxsize=10) # No TTL
async def no_ttl_function(x: int) -> int:
nonlocal no_ttl_call_count
no_ttl_call_count += 1
return x * 2
# First calls
await ttl_function(5)
await no_ttl_function(5)
assert ttl_call_count == 1
assert no_ttl_call_count == 1
# Wait for TTL to expire
await asyncio.sleep(1.1)
# Second calls after TTL expiry
await ttl_function(5) # Should call function again (TTL expired)
await no_ttl_function(5) # Should use cache (no TTL)
assert ttl_call_count == 2 # TTL function called again
assert no_ttl_call_count == 1 # No-TTL function still cached
@pytest.mark.asyncio
async def test_async_cache_info(self):
"""Test cache info for no-TTL cache."""
@async_cache(maxsize=5)
async def info_test_function(x: int) -> int:
return x * 3
# Check initial cache info
info = info_test_function.cache_info()
assert info["size"] == 0
assert info["maxsize"] == 5
assert info["ttl_seconds"] is None # No TTL
# Add an entry
await info_test_function(1)
info = info_test_function.cache_info()
assert info["size"] == 1
class TestTTLOptional:
"""Tests for optional TTL functionality."""
@pytest.mark.asyncio
async def test_ttl_none_behavior(self):
"""Test that ttl_seconds=None works like no TTL."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=None)
async def no_ttl_via_none(x: int) -> int:
nonlocal call_count
call_count += 1
return x**2
# First call
result1 = await no_ttl_via_none(3)
assert result1 == 9
assert call_count == 1
# Wait (would expire if there was TTL)
await asyncio.sleep(0.1)
# Second call - should still use cache
result2 = await no_ttl_via_none(3)
assert result2 == 9
assert call_count == 1 # No additional call
# Check cache info
info = no_ttl_via_none.cache_info()
assert info["ttl_seconds"] is None
@pytest.mark.asyncio
async def test_cache_options_comparison(self):
"""Test different cache options work as expected."""
ttl_calls = 0
no_ttl_calls = 0
@async_ttl_cache(maxsize=10, ttl_seconds=1) # With TTL
async def ttl_function(x: int) -> int:
nonlocal ttl_calls
ttl_calls += 1
return x * 10
@async_cache(maxsize=10) # Process-level cache (no TTL)
async def process_function(x: int) -> int:
nonlocal no_ttl_calls
no_ttl_calls += 1
return x * 10
# Both should cache initially
await ttl_function(3)
await process_function(3)
assert ttl_calls == 1
assert no_ttl_calls == 1
# Immediate second calls - both should use cache
await ttl_function(3)
await process_function(3)
assert ttl_calls == 1
assert no_ttl_calls == 1
# Wait for TTL to expire
await asyncio.sleep(1.1)
# After TTL expiry
await ttl_function(3) # Should call function again
await process_function(3) # Should still use cache
assert ttl_calls == 2 # TTL cache expired, called again
assert no_ttl_calls == 1 # Process cache never expires

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,8 @@
[tool.poetry]
name = "autogpt-libs"
version = "0.2.0"
description = "Shared libraries across AutoGPT Platform"
authors = ["AutoGPT team <info@agpt.co>"]
description = "Shared libraries across NextGen AutoGPT"
authors = ["Aarushi <aarushik93@gmail.com>"]
readme = "README.md"
packages = [{ include = "autogpt_libs" }]
@@ -10,20 +10,20 @@ packages = [{ include = "autogpt_libs" }]
python = ">=3.10,<4.0"
colorama = "^0.4.6"
expiringdict = "^1.2.2"
fastapi = "^0.116.1"
google-cloud-logging = "^3.12.1"
launchdarkly-server-sdk = "^9.12.0"
pydantic = "^2.11.7"
pydantic-settings = "^2.10.1"
pydantic = "^2.11.4"
pydantic-settings = "^2.9.1"
pyjwt = "^2.10.1"
pytest-asyncio = "^1.1.0"
pytest-mock = "^3.14.1"
redis = "^6.2.0"
supabase = "^2.16.0"
uvicorn = "^0.35.0"
pytest-asyncio = "^0.26.0"
pytest-mock = "^3.14.0"
supabase = "^2.15.1"
launchdarkly-server-sdk = "^9.11.1"
fastapi = "^0.115.12"
uvicorn = "^0.34.3"
[tool.poetry.group.dev.dependencies]
ruff = "^0.12.9"
redis = "^5.2.1"
ruff = "^0.11.10"
[build-system]
requires = ["poetry-core"]

View File

@@ -1,52 +0,0 @@
# Development and testing files
**/__pycache__
**/*.pyc
**/*.pyo
**/*.pyd
**/.Python
**/env/
**/venv/
**/.venv/
**/pip-log.txt
**/.pytest_cache/
**/test-results/
**/snapshots/
**/test/
# IDE and editor files
**/.vscode/
**/.idea/
**/*.swp
**/*.swo
*~
# OS files
.DS_Store
Thumbs.db
# Logs
**/*.log
**/logs/
# Git
.git/
.gitignore
# Documentation
**/*.md
!README.md
# Local development files
.env
.env.local
**/.env.test
# Build artifacts
**/dist/
**/build/
**/target/
# Docker files (avoid recursion)
Dockerfile*
docker-compose*
.dockerignore

View File

@@ -1,9 +1,3 @@
# Backend Configuration
# This file contains environment variables that MUST be set for the AutoGPT platform
# Variables with working defaults in settings.py are not included here
## ===== REQUIRED DATABASE CONFIGURATION ===== ##
# PostgreSQL Database Connection
DB_USER=postgres
DB_PASS=your-super-secret-and-long-postgres-password
DB_NAME=postgres
@@ -16,50 +10,72 @@ DB_SCHEMA=platform
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
PRISMA_SCHEMA="postgres/schema.prisma"
ENABLE_AUTH=true
## ===== REQUIRED SERVICE CREDENTIALS ===== ##
# Redis Configuration
# EXECUTOR
NUM_GRAPH_WORKERS=10
BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]
# generate using `from cryptography.fernet import Fernet;Fernet.generate_key().decode()`
ENCRYPTION_KEY='dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw='
UNSUBSCRIBE_SECRET_KEY = 'HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio='
REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_PASSWORD=password
# RabbitMQ Credentials
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
ENABLE_CREDIT=false
STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=
# Supabase Authentication
# What environment things should be logged under: local dev or prod
APP_ENV=local
# What environment to behave as: "local" or "cloud"
BEHAVE_AS=local
PYRO_HOST=localhost
SENTRY_DSN=
# Email For Postmark so we can send emails
POSTMARK_SERVER_API_TOKEN=
POSTMARK_SENDER_EMAIL=invalid@invalid.com
POSTMARK_WEBHOOK_TOKEN=
## User auth with Supabase is required for any of the 3rd party integrations with auth to work.
ENABLE_AUTH=true
SUPABASE_URL=http://localhost:8000
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
## ===== REQUIRED SECURITY KEYS ===== ##
# Generate using: from cryptography.fernet import Fernet;Fernet.generate_key().decode()
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
# RabbitMQ credentials -- Used for communication between services
RABBITMQ_HOST=localhost
RABBITMQ_PORT=5672
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
# Platform URLs (set these for webhooks and OAuth to work)
PLATFORM_BASE_URL=http://localhost:8000
FRONTEND_BASE_URL=http://localhost:3000
# Media Storage (required for marketplace and library functionality)
## GCS bucket is required for marketplace and library functionality
MEDIA_GCS_BUCKET_NAME=
## ===== API KEYS AND OAUTH CREDENTIALS ===== ##
# All API keys below are optional - only add what you need
## For local development, you may need to set FRONTEND_BASE_URL for the OAuth flow
## for integrations to work. Defaults to the value of PLATFORM_BASE_URL if not set.
# FRONTEND_BASE_URL=http://localhost:3000
# AI/LLM Services
OPENAI_API_KEY=
ANTHROPIC_API_KEY=
GROQ_API_KEY=
LLAMA_API_KEY=
AIML_API_KEY=
V0_API_KEY=
OPEN_ROUTER_API_KEY=
NVIDIA_API_KEY=
## PLATFORM_BASE_URL must be set to a *publicly accessible* URL pointing to your backend
## to use the platform's webhook-related functionality.
## If you are developing locally, you can use something like ngrok to get a publc URL
## and tunnel it to your locally running backend.
PLATFORM_BASE_URL=http://localhost:3000
## Cloudflare Turnstile (CAPTCHA) Configuration
## Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
## This is the backend secret key
TURNSTILE_SECRET_KEY=
## This is the verify URL
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
## == INTEGRATION CREDENTIALS == ##
# Each set of server side credentials is required for the corresponding 3rd party
# integration to work.
# OAuth Credentials
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
# e.g. http://localhost:3000/auth/integrations/oauth_callback
@@ -69,6 +85,7 @@ GITHUB_CLIENT_SECRET=
# Google OAuth App server credentials - https://console.cloud.google.com/apis/credentials, and enable gmail api and set scopes
# https://console.cloud.google.com/apis/credentials/consent ?project=<your_project_id>
# You'll need to add/enable the following scopes (minimum):
# https://console.developers.google.com/apis/api/gmail.googleapis.com/overview ?project=<your_project_id>
# https://console.cloud.google.com/apis/library/sheets.googleapis.com/ ?project=<your_project_id>
@@ -104,75 +121,87 @@ LINEAR_CLIENT_SECRET=
TODOIST_CLIENT_ID=
TODOIST_CLIENT_SECRET=
NOTION_CLIENT_ID=
NOTION_CLIENT_SECRET=
## ===== OPTIONAL API KEYS ===== ##
# Discord OAuth App credentials
# 1. Go to https://discord.com/developers/applications
# 2. Create a new application
# 3. Go to OAuth2 section and add redirect URI: http://localhost:3000/auth/integrations/oauth_callback
# 4. Copy Client ID and Client Secret below
DISCORD_CLIENT_ID=
DISCORD_CLIENT_SECRET=
# LLM
OPENAI_API_KEY=
ANTHROPIC_API_KEY=
AIML_API_KEY=
GROQ_API_KEY=
OPEN_ROUTER_API_KEY=
LLAMA_API_KEY=
# Reddit
# Go to https://www.reddit.com/prefs/apps and create a new app
# Choose "script" for the type
# Fill in the redirect uri as <your_frontend_url>/auth/integrations/oauth_callback, e.g. http://localhost:3000/auth/integrations/oauth_callback
REDDIT_CLIENT_ID=
REDDIT_CLIENT_SECRET=
REDDIT_USER_AGENT="AutoGPT:1.0 (by /u/autogpt)"
# Payment Processing
STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=
# Email Service (for sending notifications and confirmations)
POSTMARK_SERVER_API_TOKEN=
POSTMARK_SENDER_EMAIL=invalid@invalid.com
POSTMARK_WEBHOOK_TOKEN=
# Error Tracking
SENTRY_DSN=
# Cloudflare Turnstile (CAPTCHA) Configuration
# Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
# This is the backend secret key
TURNSTILE_SECRET_KEY=
# This is the verify URL
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
# Feature Flags
LAUNCH_DARKLY_SDK_KEY=
# Content Generation & Media
DID_API_KEY=
FAL_API_KEY=
IDEOGRAM_API_KEY=
REPLICATE_API_KEY=
REVID_API_KEY=
SCREENSHOTONE_API_KEY=
UNREAL_SPEECH_API_KEY=
# Data & Search Services
E2B_API_KEY=
EXA_API_KEY=
JINA_API_KEY=
MEM0_API_KEY=
OPENWEATHERMAP_API_KEY=
GOOGLE_MAPS_API_KEY=
# Communication Services
# Discord
DISCORD_BOT_TOKEN=
MEDIUM_API_KEY=
MEDIUM_AUTHOR_ID=
# SMTP/Email
SMTP_SERVER=
SMTP_PORT=
SMTP_USERNAME=
SMTP_PASSWORD=
# Business & Marketing Tools
# D-ID
DID_API_KEY=
# Open Weather Map
OPENWEATHERMAP_API_KEY=
# SMTP
SMTP_SERVER=
SMTP_PORT=
SMTP_USERNAME=
SMTP_PASSWORD=
# Medium
MEDIUM_API_KEY=
MEDIUM_AUTHOR_ID=
# Google Maps
GOOGLE_MAPS_API_KEY=
# Replicate
REPLICATE_API_KEY=
# Ideogram
IDEOGRAM_API_KEY=
# Fal
FAL_API_KEY=
# Exa
EXA_API_KEY=
# E2B
E2B_API_KEY=
# Mem0
MEM0_API_KEY=
# Nvidia
NVIDIA_API_KEY=
# Apollo
APOLLO_API_KEY=
ENRICHLAYER_API_KEY=
AYRSHARE_API_KEY=
AYRSHARE_JWT_KEY=
# SmartLead
SMARTLEAD_API_KEY=
# ZeroBounce
ZEROBOUNCE_API_KEY=
# Other Services
AUTOMOD_API_KEY=
## ===== OPTIONAL API KEYS END ===== ##
# Logging Configuration
LOG_LEVEL=INFO
ENABLE_CLOUD_LOGGING=false
ENABLE_FILE_LOGGING=false
# Use to manually set the log directory
# LOG_DIR=./logs

View File

@@ -1,4 +1,3 @@
.env
database.db
database.db-journal
dev.db

View File

@@ -1,34 +1,31 @@
FROM debian:13-slim AS builder
FROM python:3.11.10-slim-bookworm AS builder
# Set environment variables
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONDONTWRITEBYTECODE 1
ENV PYTHONUNBUFFERED 1
WORKDIR /app
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
# Update package list and install Python and build dependencies
RUN apt-get update --allow-releaseinfo-change --fix-missing \
&& apt-get install -y \
python3.13 \
python3.13-dev \
python3.13-venv \
python3-pip \
build-essential \
libpq5 \
libz-dev \
libssl-dev \
postgresql-client
RUN apt-get update --allow-releaseinfo-change --fix-missing
# Install build dependencies
RUN apt-get install -y build-essential
RUN apt-get install -y libpq5
RUN apt-get install -y libz-dev
RUN apt-get install -y libssl-dev
RUN apt-get install -y postgresql-client
ENV POETRY_HOME=/opt/poetry
ENV POETRY_NO_INTERACTION=1
ENV POETRY_VIRTUALENVS_CREATE=true
ENV POETRY_VIRTUALENVS_IN_PROJECT=true
ENV POETRY_VIRTUALENVS_CREATE=false
ENV PATH=/opt/poetry/bin:$PATH
RUN pip3 install poetry --break-system-packages
# Upgrade pip and setuptools to fix security vulnerabilities
RUN pip3 install --upgrade pip setuptools
RUN pip3 install poetry
# Copy and install dependencies
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
@@ -40,30 +37,27 @@ RUN poetry install --no-ansi --no-root
COPY autogpt_platform/backend/schema.prisma ./
RUN poetry run prisma generate
FROM debian:13-slim AS server_dependencies
FROM python:3.11.10-slim-bookworm AS server_dependencies
WORKDIR /app
ENV POETRY_HOME=/opt/poetry \
POETRY_NO_INTERACTION=1 \
POETRY_VIRTUALENVS_CREATE=true \
POETRY_VIRTUALENVS_IN_PROJECT=true \
DEBIAN_FRONTEND=noninteractive
POETRY_VIRTUALENVS_CREATE=false
ENV PATH=/opt/poetry/bin:$PATH
# Install Python without upgrading system-managed packages
RUN apt-get update && apt-get install -y \
python3.13 \
python3-pip
# Upgrade pip and setuptools to fix security vulnerabilities
RUN pip3 install --upgrade pip setuptools
# Copy only necessary files from builder
COPY --from=builder /app /app
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
COPY --from=builder /usr/local/lib/python3.11 /usr/local/lib/python3.11
COPY --from=builder /usr/local/bin /usr/local/bin
# Copy Prisma binaries
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
ENV PATH="/app/.venv/bin:$PATH"
RUN mkdir -p /app/autogpt_platform/autogpt_libs
RUN mkdir -p /app/autogpt_platform/backend
@@ -74,12 +68,6 @@ COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.tom
WORKDIR /app/autogpt_platform/backend
FROM server_dependencies AS migrate
# Migration stage only needs schema and migrations - much lighter than full backend
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
FROM server_dependencies AS server
COPY autogpt_platform/backend /app/autogpt_platform/backend

View File

@@ -1,150 +0,0 @@
# Test Data Scripts
This directory contains scripts for creating and updating test data in the AutoGPT Platform database, specifically designed to test the materialized views for the store functionality.
## Scripts
### test_data_creator.py
Creates a comprehensive set of test data including:
- Users with profiles
- Agent graphs, nodes, and executions
- Store listings with multiple versions
- Reviews and ratings
- Library agents
- Integration webhooks
- Onboarding data
- Credit transactions
**Image/Video Domains Used:**
- Images: `picsum.photos` (for all image URLs)
- Videos: `youtube.com` (for store listing videos)
### test_data_updater.py
Updates existing test data to simulate real-world changes:
- Adds new agent graph executions
- Creates new store listing reviews
- Updates store listing versions
- Adds credit transactions
- Refreshes materialized views
### check_db.py
Tests and verifies materialized views functionality:
- Checks pg_cron job status (for automatic refresh)
- Displays current materialized view counts
- Adds test data (executions and reviews)
- Creates store listings if none exist
- Manually refreshes materialized views
- Compares before/after counts to verify updates
- Provides a summary of test results
## Materialized Views
The scripts test three key database views:
1. **mv_agent_run_counts**: Tracks execution counts by agent
2. **mv_review_stats**: Tracks review statistics (count, average rating) by store listing
3. **StoreAgent**: A view that combines store listing data with execution counts and ratings for display
The materialized views (mv_agent_run_counts and mv_review_stats) are automatically refreshed every 15 minutes via pg_cron, or can be manually refreshed using the `refresh_store_materialized_views()` function.
## Usage
### Prerequisites
1. Ensure the database is running:
```bash
docker compose up -d
# or for test database:
docker compose -f docker-compose.test.yaml --env-file ../.env up -d
```
2. Run database migrations:
```bash
poetry run prisma migrate deploy
```
### Running the Scripts
#### Option 1: Use the helper script (from backend directory)
```bash
poetry run python run_test_data.py
```
#### Option 2: Run individually
```bash
# From backend/test directory:
# Create initial test data
poetry run python test_data_creator.py
# Update data to test materialized view changes
poetry run python test_data_updater.py
# From backend directory:
# Test materialized views functionality
poetry run python check_db.py
# Check store data status
poetry run python check_store_data.py
```
#### Option 3: Use the shell script (from backend directory)
```bash
./run_test_data_scripts.sh
```
### Manual Materialized View Refresh
To manually refresh the materialized views:
```sql
SELECT refresh_store_materialized_views();
```
## Configuration
The scripts use the database configuration from your `.env` file:
- `DATABASE_URL`: PostgreSQL connection string
- Database should have the platform schema
## Data Generation Limits
Configured in `test_data_creator.py`:
- 100 users
- 100 agent blocks
- 1-5 graphs per user
- 2-5 nodes per graph
- 1-5 presets per user
- 1-10 library agents per user
- 1-20 executions per graph
- 1-5 reviews per store listing version
## Notes
- All image URLs use `picsum.photos` for consistency with Next.js image configuration
- The scripts create realistic relationships between entities
- Materialized views are refreshed at the end of each script
- Data is designed to test both happy paths and edge cases
## Troubleshooting
### Reviews and StoreAgent view showing 0
If `check_db.py` shows that reviews remain at 0 and StoreAgent view shows 0 store agents:
1. **No store listings exist**: The script will automatically create test store listings if none exist
2. **No approved versions**: Store listings need approved versions to appear in the StoreAgent view
3. **Check with `check_store_data.py`**: This script provides detailed information about:
- Total store listings
- Store listing versions by status
- Existing reviews
- StoreAgent view contents
- Agent graph executions
### pg_cron not installed
The warning "pg_cron extension is not installed" is normal in local development environments. The materialized views can still be refreshed manually using the `refresh_store_materialized_views()` function, which all scripts do automatically.
### Common Issues
- **Type errors with None values**: Fixed in the latest version of check_db.py by using `or 0` for nullable numeric fields
- **Missing relations**: Ensure you're using the correct field names (e.g., `StoreListing` not `storeListing` in includes)
- **Column name mismatches**: The database uses camelCase for column names (e.g., `agentGraphId` not `agent_graph_id`)

View File

@@ -1,10 +1,6 @@
import logging
from typing import TYPE_CHECKING
from dotenv import load_dotenv
load_dotenv()
if TYPE_CHECKING:
from backend.util.process import AppProcess
@@ -42,12 +38,12 @@ def main(**kwargs):
from backend.server.ws_api import WebsocketServer
run_processes(
DatabaseManager().set_log_level("warning"),
DatabaseManager(),
ExecutionManager(),
Scheduler(),
NotificationManager(),
WebsocketServer(),
AgentServer(),
ExecutionManager(),
**kwargs,
)

View File

@@ -1,14 +1,10 @@
import functools
import importlib
import logging
import os
import re
from pathlib import Path
from typing import TYPE_CHECKING, TypeVar
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from backend.data.block import Block
@@ -18,27 +14,14 @@ T = TypeVar("T")
@functools.cache
def load_all_blocks() -> dict[str, type["Block"]]:
from backend.data.block import Block
from backend.util.settings import Config
# Check if example blocks should be loaded from settings
config = Config()
load_examples = config.enable_example_blocks
# Dynamically load all modules under backend.blocks
current_dir = Path(__file__).parent
modules = []
for f in current_dir.rglob("*.py"):
if not f.is_file() or f.name == "__init__.py" or f.name.startswith("test_"):
continue
# Skip examples directory if not enabled
relative_path = f.relative_to(current_dir)
if not load_examples and relative_path.parts[0] == "examples":
continue
module_path = str(relative_path)[:-3].replace(os.path.sep, ".")
modules.append(module_path)
modules = [
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
for f in current_dir.rglob("*.py")
if f.is_file() and f.name != "__init__.py"
]
for module in modules:
if not re.match("^[a-z0-9_.]+$", module):
raise ValueError(
@@ -103,15 +86,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
available_blocks[block.id] = block_cls
# Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets
from backend.data.block import is_block_auth_configured
filtered_blocks = {}
for block_id, block_cls in available_blocks.items():
if is_block_auth_configured(block_cls):
filtered_blocks[block_id] = block_cls
return filtered_blocks
return available_blocks
__all__ = ["load_all_blocks"]

View File

@@ -1,8 +1,7 @@
import asyncio
import logging
from typing import Any, Optional
from pydantic import JsonValue
from backend.data.block import (
Block,
BlockCategory,
@@ -13,11 +12,10 @@ from backend.data.block import (
get_block,
)
from backend.data.execution import ExecutionStatus
from backend.data.model import NodeExecutionStats, SchemaField
from backend.util.json import validate_with_jsonschema
from backend.util.retry import func_retry
from backend.data.model import CredentialsMetaInput, SchemaField
from backend.util import json
_logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
class AgentExecutorBlock(Block):
@@ -33,9 +31,9 @@ class AgentExecutorBlock(Block):
input_schema: dict = SchemaField(description="Input schema for the graph")
output_schema: dict = SchemaField(description="Output schema for the graph")
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = SchemaField(
default=None, hidden=True
)
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = SchemaField(default=None, hidden=True)
@classmethod
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
@@ -52,7 +50,7 @@ class AgentExecutorBlock(Block):
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
return validate_with_jsonschema(cls.get_input_schema(data), data)
return json.validate_with_jsonschema(cls.get_input_schema(data), data)
class Output(BlockSchema):
pass
@@ -76,17 +74,8 @@ class AgentExecutorBlock(Block):
graph_version=input_data.graph_version,
user_id=input_data.user_id,
inputs=input_data.inputs,
nodes_input_masks=input_data.nodes_input_masks,
)
logger = execution_utils.LogMetadata(
logger=_logger,
user_id=input_data.user_id,
graph_eid=graph_exec.id,
graph_id=input_data.graph_id,
node_eid="*",
node_id="*",
block_name=self.name,
node_credentials_input_map=input_data.node_credentials_input_map,
use_db_query=False,
)
try:
@@ -95,17 +84,21 @@ class AgentExecutorBlock(Block):
graph_version=input_data.graph_version,
graph_exec_id=graph_exec.id,
user_id=input_data.user_id,
logger=logger,
):
yield name, data
except BaseException as e:
await self._stop(
graph_exec_id=graph_exec.id,
user_id=input_data.user_id,
logger=logger,
)
except asyncio.CancelledError:
logger.warning(
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} failed: {e.__class__.__name__} {str(e)}; execution is stopped."
f"Execution of graph {input_data.graph_id} version {input_data.graph_version} was cancelled."
)
await execution_utils.stop_graph_execution(
graph_exec.id, use_db_query=False
)
except Exception as e:
logger.error(
f"Execution of graph {input_data.graph_id} version {input_data.graph_version} failed: {e}, stopping execution."
)
await execution_utils.stop_graph_execution(
graph_exec.id, use_db_query=False
)
raise
@@ -115,7 +108,6 @@ class AgentExecutorBlock(Block):
graph_version: int,
graph_exec_id: str,
user_id: str,
logger,
) -> BlockOutput:
from backend.data.execution import ExecutionEventType
@@ -125,7 +117,6 @@ class AgentExecutorBlock(Block):
log_id = f"Graph #{graph_id}-V{graph_version}, exec-id: {graph_exec_id}"
logger.info(f"Starting execution of {log_id}")
yielded_node_exec_ids = set()
async for event in event_bus.listen(
user_id=user_id,
@@ -145,26 +136,12 @@ class AgentExecutorBlock(Block):
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
# If the graph execution is COMPLETED, TERMINATED, or FAILED,
# we can stop listening for further events.
self.merge_stats(
NodeExecutionStats(
extra_cost=event.stats.cost if event.stats else 0,
extra_steps=event.stats.node_exec_count if event.stats else 0,
)
)
break
logger.debug(
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
)
if event.node_exec_id in yielded_node_exec_ids:
logger.warning(
f"{log_id} received duplicate event for node execution {event.node_exec_id}"
)
continue
else:
yielded_node_exec_ids.add(event.node_exec_id)
if not event.block_id:
logger.warning(f"{log_id} received event without block_id {event}")
continue
@@ -183,25 +160,3 @@ class AgentExecutorBlock(Block):
f"Execution {log_id} produced {output_name}: {output_data}"
)
yield output_name, output_data
@func_retry
async def _stop(
self,
graph_exec_id: str,
user_id: str,
logger,
) -> None:
from backend.executor import utils as execution_utils
log_id = f"Graph exec-id: {graph_exec_id}"
logger.info(f"Stopping execution of {log_id}")
try:
await execution_utils.stop_graph_execution(
graph_exec_id=graph_exec_id,
user_id=user_id,
wait_timeout=3600,
)
logger.info(f"Execution {log_id} stopped successfully.")
except TimeoutError as e:
logger.error(f"Execution {log_id} stop timed out: {e}")

View File

@@ -166,7 +166,7 @@ class AIMusicGeneratorBlock(Block):
output_format=input_data.output_format,
normalization_strategy=input_data.normalization_strategy,
)
if result and isinstance(result, str) and result.startswith("http"):
if result and result != "No output received":
yield "result", result
return
else:

View File

@@ -53,7 +53,6 @@ class AudioTrack(str, Enum):
REFRESHER = ("Refresher",)
TOURIST = ("Tourist",)
TWIN_TYCHES = ("Twin Tyches",)
DONT_STOP_ME_ABSTRACT_FUTURE_BASS = ("Dont Stop Me Abstract Future Bass",)
@property
def audio_url(self):
@@ -79,7 +78,6 @@ class AudioTrack(str, Enum):
AudioTrack.REFRESHER: "https://cdn.tfrv.xyz/audio/refresher.mp3",
AudioTrack.TOURIST: "https://cdn.tfrv.xyz/audio/tourist.mp3",
AudioTrack.TWIN_TYCHES: "https://cdn.tfrv.xyz/audio/twin-tynches.mp3",
AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS: "https://cdn.revid.ai/audio/_dont-stop-me-abstract-future-bass.mp3",
}
return audio_urls[self]
@@ -107,7 +105,6 @@ class GenerationPreset(str, Enum):
MOVIE = ("Movie",)
STYLIZED_ILLUSTRATION = ("Stylized Illustration",)
MANGA = ("Manga",)
DEFAULT = ("DEFAULT",)
class Voice(str, Enum):
@@ -117,7 +114,6 @@ class Voice(str, Enum):
JESSICA = "Jessica"
CHARLOTTE = "Charlotte"
CALLUM = "Callum"
EVA = "Eva"
@property
def voice_id(self):
@@ -128,7 +124,6 @@ class Voice(str, Enum):
Voice.JESSICA: "cgSgspJ2msm6clMCkdW9",
Voice.CHARLOTTE: "XB0fDUnXU5powFXDhCwa",
Voice.CALLUM: "N2lVS1w4EtoT3dr4eOWO",
Voice.EVA: "FGY2WhTYpPnrIDTdsKH5",
}
return voice_id_map[self]
@@ -146,8 +141,6 @@ logger = logging.getLogger(__name__)
class AIShortformVideoCreatorBlock(Block):
"""Creates a shortform texttovideo clip using stock or AI imagery."""
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal[ProviderName.REVID], Literal["api_key"]
@@ -191,58 +184,6 @@ class AIShortformVideoCreatorBlock(Block):
video_url: str = SchemaField(description="The URL of the created video")
error: str = SchemaField(description="Error message if the request failed")
async def create_webhook(self) -> tuple[str, str]:
"""Create a new webhook URL for receiving notifications."""
url = "https://webhook.site/token"
headers = {"Accept": "application/json", "Content-Type": "application/json"}
response = await Requests().post(url, headers=headers)
webhook_data = response.json()
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
"""Create a video using the Revid API."""
url = "https://www.revid.ai/api/public/v2/render"
headers = {"key": api_key.get_secret_value()}
response = await Requests().post(url, json=payload, headers=headers)
logger.debug(
f"API Response Status Code: {response.status}, Content: {response.text}"
)
return response.json()
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
"""Check the status of a video creation job."""
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
headers = {"key": api_key.get_secret_value()}
response = await Requests().get(url, headers=headers)
return response.json()
async def wait_for_video(
self,
api_key: SecretStr,
pid: str,
max_wait_time: int = 1000,
) -> str:
"""Wait for video creation to complete and return the video URL."""
start_time = time.time()
while time.time() - start_time < max_wait_time:
status = await self.check_video_status(api_key, pid)
logger.debug(f"Video status: {status}")
if status.get("status") == "ready" and "videoUrl" in status:
return status["videoUrl"]
elif status.get("status") == "error":
error_message = status.get("error", "Unknown error occurred")
logger.error(f"Video creation failed: {error_message}")
raise ValueError(f"Video creation failed: {error_message}")
elif status.get("status") in ["FAILED", "CANCELED"]:
logger.error(f"Video creation failed: {status.get('message')}")
raise ValueError(f"Video creation failed: {status.get('message')}")
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(
id="361697fb-0c4f-4feb-aed3-8320c88c771b",
@@ -261,22 +202,70 @@ class AIShortformVideoCreatorBlock(Block):
"voice": Voice.LILY,
"video_style": VisualMediaType.STOCK_VIDEOS,
},
test_output=("video_url", "https://example.com/video.mp4"),
test_output=(
"video_url",
"https://example.com/video.mp4",
),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"create_webhook": lambda: (
"test_uuid",
"https://webhook.site/test_uuid",
),
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/video.mp4",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
"create_video": lambda api_key, payload: {"pid": "test_pid"},
"wait_for_video": lambda api_key, pid, webhook_token, max_wait_time=1000: "https://example.com/video.mp4",
},
test_credentials=TEST_CREDENTIALS,
)
async def create_webhook(self):
url = "https://webhook.site/token"
headers = {"Accept": "application/json", "Content-Type": "application/json"}
response = await Requests().post(url, headers=headers)
webhook_data = response.json()
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
url = "https://www.revid.ai/api/public/v2/render"
headers = {"key": api_key.get_secret_value()}
response = await Requests().post(url, json=payload, headers=headers)
logger.debug(
f"API Response Status Code: {response.status}, Content: {response.text}"
)
return response.json()
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
headers = {"key": api_key.get_secret_value()}
response = await Requests().get(url, headers=headers)
return response.json()
async def wait_for_video(
self,
api_key: SecretStr,
pid: str,
webhook_token: str,
max_wait_time: int = 1000,
) -> str:
start_time = time.time()
while time.time() - start_time < max_wait_time:
status = await self.check_video_status(api_key, pid)
logger.debug(f"Video status: {status}")
if status.get("status") == "ready" and "videoUrl" in status:
return status["videoUrl"]
elif status.get("status") == "error":
error_message = status.get("error", "Unknown error occurred")
logger.error(f"Video creation failed: {error_message}")
raise ValueError(f"Video creation failed: {error_message}")
elif status.get("status") in ["FAILED", "CANCELED"]:
logger.error(f"Video creation failed: {status.get('message')}")
raise ValueError(f"Video creation failed: {status.get('message')}")
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise TimeoutError("Video creation timed out")
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
@@ -284,18 +273,20 @@ class AIShortformVideoCreatorBlock(Block):
webhook_token, webhook_url = await self.create_webhook()
logger.debug(f"Webhook URL: {webhook_url}")
audio_url = input_data.background_music.audio_url
payload = {
"frameRate": input_data.frame_rate,
"resolution": input_data.resolution,
"frameDurationMultiplier": 18,
"webhook": None,
"webhook": webhook_url,
"creationParams": {
"mediaType": input_data.video_style,
"captionPresetName": "Wrap 1",
"selectedVoice": input_data.voice.voice_id,
"hasEnhancedGeneration": True,
"generationPreset": input_data.generation_preset.name,
"selectedAudio": input_data.background_music.value,
"selectedAudio": input_data.background_music,
"origin": "/create",
"inputText": input_data.script,
"flowType": "text-to-video",
@@ -311,7 +302,7 @@ class AIShortformVideoCreatorBlock(Block):
"selectedStoryStyle": {"value": "custom", "label": "Custom"},
"hasToGenerateVideos": input_data.video_style
!= VisualMediaType.STOCK_VIDEOS,
"audioUrl": input_data.background_music.audio_url,
"audioUrl": audio_url,
},
}
@@ -328,370 +319,8 @@ class AIShortformVideoCreatorBlock(Block):
logger.debug(
f"Video created with project ID: {pid}. Waiting for completion..."
)
video_url = await self.wait_for_video(credentials.api_key, pid)
video_url = await self.wait_for_video(
credentials.api_key, pid, webhook_token
)
logger.debug(f"Video ready: {video_url}")
yield "video_url", video_url
class AIAdMakerVideoCreatorBlock(Block):
"""Generates a 30second vertical AI advert using optional usersupplied imagery."""
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal[ProviderName.REVID], Literal["api_key"]
] = CredentialsField(
description="Credentials for Revid.ai API access.",
)
script: str = SchemaField(
description="Short advertising copy. Line breaks create new scenes.",
placeholder="Introducing Foobar [show product photo] the gadget that does it all.",
)
ratio: str = SchemaField(description="Aspect ratio", default="9 / 16")
target_duration: int = SchemaField(
description="Desired length of the ad in seconds.", default=30
)
voice: Voice = SchemaField(
description="Narration voice", default=Voice.EVA, placeholder=Voice.EVA
)
background_music: AudioTrack = SchemaField(
description="Background track",
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS,
)
input_media_urls: list[str] = SchemaField(
description="List of image URLs to feature in the advert.", default=[]
)
use_only_provided_media: bool = SchemaField(
description="Restrict visuals to supplied images only.", default=True
)
class Output(BlockSchema):
video_url: str = SchemaField(description="URL of the finished advert")
error: str = SchemaField(description="Error message on failure")
async def create_webhook(self) -> tuple[str, str]:
"""Create a new webhook URL for receiving notifications."""
url = "https://webhook.site/token"
headers = {"Accept": "application/json", "Content-Type": "application/json"}
response = await Requests().post(url, headers=headers)
webhook_data = response.json()
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
"""Create a video using the Revid API."""
url = "https://www.revid.ai/api/public/v2/render"
headers = {"key": api_key.get_secret_value()}
response = await Requests().post(url, json=payload, headers=headers)
logger.debug(
f"API Response Status Code: {response.status}, Content: {response.text}"
)
return response.json()
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
"""Check the status of a video creation job."""
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
headers = {"key": api_key.get_secret_value()}
response = await Requests().get(url, headers=headers)
return response.json()
async def wait_for_video(
self,
api_key: SecretStr,
pid: str,
max_wait_time: int = 1000,
) -> str:
"""Wait for video creation to complete and return the video URL."""
start_time = time.time()
while time.time() - start_time < max_wait_time:
status = await self.check_video_status(api_key, pid)
logger.debug(f"Video status: {status}")
if status.get("status") == "ready" and "videoUrl" in status:
return status["videoUrl"]
elif status.get("status") == "error":
error_message = status.get("error", "Unknown error occurred")
logger.error(f"Video creation failed: {error_message}")
raise ValueError(f"Video creation failed: {error_message}")
elif status.get("status") in ["FAILED", "CANCELED"]:
logger.error(f"Video creation failed: {status.get('message')}")
raise ValueError(f"Video creation failed: {status.get('message')}")
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(
id="58bd2a19-115d-4fd1-8ca4-13b9e37fa6a0",
description="Creates an AIgenerated 30second advert (text + images)",
categories={BlockCategory.MARKETING, BlockCategory.AI},
input_schema=AIAdMakerVideoCreatorBlock.Input,
output_schema=AIAdMakerVideoCreatorBlock.Output,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"script": "Test product launch!",
"input_media_urls": [
"https://cdn.revid.ai/uploads/1747076315114-image.png",
],
},
test_output=("video_url", "https://example.com/ad.mp4"),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
"https://webhook.site/test_uuid",
),
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/ad.mp4",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
webhook_token, webhook_url = await self.create_webhook()
payload = {
"webhook": webhook_url,
"creationParams": {
"targetDuration": input_data.target_duration,
"ratio": input_data.ratio,
"mediaType": "aiVideo",
"inputText": input_data.script,
"flowType": "text-to-video",
"slug": "ai-ad-generator",
"slugNew": "",
"isCopiedFrom": False,
"hasToGenerateVoice": True,
"hasToTranscript": False,
"hasToSearchMedia": True,
"hasAvatar": False,
"hasWebsiteRecorder": False,
"hasTextSmallAtBottom": False,
"selectedAudio": input_data.background_music.value,
"selectedVoice": input_data.voice.voice_id,
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
"selectedAvatarType": "video/mp4",
"websiteToRecord": "",
"hasToGenerateCover": True,
"nbGenerations": 1,
"disableCaptions": False,
"mediaMultiplier": "medium",
"characters": [],
"captionPresetName": "Revid",
"sourceType": "contentScraping",
"selectedStoryStyle": {"value": "custom", "label": "General"},
"generationPreset": "DEFAULT",
"hasToGenerateMusic": False,
"isOptimizedForChinese": False,
"generationUserPrompt": "",
"enableNsfwFilter": False,
"addStickers": False,
"typeMovingImageAnim": "dynamic",
"hasToGenerateSoundEffects": False,
"forceModelType": "gpt-image-1",
"selectedCharacters": [],
"lang": "",
"voiceSpeed": 1,
"disableAudio": False,
"disableVoice": False,
"useOnlyProvidedMedia": input_data.use_only_provided_media,
"imageGenerationModel": "ultra",
"videoGenerationModel": "pro",
"hasEnhancedGeneration": True,
"hasEnhancedGenerationPro": True,
"inputMedias": [
{"url": url, "title": "", "type": "image"}
for url in input_data.input_media_urls
],
"hasToGenerateVideos": True,
"audioUrl": input_data.background_music.audio_url,
"watermark": None,
},
}
response = await self.create_video(credentials.api_key, payload)
pid = response.get("pid")
if not pid:
raise RuntimeError("Failed to create video: No project ID returned")
video_url = await self.wait_for_video(credentials.api_key, pid)
yield "video_url", video_url
class AIScreenshotToVideoAdBlock(Block):
"""Creates an advert where the supplied screenshot is narrated by an AI avatar."""
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal[ProviderName.REVID], Literal["api_key"]
] = CredentialsField(description="Revid.ai API key")
script: str = SchemaField(
description="Narration that will accompany the screenshot.",
placeholder="Check out these amazing stats!",
)
screenshot_url: str = SchemaField(
description="Screenshot or image URL to showcase."
)
ratio: str = SchemaField(default="9 / 16")
target_duration: int = SchemaField(default=30)
voice: Voice = SchemaField(default=Voice.EVA)
background_music: AudioTrack = SchemaField(
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS
)
class Output(BlockSchema):
video_url: str = SchemaField(description="Rendered video URL")
error: str = SchemaField(description="Error, if encountered")
async def create_webhook(self) -> tuple[str, str]:
"""Create a new webhook URL for receiving notifications."""
url = "https://webhook.site/token"
headers = {"Accept": "application/json", "Content-Type": "application/json"}
response = await Requests().post(url, headers=headers)
webhook_data = response.json()
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
"""Create a video using the Revid API."""
url = "https://www.revid.ai/api/public/v2/render"
headers = {"key": api_key.get_secret_value()}
response = await Requests().post(url, json=payload, headers=headers)
logger.debug(
f"API Response Status Code: {response.status}, Content: {response.text}"
)
return response.json()
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
"""Check the status of a video creation job."""
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
headers = {"key": api_key.get_secret_value()}
response = await Requests().get(url, headers=headers)
return response.json()
async def wait_for_video(
self,
api_key: SecretStr,
pid: str,
max_wait_time: int = 1000,
) -> str:
"""Wait for video creation to complete and return the video URL."""
start_time = time.time()
while time.time() - start_time < max_wait_time:
status = await self.check_video_status(api_key, pid)
logger.debug(f"Video status: {status}")
if status.get("status") == "ready" and "videoUrl" in status:
return status["videoUrl"]
elif status.get("status") == "error":
error_message = status.get("error", "Unknown error occurred")
logger.error(f"Video creation failed: {error_message}")
raise ValueError(f"Video creation failed: {error_message}")
elif status.get("status") in ["FAILED", "CANCELED"]:
logger.error(f"Video creation failed: {status.get('message')}")
raise ValueError(f"Video creation failed: {status.get('message')}")
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(
id="0f3e4635-e810-43d9-9e81-49e6f4e83b7c",
description="Turns a screenshot into an engaging, avatarnarrated video advert.",
categories={BlockCategory.AI, BlockCategory.MARKETING},
input_schema=AIScreenshotToVideoAdBlock.Input,
output_schema=AIScreenshotToVideoAdBlock.Output,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"script": "Amazing numbers!",
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
},
test_output=("video_url", "https://example.com/screenshot.mp4"),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
"https://webhook.site/test_uuid",
),
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/screenshot.mp4",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
webhook_token, webhook_url = await self.create_webhook()
payload = {
"webhook": webhook_url,
"creationParams": {
"targetDuration": input_data.target_duration,
"ratio": input_data.ratio,
"mediaType": "aiVideo",
"hasAvatar": True,
"removeAvatarBackground": True,
"inputText": input_data.script,
"flowType": "text-to-video",
"slug": "ai-ad-generator",
"slugNew": "screenshot-to-video-ad",
"isCopiedFrom": "ai-ad-generator",
"hasToGenerateVoice": True,
"hasToTranscript": False,
"hasToSearchMedia": True,
"hasWebsiteRecorder": False,
"hasTextSmallAtBottom": False,
"selectedAudio": input_data.background_music.value,
"selectedVoice": input_data.voice.voice_id,
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
"selectedAvatarType": "video/mp4",
"websiteToRecord": "",
"hasToGenerateCover": True,
"nbGenerations": 1,
"disableCaptions": False,
"mediaMultiplier": "medium",
"characters": [],
"captionPresetName": "Revid",
"sourceType": "contentScraping",
"selectedStoryStyle": {"value": "custom", "label": "General"},
"generationPreset": "DEFAULT",
"hasToGenerateMusic": False,
"isOptimizedForChinese": False,
"generationUserPrompt": "",
"enableNsfwFilter": False,
"addStickers": False,
"typeMovingImageAnim": "dynamic",
"hasToGenerateSoundEffects": False,
"forceModelType": "gpt-image-1",
"selectedCharacters": [],
"lang": "",
"voiceSpeed": 1,
"disableAudio": False,
"disableVoice": False,
"useOnlyProvidedMedia": True,
"imageGenerationModel": "ultra",
"videoGenerationModel": "ultra",
"hasEnhancedGeneration": True,
"hasEnhancedGenerationPro": True,
"inputMedias": [
{"url": input_data.screenshot_url, "title": "", "type": "image"}
],
"hasToGenerateVideos": True,
"audioUrl": input_data.background_music.audio_url,
"watermark": None,
},
}
response = await self.create_video(credentials.api_key, payload)
pid = response.get("pid")
if not pid:
raise RuntimeError("Failed to create video: No project ID returned")
video_url = await self.wait_for_video(credentials.api_key, pid)
yield "video_url", video_url

File diff suppressed because it is too large Load Diff

View File

@@ -1,323 +0,0 @@
from os import getenv
from uuid import uuid4
import pytest
from backend.sdk import APIKeyCredentials, SecretStr
from ._api import (
TableFieldType,
WebhookFilters,
WebhookSpecification,
create_base,
create_field,
create_record,
create_table,
create_webhook,
delete_multiple_records,
delete_record,
delete_webhook,
get_record,
list_bases,
list_records,
list_webhook_payloads,
update_field,
update_multiple_records,
update_record,
update_table,
)
@pytest.mark.asyncio
async def test_create_update_table():
key = getenv("AIRTABLE_API_KEY")
if not key:
return pytest.skip("AIRTABLE_API_KEY is not set")
credentials = APIKeyCredentials(
provider="airtable",
api_key=SecretStr(key),
)
postfix = uuid4().hex[:4]
workspace_id = "wsphuHmfllg7V3Brd"
response = await create_base(credentials, workspace_id, "API Testing Base")
assert response is not None, f"Checking create base response: {response}"
assert (
response.get("id") is not None
), f"Checking create base response id: {response}"
base_id = response.get("id")
assert base_id is not None, f"Checking create base response id: {base_id}"
response = await list_bases(credentials)
assert response is not None, f"Checking list bases response: {response}"
assert "API Testing Base" in [
base.get("name") for base in response.get("bases", [])
], f"Checking list bases response bases: {response}"
table_name = f"test_table_{postfix}"
table_fields = [{"name": "test_field", "type": "singleLineText"}]
table = await create_table(credentials, base_id, table_name, table_fields)
assert table.get("name") == table_name
table_id = table.get("id")
assert table_id is not None
table_name = f"test_table_updated_{postfix}"
table_description = "test_description_updated"
table = await update_table(
credentials,
base_id,
table_id,
table_name=table_name,
table_description=table_description,
)
assert table.get("name") == table_name
assert table.get("description") == table_description
@pytest.mark.asyncio
async def test_invalid_field_type():
key = getenv("AIRTABLE_API_KEY")
if not key:
return pytest.skip("AIRTABLE_API_KEY is not set")
credentials = APIKeyCredentials(
provider="airtable",
api_key=SecretStr(key),
)
postfix = uuid4().hex[:4]
base_id = "appZPxegHEU3kDc1S"
table_name = f"test_table_{postfix}"
table_fields = [{"name": "test_field", "type": "notValid"}]
with pytest.raises(AssertionError):
await create_table(credentials, base_id, table_name, table_fields)
@pytest.mark.asyncio
async def test_create_and_update_field():
key = getenv("AIRTABLE_API_KEY")
if not key:
return pytest.skip("AIRTABLE_API_KEY is not set")
credentials = APIKeyCredentials(
provider="airtable",
api_key=SecretStr(key),
)
postfix = uuid4().hex[:4]
base_id = "appZPxegHEU3kDc1S"
table_name = f"test_table_{postfix}"
table_fields = [{"name": "test_field", "type": "singleLineText"}]
table = await create_table(credentials, base_id, table_name, table_fields)
assert table.get("name") == table_name
table_id = table.get("id")
assert table_id is not None
field_name = f"test_field_{postfix}"
field_type = TableFieldType.SINGLE_LINE_TEXT
field = await create_field(credentials, base_id, table_id, field_type, field_name)
assert field.get("name") == field_name
field_id = field.get("id")
assert field_id is not None
assert isinstance(field_id, str)
field_name = f"test_field_updated_{postfix}"
field = await update_field(credentials, base_id, table_id, field_id, field_name)
assert field.get("name") == field_name
field_description = "test_description_updated"
field = await update_field(
credentials, base_id, table_id, field_id, description=field_description
)
assert field.get("description") == field_description
@pytest.mark.asyncio
async def test_record_management():
key = getenv("AIRTABLE_API_KEY")
if not key:
return pytest.skip("AIRTABLE_API_KEY is not set")
credentials = APIKeyCredentials(
provider="airtable",
api_key=SecretStr(key),
)
postfix = uuid4().hex[:4]
base_id = "appZPxegHEU3kDc1S"
table_name = f"test_table_{postfix}"
table_fields = [{"name": "test_field", "type": "singleLineText"}]
table = await create_table(credentials, base_id, table_name, table_fields)
assert table.get("name") == table_name
table_id = table.get("id")
assert table_id is not None
# Create a record
record_fields = {"test_field": "test_value"}
record = await create_record(credentials, base_id, table_id, fields=record_fields)
fields = record.get("fields")
assert fields is not None
assert isinstance(fields, dict)
assert fields.get("test_field") == "test_value"
record_id = record.get("id")
assert record_id is not None
assert isinstance(record_id, str)
# Get a record
record = await get_record(credentials, base_id, table_id, record_id)
fields = record.get("fields")
assert fields is not None
assert isinstance(fields, dict)
assert fields.get("test_field") == "test_value"
# Updata a record
record_fields = {"test_field": "test_value_updated"}
record = await update_record(
credentials, base_id, table_id, record_id, fields=record_fields
)
fields = record.get("fields")
assert fields is not None
assert isinstance(fields, dict)
assert fields.get("test_field") == "test_value_updated"
# Delete a record
record = await delete_record(credentials, base_id, table_id, record_id)
assert record is not None
assert record.get("id") == record_id
assert record.get("deleted")
# Create 2 records
records = [
{"fields": {"test_field": "test_value_1"}},
{"fields": {"test_field": "test_value_2"}},
]
response = await create_record(credentials, base_id, table_id, records=records)
created_records = response.get("records")
assert created_records is not None
assert isinstance(created_records, list)
assert len(created_records) == 2, f"Created records: {created_records}"
first_record = created_records[0] # type: ignore
second_record = created_records[1] # type: ignore
first_record_id = first_record.get("id")
second_record_id = second_record.get("id")
assert first_record_id is not None
assert second_record_id is not None
assert first_record_id != second_record_id
first_fields = first_record.get("fields")
second_fields = second_record.get("fields")
assert first_fields is not None
assert second_fields is not None
assert first_fields.get("test_field") == "test_value_1" # type: ignore
assert second_fields.get("test_field") == "test_value_2" # type: ignore
# List records
response = await list_records(credentials, base_id, table_id)
records = response.get("records")
assert records is not None
assert len(records) == 2, f"Records: {records}"
assert isinstance(records, list), f"Type of records: {type(records)}"
# Update multiple records
records = [
{"id": first_record_id, "fields": {"test_field": "test_value_1_updated"}},
{"id": second_record_id, "fields": {"test_field": "test_value_2_updated"}},
]
response = await update_multiple_records(
credentials, base_id, table_id, records=records
)
updated_records = response.get("records")
assert updated_records is not None
assert len(updated_records) == 2, f"Updated records: {updated_records}"
assert isinstance(
updated_records, list
), f"Type of updated records: {type(updated_records)}"
first_updated = updated_records[0] # type: ignore
second_updated = updated_records[1] # type: ignore
first_updated_fields = first_updated.get("fields")
second_updated_fields = second_updated.get("fields")
assert first_updated_fields is not None
assert second_updated_fields is not None
assert first_updated_fields.get("test_field") == "test_value_1_updated" # type: ignore
assert second_updated_fields.get("test_field") == "test_value_2_updated" # type: ignore
# Delete multiple records
assert isinstance(first_record_id, str)
assert isinstance(second_record_id, str)
response = await delete_multiple_records(
credentials, base_id, table_id, records=[first_record_id, second_record_id]
)
deleted_records = response.get("records")
assert deleted_records is not None
assert len(deleted_records) == 2, f"Deleted records: {deleted_records}"
assert isinstance(
deleted_records, list
), f"Type of deleted records: {type(deleted_records)}"
first_deleted = deleted_records[0] # type: ignore
second_deleted = deleted_records[1] # type: ignore
assert first_deleted.get("deleted")
assert second_deleted.get("deleted")
@pytest.mark.asyncio
async def test_webhook_management():
key = getenv("AIRTABLE_API_KEY")
if not key:
return pytest.skip("AIRTABLE_API_KEY is not set")
credentials = APIKeyCredentials(
provider="airtable",
api_key=SecretStr(key),
)
postfix = uuid4().hex[:4]
base_id = "appZPxegHEU3kDc1S"
table_name = f"test_table_{postfix}"
table_fields = [{"name": "test_field", "type": "singleLineText"}]
table = await create_table(credentials, base_id, table_name, table_fields)
assert table.get("name") == table_name
table_id = table.get("id")
assert table_id is not None
webhook_specification = WebhookSpecification(
filters=WebhookFilters(
dataTypes=["tableData", "tableFields", "tableMetadata"],
changeTypes=["add", "update", "remove"],
)
)
response = await create_webhook(credentials, base_id, webhook_specification)
assert response is not None, f"Checking create webhook response: {response}"
assert (
response.get("id") is not None
), f"Checking create webhook response id: {response}"
assert (
response.get("macSecretBase64") is not None
), f"Checking create webhook response macSecretBase64: {response}"
webhook_id = response.get("id")
assert webhook_id is not None, f"Webhook ID: {webhook_id}"
assert isinstance(webhook_id, str)
response = await create_record(
credentials, base_id, table_id, fields={"test_field": "test_value"}
)
assert response is not None, f"Checking create record response: {response}"
assert (
response.get("id") is not None
), f"Checking create record response id: {response}"
fields = response.get("fields")
assert fields is not None, f"Checking create record response fields: {response}"
assert (
fields.get("test_field") == "test_value"
), f"Checking create record response fields test_field: {response}"
response = await list_webhook_payloads(credentials, base_id, webhook_id)
assert response is not None, f"Checking list webhook payloads response: {response}"
response = await delete_webhook(credentials, base_id, webhook_id)

View File

@@ -1,32 +0,0 @@
"""
Shared configuration for all Airtable blocks using the SDK pattern.
"""
from backend.sdk import BlockCostType, ProviderBuilder
from ._oauth import AirtableOAuthHandler, AirtableScope
from ._webhook import AirtableWebhookManager
# Configure the Airtable provider with API key authentication
airtable = (
ProviderBuilder("airtable")
.with_api_key("AIRTABLE_API_KEY", "Airtable Personal Access Token")
.with_webhook_manager(AirtableWebhookManager)
.with_base_cost(1, BlockCostType.RUN)
.with_oauth(
AirtableOAuthHandler,
scopes=[
v.value
for v in [
AirtableScope.DATA_RECORDS_READ,
AirtableScope.DATA_RECORDS_WRITE,
AirtableScope.SCHEMA_BASES_READ,
AirtableScope.SCHEMA_BASES_WRITE,
AirtableScope.WEBHOOK_MANAGE,
]
],
client_id_env_var="AIRTABLE_CLIENT_ID",
client_secret_env_var="AIRTABLE_CLIENT_SECRET",
)
.build()
)

View File

@@ -1,185 +0,0 @@
"""
Airtable OAuth handler implementation.
"""
import time
from enum import Enum
from logging import getLogger
from typing import Optional
from backend.sdk import BaseOAuthHandler, OAuth2Credentials, ProviderName, SecretStr
from ._api import (
OAuthTokenResponse,
make_oauth_authorize_url,
oauth_exchange_code_for_tokens,
oauth_refresh_tokens,
)
logger = getLogger(__name__)
class AirtableScope(str, Enum):
# Basic scopes
DATA_RECORDS_READ = "data.records:read"
DATA_RECORDS_WRITE = "data.records:write"
DATA_RECORD_COMMENTS_READ = "data.recordComments:read"
DATA_RECORD_COMMENTS_WRITE = "data.recordComments:write"
SCHEMA_BASES_READ = "schema.bases:read"
SCHEMA_BASES_WRITE = "schema.bases:write"
WEBHOOK_MANAGE = "webhook:manage"
BLOCK_MANAGE = "block:manage"
USER_EMAIL_READ = "user.email:read"
# Enterprise member scopes
ENTERPRISE_GROUPS_READ = "enterprise.groups:read"
WORKSPACES_AND_BASES_READ = "workspacesAndBases:read"
WORKSPACES_AND_BASES_WRITE = "workspacesAndBases:write"
WORKSPACES_AND_BASES_SHARES_MANAGE = "workspacesAndBases.shares:manage"
# Enterprise admin scopes
ENTERPRISE_SCIM_USERS_AND_GROUPS_MANAGE = "enterprise.scim.usersAndGroups:manage"
ENTERPRISE_AUDIT_LOGS_READ = "enterprise.auditLogs:read"
ENTERPRISE_CHANGE_EVENTS_READ = "enterprise.changeEvents:read"
ENTERPRISE_EXPORTS_MANAGE = "enterprise.exports:manage"
ENTERPRISE_ACCOUNT_READ = "enterprise.account:read"
ENTERPRISE_ACCOUNT_WRITE = "enterprise.account:write"
ENTERPRISE_USER_READ = "enterprise.user:read"
ENTERPRISE_USER_WRITE = "enterprise.user:write"
ENTERPRISE_GROUPS_MANAGE = "enterprise.groups:manage"
WORKSPACES_AND_BASES_MANAGE = "workspacesAndBases:manage"
HYPERDB_RECORDS_READ = "hyperDB.records:read"
HYPERDB_RECORDS_WRITE = "hyperDB.records:write"
class AirtableOAuthHandler(BaseOAuthHandler):
"""
OAuth2 handler for Airtable with PKCE support.
"""
PROVIDER_NAME = ProviderName("airtable")
DEFAULT_SCOPES = [
v.value
for v in [
AirtableScope.DATA_RECORDS_READ,
AirtableScope.DATA_RECORDS_WRITE,
AirtableScope.SCHEMA_BASES_READ,
AirtableScope.SCHEMA_BASES_WRITE,
AirtableScope.WEBHOOK_MANAGE,
]
]
def __init__(self, client_id: str, client_secret: Optional[str], redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
self.scopes = self.DEFAULT_SCOPES
self.auth_base_url = "https://airtable.com/oauth2/v1/authorize"
self.token_url = "https://airtable.com/oauth2/v1/token"
def get_login_url(
self, scopes: list[str], state: str, code_challenge: Optional[str]
) -> str:
logger.debug("Generating Airtable OAuth login URL")
# Generate code_challenge if not provided (PKCE is required)
if not scopes:
logger.debug("No scopes provided, using default scopes")
scopes = self.scopes
logger.debug(f"Using scopes: {scopes}")
logger.debug(f"State: {state}")
logger.debug(f"Code challenge: {code_challenge}")
if not code_challenge:
logger.error("Code challenge is required but none was provided")
raise ValueError("No code challenge provided")
try:
url = make_oauth_authorize_url(
self.client_id, self.redirect_uri, scopes, state, code_challenge
)
logger.debug(f"Generated OAuth URL: {url}")
return url
except Exception as e:
logger.error(f"Failed to generate OAuth URL: {str(e)}")
raise
async def exchange_code_for_tokens(
self, code: str, scopes: list[str], code_verifier: Optional[str]
) -> OAuth2Credentials:
logger.debug("Exchanging authorization code for tokens")
logger.debug(f"Code: {code[:4]}...") # Log first 4 chars only for security
logger.debug(f"Scopes: {scopes}")
if not code_verifier:
logger.error("Code verifier is required but none was provided")
raise ValueError("No code verifier provided")
try:
response: OAuthTokenResponse = await oauth_exchange_code_for_tokens(
client_id=self.client_id,
code=code,
code_verifier=code_verifier.encode("utf-8"),
redirect_uri=self.redirect_uri,
client_secret=self.client_secret,
)
logger.info("Successfully exchanged code for tokens")
credentials = OAuth2Credentials(
access_token=SecretStr(response.access_token),
refresh_token=SecretStr(response.refresh_token),
access_token_expires_at=int(time.time()) + response.expires_in,
refresh_token_expires_at=int(time.time()) + response.refresh_expires_in,
provider=self.PROVIDER_NAME,
scopes=scopes,
)
logger.debug(f"Access token expires in {response.expires_in} seconds")
logger.debug(
f"Refresh token expires in {response.refresh_expires_in} seconds"
)
return credentials
except Exception as e:
logger.error(f"Failed to exchange code for tokens: {str(e)}")
raise
async def _refresh_tokens(
self, credentials: OAuth2Credentials
) -> OAuth2Credentials:
logger.debug("Attempting to refresh OAuth tokens")
if credentials.refresh_token is None:
logger.error("Cannot refresh tokens - no refresh token available")
raise ValueError("No refresh token available")
try:
response: OAuthTokenResponse = await oauth_refresh_tokens(
client_id=self.client_id,
refresh_token=credentials.refresh_token.get_secret_value(),
client_secret=self.client_secret,
)
logger.info("Successfully refreshed tokens")
new_credentials = OAuth2Credentials(
id=credentials.id,
access_token=SecretStr(response.access_token),
refresh_token=SecretStr(response.refresh_token),
access_token_expires_at=int(time.time()) + response.expires_in,
refresh_token_expires_at=int(time.time()) + response.refresh_expires_in,
provider=self.PROVIDER_NAME,
scopes=self.scopes,
)
logger.debug(f"New access token expires in {response.expires_in} seconds")
logger.debug(
f"New refresh token expires in {response.refresh_expires_in} seconds"
)
return new_credentials
except Exception as e:
logger.error(f"Failed to refresh tokens: {str(e)}")
raise
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
logger.debug("Token revocation requested")
logger.info(
"Airtable doesn't provide a token revocation endpoint - tokens will expire naturally after 60 minutes"
)
return False

View File

@@ -1,154 +0,0 @@
"""
Webhook management for Airtable blocks.
"""
import hashlib
import hmac
import logging
from enum import Enum
from backend.sdk import (
BaseWebhooksManager,
Credentials,
ProviderName,
Webhook,
update_webhook,
)
from ._api import (
WebhookFilters,
WebhookSpecification,
create_webhook,
delete_webhook,
list_webhook_payloads,
)
logger = logging.getLogger(__name__)
class AirtableWebhookEvent(str, Enum):
TABLE_DATA = "tableData"
TABLE_FIELDS = "tableFields"
TABLE_METADATA = "tableMetadata"
class AirtableWebhookManager(BaseWebhooksManager):
"""Webhook manager for Airtable API."""
PROVIDER_NAME = ProviderName("airtable")
@classmethod
async def validate_payload(
cls, webhook: Webhook, request, credentials: Credentials | None
) -> tuple[dict, str]:
"""Validate incoming webhook payload and signature."""
if not credentials:
raise ValueError("Missing credentials in webhook metadata")
payload = await request.json()
# Verify webhook signature using HMAC-SHA256
if webhook.secret:
mac_secret = webhook.config.get("mac_secret")
if mac_secret:
# Get the raw body for signature verification
body = await request.body()
# Calculate expected signature
mac_secret_decoded = mac_secret.encode()
hmac_obj = hmac.new(mac_secret_decoded, body, hashlib.sha256)
expected_mac = f"hmac-sha256={hmac_obj.hexdigest()}"
# Get signature from headers
signature = request.headers.get("X-Airtable-Content-MAC")
if signature and not hmac.compare_digest(signature, expected_mac):
raise ValueError("Invalid webhook signature")
# Validate payload structure
required_fields = ["base", "webhook", "timestamp"]
if not all(field in payload for field in required_fields):
raise ValueError("Invalid webhook payload structure")
if "id" not in payload["base"] or "id" not in payload["webhook"]:
raise ValueError("Missing required IDs in webhook payload")
base_id = payload["base"]["id"]
webhook_id = payload["webhook"]["id"]
# get payload request parameters
cursor = webhook.config.get("cursor", 1)
response = await list_webhook_payloads(credentials, base_id, webhook_id, cursor)
# update webhook config
await update_webhook(
webhook.id,
config={"base_id": base_id, "cursor": response.cursor},
)
event_type = "notification"
return response.model_dump(), event_type
async def _register_webhook(
self,
credentials: Credentials,
webhook_type: str,
resource: str,
events: list[str],
ingress_url: str,
secret: str,
) -> tuple[str, dict]:
"""Register webhook with Airtable API."""
# Parse resource to get base_id and table_id/name
# Resource format: "{base_id}/{table_id_or_name}"
parts = resource.split("/", 1)
if len(parts) != 2:
raise ValueError("Resource must be in format: {base_id}/{table_id_or_name}")
base_id, table_id_or_name = parts
# Prepare webhook specification
webhook_specification = WebhookSpecification(
filters=WebhookFilters(
dataTypes=events,
)
)
# Create webhook
webhook_data = await create_webhook(
credentials=credentials,
base_id=base_id,
webhook_specification=webhook_specification,
notification_url=ingress_url,
)
webhook_id = webhook_data["id"]
mac_secret = webhook_data.get("macSecretBase64")
return webhook_id, {
"webhook_id": webhook_id,
"base_id": base_id,
"table_id_or_name": table_id_or_name,
"events": events,
"mac_secret": mac_secret,
"cursor": 1,
"expiration_time": webhook_data.get("expirationTime"),
}
async def _deregister_webhook(
self, webhook: Webhook, credentials: Credentials
) -> None:
"""Deregister webhook from Airtable API."""
base_id = webhook.config.get("base_id")
webhook_id = webhook.config.get("webhook_id")
if not base_id:
raise ValueError("Missing base_id in webhook metadata")
if not webhook_id:
raise ValueError("Missing webhook_id in webhook metadata")
await delete_webhook(credentials, base_id, webhook_id)

View File

@@ -1,122 +0,0 @@
"""
Airtable base operation blocks.
"""
from typing import Optional
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
SchemaField,
)
from ._api import create_base, list_bases
from ._config import airtable
class AirtableCreateBaseBlock(Block):
"""
Creates a new base in an Airtable workspace.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
workspace_id: str = SchemaField(
description="The workspace ID where the base will be created"
)
name: str = SchemaField(description="The name of the new base")
tables: list[dict] = SchemaField(
description="At least one table and field must be specified. Array of table objects to create in the base. Each table should have 'name' and 'fields' properties",
default=[
{
"description": "Default table",
"name": "Default table",
"fields": [
{
"name": "ID",
"type": "number",
"description": "Auto-incrementing ID field",
"options": {"precision": 0},
}
],
}
],
)
class Output(BlockSchema):
base_id: str = SchemaField(description="The ID of the created base")
tables: list[dict] = SchemaField(description="Array of table objects")
table: dict = SchemaField(description="A single table object")
def __init__(self):
super().__init__(
id="f59b88a8-54ce-4676-a508-fd614b4e8dce",
description="Create a new base in Airtable",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
data = await create_base(
credentials,
input_data.workspace_id,
input_data.name,
input_data.tables,
)
yield "base_id", data.get("id", None)
yield "tables", data.get("tables", [])
for table in data.get("tables", []):
yield "table", table
class AirtableListBasesBlock(Block):
"""
Lists all bases in an Airtable workspace that the user has access to.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
trigger: str = SchemaField(
description="Trigger the block to run - value is ignored", default="manual"
)
offset: str = SchemaField(
description="Pagination offset from previous request", default=""
)
class Output(BlockSchema):
bases: list[dict] = SchemaField(description="Array of base objects")
offset: Optional[str] = SchemaField(
description="Offset for next page (null if no more bases)", default=None
)
def __init__(self):
super().__init__(
id="4bd8d466-ed5d-4e44-8083-97f25a8044e7",
description="List all bases in Airtable",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
data = await list_bases(
credentials,
offset=input_data.offset if input_data.offset else None,
)
yield "bases", data.get("bases", [])
yield "offset", data.get("offset", None)

View File

@@ -1,283 +0,0 @@
"""
Airtable record operation blocks.
"""
from typing import Optional
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
SchemaField,
)
from ._api import (
create_record,
delete_multiple_records,
get_record,
list_records,
update_multiple_records,
)
from ._config import airtable
class AirtableListRecordsBlock(Block):
"""
Lists records from an Airtable table with optional filtering, sorting, and pagination.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="The Airtable base ID")
table_id_or_name: str = SchemaField(description="Table ID or name")
filter_formula: str = SchemaField(
description="Airtable formula to filter records", default=""
)
view: str = SchemaField(description="View ID or name to use", default="")
sort: list[dict] = SchemaField(
description="Sort configuration (array of {field, direction})", default=[]
)
max_records: int = SchemaField(
description="Maximum number of records to return", default=100
)
page_size: int = SchemaField(
description="Number of records per page (max 100)", default=100
)
offset: str = SchemaField(
description="Pagination offset from previous request", default=""
)
return_fields: list[str] = SchemaField(
description="Specific fields to return (comma-separated)", default=[]
)
class Output(BlockSchema):
records: list[dict] = SchemaField(description="Array of record objects")
offset: Optional[str] = SchemaField(
description="Offset for next page (null if no more records)", default=None
)
def __init__(self):
super().__init__(
id="588a9fde-5733-4da7-b03c-35f5671e960f",
description="List records from an Airtable table",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
data = await list_records(
credentials,
input_data.base_id,
input_data.table_id_or_name,
filter_by_formula=(
input_data.filter_formula if input_data.filter_formula else None
),
view=input_data.view if input_data.view else None,
sort=input_data.sort if input_data.sort else None,
max_records=input_data.max_records if input_data.max_records else None,
page_size=min(input_data.page_size, 100) if input_data.page_size else None,
offset=input_data.offset if input_data.offset else None,
fields=input_data.return_fields if input_data.return_fields else None,
)
yield "records", data.get("records", [])
yield "offset", data.get("offset", None)
class AirtableGetRecordBlock(Block):
"""
Retrieves a single record from an Airtable table by its ID.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="The Airtable base ID")
table_id_or_name: str = SchemaField(description="Table ID or name")
record_id: str = SchemaField(description="The record ID to retrieve")
class Output(BlockSchema):
id: str = SchemaField(description="The record ID")
fields: dict = SchemaField(description="The record fields")
created_time: str = SchemaField(description="The record created time")
def __init__(self):
super().__init__(
id="c29c5cbf-0aff-40f9-bbb5-f26061792d2b",
description="Get a single record from Airtable",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
record = await get_record(
credentials,
input_data.base_id,
input_data.table_id_or_name,
input_data.record_id,
)
yield "id", record.get("id", None)
yield "fields", record.get("fields", None)
yield "created_time", record.get("createdTime", None)
class AirtableCreateRecordsBlock(Block):
"""
Creates one or more records in an Airtable table.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="The Airtable base ID")
table_id_or_name: str = SchemaField(description="Table ID or name")
records: list[dict] = SchemaField(
description="Array of records to create (each with 'fields' object)"
)
typecast: bool = SchemaField(
description="Automatically convert string values to appropriate types",
default=False,
)
return_fields_by_field_id: bool | None = SchemaField(
description="Return fields by field ID",
default=None,
)
class Output(BlockSchema):
records: list[dict] = SchemaField(description="Array of created record objects")
details: dict = SchemaField(description="Details of the created records")
def __init__(self):
super().__init__(
id="42527e98-47b6-44ce-ac0e-86b4883721d3",
description="Create records in an Airtable table",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
# The create_record API expects records in a specific format
data = await create_record(
credentials,
input_data.base_id,
input_data.table_id_or_name,
records=[{"fields": record} for record in input_data.records],
typecast=input_data.typecast if input_data.typecast else None,
return_fields_by_field_id=input_data.return_fields_by_field_id,
)
yield "records", data.get("records", [])
details = data.get("details", None)
if details:
yield "details", details
class AirtableUpdateRecordsBlock(Block):
"""
Updates one or more existing records in an Airtable table.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="The Airtable base ID")
table_id_or_name: str = SchemaField(
description="Table ID or name - It's better to use the table ID instead of the name"
)
records: list[dict] = SchemaField(
description="Array of records to update (each with 'id' and 'fields')"
)
typecast: bool | None = SchemaField(
description="Automatically convert string values to appropriate types",
default=None,
)
class Output(BlockSchema):
records: list[dict] = SchemaField(description="Array of updated record objects")
def __init__(self):
super().__init__(
id="6e7d2590-ac2b-4b5d-b08c-fc039cd77e1f",
description="Update records in an Airtable table",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
# The update_multiple_records API expects records with id and fields
data = await update_multiple_records(
credentials,
input_data.base_id,
input_data.table_id_or_name,
records=input_data.records,
typecast=input_data.typecast if input_data.typecast else None,
return_fields_by_field_id=False, # Use field names, not IDs
)
yield "records", data.get("records", [])
class AirtableDeleteRecordsBlock(Block):
"""
Deletes one or more records from an Airtable table.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="The Airtable base ID")
table_id_or_name: str = SchemaField(
description="Table ID or name - It's better to use the table ID instead of the name"
)
record_ids: list[str] = SchemaField(
description="Array of upto 10 record IDs to delete"
)
class Output(BlockSchema):
records: list[dict] = SchemaField(description="Array of deletion results")
def __init__(self):
super().__init__(
id="93e22b8b-3642-4477-aefb-1c0929a4a3a6",
description="Delete records from an Airtable table",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
if len(input_data.record_ids) > 10:
yield "error", "Only upto 10 record IDs can be deleted at a time"
else:
data = await delete_multiple_records(
credentials,
input_data.base_id,
input_data.table_id_or_name,
input_data.record_ids,
)
yield "records", data.get("records", [])

View File

@@ -1,252 +0,0 @@
"""
Airtable schema and table management blocks.
"""
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
Requests,
SchemaField,
)
from ._api import TableFieldType, create_field, create_table, update_field, update_table
from ._config import airtable
class AirtableListSchemaBlock(Block):
"""
Retrieves the complete schema of an Airtable base, including all tables,
fields, and views.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="The Airtable base ID")
class Output(BlockSchema):
base_schema: dict = SchemaField(
description="Complete base schema with tables, fields, and views"
)
tables: list[dict] = SchemaField(description="Array of table objects")
def __init__(self):
super().__init__(
id="64291d3c-99b5-47b7-a976-6d94293cdb2d",
description="Get the complete schema of an Airtable base",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
api_key = credentials.api_key.get_secret_value()
# Get base schema
response = await Requests().get(
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/tables",
headers={"Authorization": f"Bearer {api_key}"},
)
data = response.json()
yield "base_schema", data
yield "tables", data.get("tables", [])
class AirtableCreateTableBlock(Block):
"""
Creates a new table in an Airtable base with specified fields and views.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="The Airtable base ID")
table_name: str = SchemaField(description="The name of the table to create")
table_fields: list[dict] = SchemaField(
description="Table fields with name, type, and options",
default=[{"name": "Name", "type": "singleLineText"}],
)
class Output(BlockSchema):
table: dict = SchemaField(description="Created table object")
table_id: str = SchemaField(description="ID of the created table")
def __init__(self):
super().__init__(
id="fcc20ced-d817-42ea-9b40-c35e7bf34b4f",
description="Create a new table in an Airtable base",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
table_data = await create_table(
credentials,
input_data.base_id,
input_data.table_name,
input_data.table_fields,
)
yield "table", table_data
yield "table_id", table_data.get("id", "")
class AirtableUpdateTableBlock(Block):
"""
Updates an existing table's properties such as name or description.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="The Airtable base ID")
table_id: str = SchemaField(description="The table ID to update")
table_name: str | None = SchemaField(
description="The name of the table to update", default=None
)
table_description: str | None = SchemaField(
description="The description of the table to update", default=None
)
date_dependency: dict | None = SchemaField(
description="The date dependency of the table to update", default=None
)
class Output(BlockSchema):
table: dict = SchemaField(description="Updated table object")
def __init__(self):
super().__init__(
id="34077c5f-f962-49f2-9ec6-97c67077013a",
description="Update table properties",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
table_data = await update_table(
credentials,
input_data.base_id,
input_data.table_id,
input_data.table_name,
input_data.table_description,
input_data.date_dependency,
)
yield "table", table_data
class AirtableCreateFieldBlock(Block):
"""
Adds a new field (column) to an existing Airtable table.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="The Airtable base ID")
table_id: str = SchemaField(description="The table ID to add field to")
field_type: TableFieldType = SchemaField(
description="The type of the field to create",
default=TableFieldType.SINGLE_LINE_TEXT,
advanced=False,
)
name: str = SchemaField(description="The name of the field to create")
description: str | None = SchemaField(
description="The description of the field to create", default=None
)
options: dict[str, str] | None = SchemaField(
description="The options of the field to create", default=None
)
class Output(BlockSchema):
field: dict = SchemaField(description="Created field object")
field_id: str = SchemaField(description="ID of the created field")
def __init__(self):
super().__init__(
id="6c98a32f-dbf9-45d8-a2a8-5e97e8326351",
description="Add a new field to an Airtable table",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
field_data = await create_field(
credentials,
input_data.base_id,
input_data.table_id,
input_data.field_type,
input_data.name,
)
yield "field", field_data
yield "field_id", field_data.get("id", "")
class AirtableUpdateFieldBlock(Block):
"""
Updates an existing field's properties in an Airtable table.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="The Airtable base ID")
table_id: str = SchemaField(description="The table ID containing the field")
field_id: str = SchemaField(description="The field ID to update")
name: str | None = SchemaField(
description="The name of the field to update", default=None, advanced=False
)
description: str | None = SchemaField(
description="The description of the field to update",
default=None,
advanced=False,
)
class Output(BlockSchema):
field: dict = SchemaField(description="Updated field object")
def __init__(self):
super().__init__(
id="f46ac716-3b18-4da1-92e4-34ca9a464d48",
description="Update field properties in an Airtable table",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
field_data = await update_field(
credentials,
input_data.base_id,
input_data.table_id,
input_data.field_id,
input_data.name,
input_data.description,
)
yield "field", field_data

View File

@@ -1,113 +0,0 @@
from backend.sdk import (
BaseModel,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
BlockWebhookConfig,
CredentialsMetaInput,
ProviderName,
SchemaField,
)
from ._api import WebhookPayload
from ._config import airtable
class AirtableEventSelector(BaseModel):
"""
Selects the Airtable webhook event to trigger on.
"""
tableData: bool = True
tableFields: bool = True
tableMetadata: bool = True
class AirtableWebhookTriggerBlock(Block):
"""
Starts a flow whenever Airtable emits a webhook event.
Thin wrapper just forwards the payloads one at a time to the next block.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="Airtable base ID")
table_id_or_name: str = SchemaField(description="Airtable table ID or name")
payload: dict = SchemaField(hidden=True, default_factory=dict)
events: AirtableEventSelector = SchemaField(
description="Airtable webhook event filter"
)
class Output(BlockSchema):
payload: WebhookPayload = SchemaField(description="Airtable webhook payload")
def __init__(self):
example_payload = {
"payloads": [
{
"timestamp": "2022-02-01T21:25:05.663Z",
"baseTransactionNumber": 4,
"actionMetadata": {
"source": "client",
"sourceMetadata": {
"user": {
"id": "usr00000000000000",
"email": "foo@bar.com",
"permissionLevel": "create",
}
},
},
"payloadFormat": "v0",
}
],
"cursor": 5,
"mightHaveMore": False,
}
super().__init__(
# NOTE: This is disabled whilst the webhook system is finalised.
disabled=False,
id="d0180ce6-ccb9-48c7-8256-b39e93e62801",
description="Starts a flow whenever Airtable emits a webhook event",
categories={BlockCategory.INPUT, BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
block_type=BlockType.WEBHOOK,
webhook_config=BlockWebhookConfig(
provider=ProviderName("airtable"),
webhook_type="not-used",
event_filter_input="events",
event_format="{event}",
resource_format="{base_id}/{table_id_or_name}",
),
test_input={
"credentials": airtable.get_test_credentials().model_dump(),
"base_id": "app1234567890",
"table_id_or_name": "table1234567890",
"events": AirtableEventSelector(
tableData=True,
tableFields=True,
tableMetadata=False,
).model_dump(),
"payload": example_payload,
},
test_credentials=airtable.get_test_credentials(),
test_output=[
(
"payload",
WebhookPayload.model_validate(example_payload["payloads"][0]),
),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
if len(input_data.payload["payloads"]) > 0:
for item in input_data.payload["payloads"]:
yield "payload", WebhookPayload.model_validate(item)
else:
yield "error", "No valid payloads found in webhook payload"

View File

@@ -4,7 +4,6 @@ from typing import List
from backend.blocks.apollo._auth import ApolloCredentials
from backend.blocks.apollo.models import (
Contact,
EnrichPersonRequest,
Organization,
SearchOrganizationsRequest,
SearchOrganizationsResponse,
@@ -30,10 +29,10 @@ class ApolloClient:
async def search_people(self, query: SearchPeopleRequest) -> List[Contact]:
"""Search for people in Apollo"""
response = await self.requests.post(
response = await self.requests.get(
f"{self.API_URL}/mixed_people/search",
headers=self._get_headers(),
json=query.model_dump(exclude={"max_results"}),
params=query.model_dump(exclude={"credentials", "max_results"}),
)
data = response.json()
parsed_response = SearchPeopleResponse(**data)
@@ -54,10 +53,10 @@ class ApolloClient:
and len(parsed_response.people) > 0
):
query.page += 1
response = await self.requests.post(
response = await self.requests.get(
f"{self.API_URL}/mixed_people/search",
headers=self._get_headers(),
json=query.model_dump(exclude={"max_results"}),
params=query.model_dump(exclude={"credentials", "max_results"}),
)
data = response.json()
parsed_response = SearchPeopleResponse(**data)
@@ -70,10 +69,10 @@ class ApolloClient:
self, query: SearchOrganizationsRequest
) -> List[Organization]:
"""Search for organizations in Apollo"""
response = await self.requests.post(
response = await self.requests.get(
f"{self.API_URL}/mixed_companies/search",
headers=self._get_headers(),
json=query.model_dump(exclude={"max_results"}),
params=query.model_dump(exclude={"credentials", "max_results"}),
)
data = response.json()
parsed_response = SearchOrganizationsResponse(**data)
@@ -94,10 +93,10 @@ class ApolloClient:
and len(parsed_response.organizations) > 0
):
query.page += 1
response = await self.requests.post(
response = await self.requests.get(
f"{self.API_URL}/mixed_companies/search",
headers=self._get_headers(),
json=query.model_dump(exclude={"max_results"}),
params=query.model_dump(exclude={"credentials", "max_results"}),
)
data = response.json()
parsed_response = SearchOrganizationsResponse(**data)
@@ -111,21 +110,3 @@ class ApolloClient:
return (
organizations[: query.max_results] if query.max_results else organizations
)
async def enrich_person(self, query: EnrichPersonRequest) -> Contact:
"""Enrich a person's data including email & phone reveal"""
response = await self.requests.post(
f"{self.API_URL}/people/match",
headers=self._get_headers(),
json=query.model_dump(),
params={
"reveal_personal_emails": "true",
},
)
data = response.json()
if "person" not in data:
raise ValueError(f"Person not found or enrichment failed: {data}")
contact = Contact(**data["person"])
contact.email = contact.email or "-"
return contact

View File

@@ -1,31 +1,17 @@
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel as OriginalBaseModel
from pydantic import ConfigDict
from pydantic import BaseModel, ConfigDict
from backend.data.model import SchemaField
class BaseModel(OriginalBaseModel):
def model_dump(self, *args, exclude: set[str] | None = None, **kwargs):
if exclude is None:
exclude = set("credentials")
else:
exclude.add("credentials")
kwargs.setdefault("exclude_none", True)
kwargs.setdefault("exclude_unset", True)
kwargs.setdefault("exclude_defaults", True)
return super().model_dump(*args, exclude=exclude, **kwargs)
class PrimaryPhone(BaseModel):
"""A primary phone in Apollo"""
number: Optional[str] = ""
source: Optional[str] = ""
sanitized_number: Optional[str] = ""
number: str
source: str
sanitized_number: str
class SenorityLevels(str, Enum):
@@ -56,102 +42,102 @@ class ContactEmailStatuses(str, Enum):
class RuleConfigStatus(BaseModel):
"""A rule config status in Apollo"""
_id: Optional[str] = ""
created_at: Optional[str] = ""
rule_action_config_id: Optional[str] = ""
rule_config_id: Optional[str] = ""
status_cd: Optional[str] = ""
updated_at: Optional[str] = ""
id: Optional[str] = ""
key: Optional[str] = ""
_id: str
created_at: str
rule_action_config_id: str
rule_config_id: str
status_cd: str
updated_at: str
id: str
key: str
class ContactCampaignStatus(BaseModel):
"""A contact campaign status in Apollo"""
id: Optional[str] = ""
emailer_campaign_id: Optional[str] = ""
send_email_from_user_id: Optional[str] = ""
inactive_reason: Optional[str] = ""
status: Optional[str] = ""
added_at: Optional[str] = ""
added_by_user_id: Optional[str] = ""
finished_at: Optional[str] = ""
paused_at: Optional[str] = ""
auto_unpause_at: Optional[str] = ""
send_email_from_email_address: Optional[str] = ""
send_email_from_email_account_id: Optional[str] = ""
manually_set_unpause: Optional[str] = ""
failure_reason: Optional[str] = ""
current_step_id: Optional[str] = ""
in_response_to_emailer_message_id: Optional[str] = ""
cc_emails: Optional[str] = ""
bcc_emails: Optional[str] = ""
to_emails: Optional[str] = ""
id: str
emailer_campaign_id: str
send_email_from_user_id: str
inactive_reason: str
status: str
added_at: str
added_by_user_id: str
finished_at: str
paused_at: str
auto_unpause_at: str
send_email_from_email_address: str
send_email_from_email_account_id: str
manually_set_unpause: str
failure_reason: str
current_step_id: str
in_response_to_emailer_message_id: str
cc_emails: str
bcc_emails: str
to_emails: str
class Account(BaseModel):
"""An account in Apollo"""
id: Optional[str] = ""
name: Optional[str] = ""
website_url: Optional[str] = ""
blog_url: Optional[str] = ""
angellist_url: Optional[str] = ""
linkedin_url: Optional[str] = ""
twitter_url: Optional[str] = ""
facebook_url: Optional[str] = ""
primary_phone: Optional[PrimaryPhone] = PrimaryPhone()
languages: Optional[list[str]] = []
alexa_ranking: Optional[int] = 0
phone: Optional[str] = ""
linkedin_uid: Optional[str] = ""
founded_year: Optional[int] = 0
publicly_traded_symbol: Optional[str] = ""
publicly_traded_exchange: Optional[str] = ""
logo_url: Optional[str] = ""
chrunchbase_url: Optional[str] = ""
primary_domain: Optional[str] = ""
domain: Optional[str] = ""
team_id: Optional[str] = ""
organization_id: Optional[str] = ""
account_stage_id: Optional[str] = ""
source: Optional[str] = ""
original_source: Optional[str] = ""
creator_id: Optional[str] = ""
owner_id: Optional[str] = ""
created_at: Optional[str] = ""
phone_status: Optional[str] = ""
hubspot_id: Optional[str] = ""
salesforce_id: Optional[str] = ""
crm_owner_id: Optional[str] = ""
parent_account_id: Optional[str] = ""
sanitized_phone: Optional[str] = ""
id: str
name: str
website_url: str
blog_url: str
angellist_url: str
linkedin_url: str
twitter_url: str
facebook_url: str
primary_phone: PrimaryPhone
languages: list[str]
alexa_ranking: int
phone: str
linkedin_uid: str
founded_year: int
publicly_traded_symbol: str
publicly_traded_exchange: str
logo_url: str
chrunchbase_url: str
primary_domain: str
domain: str
team_id: str
organization_id: str
account_stage_id: str
source: str
original_source: str
creator_id: str
owner_id: str
created_at: str
phone_status: str
hubspot_id: str
salesforce_id: str
crm_owner_id: str
parent_account_id: str
sanitized_phone: str
# no listed type on the API docs
account_playbook_statues: Optional[list[Any]] = []
account_rule_config_statuses: Optional[list[RuleConfigStatus]] = []
existence_level: Optional[str] = ""
label_ids: Optional[list[str]] = []
typed_custom_fields: Optional[Any] = {}
custom_field_errors: Optional[Any] = {}
modality: Optional[str] = ""
source_display_name: Optional[str] = ""
salesforce_record_id: Optional[str] = ""
crm_record_url: Optional[str] = ""
account_playbook_statues: list[Any]
account_rule_config_statuses: list[RuleConfigStatus]
existence_level: str
label_ids: list[str]
typed_custom_fields: Any
custom_field_errors: Any
modality: str
source_display_name: str
salesforce_record_id: str
crm_record_url: str
class ContactEmail(BaseModel):
"""A contact email in Apollo"""
email: Optional[str] = ""
email_md5: Optional[str] = ""
email_sha256: Optional[str] = ""
email_status: Optional[str] = ""
email_source: Optional[str] = ""
extrapolated_email_confidence: Optional[str] = ""
position: Optional[int] = 0
email_from_customer: Optional[str] = ""
free_domain: Optional[bool] = True
email: str = ""
email_md5: str = ""
email_sha256: str = ""
email_status: str = ""
email_source: str = ""
extrapolated_email_confidence: str = ""
position: int = 0
email_from_customer: str = ""
free_domain: bool = True
class EmploymentHistory(BaseModel):
@@ -164,40 +150,40 @@ class EmploymentHistory(BaseModel):
populate_by_name=True,
)
_id: Optional[str] = ""
created_at: Optional[str] = ""
current: Optional[bool] = False
degree: Optional[str] = ""
description: Optional[str] = ""
emails: Optional[str] = ""
end_date: Optional[str] = ""
grade_level: Optional[str] = ""
kind: Optional[str] = ""
major: Optional[str] = ""
organization_id: Optional[str] = ""
organization_name: Optional[str] = ""
raw_address: Optional[str] = ""
start_date: Optional[str] = ""
title: Optional[str] = ""
updated_at: Optional[str] = ""
id: Optional[str] = ""
key: Optional[str] = ""
_id: Optional[str] = None
created_at: Optional[str] = None
current: Optional[bool] = None
degree: Optional[str] = None
description: Optional[str] = None
emails: Optional[str] = None
end_date: Optional[str] = None
grade_level: Optional[str] = None
kind: Optional[str] = None
major: Optional[str] = None
organization_id: Optional[str] = None
organization_name: Optional[str] = None
raw_address: Optional[str] = None
start_date: Optional[str] = None
title: Optional[str] = None
updated_at: Optional[str] = None
id: Optional[str] = None
key: Optional[str] = None
class Breadcrumb(BaseModel):
"""A breadcrumb in Apollo"""
label: Optional[str] = ""
signal_field_name: Optional[str] = ""
value: str | list | None = ""
display_name: Optional[str] = ""
label: Optional[str] = "N/A"
signal_field_name: Optional[str] = "N/A"
value: str | list | None = "N/A"
display_name: Optional[str] = "N/A"
class TypedCustomField(BaseModel):
"""A typed custom field in Apollo"""
id: Optional[str] = ""
value: Optional[str] = ""
id: Optional[str] = "N/A"
value: Optional[str] = "N/A"
class Pagination(BaseModel):
@@ -219,23 +205,23 @@ class Pagination(BaseModel):
class DialerFlags(BaseModel):
"""A dialer flags in Apollo"""
country_name: Optional[str] = ""
country_enabled: Optional[bool] = True
high_risk_calling_enabled: Optional[bool] = True
potential_high_risk_number: Optional[bool] = True
country_name: str
country_enabled: bool
high_risk_calling_enabled: bool
potential_high_risk_number: bool
class PhoneNumber(BaseModel):
"""A phone number in Apollo"""
raw_number: Optional[str] = ""
sanitized_number: Optional[str] = ""
type: Optional[str] = ""
position: Optional[int] = 0
status: Optional[str] = ""
dnc_status: Optional[str] = ""
dnc_other_info: Optional[str] = ""
dailer_flags: Optional[DialerFlags] = DialerFlags(
raw_number: str = ""
sanitized_number: str = ""
type: str = ""
position: int = 0
status: str = ""
dnc_status: str = ""
dnc_other_info: str = ""
dailer_flags: DialerFlags = DialerFlags(
country_name="",
country_enabled=True,
high_risk_calling_enabled=True,
@@ -253,31 +239,33 @@ class Organization(BaseModel):
populate_by_name=True,
)
id: Optional[str] = ""
name: Optional[str] = ""
website_url: Optional[str] = ""
blog_url: Optional[str] = ""
angellist_url: Optional[str] = ""
linkedin_url: Optional[str] = ""
twitter_url: Optional[str] = ""
facebook_url: Optional[str] = ""
primary_phone: Optional[PrimaryPhone] = PrimaryPhone()
languages: Optional[list[str]] = []
id: Optional[str] = "N/A"
name: Optional[str] = "N/A"
website_url: Optional[str] = "N/A"
blog_url: Optional[str] = "N/A"
angellist_url: Optional[str] = "N/A"
linkedin_url: Optional[str] = "N/A"
twitter_url: Optional[str] = "N/A"
facebook_url: Optional[str] = "N/A"
primary_phone: Optional[PrimaryPhone] = PrimaryPhone(
number="N/A", source="N/A", sanitized_number="N/A"
)
languages: list[str] = []
alexa_ranking: Optional[int] = 0
phone: Optional[str] = ""
linkedin_uid: Optional[str] = ""
phone: Optional[str] = "N/A"
linkedin_uid: Optional[str] = "N/A"
founded_year: Optional[int] = 0
publicly_traded_symbol: Optional[str] = ""
publicly_traded_exchange: Optional[str] = ""
logo_url: Optional[str] = ""
chrunchbase_url: Optional[str] = ""
primary_domain: Optional[str] = ""
sanitized_phone: Optional[str] = ""
owned_by_organization_id: Optional[str] = ""
intent_strength: Optional[str] = ""
show_intent: Optional[bool] = True
publicly_traded_symbol: Optional[str] = "N/A"
publicly_traded_exchange: Optional[str] = "N/A"
logo_url: Optional[str] = "N/A"
chrunchbase_url: Optional[str] = "N/A"
primary_domain: Optional[str] = "N/A"
sanitized_phone: Optional[str] = "N/A"
owned_by_organization_id: Optional[str] = "N/A"
intent_strength: Optional[str] = "N/A"
show_intent: bool = True
has_intent_signal_account: Optional[bool] = True
intent_signal_account: Optional[str] = ""
intent_signal_account: Optional[str] = "N/A"
class Contact(BaseModel):
@@ -290,95 +278,95 @@ class Contact(BaseModel):
populate_by_name=True,
)
contact_roles: Optional[list[Any]] = []
id: Optional[str] = ""
first_name: Optional[str] = ""
last_name: Optional[str] = ""
name: Optional[str] = ""
linkedin_url: Optional[str] = ""
title: Optional[str] = ""
contact_stage_id: Optional[str] = ""
owner_id: Optional[str] = ""
creator_id: Optional[str] = ""
person_id: Optional[str] = ""
email_needs_tickling: Optional[bool] = True
organization_name: Optional[str] = ""
source: Optional[str] = ""
original_source: Optional[str] = ""
organization_id: Optional[str] = ""
headline: Optional[str] = ""
photo_url: Optional[str] = ""
present_raw_address: Optional[str] = ""
linkededin_uid: Optional[str] = ""
extrapolated_email_confidence: Optional[float] = 0.0
salesforce_id: Optional[str] = ""
salesforce_lead_id: Optional[str] = ""
salesforce_contact_id: Optional[str] = ""
saleforce_account_id: Optional[str] = ""
crm_owner_id: Optional[str] = ""
created_at: Optional[str] = ""
emailer_campaign_ids: Optional[list[str]] = []
direct_dial_status: Optional[str] = ""
direct_dial_enrichment_failed_at: Optional[str] = ""
email_status: Optional[str] = ""
email_source: Optional[str] = ""
account_id: Optional[str] = ""
last_activity_date: Optional[str] = ""
hubspot_vid: Optional[str] = ""
hubspot_company_id: Optional[str] = ""
crm_id: Optional[str] = ""
sanitized_phone: Optional[str] = ""
merged_crm_ids: Optional[str] = ""
updated_at: Optional[str] = ""
queued_for_crm_push: Optional[bool] = True
suggested_from_rule_engine_config_id: Optional[str] = ""
email_unsubscribed: Optional[str] = ""
label_ids: Optional[list[Any]] = []
has_pending_email_arcgate_request: Optional[bool] = True
has_email_arcgate_request: Optional[bool] = True
existence_level: Optional[str] = ""
email: Optional[str] = ""
email_from_customer: Optional[str] = ""
typed_custom_fields: Optional[list[TypedCustomField]] = []
custom_field_errors: Optional[Any] = {}
salesforce_record_id: Optional[str] = ""
crm_record_url: Optional[str] = ""
email_status_unavailable_reason: Optional[str] = ""
email_true_status: Optional[str] = ""
updated_email_true_status: Optional[bool] = True
contact_rule_config_statuses: Optional[list[RuleConfigStatus]] = []
source_display_name: Optional[str] = ""
twitter_url: Optional[str] = ""
contact_campaign_statuses: Optional[list[ContactCampaignStatus]] = []
state: Optional[str] = ""
city: Optional[str] = ""
country: Optional[str] = ""
account: Optional[Account] = Account()
contact_emails: Optional[list[ContactEmail]] = []
organization: Optional[Organization] = Organization()
employment_history: Optional[list[EmploymentHistory]] = []
time_zone: Optional[str] = ""
intent_strength: Optional[str] = ""
show_intent: Optional[bool] = True
phone_numbers: Optional[list[PhoneNumber]] = []
account_phone_note: Optional[str] = ""
free_domain: Optional[bool] = True
is_likely_to_engage: Optional[bool] = True
email_domain_catchall: Optional[bool] = True
contact_job_change_event: Optional[str] = ""
contact_roles: list[Any] = []
id: Optional[str] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
name: Optional[str] = None
linkedin_url: Optional[str] = None
title: Optional[str] = None
contact_stage_id: Optional[str] = None
owner_id: Optional[str] = None
creator_id: Optional[str] = None
person_id: Optional[str] = None
email_needs_tickling: bool = True
organization_name: Optional[str] = None
source: Optional[str] = None
original_source: Optional[str] = None
organization_id: Optional[str] = None
headline: Optional[str] = None
photo_url: Optional[str] = None
present_raw_address: Optional[str] = None
linkededin_uid: Optional[str] = None
extrapolated_email_confidence: Optional[float] = None
salesforce_id: Optional[str] = None
salesforce_lead_id: Optional[str] = None
salesforce_contact_id: Optional[str] = None
saleforce_account_id: Optional[str] = None
crm_owner_id: Optional[str] = None
created_at: Optional[str] = None
emailer_campaign_ids: list[str] = []
direct_dial_status: Optional[str] = None
direct_dial_enrichment_failed_at: Optional[str] = None
email_status: Optional[str] = None
email_source: Optional[str] = None
account_id: Optional[str] = None
last_activity_date: Optional[str] = None
hubspot_vid: Optional[str] = None
hubspot_company_id: Optional[str] = None
crm_id: Optional[str] = None
sanitized_phone: Optional[str] = None
merged_crm_ids: Optional[str] = None
updated_at: Optional[str] = None
queued_for_crm_push: bool = True
suggested_from_rule_engine_config_id: Optional[str] = None
email_unsubscribed: Optional[str] = None
label_ids: list[Any] = []
has_pending_email_arcgate_request: bool = True
has_email_arcgate_request: bool = True
existence_level: Optional[str] = None
email: Optional[str] = None
email_from_customer: Optional[str] = None
typed_custom_fields: list[TypedCustomField] = []
custom_field_errors: Any = None
salesforce_record_id: Optional[str] = None
crm_record_url: Optional[str] = None
email_status_unavailable_reason: Optional[str] = None
email_true_status: Optional[str] = None
updated_email_true_status: bool = True
contact_rule_config_statuses: list[RuleConfigStatus] = []
source_display_name: Optional[str] = None
twitter_url: Optional[str] = None
contact_campaign_statuses: list[ContactCampaignStatus] = []
state: Optional[str] = None
city: Optional[str] = None
country: Optional[str] = None
account: Optional[Account] = None
contact_emails: list[ContactEmail] = []
organization: Optional[Organization] = None
employment_history: list[EmploymentHistory] = []
time_zone: Optional[str] = None
intent_strength: Optional[str] = None
show_intent: bool = True
phone_numbers: list[PhoneNumber] = []
account_phone_note: Optional[str] = None
free_domain: bool = True
is_likely_to_engage: bool = True
email_domain_catchall: bool = True
contact_job_change_event: Optional[str] = None
class SearchOrganizationsRequest(BaseModel):
"""Request for Apollo's search organizations API"""
organization_num_employees_range: Optional[list[int]] = SchemaField(
organization_num_empoloyees_range: list[int] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
default=[0, 1000000],
)
organization_locations: Optional[list[str]] = SchemaField(
organization_locations: list[str] = SchemaField(
description="""The location of the company headquarters. You can search across cities, US states, and countries.
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, any Boston-based companies will not appearch in your search results, even if they match other parameters.
@@ -387,30 +375,28 @@ To exclude companies based on location, use the organization_not_locations param
""",
default_factory=list,
)
organizations_not_locations: Optional[list[str]] = SchemaField(
organizations_not_locations: list[str] = SchemaField(
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
""",
default_factory=list,
)
q_organization_keyword_tags: Optional[list[str]] = SchemaField(
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry.""",
default_factory=list,
q_organization_keyword_tags: list[str] = SchemaField(
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry."""
)
q_organization_name: Optional[str] = SchemaField(
q_organization_name: str = SchemaField(
description="""Filter search results to include a specific company name.
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible.""",
default="",
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible."""
)
organization_ids: Optional[list[str]] = SchemaField(
organization_ids: list[str] = SchemaField(
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
To find IDs, identify the values for organization_id when you call this endpoint.""",
default_factory=list,
)
max_results: Optional[int] = SchemaField(
max_results: int = SchemaField(
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
default=100,
ge=1,
@@ -435,11 +421,11 @@ Use the page parameter to search the different pages of data.""",
class SearchOrganizationsResponse(BaseModel):
"""Response from Apollo's search organizations API"""
breadcrumbs: Optional[list[Breadcrumb]] = []
partial_results_only: Optional[bool] = True
has_join: Optional[bool] = True
disable_eu_prospecting: Optional[bool] = True
partial_results_limit: Optional[int] = 0
breadcrumbs: list[Breadcrumb] = []
partial_results_only: bool = True
has_join: bool = True
disable_eu_prospecting: bool = True
partial_results_limit: int = 0
pagination: Pagination = Pagination(
page=0, per_page=0, total_entries=0, total_pages=0
)
@@ -447,14 +433,14 @@ class SearchOrganizationsResponse(BaseModel):
accounts: list[Any] = []
organizations: list[Organization] = []
models_ids: list[str] = []
num_fetch_result: Optional[str] = ""
derived_params: Optional[str] = ""
num_fetch_result: Optional[str] = "N/A"
derived_params: Optional[str] = "N/A"
class SearchPeopleRequest(BaseModel):
"""Request for Apollo's search people API"""
person_titles: Optional[list[str]] = SchemaField(
person_titles: list[str] = SchemaField(
description="""Job titles held by the people you want to find. For a person to be included in search results, they only need to match 1 of the job titles you add. Adding more job titles expands your search results.
Results also include job titles with the same terms, even if they are not exact matches. For example, searching for marketing manager might return people with the job title content marketing manager.
@@ -464,13 +450,13 @@ Use this parameter in combination with the person_seniorities[] parameter to fin
default_factory=list,
placeholder="marketing manager",
)
person_locations: Optional[list[str]] = SchemaField(
person_locations: list[str] = SchemaField(
description="""The location where people live. You can search across cities, US states, and countries.
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
default_factory=list,
)
person_seniorities: Optional[list[SenorityLevels]] = SchemaField(
person_seniorities: list[SenorityLevels] = SchemaField(
description="""The job seniority that people hold within their current employer. This enables you to find people that currently hold positions at certain reporting levels, such as Director level or senior IC level.
For a person to be included in search results, they only need to match 1 of the seniorities you add. Adding more seniorities expands your search results.
@@ -480,7 +466,7 @@ Searches only return results based on their current job title, so searching for
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
default_factory=list,
)
organization_locations: Optional[list[str]] = SchemaField(
organization_locations: list[str] = SchemaField(
description="""The location of the company headquarters for a person's current employer. You can search across cities, US states, and countries.
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
@@ -488,7 +474,7 @@ If a company has several office locations, results are still based on the headqu
To find people based on their personal location, use the person_locations parameter.""",
default_factory=list,
)
q_organization_domains: Optional[list[str]] = SchemaField(
q_organization_domains: list[str] = SchemaField(
description="""The domain name for the person's employer. This can be the current employer or a previous employer. Do not include www., the @ symbol, or similar.
You can add multiple domains to search across companies.
@@ -496,23 +482,23 @@ You can add multiple domains to search across companies.
Examples: apollo.io and microsoft.com""",
default_factory=list,
)
contact_email_statuses: Optional[list[ContactEmailStatuses]] = SchemaField(
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
default_factory=list,
)
organization_ids: Optional[list[str]] = SchemaField(
organization_ids: list[str] = SchemaField(
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
default_factory=list,
)
organization_num_employees_range: Optional[list[int]] = SchemaField(
organization_num_empoloyees_range: list[int] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
default_factory=list,
)
q_keywords: Optional[str] = SchemaField(
q_keywords: str = SchemaField(
description="""A string of words over which we want to filter the results""",
default="",
)
@@ -528,7 +514,7 @@ Use this parameter in combination with the per_page parameter to make search res
Use the page parameter to search the different pages of data.""",
default=100,
)
max_results: Optional[int] = SchemaField(
max_results: int = SchemaField(
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
default=100,
ge=1,
@@ -547,61 +533,16 @@ class SearchPeopleResponse(BaseModel):
populate_by_name=True,
)
breadcrumbs: Optional[list[Breadcrumb]] = []
partial_results_only: Optional[bool] = True
has_join: Optional[bool] = True
disable_eu_prospecting: Optional[bool] = True
partial_results_limit: Optional[int] = 0
breadcrumbs: list[Breadcrumb] = []
partial_results_only: bool = True
has_join: bool = True
disable_eu_prospecting: bool = True
partial_results_limit: int = 0
pagination: Pagination = Pagination(
page=0, per_page=0, total_entries=0, total_pages=0
)
contacts: list[Contact] = []
people: list[Contact] = []
model_ids: list[str] = []
num_fetch_result: Optional[str] = ""
derived_params: Optional[str] = ""
class EnrichPersonRequest(BaseModel):
"""Request for Apollo's person enrichment API"""
person_id: Optional[str] = SchemaField(
description="Apollo person ID to enrich (most accurate method)",
default="",
)
first_name: Optional[str] = SchemaField(
description="First name of the person to enrich",
default="",
)
last_name: Optional[str] = SchemaField(
description="Last name of the person to enrich",
default="",
)
name: Optional[str] = SchemaField(
description="Full name of the person to enrich",
default="",
)
email: Optional[str] = SchemaField(
description="Email address of the person to enrich",
default="",
)
domain: Optional[str] = SchemaField(
description="Company domain of the person to enrich",
default="",
)
company: Optional[str] = SchemaField(
description="Company name of the person to enrich",
default="",
)
linkedin_url: Optional[str] = SchemaField(
description="LinkedIn URL of the person to enrich",
default="",
)
organization_id: Optional[str] = SchemaField(
description="Apollo organization ID of the person's company",
default="",
)
title: Optional[str] = SchemaField(
description="Job title of the person to enrich",
default="",
)
num_fetch_result: Optional[str] = "N/A"
derived_params: Optional[str] = "N/A"

View File

@@ -11,14 +11,14 @@ from backend.blocks.apollo.models import (
SearchOrganizationsRequest,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import CredentialsField, SchemaField
from backend.data.model import SchemaField
class SearchOrganizationsBlock(Block):
"""Search for organizations in Apollo"""
class Input(BlockSchema):
organization_num_employees_range: list[int] = SchemaField(
organization_num_empoloyees_range: list[int] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
@@ -65,7 +65,7 @@ To find IDs, identify the values for organization_id when you call this endpoint
le=50000,
advanced=True,
)
credentials: ApolloCredentialsInput = CredentialsField(
credentials: ApolloCredentialsInput = SchemaField(
description="Apollo credentials",
)
@@ -210,7 +210,9 @@ To find IDs, identify the values for organization_id when you call this endpoint
async def run(
self, input_data: Input, *, credentials: ApolloCredentials, **kwargs
) -> BlockOutput:
query = SearchOrganizationsRequest(**input_data.model_dump())
query = SearchOrganizationsRequest(
**input_data.model_dump(exclude={"credentials"})
)
organizations = await self.search_organizations(query, credentials)
for organization in organizations:
yield "organization", organization

View File

@@ -1,5 +1,3 @@
import asyncio
from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS,
@@ -10,12 +8,11 @@ from backend.blocks.apollo._auth import (
from backend.blocks.apollo.models import (
Contact,
ContactEmailStatuses,
EnrichPersonRequest,
SearchPeopleRequest,
SenorityLevels,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import CredentialsField, SchemaField
from backend.data.model import SchemaField
class SearchPeopleBlock(Block):
@@ -80,7 +77,7 @@ class SearchPeopleBlock(Block):
default_factory=list,
advanced=False,
)
organization_num_employees_range: list[int] = SchemaField(
organization_num_empoloyees_range: list[int] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
@@ -93,19 +90,14 @@ class SearchPeopleBlock(Block):
advanced=False,
)
max_results: int = SchemaField(
description="""The maximum number of results to return. If you don't specify this parameter, the default is 25. Limited to 500 to prevent overspending.""",
default=25,
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
default=100,
ge=1,
le=500,
advanced=True,
)
enrich_info: bool = SchemaField(
description="""Whether to enrich contacts with detailed information including real email addresses. This will double the search cost.""",
default=False,
le=50000,
advanced=True,
)
credentials: ApolloCredentialsInput = CredentialsField(
credentials: ApolloCredentialsInput = SchemaField(
description="Apollo credentials",
)
@@ -114,6 +106,9 @@ class SearchPeopleBlock(Block):
description="List of people found",
default_factory=list,
)
person: Contact = SchemaField(
description="Each found person, one at a time",
)
error: str = SchemaField(
description="Error message if the search failed",
default="",
@@ -129,6 +124,87 @@ class SearchPeopleBlock(Block):
test_credentials=TEST_CREDENTIALS,
test_input={"credentials": TEST_CREDENTIALS_INPUT},
test_output=[
(
"person",
Contact(
contact_roles=[],
id="1",
name="John Doe",
first_name="John",
last_name="Doe",
linkedin_url="https://www.linkedin.com/in/johndoe",
title="Software Engineer",
organization_name="Google",
organization_id="123456",
contact_stage_id="1",
owner_id="1",
creator_id="1",
person_id="1",
email_needs_tickling=True,
source="apollo",
original_source="apollo",
headline="Software Engineer",
photo_url="https://www.linkedin.com/in/johndoe",
present_raw_address="123 Main St, Anytown, USA",
linkededin_uid="123456",
extrapolated_email_confidence=0.8,
salesforce_id="123456",
salesforce_lead_id="123456",
salesforce_contact_id="123456",
saleforce_account_id="123456",
crm_owner_id="123456",
created_at="2021-01-01",
emailer_campaign_ids=[],
direct_dial_status="active",
direct_dial_enrichment_failed_at="2021-01-01",
email_status="active",
email_source="apollo",
account_id="123456",
last_activity_date="2021-01-01",
hubspot_vid="123456",
hubspot_company_id="123456",
crm_id="123456",
sanitized_phone="123456",
merged_crm_ids="123456",
updated_at="2021-01-01",
queued_for_crm_push=True,
suggested_from_rule_engine_config_id="123456",
email_unsubscribed=None,
label_ids=[],
has_pending_email_arcgate_request=True,
has_email_arcgate_request=True,
existence_level=None,
email=None,
email_from_customer=None,
typed_custom_fields=[],
custom_field_errors=None,
salesforce_record_id=None,
crm_record_url=None,
email_status_unavailable_reason=None,
email_true_status=None,
updated_email_true_status=True,
contact_rule_config_statuses=[],
source_display_name=None,
twitter_url=None,
contact_campaign_statuses=[],
state=None,
city=None,
country=None,
account=None,
contact_emails=[],
organization=None,
employment_history=[],
time_zone=None,
intent_strength=None,
show_intent=True,
phone_numbers=[],
account_phone_note=None,
free_domain=True,
is_likely_to_engage=True,
email_domain_catchall=True,
contact_job_change_event=None,
),
),
(
"people",
[
@@ -303,34 +379,6 @@ class SearchPeopleBlock(Block):
client = ApolloClient(credentials)
return await client.search_people(query)
@staticmethod
async def enrich_person(
query: EnrichPersonRequest, credentials: ApolloCredentials
) -> Contact:
client = ApolloClient(credentials)
return await client.enrich_person(query)
@staticmethod
def merge_contact_data(original: Contact, enriched: Contact) -> Contact:
"""
Merge contact data from original search with enriched data.
Enriched data complements original data, only filling in missing values.
"""
merged_data = original.model_dump()
enriched_data = enriched.model_dump()
# Only update fields that are None, empty string, empty list, or default values in original
for key, enriched_value in enriched_data.items():
# Skip if enriched value is None, empty string, or empty list
if enriched_value is None or enriched_value == "" or enriched_value == []:
continue
# Update if original value is None, empty string, empty list, or zero
if enriched_value:
merged_data[key] = enriched_value
return Contact(**merged_data)
async def run(
self,
input_data: Input,
@@ -339,25 +387,8 @@ class SearchPeopleBlock(Block):
**kwargs,
) -> BlockOutput:
query = SearchPeopleRequest(**input_data.model_dump())
query = SearchPeopleRequest(**input_data.model_dump(exclude={"credentials"}))
people = await self.search_people(query, credentials)
# Enrich with detailed info if requested
if input_data.enrich_info:
async def enrich_or_fallback(person: Contact):
try:
enrich_query = EnrichPersonRequest(person_id=person.id)
enriched_person = await self.enrich_person(
enrich_query, credentials
)
# Merge enriched data with original data, complementing instead of replacing
return self.merge_contact_data(person, enriched_person)
except Exception:
return person # If enrichment fails, use original person data
people = await asyncio.gather(
*(enrich_or_fallback(person) for person in people)
)
for person in people:
yield "person", person
yield "people", people

View File

@@ -1,138 +0,0 @@
from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
ApolloCredentials,
ApolloCredentialsInput,
)
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import CredentialsField, SchemaField
class GetPersonDetailBlock(Block):
"""Get detailed person data with Apollo API, including email reveal"""
class Input(BlockSchema):
person_id: str = SchemaField(
description="Apollo person ID to enrich (most accurate method)",
default="",
advanced=False,
)
first_name: str = SchemaField(
description="First name of the person to enrich",
default="",
advanced=False,
)
last_name: str = SchemaField(
description="Last name of the person to enrich",
default="",
advanced=False,
)
name: str = SchemaField(
description="Full name of the person to enrich (alternative to first_name + last_name)",
default="",
advanced=False,
)
email: str = SchemaField(
description="Known email address of the person (helps with matching)",
default="",
advanced=False,
)
domain: str = SchemaField(
description="Company domain of the person (e.g., 'google.com')",
default="",
advanced=False,
)
company: str = SchemaField(
description="Company name of the person",
default="",
advanced=False,
)
linkedin_url: str = SchemaField(
description="LinkedIn URL of the person",
default="",
advanced=False,
)
organization_id: str = SchemaField(
description="Apollo organization ID of the person's company",
default="",
advanced=True,
)
title: str = SchemaField(
description="Job title of the person to enrich",
default="",
advanced=True,
)
credentials: ApolloCredentialsInput = CredentialsField(
description="Apollo credentials",
)
class Output(BlockSchema):
contact: Contact = SchemaField(
description="Enriched contact information",
)
error: str = SchemaField(
description="Error message if enrichment failed",
default="",
)
def __init__(self):
super().__init__(
id="3b18d46c-3db6-42ae-a228-0ba441bdd176",
description="Get detailed person data with Apollo API, including email reveal",
categories={BlockCategory.SEARCH},
input_schema=GetPersonDetailBlock.Input,
output_schema=GetPersonDetailBlock.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"first_name": "John",
"last_name": "Doe",
"company": "Google",
},
test_output=[
(
"contact",
Contact(
id="1",
name="John Doe",
first_name="John",
last_name="Doe",
email="john.doe@gmail.com",
title="Software Engineer",
organization_name="Google",
linkedin_url="https://www.linkedin.com/in/johndoe",
),
),
],
test_mock={
"enrich_person": lambda query, credentials: Contact(
id="1",
name="John Doe",
first_name="John",
last_name="Doe",
email="john.doe@gmail.com",
title="Software Engineer",
organization_name="Google",
linkedin_url="https://www.linkedin.com/in/johndoe",
)
},
)
@staticmethod
async def enrich_person(
query: EnrichPersonRequest, credentials: ApolloCredentials
) -> Contact:
client = ApolloClient(credentials)
return await client.enrich_person(query)
async def run(
self,
input_data: Input,
*,
credentials: ApolloCredentials,
**kwargs,
) -> BlockOutput:
query = EnrichPersonRequest(**input_data.model_dump())
yield "contact", await self.enrich_person(query, credentials)

View File

@@ -1,15 +0,0 @@
AYRSHARE_BLOCK_IDS = [
"cbd52c2a-06d2-43ed-9560-6576cc163283", # PostToBlueskyBlock
"3352f512-3524-49ed-a08f-003042da2fc1", # PostToFacebookBlock
"9e8f844e-b4a5-4b25-80f2-9e1dd7d67625", # PostToXBlock
"589af4e4-507f-42fd-b9ac-a67ecef25811", # PostToLinkedInBlock
"89b02b96-a7cb-46f4-9900-c48b32fe1552", # PostToInstagramBlock
"0082d712-ff1b-4c3d-8a8d-6c7721883b83", # PostToYouTubeBlock
"c7733580-3c72-483e-8e47-a8d58754d853", # PostToRedditBlock
"47bc74eb-4af2-452c-b933-af377c7287df", # PostToTelegramBlock
"2c38c783-c484-4503-9280-ef5d1d345a7e", # PostToGMBBlock
"3ca46e05-dbaa-4afb-9e95-5a429c4177e6", # PostToPinterestBlock
"7faf4b27-96b0-4f05-bf64-e0de54ae74e1", # PostToTikTokBlock
"f8c3b2e1-9d4a-4e5f-8c7b-6a9e8d2f1c3b", # PostToThreadsBlock
"a9d7f854-2c83-4e96-b3a1-7f2e9c5d4b8e", # PostToSnapchatBlock
]

View File

@@ -1,152 +0,0 @@
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
from backend.data.block import BlockSchema
from backend.data.model import SchemaField, UserIntegrations
from backend.integrations.ayrshare import AyrshareClient
from backend.util.clients import get_database_manager_async_client
from backend.util.exceptions import MissingConfigError
async def get_profile_key(user_id: str):
user_integrations: UserIntegrations = (
await get_database_manager_async_client().get_user_integrations(user_id)
)
return user_integrations.managed_credentials.ayrshare_profile_key
class BaseAyrshareInput(BlockSchema):
"""Base input model for Ayrshare social media posts with common fields."""
post: str = SchemaField(
description="The post text to be published", default="", advanced=False
)
media_urls: list[str] = SchemaField(
description="Optional list of media URLs to include. Set is_video in advanced settings to true if you want to upload videos.",
default_factory=list,
advanced=False,
)
is_video: bool = SchemaField(
description="Whether the media is a video", default=False, advanced=True
)
schedule_date: Optional[datetime] = SchemaField(
description="UTC datetime for scheduling (YYYY-MM-DDThh:mm:ssZ)",
default=None,
advanced=True,
)
disable_comments: bool = SchemaField(
description="Whether to disable comments", default=False, advanced=True
)
shorten_links: bool = SchemaField(
description="Whether to shorten links", default=False, advanced=True
)
unsplash: Optional[str] = SchemaField(
description="Unsplash image configuration", default=None, advanced=True
)
requires_approval: bool = SchemaField(
description="Whether to enable approval workflow",
default=False,
advanced=True,
)
random_post: bool = SchemaField(
description="Whether to generate random post text",
default=False,
advanced=True,
)
random_media_url: bool = SchemaField(
description="Whether to generate random media", default=False, advanced=True
)
notes: Optional[str] = SchemaField(
description="Additional notes for the post", default=None, advanced=True
)
class CarouselItem(BaseModel):
"""Model for Facebook carousel items."""
name: str = Field(..., description="The name of the item")
link: str = Field(..., description="The link of the item")
picture: str = Field(..., description="The picture URL of the item")
class CallToAction(BaseModel):
"""Model for Google My Business Call to Action."""
action_type: str = Field(
..., description="Type of action (book, order, shop, learn_more, sign_up, call)"
)
url: Optional[str] = Field(
description="URL for the action (not required for 'call' action)"
)
class EventDetails(BaseModel):
"""Model for Google My Business Event details."""
title: str = Field(..., description="Event title")
start_date: str = Field(..., description="Event start date (ISO format)")
end_date: str = Field(..., description="Event end date (ISO format)")
class OfferDetails(BaseModel):
"""Model for Google My Business Offer details."""
title: str = Field(..., description="Offer title")
start_date: str = Field(..., description="Offer start date (ISO format)")
end_date: str = Field(..., description="Offer end date (ISO format)")
coupon_code: str = Field(..., description="Coupon code (max 58 characters)")
redeem_online_url: str = Field(..., description="URL to redeem the offer")
terms_conditions: str = Field(..., description="Terms and conditions")
class InstagramUserTag(BaseModel):
"""Model for Instagram user tags."""
username: str = Field(..., description="Instagram username (without @)")
x: Optional[float] = Field(description="X coordinate (0.0-1.0) for image posts")
y: Optional[float] = Field(description="Y coordinate (0.0-1.0) for image posts")
class LinkedInTargeting(BaseModel):
"""Model for LinkedIn audience targeting."""
countries: Optional[list[str]] = Field(
description="Country codes (e.g., ['US', 'IN', 'DE', 'GB'])"
)
seniorities: Optional[list[str]] = Field(
description="Seniority levels (e.g., ['Senior', 'VP'])"
)
degrees: Optional[list[str]] = Field(description="Education degrees")
fields_of_study: Optional[list[str]] = Field(description="Fields of study")
industries: Optional[list[str]] = Field(description="Industry categories")
job_functions: Optional[list[str]] = Field(description="Job function categories")
staff_count_ranges: Optional[list[str]] = Field(description="Company size ranges")
class PinterestCarouselOption(BaseModel):
"""Model for Pinterest carousel image options."""
title: Optional[str] = Field(description="Image title")
link: Optional[str] = Field(description="External destination link for the image")
description: Optional[str] = Field(description="Image description")
class YouTubeTargeting(BaseModel):
"""Model for YouTube country targeting."""
block: Optional[list[str]] = Field(
description="Country codes to block (e.g., ['US', 'CA'])"
)
allow: Optional[list[str]] = Field(
description="Country codes to allow (e.g., ['GB', 'AU'])"
)
def create_ayrshare_client():
"""Create an Ayrshare client instance."""
try:
return AyrshareClient()
except MissingConfigError:
return None

View File

@@ -1,114 +0,0 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class PostToBlueskyBlock(Block):
"""Block for posting to Bluesky with Bluesky-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for Bluesky posts."""
# Override post field to include character limit information
post: str = SchemaField(
description="The post text to be published (max 300 characters for Bluesky)",
default="",
advanced=False,
)
# Override media_urls to include Bluesky-specific constraints
media_urls: list[str] = SchemaField(
description="Optional list of media URLs to include. Bluesky supports up to 4 images or 1 video.",
default_factory=list,
advanced=False,
)
# Bluesky-specific options
alt_text: list[str] = SchemaField(
description="Alt text for each media item (accessibility)",
default_factory=list,
advanced=True,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
disabled=True,
id="cbd52c2a-06d2-43ed-9560-6576cc163283",
description="Post to Bluesky using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToBlueskyBlock.Input,
output_schema=PostToBlueskyBlock.Output,
)
async def run(
self,
input_data: "PostToBlueskyBlock.Input",
*,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Bluesky with Bluesky-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate character limit for Bluesky
if len(input_data.post) > 300:
yield "error", f"Post text exceeds Bluesky's 300 character limit ({len(input_data.post)} characters)"
return
# Validate media constraints for Bluesky
if len(input_data.media_urls) > 4:
yield "error", "Bluesky supports a maximum of 4 images or 1 video"
return
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
# Build Bluesky-specific options
bluesky_options = {}
if input_data.alt_text:
bluesky_options["altText"] = input_data.alt_text
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.BLUESKY],
media_urls=input_data.media_urls,
is_video=input_data.is_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
bluesky_options=bluesky_options if bluesky_options else None,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,212 +0,0 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import (
BaseAyrshareInput,
CarouselItem,
create_ayrshare_client,
get_profile_key,
)
class PostToFacebookBlock(Block):
"""Block for posting to Facebook with Facebook-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for Facebook posts."""
# Facebook-specific options
is_carousel: bool = SchemaField(
description="Whether to post a carousel", default=False, advanced=True
)
carousel_link: str = SchemaField(
description="The URL for the 'See More At' button in the carousel",
default="",
advanced=True,
)
carousel_items: list[CarouselItem] = SchemaField(
description="List of carousel items with name, link and picture URLs. Min 2, max 10 items.",
default_factory=list,
advanced=True,
)
is_reels: bool = SchemaField(
description="Whether to post to Facebook Reels",
default=False,
advanced=True,
)
reels_title: str = SchemaField(
description="Title for the Reels video (max 255 chars)",
default="",
advanced=True,
)
reels_thumbnail: str = SchemaField(
description="Thumbnail URL for Reels video (JPEG/PNG, <10MB)",
default="",
advanced=True,
)
is_story: bool = SchemaField(
description="Whether to post as a Facebook Story",
default=False,
advanced=True,
)
media_captions: list[str] = SchemaField(
description="Captions for each media item",
default_factory=list,
advanced=True,
)
location_id: str = SchemaField(
description="Facebook Page ID or name for location tagging",
default="",
advanced=True,
)
age_min: int = SchemaField(
description="Minimum age for audience targeting (13,15,18,21,25)",
default=0,
advanced=True,
)
target_countries: list[str] = SchemaField(
description="List of country codes to target (max 25)",
default_factory=list,
advanced=True,
)
alt_text: list[str] = SchemaField(
description="Alt text for each media item",
default_factory=list,
advanced=True,
)
video_title: str = SchemaField(
description="Title for video post", default="", advanced=True
)
video_thumbnail: str = SchemaField(
description="Thumbnail URL for video post", default="", advanced=True
)
is_draft: bool = SchemaField(
description="Save as draft in Meta Business Suite",
default=False,
advanced=True,
)
scheduled_publish_date: str = SchemaField(
description="Schedule publish time in Meta Business Suite (UTC)",
default="",
advanced=True,
)
preview_link: str = SchemaField(
description="URL for custom link preview", default="", advanced=True
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
disabled=True,
id="3352f512-3524-49ed-a08f-003042da2fc1",
description="Post to Facebook using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToFacebookBlock.Input,
output_schema=PostToFacebookBlock.Output,
)
async def run(
self,
input_data: "PostToFacebookBlock.Input",
*,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Facebook with Facebook-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
# Build Facebook-specific options
facebook_options = {}
if input_data.is_carousel:
facebook_options["isCarousel"] = True
if input_data.carousel_link:
facebook_options["carouselLink"] = input_data.carousel_link
if input_data.carousel_items:
facebook_options["carouselItems"] = [
item.dict() for item in input_data.carousel_items
]
if input_data.is_reels:
facebook_options["isReels"] = True
if input_data.reels_title:
facebook_options["reelsTitle"] = input_data.reels_title
if input_data.reels_thumbnail:
facebook_options["reelsThumbnail"] = input_data.reels_thumbnail
if input_data.is_story:
facebook_options["isStory"] = True
if input_data.media_captions:
facebook_options["mediaCaptions"] = input_data.media_captions
if input_data.location_id:
facebook_options["locationId"] = input_data.location_id
if input_data.age_min > 0:
facebook_options["ageMin"] = input_data.age_min
if input_data.target_countries:
facebook_options["targetCountries"] = input_data.target_countries
if input_data.alt_text:
facebook_options["altText"] = input_data.alt_text
if input_data.video_title:
facebook_options["videoTitle"] = input_data.video_title
if input_data.video_thumbnail:
facebook_options["videoThumbnail"] = input_data.video_thumbnail
if input_data.is_draft:
facebook_options["isDraft"] = True
if input_data.scheduled_publish_date:
facebook_options["scheduledPublishDate"] = input_data.scheduled_publish_date
if input_data.preview_link:
facebook_options["previewLink"] = input_data.preview_link
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.FACEBOOK],
media_urls=input_data.media_urls,
is_video=input_data.is_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
facebook_options=facebook_options if facebook_options else None,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,210 +0,0 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class PostToGMBBlock(Block):
"""Block for posting to Google My Business with GMB-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for Google My Business posts."""
# Override media_urls to include GMB-specific constraints
media_urls: list[str] = SchemaField(
description="Optional list of media URLs. GMB supports only one image or video per post.",
default_factory=list,
advanced=False,
)
# GMB-specific options
is_photo_video: bool = SchemaField(
description="Whether this is a photo/video post (appears in Photos section)",
default=False,
advanced=True,
)
photo_category: str = SchemaField(
description="Category for photo/video: cover, profile, logo, exterior, interior, product, at_work, food_and_drink, menu, common_area, rooms, teams",
default="",
advanced=True,
)
# Call to action options (flattened from CallToAction object)
call_to_action_type: str = SchemaField(
description="Type of action button: 'book', 'order', 'shop', 'learn_more', 'sign_up', or 'call'",
default="",
advanced=True,
)
call_to_action_url: str = SchemaField(
description="URL for the action button (not required for 'call' action)",
default="",
advanced=True,
)
# Event details options (flattened from EventDetails object)
event_title: str = SchemaField(
description="Event title for event posts",
default="",
advanced=True,
)
event_start_date: str = SchemaField(
description="Event start date in ISO format (e.g., '2024-03-15T09:00:00Z')",
default="",
advanced=True,
)
event_end_date: str = SchemaField(
description="Event end date in ISO format (e.g., '2024-03-15T17:00:00Z')",
default="",
advanced=True,
)
# Offer details options (flattened from OfferDetails object)
offer_title: str = SchemaField(
description="Offer title for promotional posts",
default="",
advanced=True,
)
offer_start_date: str = SchemaField(
description="Offer start date in ISO format (e.g., '2024-03-15T00:00:00Z')",
default="",
advanced=True,
)
offer_end_date: str = SchemaField(
description="Offer end date in ISO format (e.g., '2024-04-15T23:59:59Z')",
default="",
advanced=True,
)
offer_coupon_code: str = SchemaField(
description="Coupon code for the offer (max 58 characters)",
default="",
advanced=True,
)
offer_redeem_online_url: str = SchemaField(
description="URL where customers can redeem the offer online",
default="",
advanced=True,
)
offer_terms_conditions: str = SchemaField(
description="Terms and conditions for the offer",
default="",
advanced=True,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
disabled=True,
id="2c38c783-c484-4503-9280-ef5d1d345a7e",
description="Post to Google My Business using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToGMBBlock.Input,
output_schema=PostToGMBBlock.Output,
)
async def run(
self, input_data: "PostToGMBBlock.Input", *, user_id: str, **kwargs
) -> BlockOutput:
"""Post to Google My Business with GMB-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate GMB constraints
if len(input_data.media_urls) > 1:
yield "error", "Google My Business supports only one image or video per post"
return
# Validate offer coupon code length
if input_data.offer_coupon_code and len(input_data.offer_coupon_code) > 58:
yield "error", "GMB offer coupon code cannot exceed 58 characters"
return
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
# Build GMB-specific options
gmb_options = {}
# Photo/Video post options
if input_data.is_photo_video:
gmb_options["isPhotoVideo"] = True
if input_data.photo_category:
gmb_options["category"] = input_data.photo_category
# Call to Action (from flattened fields)
if input_data.call_to_action_type:
cta_dict = {"actionType": input_data.call_to_action_type}
# URL not required for 'call' action type
if (
input_data.call_to_action_type != "call"
and input_data.call_to_action_url
):
cta_dict["url"] = input_data.call_to_action_url
gmb_options["callToAction"] = cta_dict
# Event details (from flattened fields)
if (
input_data.event_title
and input_data.event_start_date
and input_data.event_end_date
):
gmb_options["event"] = {
"title": input_data.event_title,
"startDate": input_data.event_start_date,
"endDate": input_data.event_end_date,
}
# Offer details (from flattened fields)
if (
input_data.offer_title
and input_data.offer_start_date
and input_data.offer_end_date
and input_data.offer_coupon_code
and input_data.offer_redeem_online_url
and input_data.offer_terms_conditions
):
gmb_options["offer"] = {
"title": input_data.offer_title,
"startDate": input_data.offer_start_date,
"endDate": input_data.offer_end_date,
"couponCode": input_data.offer_coupon_code,
"redeemOnlineUrl": input_data.offer_redeem_online_url,
"termsConditions": input_data.offer_terms_conditions,
}
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.GOOGLE_MY_BUSINESS],
media_urls=input_data.media_urls,
is_video=input_data.is_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
gmb_options=gmb_options if gmb_options else None,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,249 +0,0 @@
from typing import Any
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import (
BaseAyrshareInput,
InstagramUserTag,
create_ayrshare_client,
get_profile_key,
)
class PostToInstagramBlock(Block):
"""Block for posting to Instagram with Instagram-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for Instagram posts."""
# Override post field to include Instagram-specific information
post: str = SchemaField(
description="The post text (max 2,200 chars, up to 30 hashtags, 3 @mentions)",
default="",
advanced=False,
)
# Override media_urls to include Instagram-specific constraints
media_urls: list[str] = SchemaField(
description="Optional list of media URLs. Instagram supports up to 10 images/videos in a carousel.",
default_factory=list,
advanced=False,
)
# Instagram-specific options
is_story: bool | None = SchemaField(
description="Whether to post as Instagram Story (24-hour expiration)",
default=None,
advanced=True,
)
# ------- REELS OPTIONS -------
share_reels_feed: bool | None = SchemaField(
description="Whether Reel should appear in both Feed and Reels tabs",
default=None,
advanced=True,
)
audio_name: str | None = SchemaField(
description="Audio name for Reels (e.g., 'The Weeknd - Blinding Lights')",
default=None,
advanced=True,
)
thumbnail: str | None = SchemaField(
description="Thumbnail URL for Reel video", default=None, advanced=True
)
thumbnail_offset: int | None = SchemaField(
description="Thumbnail frame offset in milliseconds (default: 0)",
default=0,
advanced=True,
)
# ------- POST OPTIONS -------
alt_text: list[str] = SchemaField(
description="Alt text for each media item (up to 1,000 chars each, accessibility feature), each item in the list corresponds to a media item in the media_urls list",
default_factory=list,
advanced=True,
)
location_id: str | None = SchemaField(
description="Facebook Page ID or name for location tagging (e.g., '7640348500' or '@guggenheimmuseum')",
default=None,
advanced=True,
)
user_tags: list[dict[str, Any]] = SchemaField(
description="List of users to tag with coordinates for images",
default_factory=list,
advanced=True,
)
collaborators: list[str] = SchemaField(
description="Instagram usernames to invite as collaborators (max 3, public accounts only)",
default_factory=list,
advanced=True,
)
auto_resize: bool | None = SchemaField(
description="Auto-resize images to 1080x1080px for Instagram",
default=None,
advanced=True,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
id="89b02b96-a7cb-46f4-9900-c48b32fe1552",
description="Post to Instagram using Ayrshare. Requires a Business or Creator Instagram Account connected with a Facebook Page",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToInstagramBlock.Input,
output_schema=PostToInstagramBlock.Output,
)
async def run(
self,
input_data: "PostToInstagramBlock.Input",
*,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Instagram with Instagram-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate Instagram constraints
if len(input_data.post) > 2200:
yield "error", f"Instagram post text exceeds 2,200 character limit ({len(input_data.post)} characters)"
return
if len(input_data.media_urls) > 10:
yield "error", "Instagram supports a maximum of 10 images/videos in a carousel"
return
if len(input_data.collaborators) > 3:
yield "error", "Instagram supports a maximum of 3 collaborators"
return
# Validate that if any reel option is set, all required reel options are set
reel_options = [
input_data.share_reels_feed,
input_data.audio_name,
input_data.thumbnail,
]
if any(reel_options) and not all(reel_options):
yield "error", "When posting a reel, all reel options must be set: share_reels_feed, audio_name, and either thumbnail or thumbnail_offset"
return
# Count hashtags and mentions
hashtag_count = input_data.post.count("#")
mention_count = input_data.post.count("@")
if hashtag_count > 30:
yield "error", f"Instagram allows maximum 30 hashtags ({hashtag_count} found)"
return
if mention_count > 3:
yield "error", f"Instagram allows maximum 3 @mentions ({mention_count} found)"
return
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
# Build Instagram-specific options
instagram_options = {}
# Stories
if input_data.is_story:
instagram_options["stories"] = True
# Reels options
if input_data.share_reels_feed is not None:
instagram_options["shareReelsFeed"] = input_data.share_reels_feed
if input_data.audio_name:
instagram_options["audioName"] = input_data.audio_name
if input_data.thumbnail:
instagram_options["thumbNail"] = input_data.thumbnail
elif input_data.thumbnail_offset and input_data.thumbnail_offset > 0:
instagram_options["thumbNailOffset"] = input_data.thumbnail_offset
# Alt text
if input_data.alt_text:
# Validate alt text length
for i, alt in enumerate(input_data.alt_text):
if len(alt) > 1000:
yield "error", f"Alt text {i+1} exceeds 1,000 character limit ({len(alt)} characters)"
return
instagram_options["altText"] = input_data.alt_text
# Location
if input_data.location_id:
instagram_options["locationId"] = input_data.location_id
# User tags
if input_data.user_tags:
user_tags_list = []
for tag in input_data.user_tags:
try:
tag_obj = InstagramUserTag(**tag)
except Exception as e:
yield "error", f"Invalid user tag: {e}, tages need to be a dictionary with a 3 items: username (str), x (float) and y (float)"
return
tag_dict: dict[str, float | str] = {"username": tag_obj.username}
if tag_obj.x is not None and tag_obj.y is not None:
# Validate coordinates
if not (0.0 <= tag_obj.x <= 1.0) or not (0.0 <= tag_obj.y <= 1.0):
yield "error", f"User tag coordinates must be between 0.0 and 1.0 (user: {tag_obj.username})"
return
tag_dict["x"] = tag_obj.x
tag_dict["y"] = tag_obj.y
user_tags_list.append(tag_dict)
instagram_options["userTags"] = user_tags_list
# Collaborators
if input_data.collaborators:
instagram_options["collaborators"] = input_data.collaborators
# Auto resize
if input_data.auto_resize:
instagram_options["autoResize"] = True
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.INSTAGRAM],
media_urls=input_data.media_urls,
is_video=input_data.is_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
instagram_options=instagram_options if instagram_options else None,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,222 +0,0 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class PostToLinkedInBlock(Block):
"""Block for posting to LinkedIn with LinkedIn-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for LinkedIn posts."""
# Override post field to include LinkedIn-specific information
post: str = SchemaField(
description="The post text (max 3,000 chars, hashtags supported with #)",
default="",
advanced=False,
)
# Override media_urls to include LinkedIn-specific constraints
media_urls: list[str] = SchemaField(
description="Optional list of media URLs. LinkedIn supports up to 9 images, videos, or documents (PPT, PPTX, DOC, DOCX, PDF <100MB, <300 pages).",
default_factory=list,
advanced=False,
)
# LinkedIn-specific options
visibility: str = SchemaField(
description="Post visibility: 'public' (default), 'connections' (personal only), 'loggedin'",
default="public",
advanced=True,
)
alt_text: list[str] = SchemaField(
description="Alt text for each image (accessibility feature, not supported for videos/documents)",
default_factory=list,
advanced=True,
)
titles: list[str] = SchemaField(
description="Title/caption for each image or video",
default_factory=list,
advanced=True,
)
document_title: str = SchemaField(
description="Title for document posts (max 400 chars, uses filename if not specified)",
default="",
advanced=True,
)
thumbnail: str = SchemaField(
description="Thumbnail URL for video (PNG/JPG, same dimensions as video, <10MB)",
default="",
advanced=True,
)
# LinkedIn targeting options (flattened from LinkedInTargeting object)
targeting_countries: list[str] | None = SchemaField(
description="Country codes for targeting (e.g., ['US', 'IN', 'DE', 'GB']). Requires 300+ followers in target audience.",
default=None,
advanced=True,
)
targeting_seniorities: list[str] | None = SchemaField(
description="Seniority levels for targeting (e.g., ['Senior', 'VP']). Requires 300+ followers in target audience.",
default=None,
advanced=True,
)
targeting_degrees: list[str] | None = SchemaField(
description="Education degrees for targeting. Requires 300+ followers in target audience.",
default=None,
advanced=True,
)
targeting_fields_of_study: list[str] | None = SchemaField(
description="Fields of study for targeting. Requires 300+ followers in target audience.",
default=None,
advanced=True,
)
targeting_industries: list[str] | None = SchemaField(
description="Industry categories for targeting. Requires 300+ followers in target audience.",
default=None,
advanced=True,
)
targeting_job_functions: list[str] | None = SchemaField(
description="Job function categories for targeting. Requires 300+ followers in target audience.",
default=None,
advanced=True,
)
targeting_staff_count_ranges: list[str] | None = SchemaField(
description="Company size ranges for targeting. Requires 300+ followers in target audience.",
default=None,
advanced=True,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
id="589af4e4-507f-42fd-b9ac-a67ecef25811",
description="Post to LinkedIn using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToLinkedInBlock.Input,
output_schema=PostToLinkedInBlock.Output,
)
async def run(
self,
input_data: "PostToLinkedInBlock.Input",
*,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to LinkedIn with LinkedIn-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate LinkedIn constraints
if len(input_data.post) > 3000:
yield "error", f"LinkedIn post text exceeds 3,000 character limit ({len(input_data.post)} characters)"
return
if len(input_data.media_urls) > 9:
yield "error", "LinkedIn supports a maximum of 9 images/videos/documents"
return
if input_data.document_title and len(input_data.document_title) > 400:
yield "error", f"LinkedIn document title exceeds 400 character limit ({len(input_data.document_title)} characters)"
return
# Validate visibility option
valid_visibility = ["public", "connections", "loggedin"]
if input_data.visibility not in valid_visibility:
yield "error", f"LinkedIn visibility must be one of: {', '.join(valid_visibility)}"
return
# Check for document extensions
document_extensions = [".ppt", ".pptx", ".doc", ".docx", ".pdf"]
has_documents = any(
any(url.lower().endswith(ext) for ext in document_extensions)
for url in input_data.media_urls
)
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
# Build LinkedIn-specific options
linkedin_options = {}
# Visibility
if input_data.visibility != "public":
linkedin_options["visibility"] = input_data.visibility
# Alt text (not supported for videos or documents)
if input_data.alt_text and not has_documents:
linkedin_options["altText"] = input_data.alt_text
# Titles/captions
if input_data.titles:
linkedin_options["titles"] = input_data.titles
# Document title
if input_data.document_title and has_documents:
linkedin_options["title"] = input_data.document_title
# Video thumbnail
if input_data.thumbnail:
linkedin_options["thumbNail"] = input_data.thumbnail
# Audience targeting (from flattened fields)
targeting_dict = {}
if input_data.targeting_countries:
targeting_dict["countries"] = input_data.targeting_countries
if input_data.targeting_seniorities:
targeting_dict["seniorities"] = input_data.targeting_seniorities
if input_data.targeting_degrees:
targeting_dict["degrees"] = input_data.targeting_degrees
if input_data.targeting_fields_of_study:
targeting_dict["fieldsOfStudy"] = input_data.targeting_fields_of_study
if input_data.targeting_industries:
targeting_dict["industries"] = input_data.targeting_industries
if input_data.targeting_job_functions:
targeting_dict["jobFunctions"] = input_data.targeting_job_functions
if input_data.targeting_staff_count_ranges:
targeting_dict["staffCountRanges"] = input_data.targeting_staff_count_ranges
if targeting_dict:
linkedin_options["targeting"] = targeting_dict
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.LINKEDIN],
media_urls=input_data.media_urls,
is_video=input_data.is_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
linkedin_options=linkedin_options if linkedin_options else None,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,214 +0,0 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import (
BaseAyrshareInput,
PinterestCarouselOption,
create_ayrshare_client,
get_profile_key,
)
class PostToPinterestBlock(Block):
"""Block for posting to Pinterest with Pinterest-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for Pinterest posts."""
# Override post field to include Pinterest-specific information
post: str = SchemaField(
description="Pin description (max 500 chars, links not clickable - use link field instead)",
default="",
advanced=False,
)
# Override media_urls to include Pinterest-specific constraints
media_urls: list[str] = SchemaField(
description="Required image/video URLs. Pinterest requires at least one image. Videos need thumbnail. Up to 5 images for carousel.",
default_factory=list,
advanced=False,
)
# Pinterest-specific options
pin_title: str = SchemaField(
description="Pin title displayed in 'Add your title' section (max 100 chars)",
default="",
advanced=True,
)
link: str = SchemaField(
description="Clickable destination URL when users click the pin (max 2048 chars)",
default="",
advanced=True,
)
board_id: str = SchemaField(
description="Pinterest Board ID to post to (from /user/details endpoint, uses default board if not specified)",
default="",
advanced=True,
)
note: str = SchemaField(
description="Private note for the pin (only visible to you and board collaborators)",
default="",
advanced=True,
)
thumbnail: str = SchemaField(
description="Required thumbnail URL for video pins (must have valid image Content-Type)",
default="",
advanced=True,
)
carousel_options: list[PinterestCarouselOption] = SchemaField(
description="Options for each image in carousel (title, link, description per image)",
default_factory=list,
advanced=True,
)
alt_text: list[str] = SchemaField(
description="Alt text for each image/video (max 500 chars each, accessibility feature)",
default_factory=list,
advanced=True,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
disabled=True,
id="3ca46e05-dbaa-4afb-9e95-5a429c4177e6",
description="Post to Pinterest using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToPinterestBlock.Input,
output_schema=PostToPinterestBlock.Output,
)
async def run(
self,
input_data: "PostToPinterestBlock.Input",
*,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Pinterest with Pinterest-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate Pinterest constraints
if len(input_data.post) > 500:
yield "error", f"Pinterest pin description exceeds 500 character limit ({len(input_data.post)} characters)"
return
if len(input_data.pin_title) > 100:
yield "error", f"Pinterest pin title exceeds 100 character limit ({len(input_data.pin_title)} characters)"
return
if len(input_data.link) > 2048:
yield "error", f"Pinterest link URL exceeds 2048 character limit ({len(input_data.link)} characters)"
return
if len(input_data.media_urls) == 0:
yield "error", "Pinterest requires at least one image or video"
return
if len(input_data.media_urls) > 5:
yield "error", "Pinterest supports a maximum of 5 images in a carousel"
return
# Check if video is included and thumbnail is provided
video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".webm"]
has_video = any(
any(url.lower().endswith(ext) for ext in video_extensions)
for url in input_data.media_urls
)
if (has_video or input_data.is_video) and not input_data.thumbnail:
yield "error", "Pinterest video pins require a thumbnail URL"
return
# Validate alt text length
for i, alt in enumerate(input_data.alt_text):
if len(alt) > 500:
yield "error", f"Pinterest alt text {i+1} exceeds 500 character limit ({len(alt)} characters)"
return
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
# Build Pinterest-specific options
pinterest_options = {}
# Pin title
if input_data.pin_title:
pinterest_options["title"] = input_data.pin_title
# Clickable link
if input_data.link:
pinterest_options["link"] = input_data.link
# Board ID
if input_data.board_id:
pinterest_options["boardId"] = input_data.board_id
# Private note
if input_data.note:
pinterest_options["note"] = input_data.note
# Video thumbnail
if input_data.thumbnail:
pinterest_options["thumbNail"] = input_data.thumbnail
# Carousel options
if input_data.carousel_options:
carousel_list = []
for option in input_data.carousel_options:
carousel_dict = {}
if option.title:
carousel_dict["title"] = option.title
if option.link:
carousel_dict["link"] = option.link
if option.description:
carousel_dict["description"] = option.description
if carousel_dict: # Only add if not empty
carousel_list.append(carousel_dict)
if carousel_list:
pinterest_options["carouselOptions"] = carousel_list
# Alt text
if input_data.alt_text:
pinterest_options["altText"] = input_data.alt_text
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.PINTEREST],
media_urls=input_data.media_urls,
is_video=input_data.is_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
pinterest_options=pinterest_options if pinterest_options else None,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,69 +0,0 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class PostToRedditBlock(Block):
"""Block for posting to Reddit."""
class Input(BaseAyrshareInput):
"""Input schema for Reddit posts."""
pass # Uses all base fields
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
disabled=True,
id="c7733580-3c72-483e-8e47-a8d58754d853",
description="Post to Reddit using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToRedditBlock.Input,
output_schema=PostToRedditBlock.Output,
)
async def run(
self, input_data: "PostToRedditBlock.Input", *, user_id: str, **kwargs
) -> BlockOutput:
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured."
return
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.REDDIT],
media_urls=input_data.media_urls,
is_video=input_data.is_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,129 +0,0 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class PostToSnapchatBlock(Block):
"""Block for posting to Snapchat with Snapchat-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for Snapchat posts."""
# Override post field to include Snapchat-specific information
post: str = SchemaField(
description="The post text (optional for video-only content)",
default="",
advanced=False,
)
# Override media_urls to include Snapchat-specific constraints
media_urls: list[str] = SchemaField(
description="Required video URL for Snapchat posts. Snapchat only supports video content.",
default_factory=list,
advanced=False,
)
# Snapchat-specific options
story_type: str = SchemaField(
description="Type of Snapchat content: 'story' (24-hour Stories), 'saved_story' (Saved Stories), or 'spotlight' (Spotlight posts)",
default="story",
advanced=True,
)
video_thumbnail: str = SchemaField(
description="Thumbnail URL for video content (optional, auto-generated if not provided)",
default="",
advanced=True,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
disabled=True,
id="a9d7f854-2c83-4e96-b3a1-7f2e9c5d4b8e",
description="Post to Snapchat using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToSnapchatBlock.Input,
output_schema=PostToSnapchatBlock.Output,
)
async def run(
self,
input_data: "PostToSnapchatBlock.Input",
*,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Snapchat with Snapchat-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate Snapchat constraints
if not input_data.media_urls:
yield "error", "Snapchat requires at least one video URL"
return
if len(input_data.media_urls) > 1:
yield "error", "Snapchat supports only one video per post"
return
# Validate story type
valid_story_types = ["story", "saved_story", "spotlight"]
if input_data.story_type not in valid_story_types:
yield "error", f"Snapchat story type must be one of: {', '.join(valid_story_types)}"
return
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
# Build Snapchat-specific options
snapchat_options = {}
# Story type
if input_data.story_type != "story":
snapchat_options["storyType"] = input_data.story_type
# Video thumbnail
if input_data.video_thumbnail:
snapchat_options["videoThumbnail"] = input_data.video_thumbnail
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.SNAPCHAT],
media_urls=input_data.media_urls,
is_video=True, # Snapchat only supports video
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
snapchat_options=snapchat_options if snapchat_options else None,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,116 +0,0 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class PostToTelegramBlock(Block):
"""Block for posting to Telegram with Telegram-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for Telegram posts."""
# Override post field to include Telegram-specific information
post: str = SchemaField(
description="The post text (empty string allowed). Use @handle to mention other Telegram users.",
default="",
advanced=False,
)
# Override media_urls to include Telegram-specific constraints
media_urls: list[str] = SchemaField(
description="Optional list of media URLs. For animated GIFs, only one URL is allowed. Telegram will auto-preview links unless image/video is included.",
default_factory=list,
advanced=False,
)
# Override is_video to include GIF-specific information
is_video: bool = SchemaField(
description="Whether the media is a video. Set to true for animated GIFs that don't end in .gif/.GIF extension.",
default=False,
advanced=True,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
disabled=True,
id="47bc74eb-4af2-452c-b933-af377c7287df",
description="Post to Telegram using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToTelegramBlock.Input,
output_schema=PostToTelegramBlock.Output,
)
async def run(
self,
input_data: "PostToTelegramBlock.Input",
*,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Telegram with Telegram-specific validation."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate Telegram constraints
# Check for animated GIFs - only one URL allowed
gif_extensions = [".gif", ".GIF"]
has_gif = any(
any(url.endswith(ext) for ext in gif_extensions)
for url in input_data.media_urls
)
if has_gif and len(input_data.media_urls) > 1:
yield "error", "Telegram animated GIFs support only one URL per post"
return
# Auto-detect if we need to set is_video for GIFs without proper extension
detected_is_video = input_data.is_video
if input_data.media_urls and not has_gif and not input_data.is_video:
# Check if this might be a GIF without proper extension
# This is just informational - user needs to set is_video manually
pass
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.TELEGRAM],
media_urls=input_data.media_urls,
is_video=detected_is_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,111 +0,0 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class PostToThreadsBlock(Block):
"""Block for posting to Threads with Threads-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for Threads posts."""
# Override post field to include Threads-specific information
post: str = SchemaField(
description="The post text (max 500 chars, empty string allowed). Only 1 hashtag allowed. Use @handle to mention users.",
default="",
advanced=False,
)
# Override media_urls to include Threads-specific constraints
media_urls: list[str] = SchemaField(
description="Optional list of media URLs. Supports up to 20 images/videos in a carousel. Auto-preview links unless media is included.",
default_factory=list,
advanced=False,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
disabled=True,
id="f8c3b2e1-9d4a-4e5f-8c7b-6a9e8d2f1c3b",
description="Post to Threads using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToThreadsBlock.Input,
output_schema=PostToThreadsBlock.Output,
)
async def run(
self,
input_data: "PostToThreadsBlock.Input",
*,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Threads with Threads-specific validation."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate Threads constraints
if len(input_data.post) > 500:
yield "error", f"Threads post text exceeds 500 character limit ({len(input_data.post)} characters)"
return
if len(input_data.media_urls) > 20:
yield "error", "Threads supports a maximum of 20 images/videos in a carousel"
return
# Count hashtags (only 1 allowed)
hashtag_count = input_data.post.count("#")
if hashtag_count > 1:
yield "error", f"Threads allows only 1 hashtag per post ({hashtag_count} found)"
return
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
# Build Threads-specific options
threads_options = {}
# Note: Based on the documentation, Threads doesn't seem to have specific options
# beyond the standard ones. The main constraints are validation-based.
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.THREADS],
media_urls=input_data.media_urls,
is_video=input_data.is_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
threads_options=threads_options if threads_options else None,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,243 +0,0 @@
from enum import Enum
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class TikTokVisibility(str, Enum):
PUBLIC = "public"
PRIVATE = "private"
FOLLOWERS = "followers"
class PostToTikTokBlock(Block):
"""Block for posting to TikTok with TikTok-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for TikTok posts."""
# Override post field to include TikTok-specific information
post: str = SchemaField(
description="The post text (max 2,200 chars, empty string allowed). Use @handle to mention users. Line breaks will be ignored.",
advanced=False,
)
# Override media_urls to include TikTok-specific constraints
media_urls: list[str] = SchemaField(
description="Required media URLs. Either 1 video OR up to 35 images (JPG/JPEG/WEBP only). Cannot mix video and images.",
default_factory=list,
advanced=False,
)
# TikTok-specific options
auto_add_music: bool = SchemaField(
description="Whether to automatically add recommended music to the post. If you set this field to true, you can change the music later in the TikTok app.",
default=False,
advanced=True,
)
disable_comments: bool = SchemaField(
description="Disable comments on the published post",
default=False,
advanced=True,
)
disable_duet: bool = SchemaField(
description="Disable duets on published video (video only)",
default=False,
advanced=True,
)
disable_stitch: bool = SchemaField(
description="Disable stitch on published video (video only)",
default=False,
advanced=True,
)
is_ai_generated: bool = SchemaField(
description="If you enable the toggle, your video will be labeled as “Creator labeled as AI-generated” once posted and cant be changed. The “Creator labeled as AI-generated” label indicates that the content was completely AI-generated or significantly edited with AI.",
default=False,
advanced=True,
)
is_branded_content: bool = SchemaField(
description="Whether to enable the Branded Content toggle. If this field is set to true, the video will be labeled as Branded Content, indicating you are in a paid partnership with a brand. A “Paid partnership” label will be attached to the video.",
default=False,
advanced=True,
)
is_brand_organic: bool = SchemaField(
description="Whether to enable the Brand Organic Content toggle. If this field is set to true, the video will be labeled as Brand Organic Content, indicating you are promoting yourself or your own business. A “Promotional content” label will be attached to the video.",
default=False,
advanced=True,
)
image_cover_index: int = SchemaField(
description="Index of image to use as cover (0-based, image posts only)",
default=0,
advanced=True,
)
title: str = SchemaField(
description="Title for image posts", default="", advanced=True
)
thumbnail_offset: int = SchemaField(
description="Video thumbnail frame offset in milliseconds (video only)",
default=0,
advanced=True,
)
visibility: TikTokVisibility = SchemaField(
description="Post visibility: 'public', 'private', 'followers', or 'friends'",
default=TikTokVisibility.PUBLIC,
advanced=True,
)
draft: bool = SchemaField(
description="Create as draft post (video only)",
default=False,
advanced=True,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
id="7faf4b27-96b0-4f05-bf64-e0de54ae74e1",
description="Post to TikTok using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToTikTokBlock.Input,
output_schema=PostToTikTokBlock.Output,
)
async def run(
self, input_data: "PostToTikTokBlock.Input", *, user_id: str, **kwargs
) -> BlockOutput:
"""Post to TikTok with TikTok-specific validation and options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate TikTok constraints
if len(input_data.post) > 2200:
yield "error", f"TikTok post text exceeds 2,200 character limit ({len(input_data.post)} characters)"
return
if not input_data.media_urls:
yield "error", "TikTok requires at least one media URL (either 1 video or up to 35 images)"
return
# Check for video vs image constraints
video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".webm"]
image_extensions = [".jpg", ".jpeg", ".webp"]
has_video = input_data.is_video or any(
any(url.lower().endswith(ext) for ext in video_extensions)
for url in input_data.media_urls
)
has_images = any(
any(url.lower().endswith(ext) for ext in image_extensions)
for url in input_data.media_urls
)
if has_video and has_images:
yield "error", "TikTok does not support mixing video and images in the same post"
return
if has_video and len(input_data.media_urls) > 1:
yield "error", "TikTok supports only 1 video per post"
return
if has_images and len(input_data.media_urls) > 35:
yield "error", "TikTok supports a maximum of 35 images per post"
return
# Validate image cover index
if has_images and input_data.image_cover_index >= len(input_data.media_urls):
yield "error", f"Image cover index {input_data.image_cover_index} is out of range (max: {len(input_data.media_urls) - 1})"
return
# Check for PNG files (not supported)
has_png = any(url.lower().endswith(".png") for url in input_data.media_urls)
if has_png:
yield "error", "TikTok does not support PNG files. Please use JPG, JPEG, or WEBP for images."
return
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
# Build TikTok-specific options
tiktok_options = {}
# Common options
if input_data.auto_add_music and has_images:
tiktok_options["autoAddMusic"] = True
if input_data.disable_comments:
tiktok_options["disableComments"] = True
if input_data.is_branded_content:
tiktok_options["isBrandedContent"] = True
if input_data.is_brand_organic:
tiktok_options["isBrandOrganic"] = True
# Video-specific options
if has_video:
if input_data.disable_duet:
tiktok_options["disableDuet"] = True
if input_data.disable_stitch:
tiktok_options["disableStitch"] = True
if input_data.is_ai_generated:
tiktok_options["isAIGenerated"] = True
if input_data.thumbnail_offset > 0:
tiktok_options["thumbNailOffset"] = input_data.thumbnail_offset
if input_data.draft:
tiktok_options["draft"] = True
# Image-specific options
if has_images:
if input_data.image_cover_index > 0:
tiktok_options["imageCoverIndex"] = input_data.image_cover_index
if input_data.title:
tiktok_options["title"] = input_data.title
if input_data.visibility != TikTokVisibility.PUBLIC:
tiktok_options["visibility"] = input_data.visibility.value
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.TIKTOK],
media_urls=input_data.media_urls,
is_video=has_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
tiktok_options=tiktok_options if tiktok_options else None,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,241 +0,0 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class PostToXBlock(Block):
"""Block for posting to X / Twitter with Twitter-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for X / Twitter posts."""
# Override post field to include X-specific information
post: str = SchemaField(
description="The post text (max 280 chars, up to 25,000 for Premium users). Use @handle to mention users. Use \\n\\n for thread breaks.",
advanced=False,
)
# Override media_urls to include X-specific constraints
media_urls: list[str] = SchemaField(
description="Optional list of media URLs. X supports up to 4 images or videos per tweet. Auto-preview links unless media is included.",
default_factory=list,
advanced=False,
)
# X-specific options
reply_to_id: str | None = SchemaField(
description="ID of the tweet to reply to",
default=None,
advanced=True,
)
quote_tweet_id: str | None = SchemaField(
description="ID of the tweet to quote (low-level Tweet ID)",
default=None,
advanced=True,
)
poll_options: list[str] = SchemaField(
description="Poll options (2-4 choices)",
default_factory=list,
advanced=True,
)
poll_duration: int = SchemaField(
description="Poll duration in minutes (1-10080)",
default=1440,
advanced=True,
)
alt_text: list[str] = SchemaField(
description="Alt text for each image (max 1,000 chars each, not supported for videos)",
default_factory=list,
advanced=True,
)
is_thread: bool = SchemaField(
description="Whether to automatically break post into thread based on line breaks",
default=False,
advanced=True,
)
thread_number: bool = SchemaField(
description="Add thread numbers (1/n format) to each thread post",
default=False,
advanced=True,
)
thread_media_urls: list[str] = SchemaField(
description="Media URLs for thread posts (one per thread, use 'null' to skip)",
default_factory=list,
advanced=True,
)
long_post: bool = SchemaField(
description="Force long form post (requires Premium X account)",
default=False,
advanced=True,
)
long_video: bool = SchemaField(
description="Enable long video upload (requires approval and Business/Enterprise plan)",
default=False,
advanced=True,
)
subtitle_url: str = SchemaField(
description="URL to SRT subtitle file for videos (must be HTTPS and end in .srt)",
default="",
advanced=True,
)
subtitle_language: str = SchemaField(
description="Language code for subtitles (default: 'en')",
default="en",
advanced=True,
)
subtitle_name: str = SchemaField(
description="Name of caption track (max 150 chars, default: 'English')",
default="English",
advanced=True,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
id="9e8f844e-b4a5-4b25-80f2-9e1dd7d67625",
description="Post to X / Twitter using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToXBlock.Input,
output_schema=PostToXBlock.Output,
)
async def run(
self,
input_data: "PostToXBlock.Input",
*,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to X / Twitter with enhanced X-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate X constraints
if not input_data.long_post and len(input_data.post) > 280:
yield "error", f"X post text exceeds 280 character limit ({len(input_data.post)} characters). Enable 'long_post' for Premium accounts."
return
if input_data.long_post and len(input_data.post) > 25000:
yield "error", f"X long post text exceeds 25,000 character limit ({len(input_data.post)} characters)"
return
if len(input_data.media_urls) > 4:
yield "error", "X supports a maximum of 4 images or videos per tweet"
return
# Validate poll options
if input_data.poll_options:
if len(input_data.poll_options) < 2 or len(input_data.poll_options) > 4:
yield "error", "X polls require 2-4 options"
return
if input_data.poll_duration < 1 or input_data.poll_duration > 10080:
yield "error", "X poll duration must be between 1 and 10,080 minutes (7 days)"
return
# Validate alt text
if input_data.alt_text:
for i, alt in enumerate(input_data.alt_text):
if len(alt) > 1000:
yield "error", f"X alt text {i+1} exceeds 1,000 character limit ({len(alt)} characters)"
return
# Validate subtitle settings
if input_data.subtitle_url:
if not input_data.subtitle_url.startswith(
"https://"
) or not input_data.subtitle_url.endswith(".srt"):
yield "error", "Subtitle URL must start with https:// and end with .srt"
return
if len(input_data.subtitle_name) > 150:
yield "error", f"Subtitle name exceeds 150 character limit ({len(input_data.subtitle_name)} characters)"
return
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
# Build X-specific options
twitter_options = {}
# Basic options
if input_data.reply_to_id:
twitter_options["replyToId"] = input_data.reply_to_id
if input_data.quote_tweet_id:
twitter_options["quoteTweetId"] = input_data.quote_tweet_id
if input_data.long_post:
twitter_options["longPost"] = True
if input_data.long_video:
twitter_options["longVideo"] = True
# Poll options
if input_data.poll_options:
twitter_options["poll"] = {
"duration": input_data.poll_duration,
"options": input_data.poll_options,
}
# Alt text for images
if input_data.alt_text:
twitter_options["altText"] = input_data.alt_text
# Thread options
if input_data.is_thread:
twitter_options["thread"] = True
if input_data.thread_number:
twitter_options["threadNumber"] = True
if input_data.thread_media_urls:
twitter_options["mediaUrls"] = input_data.thread_media_urls
# Video subtitle options
if input_data.subtitle_url:
twitter_options["subTitleUrl"] = input_data.subtitle_url
twitter_options["subTitleLanguage"] = input_data.subtitle_language
twitter_options["subTitleName"] = input_data.subtitle_name
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.TWITTER],
media_urls=input_data.media_urls,
is_video=input_data.is_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
twitter_options=twitter_options if twitter_options else None,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,310 +0,0 @@
from enum import Enum
from typing import Any
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class YouTubeVisibility(str, Enum):
PRIVATE = "private"
PUBLIC = "public"
UNLISTED = "unlisted"
class PostToYouTubeBlock(Block):
"""Block for posting to YouTube with YouTube-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for YouTube posts."""
# Override post field to include YouTube-specific information
post: str = SchemaField(
description="Video description (max 5,000 chars, empty string allowed). Cannot contain < or > characters.",
advanced=False,
)
# Override media_urls to include YouTube-specific constraints
media_urls: list[str] = SchemaField(
description="Required video URL. YouTube only supports 1 video per post.",
default_factory=list,
advanced=False,
)
# YouTube-specific required options
title: str = SchemaField(
description="Video title (max 100 chars, required). Cannot contain < or > characters.",
advanced=False,
)
# YouTube-specific optional options
visibility: YouTubeVisibility = SchemaField(
description="Video visibility: 'private' (default), 'public' , or 'unlisted'",
default=YouTubeVisibility.PRIVATE,
advanced=False,
)
thumbnail: str | None = SchemaField(
description="Thumbnail URL (JPEG/PNG under 2MB, must end in .png/.jpg/.jpeg). Requires phone verification.",
default=None,
advanced=True,
)
playlist_id: str | None = SchemaField(
description="Playlist ID to add video (user must own playlist)",
default=None,
advanced=True,
)
tags: list[str] | None = SchemaField(
description="Video tags (min 2 chars each, max 500 chars total)",
default=None,
advanced=True,
)
made_for_kids: bool | None = SchemaField(
description="Self-declared kids content", default=None, advanced=True
)
is_shorts: bool | None = SchemaField(
description="Post as YouTube Short (max 3 minutes, adds #shorts)",
default=None,
advanced=True,
)
notify_subscribers: bool | None = SchemaField(
description="Send notification to subscribers", default=None, advanced=True
)
category_id: int | None = SchemaField(
description="Video category ID (e.g., 24 = Entertainment)",
default=None,
advanced=True,
)
contains_synthetic_media: bool | None = SchemaField(
description="Disclose realistic AI/synthetic content",
default=None,
advanced=True,
)
publish_at: str | None = SchemaField(
description="UTC publish time (YouTube controlled, format: 2022-10-08T21:18:36Z)",
default=None,
advanced=True,
)
# YouTube targeting options (flattened from YouTubeTargeting object)
targeting_block_countries: list[str] | None = SchemaField(
description="Country codes to block from viewing (e.g., ['US', 'CA'])",
default=None,
advanced=True,
)
targeting_allow_countries: list[str] | None = SchemaField(
description="Country codes to allow viewing (e.g., ['GB', 'AU'])",
default=None,
advanced=True,
)
subtitle_url: str | None = SchemaField(
description="URL to SRT or SBV subtitle file (must be HTTPS and end in .srt/.sbv, under 100MB)",
default=None,
advanced=True,
)
subtitle_language: str | None = SchemaField(
description="Language code for subtitles (default: 'en')",
default=None,
advanced=True,
)
subtitle_name: str | None = SchemaField(
description="Name of caption track (max 150 chars, default: 'English')",
default=None,
advanced=True,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
id="0082d712-ff1b-4c3d-8a8d-6c7721883b83",
description="Post to YouTube using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToYouTubeBlock.Input,
output_schema=PostToYouTubeBlock.Output,
)
async def run(
self,
input_data: "PostToYouTubeBlock.Input",
*,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to YouTube with YouTube-specific validation and options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate YouTube constraints
if not input_data.title:
yield "error", "YouTube requires a video title"
return
if len(input_data.title) > 100:
yield "error", f"YouTube title exceeds 100 character limit ({len(input_data.title)} characters)"
return
if len(input_data.post) > 5000:
yield "error", f"YouTube description exceeds 5,000 character limit ({len(input_data.post)} characters)"
return
# Check for forbidden characters
forbidden_chars = ["<", ">"]
for char in forbidden_chars:
if char in input_data.title:
yield "error", f"YouTube title cannot contain '{char}' character"
return
if char in input_data.post:
yield "error", f"YouTube description cannot contain '{char}' character"
return
if not input_data.media_urls:
yield "error", "YouTube requires exactly one video URL"
return
if len(input_data.media_urls) > 1:
yield "error", "YouTube supports only 1 video per post"
return
# Validate visibility option
valid_visibility = ["private", "public", "unlisted"]
if input_data.visibility not in valid_visibility:
yield "error", f"YouTube visibility must be one of: {', '.join(valid_visibility)}"
return
# Validate thumbnail URL format
if input_data.thumbnail:
valid_extensions = [".png", ".jpg", ".jpeg"]
if not any(
input_data.thumbnail.lower().endswith(ext) for ext in valid_extensions
):
yield "error", "YouTube thumbnail must end in .png, .jpg, or .jpeg"
return
# Validate tags
if input_data.tags:
total_tag_length = sum(len(tag) for tag in input_data.tags)
if total_tag_length > 500:
yield "error", f"YouTube tags total length exceeds 500 characters ({total_tag_length} characters)"
return
for tag in input_data.tags:
if len(tag) < 2:
yield "error", f"YouTube tag '{tag}' is too short (minimum 2 characters)"
return
# Validate subtitle URL
if input_data.subtitle_url:
if not input_data.subtitle_url.startswith("https://"):
yield "error", "YouTube subtitle URL must start with https://"
return
valid_subtitle_extensions = [".srt", ".sbv"]
if not any(
input_data.subtitle_url.lower().endswith(ext)
for ext in valid_subtitle_extensions
):
yield "error", "YouTube subtitle URL must end in .srt or .sbv"
return
if input_data.subtitle_name and len(input_data.subtitle_name) > 150:
yield "error", f"YouTube subtitle name exceeds 150 character limit ({len(input_data.subtitle_name)} characters)"
return
# Validate publish_at format if provided
if input_data.publish_at and input_data.schedule_date:
yield "error", "Cannot use both 'publish_at' and 'schedule_date'. Use 'publish_at' for YouTube-controlled publishing."
return
# Convert datetime to ISO format if provided (only if not using publish_at)
iso_date = None
if not input_data.publish_at and input_data.schedule_date:
iso_date = input_data.schedule_date.isoformat()
# Build YouTube-specific options
youtube_options: dict[str, Any] = {"title": input_data.title}
# Basic options
if input_data.visibility != "private":
youtube_options["visibility"] = input_data.visibility
if input_data.thumbnail:
youtube_options["thumbNail"] = input_data.thumbnail
if input_data.playlist_id:
youtube_options["playListId"] = input_data.playlist_id
if input_data.tags:
youtube_options["tags"] = input_data.tags
if input_data.made_for_kids:
youtube_options["madeForKids"] = True
if input_data.is_shorts:
youtube_options["shorts"] = True
if not input_data.notify_subscribers:
youtube_options["notifySubscribers"] = False
if input_data.category_id and input_data.category_id > 0:
youtube_options["categoryId"] = input_data.category_id
if input_data.contains_synthetic_media:
youtube_options["containsSyntheticMedia"] = True
if input_data.publish_at:
youtube_options["publishAt"] = input_data.publish_at
# Country targeting (from flattened fields)
targeting_dict = {}
if input_data.targeting_block_countries:
targeting_dict["block"] = input_data.targeting_block_countries
if input_data.targeting_allow_countries:
targeting_dict["allow"] = input_data.targeting_allow_countries
if targeting_dict:
youtube_options["targeting"] = targeting_dict
# Subtitle options
if input_data.subtitle_url:
youtube_options["subTitleUrl"] = input_data.subtitle_url
youtube_options["subTitleLanguage"] = input_data.subtitle_language
youtube_options["subTitleName"] = input_data.subtitle_name
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.YOUTUBE],
media_urls=input_data.media_urls,
is_video=True, # YouTube only supports videos
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
youtube_options=youtube_options,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,205 +0,0 @@
"""
Meeting BaaS API client module.
All API calls centralized for consistency and maintainability.
"""
from typing import Any, Dict, List, Optional
from backend.sdk import Requests
class MeetingBaasAPI:
"""Client for Meeting BaaS API endpoints."""
BASE_URL = "https://api.meetingbaas.com"
def __init__(self, api_key: str):
"""Initialize API client with authentication key."""
self.api_key = api_key
self.headers = {"x-meeting-baas-api-key": api_key}
self.requests = Requests()
# Bot Management Endpoints
async def join_meeting(
self,
bot_name: str,
meeting_url: str,
reserved: bool = False,
bot_image: Optional[str] = None,
entry_message: Optional[str] = None,
start_time: Optional[int] = None,
speech_to_text: Optional[Dict[str, Any]] = None,
webhook_url: Optional[str] = None,
automatic_leave: Optional[Dict[str, Any]] = None,
extra: Optional[Dict[str, Any]] = None,
recording_mode: str = "speaker_view",
streaming: Optional[Dict[str, Any]] = None,
deduplication_key: Optional[str] = None,
zoom_sdk_id: Optional[str] = None,
zoom_sdk_pwd: Optional[str] = None,
) -> Dict[str, Any]:
"""
Deploy a bot to join and record a meeting.
POST /bots
"""
body = {
"bot_name": bot_name,
"meeting_url": meeting_url,
"reserved": reserved,
"recording_mode": recording_mode,
}
# Add optional fields if provided
if bot_image is not None:
body["bot_image"] = bot_image
if entry_message is not None:
body["entry_message"] = entry_message
if start_time is not None:
body["start_time"] = start_time
if speech_to_text is not None:
body["speech_to_text"] = speech_to_text
if webhook_url is not None:
body["webhook_url"] = webhook_url
if automatic_leave is not None:
body["automatic_leave"] = automatic_leave
if extra is not None:
body["extra"] = extra
if streaming is not None:
body["streaming"] = streaming
if deduplication_key is not None:
body["deduplication_key"] = deduplication_key
if zoom_sdk_id is not None:
body["zoom_sdk_id"] = zoom_sdk_id
if zoom_sdk_pwd is not None:
body["zoom_sdk_pwd"] = zoom_sdk_pwd
response = await self.requests.post(
f"{self.BASE_URL}/bots",
headers=self.headers,
json=body,
)
return response.json()
async def leave_meeting(self, bot_id: str) -> bool:
"""
Remove a bot from an ongoing meeting.
DELETE /bots/{uuid}
"""
response = await self.requests.delete(
f"{self.BASE_URL}/bots/{bot_id}",
headers=self.headers,
)
return response.status in [200, 204]
async def retranscribe(
self,
bot_uuid: str,
speech_to_text: Optional[Dict[str, Any]] = None,
webhook_url: Optional[str] = None,
) -> Dict[str, Any]:
"""
Re-run transcription on a bot's audio.
POST /bots/retranscribe
"""
body: Dict[str, Any] = {"bot_uuid": bot_uuid}
if speech_to_text is not None:
body["speech_to_text"] = speech_to_text
if webhook_url is not None:
body["webhook_url"] = webhook_url
response = await self.requests.post(
f"{self.BASE_URL}/bots/retranscribe",
headers=self.headers,
json=body,
)
if response.status == 202:
return {"accepted": True}
return response.json()
# Data Retrieval Endpoints
async def get_meeting_data(
self, bot_id: str, include_transcripts: bool = True
) -> Dict[str, Any]:
"""
Retrieve meeting data including recording and transcripts.
GET /bots/meeting_data
"""
params = {
"bot_id": bot_id,
"include_transcripts": str(include_transcripts).lower(),
}
response = await self.requests.get(
f"{self.BASE_URL}/bots/meeting_data",
headers=self.headers,
params=params,
)
return response.json()
async def get_screenshots(self, bot_id: str) -> List[Dict[str, Any]]:
"""
Retrieve screenshots captured during a meeting.
GET /bots/{uuid}/screenshots
"""
response = await self.requests.get(
f"{self.BASE_URL}/bots/{bot_id}/screenshots",
headers=self.headers,
)
result = response.json()
# Ensure we return a list
if isinstance(result, list):
return result
return []
async def delete_data(self, bot_id: str) -> bool:
"""
Delete a bot's recorded data.
POST /bots/{uuid}/delete_data
"""
response = await self.requests.post(
f"{self.BASE_URL}/bots/{bot_id}/delete_data",
headers=self.headers,
)
return response.status == 200
async def list_bots_with_metadata(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
sort_by: Optional[str] = None,
sort_order: Optional[str] = None,
filter_by: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
List bots with metadata including IDs, names, and meeting details.
GET /bots/bots_with_metadata
"""
params = {}
if limit is not None:
params["limit"] = limit
if offset is not None:
params["offset"] = offset
if sort_by is not None:
params["sort_by"] = sort_by
if sort_order is not None:
params["sort_order"] = sort_order
if filter_by is not None:
params.update(filter_by)
response = await self.requests.get(
f"{self.BASE_URL}/bots/bots_with_metadata",
headers=self.headers,
params=params,
)
return response.json()

View File

@@ -1,13 +0,0 @@
"""
Shared configuration for all Meeting BaaS blocks using the SDK pattern.
"""
from backend.sdk import BlockCostType, ProviderBuilder
# Configure the Meeting BaaS provider with API key authentication
baas = (
ProviderBuilder("baas")
.with_api_key("MEETING_BAAS_API_KEY", "Meeting BaaS API Key")
.with_base_cost(5, BlockCostType.RUN) # Higher cost for meeting recording service
.build()
)

View File

@@ -1,217 +0,0 @@
"""
Meeting BaaS bot (recording) blocks.
"""
from typing import Optional
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
SchemaField,
)
from ._api import MeetingBaasAPI
from ._config import baas
class BaasBotJoinMeetingBlock(Block):
"""
Deploy a bot immediately or at a scheduled start_time to join and record a meeting.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = baas.credentials_field(
description="Meeting BaaS API credentials"
)
meeting_url: str = SchemaField(
description="The URL of the meeting the bot should join"
)
bot_name: str = SchemaField(
description="Display name for the bot in the meeting"
)
bot_image: str = SchemaField(
description="URL to an image for the bot's avatar (16:9 ratio recommended)",
default="",
)
entry_message: str = SchemaField(
description="Chat message the bot will post upon entry", default=""
)
reserved: bool = SchemaField(
description="Use a reserved bot slot (joins 4 min before meeting)",
default=False,
)
start_time: Optional[int] = SchemaField(
description="Unix timestamp (ms) when bot should join", default=None
)
webhook_url: str | None = SchemaField(
description="URL to receive webhook events for this bot", default=None
)
timeouts: dict = SchemaField(
description="Automatic leave timeouts configuration", default={}
)
extra: dict = SchemaField(
description="Custom metadata to attach to the bot", default={}
)
class Output(BlockSchema):
bot_id: str = SchemaField(description="UUID of the deployed bot")
join_response: dict = SchemaField(
description="Full response from join operation"
)
def __init__(self):
super().__init__(
id="377d1a6a-a99b-46cf-9af3-1d1b12758e04",
description="Deploy a bot to join and record a meeting",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
api_key = credentials.api_key.get_secret_value()
api = MeetingBaasAPI(api_key)
# Call API with all parameters
data = await api.join_meeting(
bot_name=input_data.bot_name,
meeting_url=input_data.meeting_url,
reserved=input_data.reserved,
bot_image=input_data.bot_image if input_data.bot_image else None,
entry_message=(
input_data.entry_message if input_data.entry_message else None
),
start_time=input_data.start_time,
speech_to_text={"provider": "Default"},
webhook_url=input_data.webhook_url if input_data.webhook_url else None,
automatic_leave=input_data.timeouts if input_data.timeouts else None,
extra=input_data.extra if input_data.extra else None,
)
yield "bot_id", data.get("bot_id", "")
yield "join_response", data
class BaasBotLeaveMeetingBlock(Block):
"""
Force the bot to exit the call.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = baas.credentials_field(
description="Meeting BaaS API credentials"
)
bot_id: str = SchemaField(description="UUID of the bot to remove from meeting")
class Output(BlockSchema):
left: bool = SchemaField(description="Whether the bot successfully left")
def __init__(self):
super().__init__(
id="bf77d128-8b25-4280-b5c7-2d553ba7e482",
description="Remove a bot from an ongoing meeting",
categories={BlockCategory.COMMUNICATION},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
api_key = credentials.api_key.get_secret_value()
api = MeetingBaasAPI(api_key)
# Leave meeting
left = await api.leave_meeting(input_data.bot_id)
yield "left", left
class BaasBotFetchMeetingDataBlock(Block):
"""
Pull MP4 URL, transcript & metadata for a completed meeting.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = baas.credentials_field(
description="Meeting BaaS API credentials"
)
bot_id: str = SchemaField(description="UUID of the bot whose data to fetch")
include_transcripts: bool = SchemaField(
description="Include transcript data in response", default=True
)
class Output(BlockSchema):
mp4_url: str = SchemaField(
description="URL to download the meeting recording (time-limited)"
)
transcript: list = SchemaField(description="Meeting transcript data")
metadata: dict = SchemaField(description="Meeting metadata and bot information")
def __init__(self):
super().__init__(
id="ea7c1309-303c-4da1-893f-89c0e9d64e78",
description="Retrieve recorded meeting data",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
api_key = credentials.api_key.get_secret_value()
api = MeetingBaasAPI(api_key)
# Fetch meeting data
data = await api.get_meeting_data(
bot_id=input_data.bot_id,
include_transcripts=input_data.include_transcripts,
)
yield "mp4_url", data.get("mp4", "")
yield "transcript", data.get("bot_data", {}).get("transcripts", [])
yield "metadata", data.get("bot_data", {}).get("bot", {})
class BaasBotDeleteRecordingBlock(Block):
"""
Purge MP4 + transcript data for privacy or storage management.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = baas.credentials_field(
description="Meeting BaaS API credentials"
)
bot_id: str = SchemaField(description="UUID of the bot whose data to delete")
class Output(BlockSchema):
deleted: bool = SchemaField(
description="Whether the data was successfully deleted"
)
def __init__(self):
super().__init__(
id="bf8d1aa6-42d8-4944-b6bd-6bac554c0d3b",
description="Permanently delete a meeting's recorded data",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
api_key = credentials.api_key.get_secret_value()
api = MeetingBaasAPI(api_key)
# Delete recording data
deleted = await api.delete_data(input_data.bot_id)
yield "deleted", deleted

View File

@@ -1,9 +1,11 @@
import enum
from typing import Any
from typing import Any, List
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
from backend.data.model import SchemaField
from backend.util import json
from backend.util.file import store_media_file
from backend.util.mock import MockObject
from backend.util.type import MediaFileType, convert
@@ -12,12 +14,6 @@ class FileStoreBlock(Block):
file_in: MediaFileType = SchemaField(
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
)
base_64: bool = SchemaField(
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
default=False,
advanced=True,
title="Produce Base64 Output",
)
class Output(BlockSchema):
file_out: MediaFileType = SchemaField(
@@ -39,15 +35,14 @@ class FileStoreBlock(Block):
input_data: Input,
*,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
yield "file_out", await store_media_file(
file_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.file_in,
user_id=user_id,
return_content=input_data.base_64,
return_content=False,
)
yield "file_out", file_path
class StoreValueBlock(Block):
@@ -120,6 +115,266 @@ class PrintToConsoleBlock(Block):
yield "status", "printed"
class FindInDictionaryBlock(Block):
class Input(BlockSchema):
input: Any = SchemaField(description="Dictionary to lookup from")
key: str | int = SchemaField(description="Key to lookup in the dictionary")
class Output(BlockSchema):
output: Any = SchemaField(description="Value found for the given key")
missing: Any = SchemaField(
description="Value of the input that missing the key"
)
def __init__(self):
super().__init__(
id="0e50422c-6dee-4145-83d6-3a5a392f65de",
description="Lookup the given key in the input dictionary/object/list and return the value.",
input_schema=FindInDictionaryBlock.Input,
output_schema=FindInDictionaryBlock.Output,
test_input=[
{"input": {"apple": 1, "banana": 2, "cherry": 3}, "key": "banana"},
{"input": {"x": 10, "y": 20, "z": 30}, "key": "w"},
{"input": [1, 2, 3], "key": 1},
{"input": [1, 2, 3], "key": 3},
{"input": MockObject(value="!!", key="key"), "key": "key"},
{"input": [{"k1": "v1"}, {"k2": "v2"}, {"k1": "v3"}], "key": "k1"},
],
test_output=[
("output", 2),
("missing", {"x": 10, "y": 20, "z": 30}),
("output", 2),
("missing", [1, 2, 3]),
("output", "key"),
("output", ["v1", "v3"]),
],
categories={BlockCategory.BASIC},
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
obj = input_data.input
key = input_data.key
if isinstance(obj, str):
obj = json.loads(obj)
if isinstance(obj, dict) and key in obj:
yield "output", obj[key]
elif isinstance(obj, list) and isinstance(key, int) and 0 <= key < len(obj):
yield "output", obj[key]
elif isinstance(obj, list) and isinstance(key, str):
if len(obj) == 0:
yield "output", []
elif isinstance(obj[0], dict) and key in obj[0]:
yield "output", [item[key] for item in obj if key in item]
else:
yield "output", [getattr(val, key) for val in obj if hasattr(val, key)]
elif isinstance(obj, object) and isinstance(key, str) and hasattr(obj, key):
yield "output", getattr(obj, key)
else:
yield "missing", input_data.input
class AddToDictionaryBlock(Block):
class Input(BlockSchema):
dictionary: dict[Any, Any] = SchemaField(
default_factory=dict,
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
)
key: str = SchemaField(
default="",
description="The key for the new entry.",
placeholder="new_key",
advanced=False,
)
value: Any = SchemaField(
default=None,
description="The value for the new entry.",
placeholder="new_value",
advanced=False,
)
entries: dict[Any, Any] = SchemaField(
default_factory=dict,
description="The entries to add to the dictionary. This is the batch version of the `key` and `value` fields.",
advanced=True,
)
class Output(BlockSchema):
updated_dictionary: dict = SchemaField(
description="The dictionary with the new entry added."
)
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="31d1064e-7446-4693-a7d4-65e5ca1180d1",
description="Adds a new key-value pair to a dictionary. If no dictionary is provided, a new one is created.",
categories={BlockCategory.BASIC},
input_schema=AddToDictionaryBlock.Input,
output_schema=AddToDictionaryBlock.Output,
test_input=[
{
"dictionary": {"existing_key": "existing_value"},
"key": "new_key",
"value": "new_value",
},
{"key": "first_key", "value": "first_value"},
{
"dictionary": {"existing_key": "existing_value"},
"entries": {"new_key": "new_value", "first_key": "first_value"},
},
],
test_output=[
(
"updated_dictionary",
{"existing_key": "existing_value", "new_key": "new_value"},
),
("updated_dictionary", {"first_key": "first_value"}),
(
"updated_dictionary",
{
"existing_key": "existing_value",
"new_key": "new_value",
"first_key": "first_value",
},
),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
updated_dict = input_data.dictionary.copy()
if input_data.value is not None and input_data.key:
updated_dict[input_data.key] = input_data.value
for key, value in input_data.entries.items():
updated_dict[key] = value
yield "updated_dictionary", updated_dict
class AddToListBlock(Block):
class Input(BlockSchema):
list: List[Any] = SchemaField(
default_factory=list,
advanced=False,
description="The list to add the entry to. If not provided, a new list will be created.",
)
entry: Any = SchemaField(
description="The entry to add to the list. Can be of any type (string, int, dict, etc.).",
advanced=False,
default=None,
)
entries: List[Any] = SchemaField(
default_factory=lambda: list(),
description="The entries to add to the list. This is the batch version of the `entry` field.",
advanced=True,
)
position: int | None = SchemaField(
default=None,
description="The position to insert the new entry. If not provided, the entry will be appended to the end of the list.",
)
class Output(BlockSchema):
updated_list: List[Any] = SchemaField(
description="The list with the new entry added."
)
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="aeb08fc1-2fc1-4141-bc8e-f758f183a822",
description="Adds a new entry to a list. The entry can be of any type. If no list is provided, a new one is created.",
categories={BlockCategory.BASIC},
input_schema=AddToListBlock.Input,
output_schema=AddToListBlock.Output,
test_input=[
{
"list": [1, "string", {"existing_key": "existing_value"}],
"entry": {"new_key": "new_value"},
"position": 1,
},
{"entry": "first_entry"},
{"list": ["a", "b", "c"], "entry": "d"},
{
"entry": "e",
"entries": ["f", "g"],
"list": ["a", "b"],
"position": 1,
},
],
test_output=[
(
"updated_list",
[
1,
{"new_key": "new_value"},
"string",
{"existing_key": "existing_value"},
],
),
("updated_list", ["first_entry"]),
("updated_list", ["a", "b", "c", "d"]),
("updated_list", ["a", "f", "g", "e", "b"]),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
entries_added = input_data.entries.copy()
if input_data.entry:
entries_added.append(input_data.entry)
updated_list = input_data.list.copy()
if (pos := input_data.position) is not None:
updated_list = updated_list[:pos] + entries_added + updated_list[pos:]
else:
updated_list += entries_added
yield "updated_list", updated_list
class FindInListBlock(Block):
class Input(BlockSchema):
list: List[Any] = SchemaField(description="The list to search in.")
value: Any = SchemaField(description="The value to search for.")
class Output(BlockSchema):
index: int = SchemaField(description="The index of the value in the list.")
found: bool = SchemaField(
description="Whether the value was found in the list."
)
not_found_value: Any = SchemaField(
description="The value that was not found in the list."
)
def __init__(self):
super().__init__(
id="5e2c6d0a-1e37-489f-b1d0-8e1812b23333",
description="Finds the index of the value in the list.",
categories={BlockCategory.BASIC},
input_schema=FindInListBlock.Input,
output_schema=FindInListBlock.Output,
test_input=[
{"list": [1, 2, 3, 4, 5], "value": 3},
{"list": [1, 2, 3, 4, 5], "value": 6},
],
test_output=[
("index", 2),
("found", True),
("found", False),
("not_found_value", 6),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
yield "index", input_data.list.index(input_data.value)
yield "found", True
except ValueError:
yield "found", False
yield "not_found_value", input_data.value
class NoteBlock(Block):
class Input(BlockSchema):
text: str = SchemaField(description="The text to display in the sticky note.")
@@ -145,6 +400,104 @@ class NoteBlock(Block):
yield "output", input_data.text
class CreateDictionaryBlock(Block):
class Input(BlockSchema):
values: dict[str, Any] = SchemaField(
description="Key-value pairs to create the dictionary with",
placeholder="e.g., {'name': 'Alice', 'age': 25}",
)
class Output(BlockSchema):
dictionary: dict[str, Any] = SchemaField(
description="The created dictionary containing the specified key-value pairs"
)
error: str = SchemaField(
description="Error message if dictionary creation failed"
)
def __init__(self):
super().__init__(
id="b924ddf4-de4f-4b56-9a85-358930dcbc91",
description="Creates a dictionary with the specified key-value pairs. Use this when you know all the values you want to add upfront.",
categories={BlockCategory.DATA},
input_schema=CreateDictionaryBlock.Input,
output_schema=CreateDictionaryBlock.Output,
test_input=[
{
"values": {"name": "Alice", "age": 25, "city": "New York"},
},
{
"values": {"numbers": [1, 2, 3], "active": True, "score": 95.5},
},
],
test_output=[
(
"dictionary",
{"name": "Alice", "age": 25, "city": "New York"},
),
(
"dictionary",
{"numbers": [1, 2, 3], "active": True, "score": 95.5},
),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
# The values are already validated by Pydantic schema
yield "dictionary", input_data.values
except Exception as e:
yield "error", f"Failed to create dictionary: {str(e)}"
class CreateListBlock(Block):
class Input(BlockSchema):
values: List[Any] = SchemaField(
description="A list of values to be combined into a new list.",
placeholder="e.g., ['Alice', 25, True]",
)
class Output(BlockSchema):
list: List[Any] = SchemaField(
description="The created list containing the specified values."
)
error: str = SchemaField(description="Error message if list creation failed.")
def __init__(self):
super().__init__(
id="a912d5c7-6e00-4542-b2a9-8034136930e4",
description="Creates a list with the specified values. Use this when you know all the values you want to add upfront.",
categories={BlockCategory.DATA},
input_schema=CreateListBlock.Input,
output_schema=CreateListBlock.Output,
test_input=[
{
"values": ["Alice", 25, True],
},
{
"values": [1, 2, 3, "four", {"key": "value"}],
},
],
test_output=[
(
"list",
["Alice", 25, True],
),
(
"list",
[1, 2, 3, "four", {"key": "value"}],
),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
# The values are already validated by Pydantic schema
yield "list", input_data.values
except Exception as e:
yield "error", f"Failed to create list: {str(e)}"
class TypeOptions(enum.Enum):
STRING = "string"
NUMBER = "number"
@@ -162,7 +515,6 @@ class UniversalTypeConverterBlock(Block):
class Output(BlockSchema):
value: Any = SchemaField(description="The converted value.")
error: str = SchemaField(description="Error message if conversion failed.")
def __init__(self):
super().__init__(
@@ -188,31 +540,3 @@ class UniversalTypeConverterBlock(Block):
yield "value", converted_value
except Exception as e:
yield "error", f"Failed to convert value: {str(e)}"
class ReverseListOrderBlock(Block):
"""
A block which takes in a list and returns it in the opposite order.
"""
class Input(BlockSchema):
input_list: list[Any] = SchemaField(description="The list to reverse")
class Output(BlockSchema):
reversed_list: list[Any] = SchemaField(description="The list in reversed order")
def __init__(self):
super().__init__(
id="422cb708-3109-4277-bfe3-bc2ae5812777",
description="Reverses the order of elements in a list",
categories={BlockCategory.BASIC},
input_schema=ReverseListOrderBlock.Input,
output_schema=ReverseListOrderBlock.Output,
test_input={"input_list": [1, 2, 3, 4, 5]},
test_output=[("reversed_list", [5, 4, 3, 2, 1])],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
reversed_list = list(input_data.input_list)
reversed_list.reverse()
yield "reversed_list", reversed_list

View File

@@ -3,7 +3,6 @@ from typing import Any
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.type import convert
class ComparisonOperator(Enum):
@@ -164,7 +163,7 @@ class IfInputMatchesBlock(Block):
},
{
"input": 10,
"value": "None",
"value": None,
"yes_value": "Yes",
"no_value": "No",
},
@@ -182,23 +181,7 @@ class IfInputMatchesBlock(Block):
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
# If input_data.value is not matching input_data.input, convert value to type of input
if (
input_data.input != input_data.value
and input_data.input is not input_data.value
):
try:
# Only attempt conversion if input is not None and value is not None
if input_data.input is not None and input_data.value is not None:
input_type = type(input_data.input)
# Avoid converting if input_type is Any or object
if input_type not in (Any, object):
input_data.value = convert(input_data.value, input_type)
except Exception:
pass # If conversion fails, just leave value as is
if input_data.input == input_data.value:
if input_data.input == input_data.value or input_data.input is input_data.value:
yield "result", True
yield "yes_output", input_data.yes_value
else:

View File

@@ -0,0 +1,109 @@
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import ContributorDetails, SchemaField
class ReadCsvBlock(Block):
class Input(BlockSchema):
contents: str = SchemaField(
description="The contents of the CSV file to read",
placeholder="a, b, c\n1,2,3\n4,5,6",
)
delimiter: str = SchemaField(
description="The delimiter used in the CSV file",
default=",",
)
quotechar: str = SchemaField(
description="The character used to quote fields",
default='"',
)
escapechar: str = SchemaField(
description="The character used to escape the delimiter",
default="\\",
)
has_header: bool = SchemaField(
description="Whether the CSV file has a header row",
default=True,
)
skip_rows: int = SchemaField(
description="The number of rows to skip from the start of the file",
default=0,
)
strip: bool = SchemaField(
description="Whether to strip whitespace from the values",
default=True,
)
skip_columns: list[str] = SchemaField(
description="The columns to skip from the start of the row",
default_factory=list,
)
class Output(BlockSchema):
row: dict[str, str] = SchemaField(
description="The data produced from each row in the CSV file"
)
all_data: list[dict[str, str]] = SchemaField(
description="All the data in the CSV file as a list of rows"
)
def __init__(self):
super().__init__(
id="acf7625e-d2cb-4941-bfeb-2819fc6fc015",
input_schema=ReadCsvBlock.Input,
output_schema=ReadCsvBlock.Output,
description="Reads a CSV file and outputs the data as a list of dictionaries and individual rows via rows.",
contributors=[ContributorDetails(name="Nicholas Tindle")],
categories={BlockCategory.TEXT, BlockCategory.DATA},
test_input={
"contents": "a, b, c\n1,2,3\n4,5,6",
},
test_output=[
("row", {"a": "1", "b": "2", "c": "3"}),
("row", {"a": "4", "b": "5", "c": "6"}),
(
"all_data",
[
{"a": "1", "b": "2", "c": "3"},
{"a": "4", "b": "5", "c": "6"},
],
),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
import csv
from io import StringIO
csv_file = StringIO(input_data.contents)
reader = csv.reader(
csv_file,
delimiter=input_data.delimiter,
quotechar=input_data.quotechar,
escapechar=input_data.escapechar,
)
header = None
if input_data.has_header:
header = next(reader)
if input_data.strip:
header = [h.strip() for h in header]
for _ in range(input_data.skip_rows):
next(reader)
def process_row(row):
data = {}
for i, value in enumerate(row):
if i not in input_data.skip_columns:
if input_data.has_header and header:
data[header[i]] = value.strip() if input_data.strip else value
else:
data[str(i)] = value.strip() if input_data.strip else value
return data
all_data = []
for row in reader:
processed_row = process_row(row)
all_data.append(processed_row)
yield "row", processed_row
yield "all_data", all_data

View File

@@ -1,683 +0,0 @@
from typing import Any, List
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.json import loads
from backend.util.mock import MockObject
from backend.util.prompt import estimate_token_count_str
# =============================================================================
# Dictionary Manipulation Blocks
# =============================================================================
class CreateDictionaryBlock(Block):
class Input(BlockSchema):
values: dict[str, Any] = SchemaField(
description="Key-value pairs to create the dictionary with",
placeholder="e.g., {'name': 'Alice', 'age': 25}",
)
class Output(BlockSchema):
dictionary: dict[str, Any] = SchemaField(
description="The created dictionary containing the specified key-value pairs"
)
error: str = SchemaField(
description="Error message if dictionary creation failed"
)
def __init__(self):
super().__init__(
id="b924ddf4-de4f-4b56-9a85-358930dcbc91",
description="Creates a dictionary with the specified key-value pairs. Use this when you know all the values you want to add upfront.",
categories={BlockCategory.DATA},
input_schema=CreateDictionaryBlock.Input,
output_schema=CreateDictionaryBlock.Output,
test_input=[
{
"values": {"name": "Alice", "age": 25, "city": "New York"},
},
{
"values": {"numbers": [1, 2, 3], "active": True, "score": 95.5},
},
],
test_output=[
(
"dictionary",
{"name": "Alice", "age": 25, "city": "New York"},
),
(
"dictionary",
{"numbers": [1, 2, 3], "active": True, "score": 95.5},
),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
# The values are already validated by Pydantic schema
yield "dictionary", input_data.values
except Exception as e:
yield "error", f"Failed to create dictionary: {str(e)}"
class AddToDictionaryBlock(Block):
class Input(BlockSchema):
dictionary: dict[Any, Any] = SchemaField(
default_factory=dict,
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
)
key: str = SchemaField(
default="",
description="The key for the new entry.",
placeholder="new_key",
advanced=False,
)
value: Any = SchemaField(
default=None,
description="The value for the new entry.",
placeholder="new_value",
advanced=False,
)
entries: dict[Any, Any] = SchemaField(
default_factory=dict,
description="The entries to add to the dictionary. This is the batch version of the `key` and `value` fields.",
advanced=True,
)
class Output(BlockSchema):
updated_dictionary: dict = SchemaField(
description="The dictionary with the new entry added."
)
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="31d1064e-7446-4693-a7d4-65e5ca1180d1",
description="Adds a new key-value pair to a dictionary. If no dictionary is provided, a new one is created.",
categories={BlockCategory.BASIC},
input_schema=AddToDictionaryBlock.Input,
output_schema=AddToDictionaryBlock.Output,
test_input=[
{
"dictionary": {"existing_key": "existing_value"},
"key": "new_key",
"value": "new_value",
},
{"key": "first_key", "value": "first_value"},
{
"dictionary": {"existing_key": "existing_value"},
"entries": {"new_key": "new_value", "first_key": "first_value"},
},
],
test_output=[
(
"updated_dictionary",
{"existing_key": "existing_value", "new_key": "new_value"},
),
("updated_dictionary", {"first_key": "first_value"}),
(
"updated_dictionary",
{
"existing_key": "existing_value",
"new_key": "new_value",
"first_key": "first_value",
},
),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
updated_dict = input_data.dictionary.copy()
if input_data.value is not None and input_data.key:
updated_dict[input_data.key] = input_data.value
for key, value in input_data.entries.items():
updated_dict[key] = value
yield "updated_dictionary", updated_dict
class FindInDictionaryBlock(Block):
class Input(BlockSchema):
input: Any = SchemaField(description="Dictionary to lookup from")
key: str | int = SchemaField(description="Key to lookup in the dictionary")
class Output(BlockSchema):
output: Any = SchemaField(description="Value found for the given key")
missing: Any = SchemaField(
description="Value of the input that missing the key"
)
def __init__(self):
super().__init__(
id="0e50422c-6dee-4145-83d6-3a5a392f65de",
description="Lookup the given key in the input dictionary/object/list and return the value.",
input_schema=FindInDictionaryBlock.Input,
output_schema=FindInDictionaryBlock.Output,
test_input=[
{"input": {"apple": 1, "banana": 2, "cherry": 3}, "key": "banana"},
{"input": {"x": 10, "y": 20, "z": 30}, "key": "w"},
{"input": [1, 2, 3], "key": 1},
{"input": [1, 2, 3], "key": 3},
{"input": MockObject(value="!!", key="key"), "key": "key"},
{"input": [{"k1": "v1"}, {"k2": "v2"}, {"k1": "v3"}], "key": "k1"},
],
test_output=[
("output", 2),
("missing", {"x": 10, "y": 20, "z": 30}),
("output", 2),
("missing", [1, 2, 3]),
("output", "key"),
("output", ["v1", "v3"]),
],
categories={BlockCategory.BASIC},
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
obj = input_data.input
key = input_data.key
if isinstance(obj, str):
obj = loads(obj)
if isinstance(obj, dict) and key in obj:
yield "output", obj[key]
elif isinstance(obj, list) and isinstance(key, int) and 0 <= key < len(obj):
yield "output", obj[key]
elif isinstance(obj, list) and isinstance(key, str):
if len(obj) == 0:
yield "output", []
elif isinstance(obj[0], dict) and key in obj[0]:
yield "output", [item[key] for item in obj if key in item]
else:
yield "output", [getattr(val, key) for val in obj if hasattr(val, key)]
elif isinstance(obj, object) and isinstance(key, str) and hasattr(obj, key):
yield "output", getattr(obj, key)
else:
yield "missing", input_data.input
class RemoveFromDictionaryBlock(Block):
class Input(BlockSchema):
dictionary: dict[Any, Any] = SchemaField(
description="The dictionary to modify."
)
key: str | int = SchemaField(description="Key to remove from the dictionary.")
return_value: bool = SchemaField(
default=False, description="Whether to return the removed value."
)
class Output(BlockSchema):
updated_dictionary: dict[Any, Any] = SchemaField(
description="The dictionary after removal."
)
removed_value: Any = SchemaField(description="The removed value if requested.")
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="46afe2ea-c613-43f8-95ff-6692c3ef6876",
description="Removes a key-value pair from a dictionary.",
categories={BlockCategory.BASIC},
input_schema=RemoveFromDictionaryBlock.Input,
output_schema=RemoveFromDictionaryBlock.Output,
test_input=[
{
"dictionary": {"a": 1, "b": 2, "c": 3},
"key": "b",
"return_value": True,
},
{"dictionary": {"x": "hello", "y": "world"}, "key": "x"},
],
test_output=[
("updated_dictionary", {"a": 1, "c": 3}),
("removed_value", 2),
("updated_dictionary", {"y": "world"}),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
updated_dict = input_data.dictionary.copy()
try:
removed_value = updated_dict.pop(input_data.key)
yield "updated_dictionary", updated_dict
if input_data.return_value:
yield "removed_value", removed_value
except KeyError:
yield "error", f"Key '{input_data.key}' not found in dictionary"
class ReplaceDictionaryValueBlock(Block):
class Input(BlockSchema):
dictionary: dict[Any, Any] = SchemaField(
description="The dictionary to modify."
)
key: str | int = SchemaField(description="Key to replace the value for.")
value: Any = SchemaField(description="The new value for the given key.")
class Output(BlockSchema):
updated_dictionary: dict[Any, Any] = SchemaField(
description="The dictionary after replacement."
)
old_value: Any = SchemaField(description="The value that was replaced.")
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="27e31876-18b6-44f3-ab97-f6226d8b3889",
description="Replaces the value for a specified key in a dictionary.",
categories={BlockCategory.BASIC},
input_schema=ReplaceDictionaryValueBlock.Input,
output_schema=ReplaceDictionaryValueBlock.Output,
test_input=[
{"dictionary": {"a": 1, "b": 2, "c": 3}, "key": "b", "value": 99},
{
"dictionary": {"x": "hello", "y": "world"},
"key": "y",
"value": "universe",
},
],
test_output=[
("updated_dictionary", {"a": 1, "b": 99, "c": 3}),
("old_value", 2),
("updated_dictionary", {"x": "hello", "y": "universe"}),
("old_value", "world"),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
updated_dict = input_data.dictionary.copy()
try:
old_value = updated_dict[input_data.key]
updated_dict[input_data.key] = input_data.value
yield "updated_dictionary", updated_dict
yield "old_value", old_value
except KeyError:
yield "error", f"Key '{input_data.key}' not found in dictionary"
class DictionaryIsEmptyBlock(Block):
class Input(BlockSchema):
dictionary: dict[Any, Any] = SchemaField(description="The dictionary to check.")
class Output(BlockSchema):
is_empty: bool = SchemaField(description="True if the dictionary is empty.")
def __init__(self):
super().__init__(
id="a3cf3f64-6bb9-4cc6-9900-608a0b3359b0",
description="Checks if a dictionary is empty.",
categories={BlockCategory.BASIC},
input_schema=DictionaryIsEmptyBlock.Input,
output_schema=DictionaryIsEmptyBlock.Output,
test_input=[{"dictionary": {}}, {"dictionary": {"a": 1}}],
test_output=[("is_empty", True), ("is_empty", False)],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
yield "is_empty", len(input_data.dictionary) == 0
# =============================================================================
# List Manipulation Blocks
# =============================================================================
class CreateListBlock(Block):
class Input(BlockSchema):
values: List[Any] = SchemaField(
description="A list of values to be combined into a new list.",
placeholder="e.g., ['Alice', 25, True]",
)
max_size: int | None = SchemaField(
default=None,
description="Maximum size of the list. If provided, the list will be yielded in chunks of this size.",
advanced=True,
)
max_tokens: int | None = SchemaField(
default=None,
description="Maximum tokens for the list. If provided, the list will be yielded in chunks that fit within this token limit.",
advanced=True,
)
class Output(BlockSchema):
list: List[Any] = SchemaField(
description="The created list containing the specified values."
)
error: str = SchemaField(description="Error message if list creation failed.")
def __init__(self):
super().__init__(
id="a912d5c7-6e00-4542-b2a9-8034136930e4",
description="Creates a list with the specified values. Use this when you know all the values you want to add upfront. This block can also yield the list in batches based on a maximum size or token limit.",
categories={BlockCategory.DATA},
input_schema=CreateListBlock.Input,
output_schema=CreateListBlock.Output,
test_input=[
{
"values": ["Alice", 25, True],
},
{
"values": [1, 2, 3, "four", {"key": "value"}],
},
],
test_output=[
(
"list",
["Alice", 25, True],
),
(
"list",
[1, 2, 3, "four", {"key": "value"}],
),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
chunk = []
cur_tokens, max_tokens = 0, input_data.max_tokens
cur_size, max_size = 0, input_data.max_size
for value in input_data.values:
if max_tokens:
tokens = estimate_token_count_str(value)
else:
tokens = 0
# Check if adding this value would exceed either limit
if (max_tokens and (cur_tokens + tokens > max_tokens)) or (
max_size and (cur_size + 1 > max_size)
):
yield "list", chunk
chunk = [value]
cur_size, cur_tokens = 1, tokens
else:
chunk.append(value)
cur_size, cur_tokens = cur_size + 1, cur_tokens + tokens
# Yield final chunk if any
if chunk or not input_data.values:
yield "list", chunk
class AddToListBlock(Block):
class Input(BlockSchema):
list: List[Any] = SchemaField(
default_factory=list,
advanced=False,
description="The list to add the entry to. If not provided, a new list will be created.",
)
entry: Any = SchemaField(
description="The entry to add to the list. Can be of any type (string, int, dict, etc.).",
advanced=False,
default=None,
)
entries: List[Any] = SchemaField(
default_factory=lambda: list(),
description="The entries to add to the list. This is the batch version of the `entry` field.",
advanced=True,
)
position: int | None = SchemaField(
default=None,
description="The position to insert the new entry. If not provided, the entry will be appended to the end of the list.",
)
class Output(BlockSchema):
updated_list: List[Any] = SchemaField(
description="The list with the new entry added."
)
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="aeb08fc1-2fc1-4141-bc8e-f758f183a822",
description="Adds a new entry to a list. The entry can be of any type. If no list is provided, a new one is created.",
categories={BlockCategory.BASIC},
input_schema=AddToListBlock.Input,
output_schema=AddToListBlock.Output,
test_input=[
{
"list": [1, "string", {"existing_key": "existing_value"}],
"entry": {"new_key": "new_value"},
"position": 1,
},
{"entry": "first_entry"},
{"list": ["a", "b", "c"], "entry": "d"},
{
"entry": "e",
"entries": ["f", "g"],
"list": ["a", "b"],
"position": 1,
},
],
test_output=[
(
"updated_list",
[
1,
{"new_key": "new_value"},
"string",
{"existing_key": "existing_value"},
],
),
("updated_list", ["first_entry"]),
("updated_list", ["a", "b", "c", "d"]),
("updated_list", ["a", "f", "g", "e", "b"]),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
entries_added = input_data.entries.copy()
if input_data.entry:
entries_added.append(input_data.entry)
updated_list = input_data.list.copy()
if (pos := input_data.position) is not None:
updated_list = updated_list[:pos] + entries_added + updated_list[pos:]
else:
updated_list += entries_added
yield "updated_list", updated_list
class FindInListBlock(Block):
class Input(BlockSchema):
list: List[Any] = SchemaField(description="The list to search in.")
value: Any = SchemaField(description="The value to search for.")
class Output(BlockSchema):
index: int = SchemaField(description="The index of the value in the list.")
found: bool = SchemaField(
description="Whether the value was found in the list."
)
not_found_value: Any = SchemaField(
description="The value that was not found in the list."
)
def __init__(self):
super().__init__(
id="5e2c6d0a-1e37-489f-b1d0-8e1812b23333",
description="Finds the index of the value in the list.",
categories={BlockCategory.BASIC},
input_schema=FindInListBlock.Input,
output_schema=FindInListBlock.Output,
test_input=[
{"list": [1, 2, 3, 4, 5], "value": 3},
{"list": [1, 2, 3, 4, 5], "value": 6},
],
test_output=[
("index", 2),
("found", True),
("found", False),
("not_found_value", 6),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
yield "index", input_data.list.index(input_data.value)
yield "found", True
except ValueError:
yield "found", False
yield "not_found_value", input_data.value
class GetListItemBlock(Block):
class Input(BlockSchema):
list: List[Any] = SchemaField(description="The list to get the item from.")
index: int = SchemaField(
description="The 0-based index of the item (supports negative indices)."
)
class Output(BlockSchema):
item: Any = SchemaField(description="The item at the specified index.")
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="262ca24c-1025-43cf-a578-534e23234e97",
description="Returns the element at the given index.",
categories={BlockCategory.BASIC},
input_schema=GetListItemBlock.Input,
output_schema=GetListItemBlock.Output,
test_input=[
{"list": [1, 2, 3], "index": 1},
{"list": [1, 2, 3], "index": -1},
],
test_output=[
("item", 2),
("item", 3),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
yield "item", input_data.list[input_data.index]
except IndexError:
yield "error", "Index out of range"
class RemoveFromListBlock(Block):
class Input(BlockSchema):
list: List[Any] = SchemaField(description="The list to modify.")
value: Any = SchemaField(
default=None, description="Value to remove from the list."
)
index: int | None = SchemaField(
default=None,
description="Index of the item to pop (supports negative indices).",
)
return_item: bool = SchemaField(
default=False, description="Whether to return the removed item."
)
class Output(BlockSchema):
updated_list: List[Any] = SchemaField(description="The list after removal.")
removed_item: Any = SchemaField(description="The removed item if requested.")
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="d93c5a93-ac7e-41c1-ae5c-ef67e6e9b826",
description="Removes an item from a list by value or index.",
categories={BlockCategory.BASIC},
input_schema=RemoveFromListBlock.Input,
output_schema=RemoveFromListBlock.Output,
test_input=[
{"list": [1, 2, 3], "index": 1, "return_item": True},
{"list": ["a", "b", "c"], "value": "b"},
],
test_output=[
("updated_list", [1, 3]),
("removed_item", 2),
("updated_list", ["a", "c"]),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
lst = input_data.list.copy()
removed = None
try:
if input_data.index is not None:
removed = lst.pop(input_data.index)
elif input_data.value is not None:
lst.remove(input_data.value)
removed = input_data.value
else:
raise ValueError("No index or value provided for removal")
except (IndexError, ValueError):
yield "error", "Index or value not found"
return
yield "updated_list", lst
if input_data.return_item:
yield "removed_item", removed
class ReplaceListItemBlock(Block):
class Input(BlockSchema):
list: List[Any] = SchemaField(description="The list to modify.")
index: int = SchemaField(
description="Index of the item to replace (supports negative indices)."
)
value: Any = SchemaField(description="The new value for the given index.")
class Output(BlockSchema):
updated_list: List[Any] = SchemaField(description="The list after replacement.")
old_item: Any = SchemaField(description="The item that was replaced.")
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="fbf62922-bea1-4a3d-8bac-23587f810b38",
description="Replaces an item at the specified index.",
categories={BlockCategory.BASIC},
input_schema=ReplaceListItemBlock.Input,
output_schema=ReplaceListItemBlock.Output,
test_input=[
{"list": [1, 2, 3], "index": 1, "value": 99},
{"list": ["a", "b"], "index": -1, "value": "c"},
],
test_output=[
("updated_list", [1, 99, 3]),
("old_item", 2),
("updated_list", ["a", "c"]),
("old_item", "b"),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
lst = input_data.list.copy()
try:
old = lst[input_data.index]
lst[input_data.index] = input_data.value
except IndexError:
yield "error", "Index out of range"
return
yield "updated_list", lst
yield "old_item", old
class ListIsEmptyBlock(Block):
class Input(BlockSchema):
list: List[Any] = SchemaField(description="The list to check.")
class Output(BlockSchema):
is_empty: bool = SchemaField(description="True if the list is empty.")
def __init__(self):
super().__init__(
id="896ed73b-27d0-41be-813c-c1c1dc856c03",
description="Checks if a list is empty.",
categories={BlockCategory.BASIC},
input_schema=ListIsEmptyBlock.Input,
output_schema=ListIsEmptyBlock.Output,
test_input=[{"list": []}, {"list": [1]}],
test_output=[("is_empty", True), ("is_empty", False)],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
yield "is_empty", len(input_data.list) == 0

View File

@@ -1,178 +0,0 @@
"""
DataForSEO API client with async support using the SDK patterns.
"""
import base64
from typing import Any, Dict, List, Optional
from backend.sdk import Requests, UserPasswordCredentials
class DataForSeoClient:
"""Client for the DataForSEO API using async requests."""
API_URL = "https://api.dataforseo.com"
def __init__(self, credentials: UserPasswordCredentials):
self.credentials = credentials
self.requests = Requests(
trusted_origins=["https://api.dataforseo.com"],
raise_for_status=False,
)
def _get_headers(self) -> Dict[str, str]:
"""Generate the authorization header using Basic Auth."""
username = self.credentials.username.get_secret_value()
password = self.credentials.password.get_secret_value()
credentials_str = f"{username}:{password}"
encoded = base64.b64encode(credentials_str.encode("ascii")).decode("ascii")
return {
"Authorization": f"Basic {encoded}",
"Content-Type": "application/json",
}
async def keyword_suggestions(
self,
keyword: str,
location_code: Optional[int] = None,
language_code: Optional[str] = None,
include_seed_keyword: bool = True,
include_serp_info: bool = False,
include_clickstream_data: bool = False,
limit: int = 100,
) -> List[Dict[str, Any]]:
"""
Get keyword suggestions from DataForSEO Labs.
Args:
keyword: Seed keyword
location_code: Location code for targeting
language_code: Language code (e.g., "en")
include_seed_keyword: Include seed keyword in results
include_serp_info: Include SERP data
include_clickstream_data: Include clickstream metrics
limit: Maximum number of results (up to 3000)
Returns:
API response with keyword suggestions
"""
endpoint = f"{self.API_URL}/v3/dataforseo_labs/google/keyword_suggestions/live"
# Build payload only with non-None values to avoid sending null fields
task_data: dict[str, Any] = {
"keyword": keyword,
}
if location_code is not None:
task_data["location_code"] = location_code
if language_code is not None:
task_data["language_code"] = language_code
if include_seed_keyword is not None:
task_data["include_seed_keyword"] = include_seed_keyword
if include_serp_info is not None:
task_data["include_serp_info"] = include_serp_info
if include_clickstream_data is not None:
task_data["include_clickstream_data"] = include_clickstream_data
if limit is not None:
task_data["limit"] = limit
payload = [task_data]
response = await self.requests.post(
endpoint,
headers=self._get_headers(),
json=payload,
)
data = response.json()
# Check for API errors
if response.status != 200:
error_message = data.get("status_message", "Unknown error")
raise Exception(
f"DataForSEO API error ({response.status}): {error_message}"
)
# Extract the results from the response
if data.get("tasks") and len(data["tasks"]) > 0:
task = data["tasks"][0]
if task.get("status_code") == 20000: # Success code
return task.get("result", [])
else:
error_msg = task.get("status_message", "Task failed")
raise Exception(f"DataForSEO task error: {error_msg}")
return []
async def related_keywords(
self,
keyword: str,
location_code: Optional[int] = None,
language_code: Optional[str] = None,
include_seed_keyword: bool = True,
include_serp_info: bool = False,
include_clickstream_data: bool = False,
limit: int = 100,
) -> List[Dict[str, Any]]:
"""
Get related keywords from DataForSEO Labs.
Args:
keyword: Seed keyword
location_code: Location code for targeting
language_code: Language code (e.g., "en")
include_seed_keyword: Include seed keyword in results
include_serp_info: Include SERP data
include_clickstream_data: Include clickstream metrics
limit: Maximum number of results (up to 3000)
Returns:
API response with related keywords
"""
endpoint = f"{self.API_URL}/v3/dataforseo_labs/google/related_keywords/live"
# Build payload only with non-None values to avoid sending null fields
task_data: dict[str, Any] = {
"keyword": keyword,
}
if location_code is not None:
task_data["location_code"] = location_code
if language_code is not None:
task_data["language_code"] = language_code
if include_seed_keyword is not None:
task_data["include_seed_keyword"] = include_seed_keyword
if include_serp_info is not None:
task_data["include_serp_info"] = include_serp_info
if include_clickstream_data is not None:
task_data["include_clickstream_data"] = include_clickstream_data
if limit is not None:
task_data["limit"] = limit
payload = [task_data]
response = await self.requests.post(
endpoint,
headers=self._get_headers(),
json=payload,
)
data = response.json()
# Check for API errors
if response.status != 200:
error_message = data.get("status_message", "Unknown error")
raise Exception(
f"DataForSEO API error ({response.status}): {error_message}"
)
# Extract the results from the response
if data.get("tasks") and len(data["tasks"]) > 0:
task = data["tasks"][0]
if task.get("status_code") == 20000: # Success code
return task.get("result", [])
else:
error_msg = task.get("status_message", "Task failed")
raise Exception(f"DataForSEO task error: {error_msg}")
return []

View File

@@ -1,17 +0,0 @@
"""
Configuration for all DataForSEO blocks using the new SDK pattern.
"""
from backend.sdk import BlockCostType, ProviderBuilder
# Build the DataForSEO provider with username/password authentication
dataforseo = (
ProviderBuilder("dataforseo")
.with_user_password(
username_env_var="DATAFORSEO_USERNAME",
password_env_var="DATAFORSEO_PASSWORD",
title="DataForSEO Credentials",
)
.with_base_cost(1, BlockCostType.RUN)
.build()
)

View File

@@ -1,273 +0,0 @@
"""
DataForSEO Google Keyword Suggestions block.
"""
from typing import Any, Dict, List, Optional
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
SchemaField,
UserPasswordCredentials,
)
from ._api import DataForSeoClient
from ._config import dataforseo
class KeywordSuggestion(BlockSchema):
"""Schema for a keyword suggestion result."""
keyword: str = SchemaField(description="The keyword suggestion")
search_volume: Optional[int] = SchemaField(
description="Monthly search volume", default=None
)
competition: Optional[float] = SchemaField(
description="Competition level (0-1)", default=None
)
cpc: Optional[float] = SchemaField(
description="Cost per click in USD", default=None
)
keyword_difficulty: Optional[int] = SchemaField(
description="Keyword difficulty score", default=None
)
serp_info: Optional[Dict[str, Any]] = SchemaField(
description="data from SERP for each keyword", default=None
)
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
description="Clickstream data metrics", default=None
)
class DataForSeoKeywordSuggestionsBlock(Block):
"""Block for getting keyword suggestions from DataForSEO Labs."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = dataforseo.credentials_field(
description="DataForSEO credentials (username and password)"
)
keyword: str = SchemaField(description="Seed keyword to get suggestions for")
location_code: Optional[int] = SchemaField(
description="Location code for targeting (e.g., 2840 for USA)",
default=2840, # USA
)
language_code: Optional[str] = SchemaField(
description="Language code (e.g., 'en' for English)",
default="en",
)
include_seed_keyword: bool = SchemaField(
description="Include the seed keyword in results",
default=True,
)
include_serp_info: bool = SchemaField(
description="Include SERP information",
default=False,
)
include_clickstream_data: bool = SchemaField(
description="Include clickstream metrics",
default=False,
)
limit: int = SchemaField(
description="Maximum number of results (up to 3000)",
default=100,
ge=1,
le=3000,
)
class Output(BlockSchema):
suggestions: List[KeywordSuggestion] = SchemaField(
description="List of keyword suggestions with metrics"
)
suggestion: KeywordSuggestion = SchemaField(
description="A single keyword suggestion with metrics"
)
total_count: int = SchemaField(
description="Total number of suggestions returned"
)
seed_keyword: str = SchemaField(
description="The seed keyword used for the query"
)
def __init__(self):
super().__init__(
id="73c3e7c4-2b3f-4e9f-9e3e-8f7a5c3e2d45",
description="Get keyword suggestions from DataForSEO Labs Google API",
categories={BlockCategory.SEARCH, BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
test_input={
"credentials": dataforseo.get_test_credentials().model_dump(),
"keyword": "digital marketing",
"location_code": 2840,
"language_code": "en",
"limit": 1,
},
test_credentials=dataforseo.get_test_credentials(),
test_output=[
(
"suggestion",
lambda x: hasattr(x, "keyword")
and x.keyword == "digital marketing strategy",
),
("suggestions", lambda x: isinstance(x, list) and len(x) == 1),
("total_count", 1),
("seed_keyword", "digital marketing"),
],
test_mock={
"_fetch_keyword_suggestions": lambda *args, **kwargs: [
{
"items": [
{
"keyword": "digital marketing strategy",
"keyword_info": {
"search_volume": 10000,
"competition": 0.5,
"cpc": 2.5,
},
"keyword_properties": {
"keyword_difficulty": 50,
},
}
]
}
]
},
)
async def _fetch_keyword_suggestions(
self,
client: DataForSeoClient,
input_data: Input,
) -> Any:
"""Private method to fetch keyword suggestions - can be mocked for testing."""
return await client.keyword_suggestions(
keyword=input_data.keyword,
location_code=input_data.location_code,
language_code=input_data.language_code,
include_seed_keyword=input_data.include_seed_keyword,
include_serp_info=input_data.include_serp_info,
include_clickstream_data=input_data.include_clickstream_data,
limit=input_data.limit,
)
async def run(
self,
input_data: Input,
*,
credentials: UserPasswordCredentials,
**kwargs,
) -> BlockOutput:
"""Execute the keyword suggestions query."""
client = DataForSeoClient(credentials)
results = await self._fetch_keyword_suggestions(client, input_data)
# Process and format the results
suggestions = []
if results and len(results) > 0:
# results is a list, get the first element
first_result = results[0] if isinstance(results, list) else results
items = (
first_result.get("items", []) if isinstance(first_result, dict) else []
)
for item in items:
# Create the KeywordSuggestion object
suggestion = KeywordSuggestion(
keyword=item.get("keyword", ""),
search_volume=item.get("keyword_info", {}).get("search_volume"),
competition=item.get("keyword_info", {}).get("competition"),
cpc=item.get("keyword_info", {}).get("cpc"),
keyword_difficulty=item.get("keyword_properties", {}).get(
"keyword_difficulty"
),
serp_info=(
item.get("serp_info") if input_data.include_serp_info else None
),
clickstream_data=(
item.get("clickstream_keyword_info")
if input_data.include_clickstream_data
else None
),
)
yield "suggestion", suggestion
suggestions.append(suggestion)
yield "suggestions", suggestions
yield "total_count", len(suggestions)
yield "seed_keyword", input_data.keyword
class KeywordSuggestionExtractorBlock(Block):
"""Extracts individual fields from a KeywordSuggestion object."""
class Input(BlockSchema):
suggestion: KeywordSuggestion = SchemaField(
description="The keyword suggestion object to extract fields from"
)
class Output(BlockSchema):
keyword: str = SchemaField(description="The keyword suggestion")
search_volume: Optional[int] = SchemaField(
description="Monthly search volume", default=None
)
competition: Optional[float] = SchemaField(
description="Competition level (0-1)", default=None
)
cpc: Optional[float] = SchemaField(
description="Cost per click in USD", default=None
)
keyword_difficulty: Optional[int] = SchemaField(
description="Keyword difficulty score", default=None
)
serp_info: Optional[Dict[str, Any]] = SchemaField(
description="data from SERP for each keyword", default=None
)
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
description="Clickstream data metrics", default=None
)
def __init__(self):
super().__init__(
id="4193cb94-677c-48b0-9eec-6ac72fffd0f2",
description="Extract individual fields from a KeywordSuggestion object",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
test_input={
"suggestion": KeywordSuggestion(
keyword="test keyword",
search_volume=1000,
competition=0.5,
cpc=2.5,
keyword_difficulty=60,
).model_dump()
},
test_output=[
("keyword", "test keyword"),
("search_volume", 1000),
("competition", 0.5),
("cpc", 2.5),
("keyword_difficulty", 60),
("serp_info", None),
("clickstream_data", None),
],
)
async def run(
self,
input_data: Input,
**kwargs,
) -> BlockOutput:
"""Extract fields from the KeywordSuggestion object."""
suggestion = input_data.suggestion
yield "keyword", suggestion.keyword
yield "search_volume", suggestion.search_volume
yield "competition", suggestion.competition
yield "cpc", suggestion.cpc
yield "keyword_difficulty", suggestion.keyword_difficulty
yield "serp_info", suggestion.serp_info
yield "clickstream_data", suggestion.clickstream_data

View File

@@ -1,283 +0,0 @@
"""
DataForSEO Google Related Keywords block.
"""
from typing import Any, Dict, List, Optional
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
SchemaField,
UserPasswordCredentials,
)
from ._api import DataForSeoClient
from ._config import dataforseo
class RelatedKeyword(BlockSchema):
"""Schema for a related keyword result."""
keyword: str = SchemaField(description="The related keyword")
search_volume: Optional[int] = SchemaField(
description="Monthly search volume", default=None
)
competition: Optional[float] = SchemaField(
description="Competition level (0-1)", default=None
)
cpc: Optional[float] = SchemaField(
description="Cost per click in USD", default=None
)
keyword_difficulty: Optional[int] = SchemaField(
description="Keyword difficulty score", default=None
)
serp_info: Optional[Dict[str, Any]] = SchemaField(
description="SERP data for the keyword", default=None
)
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
description="Clickstream data metrics", default=None
)
class DataForSeoRelatedKeywordsBlock(Block):
"""Block for getting related keywords from DataForSEO Labs."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = dataforseo.credentials_field(
description="DataForSEO credentials (username and password)"
)
keyword: str = SchemaField(
description="Seed keyword to find related keywords for"
)
location_code: Optional[int] = SchemaField(
description="Location code for targeting (e.g., 2840 for USA)",
default=2840, # USA
)
language_code: Optional[str] = SchemaField(
description="Language code (e.g., 'en' for English)",
default="en",
)
include_seed_keyword: bool = SchemaField(
description="Include the seed keyword in results",
default=True,
)
include_serp_info: bool = SchemaField(
description="Include SERP information",
default=False,
)
include_clickstream_data: bool = SchemaField(
description="Include clickstream metrics",
default=False,
)
limit: int = SchemaField(
description="Maximum number of results (up to 3000)",
default=100,
ge=1,
le=3000,
)
class Output(BlockSchema):
related_keywords: List[RelatedKeyword] = SchemaField(
description="List of related keywords with metrics"
)
related_keyword: RelatedKeyword = SchemaField(
description="A related keyword with metrics"
)
total_count: int = SchemaField(
description="Total number of related keywords returned"
)
seed_keyword: str = SchemaField(
description="The seed keyword used for the query"
)
def __init__(self):
super().__init__(
id="8f2e4d6a-1b3c-4a5e-9d7f-2c8e6a4b3f1d",
description="Get related keywords from DataForSEO Labs Google API",
categories={BlockCategory.SEARCH, BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
test_input={
"credentials": dataforseo.get_test_credentials().model_dump(),
"keyword": "content marketing",
"location_code": 2840,
"language_code": "en",
"limit": 1,
},
test_credentials=dataforseo.get_test_credentials(),
test_output=[
(
"related_keyword",
lambda x: hasattr(x, "keyword") and x.keyword == "content strategy",
),
("related_keywords", lambda x: isinstance(x, list) and len(x) == 1),
("total_count", 1),
("seed_keyword", "content marketing"),
],
test_mock={
"_fetch_related_keywords": lambda *args, **kwargs: [
{
"items": [
{
"keyword_data": {
"keyword": "content strategy",
"keyword_info": {
"search_volume": 8000,
"competition": 0.4,
"cpc": 3.0,
},
"keyword_properties": {
"keyword_difficulty": 45,
},
}
}
]
}
]
},
)
async def _fetch_related_keywords(
self,
client: DataForSeoClient,
input_data: Input,
) -> Any:
"""Private method to fetch related keywords - can be mocked for testing."""
return await client.related_keywords(
keyword=input_data.keyword,
location_code=input_data.location_code,
language_code=input_data.language_code,
include_seed_keyword=input_data.include_seed_keyword,
include_serp_info=input_data.include_serp_info,
include_clickstream_data=input_data.include_clickstream_data,
limit=input_data.limit,
)
async def run(
self,
input_data: Input,
*,
credentials: UserPasswordCredentials,
**kwargs,
) -> BlockOutput:
"""Execute the related keywords query."""
client = DataForSeoClient(credentials)
results = await self._fetch_related_keywords(client, input_data)
# Process and format the results
related_keywords = []
if results and len(results) > 0:
# results is a list, get the first element
first_result = results[0] if isinstance(results, list) else results
items = (
first_result.get("items", []) if isinstance(first_result, dict) else []
)
for item in items:
# Extract keyword_data from the item
keyword_data = item.get("keyword_data", {})
# Create the RelatedKeyword object
keyword = RelatedKeyword(
keyword=keyword_data.get("keyword", ""),
search_volume=keyword_data.get("keyword_info", {}).get(
"search_volume"
),
competition=keyword_data.get("keyword_info", {}).get("competition"),
cpc=keyword_data.get("keyword_info", {}).get("cpc"),
keyword_difficulty=keyword_data.get("keyword_properties", {}).get(
"keyword_difficulty"
),
serp_info=(
keyword_data.get("serp_info")
if input_data.include_serp_info
else None
),
clickstream_data=(
keyword_data.get("clickstream_keyword_info")
if input_data.include_clickstream_data
else None
),
)
yield "related_keyword", keyword
related_keywords.append(keyword)
yield "related_keywords", related_keywords
yield "total_count", len(related_keywords)
yield "seed_keyword", input_data.keyword
class RelatedKeywordExtractorBlock(Block):
"""Extracts individual fields from a RelatedKeyword object."""
class Input(BlockSchema):
related_keyword: RelatedKeyword = SchemaField(
description="The related keyword object to extract fields from"
)
class Output(BlockSchema):
keyword: str = SchemaField(description="The related keyword")
search_volume: Optional[int] = SchemaField(
description="Monthly search volume", default=None
)
competition: Optional[float] = SchemaField(
description="Competition level (0-1)", default=None
)
cpc: Optional[float] = SchemaField(
description="Cost per click in USD", default=None
)
keyword_difficulty: Optional[int] = SchemaField(
description="Keyword difficulty score", default=None
)
serp_info: Optional[Dict[str, Any]] = SchemaField(
description="SERP data for the keyword", default=None
)
clickstream_data: Optional[Dict[str, Any]] = SchemaField(
description="Clickstream data metrics", default=None
)
def __init__(self):
super().__init__(
id="98342061-09d2-4952-bf77-0761fc8cc9a8",
description="Extract individual fields from a RelatedKeyword object",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
test_input={
"related_keyword": RelatedKeyword(
keyword="test related keyword",
search_volume=800,
competition=0.4,
cpc=3.0,
keyword_difficulty=55,
).model_dump()
},
test_output=[
("keyword", "test related keyword"),
("search_volume", 800),
("competition", 0.4),
("cpc", 3.0),
("keyword_difficulty", 55),
("serp_info", None),
("clickstream_data", None),
],
)
async def run(
self,
input_data: Input,
**kwargs,
) -> BlockOutput:
"""Extract fields from the RelatedKeyword object."""
related_keyword = input_data.related_keyword
yield "keyword", related_keyword.keyword
yield "search_volume", related_keyword.search_volume
yield "competition", related_keyword.competition
yield "cpc", related_keyword.cpc
yield "keyword_difficulty", related_keyword.keyword_difficulty
yield "serp_info", related_keyword.serp_info
yield "clickstream_data", related_keyword.clickstream_data

View File

@@ -0,0 +1,237 @@
from typing import Literal
import aiohttp
import discord
from pydantic import SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
DiscordCredentials = CredentialsMetaInput[
Literal[ProviderName.DISCORD], Literal["api_key"]
]
def DiscordCredentialsField() -> DiscordCredentials:
return CredentialsField(description="Discord bot token")
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="discord",
api_key=SecretStr("test_api_key"),
title="Mock Discord API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.type,
}
class ReadDiscordMessagesBlock(Block):
class Input(BlockSchema):
credentials: DiscordCredentials = DiscordCredentialsField()
class Output(BlockSchema):
message_content: str = SchemaField(
description="The content of the message received"
)
channel_name: str = SchemaField(
description="The name of the channel the message was received from"
)
username: str = SchemaField(
description="The username of the user who sent the message"
)
def __init__(self):
super().__init__(
id="df06086a-d5ac-4abb-9996-2ad0acb2eff7",
input_schema=ReadDiscordMessagesBlock.Input, # Assign input schema
output_schema=ReadDiscordMessagesBlock.Output, # Assign output schema
description="Reads messages from a Discord channel using a bot token.",
categories={BlockCategory.SOCIAL},
test_input={
"continuous_read": False,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"message_content",
"Hello!\n\nFile from user: example.txt\nContent: This is the content of the file.",
),
("channel_name", "general"),
("username", "test_user"),
],
test_mock={
"run_bot": lambda token: {
"output_data": "Hello!\n\nFile from user: example.txt\nContent: This is the content of the file.",
"channel_name": "general",
"username": "test_user",
}
},
)
async def run_bot(self, token: SecretStr):
intents = discord.Intents.default()
intents.message_content = True
client = discord.Client(intents=intents)
self.output_data = None
self.channel_name = None
self.username = None
@client.event
async def on_ready():
print(f"Logged in as {client.user}")
@client.event
async def on_message(message):
if message.author == client.user:
return
self.output_data = message.content
self.channel_name = message.channel.name
self.username = message.author.name
if message.attachments:
attachment = message.attachments[0] # Process the first attachment
if attachment.filename.endswith((".txt", ".py")):
async with aiohttp.ClientSession() as session:
async with session.get(attachment.url) as response:
file_content = response.text()
self.output_data += f"\n\nFile from user: {attachment.filename}\nContent: {file_content}"
await client.close()
await client.start(token.get_secret_value())
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
async for output_name, output_value in self.__run(input_data, credentials):
yield output_name, output_value
async def __run(
self, input_data: Input, credentials: APIKeyCredentials
) -> BlockOutput:
try:
result = await self.run_bot(credentials.api_key)
# For testing purposes, use the mocked result
if isinstance(result, dict):
self.output_data = result.get("output_data")
self.channel_name = result.get("channel_name")
self.username = result.get("username")
if (
self.output_data is None
or self.channel_name is None
or self.username is None
):
raise ValueError("No message, channel name, or username received.")
yield "message_content", self.output_data
yield "channel_name", self.channel_name
yield "username", self.username
except discord.errors.LoginFailure as login_err:
raise ValueError(f"Login error occurred: {login_err}")
except Exception as e:
raise ValueError(f"An error occurred: {e}")
class SendDiscordMessageBlock(Block):
class Input(BlockSchema):
credentials: DiscordCredentials = DiscordCredentialsField()
message_content: str = SchemaField(
description="The content of the message received"
)
channel_name: str = SchemaField(
description="The name of the channel the message was received from"
)
class Output(BlockSchema):
status: str = SchemaField(
description="The status of the operation (e.g., 'Message sent', 'Error')"
)
def __init__(self):
super().__init__(
id="d0822ab5-9f8a-44a3-8971-531dd0178b6b",
input_schema=SendDiscordMessageBlock.Input, # Assign input schema
output_schema=SendDiscordMessageBlock.Output, # Assign output schema
description="Sends a message to a Discord channel using a bot token.",
categories={BlockCategory.SOCIAL},
test_input={
"channel_name": "general",
"message_content": "Hello, Discord!",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[("status", "Message sent")],
test_mock={
"send_message": lambda token, channel_name, message_content: "Message sent"
},
test_credentials=TEST_CREDENTIALS,
)
async def send_message(self, token: str, channel_name: str, message_content: str):
intents = discord.Intents.default()
intents.guilds = True # Required for fetching guild/channel information
client = discord.Client(intents=intents)
@client.event
async def on_ready():
print(f"Logged in as {client.user}")
for guild in client.guilds:
for channel in guild.text_channels:
if channel.name == channel_name:
# Split message into chunks if it exceeds 2000 characters
for chunk in self.chunk_message(message_content):
await channel.send(chunk)
self.output_data = "Message sent"
await client.close()
return
self.output_data = "Channel not found"
await client.close()
await client.start(token)
def chunk_message(self, message: str, limit: int = 2000) -> list:
"""Splits a message into chunks not exceeding the Discord limit."""
return [message[i : i + limit] for i in range(0, len(message), limit)]
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
try:
result = await self.send_message(
credentials.api_key.get_secret_value(),
input_data.channel_name,
input_data.message_content,
)
# For testing purposes, use the mocked result
if isinstance(result, str):
self.output_data = result
if self.output_data is None:
raise ValueError("No status message received.")
yield "status", self.output_data
except discord.errors.LoginFailure as login_err:
raise ValueError(f"Login error occurred: {login_err}")
except Exception as e:
raise ValueError(f"An error occurred: {e}")

View File

@@ -1,117 +0,0 @@
"""
Discord API helper functions for making authenticated requests.
"""
import logging
from typing import Optional
from pydantic import BaseModel
from backend.data.model import OAuth2Credentials
from backend.util.request import Requests
logger = logging.getLogger(__name__)
class DiscordAPIException(Exception):
"""Exception raised for Discord API errors."""
def __init__(self, message: str, status_code: int):
super().__init__(message)
self.status_code = status_code
class DiscordOAuthUser(BaseModel):
"""Model for Discord OAuth user response."""
user_id: str
username: str
avatar_url: str
banner: Optional[str] = None
accent_color: Optional[int] = None
def get_api(credentials: OAuth2Credentials) -> Requests:
"""
Create a Requests instance configured for Discord API calls with OAuth2 credentials.
Args:
credentials: The OAuth2 credentials containing the access token.
Returns:
A configured Requests instance for Discord API calls.
"""
return Requests(
trusted_origins=[],
extra_headers={
"Authorization": f"Bearer {credentials.access_token.get_secret_value()}",
"Content-Type": "application/json",
},
raise_for_status=False,
)
async def get_current_user(credentials: OAuth2Credentials) -> DiscordOAuthUser:
"""
Fetch the current user's information using Discord OAuth2 API.
Reference: https://discord.com/developers/docs/resources/user#get-current-user
Args:
credentials: The OAuth2 credentials.
Returns:
A model containing user data with avatar URL.
Raises:
DiscordAPIException: If the API request fails.
"""
api = get_api(credentials)
response = await api.get("https://discord.com/api/oauth2/@me")
if not response.ok:
error_text = response.text()
raise DiscordAPIException(
f"Failed to fetch user info: {response.status} - {error_text}",
response.status,
)
data = response.json()
logger.info(f"Discord OAuth2 API Response: {data}")
# The /api/oauth2/@me endpoint returns a user object nested in the response
user_info = data.get("user", {})
logger.info(f"User info extracted: {user_info}")
# Build avatar URL
user_id = user_info.get("id")
avatar_hash = user_info.get("avatar")
if avatar_hash:
# Custom avatar
avatar_ext = "gif" if avatar_hash.startswith("a_") else "png"
avatar_url = (
f"https://cdn.discordapp.com/avatars/{user_id}/{avatar_hash}.{avatar_ext}"
)
else:
# Default avatar based on discriminator or user ID
discriminator = user_info.get("discriminator", "0")
if discriminator == "0":
# New username system - use user ID for default avatar
default_avatar_index = (int(user_id) >> 22) % 6
else:
# Legacy discriminator system
default_avatar_index = int(discriminator) % 5
avatar_url = (
f"https://cdn.discordapp.com/embed/avatars/{default_avatar_index}.png"
)
result = DiscordOAuthUser(
user_id=user_id,
username=user_info.get("username", ""),
avatar_url=avatar_url,
banner=user_info.get("banner"),
accent_color=user_info.get("accent_color"),
)
logger.info(f"Returning user data: {result.model_dump()}")
return result

View File

@@ -1,74 +0,0 @@
from typing import Literal
from pydantic import SecretStr
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
OAuth2Credentials,
)
from backend.integrations.providers import ProviderName
from backend.util.settings import Secrets
secrets = Secrets()
DISCORD_OAUTH_IS_CONFIGURED = bool(
secrets.discord_client_id and secrets.discord_client_secret
)
# Bot token credentials (existing)
DiscordBotCredentials = APIKeyCredentials
DiscordBotCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.DISCORD], Literal["api_key"]
]
# OAuth2 credentials (new)
DiscordOAuthCredentials = OAuth2Credentials
DiscordOAuthCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.DISCORD], Literal["oauth2"]
]
def DiscordBotCredentialsField() -> DiscordBotCredentialsInput:
"""Creates a Discord bot token credentials field."""
return CredentialsField(description="Discord bot token")
def DiscordOAuthCredentialsField(scopes: list[str]) -> DiscordOAuthCredentialsInput:
"""Creates a Discord OAuth2 credentials field."""
return CredentialsField(
description="Discord OAuth2 credentials",
required_scopes=set(scopes) | {"identify"}, # Basic user info scope
)
# Test credentials for bot tokens
TEST_BOT_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="discord",
api_key=SecretStr("test_api_key"),
title="Mock Discord API key",
expires_at=None,
)
TEST_BOT_CREDENTIALS_INPUT = {
"provider": TEST_BOT_CREDENTIALS.provider,
"id": TEST_BOT_CREDENTIALS.id,
"type": TEST_BOT_CREDENTIALS.type,
"title": TEST_BOT_CREDENTIALS.type,
}
# Test credentials for OAuth2
TEST_OAUTH_CREDENTIALS = OAuth2Credentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="discord",
access_token=SecretStr("test_access_token"),
title="Mock Discord OAuth",
scopes=["identify"],
username="testuser",
)
TEST_OAUTH_CREDENTIALS_INPUT = {
"provider": TEST_OAUTH_CREDENTIALS.provider,
"id": TEST_OAUTH_CREDENTIALS.id,
"type": TEST_OAUTH_CREDENTIALS.type,
"title": TEST_OAUTH_CREDENTIALS.type,
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,99 +0,0 @@
"""
Discord OAuth-based blocks.
"""
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import OAuth2Credentials, SchemaField
from ._api import DiscordOAuthUser, get_current_user
from ._auth import (
DISCORD_OAUTH_IS_CONFIGURED,
TEST_OAUTH_CREDENTIALS,
TEST_OAUTH_CREDENTIALS_INPUT,
DiscordOAuthCredentialsField,
DiscordOAuthCredentialsInput,
)
class DiscordGetCurrentUserBlock(Block):
"""
Gets information about the currently authenticated Discord user using OAuth2.
This block requires Discord OAuth2 credentials (not bot tokens).
"""
class Input(BlockSchema):
credentials: DiscordOAuthCredentialsInput = DiscordOAuthCredentialsField(
["identify"]
)
class Output(BlockSchema):
user_id: str = SchemaField(description="The authenticated user's Discord ID")
username: str = SchemaField(description="The user's username")
avatar_url: str = SchemaField(description="URL to the user's avatar image")
banner_url: str = SchemaField(
description="URL to the user's banner image (if set)", default=""
)
accent_color: int = SchemaField(
description="The user's accent color as an integer", default=0
)
def __init__(self):
super().__init__(
id="8c7e39b8-4e9d-4f3a-b4e1-2a8c9d5f6e3b",
input_schema=DiscordGetCurrentUserBlock.Input,
output_schema=DiscordGetCurrentUserBlock.Output,
description="Gets information about the currently authenticated Discord user using OAuth2 credentials.",
categories={BlockCategory.SOCIAL},
disabled=not DISCORD_OAUTH_IS_CONFIGURED,
test_input={
"credentials": TEST_OAUTH_CREDENTIALS_INPUT,
},
test_credentials=TEST_OAUTH_CREDENTIALS,
test_output=[
("user_id", "123456789012345678"),
("username", "testuser"),
(
"avatar_url",
"https://cdn.discordapp.com/avatars/123456789012345678/avatar.png",
),
("banner_url", ""),
("accent_color", 0),
],
test_mock={
"get_user": lambda _: DiscordOAuthUser(
user_id="123456789012345678",
username="testuser",
avatar_url="https://cdn.discordapp.com/avatars/123456789012345678/avatar.png",
banner=None,
accent_color=0,
)
},
)
@staticmethod
async def get_user(credentials: OAuth2Credentials) -> DiscordOAuthUser:
user_info = await get_current_user(credentials)
return user_info
async def run(
self, input_data: Input, *, credentials: OAuth2Credentials, **kwargs
) -> BlockOutput:
try:
result = await self.get_user(credentials)
# Yield each output field
yield "user_id", result.user_id
yield "username", result.username
yield "avatar_url", result.avatar_url
# Handle banner URL if banner hash exists
if result.banner:
banner_url = f"https://cdn.discordapp.com/banners/{result.user_id}/{result.banner}.png"
yield "banner_url", banner_url
else:
yield "banner_url", ""
yield "accent_color", result.accent_color or 0
except Exception as e:
raise ValueError(f"Failed to get Discord user info: {e}")

View File

@@ -1,408 +0,0 @@
"""
API module for Enrichlayer integration.
This module provides a client for interacting with the Enrichlayer API,
which allows fetching LinkedIn profile data and related information.
"""
import datetime
import enum
import logging
from json import JSONDecodeError
from typing import Any, Optional, TypeVar
from pydantic import BaseModel, Field
from backend.data.model import APIKeyCredentials
from backend.util.request import Requests
logger = logging.getLogger(__name__)
T = TypeVar("T")
class EnrichlayerAPIException(Exception):
"""Exception raised for Enrichlayer API errors."""
def __init__(self, message: str, status_code: int):
super().__init__(message)
self.status_code = status_code
class FallbackToCache(enum.Enum):
ON_ERROR = "on-error"
NEVER = "never"
class UseCache(enum.Enum):
IF_PRESENT = "if-present"
NEVER = "never"
class SocialMediaProfiles(BaseModel):
"""Social media profiles model."""
twitter: Optional[str] = None
facebook: Optional[str] = None
github: Optional[str] = None
class Experience(BaseModel):
"""Experience model for LinkedIn profiles."""
company: Optional[str] = None
title: Optional[str] = None
description: Optional[str] = None
location: Optional[str] = None
starts_at: Optional[dict[str, int]] = None
ends_at: Optional[dict[str, int]] = None
company_linkedin_profile_url: Optional[str] = None
class Education(BaseModel):
"""Education model for LinkedIn profiles."""
school: Optional[str] = None
degree_name: Optional[str] = None
field_of_study: Optional[str] = None
starts_at: Optional[dict[str, int]] = None
ends_at: Optional[dict[str, int]] = None
school_linkedin_profile_url: Optional[str] = None
class PersonProfileResponse(BaseModel):
"""Response model for LinkedIn person profile.
This model represents the response from Enrichlayer's LinkedIn profile API.
The API returns comprehensive profile data including work experience,
education, skills, and contact information (when available).
Example API Response:
{
"public_identifier": "johnsmith",
"full_name": "John Smith",
"occupation": "Software Engineer at Tech Corp",
"experiences": [
{
"company": "Tech Corp",
"title": "Software Engineer",
"starts_at": {"year": 2020, "month": 1}
}
],
"education": [...],
"skills": ["Python", "JavaScript", ...]
}
"""
public_identifier: Optional[str] = None
profile_pic_url: Optional[str] = None
full_name: Optional[str] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
occupation: Optional[str] = None
headline: Optional[str] = None
summary: Optional[str] = None
country: Optional[str] = None
country_full_name: Optional[str] = None
city: Optional[str] = None
state: Optional[str] = None
experiences: Optional[list[Experience]] = None
education: Optional[list[Education]] = None
languages: Optional[list[str]] = None
skills: Optional[list[str]] = None
inferred_salary: Optional[dict[str, Any]] = None
personal_email: Optional[str] = None
personal_contact_number: Optional[str] = None
social_media_profiles: Optional[SocialMediaProfiles] = None
extra: Optional[dict[str, Any]] = None
class SimilarProfile(BaseModel):
"""Similar profile model for LinkedIn person lookup."""
similarity: float
linkedin_profile_url: str
class PersonLookupResponse(BaseModel):
"""Response model for LinkedIn person lookup.
This model represents the response from Enrichlayer's person lookup API.
The API returns a LinkedIn profile URL and similarity scores when
searching for a person by name and company.
Example API Response:
{
"url": "https://www.linkedin.com/in/johnsmith/",
"name_similarity_score": 0.95,
"company_similarity_score": 0.88,
"title_similarity_score": 0.75,
"location_similarity_score": 0.60
}
"""
url: str | None = None
name_similarity_score: float | None
company_similarity_score: float | None
title_similarity_score: float | None
location_similarity_score: float | None
last_updated: datetime.datetime | None = None
profile: PersonProfileResponse | None = None
class RoleLookupResponse(BaseModel):
"""Response model for LinkedIn role lookup.
This model represents the response from Enrichlayer's role lookup API.
The API returns LinkedIn profile data for a specific role at a company.
Example API Response:
{
"linkedin_profile_url": "https://www.linkedin.com/in/johnsmith/",
"profile_data": {...} // Full PersonProfileResponse data when enrich_profile=True
}
"""
linkedin_profile_url: Optional[str] = None
profile_data: Optional[PersonProfileResponse] = None
class ProfilePictureResponse(BaseModel):
"""Response model for LinkedIn profile picture.
This model represents the response from Enrichlayer's profile picture API.
The API returns a URL to the person's LinkedIn profile picture.
Example API Response:
{
"tmp_profile_pic_url": "https://media.licdn.com/dms/image/..."
}
"""
tmp_profile_pic_url: str = Field(
..., description="URL of the profile picture", alias="tmp_profile_pic_url"
)
@property
def profile_picture_url(self) -> str:
"""Backward compatibility property for profile_picture_url."""
return self.tmp_profile_pic_url
class EnrichlayerClient:
"""Client for interacting with the Enrichlayer API."""
API_BASE_URL = "https://enrichlayer.com/api/v2"
def __init__(
self,
credentials: Optional[APIKeyCredentials] = None,
custom_requests: Optional[Requests] = None,
):
"""
Initialize the Enrichlayer client.
Args:
credentials: The credentials to use for authentication.
custom_requests: Custom Requests instance for testing.
"""
if custom_requests:
self._requests = custom_requests
else:
headers: dict[str, str] = {
"Content-Type": "application/json",
}
if credentials:
headers["Authorization"] = (
f"Bearer {credentials.api_key.get_secret_value()}"
)
self._requests = Requests(
extra_headers=headers,
raise_for_status=False,
)
async def _handle_response(self, response) -> Any:
"""
Handle API response and check for errors.
Args:
response: The response object from the request.
Returns:
The response data.
Raises:
EnrichlayerAPIException: If the API request fails.
"""
if not response.ok:
try:
error_data = response.json()
error_message = error_data.get("message", "")
except JSONDecodeError:
error_message = response.text
raise EnrichlayerAPIException(
f"Enrichlayer API request failed ({response.status_code}): {error_message}",
response.status_code,
)
return response.json()
async def fetch_profile(
self,
linkedin_url: str,
fallback_to_cache: FallbackToCache = FallbackToCache.ON_ERROR,
use_cache: UseCache = UseCache.IF_PRESENT,
include_skills: bool = False,
include_inferred_salary: bool = False,
include_personal_email: bool = False,
include_personal_contact_number: bool = False,
include_social_media: bool = False,
include_extra: bool = False,
) -> PersonProfileResponse:
"""
Fetch a LinkedIn profile with optional parameters.
Args:
linkedin_url: The LinkedIn profile URL to fetch.
fallback_to_cache: Cache usage if live fetch fails ('on-error' or 'never').
use_cache: Cache utilization ('if-present' or 'never').
include_skills: Whether to include skills data.
include_inferred_salary: Whether to include inferred salary data.
include_personal_email: Whether to include personal email.
include_personal_contact_number: Whether to include personal contact number.
include_social_media: Whether to include social media profiles.
include_extra: Whether to include additional data.
Returns:
The LinkedIn profile data.
Raises:
EnrichlayerAPIException: If the API request fails.
"""
params = {
"url": linkedin_url,
"fallback_to_cache": fallback_to_cache.value.lower(),
"use_cache": use_cache.value.lower(),
}
if include_skills:
params["skills"] = "include"
if include_inferred_salary:
params["inferred_salary"] = "include"
if include_personal_email:
params["personal_email"] = "include"
if include_personal_contact_number:
params["personal_contact_number"] = "include"
if include_social_media:
params["twitter_profile_id"] = "include"
params["facebook_profile_id"] = "include"
params["github_profile_id"] = "include"
if include_extra:
params["extra"] = "include"
response = await self._requests.get(
f"{self.API_BASE_URL}/profile", params=params
)
return PersonProfileResponse(**await self._handle_response(response))
async def lookup_person(
self,
first_name: str,
company_domain: str,
last_name: str | None = None,
location: Optional[str] = None,
title: Optional[str] = None,
include_similarity_checks: bool = False,
enrich_profile: bool = False,
) -> PersonLookupResponse:
"""
Look up a LinkedIn profile by person's information.
Args:
first_name: The person's first name.
last_name: The person's last name.
company_domain: The domain of the company they work for.
location: The person's location.
title: The person's job title.
include_similarity_checks: Whether to include similarity checks.
enrich_profile: Whether to enrich the profile.
Returns:
The LinkedIn profile lookup result.
Raises:
EnrichlayerAPIException: If the API request fails.
"""
params = {"first_name": first_name, "company_domain": company_domain}
if last_name:
params["last_name"] = last_name
if location:
params["location"] = location
if title:
params["title"] = title
if include_similarity_checks:
params["similarity_checks"] = "include"
if enrich_profile:
params["enrich_profile"] = "enrich"
response = await self._requests.get(
f"{self.API_BASE_URL}/profile/resolve", params=params
)
return PersonLookupResponse(**await self._handle_response(response))
async def lookup_role(
self, role: str, company_name: str, enrich_profile: bool = False
) -> RoleLookupResponse:
"""
Look up a LinkedIn profile by role in a company.
Args:
role: The role title (e.g., CEO, CTO).
company_name: The name of the company.
enrich_profile: Whether to enrich the profile.
Returns:
The LinkedIn profile lookup result.
Raises:
EnrichlayerAPIException: If the API request fails.
"""
params = {
"role": role,
"company_name": company_name,
}
if enrich_profile:
params["enrich_profile"] = "enrich"
response = await self._requests.get(
f"{self.API_BASE_URL}/find/company/role", params=params
)
return RoleLookupResponse(**await self._handle_response(response))
async def get_profile_picture(
self, linkedin_profile_url: str
) -> ProfilePictureResponse:
"""
Get a LinkedIn profile picture URL.
Args:
linkedin_profile_url: The LinkedIn profile URL.
Returns:
The profile picture URL.
Raises:
EnrichlayerAPIException: If the API request fails.
"""
params = {
"linkedin_person_profile_url": linkedin_profile_url,
}
response = await self._requests.get(
f"{self.API_BASE_URL}/person/profile-picture", params=params
)
return ProfilePictureResponse(**await self._handle_response(response))

View File

@@ -1,34 +0,0 @@
"""
Authentication module for Enrichlayer API integration.
This module provides credential types and test credentials for the Enrichlayer API.
"""
from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
from backend.integrations.providers import ProviderName
# Define the type of credentials input expected for Enrichlayer API
EnrichlayerCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.ENRICHLAYER], Literal["api_key"]
]
# Mock credentials for testing Enrichlayer API integration
TEST_CREDENTIALS = APIKeyCredentials(
id="1234a567-89bc-4def-ab12-3456cdef7890",
provider="enrichlayer",
api_key=SecretStr("mock-enrichlayer-api-key"),
title="Mock Enrichlayer API key",
expires_at=None,
)
# Dictionary representation of test credentials for input fields
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}

View File

@@ -1,527 +0,0 @@
"""
Block definitions for Enrichlayer API integration.
This module implements blocks for interacting with the Enrichlayer API,
which provides access to LinkedIn profile data and related information.
"""
import logging
from typing import Optional
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
from backend.util.type import MediaFileType
from ._api import (
EnrichlayerClient,
Experience,
FallbackToCache,
PersonLookupResponse,
PersonProfileResponse,
RoleLookupResponse,
UseCache,
)
from ._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, EnrichlayerCredentialsInput
logger = logging.getLogger(__name__)
class GetLinkedinProfileBlock(Block):
"""Block to fetch LinkedIn profile data using Enrichlayer API."""
class Input(BlockSchema):
"""Input schema for GetLinkedinProfileBlock."""
linkedin_url: str = SchemaField(
description="LinkedIn profile URL to fetch data from",
placeholder="https://www.linkedin.com/in/username/",
)
fallback_to_cache: FallbackToCache = SchemaField(
description="Cache usage if live fetch fails",
default=FallbackToCache.ON_ERROR,
advanced=True,
)
use_cache: UseCache = SchemaField(
description="Cache utilization strategy",
default=UseCache.IF_PRESENT,
advanced=True,
)
include_skills: bool = SchemaField(
description="Include skills data",
default=False,
advanced=True,
)
include_inferred_salary: bool = SchemaField(
description="Include inferred salary data",
default=False,
advanced=True,
)
include_personal_email: bool = SchemaField(
description="Include personal email",
default=False,
advanced=True,
)
include_personal_contact_number: bool = SchemaField(
description="Include personal contact number",
default=False,
advanced=True,
)
include_social_media: bool = SchemaField(
description="Include social media profiles",
default=False,
advanced=True,
)
include_extra: bool = SchemaField(
description="Include additional data",
default=False,
advanced=True,
)
credentials: EnrichlayerCredentialsInput = CredentialsField(
description="Enrichlayer API credentials"
)
class Output(BlockSchema):
"""Output schema for GetLinkedinProfileBlock."""
profile: PersonProfileResponse = SchemaField(
description="LinkedIn profile data"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
"""Initialize GetLinkedinProfileBlock."""
super().__init__(
id="f6e0ac73-4f1d-4acb-b4b7-b67066c5984e",
description="Fetch LinkedIn profile data using Enrichlayer",
categories={BlockCategory.SOCIAL},
input_schema=GetLinkedinProfileBlock.Input,
output_schema=GetLinkedinProfileBlock.Output,
test_input={
"linkedin_url": "https://www.linkedin.com/in/williamhgates/",
"include_skills": True,
"include_social_media": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
(
"profile",
PersonProfileResponse(
public_identifier="williamhgates",
full_name="Bill Gates",
occupation="Co-chair at Bill & Melinda Gates Foundation",
experiences=[
Experience(
company="Bill & Melinda Gates Foundation",
title="Co-chair",
starts_at={"year": 2000},
)
],
),
)
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"_fetch_profile": lambda *args, **kwargs: PersonProfileResponse(
public_identifier="williamhgates",
full_name="Bill Gates",
occupation="Co-chair at Bill & Melinda Gates Foundation",
experiences=[
Experience(
company="Bill & Melinda Gates Foundation",
title="Co-chair",
starts_at={"year": 2000},
)
],
),
},
)
@staticmethod
async def _fetch_profile(
credentials: APIKeyCredentials,
linkedin_url: str,
fallback_to_cache: FallbackToCache = FallbackToCache.ON_ERROR,
use_cache: UseCache = UseCache.IF_PRESENT,
include_skills: bool = False,
include_inferred_salary: bool = False,
include_personal_email: bool = False,
include_personal_contact_number: bool = False,
include_social_media: bool = False,
include_extra: bool = False,
):
client = EnrichlayerClient(credentials)
profile = await client.fetch_profile(
linkedin_url=linkedin_url,
fallback_to_cache=fallback_to_cache,
use_cache=use_cache,
include_skills=include_skills,
include_inferred_salary=include_inferred_salary,
include_personal_email=include_personal_email,
include_personal_contact_number=include_personal_contact_number,
include_social_media=include_social_media,
include_extra=include_extra,
)
return profile
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
"""
Run the block to fetch LinkedIn profile data.
Args:
input_data: Input parameters for the block
credentials: API key credentials for Enrichlayer
**kwargs: Additional keyword arguments
Yields:
Tuples of (output_name, output_value)
"""
try:
profile = await self._fetch_profile(
credentials=credentials,
linkedin_url=input_data.linkedin_url,
fallback_to_cache=input_data.fallback_to_cache,
use_cache=input_data.use_cache,
include_skills=input_data.include_skills,
include_inferred_salary=input_data.include_inferred_salary,
include_personal_email=input_data.include_personal_email,
include_personal_contact_number=input_data.include_personal_contact_number,
include_social_media=input_data.include_social_media,
include_extra=input_data.include_extra,
)
yield "profile", profile
except Exception as e:
logger.error(f"Error fetching LinkedIn profile: {str(e)}")
yield "error", str(e)
class LinkedinPersonLookupBlock(Block):
"""Block to look up LinkedIn profiles by person's information using Enrichlayer API."""
class Input(BlockSchema):
"""Input schema for LinkedinPersonLookupBlock."""
first_name: str = SchemaField(
description="Person's first name",
placeholder="John",
advanced=False,
)
last_name: str | None = SchemaField(
description="Person's last name",
placeholder="Doe",
default=None,
advanced=False,
)
company_domain: str = SchemaField(
description="Domain of the company they work for (optional)",
placeholder="example.com",
advanced=False,
)
location: Optional[str] = SchemaField(
description="Person's location (optional)",
placeholder="San Francisco",
default=None,
)
title: Optional[str] = SchemaField(
description="Person's job title (optional)",
placeholder="CEO",
default=None,
)
include_similarity_checks: bool = SchemaField(
description="Include similarity checks",
default=False,
advanced=True,
)
enrich_profile: bool = SchemaField(
description="Enrich the profile with additional data",
default=False,
advanced=True,
)
credentials: EnrichlayerCredentialsInput = CredentialsField(
description="Enrichlayer API credentials"
)
class Output(BlockSchema):
"""Output schema for LinkedinPersonLookupBlock."""
lookup_result: PersonLookupResponse = SchemaField(
description="LinkedIn profile lookup result"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
"""Initialize LinkedinPersonLookupBlock."""
super().__init__(
id="d237a98a-5c4b-4a1c-b9e3-e6f9a6c81df7",
description="Look up LinkedIn profiles by person information using Enrichlayer",
categories={BlockCategory.SOCIAL},
input_schema=LinkedinPersonLookupBlock.Input,
output_schema=LinkedinPersonLookupBlock.Output,
test_input={
"first_name": "Bill",
"last_name": "Gates",
"company_domain": "gatesfoundation.org",
"include_similarity_checks": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
(
"lookup_result",
PersonLookupResponse(
url="https://www.linkedin.com/in/williamhgates/",
name_similarity_score=0.93,
company_similarity_score=0.83,
title_similarity_score=0.3,
location_similarity_score=0.20,
),
)
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"_lookup_person": lambda *args, **kwargs: PersonLookupResponse(
url="https://www.linkedin.com/in/williamhgates/",
name_similarity_score=0.93,
company_similarity_score=0.83,
title_similarity_score=0.3,
location_similarity_score=0.20,
)
},
)
@staticmethod
async def _lookup_person(
credentials: APIKeyCredentials,
first_name: str,
company_domain: str,
last_name: str | None = None,
location: Optional[str] = None,
title: Optional[str] = None,
include_similarity_checks: bool = False,
enrich_profile: bool = False,
):
client = EnrichlayerClient(credentials=credentials)
lookup_result = await client.lookup_person(
first_name=first_name,
last_name=last_name,
company_domain=company_domain,
location=location,
title=title,
include_similarity_checks=include_similarity_checks,
enrich_profile=enrich_profile,
)
return lookup_result
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
"""
Run the block to look up LinkedIn profiles.
Args:
input_data: Input parameters for the block
credentials: API key credentials for Enrichlayer
**kwargs: Additional keyword arguments
Yields:
Tuples of (output_name, output_value)
"""
try:
lookup_result = await self._lookup_person(
credentials=credentials,
first_name=input_data.first_name,
last_name=input_data.last_name,
company_domain=input_data.company_domain,
location=input_data.location,
title=input_data.title,
include_similarity_checks=input_data.include_similarity_checks,
enrich_profile=input_data.enrich_profile,
)
yield "lookup_result", lookup_result
except Exception as e:
logger.error(f"Error looking up LinkedIn profile: {str(e)}")
yield "error", str(e)
class LinkedinRoleLookupBlock(Block):
"""Block to look up LinkedIn profiles by role in a company using Enrichlayer API."""
class Input(BlockSchema):
"""Input schema for LinkedinRoleLookupBlock."""
role: str = SchemaField(
description="Role title (e.g., CEO, CTO)",
placeholder="CEO",
)
company_name: str = SchemaField(
description="Name of the company",
placeholder="Microsoft",
)
enrich_profile: bool = SchemaField(
description="Enrich the profile with additional data",
default=False,
advanced=True,
)
credentials: EnrichlayerCredentialsInput = CredentialsField(
description="Enrichlayer API credentials"
)
class Output(BlockSchema):
"""Output schema for LinkedinRoleLookupBlock."""
role_lookup_result: RoleLookupResponse = SchemaField(
description="LinkedIn role lookup result"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
"""Initialize LinkedinRoleLookupBlock."""
super().__init__(
id="3b9fc742-06d4-49c7-b5ce-7e302dd7c8a7",
description="Look up LinkedIn profiles by role in a company using Enrichlayer",
categories={BlockCategory.SOCIAL},
input_schema=LinkedinRoleLookupBlock.Input,
output_schema=LinkedinRoleLookupBlock.Output,
test_input={
"role": "Co-chair",
"company_name": "Gates Foundation",
"enrich_profile": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
(
"role_lookup_result",
RoleLookupResponse(
linkedin_profile_url="https://www.linkedin.com/in/williamhgates/",
),
)
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"_lookup_role": lambda *args, **kwargs: RoleLookupResponse(
linkedin_profile_url="https://www.linkedin.com/in/williamhgates/",
),
},
)
@staticmethod
async def _lookup_role(
credentials: APIKeyCredentials,
role: str,
company_name: str,
enrich_profile: bool = False,
):
client = EnrichlayerClient(credentials=credentials)
role_lookup_result = await client.lookup_role(
role=role,
company_name=company_name,
enrich_profile=enrich_profile,
)
return role_lookup_result
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
"""
Run the block to look up LinkedIn profiles by role.
Args:
input_data: Input parameters for the block
credentials: API key credentials for Enrichlayer
**kwargs: Additional keyword arguments
Yields:
Tuples of (output_name, output_value)
"""
try:
role_lookup_result = await self._lookup_role(
credentials=credentials,
role=input_data.role,
company_name=input_data.company_name,
enrich_profile=input_data.enrich_profile,
)
yield "role_lookup_result", role_lookup_result
except Exception as e:
logger.error(f"Error looking up role in company: {str(e)}")
yield "error", str(e)
class GetLinkedinProfilePictureBlock(Block):
"""Block to get LinkedIn profile pictures using Enrichlayer API."""
class Input(BlockSchema):
"""Input schema for GetLinkedinProfilePictureBlock."""
linkedin_profile_url: str = SchemaField(
description="LinkedIn profile URL",
placeholder="https://www.linkedin.com/in/username/",
)
credentials: EnrichlayerCredentialsInput = CredentialsField(
description="Enrichlayer API credentials"
)
class Output(BlockSchema):
"""Output schema for GetLinkedinProfilePictureBlock."""
profile_picture_url: MediaFileType = SchemaField(
description="LinkedIn profile picture URL"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
"""Initialize GetLinkedinProfilePictureBlock."""
super().__init__(
id="68d5a942-9b3f-4e9a-b7c1-d96ea4321f0d",
description="Get LinkedIn profile pictures using Enrichlayer",
categories={BlockCategory.SOCIAL},
input_schema=GetLinkedinProfilePictureBlock.Input,
output_schema=GetLinkedinProfilePictureBlock.Output,
test_input={
"linkedin_profile_url": "https://www.linkedin.com/in/williamhgates/",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
(
"profile_picture_url",
"https://media.licdn.com/dms/image/C4D03AQFj-xjuXrLFSQ/profile-displayphoto-shrink_800_800/0/1576881858598?e=1686787200&v=beta&t=zrQC76QwsfQQIWthfOnrKRBMZ5D-qIAvzLXLmWgYvTk",
)
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"_get_profile_picture": lambda *args, **kwargs: "https://media.licdn.com/dms/image/C4D03AQFj-xjuXrLFSQ/profile-displayphoto-shrink_800_800/0/1576881858598?e=1686787200&v=beta&t=zrQC76QwsfQQIWthfOnrKRBMZ5D-qIAvzLXLmWgYvTk",
},
)
@staticmethod
async def _get_profile_picture(
credentials: APIKeyCredentials, linkedin_profile_url: str
):
client = EnrichlayerClient(credentials=credentials)
profile_picture_response = await client.get_profile_picture(
linkedin_profile_url=linkedin_profile_url,
)
return profile_picture_response.profile_picture_url
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
"""
Run the block to get LinkedIn profile pictures.
Args:
input_data: Input parameters for the block
credentials: API key credentials for Enrichlayer
**kwargs: Additional keyword arguments
Yields:
Tuples of (output_name, output_value)
"""
try:
profile_picture = await self._get_profile_picture(
credentials=credentials,
linkedin_profile_url=input_data.linkedin_profile_url,
)
yield "profile_picture_url", profile_picture
except Exception as e:
logger.error(f"Error getting profile picture: {str(e)}")
yield "error", str(e)

View File

@@ -0,0 +1,32 @@
from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
ExaCredentials = APIKeyCredentials
ExaCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.EXA],
Literal["api_key"],
]
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="exa",
api_key=SecretStr("mock-exa-api-key"),
title="Mock Exa API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
def ExaCredentialsField() -> ExaCredentialsInput:
"""Creates an Exa credentials input on a block."""
return CredentialsField(description="The Exa integration requires an API Key.")

View File

@@ -1,16 +0,0 @@
"""
Shared configuration for all Exa blocks using the new SDK pattern.
"""
from backend.sdk import BlockCostType, ProviderBuilder
from ._webhook import ExaWebhookManager
# Configure the Exa provider once for all blocks
exa = (
ProviderBuilder("exa")
.with_api_key("EXA_API_KEY", "Exa API Key")
.with_webhook_manager(ExaWebhookManager)
.with_base_cost(1, BlockCostType.RUN)
.build()
)

View File

@@ -1,136 +0,0 @@
"""
Exa Webhook Manager implementation.
"""
import hashlib
import hmac
from enum import Enum
from backend.data.model import Credentials
from backend.sdk import (
APIKeyCredentials,
BaseWebhooksManager,
ProviderName,
Requests,
Webhook,
)
class ExaWebhookType(str, Enum):
"""Available webhook types for Exa."""
WEBSET = "webset"
class ExaEventType(str, Enum):
"""Available event types for Exa webhooks."""
WEBSET_CREATED = "webset.created"
WEBSET_DELETED = "webset.deleted"
WEBSET_PAUSED = "webset.paused"
WEBSET_IDLE = "webset.idle"
WEBSET_SEARCH_CREATED = "webset.search.created"
WEBSET_SEARCH_CANCELED = "webset.search.canceled"
WEBSET_SEARCH_COMPLETED = "webset.search.completed"
WEBSET_SEARCH_UPDATED = "webset.search.updated"
IMPORT_CREATED = "import.created"
IMPORT_COMPLETED = "import.completed"
IMPORT_PROCESSING = "import.processing"
WEBSET_ITEM_CREATED = "webset.item.created"
WEBSET_ITEM_ENRICHED = "webset.item.enriched"
WEBSET_EXPORT_CREATED = "webset.export.created"
WEBSET_EXPORT_COMPLETED = "webset.export.completed"
class ExaWebhookManager(BaseWebhooksManager):
"""Webhook manager for Exa API."""
PROVIDER_NAME = ProviderName("exa")
class WebhookType(str, Enum):
WEBSET = "webset"
@classmethod
async def validate_payload(
cls, webhook: Webhook, request, credentials: Credentials | None
) -> tuple[dict, str]:
"""Validate incoming webhook payload and signature."""
payload = await request.json()
# Get event type from payload
event_type = payload.get("eventType", "unknown")
# Verify webhook signature if secret is available
if webhook.secret:
signature = request.headers.get("X-Exa-Signature")
if signature:
# Compute expected signature
body = await request.body()
expected_signature = hmac.new(
webhook.secret.encode(), body, hashlib.sha256
).hexdigest()
# Compare signatures
if not hmac.compare_digest(signature, expected_signature):
raise ValueError("Invalid webhook signature")
return payload, event_type
async def _register_webhook(
self,
credentials: Credentials,
webhook_type: str,
resource: str,
events: list[str],
ingress_url: str,
secret: str,
) -> tuple[str, dict]:
"""Register webhook with Exa API."""
if not isinstance(credentials, APIKeyCredentials):
raise ValueError("Exa webhooks require API key credentials")
api_key = credentials.api_key.get_secret_value()
# Create webhook via Exa API
response = await Requests().post(
"https://api.exa.ai/v0/webhooks",
headers={"x-api-key": api_key},
json={
"url": ingress_url,
"events": events,
"metadata": {
"resource": resource,
"webhook_type": webhook_type,
},
},
)
if not response.ok:
error_data = response.json()
raise Exception(f"Failed to create Exa webhook: {error_data}")
webhook_data = response.json()
# Store the secret returned by Exa
return webhook_data["id"], {
"events": events,
"resource": resource,
"exa_secret": webhook_data.get("secret"),
}
async def _deregister_webhook(
self, webhook: Webhook, credentials: Credentials
) -> None:
"""Deregister webhook from Exa API."""
if not isinstance(credentials, APIKeyCredentials):
raise ValueError("Exa webhooks require API key credentials")
api_key = credentials.api_key.get_secret_value()
# Delete webhook via Exa API
response = await Requests().delete(
f"https://api.exa.ai/v0/webhooks/{webhook.provider_webhook_id}",
headers={"x-api-key": api_key},
)
if not response.ok and response.status != 404:
error_data = response.json()
raise Exception(f"Failed to delete Exa webhook: {error_data}")

View File

@@ -1,121 +0,0 @@
from backend.sdk import (
APIKeyCredentials,
BaseModel,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
Requests,
SchemaField,
)
from ._config import exa
class CostBreakdown(BaseModel):
keywordSearch: float
neuralSearch: float
contentText: float
contentHighlight: float
contentSummary: float
class SearchBreakdown(BaseModel):
search: float
contents: float
breakdown: CostBreakdown
class PerRequestPrices(BaseModel):
neuralSearch_1_25_results: float
neuralSearch_26_100_results: float
neuralSearch_100_plus_results: float
keywordSearch_1_100_results: float
keywordSearch_100_plus_results: float
class PerPagePrices(BaseModel):
contentText: float
contentHighlight: float
contentSummary: float
class CostDollars(BaseModel):
total: float
breakDown: list[SearchBreakdown]
perRequestPrices: PerRequestPrices
perPagePrices: PerPagePrices
class ExaAnswerBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
query: str = SchemaField(
description="The question or query to answer",
placeholder="What is the latest valuation of SpaceX?",
)
text: bool = SchemaField(
default=False,
description="If true, the response includes full text content in the search results",
advanced=True,
)
model: str = SchemaField(
default="exa",
description="The search model to use (exa or exa-pro)",
placeholder="exa",
advanced=True,
)
class Output(BlockSchema):
answer: str = SchemaField(
description="The generated answer based on search results"
)
citations: list[dict] = SchemaField(
description="Search results used to generate the answer",
default_factory=list,
)
cost_dollars: CostDollars = SchemaField(
description="Cost breakdown of the request"
)
error: str = SchemaField(
description="Error message if the request failed", default=""
)
def __init__(self):
super().__init__(
id="b79ca4cc-9d5e-47d1-9d4f-e3a2d7f28df5",
description="Get an LLM answer to a question informed by Exa search results",
categories={BlockCategory.SEARCH, BlockCategory.AI},
input_schema=ExaAnswerBlock.Input,
output_schema=ExaAnswerBlock.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/answer"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}
# Build the payload
payload = {
"query": input_data.query,
"text": input_data.text,
"model": input_data.model,
}
try:
response = await Requests().post(url, headers=headers, json=payload)
data = response.json()
yield "answer", data.get("answer", "")
yield "citations", data.get("citations", [])
yield "cost_dollars", data.get("costDollars", {})
except Exception as e:
yield "error", str(e)

View File

@@ -1,39 +1,57 @@
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
Requests,
SchemaField,
)
from typing import List
from ._config import exa
from .helpers import ContentSettings
from pydantic import BaseModel
from backend.blocks.exa._auth import (
ExaCredentials,
ExaCredentialsField,
ExaCredentialsInput,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import Requests
class ContentRetrievalSettings(BaseModel):
text: dict = SchemaField(
description="Text content settings",
default={"maxCharacters": 1000, "includeHtmlTags": False},
advanced=True,
)
highlights: dict = SchemaField(
description="Highlight settings",
default={
"numSentences": 3,
"highlightsPerUrl": 3,
"query": "",
},
advanced=True,
)
summary: dict = SchemaField(
description="Summary settings",
default={"query": ""},
advanced=True,
)
class ExaContentsBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
credentials: ExaCredentialsInput = ExaCredentialsField()
ids: List[str] = SchemaField(
description="Array of document IDs obtained from searches",
)
ids: list[str] = SchemaField(
description="Array of document IDs obtained from searches"
)
contents: ContentSettings = SchemaField(
contents: ContentRetrievalSettings = SchemaField(
description="Content retrieval settings",
default=ContentSettings(),
default=ContentRetrievalSettings(),
advanced=True,
)
class Output(BlockSchema):
results: list = SchemaField(
description="List of document contents", default_factory=list
)
error: str = SchemaField(
description="Error message if the request failed", default=""
description="List of document contents",
default_factory=list,
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
@@ -45,7 +63,7 @@ class ExaContentsBlock(Block):
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
self, input_data: Input, *, credentials: ExaCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/contents"
headers = {
@@ -53,7 +71,6 @@ class ExaContentsBlock(Block):
"x-api-key": credentials.api_key.get_secret_value(),
}
# Convert ContentSettings to API format
payload = {
"ids": input_data.ids,
"text": input_data.contents.text,

View File

@@ -1,6 +1,8 @@
from typing import Optional
from backend.sdk import BaseModel, SchemaField
from pydantic import BaseModel
from backend.data.model import SchemaField
class TextSettings(BaseModel):
@@ -40,90 +42,13 @@ class SummarySettings(BaseModel):
class ContentSettings(BaseModel):
text: TextSettings = SchemaField(
default=TextSettings(),
description="Text content settings",
)
highlights: HighlightSettings = SchemaField(
default=HighlightSettings(),
description="Highlight settings",
)
summary: SummarySettings = SchemaField(
default=SummarySettings(),
)
# Websets Models
class WebsetEntitySettings(BaseModel):
type: Optional[str] = SchemaField(
default=None,
description="Entity type (e.g., 'company', 'person')",
placeholder="company",
)
class WebsetCriterion(BaseModel):
description: str = SchemaField(
description="Description of the criterion",
placeholder="Must be based in the US",
)
success_rate: Optional[int] = SchemaField(
default=None,
description="Success rate percentage",
ge=0,
le=100,
)
class WebsetSearchConfig(BaseModel):
query: str = SchemaField(
description="Search query",
placeholder="Marketing agencies based in the US",
)
count: int = SchemaField(
default=10,
description="Number of results to return",
ge=1,
le=100,
)
entity: Optional[WebsetEntitySettings] = SchemaField(
default=None,
description="Entity settings for the search",
)
criteria: Optional[list[WebsetCriterion]] = SchemaField(
default=None,
description="Search criteria",
)
behavior: Optional[str] = SchemaField(
default="override",
description="Behavior when updating results ('override' or 'append')",
placeholder="override",
)
class EnrichmentOption(BaseModel):
label: str = SchemaField(
description="Label for the enrichment option",
placeholder="Option 1",
)
class WebsetEnrichmentConfig(BaseModel):
title: str = SchemaField(
description="Title of the enrichment",
placeholder="Company Details",
)
description: str = SchemaField(
description="Description of what this enrichment does",
placeholder="Extract company information",
)
format: str = SchemaField(
default="text",
description="Format of the enrichment result",
placeholder="text",
)
instructions: Optional[str] = SchemaField(
default=None,
description="Instructions for the enrichment",
placeholder="Extract key company metrics",
)
options: Optional[list[EnrichmentOption]] = SchemaField(
default=None,
description="Options for the enrichment",
description="Summary settings",
)

View File

@@ -1,247 +0,0 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
# Enum definitions based on available options
class WebsetStatus(str, Enum):
IDLE = "idle"
PENDING = "pending"
RUNNING = "running"
PAUSED = "paused"
class WebsetSearchStatus(str, Enum):
CREATED = "created"
# Add more if known, based on example it's "created"
class ImportStatus(str, Enum):
PENDING = "pending"
# Add more if known
class ImportFormat(str, Enum):
CSV = "csv"
# Add more if known
class EnrichmentStatus(str, Enum):
PENDING = "pending"
# Add more if known
class EnrichmentFormat(str, Enum):
TEXT = "text"
# Add more if known
class MonitorStatus(str, Enum):
ENABLED = "enabled"
# Add more if known
class MonitorBehaviorType(str, Enum):
SEARCH = "search"
# Add more if known
class MonitorRunStatus(str, Enum):
CREATED = "created"
# Add more if known
class CanceledReason(str, Enum):
WEBSET_DELETED = "webset_deleted"
# Add more if known
class FailedReason(str, Enum):
INVALID_FORMAT = "invalid_format"
# Add more if known
class Confidence(str, Enum):
HIGH = "high"
# Add more if known
# Nested models
class Entity(BaseModel):
type: str
class Criterion(BaseModel):
description: str
successRate: Optional[int] = None
class ExcludeItem(BaseModel):
source: str = Field(default="import")
id: str
class Relationship(BaseModel):
definition: str
limit: Optional[float] = None
class ScopeItem(BaseModel):
source: str = Field(default="import")
id: str
relationship: Optional[Relationship] = None
class Progress(BaseModel):
found: int
analyzed: int
completion: int
timeLeft: int
class Bounds(BaseModel):
min: int
max: int
class Expected(BaseModel):
total: int
confidence: str = Field(default="high") # Use str or Confidence enum
bounds: Bounds
class Recall(BaseModel):
expected: Expected
reasoning: str
class WebsetSearch(BaseModel):
id: str
object: str = Field(default="webset_search")
status: str = Field(default="created") # Or use WebsetSearchStatus
websetId: str
query: str
entity: Entity
criteria: List[Criterion]
count: int
behavior: str = Field(default="override")
exclude: List[ExcludeItem]
scope: List[ScopeItem]
progress: Progress
recall: Recall
metadata: Dict[str, Any] = Field(default_factory=dict)
canceledAt: Optional[datetime] = None
canceledReason: Optional[str] = Field(default=None) # Or use CanceledReason
createdAt: datetime
updatedAt: datetime
class ImportEntity(BaseModel):
type: str
class Import(BaseModel):
id: str
object: str = Field(default="import")
status: str = Field(default="pending") # Or use ImportStatus
format: str = Field(default="csv") # Or use ImportFormat
entity: ImportEntity
title: str
count: int
metadata: Dict[str, Any] = Field(default_factory=dict)
failedReason: Optional[str] = Field(default=None) # Or use FailedReason
failedAt: Optional[datetime] = None
failedMessage: Optional[str] = None
createdAt: datetime
updatedAt: datetime
class Option(BaseModel):
label: str
class WebsetEnrichment(BaseModel):
id: str
object: str = Field(default="webset_enrichment")
status: str = Field(default="pending") # Or use EnrichmentStatus
websetId: str
title: str
description: str
format: str = Field(default="text") # Or use EnrichmentFormat
options: List[Option]
instructions: str
metadata: Dict[str, Any] = Field(default_factory=dict)
createdAt: datetime
updatedAt: datetime
class Cadence(BaseModel):
cron: str
timezone: str = Field(default="Etc/UTC")
class BehaviorConfig(BaseModel):
query: Optional[str] = None
criteria: Optional[List[Criterion]] = None
entity: Optional[Entity] = None
count: Optional[int] = None
behavior: Optional[str] = Field(default=None)
class Behavior(BaseModel):
type: str = Field(default="search") # Or use MonitorBehaviorType
config: BehaviorConfig
class MonitorRun(BaseModel):
id: str
object: str = Field(default="monitor_run")
status: str = Field(default="created") # Or use MonitorRunStatus
monitorId: str
type: str = Field(default="search")
completedAt: Optional[datetime] = None
failedAt: Optional[datetime] = None
failedReason: Optional[str] = None
canceledAt: Optional[datetime] = None
createdAt: datetime
updatedAt: datetime
class Monitor(BaseModel):
id: str
object: str = Field(default="monitor")
status: str = Field(default="enabled") # Or use MonitorStatus
websetId: str
cadence: Cadence
behavior: Behavior
lastRun: Optional[MonitorRun] = None
nextRunAt: Optional[datetime] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
createdAt: datetime
updatedAt: datetime
class Webset(BaseModel):
id: str
object: str = Field(default="webset")
status: WebsetStatus
externalId: Optional[str] = None
title: Optional[str] = None
searches: List[WebsetSearch]
imports: List[Import]
enrichments: List[WebsetEnrichment]
monitors: List[Monitor]
streams: List[Any]
createdAt: datetime
updatedAt: datetime
metadata: Dict[str, Any] = Field(default_factory=dict)
class ListWebsets(BaseModel):
data: List[Webset]
hasMore: bool
nextCursor: Optional[str] = None

View File

@@ -1,61 +1,71 @@
from datetime import datetime
from typing import List
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
Requests,
SchemaField,
from backend.blocks.exa._auth import (
ExaCredentials,
ExaCredentialsField,
ExaCredentialsInput,
)
from ._config import exa
from .helpers import ContentSettings
from backend.blocks.exa.helpers import ContentSettings
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import Requests
class ExaSearchBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
credentials: ExaCredentialsInput = ExaCredentialsField()
query: str = SchemaField(description="The search query")
use_auto_prompt: bool = SchemaField(
description="Whether to use autoprompt", default=True, advanced=True
description="Whether to use autoprompt",
default=True,
advanced=True,
)
type: str = SchemaField(
description="Type of search",
default="",
advanced=True,
)
type: str = SchemaField(description="Type of search", default="", advanced=True)
category: str = SchemaField(
description="Category to search within", default="", advanced=True
description="Category to search within",
default="",
advanced=True,
)
number_of_results: int = SchemaField(
description="Number of results to return", default=10, advanced=True
description="Number of results to return",
default=10,
advanced=True,
)
include_domains: list[str] = SchemaField(
description="Domains to include in search", default_factory=list
include_domains: List[str] = SchemaField(
description="Domains to include in search",
default_factory=list,
)
exclude_domains: list[str] = SchemaField(
exclude_domains: List[str] = SchemaField(
description="Domains to exclude from search",
default_factory=list,
advanced=True,
)
start_crawl_date: datetime = SchemaField(
description="Start date for crawled content"
description="Start date for crawled content",
)
end_crawl_date: datetime = SchemaField(
description="End date for crawled content"
description="End date for crawled content",
)
start_published_date: datetime = SchemaField(
description="Start date for published content"
description="Start date for published content",
)
end_published_date: datetime = SchemaField(
description="End date for published content"
description="End date for published content",
)
include_text: list[str] = SchemaField(
description="Text patterns to include", default_factory=list, advanced=True
include_text: List[str] = SchemaField(
description="Text patterns to include",
default_factory=list,
advanced=True,
)
exclude_text: list[str] = SchemaField(
description="Text patterns to exclude", default_factory=list, advanced=True
exclude_text: List[str] = SchemaField(
description="Text patterns to exclude",
default_factory=list,
advanced=True,
)
contents: ContentSettings = SchemaField(
description="Content retrieval settings",
@@ -65,7 +75,8 @@ class ExaSearchBlock(Block):
class Output(BlockSchema):
results: list = SchemaField(
description="List of search results", default_factory=list
description="List of search results",
default_factory=list,
)
error: str = SchemaField(
description="Error message if the request failed",
@@ -81,7 +92,7 @@ class ExaSearchBlock(Block):
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
self, input_data: Input, *, credentials: ExaCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/search"
headers = {
@@ -93,7 +104,7 @@ class ExaSearchBlock(Block):
"query": input_data.query,
"useAutoprompt": input_data.use_auto_prompt,
"numResults": input_data.number_of_results,
"contents": input_data.contents.model_dump(),
"contents": input_data.contents.dict(),
}
date_field_mapping = {

View File

@@ -1,60 +1,57 @@
from datetime import datetime
from typing import Any
from typing import Any, List
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
Requests,
SchemaField,
from backend.blocks.exa._auth import (
ExaCredentials,
ExaCredentialsField,
ExaCredentialsInput,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import Requests
from ._config import exa
from .helpers import ContentSettings
class ExaFindSimilarBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
credentials: ExaCredentialsInput = ExaCredentialsField()
url: str = SchemaField(
description="The url for which you would like to find similar links"
)
number_of_results: int = SchemaField(
description="Number of results to return", default=10, advanced=True
description="Number of results to return",
default=10,
advanced=True,
)
include_domains: list[str] = SchemaField(
include_domains: List[str] = SchemaField(
description="Domains to include in search",
default_factory=list,
advanced=True,
)
exclude_domains: list[str] = SchemaField(
exclude_domains: List[str] = SchemaField(
description="Domains to exclude from search",
default_factory=list,
advanced=True,
)
start_crawl_date: datetime = SchemaField(
description="Start date for crawled content"
description="Start date for crawled content",
)
end_crawl_date: datetime = SchemaField(
description="End date for crawled content"
description="End date for crawled content",
)
start_published_date: datetime = SchemaField(
description="Start date for published content"
description="Start date for published content",
)
end_published_date: datetime = SchemaField(
description="End date for published content"
description="End date for published content",
)
include_text: list[str] = SchemaField(
include_text: List[str] = SchemaField(
description="Text patterns to include (max 1 string, up to 5 words)",
default_factory=list,
advanced=True,
)
exclude_text: list[str] = SchemaField(
exclude_text: List[str] = SchemaField(
description="Text patterns to exclude (max 1 string, up to 5 words)",
default_factory=list,
advanced=True,
@@ -66,13 +63,11 @@ class ExaFindSimilarBlock(Block):
)
class Output(BlockSchema):
results: list[Any] = SchemaField(
results: List[Any] = SchemaField(
description="List of similar documents with title, URL, published date, author, and score",
default_factory=list,
)
error: str = SchemaField(
description="Error message if the request failed", default=""
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
@@ -84,7 +79,7 @@ class ExaFindSimilarBlock(Block):
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
self, input_data: Input, *, credentials: ExaCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/findSimilar"
headers = {
@@ -95,7 +90,7 @@ class ExaFindSimilarBlock(Block):
payload = {
"url": input_data.url,
"numResults": input_data.number_of_results,
"contents": input_data.contents.model_dump(),
"contents": input_data.contents.dict(),
}
optional_field_mapping = {

View File

@@ -1,202 +0,0 @@
"""
Exa Webhook Blocks
These blocks handle webhook events from Exa's API for websets and other events.
"""
from backend.sdk import (
BaseModel,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
BlockWebhookConfig,
CredentialsMetaInput,
Field,
ProviderName,
SchemaField,
)
from ._config import exa
from ._webhook import ExaEventType
class WebsetEventFilter(BaseModel):
"""Filter configuration for Exa webset events."""
webset_created: bool = Field(
default=True, description="Receive notifications when websets are created"
)
webset_deleted: bool = Field(
default=False, description="Receive notifications when websets are deleted"
)
webset_paused: bool = Field(
default=False, description="Receive notifications when websets are paused"
)
webset_idle: bool = Field(
default=False, description="Receive notifications when websets become idle"
)
search_created: bool = Field(
default=True,
description="Receive notifications when webset searches are created",
)
search_completed: bool = Field(
default=True, description="Receive notifications when webset searches complete"
)
search_canceled: bool = Field(
default=False,
description="Receive notifications when webset searches are canceled",
)
search_updated: bool = Field(
default=False,
description="Receive notifications when webset searches are updated",
)
item_created: bool = Field(
default=True, description="Receive notifications when webset items are created"
)
item_enriched: bool = Field(
default=True, description="Receive notifications when webset items are enriched"
)
export_created: bool = Field(
default=False,
description="Receive notifications when webset exports are created",
)
export_completed: bool = Field(
default=True, description="Receive notifications when webset exports complete"
)
import_created: bool = Field(
default=False, description="Receive notifications when imports are created"
)
import_completed: bool = Field(
default=True, description="Receive notifications when imports complete"
)
import_processing: bool = Field(
default=False, description="Receive notifications when imports are processing"
)
class ExaWebsetWebhookBlock(Block):
"""
Receives webhook notifications for Exa webset events.
This block allows you to monitor various events related to Exa websets,
including creation, updates, searches, and exports.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = exa.credentials_field(
description="Exa API credentials for webhook management"
)
webhook_url: str = SchemaField(
description="URL to receive webhooks (auto-generated)",
default="",
hidden=True,
)
webset_id: str = SchemaField(
description="The webset ID to monitor (optional, monitors all if empty)",
default="",
)
event_filter: WebsetEventFilter = SchemaField(
description="Configure which events to receive", default=WebsetEventFilter()
)
payload: dict = SchemaField(
description="Webhook payload data", default={}, hidden=True
)
class Output(BlockSchema):
event_type: str = SchemaField(description="Type of event that occurred")
event_id: str = SchemaField(description="Unique identifier for this event")
webset_id: str = SchemaField(description="ID of the affected webset")
data: dict = SchemaField(description="Event-specific data")
timestamp: str = SchemaField(description="When the event occurred")
metadata: dict = SchemaField(description="Additional event metadata")
def __init__(self):
super().__init__(
disabled=True,
id="d0204ed8-8b81-408d-8b8d-ed087a546228",
description="Receive webhook notifications for Exa webset events",
categories={BlockCategory.INPUT},
input_schema=ExaWebsetWebhookBlock.Input,
output_schema=ExaWebsetWebhookBlock.Output,
block_type=BlockType.WEBHOOK,
webhook_config=BlockWebhookConfig(
provider=ProviderName("exa"),
webhook_type="webset",
event_filter_input="event_filter",
resource_format="{webset_id}",
),
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
"""Process incoming Exa webhook payload."""
try:
payload = input_data.payload
# Extract event details
event_type = payload.get("eventType", "unknown")
event_id = payload.get("eventId", "")
# Get webset ID from payload or input
webset_id = payload.get("websetId", input_data.webset_id)
# Check if we should process this event based on filter
should_process = self._should_process_event(
event_type, input_data.event_filter
)
if not should_process:
# Skip events that don't match our filter
return
# Extract event data
event_data = payload.get("data", {})
timestamp = payload.get("occurredAt", payload.get("createdAt", ""))
metadata = payload.get("metadata", {})
yield "event_type", event_type
yield "event_id", event_id
yield "webset_id", webset_id
yield "data", event_data
yield "timestamp", timestamp
yield "metadata", metadata
except Exception as e:
# Handle errors gracefully
yield "event_type", "error"
yield "event_id", ""
yield "webset_id", input_data.webset_id
yield "data", {"error": str(e)}
yield "timestamp", ""
yield "metadata", {}
def _should_process_event(
self, event_type: str, event_filter: WebsetEventFilter
) -> bool:
"""Check if an event should be processed based on the filter."""
filter_mapping = {
ExaEventType.WEBSET_CREATED: event_filter.webset_created,
ExaEventType.WEBSET_DELETED: event_filter.webset_deleted,
ExaEventType.WEBSET_PAUSED: event_filter.webset_paused,
ExaEventType.WEBSET_IDLE: event_filter.webset_idle,
ExaEventType.WEBSET_SEARCH_CREATED: event_filter.search_created,
ExaEventType.WEBSET_SEARCH_COMPLETED: event_filter.search_completed,
ExaEventType.WEBSET_SEARCH_CANCELED: event_filter.search_canceled,
ExaEventType.WEBSET_SEARCH_UPDATED: event_filter.search_updated,
ExaEventType.WEBSET_ITEM_CREATED: event_filter.item_created,
ExaEventType.WEBSET_ITEM_ENRICHED: event_filter.item_enriched,
ExaEventType.WEBSET_EXPORT_CREATED: event_filter.export_created,
ExaEventType.WEBSET_EXPORT_COMPLETED: event_filter.export_completed,
ExaEventType.IMPORT_CREATED: event_filter.import_created,
ExaEventType.IMPORT_COMPLETED: event_filter.import_completed,
ExaEventType.IMPORT_PROCESSING: event_filter.import_processing,
}
# Try to convert string to ExaEventType enum
try:
event_type_enum = ExaEventType(event_type)
return filter_mapping.get(event_type_enum, True)
except ValueError:
# If event_type is not a valid enum value, process it by default
return True

View File

@@ -1,752 +0,0 @@
from datetime import datetime
from enum import Enum
from typing import Annotated, Any, Dict, List, Optional
from exa_py import Exa
from exa_py.websets.types import (
CreateCriterionParameters,
CreateEnrichmentParameters,
CreateWebsetParameters,
CreateWebsetParametersSearch,
ExcludeItem,
Format,
ImportItem,
ImportSource,
Option,
ScopeItem,
ScopeRelationship,
ScopeSourceType,
WebsetArticleEntity,
WebsetCompanyEntity,
WebsetCustomEntity,
WebsetPersonEntity,
WebsetResearchPaperEntity,
WebsetStatus,
)
from pydantic import Field
from backend.sdk import (
APIKeyCredentials,
BaseModel,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
Requests,
SchemaField,
)
from ._config import exa
class SearchEntityType(str, Enum):
COMPANY = "company"
PERSON = "person"
ARTICLE = "article"
RESEARCH_PAPER = "research_paper"
CUSTOM = "custom"
AUTO = "auto"
class SearchType(str, Enum):
IMPORT = "import"
WEBSET = "webset"
class EnrichmentFormat(str, Enum):
TEXT = "text"
DATE = "date"
NUMBER = "number"
OPTIONS = "options"
EMAIL = "email"
PHONE = "phone"
class Webset(BaseModel):
id: str
status: WebsetStatus | None = Field(..., title="WebsetStatus")
"""
The status of the webset
"""
external_id: Annotated[Optional[str], Field(alias="externalId")] = None
"""
The external identifier for the webset
NOTE: Returning dict to avoid ui crashing due to nested objects
"""
searches: List[dict[str, Any]] | None = None
"""
The searches that have been performed on the webset.
NOTE: Returning dict to avoid ui crashing due to nested objects
"""
enrichments: List[dict[str, Any]] | None = None
"""
The Enrichments to apply to the Webset Items.
NOTE: Returning dict to avoid ui crashing due to nested objects
"""
monitors: List[dict[str, Any]] | None = None
"""
The Monitors for the Webset.
NOTE: Returning dict to avoid ui crashing due to nested objects
"""
metadata: Optional[Dict[str, Any]] = {}
"""
Set of key-value pairs you want to associate with this object.
"""
created_at: Annotated[datetime, Field(alias="createdAt")] | None = None
"""
The date and time the webset was created
"""
updated_at: Annotated[datetime, Field(alias="updatedAt")] | None = None
"""
The date and time the webset was last updated
"""
class ExaCreateWebsetBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
# Search parameters (flattened)
search_query: str = SchemaField(
description="Your search query. Use this to describe what you are looking for. Any URL provided will be crawled and used as context for the search.",
placeholder="Marketing agencies based in the US, that focus on consumer products",
)
search_count: Optional[int] = SchemaField(
default=10,
description="Number of items the search will attempt to find. The actual number of items found may be less than this number depending on the search complexity.",
ge=1,
le=1000,
)
search_entity_type: SearchEntityType = SchemaField(
default=SearchEntityType.AUTO,
description="Entity type: 'company', 'person', 'article', 'research_paper', or 'custom'. If not provided, we automatically detect the entity from the query.",
advanced=True,
)
search_entity_description: Optional[str] = SchemaField(
default=None,
description="Description for custom entity type (required when search_entity_type is 'custom')",
advanced=True,
)
# Search criteria (flattened)
search_criteria: list[str] = SchemaField(
default_factory=list,
description="List of criteria descriptions that every item will be evaluated against. If not provided, we automatically detect the criteria from the query.",
advanced=True,
)
# Search exclude sources (flattened)
search_exclude_sources: list[str] = SchemaField(
default_factory=list,
description="List of source IDs (imports or websets) to exclude from search results",
advanced=True,
)
search_exclude_types: list[SearchType] = SchemaField(
default_factory=list,
description="List of source types corresponding to exclude sources ('import' or 'webset')",
advanced=True,
)
# Search scope sources (flattened)
search_scope_sources: list[str] = SchemaField(
default_factory=list,
description="List of source IDs (imports or websets) to limit search scope to",
advanced=True,
)
search_scope_types: list[SearchType] = SchemaField(
default_factory=list,
description="List of source types corresponding to scope sources ('import' or 'webset')",
advanced=True,
)
search_scope_relationships: list[str] = SchemaField(
default_factory=list,
description="List of relationship definitions for hop searches (optional, one per scope source)",
advanced=True,
)
search_scope_relationship_limits: list[int] = SchemaField(
default_factory=list,
description="List of limits on the number of related entities to find (optional, one per scope relationship)",
advanced=True,
)
# Import parameters (flattened)
import_sources: list[str] = SchemaField(
default_factory=list,
description="List of source IDs to import from",
advanced=True,
)
import_types: list[SearchType] = SchemaField(
default_factory=list,
description="List of source types corresponding to import sources ('import' or 'webset')",
advanced=True,
)
# Enrichment parameters (flattened)
enrichment_descriptions: list[str] = SchemaField(
default_factory=list,
description="List of enrichment task descriptions to perform on each webset item",
advanced=True,
)
enrichment_formats: list[EnrichmentFormat] = SchemaField(
default_factory=list,
description="List of formats for enrichment responses ('text', 'date', 'number', 'options', 'email', 'phone'). If not specified, we automatically select the best format.",
advanced=True,
)
enrichment_options: list[list[str]] = SchemaField(
default_factory=list,
description="List of option lists for enrichments with 'options' format. Each inner list contains the option labels.",
advanced=True,
)
enrichment_metadata: list[dict] = SchemaField(
default_factory=list,
description="List of metadata dictionaries for enrichments",
advanced=True,
)
# Webset metadata
external_id: Optional[str] = SchemaField(
default=None,
description="External identifier for the webset. You can use this to reference the webset by your own internal identifiers.",
placeholder="my-webset-123",
advanced=True,
)
metadata: Optional[dict] = SchemaField(
default_factory=dict,
description="Key-value pairs to associate with this webset",
advanced=True,
)
class Output(BlockSchema):
webset: Webset = SchemaField(
description="The unique identifier for the created webset"
)
def __init__(self):
super().__init__(
id="0cda29ff-c549-4a19-8805-c982b7d4ec34",
description="Create a new Exa Webset for persistent web search collections",
categories={BlockCategory.SEARCH},
input_schema=ExaCreateWebsetBlock.Input,
output_schema=ExaCreateWebsetBlock.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
exa = Exa(credentials.api_key.get_secret_value())
# ------------------------------------------------------------
# Build entity (if explicitly provided)
# ------------------------------------------------------------
entity = None
if input_data.search_entity_type == SearchEntityType.COMPANY:
entity = WebsetCompanyEntity(type="company")
elif input_data.search_entity_type == SearchEntityType.PERSON:
entity = WebsetPersonEntity(type="person")
elif input_data.search_entity_type == SearchEntityType.ARTICLE:
entity = WebsetArticleEntity(type="article")
elif input_data.search_entity_type == SearchEntityType.RESEARCH_PAPER:
entity = WebsetResearchPaperEntity(type="research_paper")
elif (
input_data.search_entity_type == SearchEntityType.CUSTOM
and input_data.search_entity_description
):
entity = WebsetCustomEntity(
type="custom", description=input_data.search_entity_description
)
# ------------------------------------------------------------
# Build criteria list
# ------------------------------------------------------------
criteria = None
if input_data.search_criteria:
criteria = [
CreateCriterionParameters(description=item)
for item in input_data.search_criteria
]
# ------------------------------------------------------------
# Build exclude sources list
# ------------------------------------------------------------
exclude_items = None
if input_data.search_exclude_sources:
exclude_items = []
for idx, src_id in enumerate(input_data.search_exclude_sources):
src_type = None
if input_data.search_exclude_types and idx < len(
input_data.search_exclude_types
):
src_type = input_data.search_exclude_types[idx]
# Default to IMPORT if type missing
if src_type == SearchType.WEBSET:
source_enum = ImportSource.webset
else:
source_enum = ImportSource.import_
exclude_items.append(ExcludeItem(source=source_enum, id=src_id))
# ------------------------------------------------------------
# Build scope list
# ------------------------------------------------------------
scope_items = None
if input_data.search_scope_sources:
scope_items = []
for idx, src_id in enumerate(input_data.search_scope_sources):
src_type = None
if input_data.search_scope_types and idx < len(
input_data.search_scope_types
):
src_type = input_data.search_scope_types[idx]
relationship = None
if input_data.search_scope_relationships and idx < len(
input_data.search_scope_relationships
):
rel_def = input_data.search_scope_relationships[idx]
lim = None
if input_data.search_scope_relationship_limits and idx < len(
input_data.search_scope_relationship_limits
):
lim = input_data.search_scope_relationship_limits[idx]
relationship = ScopeRelationship(definition=rel_def, limit=lim)
if src_type == SearchType.WEBSET:
src_enum = ScopeSourceType.webset
else:
src_enum = ScopeSourceType.import_
scope_items.append(
ScopeItem(source=src_enum, id=src_id, relationship=relationship)
)
# ------------------------------------------------------------
# Assemble search parameters (only if a query is provided)
# ------------------------------------------------------------
search_params = None
if input_data.search_query:
search_params = CreateWebsetParametersSearch(
query=input_data.search_query,
count=input_data.search_count,
entity=entity,
criteria=criteria,
exclude=exclude_items,
scope=scope_items,
)
# ------------------------------------------------------------
# Build imports list
# ------------------------------------------------------------
imports_params = None
if input_data.import_sources:
imports_params = []
for idx, src_id in enumerate(input_data.import_sources):
src_type = None
if input_data.import_types and idx < len(input_data.import_types):
src_type = input_data.import_types[idx]
if src_type == SearchType.WEBSET:
source_enum = ImportSource.webset
else:
source_enum = ImportSource.import_
imports_params.append(ImportItem(source=source_enum, id=src_id))
# ------------------------------------------------------------
# Build enrichment list
# ------------------------------------------------------------
enrichments_params = None
if input_data.enrichment_descriptions:
enrichments_params = []
for idx, desc in enumerate(input_data.enrichment_descriptions):
fmt = None
if input_data.enrichment_formats and idx < len(
input_data.enrichment_formats
):
fmt_enum = input_data.enrichment_formats[idx]
if fmt_enum is not None:
fmt = Format(
fmt_enum.value if isinstance(fmt_enum, Enum) else fmt_enum
)
options_list = None
if input_data.enrichment_options and idx < len(
input_data.enrichment_options
):
raw_opts = input_data.enrichment_options[idx]
if raw_opts:
options_list = [Option(label=o) for o in raw_opts]
metadata_obj = None
if input_data.enrichment_metadata and idx < len(
input_data.enrichment_metadata
):
metadata_obj = input_data.enrichment_metadata[idx]
enrichments_params.append(
CreateEnrichmentParameters(
description=desc,
format=fmt,
options=options_list,
metadata=metadata_obj,
)
)
# ------------------------------------------------------------
# Create the webset
# ------------------------------------------------------------
webset = exa.websets.create(
params=CreateWebsetParameters(
search=search_params,
imports=imports_params,
enrichments=enrichments_params,
external_id=input_data.external_id,
metadata=input_data.metadata,
)
)
# Use alias field names returned from Exa SDK so that nested models validate correctly
yield "webset", Webset.model_validate(webset.model_dump(by_alias=True))
class ExaUpdateWebsetBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
webset_id: str = SchemaField(
description="The ID or external ID of the Webset to update",
placeholder="webset-id-or-external-id",
)
metadata: Optional[dict] = SchemaField(
default=None,
description="Key-value pairs to associate with this webset (set to null to clear)",
)
class Output(BlockSchema):
webset_id: str = SchemaField(description="The unique identifier for the webset")
status: str = SchemaField(description="The status of the webset")
external_id: Optional[str] = SchemaField(
description="The external identifier for the webset", default=None
)
metadata: dict = SchemaField(
description="Updated metadata for the webset", default_factory=dict
)
updated_at: str = SchemaField(
description="The date and time the webset was updated"
)
error: str = SchemaField(
description="Error message if the request failed", default=""
)
def __init__(self):
super().__init__(
id="89ccd99a-3c2b-4fbf-9e25-0ffa398d0314",
description="Update metadata for an existing Webset",
categories={BlockCategory.SEARCH},
input_schema=ExaUpdateWebsetBlock.Input,
output_schema=ExaUpdateWebsetBlock.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}
# Build the payload
payload = {}
if input_data.metadata is not None:
payload["metadata"] = input_data.metadata
try:
response = await Requests().post(url, headers=headers, json=payload)
data = response.json()
yield "webset_id", data.get("id", "")
yield "status", data.get("status", "")
yield "external_id", data.get("externalId")
yield "metadata", data.get("metadata", {})
yield "updated_at", data.get("updatedAt", "")
except Exception as e:
yield "error", str(e)
yield "webset_id", ""
yield "status", ""
yield "metadata", {}
yield "updated_at", ""
class ExaListWebsetsBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
trigger: Any | None = SchemaField(
default=None,
description="Trigger for the webset, value is ignored!",
advanced=False,
)
cursor: Optional[str] = SchemaField(
default=None,
description="Cursor for pagination through results",
advanced=True,
)
limit: int = SchemaField(
default=25,
description="Number of websets to return (1-100)",
ge=1,
le=100,
advanced=True,
)
class Output(BlockSchema):
websets: list[Webset] = SchemaField(
description="List of websets", default_factory=list
)
has_more: bool = SchemaField(
description="Whether there are more results to paginate through",
default=False,
)
next_cursor: Optional[str] = SchemaField(
description="Cursor for the next page of results", default=None
)
error: str = SchemaField(
description="Error message if the request failed", default=""
)
def __init__(self):
super().__init__(
id="1dcd8fd6-c13f-4e6f-bd4c-654428fa4757",
description="List all Websets with pagination support",
categories={BlockCategory.SEARCH},
input_schema=ExaListWebsetsBlock.Input,
output_schema=ExaListWebsetsBlock.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/websets/v0/websets"
headers = {
"x-api-key": credentials.api_key.get_secret_value(),
}
params: dict[str, Any] = {
"limit": input_data.limit,
}
if input_data.cursor:
params["cursor"] = input_data.cursor
try:
response = await Requests().get(url, headers=headers, params=params)
data = response.json()
yield "websets", data.get("data", [])
yield "has_more", data.get("hasMore", False)
yield "next_cursor", data.get("nextCursor")
except Exception as e:
yield "error", str(e)
yield "websets", []
yield "has_more", False
class ExaGetWebsetBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
webset_id: str = SchemaField(
description="The ID or external ID of the Webset to retrieve",
placeholder="webset-id-or-external-id",
)
class Output(BlockSchema):
webset_id: str = SchemaField(description="The unique identifier for the webset")
status: str = SchemaField(description="The status of the webset")
external_id: Optional[str] = SchemaField(
description="The external identifier for the webset", default=None
)
searches: list[dict] = SchemaField(
description="The searches performed on the webset", default_factory=list
)
enrichments: list[dict] = SchemaField(
description="The enrichments applied to the webset", default_factory=list
)
monitors: list[dict] = SchemaField(
description="The monitors for the webset", default_factory=list
)
items: Optional[list[dict]] = SchemaField(
description="The items in the webset (if expand_items is true)",
default=None,
)
metadata: dict = SchemaField(
description="Key-value pairs associated with the webset",
default_factory=dict,
)
created_at: str = SchemaField(
description="The date and time the webset was created"
)
updated_at: str = SchemaField(
description="The date and time the webset was last updated"
)
error: str = SchemaField(
description="Error message if the request failed", default=""
)
def __init__(self):
super().__init__(
id="6ab8e12a-132c-41bf-b5f3-d662620fa832",
description="Retrieve a Webset by ID or external ID",
categories={BlockCategory.SEARCH},
input_schema=ExaGetWebsetBlock.Input,
output_schema=ExaGetWebsetBlock.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}"
headers = {
"x-api-key": credentials.api_key.get_secret_value(),
}
try:
response = await Requests().get(url, headers=headers)
data = response.json()
yield "webset_id", data.get("id", "")
yield "status", data.get("status", "")
yield "external_id", data.get("externalId")
yield "searches", data.get("searches", [])
yield "enrichments", data.get("enrichments", [])
yield "monitors", data.get("monitors", [])
yield "items", data.get("items")
yield "metadata", data.get("metadata", {})
yield "created_at", data.get("createdAt", "")
yield "updated_at", data.get("updatedAt", "")
except Exception as e:
yield "error", str(e)
yield "webset_id", ""
yield "status", ""
yield "searches", []
yield "enrichments", []
yield "monitors", []
yield "metadata", {}
yield "created_at", ""
yield "updated_at", ""
class ExaDeleteWebsetBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
webset_id: str = SchemaField(
description="The ID or external ID of the Webset to delete",
placeholder="webset-id-or-external-id",
)
class Output(BlockSchema):
webset_id: str = SchemaField(
description="The unique identifier for the deleted webset"
)
external_id: Optional[str] = SchemaField(
description="The external identifier for the deleted webset", default=None
)
status: str = SchemaField(description="The status of the deleted webset")
success: str = SchemaField(
description="Whether the deletion was successful", default="true"
)
error: str = SchemaField(
description="Error message if the request failed", default=""
)
def __init__(self):
super().__init__(
id="aa6994a2-e986-421f-8d4c-7671d3be7b7e",
description="Delete a Webset and all its items",
categories={BlockCategory.SEARCH},
input_schema=ExaDeleteWebsetBlock.Input,
output_schema=ExaDeleteWebsetBlock.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}"
headers = {
"x-api-key": credentials.api_key.get_secret_value(),
}
try:
response = await Requests().delete(url, headers=headers)
data = response.json()
yield "webset_id", data.get("id", "")
yield "external_id", data.get("externalId")
yield "status", data.get("status", "")
yield "success", "true"
except Exception as e:
yield "error", str(e)
yield "webset_id", ""
yield "status", ""
yield "success", "false"
class ExaCancelWebsetBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
webset_id: str = SchemaField(
description="The ID or external ID of the Webset to cancel",
placeholder="webset-id-or-external-id",
)
class Output(BlockSchema):
webset_id: str = SchemaField(description="The unique identifier for the webset")
status: str = SchemaField(
description="The status of the webset after cancellation"
)
external_id: Optional[str] = SchemaField(
description="The external identifier for the webset", default=None
)
success: str = SchemaField(
description="Whether the cancellation was successful", default="true"
)
error: str = SchemaField(
description="Error message if the request failed", default=""
)
def __init__(self):
super().__init__(
id="e40a6420-1db8-47bb-b00a-0e6aecd74176",
description="Cancel all operations being performed on a Webset",
categories={BlockCategory.SEARCH},
input_schema=ExaCancelWebsetBlock.Input,
output_schema=ExaCancelWebsetBlock.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}/cancel"
headers = {
"x-api-key": credentials.api_key.get_secret_value(),
}
try:
response = await Requests().post(url, headers=headers)
data = response.json()
yield "webset_id", data.get("id", "")
yield "status", data.get("status", "")
yield "external_id", data.get("externalId")
yield "success", "true"
except Exception as e:
yield "error", str(e)
yield "webset_id", ""
yield "status", ""
yield "success", "false"

Some files were not shown because too many files have changed in this diff Show More