Compare commits

..

1 Commits

Author SHA1 Message Date
SwiftyOS
73f3464c6f added a summary of the current system design 2025-07-11 09:55:41 +02:00
824 changed files with 37074 additions and 45146 deletions

View File

@@ -9,13 +9,11 @@
# Platform - Backend
!autogpt_platform/backend/backend/
!autogpt_platform/backend/test/e2e_test_data.py
!autogpt_platform/backend/migrations/
!autogpt_platform/backend/schema.prisma
!autogpt_platform/backend/pyproject.toml
!autogpt_platform/backend/poetry.lock
!autogpt_platform/backend/README.md
!autogpt_platform/backend/.env
# Platform - Market
!autogpt_platform/market/market/
@@ -28,7 +26,6 @@
# Platform - Frontend
!autogpt_platform/frontend/src/
!autogpt_platform/frontend/public/
!autogpt_platform/frontend/scripts/
!autogpt_platform/frontend/package.json
!autogpt_platform/frontend/pnpm-lock.yaml
!autogpt_platform/frontend/tsconfig.json
@@ -36,7 +33,6 @@
## config
!autogpt_platform/frontend/*.config.*
!autogpt_platform/frontend/.env.*
!autogpt_platform/frontend/.env
# Classic - AutoGPT
!classic/original_autogpt/autogpt/

View File

@@ -24,8 +24,7 @@
</details>
#### For configuration changes:
- [ ] `.env.default` is updated or already compatible with my changes
- [ ] `.env.example` is updated or already compatible with my changes
- [ ] `docker-compose.yml` is updated or already compatible with my changes
- [ ] I have included a list of my configuration changes in the PR description (under **Changes**)

View File

@@ -37,7 +37,7 @@ jobs:
- name: Generate cache key
id: cache-key
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('**/pnpm-lock.yaml') }}" >> $GITHUB_OUTPUT
- name: Cache dependencies
uses: actions/cache@v4
@@ -45,7 +45,6 @@ jobs:
path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
@@ -73,7 +72,6 @@ jobs:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
@@ -82,6 +80,36 @@ jobs:
- name: Run lint
run: pnpm lint
type-check:
runs-on: ubuntu-latest
needs: setup
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Enable corepack
run: corepack enable
- name: Restore dependencies cache
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Run tsc check
run: pnpm type-check
chromatic:
runs-on: ubuntu-latest
needs: setup
@@ -108,7 +136,6 @@ jobs:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
@@ -124,10 +151,12 @@ jobs:
exitOnceUploaded: true
test:
runs-on: big-boi
runs-on: ubuntu-latest
needs: setup
strategy:
fail-fast: false
matrix:
browser: [chromium, webkit]
steps:
- name: Checkout repository
@@ -143,67 +172,23 @@ jobs:
- name: Enable corepack
run: corepack enable
- name: Free Disk Space (Ubuntu)
uses: jlumbroso/free-disk-space@main
with:
large-packages: false # slow
docker-images: false # limited benefit
- name: Copy default supabase .env
run: |
cp ../.env.default ../.env
cp ../.env.example ../.env
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Cache Docker layers
uses: actions/cache@v4
with:
path: /tmp/.buildx-cache
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
restore-keys: |
${{ runner.os }}-buildx-frontend-test-
- name: Copy backend .env
run: |
cp ../backend/.env.example ../backend/.env
- name: Run docker compose
run: |
docker compose -f ../docker-compose.yml up -d
env:
DOCKER_BUILDKIT: 1
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
BUILDX_CACHE_TO: type=local,dest=/tmp/.buildx-cache-new,mode=max
- name: Move cache
run: |
rm -rf /tmp/.buildx-cache
if [ -d "/tmp/.buildx-cache-new" ]; then
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
fi
- name: Wait for services to be ready
run: |
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
- name: Create E2E test data
run: |
echo "Creating E2E test data..."
# First try to run the script from inside the container
if docker compose -f ../docker-compose.yml exec -T rest_server test -f /app/autogpt_platform/backend/test/e2e_test_data.py; then
echo "✅ Found e2e_test_data.py in container, running it..."
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python backend/test/e2e_test_data.py" || {
echo "❌ E2E test data creation failed!"
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
exit 1
}
else
echo "⚠️ e2e_test_data.py not found in container, copying and running..."
# Copy the script into the container and run it
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.yml ps -q rest_server):/tmp/e2e_test_data.py || {
echo "❌ Failed to copy script to container"
exit 1
}
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
echo "❌ E2E test data creation failed!"
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
exit 1
}
fi
- name: Restore dependencies cache
uses: actions/cache@v4
@@ -211,25 +196,33 @@ jobs:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Install Browser 'chromium'
run: pnpm playwright install --with-deps chromium
- name: Setup .env
run: cp .env.example .env
- name: Build frontend
run: pnpm build --turbo
# uses Turbopack, much faster and safe enough for a test pipeline
- name: Install Browser '${{ matrix.browser }}'
run: pnpm playwright install --with-deps ${{ matrix.browser }}
- name: Run Playwright tests
run: pnpm test:no-build
- name: Upload Playwright artifacts
if: failure()
uses: actions/upload-artifact@v4
with:
name: playwright-report
path: playwright-report
run: pnpm test:no-build --project=${{ matrix.browser }}
env:
BROWSER_TYPE: ${{ matrix.browser }}
- name: Print Final Docker Compose logs
if: always()
run: docker compose -f ../docker-compose.yml logs
- uses: actions/upload-artifact@v4
if: ${{ !cancelled() }}
with:
name: playwright-report-${{ matrix.browser }}
path: playwright-report/
retention-days: 30

View File

@@ -1,132 +0,0 @@
name: AutoGPT Platform - Frontend CI
on:
push:
branches: [master, dev]
paths:
- ".github/workflows/platform-fullstack-ci.yml"
- "autogpt_platform/**"
pull_request:
paths:
- ".github/workflows/platform-fullstack-ci.yml"
- "autogpt_platform/**"
merge_group:
defaults:
run:
shell: bash
working-directory: autogpt_platform/frontend
jobs:
setup:
runs-on: ubuntu-latest
outputs:
cache-key: ${{ steps.cache-key.outputs.key }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Enable corepack
run: corepack enable
- name: Generate cache key
id: cache-key
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
- name: Cache dependencies
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }}
restore-keys: |
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
types:
runs-on: ubuntu-latest
needs: setup
strategy:
fail-fast: false
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Enable corepack
run: corepack enable
- name: Copy default supabase .env
run: |
cp ../.env.default ../.env
- name: Copy backend .env
run: |
cp ../backend/.env.default ../backend/.env
- name: Run docker compose
run: |
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
- name: Restore dependencies cache
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Setup .env
run: cp .env.default .env
- name: Wait for services to be ready
run: |
echo "Waiting for rest_server to be ready..."
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
echo "Waiting for database to be ready..."
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
- name: Generate API queries
run: pnpm generate:api:force
- name: Check for API schema changes
run: |
if ! git diff --exit-code src/app/api/openapi.json; then
echo "❌ API schema changes detected in src/app/api/openapi.json"
echo ""
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
echo "The API schema is now out of sync with the Front-end queries."
echo ""
echo "To fix this:"
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
echo "2. Run 'pnpm generate:api' locally"
echo "3. Run 'pnpm types' locally"
echo "4. Fix any TypeScript errors that may have been introduced"
echo "5. Commit and push your changes"
echo ""
exit 1
else
echo "✅ No API schema changes detected"
fi
- name: Run Typescript checks
run: pnpm types

3
.gitignore vendored
View File

@@ -5,8 +5,6 @@ classic/original_autogpt/*.json
auto_gpt_workspace/*
*.mpeg
.env
# Root .env files
/.env
azure.yaml
.vscode
.idea/*
@@ -123,6 +121,7 @@ celerybeat.pid
# Environments
.direnv/
.env
.venv
env/
venv*/

View File

@@ -235,7 +235,7 @@ repos:
hooks:
- id: tsc
name: Typecheck - AutoGPT Platform - Frontend
entry: bash -c 'cd autogpt_platform/frontend && pnpm types'
entry: bash -c 'cd autogpt_platform/frontend && pnpm type-check'
files: ^autogpt_platform/frontend/
types: [file]
language: system

6
.vscode/launch.json vendored
View File

@@ -6,7 +6,7 @@
"type": "node-terminal",
"request": "launch",
"cwd": "${workspaceFolder}/autogpt_platform/frontend",
"command": "pnpm dev"
"command": "yarn dev"
},
{
"name": "Frontend: Client Side",
@@ -19,12 +19,12 @@
"type": "node-terminal",
"request": "launch",
"command": "pnpm dev",
"command": "yarn dev",
"cwd": "${workspaceFolder}/autogpt_platform/frontend",
"serverReadyAction": {
"pattern": "- Local:.+(https?://.+)",
"uriFormat": "%s",
"action": "debugWithChrome"
"action": "debugWithEdge"
}
},
{

195
LICENSE
View File

@@ -1,197 +1,6 @@
All portions of this repository are under one of two licenses.
All portions of this repository are under one of two licenses. The majority of the AutoGPT repository is under the MIT License below. The autogpt_platform folder is under the
Polyform Shield License.
- Everything inside the autogpt_platform folder is under the Polyform Shield License.
- Everything outside the autogpt_platform folder is under the MIT License.
More info:
**Polyform Shield License:**
Code and content within the `autogpt_platform` folder is licensed under the Polyform Shield License. This new project is our in-developlemt platform for building, deploying and managing agents.
Read more about this effort here: https://agpt.co/blog/introducing-the-autogpt-platform
**MIT License:**
All other portions of the AutoGPT repository (i.e., everything outside the `autogpt_platform` folder) are licensed under the MIT License. This includes:
- The Original, stand-alone AutoGPT Agent
- Forge: https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/forge
- AG Benchmark: https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/benchmark
- AutoGPT Classic GUI: https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/frontend.
We also publish additional work under the MIT Licence in other repositories, such as GravitasML (https://github.com/Significant-Gravitas/gravitasml) which is developed for and used in the AutoGPT Platform, and our [Code Ability](https://github.com/Significant-Gravitas/AutoGPT-Code-Ability) project.
Both licences are available to read below:
=====================================================
-----------------------------------------------------
=====================================================
# PolyForm Shield License 1.0.0
<https://polyformproject.org/licenses/shield/1.0.0>
## Acceptance
In order to get any license under these terms, you must agree
to them as both strict obligations and conditions to all
your licenses.
## Copyright License
The licensor grants you a copyright license for the
software to do everything you might do with the software
that would otherwise infringe the licensor's copyright
in it for any permitted purpose. However, you may
only distribute the software according to [Distribution
License](#distribution-license) and make changes or new works
based on the software according to [Changes and New Works
License](#changes-and-new-works-license).
## Distribution License
The licensor grants you an additional copyright license
to distribute copies of the software. Your license
to distribute covers distributing the software with
changes and new works permitted by [Changes and New Works
License](#changes-and-new-works-license).
## Notices
You must ensure that anyone who gets a copy of any part of
the software from you also gets a copy of these terms or the
URL for them above, as well as copies of any plain-text lines
beginning with `Required Notice:` that the licensor provided
with the software. For example:
> Required Notice: Copyright Yoyodyne, Inc. (http://example.com)
## Changes and New Works License
The licensor grants you an additional copyright license to
make changes and new works based on the software for any
permitted purpose.
## Patent License
The licensor grants you a patent license for the software that
covers patent claims the licensor can license, or becomes able
to license, that you would infringe by using the software.
## Noncompete
Any purpose is a permitted purpose, except for providing any
product that competes with the software or any product the
licensor or any of its affiliates provides using the software.
## Competition
Goods and services compete even when they provide functionality
through different kinds of interfaces or for different technical
platforms. Applications can compete with services, libraries
with plugins, frameworks with development tools, and so on,
even if they're written in different programming languages
or for different computer architectures. Goods and services
compete even when provided free of charge. If you market a
product as a practical substitute for the software or another
product, it definitely competes.
## New Products
If you are using the software to provide a product that does
not compete, but the licensor or any of its affiliates brings
your product into competition by providing a new version of
the software or another product using the software, you may
continue using versions of the software available under these
terms beforehand to provide your competing product, but not
any later versions.
## Discontinued Products
You may begin using the software to compete with a product
or service that the licensor or any of its affiliates has
stopped providing, unless the licensor includes a plain-text
line beginning with `Licensor Line of Business:` with the
software that mentions that line of business. For example:
> Licensor Line of Business: YoyodyneCMS Content Management
System (http://example.com/cms)
## Sales of Business
If the licensor or any of its affiliates sells a line of
business developing the software or using the software
to provide a product, the buyer can also enforce
[Noncompete](#noncompete) for that product.
## Fair Use
You may have "fair use" rights for the software under the
law. These terms do not limit them.
## No Other Rights
These terms do not allow you to sublicense or transfer any of
your licenses to anyone else, or prevent the licensor from
granting licenses to anyone else. These terms do not imply
any other licenses.
## Patent Defense
If you make any written claim that the software infringes or
contributes to infringement of any patent, your patent license
for the software granted under these terms ends immediately. If
your company makes such a claim, your patent license ends
immediately for work on behalf of your company.
## Violations
The first time you are notified in writing that you have
violated any of these terms, or done anything with the software
not covered by your licenses, your licenses can nonetheless
continue if you come into full compliance with these terms,
and take practical steps to correct past violations, within
32 days of receiving notice. Otherwise, all your licenses
end immediately.
## No Liability
***As far as the law allows, the software comes as is, without
any warranty or condition, and the licensor will not be liable
to you for any damages arising out of these terms or the use
or nature of the software, under any kind of legal claim.***
## Definitions
The **licensor** is the individual or entity offering these
terms, and the **software** is the software the licensor makes
available under these terms.
A **product** can be a good or service, or a combination
of them.
**You** refers to the individual or entity agreeing to these
terms.
**Your company** is any legal entity, sole proprietorship,
or other kind of organization that you work for, plus all
its affiliates.
**Affiliates** means the other organizations than an
organization has control over, is under the control of, or is
under common control with.
**Control** means ownership of substantially all the assets of
an entity, or the power to direct its management and policies
by vote, contract, or otherwise. Control can be direct or
indirect.
**Your licenses** are all the licenses granted to you for the
software under these terms.
**Use** means anything you do with the software requiring one
of your licenses.
=====================================================
-----------------------------------------------------
=====================================================
MIT License

View File

@@ -1,25 +1,16 @@
# AutoGPT: Build, Deploy, and Run AI Agents
[![Discord Follow](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fautogpt%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&label=total%20members&logo=discord&logoColor=white&color=7289da)](https://discord.gg/autogpt) &ensp;
[![Discord Follow](https://dcbadge.vercel.app/api/server/autogpt?style=flat)](https://discord.gg/autogpt) &ensp;
[![Twitter Follow](https://img.shields.io/twitter/follow/Auto_GPT?style=social)](https://twitter.com/Auto_GPT) &ensp;
<!-- Keep these links. Translations will automatically update with the README. -->
[Deutsch](https://zdoc.app/de/Significant-Gravitas/AutoGPT) |
[Español](https://zdoc.app/es/Significant-Gravitas/AutoGPT) |
[français](https://zdoc.app/fr/Significant-Gravitas/AutoGPT) |
[日本語](https://zdoc.app/ja/Significant-Gravitas/AutoGPT) |
[한국어](https://zdoc.app/ko/Significant-Gravitas/AutoGPT) |
[Português](https://zdoc.app/pt/Significant-Gravitas/AutoGPT) |
[Русский](https://zdoc.app/ru/Significant-Gravitas/AutoGPT) |
[中文](https://zdoc.app/zh/Significant-Gravitas/AutoGPT)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
**AutoGPT** is a powerful platform that allows you to create, deploy, and manage continuous AI agents that automate complex workflows.
## Hosting Options
- Download to self-host (Free!)
- [Join the Waitlist](https://bit.ly/3ZDijAI) for the cloud-hosted beta (Closed Beta - Public release Coming Soon!)
- Download to self-host
- [Join the Waitlist](https://bit.ly/3ZDijAI) for the cloud-hosted beta
## How to Self-Host the AutoGPT Platform
## How to Setup for Self-Hosting
> [!NOTE]
> Setting up and hosting the AutoGPT Platform yourself is a technical process.
> If you'd rather something that just works, we recommend [joining the waitlist](https://bit.ly/3ZDijAI) for the cloud-hosted beta.
@@ -59,24 +50,6 @@ We've moved to a fully maintained and regularly updated documentation site.
This tutorial assumes you have Docker, VSCode, git and npm installed.
---
#### ⚡ Quick Setup with One-Line Script (Recommended for Local Hosting)
Skip the manual steps and get started in minutes using our automatic setup script.
For macOS/Linux:
```
curl -fsSL https://setup.agpt.co/install.sh -o install.sh && bash install.sh
```
For Windows (PowerShell):
```
powershell -c "iwr https://setup.agpt.co/install.bat -o install.bat; ./install.bat"
```
This will install dependencies, configure Docker, and launch your local instance — all in one go.
### 🧱 AutoGPT Frontend
The AutoGPT frontend is where users interact with our powerful AI automation platform. It offers multiple ways to engage with and leverage our AI agents. This is the interface where you'll bring your AI automation ideas to life:
@@ -123,17 +96,7 @@ Here are two examples of what you can do with AutoGPT:
These examples show just a glimpse of what you can achieve with AutoGPT! You can create customized workflows to build agents for any use case.
---
### **License Overview:**
🛡️ **Polyform Shield License:**
All code and content within the `autogpt_platform` folder is licensed under the Polyform Shield License. This new project is our in-developlemt platform for building, deploying and managing agents.</br>_[Read more about this effort](https://agpt.co/blog/introducing-the-autogpt-platform)_
🦉 **MIT License:**
All other portions of the AutoGPT repository (i.e., everything outside the `autogpt_platform` folder) are licensed under the MIT License. This includes the original stand-alone AutoGPT Agent, along with projects such as [Forge](https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/forge), [agbenchmark](https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/benchmark) and the [AutoGPT Classic GUI](https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/frontend).</br>We also publish additional work under the MIT Licence in other repositories, such as [GravitasML](https://github.com/Significant-Gravitas/gravitasml) which is developed for and used in the AutoGPT Platform. See also our MIT Licenced [Code Ability](https://github.com/Significant-Gravitas/AutoGPT-Code-Ability) project.
---
### Mission
### Mission and Licencing
Our mission is to provide the tools, so that you can focus on what matters:
- 🏗️ **Building** - Lay the foundation for something amazing.
@@ -146,6 +109,14 @@ Be part of the revolution! **AutoGPT** is here to stay, at the forefront of AI i
&ensp;|&ensp;
**🚀 [Contributing](CONTRIBUTING.md)**
**Licensing:**
MIT License: The majority of the AutoGPT repository is under the MIT License.
Polyform Shield License: This license applies to the autogpt_platform folder.
For more information, see https://agpt.co/blog/introducing-the-autogpt-platform
---
## 🤖 AutoGPT Classic
> Below is information about the classic version of AutoGPT.

View File

@@ -1,11 +1,9 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Repository Overview
AutoGPT Platform is a monorepo containing:
- **Backend** (`/backend`): Python FastAPI server with async support
- **Frontend** (`/frontend`): Next.js React application
- **Shared Libraries** (`/autogpt_libs`): Common Python utilities
@@ -13,7 +11,6 @@ AutoGPT Platform is a monorepo containing:
## Essential Commands
### Backend Development
```bash
# Install dependencies
cd backend && poetry install
@@ -33,18 +30,11 @@ poetry run test
# Run specific test
poetry run pytest path/to/test_file.py::test_function_name
# Run block tests (tests that validate all blocks work correctly)
poetry run pytest backend/blocks/test/test_block.py -xvs
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
More details can be found in TESTING.md
#### Creating/Updating Snapshots
@@ -57,8 +47,8 @@ poetry run pytest path/to/test.py --snapshot-update
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
### Frontend Development
### Frontend Development
```bash
# Install dependencies
cd frontend && npm install
@@ -76,13 +66,12 @@ npm run storybook
npm run build
# Type checking
npm run types
npm run type-check
```
## Architecture Overview
### Backend Architecture
- **API Layer**: FastAPI with REST and WebSocket endpoints
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
- **Queue System**: RabbitMQ for async task processing
@@ -91,7 +80,6 @@ npm run types
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
### Frontend Architecture
- **Framework**: Next.js App Router with React Server Components
- **State Management**: React hooks + Supabase client for real-time updates
- **Workflow Builder**: Visual graph editor using @xyflow/react
@@ -99,7 +87,6 @@ npm run types
- **Feature Flags**: LaunchDarkly integration
### Key Concepts
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
3. **Integrations**: OAuth and API connections stored per user
@@ -107,16 +94,13 @@ npm run types
5. **Virus Scanning**: ClamAV integration for file upload security
### Testing Approach
- Backend uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
- Frontend uses Playwright for E2E tests
- Component testing via Storybook
### Database Schema
Key models (defined in `/backend/schema.prisma`):
- `User`: Authentication and profile data
- `AgentGraph`: Workflow definitions with version control
- `AgentGraphExecution`: Execution history and results
@@ -124,31 +108,13 @@ Key models (defined in `/backend/schema.prisma`):
- `StoreListing`: Marketplace listings for sharing agents
### Environment Configuration
#### Configuration Files
- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides)
- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides)
- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides)
#### Docker Environment Loading Order
1. `.env.default` files provide base configuration (tracked in git)
2. `.env` files provide user-specific overrides (gitignored)
3. Docker Compose `environment:` sections provide service-specific overrides
4. Shell environment variables have highest precedence
#### Key Points
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
- The `env_file` directive loads variables INTO containers at runtime
- Backend/Frontend services use YAML anchors for consistent configuration
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
- Backend: `.env` file in `/backend`
- Frontend: `.env.local` file in `/frontend`
- Both require Supabase credentials and API keys for various services
### Common Development Tasks
**Adding a new block:**
1. Create new file in `/backend/backend/blocks/`
2. Inherit from `Block` base class
3. Define input/output schemas
@@ -156,18 +122,13 @@ Key models (defined in `/backend/schema.prisma`):
5. Register in block registry
6. Generate the block uuid using `uuid.uuid4()`
Note: when making many new blocks analyze the interfaces for each of these blcoks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
**Modifying the API:**
1. Update route in `/backend/backend/server/routers/`
2. Add/update Pydantic models in same directory
3. Write tests alongside the route file
4. Run `poetry run test` to verify
**Frontend feature development:**
1. Components go in `/frontend/src/components/`
2. Use existing UI components from `/frontend/src/components/ui/`
3. Add Storybook stories for new components
@@ -176,7 +137,6 @@ ex: do the inputs and outputs tie well together?
### Security Implementation
**Cache Protection Middleware:**
- Located in `/backend/backend/server/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
@@ -184,47 +144,3 @@ ex: do the inputs and outputs tie well together?
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications
### Creating Pull Requests
- Create the PR aginst the `dev` branch of the repository.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/
- Use conventional commit messages (see below)/
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/
- Run the github pre-commit hooks to ensure code quality.
### Reviewing/Revising Pull Requests
- When the user runs /pr-comments or tries to fetch them, also run gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews to get the reviews
- Use gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews/[review_id]/comments to get the review contents
- Use gh api /repos/Significant-Gravitas/AutoGPT/issues/9924/comments to get the pr specific comments
### Conventional Commits
Use this format for commit messages and Pull Request titles:
**Conventional Commit Types:**
- `feat`: Introduces a new feature to the codebase
- `fix`: Patches a bug in the codebase
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
- `ci`: Changes to CI configuration
- `docs`: Documentation-only changes
- `dx`: Improvements to the developer experience
**Recommended Base Scopes:**
- `platform`: Changes affecting both frontend and backend
- `frontend`
- `backend`
- `infra`
- `blocks`: Modifications/additions of individual blocks
**Subscope Examples:**
- `backend/executor`
- `backend/db`
- `frontend/builder` (includes changes to the block UI component)
- `infra/prod`
Use these scopes and subscopes for clarity and consistency in commit messages.

View File

@@ -8,6 +8,7 @@ Welcome to the AutoGPT Platform - a powerful system for creating and running AI
- Docker
- Docker Compose V2 (comes with Docker Desktop, or can be installed separately)
- Node.js & NPM (for running the frontend application)
### Running the System
@@ -23,10 +24,10 @@ To run the AutoGPT Platform, follow these steps:
2. Run the following command:
```
cp .env.default .env
cp .env.example .env
```
This command will copy the `.env.default` file to `.env`. You can modify the `.env` file to add your own environment variables.
This command will copy the `.env.example` file to `.env`. You can modify the `.env` file to add your own environment variables.
3. Run the following command:
@@ -36,7 +37,44 @@ To run the AutoGPT Platform, follow these steps:
This command will start all the necessary backend services defined in the `docker-compose.yml` file in detached mode.
4. After all the services are in ready state, open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
4. Navigate to `frontend` within the `autogpt_platform` directory:
```
cd frontend
```
You will need to run your frontend application separately on your local machine.
5. Run the following command:
```
cp .env.example .env.local
```
This command will copy the `.env.example` file to `.env.local` in the `frontend` directory. You can modify the `.env.local` within this folder to add your own environment variables for the frontend application.
6. Run the following command:
Enable corepack and install dependencies by running:
```
corepack enable
pnpm i
```
Generate the API client (this step is required before running the frontend):
```
pnpm generate:api-client
```
Then start the frontend application in development mode:
```
pnpm dev
```
7. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
### Docker Compose Commands
@@ -139,21 +177,20 @@ The platform includes scripts for generating and managing the API client:
- `pnpm fetch:openapi`: Fetches the OpenAPI specification from the backend service (requires backend to be running on port 8006)
- `pnpm generate:api-client`: Generates the TypeScript API client from the OpenAPI specification using Orval
- `pnpm generate:api`: Runs both fetch and generate commands in sequence
- `pnpm generate:api-all`: Runs both fetch and generate commands in sequence
#### Manual API Client Updates
If you need to update the API client after making changes to the backend API:
1. Ensure the backend services are running:
```
docker compose up -d
```
2. Generate the updated API client:
```
pnpm generate:api
pnpm generate:api-all
```
This will fetch the latest OpenAPI specification and regenerate the TypeScript client code.

File diff suppressed because it is too large Load Diff

View File

@@ -7,5 +7,9 @@ class Settings:
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
self.JWT_ALGORITHM: str = "HS256"
@property
def is_configured(self) -> bool:
return bool(self.JWT_SECRET_KEY)
settings = Settings()

View File

@@ -0,0 +1,166 @@
import asyncio
import contextlib
import logging
from functools import wraps
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, cast
import ldclient
from fastapi import HTTPException
from ldclient import Context, LDClient
from ldclient.config import Config
from typing_extensions import ParamSpec
from .config import SETTINGS
logger = logging.getLogger(__name__)
P = ParamSpec("P")
T = TypeVar("T")
def get_client() -> LDClient:
"""Get the LaunchDarkly client singleton."""
return ldclient.get()
def initialize_launchdarkly() -> None:
sdk_key = SETTINGS.launch_darkly_sdk_key
logger.debug(
f"Initializing LaunchDarkly with SDK key: {'present' if sdk_key else 'missing'}"
)
if not sdk_key:
logger.warning("LaunchDarkly SDK key not configured")
return
config = Config(sdk_key)
ldclient.set_config(config)
if ldclient.get().is_initialized():
logger.info("LaunchDarkly client initialized successfully")
else:
logger.error("LaunchDarkly client failed to initialize")
def shutdown_launchdarkly() -> None:
"""Shutdown the LaunchDarkly client."""
if ldclient.get().is_initialized():
ldclient.get().close()
logger.info("LaunchDarkly client closed successfully")
def create_context(
user_id: str, additional_attributes: Optional[Dict[str, Any]] = None
) -> Context:
"""Create LaunchDarkly context with optional additional attributes."""
builder = Context.builder(str(user_id)).kind("user")
if additional_attributes:
for key, value in additional_attributes.items():
builder.set(key, value)
return builder.build()
def feature_flag(
flag_key: str,
default: bool = False,
) -> Callable[
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
]:
"""
Decorator for feature flag protected endpoints.
"""
def decorator(
func: Callable[P, Union[T, Awaitable[T]]],
) -> Callable[P, Union[T, Awaitable[T]]]:
@wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
try:
user_id = kwargs.get("user_id")
if not user_id:
raise ValueError("user_id is required")
if not get_client().is_initialized():
logger.warning(
f"LaunchDarkly not initialized, using default={default}"
)
is_enabled = default
else:
context = create_context(str(user_id))
is_enabled = get_client().variation(flag_key, context, default)
if not is_enabled:
raise HTTPException(status_code=404, detail="Feature not available")
result = func(*args, **kwargs)
if asyncio.iscoroutine(result):
return await result
return cast(T, result)
except Exception as e:
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
raise
@wraps(func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
try:
user_id = kwargs.get("user_id")
if not user_id:
raise ValueError("user_id is required")
if not get_client().is_initialized():
logger.warning(
f"LaunchDarkly not initialized, using default={default}"
)
is_enabled = default
else:
context = create_context(str(user_id))
is_enabled = get_client().variation(flag_key, context, default)
if not is_enabled:
raise HTTPException(status_code=404, detail="Feature not available")
return cast(T, func(*args, **kwargs))
except Exception as e:
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
raise
return cast(
Callable[P, Union[T, Awaitable[T]]],
async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper,
)
return decorator
def percentage_rollout(
flag_key: str,
default: bool = False,
) -> Callable[
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
]:
"""Decorator for percentage-based rollouts."""
return feature_flag(flag_key, default)
def beta_feature(
flag_key: Optional[str] = None,
unauthorized_response: Any = {"message": "Not available in beta"},
) -> Callable[
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
]:
"""Decorator for beta features."""
actual_key = f"beta-{flag_key}" if flag_key else "beta"
return feature_flag(actual_key, False)
@contextlib.contextmanager
def mock_flag_variation(flag_key: str, return_value: Any):
"""Context manager for testing feature flags."""
original_variation = get_client().variation
get_client().variation = lambda key, context, default: (
return_value if key == flag_key else original_variation(key, context, default)
)
try:
yield
finally:
get_client().variation = original_variation

View File

@@ -0,0 +1,45 @@
import pytest
from ldclient import LDClient
from autogpt_libs.feature_flag.client import feature_flag, mock_flag_variation
@pytest.fixture
def ld_client(mocker):
client = mocker.Mock(spec=LDClient)
mocker.patch("ldclient.get", return_value=client)
client.is_initialized.return_value = True
return client
@pytest.mark.asyncio
async def test_feature_flag_enabled(ld_client):
ld_client.variation.return_value = True
@feature_flag("test-flag")
async def test_function(user_id: str):
return "success"
result = test_function(user_id="test-user")
assert result == "success"
ld_client.variation.assert_called_once()
@pytest.mark.asyncio
async def test_feature_flag_unauthorized_response(ld_client):
ld_client.variation.return_value = False
@feature_flag("test-flag")
async def test_function(user_id: str):
return "success"
result = test_function(user_id="test-user")
assert result == {"error": "disabled"}
def test_mock_flag_variation(ld_client):
with mock_flag_variation("test-flag", True):
assert ld_client.variation("test-flag", None, False)
with mock_flag_variation("test-flag", False):
assert ld_client.variation("test-flag", None, False)

View File

@@ -0,0 +1,15 @@
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
launch_darkly_sdk_key: str = Field(
default="",
description="The Launch Darkly SDK key",
validation_alias="LAUNCH_DARKLY_SDK_KEY",
)
model_config = SettingsConfigDict(case_sensitive=True, extra="ignore")
SETTINGS = Settings()

View File

@@ -1,8 +1,6 @@
"""Logging module for Auto-GPT."""
import logging
import os
import socket
import sys
from pathlib import Path
@@ -12,15 +10,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
from .filters import BelowLevelFilter
from .formatters import AGPTFormatter
# Configure global socket timeout and gRPC keepalive to prevent deadlocks
# This must be done at import time before any gRPC connections are established
socket.setdefaulttimeout(30) # 30-second socket timeout
# Enable gRPC keepalive to detect dead connections faster
os.environ.setdefault("GRPC_KEEPALIVE_TIME_MS", "30000") # 30 seconds
os.environ.setdefault("GRPC_KEEPALIVE_TIMEOUT_MS", "5000") # 5 seconds
os.environ.setdefault("GRPC_KEEPALIVE_PERMIT_WITHOUT_CALLS", "true")
LOG_DIR = Path(__file__).parent.parent.parent.parent / "logs"
LOG_FILE = "activity.log"
DEBUG_LOG_FILE = "debug.log"
@@ -90,6 +79,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
Note: This function is typically called at the start of the application
to set up the logging infrastructure.
"""
config = LoggingConfig()
log_handlers: list[logging.Handler] = []
@@ -115,17 +105,13 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
if config.enable_cloud_logging or force_cloud_logging:
import google.cloud.logging
from google.cloud.logging.handlers import CloudLoggingHandler
from google.cloud.logging_v2.handlers.transports import (
BackgroundThreadTransport,
)
from google.cloud.logging_v2.handlers.transports.sync import SyncTransport
client = google.cloud.logging.Client()
# Use BackgroundThreadTransport to prevent blocking the main thread
# and deadlocks when gRPC calls to Google Cloud Logging hang
cloud_handler = CloudLoggingHandler(
client,
name="autogpt_logs",
transport=BackgroundThreadTransport,
transport=SyncTransport,
)
cloud_handler.setLevel(config.level)
log_handlers.append(cloud_handler)

View File

@@ -1,5 +1,39 @@
import logging
import re
from typing import Any
import uvicorn.config
from colorama import Fore
def remove_color_codes(s: str) -> str:
return re.sub(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])", "", s)
def fmt_kwargs(kwargs: dict) -> str:
return ", ".join(f"{n}={repr(v)}" for n, v in kwargs.items())
def print_attribute(
title: str, value: Any, title_color: str = Fore.GREEN, value_color: str = ""
) -> None:
logger = logging.getLogger()
logger.info(
str(value),
extra={
"title": f"{title.rstrip(':')}:",
"title_color": title_color,
"color": value_color,
},
)
def generate_uvicorn_config():
"""
Generates a uvicorn logging config that silences uvicorn's default logging and tells it to use the native logging module.
"""
log_config = dict(uvicorn.config.LOGGING_CONFIG)
log_config["loggers"]["uvicorn"] = {"handlers": []}
log_config["loggers"]["uvicorn.error"] = {"handlers": []}
log_config["loggers"]["uvicorn.access"] = {"handlers": []}
return log_config

View File

@@ -1,34 +1,17 @@
import inspect
import logging
import threading
import time
from functools import wraps
from typing import (
Awaitable,
Callable,
ParamSpec,
Protocol,
Tuple,
TypeVar,
cast,
overload,
runtime_checkable,
)
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
P = ParamSpec("P")
R = TypeVar("R")
logger = logging.getLogger(__name__)
@overload
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
@overload
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
pass
@overload
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
pass
def thread_cached(func: Callable[P, R]) -> Callable[P, R]: ...
def thread_cached(
@@ -74,193 +57,3 @@ def thread_cached(
def clear_thread_cache(func: Callable) -> None:
if clear := getattr(func, "clear_cache", None):
clear()
FuncT = TypeVar("FuncT")
R_co = TypeVar("R_co", covariant=True)
@runtime_checkable
class AsyncCachedFunction(Protocol[P, R_co]):
"""Protocol for async functions with cache management methods."""
def cache_clear(self) -> None:
"""Clear all cached entries."""
return None
def cache_info(self) -> dict[str, int | None]:
"""Get cache statistics."""
return {}
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
"""Call the cached function."""
return None # type: ignore
def async_ttl_cache(
maxsize: int = 128, ttl_seconds: int | None = None
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
"""
TTL (Time To Live) cache decorator for async functions.
Similar to functools.lru_cache but works with async functions and includes optional TTL.
Args:
maxsize: Maximum number of cached entries
ttl_seconds: Time to live in seconds. If None, entries never expire (like lru_cache)
Returns:
Decorator function
Example:
# With TTL
@async_ttl_cache(maxsize=1000, ttl_seconds=300)
async def api_call(param: str) -> dict:
return {"result": param}
# Without TTL (permanent cache like lru_cache)
@async_ttl_cache(maxsize=1000)
async def expensive_computation(param: str) -> dict:
return {"result": param}
"""
def decorator(
async_func: Callable[P, Awaitable[R]],
) -> AsyncCachedFunction[P, R]:
# Cache storage - use union type to handle both cases
cache_storage: dict[tuple, R | Tuple[R, float]] = {}
@wraps(async_func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# Create cache key from arguments
key = (args, tuple(sorted(kwargs.items())))
current_time = time.time()
# Check if we have a valid cached entry
if key in cache_storage:
if ttl_seconds is None:
# No TTL - return cached result directly
logger.debug(
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
)
return cast(R, cache_storage[key])
else:
# With TTL - check expiration
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.debug(
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
)
return cast(R, result)
else:
# Expired entry
del cache_storage[key]
logger.debug(
f"Cache entry expired for {async_func.__name__}"
)
# Cache miss or expired - fetch fresh data
logger.debug(
f"Cache miss for {async_func.__name__} with key: {str(key)[:50]}"
)
result = await async_func(*args, **kwargs)
# Store in cache
if ttl_seconds is None:
cache_storage[key] = result
else:
cache_storage[key] = (result, current_time)
# Simple cleanup when cache gets too large
if len(cache_storage) > maxsize:
# Remove oldest entries (simple FIFO cleanup)
cutoff = maxsize // 2
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
logger.debug(
f"Cache cleanup: removed {len(oldest_keys)} entries for {async_func.__name__}"
)
return result
# Add cache management methods (similar to functools.lru_cache)
def cache_clear() -> None:
cache_storage.clear()
def cache_info() -> dict[str, int | None]:
return {
"size": len(cache_storage),
"maxsize": maxsize,
"ttl_seconds": ttl_seconds,
}
# Attach methods to wrapper
setattr(wrapper, "cache_clear", cache_clear)
setattr(wrapper, "cache_info", cache_info)
return cast(AsyncCachedFunction[P, R], wrapper)
return decorator
@overload
def async_cache(
func: Callable[P, Awaitable[R]],
) -> AsyncCachedFunction[P, R]:
pass
@overload
def async_cache(
func: None = None,
*,
maxsize: int = 128,
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
pass
def async_cache(
func: Callable[P, Awaitable[R]] | None = None,
*,
maxsize: int = 128,
) -> (
AsyncCachedFunction[P, R]
| Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]
):
"""
Process-level cache decorator for async functions (no TTL).
Similar to functools.lru_cache but works with async functions.
This is a convenience wrapper around async_ttl_cache with ttl_seconds=None.
Args:
func: The async function to cache (when used without parentheses)
maxsize: Maximum number of cached entries
Returns:
Decorated function or decorator
Example:
# Without parentheses (uses default maxsize=128)
@async_cache
async def get_data(param: str) -> dict:
return {"result": param}
# With parentheses and custom maxsize
@async_cache(maxsize=1000)
async def expensive_computation(param: str) -> dict:
# Expensive computation here
return {"result": param}
"""
if func is None:
# Called with parentheses @async_cache() or @async_cache(maxsize=...)
return async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
else:
# Called without parentheses @async_cache
decorator = async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
return decorator(func)

View File

@@ -1,705 +0,0 @@
"""Tests for the @thread_cached decorator.
This module tests the thread-local caching functionality including:
- Basic caching for sync and async functions
- Thread isolation (each thread has its own cache)
- Cache clearing functionality
- Exception handling (exceptions are not cached)
- Argument handling (positional vs keyword arguments)
"""
import asyncio
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock
import pytest
from autogpt_libs.utils.cache import (
async_cache,
async_ttl_cache,
clear_thread_cache,
thread_cached,
)
class TestThreadCached:
def test_sync_function_caching(self):
call_count = 0
@thread_cached
def expensive_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
return x + y
assert expensive_function(1, 2) == 3
assert call_count == 1
assert expensive_function(1, 2) == 3
assert call_count == 1
assert expensive_function(1, y=2) == 3
assert call_count == 2
assert expensive_function(2, 3) == 5
assert call_count == 3
assert expensive_function(1) == 1
assert call_count == 4
@pytest.mark.asyncio
async def test_async_function_caching(self):
call_count = 0
@thread_cached
async def expensive_async_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return x + y
assert await expensive_async_function(1, 2) == 3
assert call_count == 1
assert await expensive_async_function(1, 2) == 3
assert call_count == 1
assert await expensive_async_function(1, y=2) == 3
assert call_count == 2
assert await expensive_async_function(2, 3) == 5
assert call_count == 3
def test_thread_isolation(self):
call_count = 0
results = {}
@thread_cached
def thread_specific_function(x: int) -> str:
nonlocal call_count
call_count += 1
return f"{threading.current_thread().name}-{x}"
def worker(thread_id: int):
result1 = thread_specific_function(1)
result2 = thread_specific_function(1)
result3 = thread_specific_function(2)
results[thread_id] = (result1, result2, result3)
with ThreadPoolExecutor(max_workers=3) as executor:
futures = [executor.submit(worker, i) for i in range(3)]
for future in futures:
future.result()
assert call_count >= 2
for thread_id, (r1, r2, r3) in results.items():
assert r1 == r2
assert r1 != r3
@pytest.mark.asyncio
async def test_async_thread_isolation(self):
call_count = 0
results = {}
@thread_cached
async def async_thread_specific_function(x: int) -> str:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return f"{threading.current_thread().name}-{x}"
async def async_worker(worker_id: int):
result1 = await async_thread_specific_function(1)
result2 = await async_thread_specific_function(1)
result3 = await async_thread_specific_function(2)
results[worker_id] = (result1, result2, result3)
tasks = [async_worker(i) for i in range(3)]
await asyncio.gather(*tasks)
for worker_id, (r1, r2, r3) in results.items():
assert r1 == r2
assert r1 != r3
def test_clear_cache_sync(self):
call_count = 0
@thread_cached
def clearable_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 2
assert clearable_function(5) == 10
assert call_count == 1
assert clearable_function(5) == 10
assert call_count == 1
clear_thread_cache(clearable_function)
assert clearable_function(5) == 10
assert call_count == 2
@pytest.mark.asyncio
async def test_clear_cache_async(self):
call_count = 0
@thread_cached
async def clearable_async_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return x * 2
assert await clearable_async_function(5) == 10
assert call_count == 1
assert await clearable_async_function(5) == 10
assert call_count == 1
clear_thread_cache(clearable_async_function)
assert await clearable_async_function(5) == 10
assert call_count == 2
def test_simple_arguments(self):
call_count = 0
@thread_cached
def simple_function(a: str, b: int, c: str = "default") -> str:
nonlocal call_count
call_count += 1
return f"{a}-{b}-{c}"
# First call with all positional args
result1 = simple_function("test", 42, "custom")
assert call_count == 1
# Same args, all positional - should hit cache
result2 = simple_function("test", 42, "custom")
assert call_count == 1
assert result1 == result2
# Same values but last arg as keyword - creates different cache key
result3 = simple_function("test", 42, c="custom")
assert call_count == 2
assert result1 == result3 # Same result, different cache entry
# Different value - new cache entry
result4 = simple_function("test", 43, "custom")
assert call_count == 3
assert result1 != result4
def test_positional_vs_keyword_args(self):
"""Test that positional and keyword arguments create different cache entries."""
call_count = 0
@thread_cached
def func(a: int, b: int = 10) -> str:
nonlocal call_count
call_count += 1
return f"result-{a}-{b}"
# All positional
result1 = func(1, 2)
assert call_count == 1
assert result1 == "result-1-2"
# Same values, but second arg as keyword
result2 = func(1, b=2)
assert call_count == 2 # Different cache key!
assert result2 == "result-1-2" # Same result
# Verify both are cached separately
func(1, 2) # Uses first cache entry
assert call_count == 2
func(1, b=2) # Uses second cache entry
assert call_count == 2
def test_exception_handling(self):
call_count = 0
@thread_cached
def failing_function(x: int) -> int:
nonlocal call_count
call_count += 1
if x < 0:
raise ValueError("Negative value")
return x * 2
assert failing_function(5) == 10
assert call_count == 1
with pytest.raises(ValueError):
failing_function(-1)
assert call_count == 2
with pytest.raises(ValueError):
failing_function(-1)
assert call_count == 3
assert failing_function(5) == 10
assert call_count == 3
@pytest.mark.asyncio
async def test_async_exception_handling(self):
call_count = 0
@thread_cached
async def async_failing_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
if x < 0:
raise ValueError("Negative value")
return x * 2
assert await async_failing_function(5) == 10
assert call_count == 1
with pytest.raises(ValueError):
await async_failing_function(-1)
assert call_count == 2
with pytest.raises(ValueError):
await async_failing_function(-1)
assert call_count == 3
def test_sync_caching_performance(self):
@thread_cached
def slow_function(x: int) -> int:
print(f"slow_function called with x={x}")
time.sleep(0.1)
return x * 2
start = time.time()
result1 = slow_function(5)
first_call_time = time.time() - start
print(f"First call took {first_call_time:.4f} seconds")
start = time.time()
result2 = slow_function(5)
second_call_time = time.time() - start
print(f"Second call took {second_call_time:.4f} seconds")
assert result1 == result2 == 10
assert first_call_time > 0.09
assert second_call_time < 0.01
@pytest.mark.asyncio
async def test_async_caching_performance(self):
@thread_cached
async def slow_async_function(x: int) -> int:
print(f"slow_async_function called with x={x}")
await asyncio.sleep(0.1)
return x * 2
start = time.time()
result1 = await slow_async_function(5)
first_call_time = time.time() - start
print(f"First async call took {first_call_time:.4f} seconds")
start = time.time()
result2 = await slow_async_function(5)
second_call_time = time.time() - start
print(f"Second async call took {second_call_time:.4f} seconds")
assert result1 == result2 == 10
assert first_call_time > 0.09
assert second_call_time < 0.01
def test_with_mock_objects(self):
mock = Mock(return_value=42)
@thread_cached
def function_using_mock(x: int) -> int:
return mock(x)
assert function_using_mock(1) == 42
assert mock.call_count == 1
assert function_using_mock(1) == 42
assert mock.call_count == 1
assert function_using_mock(2) == 42
assert mock.call_count == 2
class TestAsyncTTLCache:
"""Tests for the @async_ttl_cache decorator."""
@pytest.mark.asyncio
async def test_basic_caching(self):
"""Test basic caching functionality."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def cached_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01) # Simulate async work
return x + y
# First call
result1 = await cached_function(1, 2)
assert result1 == 3
assert call_count == 1
# Second call with same args - should use cache
result2 = await cached_function(1, 2)
assert result2 == 3
assert call_count == 1 # No additional call
# Different args - should call function again
result3 = await cached_function(2, 3)
assert result3 == 5
assert call_count == 2
@pytest.mark.asyncio
async def test_ttl_expiration(self):
"""Test that cache entries expire after TTL."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
async def short_lived_cache(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 2
# First call
result1 = await short_lived_cache(5)
assert result1 == 10
assert call_count == 1
# Second call immediately - should use cache
result2 = await short_lived_cache(5)
assert result2 == 10
assert call_count == 1
# Wait for TTL to expire
await asyncio.sleep(1.1)
# Third call after expiration - should call function again
result3 = await short_lived_cache(5)
assert result3 == 10
assert call_count == 2
@pytest.mark.asyncio
async def test_cache_info(self):
"""Test cache info functionality."""
@async_ttl_cache(maxsize=5, ttl_seconds=300)
async def info_test_function(x: int) -> int:
return x * 3
# Check initial cache info
info = info_test_function.cache_info()
assert info["size"] == 0
assert info["maxsize"] == 5
assert info["ttl_seconds"] == 300
# Add an entry
await info_test_function(1)
info = info_test_function.cache_info()
assert info["size"] == 1
@pytest.mark.asyncio
async def test_cache_clear(self):
"""Test cache clearing functionality."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def clearable_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 4
# First call
result1 = await clearable_function(2)
assert result1 == 8
assert call_count == 1
# Second call - should use cache
result2 = await clearable_function(2)
assert result2 == 8
assert call_count == 1
# Clear cache
clearable_function.cache_clear()
# Third call after clear - should call function again
result3 = await clearable_function(2)
assert result3 == 8
assert call_count == 2
@pytest.mark.asyncio
async def test_maxsize_cleanup(self):
"""Test that cache cleans up when maxsize is exceeded."""
call_count = 0
@async_ttl_cache(maxsize=3, ttl_seconds=60)
async def size_limited_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x**2
# Fill cache to maxsize
await size_limited_function(1) # call_count: 1
await size_limited_function(2) # call_count: 2
await size_limited_function(3) # call_count: 3
info = size_limited_function.cache_info()
assert info["size"] == 3
# Add one more entry - should trigger cleanup
await size_limited_function(4) # call_count: 4
# Cache size should be reduced (cleanup removes oldest entries)
info = size_limited_function.cache_info()
assert info["size"] is not None and info["size"] <= 3 # Should be cleaned up
@pytest.mark.asyncio
async def test_argument_variations(self):
"""Test caching with different argument patterns."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def arg_test_function(a: int, b: str = "default", *, c: int = 100) -> str:
nonlocal call_count
call_count += 1
return f"{a}-{b}-{c}"
# Different ways to call with same logical arguments
result1 = await arg_test_function(1, "test", c=200)
assert call_count == 1
# Same arguments, same order - should use cache
result2 = await arg_test_function(1, "test", c=200)
assert call_count == 1
assert result1 == result2
# Different arguments - should call function
result3 = await arg_test_function(2, "test", c=200)
assert call_count == 2
assert result1 != result3
@pytest.mark.asyncio
async def test_exception_handling(self):
"""Test that exceptions are not cached."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def exception_function(x: int) -> int:
nonlocal call_count
call_count += 1
if x < 0:
raise ValueError("Negative value not allowed")
return x * 2
# Successful call - should be cached
result1 = await exception_function(5)
assert result1 == 10
assert call_count == 1
# Same successful call - should use cache
result2 = await exception_function(5)
assert result2 == 10
assert call_count == 1
# Exception call - should not be cached
with pytest.raises(ValueError):
await exception_function(-1)
assert call_count == 2
# Same exception call - should call again (not cached)
with pytest.raises(ValueError):
await exception_function(-1)
assert call_count == 3
@pytest.mark.asyncio
async def test_concurrent_calls(self):
"""Test caching behavior with concurrent calls."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def concurrent_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.05) # Simulate work
return x * x
# Launch concurrent calls with same arguments
tasks = [concurrent_function(3) for _ in range(5)]
results = await asyncio.gather(*tasks)
# All results should be the same
assert all(result == 9 for result in results)
# Note: Due to race conditions, call_count might be up to 5 for concurrent calls
# This tests that the cache doesn't break under concurrent access
assert 1 <= call_count <= 5
class TestAsyncCache:
"""Tests for the @async_cache decorator (no TTL)."""
@pytest.mark.asyncio
async def test_basic_caching_no_ttl(self):
"""Test basic caching functionality without TTL."""
call_count = 0
@async_cache(maxsize=10)
async def cached_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01) # Simulate async work
return x + y
# First call
result1 = await cached_function(1, 2)
assert result1 == 3
assert call_count == 1
# Second call with same args - should use cache
result2 = await cached_function(1, 2)
assert result2 == 3
assert call_count == 1 # No additional call
# Third call after some time - should still use cache (no TTL)
await asyncio.sleep(0.05)
result3 = await cached_function(1, 2)
assert result3 == 3
assert call_count == 1 # Still no additional call
# Different args - should call function again
result4 = await cached_function(2, 3)
assert result4 == 5
assert call_count == 2
@pytest.mark.asyncio
async def test_no_ttl_vs_ttl_behavior(self):
"""Test the difference between TTL and no-TTL caching."""
ttl_call_count = 0
no_ttl_call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
async def ttl_function(x: int) -> int:
nonlocal ttl_call_count
ttl_call_count += 1
return x * 2
@async_cache(maxsize=10) # No TTL
async def no_ttl_function(x: int) -> int:
nonlocal no_ttl_call_count
no_ttl_call_count += 1
return x * 2
# First calls
await ttl_function(5)
await no_ttl_function(5)
assert ttl_call_count == 1
assert no_ttl_call_count == 1
# Wait for TTL to expire
await asyncio.sleep(1.1)
# Second calls after TTL expiry
await ttl_function(5) # Should call function again (TTL expired)
await no_ttl_function(5) # Should use cache (no TTL)
assert ttl_call_count == 2 # TTL function called again
assert no_ttl_call_count == 1 # No-TTL function still cached
@pytest.mark.asyncio
async def test_async_cache_info(self):
"""Test cache info for no-TTL cache."""
@async_cache(maxsize=5)
async def info_test_function(x: int) -> int:
return x * 3
# Check initial cache info
info = info_test_function.cache_info()
assert info["size"] == 0
assert info["maxsize"] == 5
assert info["ttl_seconds"] is None # No TTL
# Add an entry
await info_test_function(1)
info = info_test_function.cache_info()
assert info["size"] == 1
class TestTTLOptional:
"""Tests for optional TTL functionality."""
@pytest.mark.asyncio
async def test_ttl_none_behavior(self):
"""Test that ttl_seconds=None works like no TTL."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=None)
async def no_ttl_via_none(x: int) -> int:
nonlocal call_count
call_count += 1
return x**2
# First call
result1 = await no_ttl_via_none(3)
assert result1 == 9
assert call_count == 1
# Wait (would expire if there was TTL)
await asyncio.sleep(0.1)
# Second call - should still use cache
result2 = await no_ttl_via_none(3)
assert result2 == 9
assert call_count == 1 # No additional call
# Check cache info
info = no_ttl_via_none.cache_info()
assert info["ttl_seconds"] is None
@pytest.mark.asyncio
async def test_cache_options_comparison(self):
"""Test different cache options work as expected."""
ttl_calls = 0
no_ttl_calls = 0
@async_ttl_cache(maxsize=10, ttl_seconds=1) # With TTL
async def ttl_function(x: int) -> int:
nonlocal ttl_calls
ttl_calls += 1
return x * 10
@async_cache(maxsize=10) # Process-level cache (no TTL)
async def process_function(x: int) -> int:
nonlocal no_ttl_calls
no_ttl_calls += 1
return x * 10
# Both should cache initially
await ttl_function(3)
await process_function(3)
assert ttl_calls == 1
assert no_ttl_calls == 1
# Immediate second calls - both should use cache
await ttl_function(3)
await process_function(3)
assert ttl_calls == 1
assert no_ttl_calls == 1
# Wait for TTL to expire
await asyncio.sleep(1.1)
# After TTL expiry
await ttl_function(3) # Should call function again
await process_function(3) # Should still use cache
assert ttl_calls == 2 # TTL cache expired, called again
assert no_ttl_calls == 1 # Process cache never expires

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,8 @@
[tool.poetry]
name = "autogpt-libs"
version = "0.2.0"
description = "Shared libraries across AutoGPT Platform"
authors = ["AutoGPT team <info@agpt.co>"]
description = "Shared libraries across NextGen AutoGPT"
authors = ["Aarushi <aarushik93@gmail.com>"]
readme = "README.md"
packages = [{ include = "autogpt_libs" }]
@@ -10,20 +10,20 @@ packages = [{ include = "autogpt_libs" }]
python = ">=3.10,<4.0"
colorama = "^0.4.6"
expiringdict = "^1.2.2"
fastapi = "^0.116.1"
google-cloud-logging = "^3.12.1"
launchdarkly-server-sdk = "^9.12.0"
pydantic = "^2.11.7"
pydantic-settings = "^2.10.1"
pydantic = "^2.11.4"
pydantic-settings = "^2.9.1"
pyjwt = "^2.10.1"
pytest-asyncio = "^1.1.0"
pytest-mock = "^3.14.1"
redis = "^6.2.0"
supabase = "^2.16.0"
uvicorn = "^0.35.0"
pytest-asyncio = "^0.26.0"
pytest-mock = "^3.14.0"
supabase = "^2.15.1"
launchdarkly-server-sdk = "^9.11.1"
fastapi = "^0.115.12"
uvicorn = "^0.34.3"
[tool.poetry.group.dev.dependencies]
ruff = "^0.12.3"
redis = "^5.2.1"
ruff = "^0.12.2"
[build-system]
requires = ["poetry-core"]

View File

@@ -1,52 +0,0 @@
# Development and testing files
**/__pycache__
**/*.pyc
**/*.pyo
**/*.pyd
**/.Python
**/env/
**/venv/
**/.venv/
**/pip-log.txt
**/.pytest_cache/
**/test-results/
**/snapshots/
**/test/
# IDE and editor files
**/.vscode/
**/.idea/
**/*.swp
**/*.swo
*~
# OS files
.DS_Store
Thumbs.db
# Logs
**/*.log
**/logs/
# Git
.git/
.gitignore
# Documentation
**/*.md
!README.md
# Local development files
.env
.env.local
**/.env.test
# Build artifacts
**/dist/
**/build/
**/target/
# Docker files (avoid recursion)
Dockerfile*
docker-compose*
.dockerignore

View File

@@ -1,9 +1,3 @@
# Backend Configuration
# This file contains environment variables that MUST be set for the AutoGPT platform
# Variables with working defaults in settings.py are not included here
## ===== REQUIRED DATABASE CONFIGURATION ===== ##
# PostgreSQL Database Connection
DB_USER=postgres
DB_PASS=your-super-secret-and-long-postgres-password
DB_NAME=postgres
@@ -16,50 +10,72 @@ DB_SCHEMA=platform
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
PRISMA_SCHEMA="postgres/schema.prisma"
ENABLE_AUTH=true
## ===== REQUIRED SERVICE CREDENTIALS ===== ##
# Redis Configuration
# EXECUTOR
NUM_GRAPH_WORKERS=10
BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]
# generate using `from cryptography.fernet import Fernet;Fernet.generate_key().decode()`
ENCRYPTION_KEY='dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw='
UNSUBSCRIBE_SECRET_KEY = 'HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio='
REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_PASSWORD=password
# RabbitMQ Credentials
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
ENABLE_CREDIT=false
STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=
# Supabase Authentication
# What environment things should be logged under: local dev or prod
APP_ENV=local
# What environment to behave as: "local" or "cloud"
BEHAVE_AS=local
PYRO_HOST=localhost
SENTRY_DSN=
# Email For Postmark so we can send emails
POSTMARK_SERVER_API_TOKEN=
POSTMARK_SENDER_EMAIL=invalid@invalid.com
POSTMARK_WEBHOOK_TOKEN=
## User auth with Supabase is required for any of the 3rd party integrations with auth to work.
ENABLE_AUTH=true
SUPABASE_URL=http://localhost:8000
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
## ===== REQUIRED SECURITY KEYS ===== ##
# Generate using: from cryptography.fernet import Fernet;Fernet.generate_key().decode()
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
# RabbitMQ credentials -- Used for communication between services
RABBITMQ_HOST=localhost
RABBITMQ_PORT=5672
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
# Platform URLs (set these for webhooks and OAuth to work)
PLATFORM_BASE_URL=http://localhost:8000
FRONTEND_BASE_URL=http://localhost:3000
# Media Storage (required for marketplace and library functionality)
## GCS bucket is required for marketplace and library functionality
MEDIA_GCS_BUCKET_NAME=
## ===== API KEYS AND OAUTH CREDENTIALS ===== ##
# All API keys below are optional - only add what you need
## For local development, you may need to set FRONTEND_BASE_URL for the OAuth flow
## for integrations to work. Defaults to the value of PLATFORM_BASE_URL if not set.
# FRONTEND_BASE_URL=http://localhost:3000
# AI/LLM Services
OPENAI_API_KEY=
ANTHROPIC_API_KEY=
GROQ_API_KEY=
LLAMA_API_KEY=
AIML_API_KEY=
V0_API_KEY=
OPEN_ROUTER_API_KEY=
NVIDIA_API_KEY=
## PLATFORM_BASE_URL must be set to a *publicly accessible* URL pointing to your backend
## to use the platform's webhook-related functionality.
## If you are developing locally, you can use something like ngrok to get a publc URL
## and tunnel it to your locally running backend.
PLATFORM_BASE_URL=http://localhost:3000
## Cloudflare Turnstile (CAPTCHA) Configuration
## Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
## This is the backend secret key
TURNSTILE_SECRET_KEY=
## This is the verify URL
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
## == INTEGRATION CREDENTIALS == ##
# Each set of server side credentials is required for the corresponding 3rd party
# integration to work.
# OAuth Credentials
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
# e.g. http://localhost:3000/auth/integrations/oauth_callback
@@ -69,6 +85,7 @@ GITHUB_CLIENT_SECRET=
# Google OAuth App server credentials - https://console.cloud.google.com/apis/credentials, and enable gmail api and set scopes
# https://console.cloud.google.com/apis/credentials/consent ?project=<your_project_id>
# You'll need to add/enable the following scopes (minimum):
# https://console.developers.google.com/apis/api/gmail.googleapis.com/overview ?project=<your_project_id>
# https://console.cloud.google.com/apis/library/sheets.googleapis.com/ ?project=<your_project_id>
@@ -104,66 +121,96 @@ LINEAR_CLIENT_SECRET=
TODOIST_CLIENT_ID=
TODOIST_CLIENT_SECRET=
NOTION_CLIENT_ID=
NOTION_CLIENT_SECRET=
## ===== OPTIONAL API KEYS ===== ##
# LLM
OPENAI_API_KEY=
ANTHROPIC_API_KEY=
AIML_API_KEY=
GROQ_API_KEY=
OPEN_ROUTER_API_KEY=
LLAMA_API_KEY=
# Reddit
# Go to https://www.reddit.com/prefs/apps and create a new app
# Choose "script" for the type
# Fill in the redirect uri as <your_frontend_url>/auth/integrations/oauth_callback, e.g. http://localhost:3000/auth/integrations/oauth_callback
REDDIT_CLIENT_ID=
REDDIT_CLIENT_SECRET=
REDDIT_USER_AGENT="AutoGPT:1.0 (by /u/autogpt)"
# Payment Processing
STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=
# Email Service (for sending notifications and confirmations)
POSTMARK_SERVER_API_TOKEN=
POSTMARK_SENDER_EMAIL=invalid@invalid.com
POSTMARK_WEBHOOK_TOKEN=
# Error Tracking
SENTRY_DSN=
# Cloudflare Turnstile (CAPTCHA) Configuration
# Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
# This is the backend secret key
TURNSTILE_SECRET_KEY=
# This is the verify URL
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
# Feature Flags
LAUNCH_DARKLY_SDK_KEY=
# Content Generation & Media
DID_API_KEY=
FAL_API_KEY=
IDEOGRAM_API_KEY=
REPLICATE_API_KEY=
REVID_API_KEY=
SCREENSHOTONE_API_KEY=
UNREAL_SPEECH_API_KEY=
# Data & Search Services
E2B_API_KEY=
EXA_API_KEY=
JINA_API_KEY=
MEM0_API_KEY=
OPENWEATHERMAP_API_KEY=
GOOGLE_MAPS_API_KEY=
# Communication Services
# Discord
DISCORD_BOT_TOKEN=
MEDIUM_API_KEY=
MEDIUM_AUTHOR_ID=
# SMTP/Email
SMTP_SERVER=
SMTP_PORT=
SMTP_USERNAME=
SMTP_PASSWORD=
# Business & Marketing Tools
# D-ID
DID_API_KEY=
# Open Weather Map
OPENWEATHERMAP_API_KEY=
# SMTP
SMTP_SERVER=
SMTP_PORT=
SMTP_USERNAME=
SMTP_PASSWORD=
# Medium
MEDIUM_API_KEY=
MEDIUM_AUTHOR_ID=
# Google Maps
GOOGLE_MAPS_API_KEY=
# Replicate
REPLICATE_API_KEY=
# Ideogram
IDEOGRAM_API_KEY=
# Fal
FAL_API_KEY=
# Exa
EXA_API_KEY=
# E2B
E2B_API_KEY=
# Mem0
MEM0_API_KEY=
# Nvidia
NVIDIA_API_KEY=
# Apollo
APOLLO_API_KEY=
ENRICHLAYER_API_KEY=
AYRSHARE_API_KEY=
AYRSHARE_JWT_KEY=
# SmartLead
SMARTLEAD_API_KEY=
# ZeroBounce
ZEROBOUNCE_API_KEY=
# Other Services
AUTOMOD_API_KEY=
## ===== OPTIONAL API KEYS END ===== ##
# Block Error Rate Monitoring
BLOCK_ERROR_RATE_THRESHOLD=0.5
BLOCK_ERROR_RATE_CHECK_INTERVAL_SECS=86400
# Logging Configuration
LOG_LEVEL=INFO
ENABLE_CLOUD_LOGGING=false
ENABLE_FILE_LOGGING=false
# Use to manually set the log directory
# LOG_DIR=./logs
# Example Blocks Configuration
# Set to true to enable example blocks in development
# These blocks are disabled by default in production
ENABLE_EXAMPLE_BLOCKS=false

View File

@@ -1,4 +1,3 @@
.env
database.db
database.db-journal
dev.db

View File

@@ -8,14 +8,14 @@ WORKDIR /app
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
# Update package list and install build dependencies in a single layer
RUN apt-get update --allow-releaseinfo-change --fix-missing \
&& apt-get install -y \
build-essential \
libpq5 \
libz-dev \
libssl-dev \
postgresql-client
RUN apt-get update --allow-releaseinfo-change --fix-missing
# Install build dependencies
RUN apt-get install -y build-essential
RUN apt-get install -y libpq5
RUN apt-get install -y libz-dev
RUN apt-get install -y libssl-dev
RUN apt-get install -y postgresql-client
ENV POETRY_HOME=/opt/poetry
ENV POETRY_NO_INTERACTION=1
@@ -68,12 +68,6 @@ COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.tom
WORKDIR /app/autogpt_platform/backend
FROM server_dependencies AS migrate
# Migration stage only needs schema and migrations - much lighter than full backend
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
FROM server_dependencies AS server
COPY autogpt_platform/backend /app/autogpt_platform/backend

View File

@@ -1,10 +1,6 @@
import logging
from typing import TYPE_CHECKING
from dotenv import load_dotenv
load_dotenv()
if TYPE_CHECKING:
from backend.util.process import AppProcess
@@ -42,12 +38,12 @@ def main(**kwargs):
from backend.server.ws_api import WebsocketServer
run_processes(
DatabaseManager().set_log_level("warning"),
DatabaseManager(),
ExecutionManager(),
Scheduler(),
NotificationManager(),
WebsocketServer(),
AgentServer(),
ExecutionManager(),
**kwargs,
)

View File

@@ -1,14 +1,10 @@
import functools
import importlib
import logging
import os
import re
from pathlib import Path
from typing import TYPE_CHECKING, TypeVar
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from backend.data.block import Block
@@ -103,15 +99,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
available_blocks[block.id] = block_cls
# Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets
from backend.data.block import is_block_auth_configured
filtered_blocks = {}
for block_id, block_cls in available_blocks.items():
if is_block_auth_configured(block_cls):
filtered_blocks[block_id] = block_cls
return filtered_blocks
return available_blocks
__all__ = ["load_all_blocks"]

View File

@@ -1,3 +1,4 @@
import asyncio
import logging
from typing import Any, Optional
@@ -13,9 +14,8 @@ from backend.data.block import (
get_block,
)
from backend.data.execution import ExecutionStatus
from backend.data.model import NodeExecutionStats, SchemaField
from backend.util.json import validate_with_jsonschema
from backend.util.retry import func_retry
from backend.data.model import SchemaField
from backend.util import json, retry
_logger = logging.getLogger(__name__)
@@ -49,7 +49,7 @@ class AgentExecutorBlock(Block):
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
return validate_with_jsonschema(cls.get_input_schema(data), data)
return json.validate_with_jsonschema(cls.get_input_schema(data), data)
class Output(BlockSchema):
pass
@@ -74,6 +74,7 @@ class AgentExecutorBlock(Block):
user_id=input_data.user_id,
inputs=input_data.inputs,
nodes_input_masks=input_data.nodes_input_masks,
use_db_query=False,
)
logger = execution_utils.LogMetadata(
@@ -95,14 +96,23 @@ class AgentExecutorBlock(Block):
logger=logger,
):
yield name, data
except BaseException as e:
except asyncio.CancelledError:
await self._stop(
graph_exec_id=graph_exec.id,
user_id=input_data.user_id,
logger=logger,
)
logger.warning(
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} failed: {e.__class__.__name__} {str(e)}; execution is stopped."
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} was cancelled."
)
except Exception as e:
await self._stop(
graph_exec_id=graph_exec.id,
user_id=input_data.user_id,
logger=logger,
)
logger.error(
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} failed: {e}, execution is stopped."
)
raise
@@ -122,7 +132,6 @@ class AgentExecutorBlock(Block):
log_id = f"Graph #{graph_id}-V{graph_version}, exec-id: {graph_exec_id}"
logger.info(f"Starting execution of {log_id}")
yielded_node_exec_ids = set()
async for event in event_bus.listen(
user_id=user_id,
@@ -142,26 +151,12 @@ class AgentExecutorBlock(Block):
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
# If the graph execution is COMPLETED, TERMINATED, or FAILED,
# we can stop listening for further events.
self.merge_stats(
NodeExecutionStats(
extra_cost=event.stats.cost if event.stats else 0,
extra_steps=event.stats.node_exec_count if event.stats else 0,
)
)
break
logger.debug(
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
)
if event.node_exec_id in yielded_node_exec_ids:
logger.warning(
f"{log_id} received duplicate event for node execution {event.node_exec_id}"
)
continue
else:
yielded_node_exec_ids.add(event.node_exec_id)
if not event.block_id:
logger.warning(f"{log_id} received event without block_id {event}")
continue
@@ -181,7 +176,7 @@ class AgentExecutorBlock(Block):
)
yield output_name, output_data
@func_retry
@retry.func_retry
async def _stop(
self,
graph_exec_id: str,
@@ -197,8 +192,8 @@ class AgentExecutorBlock(Block):
await execution_utils.stop_graph_execution(
graph_exec_id=graph_exec_id,
user_id=user_id,
wait_timeout=3600,
use_db_query=False,
)
logger.info(f"Execution {log_id} stopped successfully.")
except TimeoutError as e:
logger.error(f"Execution {log_id} stop timed out: {e}")
except Exception as e:
logger.error(f"Failed to stop execution {log_id}: {e}")

File diff suppressed because it is too large Load Diff

View File

@@ -1,323 +0,0 @@
from os import getenv
from uuid import uuid4
import pytest
from backend.sdk import APIKeyCredentials, SecretStr
from ._api import (
TableFieldType,
WebhookFilters,
WebhookSpecification,
create_base,
create_field,
create_record,
create_table,
create_webhook,
delete_multiple_records,
delete_record,
delete_webhook,
get_record,
list_bases,
list_records,
list_webhook_payloads,
update_field,
update_multiple_records,
update_record,
update_table,
)
@pytest.mark.asyncio
async def test_create_update_table():
key = getenv("AIRTABLE_API_KEY")
if not key:
return pytest.skip("AIRTABLE_API_KEY is not set")
credentials = APIKeyCredentials(
provider="airtable",
api_key=SecretStr(key),
)
postfix = uuid4().hex[:4]
workspace_id = "wsphuHmfllg7V3Brd"
response = await create_base(credentials, workspace_id, "API Testing Base")
assert response is not None, f"Checking create base response: {response}"
assert (
response.get("id") is not None
), f"Checking create base response id: {response}"
base_id = response.get("id")
assert base_id is not None, f"Checking create base response id: {base_id}"
response = await list_bases(credentials)
assert response is not None, f"Checking list bases response: {response}"
assert "API Testing Base" in [
base.get("name") for base in response.get("bases", [])
], f"Checking list bases response bases: {response}"
table_name = f"test_table_{postfix}"
table_fields = [{"name": "test_field", "type": "singleLineText"}]
table = await create_table(credentials, base_id, table_name, table_fields)
assert table.get("name") == table_name
table_id = table.get("id")
assert table_id is not None
table_name = f"test_table_updated_{postfix}"
table_description = "test_description_updated"
table = await update_table(
credentials,
base_id,
table_id,
table_name=table_name,
table_description=table_description,
)
assert table.get("name") == table_name
assert table.get("description") == table_description
@pytest.mark.asyncio
async def test_invalid_field_type():
key = getenv("AIRTABLE_API_KEY")
if not key:
return pytest.skip("AIRTABLE_API_KEY is not set")
credentials = APIKeyCredentials(
provider="airtable",
api_key=SecretStr(key),
)
postfix = uuid4().hex[:4]
base_id = "appZPxegHEU3kDc1S"
table_name = f"test_table_{postfix}"
table_fields = [{"name": "test_field", "type": "notValid"}]
with pytest.raises(AssertionError):
await create_table(credentials, base_id, table_name, table_fields)
@pytest.mark.asyncio
async def test_create_and_update_field():
key = getenv("AIRTABLE_API_KEY")
if not key:
return pytest.skip("AIRTABLE_API_KEY is not set")
credentials = APIKeyCredentials(
provider="airtable",
api_key=SecretStr(key),
)
postfix = uuid4().hex[:4]
base_id = "appZPxegHEU3kDc1S"
table_name = f"test_table_{postfix}"
table_fields = [{"name": "test_field", "type": "singleLineText"}]
table = await create_table(credentials, base_id, table_name, table_fields)
assert table.get("name") == table_name
table_id = table.get("id")
assert table_id is not None
field_name = f"test_field_{postfix}"
field_type = TableFieldType.SINGLE_LINE_TEXT
field = await create_field(credentials, base_id, table_id, field_type, field_name)
assert field.get("name") == field_name
field_id = field.get("id")
assert field_id is not None
assert isinstance(field_id, str)
field_name = f"test_field_updated_{postfix}"
field = await update_field(credentials, base_id, table_id, field_id, field_name)
assert field.get("name") == field_name
field_description = "test_description_updated"
field = await update_field(
credentials, base_id, table_id, field_id, description=field_description
)
assert field.get("description") == field_description
@pytest.mark.asyncio
async def test_record_management():
key = getenv("AIRTABLE_API_KEY")
if not key:
return pytest.skip("AIRTABLE_API_KEY is not set")
credentials = APIKeyCredentials(
provider="airtable",
api_key=SecretStr(key),
)
postfix = uuid4().hex[:4]
base_id = "appZPxegHEU3kDc1S"
table_name = f"test_table_{postfix}"
table_fields = [{"name": "test_field", "type": "singleLineText"}]
table = await create_table(credentials, base_id, table_name, table_fields)
assert table.get("name") == table_name
table_id = table.get("id")
assert table_id is not None
# Create a record
record_fields = {"test_field": "test_value"}
record = await create_record(credentials, base_id, table_id, fields=record_fields)
fields = record.get("fields")
assert fields is not None
assert isinstance(fields, dict)
assert fields.get("test_field") == "test_value"
record_id = record.get("id")
assert record_id is not None
assert isinstance(record_id, str)
# Get a record
record = await get_record(credentials, base_id, table_id, record_id)
fields = record.get("fields")
assert fields is not None
assert isinstance(fields, dict)
assert fields.get("test_field") == "test_value"
# Updata a record
record_fields = {"test_field": "test_value_updated"}
record = await update_record(
credentials, base_id, table_id, record_id, fields=record_fields
)
fields = record.get("fields")
assert fields is not None
assert isinstance(fields, dict)
assert fields.get("test_field") == "test_value_updated"
# Delete a record
record = await delete_record(credentials, base_id, table_id, record_id)
assert record is not None
assert record.get("id") == record_id
assert record.get("deleted")
# Create 2 records
records = [
{"fields": {"test_field": "test_value_1"}},
{"fields": {"test_field": "test_value_2"}},
]
response = await create_record(credentials, base_id, table_id, records=records)
created_records = response.get("records")
assert created_records is not None
assert isinstance(created_records, list)
assert len(created_records) == 2, f"Created records: {created_records}"
first_record = created_records[0] # type: ignore
second_record = created_records[1] # type: ignore
first_record_id = first_record.get("id")
second_record_id = second_record.get("id")
assert first_record_id is not None
assert second_record_id is not None
assert first_record_id != second_record_id
first_fields = first_record.get("fields")
second_fields = second_record.get("fields")
assert first_fields is not None
assert second_fields is not None
assert first_fields.get("test_field") == "test_value_1" # type: ignore
assert second_fields.get("test_field") == "test_value_2" # type: ignore
# List records
response = await list_records(credentials, base_id, table_id)
records = response.get("records")
assert records is not None
assert len(records) == 2, f"Records: {records}"
assert isinstance(records, list), f"Type of records: {type(records)}"
# Update multiple records
records = [
{"id": first_record_id, "fields": {"test_field": "test_value_1_updated"}},
{"id": second_record_id, "fields": {"test_field": "test_value_2_updated"}},
]
response = await update_multiple_records(
credentials, base_id, table_id, records=records
)
updated_records = response.get("records")
assert updated_records is not None
assert len(updated_records) == 2, f"Updated records: {updated_records}"
assert isinstance(
updated_records, list
), f"Type of updated records: {type(updated_records)}"
first_updated = updated_records[0] # type: ignore
second_updated = updated_records[1] # type: ignore
first_updated_fields = first_updated.get("fields")
second_updated_fields = second_updated.get("fields")
assert first_updated_fields is not None
assert second_updated_fields is not None
assert first_updated_fields.get("test_field") == "test_value_1_updated" # type: ignore
assert second_updated_fields.get("test_field") == "test_value_2_updated" # type: ignore
# Delete multiple records
assert isinstance(first_record_id, str)
assert isinstance(second_record_id, str)
response = await delete_multiple_records(
credentials, base_id, table_id, records=[first_record_id, second_record_id]
)
deleted_records = response.get("records")
assert deleted_records is not None
assert len(deleted_records) == 2, f"Deleted records: {deleted_records}"
assert isinstance(
deleted_records, list
), f"Type of deleted records: {type(deleted_records)}"
first_deleted = deleted_records[0] # type: ignore
second_deleted = deleted_records[1] # type: ignore
assert first_deleted.get("deleted")
assert second_deleted.get("deleted")
@pytest.mark.asyncio
async def test_webhook_management():
key = getenv("AIRTABLE_API_KEY")
if not key:
return pytest.skip("AIRTABLE_API_KEY is not set")
credentials = APIKeyCredentials(
provider="airtable",
api_key=SecretStr(key),
)
postfix = uuid4().hex[:4]
base_id = "appZPxegHEU3kDc1S"
table_name = f"test_table_{postfix}"
table_fields = [{"name": "test_field", "type": "singleLineText"}]
table = await create_table(credentials, base_id, table_name, table_fields)
assert table.get("name") == table_name
table_id = table.get("id")
assert table_id is not None
webhook_specification = WebhookSpecification(
filters=WebhookFilters(
dataTypes=["tableData", "tableFields", "tableMetadata"],
changeTypes=["add", "update", "remove"],
)
)
response = await create_webhook(credentials, base_id, webhook_specification)
assert response is not None, f"Checking create webhook response: {response}"
assert (
response.get("id") is not None
), f"Checking create webhook response id: {response}"
assert (
response.get("macSecretBase64") is not None
), f"Checking create webhook response macSecretBase64: {response}"
webhook_id = response.get("id")
assert webhook_id is not None, f"Webhook ID: {webhook_id}"
assert isinstance(webhook_id, str)
response = await create_record(
credentials, base_id, table_id, fields={"test_field": "test_value"}
)
assert response is not None, f"Checking create record response: {response}"
assert (
response.get("id") is not None
), f"Checking create record response id: {response}"
fields = response.get("fields")
assert fields is not None, f"Checking create record response fields: {response}"
assert (
fields.get("test_field") == "test_value"
), f"Checking create record response fields test_field: {response}"
response = await list_webhook_payloads(credentials, base_id, webhook_id)
assert response is not None, f"Checking list webhook payloads response: {response}"
response = await delete_webhook(credentials, base_id, webhook_id)

View File

@@ -1,32 +0,0 @@
"""
Shared configuration for all Airtable blocks using the SDK pattern.
"""
from backend.sdk import BlockCostType, ProviderBuilder
from ._oauth import AirtableOAuthHandler, AirtableScope
from ._webhook import AirtableWebhookManager
# Configure the Airtable provider with API key authentication
airtable = (
ProviderBuilder("airtable")
.with_api_key("AIRTABLE_API_KEY", "Airtable Personal Access Token")
.with_webhook_manager(AirtableWebhookManager)
.with_base_cost(1, BlockCostType.RUN)
.with_oauth(
AirtableOAuthHandler,
scopes=[
v.value
for v in [
AirtableScope.DATA_RECORDS_READ,
AirtableScope.DATA_RECORDS_WRITE,
AirtableScope.SCHEMA_BASES_READ,
AirtableScope.SCHEMA_BASES_WRITE,
AirtableScope.WEBHOOK_MANAGE,
]
],
client_id_env_var="AIRTABLE_CLIENT_ID",
client_secret_env_var="AIRTABLE_CLIENT_SECRET",
)
.build()
)

View File

@@ -1,185 +0,0 @@
"""
Airtable OAuth handler implementation.
"""
import time
from enum import Enum
from logging import getLogger
from typing import Optional
from backend.sdk import BaseOAuthHandler, OAuth2Credentials, ProviderName, SecretStr
from ._api import (
OAuthTokenResponse,
make_oauth_authorize_url,
oauth_exchange_code_for_tokens,
oauth_refresh_tokens,
)
logger = getLogger(__name__)
class AirtableScope(str, Enum):
# Basic scopes
DATA_RECORDS_READ = "data.records:read"
DATA_RECORDS_WRITE = "data.records:write"
DATA_RECORD_COMMENTS_READ = "data.recordComments:read"
DATA_RECORD_COMMENTS_WRITE = "data.recordComments:write"
SCHEMA_BASES_READ = "schema.bases:read"
SCHEMA_BASES_WRITE = "schema.bases:write"
WEBHOOK_MANAGE = "webhook:manage"
BLOCK_MANAGE = "block:manage"
USER_EMAIL_READ = "user.email:read"
# Enterprise member scopes
ENTERPRISE_GROUPS_READ = "enterprise.groups:read"
WORKSPACES_AND_BASES_READ = "workspacesAndBases:read"
WORKSPACES_AND_BASES_WRITE = "workspacesAndBases:write"
WORKSPACES_AND_BASES_SHARES_MANAGE = "workspacesAndBases.shares:manage"
# Enterprise admin scopes
ENTERPRISE_SCIM_USERS_AND_GROUPS_MANAGE = "enterprise.scim.usersAndGroups:manage"
ENTERPRISE_AUDIT_LOGS_READ = "enterprise.auditLogs:read"
ENTERPRISE_CHANGE_EVENTS_READ = "enterprise.changeEvents:read"
ENTERPRISE_EXPORTS_MANAGE = "enterprise.exports:manage"
ENTERPRISE_ACCOUNT_READ = "enterprise.account:read"
ENTERPRISE_ACCOUNT_WRITE = "enterprise.account:write"
ENTERPRISE_USER_READ = "enterprise.user:read"
ENTERPRISE_USER_WRITE = "enterprise.user:write"
ENTERPRISE_GROUPS_MANAGE = "enterprise.groups:manage"
WORKSPACES_AND_BASES_MANAGE = "workspacesAndBases:manage"
HYPERDB_RECORDS_READ = "hyperDB.records:read"
HYPERDB_RECORDS_WRITE = "hyperDB.records:write"
class AirtableOAuthHandler(BaseOAuthHandler):
"""
OAuth2 handler for Airtable with PKCE support.
"""
PROVIDER_NAME = ProviderName("airtable")
DEFAULT_SCOPES = [
v.value
for v in [
AirtableScope.DATA_RECORDS_READ,
AirtableScope.DATA_RECORDS_WRITE,
AirtableScope.SCHEMA_BASES_READ,
AirtableScope.SCHEMA_BASES_WRITE,
AirtableScope.WEBHOOK_MANAGE,
]
]
def __init__(self, client_id: str, client_secret: Optional[str], redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
self.scopes = self.DEFAULT_SCOPES
self.auth_base_url = "https://airtable.com/oauth2/v1/authorize"
self.token_url = "https://airtable.com/oauth2/v1/token"
def get_login_url(
self, scopes: list[str], state: str, code_challenge: Optional[str]
) -> str:
logger.debug("Generating Airtable OAuth login URL")
# Generate code_challenge if not provided (PKCE is required)
if not scopes:
logger.debug("No scopes provided, using default scopes")
scopes = self.scopes
logger.debug(f"Using scopes: {scopes}")
logger.debug(f"State: {state}")
logger.debug(f"Code challenge: {code_challenge}")
if not code_challenge:
logger.error("Code challenge is required but none was provided")
raise ValueError("No code challenge provided")
try:
url = make_oauth_authorize_url(
self.client_id, self.redirect_uri, scopes, state, code_challenge
)
logger.debug(f"Generated OAuth URL: {url}")
return url
except Exception as e:
logger.error(f"Failed to generate OAuth URL: {str(e)}")
raise
async def exchange_code_for_tokens(
self, code: str, scopes: list[str], code_verifier: Optional[str]
) -> OAuth2Credentials:
logger.debug("Exchanging authorization code for tokens")
logger.debug(f"Code: {code[:4]}...") # Log first 4 chars only for security
logger.debug(f"Scopes: {scopes}")
if not code_verifier:
logger.error("Code verifier is required but none was provided")
raise ValueError("No code verifier provided")
try:
response: OAuthTokenResponse = await oauth_exchange_code_for_tokens(
client_id=self.client_id,
code=code,
code_verifier=code_verifier.encode("utf-8"),
redirect_uri=self.redirect_uri,
client_secret=self.client_secret,
)
logger.info("Successfully exchanged code for tokens")
credentials = OAuth2Credentials(
access_token=SecretStr(response.access_token),
refresh_token=SecretStr(response.refresh_token),
access_token_expires_at=int(time.time()) + response.expires_in,
refresh_token_expires_at=int(time.time()) + response.refresh_expires_in,
provider=self.PROVIDER_NAME,
scopes=scopes,
)
logger.debug(f"Access token expires in {response.expires_in} seconds")
logger.debug(
f"Refresh token expires in {response.refresh_expires_in} seconds"
)
return credentials
except Exception as e:
logger.error(f"Failed to exchange code for tokens: {str(e)}")
raise
async def _refresh_tokens(
self, credentials: OAuth2Credentials
) -> OAuth2Credentials:
logger.debug("Attempting to refresh OAuth tokens")
if credentials.refresh_token is None:
logger.error("Cannot refresh tokens - no refresh token available")
raise ValueError("No refresh token available")
try:
response: OAuthTokenResponse = await oauth_refresh_tokens(
client_id=self.client_id,
refresh_token=credentials.refresh_token.get_secret_value(),
client_secret=self.client_secret,
)
logger.info("Successfully refreshed tokens")
new_credentials = OAuth2Credentials(
id=credentials.id,
access_token=SecretStr(response.access_token),
refresh_token=SecretStr(response.refresh_token),
access_token_expires_at=int(time.time()) + response.expires_in,
refresh_token_expires_at=int(time.time()) + response.refresh_expires_in,
provider=self.PROVIDER_NAME,
scopes=self.scopes,
)
logger.debug(f"New access token expires in {response.expires_in} seconds")
logger.debug(
f"New refresh token expires in {response.refresh_expires_in} seconds"
)
return new_credentials
except Exception as e:
logger.error(f"Failed to refresh tokens: {str(e)}")
raise
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
logger.debug("Token revocation requested")
logger.info(
"Airtable doesn't provide a token revocation endpoint - tokens will expire naturally after 60 minutes"
)
return False

View File

@@ -1,154 +0,0 @@
"""
Webhook management for Airtable blocks.
"""
import hashlib
import hmac
import logging
from enum import Enum
from backend.sdk import (
BaseWebhooksManager,
Credentials,
ProviderName,
Webhook,
update_webhook,
)
from ._api import (
WebhookFilters,
WebhookSpecification,
create_webhook,
delete_webhook,
list_webhook_payloads,
)
logger = logging.getLogger(__name__)
class AirtableWebhookEvent(str, Enum):
TABLE_DATA = "tableData"
TABLE_FIELDS = "tableFields"
TABLE_METADATA = "tableMetadata"
class AirtableWebhookManager(BaseWebhooksManager):
"""Webhook manager for Airtable API."""
PROVIDER_NAME = ProviderName("airtable")
@classmethod
async def validate_payload(
cls, webhook: Webhook, request, credentials: Credentials | None
) -> tuple[dict, str]:
"""Validate incoming webhook payload and signature."""
if not credentials:
raise ValueError("Missing credentials in webhook metadata")
payload = await request.json()
# Verify webhook signature using HMAC-SHA256
if webhook.secret:
mac_secret = webhook.config.get("mac_secret")
if mac_secret:
# Get the raw body for signature verification
body = await request.body()
# Calculate expected signature
mac_secret_decoded = mac_secret.encode()
hmac_obj = hmac.new(mac_secret_decoded, body, hashlib.sha256)
expected_mac = f"hmac-sha256={hmac_obj.hexdigest()}"
# Get signature from headers
signature = request.headers.get("X-Airtable-Content-MAC")
if signature and not hmac.compare_digest(signature, expected_mac):
raise ValueError("Invalid webhook signature")
# Validate payload structure
required_fields = ["base", "webhook", "timestamp"]
if not all(field in payload for field in required_fields):
raise ValueError("Invalid webhook payload structure")
if "id" not in payload["base"] or "id" not in payload["webhook"]:
raise ValueError("Missing required IDs in webhook payload")
base_id = payload["base"]["id"]
webhook_id = payload["webhook"]["id"]
# get payload request parameters
cursor = webhook.config.get("cursor", 1)
response = await list_webhook_payloads(credentials, base_id, webhook_id, cursor)
# update webhook config
await update_webhook(
webhook.id,
config={"base_id": base_id, "cursor": response.cursor},
)
event_type = "notification"
return response.model_dump(), event_type
async def _register_webhook(
self,
credentials: Credentials,
webhook_type: str,
resource: str,
events: list[str],
ingress_url: str,
secret: str,
) -> tuple[str, dict]:
"""Register webhook with Airtable API."""
# Parse resource to get base_id and table_id/name
# Resource format: "{base_id}/{table_id_or_name}"
parts = resource.split("/", 1)
if len(parts) != 2:
raise ValueError("Resource must be in format: {base_id}/{table_id_or_name}")
base_id, table_id_or_name = parts
# Prepare webhook specification
webhook_specification = WebhookSpecification(
filters=WebhookFilters(
dataTypes=events,
)
)
# Create webhook
webhook_data = await create_webhook(
credentials=credentials,
base_id=base_id,
webhook_specification=webhook_specification,
notification_url=ingress_url,
)
webhook_id = webhook_data["id"]
mac_secret = webhook_data.get("macSecretBase64")
return webhook_id, {
"webhook_id": webhook_id,
"base_id": base_id,
"table_id_or_name": table_id_or_name,
"events": events,
"mac_secret": mac_secret,
"cursor": 1,
"expiration_time": webhook_data.get("expirationTime"),
}
async def _deregister_webhook(
self, webhook: Webhook, credentials: Credentials
) -> None:
"""Deregister webhook from Airtable API."""
base_id = webhook.config.get("base_id")
webhook_id = webhook.config.get("webhook_id")
if not base_id:
raise ValueError("Missing base_id in webhook metadata")
if not webhook_id:
raise ValueError("Missing webhook_id in webhook metadata")
await delete_webhook(credentials, base_id, webhook_id)

View File

@@ -1,122 +0,0 @@
"""
Airtable base operation blocks.
"""
from typing import Optional
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
SchemaField,
)
from ._api import create_base, list_bases
from ._config import airtable
class AirtableCreateBaseBlock(Block):
"""
Creates a new base in an Airtable workspace.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
workspace_id: str = SchemaField(
description="The workspace ID where the base will be created"
)
name: str = SchemaField(description="The name of the new base")
tables: list[dict] = SchemaField(
description="At least one table and field must be specified. Array of table objects to create in the base. Each table should have 'name' and 'fields' properties",
default=[
{
"description": "Default table",
"name": "Default table",
"fields": [
{
"name": "ID",
"type": "number",
"description": "Auto-incrementing ID field",
"options": {"precision": 0},
}
],
}
],
)
class Output(BlockSchema):
base_id: str = SchemaField(description="The ID of the created base")
tables: list[dict] = SchemaField(description="Array of table objects")
table: dict = SchemaField(description="A single table object")
def __init__(self):
super().__init__(
id="f59b88a8-54ce-4676-a508-fd614b4e8dce",
description="Create a new base in Airtable",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
data = await create_base(
credentials,
input_data.workspace_id,
input_data.name,
input_data.tables,
)
yield "base_id", data.get("id", None)
yield "tables", data.get("tables", [])
for table in data.get("tables", []):
yield "table", table
class AirtableListBasesBlock(Block):
"""
Lists all bases in an Airtable workspace that the user has access to.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
trigger: str = SchemaField(
description="Trigger the block to run - value is ignored", default="manual"
)
offset: str = SchemaField(
description="Pagination offset from previous request", default=""
)
class Output(BlockSchema):
bases: list[dict] = SchemaField(description="Array of base objects")
offset: Optional[str] = SchemaField(
description="Offset for next page (null if no more bases)", default=None
)
def __init__(self):
super().__init__(
id="4bd8d466-ed5d-4e44-8083-97f25a8044e7",
description="List all bases in Airtable",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
data = await list_bases(
credentials,
offset=input_data.offset if input_data.offset else None,
)
yield "bases", data.get("bases", [])
yield "offset", data.get("offset", None)

View File

@@ -1,283 +0,0 @@
"""
Airtable record operation blocks.
"""
from typing import Optional
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
SchemaField,
)
from ._api import (
create_record,
delete_multiple_records,
get_record,
list_records,
update_multiple_records,
)
from ._config import airtable
class AirtableListRecordsBlock(Block):
"""
Lists records from an Airtable table with optional filtering, sorting, and pagination.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="The Airtable base ID")
table_id_or_name: str = SchemaField(description="Table ID or name")
filter_formula: str = SchemaField(
description="Airtable formula to filter records", default=""
)
view: str = SchemaField(description="View ID or name to use", default="")
sort: list[dict] = SchemaField(
description="Sort configuration (array of {field, direction})", default=[]
)
max_records: int = SchemaField(
description="Maximum number of records to return", default=100
)
page_size: int = SchemaField(
description="Number of records per page (max 100)", default=100
)
offset: str = SchemaField(
description="Pagination offset from previous request", default=""
)
return_fields: list[str] = SchemaField(
description="Specific fields to return (comma-separated)", default=[]
)
class Output(BlockSchema):
records: list[dict] = SchemaField(description="Array of record objects")
offset: Optional[str] = SchemaField(
description="Offset for next page (null if no more records)", default=None
)
def __init__(self):
super().__init__(
id="588a9fde-5733-4da7-b03c-35f5671e960f",
description="List records from an Airtable table",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
data = await list_records(
credentials,
input_data.base_id,
input_data.table_id_or_name,
filter_by_formula=(
input_data.filter_formula if input_data.filter_formula else None
),
view=input_data.view if input_data.view else None,
sort=input_data.sort if input_data.sort else None,
max_records=input_data.max_records if input_data.max_records else None,
page_size=min(input_data.page_size, 100) if input_data.page_size else None,
offset=input_data.offset if input_data.offset else None,
fields=input_data.return_fields if input_data.return_fields else None,
)
yield "records", data.get("records", [])
yield "offset", data.get("offset", None)
class AirtableGetRecordBlock(Block):
"""
Retrieves a single record from an Airtable table by its ID.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="The Airtable base ID")
table_id_or_name: str = SchemaField(description="Table ID or name")
record_id: str = SchemaField(description="The record ID to retrieve")
class Output(BlockSchema):
id: str = SchemaField(description="The record ID")
fields: dict = SchemaField(description="The record fields")
created_time: str = SchemaField(description="The record created time")
def __init__(self):
super().__init__(
id="c29c5cbf-0aff-40f9-bbb5-f26061792d2b",
description="Get a single record from Airtable",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
record = await get_record(
credentials,
input_data.base_id,
input_data.table_id_or_name,
input_data.record_id,
)
yield "id", record.get("id", None)
yield "fields", record.get("fields", None)
yield "created_time", record.get("createdTime", None)
class AirtableCreateRecordsBlock(Block):
"""
Creates one or more records in an Airtable table.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="The Airtable base ID")
table_id_or_name: str = SchemaField(description="Table ID or name")
records: list[dict] = SchemaField(
description="Array of records to create (each with 'fields' object)"
)
typecast: bool = SchemaField(
description="Automatically convert string values to appropriate types",
default=False,
)
return_fields_by_field_id: bool | None = SchemaField(
description="Return fields by field ID",
default=None,
)
class Output(BlockSchema):
records: list[dict] = SchemaField(description="Array of created record objects")
details: dict = SchemaField(description="Details of the created records")
def __init__(self):
super().__init__(
id="42527e98-47b6-44ce-ac0e-86b4883721d3",
description="Create records in an Airtable table",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
# The create_record API expects records in a specific format
data = await create_record(
credentials,
input_data.base_id,
input_data.table_id_or_name,
records=[{"fields": record} for record in input_data.records],
typecast=input_data.typecast if input_data.typecast else None,
return_fields_by_field_id=input_data.return_fields_by_field_id,
)
yield "records", data.get("records", [])
details = data.get("details", None)
if details:
yield "details", details
class AirtableUpdateRecordsBlock(Block):
"""
Updates one or more existing records in an Airtable table.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="The Airtable base ID")
table_id_or_name: str = SchemaField(
description="Table ID or name - It's better to use the table ID instead of the name"
)
records: list[dict] = SchemaField(
description="Array of records to update (each with 'id' and 'fields')"
)
typecast: bool | None = SchemaField(
description="Automatically convert string values to appropriate types",
default=None,
)
class Output(BlockSchema):
records: list[dict] = SchemaField(description="Array of updated record objects")
def __init__(self):
super().__init__(
id="6e7d2590-ac2b-4b5d-b08c-fc039cd77e1f",
description="Update records in an Airtable table",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
# The update_multiple_records API expects records with id and fields
data = await update_multiple_records(
credentials,
input_data.base_id,
input_data.table_id_or_name,
records=input_data.records,
typecast=input_data.typecast if input_data.typecast else None,
return_fields_by_field_id=False, # Use field names, not IDs
)
yield "records", data.get("records", [])
class AirtableDeleteRecordsBlock(Block):
"""
Deletes one or more records from an Airtable table.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="The Airtable base ID")
table_id_or_name: str = SchemaField(
description="Table ID or name - It's better to use the table ID instead of the name"
)
record_ids: list[str] = SchemaField(
description="Array of upto 10 record IDs to delete"
)
class Output(BlockSchema):
records: list[dict] = SchemaField(description="Array of deletion results")
def __init__(self):
super().__init__(
id="93e22b8b-3642-4477-aefb-1c0929a4a3a6",
description="Delete records from an Airtable table",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
if len(input_data.record_ids) > 10:
yield "error", "Only upto 10 record IDs can be deleted at a time"
else:
data = await delete_multiple_records(
credentials,
input_data.base_id,
input_data.table_id_or_name,
input_data.record_ids,
)
yield "records", data.get("records", [])

View File

@@ -1,252 +0,0 @@
"""
Airtable schema and table management blocks.
"""
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
Requests,
SchemaField,
)
from ._api import TableFieldType, create_field, create_table, update_field, update_table
from ._config import airtable
class AirtableListSchemaBlock(Block):
"""
Retrieves the complete schema of an Airtable base, including all tables,
fields, and views.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="The Airtable base ID")
class Output(BlockSchema):
base_schema: dict = SchemaField(
description="Complete base schema with tables, fields, and views"
)
tables: list[dict] = SchemaField(description="Array of table objects")
def __init__(self):
super().__init__(
id="64291d3c-99b5-47b7-a976-6d94293cdb2d",
description="Get the complete schema of an Airtable base",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
api_key = credentials.api_key.get_secret_value()
# Get base schema
response = await Requests().get(
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/tables",
headers={"Authorization": f"Bearer {api_key}"},
)
data = response.json()
yield "base_schema", data
yield "tables", data.get("tables", [])
class AirtableCreateTableBlock(Block):
"""
Creates a new table in an Airtable base with specified fields and views.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="The Airtable base ID")
table_name: str = SchemaField(description="The name of the table to create")
table_fields: list[dict] = SchemaField(
description="Table fields with name, type, and options",
default=[{"name": "Name", "type": "singleLineText"}],
)
class Output(BlockSchema):
table: dict = SchemaField(description="Created table object")
table_id: str = SchemaField(description="ID of the created table")
def __init__(self):
super().__init__(
id="fcc20ced-d817-42ea-9b40-c35e7bf34b4f",
description="Create a new table in an Airtable base",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
table_data = await create_table(
credentials,
input_data.base_id,
input_data.table_name,
input_data.table_fields,
)
yield "table", table_data
yield "table_id", table_data.get("id", "")
class AirtableUpdateTableBlock(Block):
"""
Updates an existing table's properties such as name or description.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="The Airtable base ID")
table_id: str = SchemaField(description="The table ID to update")
table_name: str | None = SchemaField(
description="The name of the table to update", default=None
)
table_description: str | None = SchemaField(
description="The description of the table to update", default=None
)
date_dependency: dict | None = SchemaField(
description="The date dependency of the table to update", default=None
)
class Output(BlockSchema):
table: dict = SchemaField(description="Updated table object")
def __init__(self):
super().__init__(
id="34077c5f-f962-49f2-9ec6-97c67077013a",
description="Update table properties",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
table_data = await update_table(
credentials,
input_data.base_id,
input_data.table_id,
input_data.table_name,
input_data.table_description,
input_data.date_dependency,
)
yield "table", table_data
class AirtableCreateFieldBlock(Block):
"""
Adds a new field (column) to an existing Airtable table.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="The Airtable base ID")
table_id: str = SchemaField(description="The table ID to add field to")
field_type: TableFieldType = SchemaField(
description="The type of the field to create",
default=TableFieldType.SINGLE_LINE_TEXT,
advanced=False,
)
name: str = SchemaField(description="The name of the field to create")
description: str | None = SchemaField(
description="The description of the field to create", default=None
)
options: dict[str, str] | None = SchemaField(
description="The options of the field to create", default=None
)
class Output(BlockSchema):
field: dict = SchemaField(description="Created field object")
field_id: str = SchemaField(description="ID of the created field")
def __init__(self):
super().__init__(
id="6c98a32f-dbf9-45d8-a2a8-5e97e8326351",
description="Add a new field to an Airtable table",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
field_data = await create_field(
credentials,
input_data.base_id,
input_data.table_id,
input_data.field_type,
input_data.name,
)
yield "field", field_data
yield "field_id", field_data.get("id", "")
class AirtableUpdateFieldBlock(Block):
"""
Updates an existing field's properties in an Airtable table.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="The Airtable base ID")
table_id: str = SchemaField(description="The table ID containing the field")
field_id: str = SchemaField(description="The field ID to update")
name: str | None = SchemaField(
description="The name of the field to update", default=None, advanced=False
)
description: str | None = SchemaField(
description="The description of the field to update",
default=None,
advanced=False,
)
class Output(BlockSchema):
field: dict = SchemaField(description="Updated field object")
def __init__(self):
super().__init__(
id="f46ac716-3b18-4da1-92e4-34ca9a464d48",
description="Update field properties in an Airtable table",
categories={BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
field_data = await update_field(
credentials,
input_data.base_id,
input_data.table_id,
input_data.field_id,
input_data.name,
input_data.description,
)
yield "field", field_data

View File

@@ -1,113 +0,0 @@
from backend.sdk import (
BaseModel,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
BlockWebhookConfig,
CredentialsMetaInput,
ProviderName,
SchemaField,
)
from ._api import WebhookPayload
from ._config import airtable
class AirtableEventSelector(BaseModel):
"""
Selects the Airtable webhook event to trigger on.
"""
tableData: bool = True
tableFields: bool = True
tableMetadata: bool = True
class AirtableWebhookTriggerBlock(Block):
"""
Starts a flow whenever Airtable emits a webhook event.
Thin wrapper just forwards the payloads one at a time to the next block.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = airtable.credentials_field(
description="Airtable API credentials"
)
base_id: str = SchemaField(description="Airtable base ID")
table_id_or_name: str = SchemaField(description="Airtable table ID or name")
payload: dict = SchemaField(hidden=True, default_factory=dict)
events: AirtableEventSelector = SchemaField(
description="Airtable webhook event filter"
)
class Output(BlockSchema):
payload: WebhookPayload = SchemaField(description="Airtable webhook payload")
def __init__(self):
example_payload = {
"payloads": [
{
"timestamp": "2022-02-01T21:25:05.663Z",
"baseTransactionNumber": 4,
"actionMetadata": {
"source": "client",
"sourceMetadata": {
"user": {
"id": "usr00000000000000",
"email": "foo@bar.com",
"permissionLevel": "create",
}
},
},
"payloadFormat": "v0",
}
],
"cursor": 5,
"mightHaveMore": False,
}
super().__init__(
# NOTE: This is disabled whilst the webhook system is finalised.
disabled=False,
id="d0180ce6-ccb9-48c7-8256-b39e93e62801",
description="Starts a flow whenever Airtable emits a webhook event",
categories={BlockCategory.INPUT, BlockCategory.DATA},
input_schema=self.Input,
output_schema=self.Output,
block_type=BlockType.WEBHOOK,
webhook_config=BlockWebhookConfig(
provider=ProviderName("airtable"),
webhook_type="not-used",
event_filter_input="events",
event_format="{event}",
resource_format="{base_id}/{table_id_or_name}",
),
test_input={
"credentials": airtable.get_test_credentials().model_dump(),
"base_id": "app1234567890",
"table_id_or_name": "table1234567890",
"events": AirtableEventSelector(
tableData=True,
tableFields=True,
tableMetadata=False,
).model_dump(),
"payload": example_payload,
},
test_credentials=airtable.get_test_credentials(),
test_output=[
(
"payload",
WebhookPayload.model_validate(example_payload["payloads"][0]),
),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
if len(input_data.payload["payloads"]) > 0:
for item in input_data.payload["payloads"]:
yield "payload", WebhookPayload.model_validate(item)
else:
yield "error", "No valid payloads found in webhook payload"

View File

@@ -1,15 +0,0 @@
AYRSHARE_BLOCK_IDS = [
"cbd52c2a-06d2-43ed-9560-6576cc163283", # PostToBlueskyBlock
"3352f512-3524-49ed-a08f-003042da2fc1", # PostToFacebookBlock
"9e8f844e-b4a5-4b25-80f2-9e1dd7d67625", # PostToXBlock
"589af4e4-507f-42fd-b9ac-a67ecef25811", # PostToLinkedInBlock
"89b02b96-a7cb-46f4-9900-c48b32fe1552", # PostToInstagramBlock
"0082d712-ff1b-4c3d-8a8d-6c7721883b83", # PostToYouTubeBlock
"c7733580-3c72-483e-8e47-a8d58754d853", # PostToRedditBlock
"47bc74eb-4af2-452c-b933-af377c7287df", # PostToTelegramBlock
"2c38c783-c484-4503-9280-ef5d1d345a7e", # PostToGMBBlock
"3ca46e05-dbaa-4afb-9e95-5a429c4177e6", # PostToPinterestBlock
"7faf4b27-96b0-4f05-bf64-e0de54ae74e1", # PostToTikTokBlock
"f8c3b2e1-9d4a-4e5f-8c7b-6a9e8d2f1c3b", # PostToThreadsBlock
"a9d7f854-2c83-4e96-b3a1-7f2e9c5d4b8e", # PostToSnapchatBlock
]

View File

@@ -1,152 +0,0 @@
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
from backend.data.block import BlockSchema
from backend.data.model import SchemaField, UserIntegrations
from backend.integrations.ayrshare import AyrshareClient
from backend.util.clients import get_database_manager_async_client
from backend.util.exceptions import MissingConfigError
async def get_profile_key(user_id: str):
user_integrations: UserIntegrations = (
await get_database_manager_async_client().get_user_integrations(user_id)
)
return user_integrations.managed_credentials.ayrshare_profile_key
class BaseAyrshareInput(BlockSchema):
"""Base input model for Ayrshare social media posts with common fields."""
post: str = SchemaField(
description="The post text to be published", default="", advanced=False
)
media_urls: list[str] = SchemaField(
description="Optional list of media URLs to include. Set is_video in advanced settings to true if you want to upload videos.",
default_factory=list,
advanced=False,
)
is_video: bool = SchemaField(
description="Whether the media is a video", default=False, advanced=True
)
schedule_date: Optional[datetime] = SchemaField(
description="UTC datetime for scheduling (YYYY-MM-DDThh:mm:ssZ)",
default=None,
advanced=True,
)
disable_comments: bool = SchemaField(
description="Whether to disable comments", default=False, advanced=True
)
shorten_links: bool = SchemaField(
description="Whether to shorten links", default=False, advanced=True
)
unsplash: Optional[str] = SchemaField(
description="Unsplash image configuration", default=None, advanced=True
)
requires_approval: bool = SchemaField(
description="Whether to enable approval workflow",
default=False,
advanced=True,
)
random_post: bool = SchemaField(
description="Whether to generate random post text",
default=False,
advanced=True,
)
random_media_url: bool = SchemaField(
description="Whether to generate random media", default=False, advanced=True
)
notes: Optional[str] = SchemaField(
description="Additional notes for the post", default=None, advanced=True
)
class CarouselItem(BaseModel):
"""Model for Facebook carousel items."""
name: str = Field(..., description="The name of the item")
link: str = Field(..., description="The link of the item")
picture: str = Field(..., description="The picture URL of the item")
class CallToAction(BaseModel):
"""Model for Google My Business Call to Action."""
action_type: str = Field(
..., description="Type of action (book, order, shop, learn_more, sign_up, call)"
)
url: Optional[str] = Field(
description="URL for the action (not required for 'call' action)"
)
class EventDetails(BaseModel):
"""Model for Google My Business Event details."""
title: str = Field(..., description="Event title")
start_date: str = Field(..., description="Event start date (ISO format)")
end_date: str = Field(..., description="Event end date (ISO format)")
class OfferDetails(BaseModel):
"""Model for Google My Business Offer details."""
title: str = Field(..., description="Offer title")
start_date: str = Field(..., description="Offer start date (ISO format)")
end_date: str = Field(..., description="Offer end date (ISO format)")
coupon_code: str = Field(..., description="Coupon code (max 58 characters)")
redeem_online_url: str = Field(..., description="URL to redeem the offer")
terms_conditions: str = Field(..., description="Terms and conditions")
class InstagramUserTag(BaseModel):
"""Model for Instagram user tags."""
username: str = Field(..., description="Instagram username (without @)")
x: Optional[float] = Field(description="X coordinate (0.0-1.0) for image posts")
y: Optional[float] = Field(description="Y coordinate (0.0-1.0) for image posts")
class LinkedInTargeting(BaseModel):
"""Model for LinkedIn audience targeting."""
countries: Optional[list[str]] = Field(
description="Country codes (e.g., ['US', 'IN', 'DE', 'GB'])"
)
seniorities: Optional[list[str]] = Field(
description="Seniority levels (e.g., ['Senior', 'VP'])"
)
degrees: Optional[list[str]] = Field(description="Education degrees")
fields_of_study: Optional[list[str]] = Field(description="Fields of study")
industries: Optional[list[str]] = Field(description="Industry categories")
job_functions: Optional[list[str]] = Field(description="Job function categories")
staff_count_ranges: Optional[list[str]] = Field(description="Company size ranges")
class PinterestCarouselOption(BaseModel):
"""Model for Pinterest carousel image options."""
title: Optional[str] = Field(description="Image title")
link: Optional[str] = Field(description="External destination link for the image")
description: Optional[str] = Field(description="Image description")
class YouTubeTargeting(BaseModel):
"""Model for YouTube country targeting."""
block: Optional[list[str]] = Field(
description="Country codes to block (e.g., ['US', 'CA'])"
)
allow: Optional[list[str]] = Field(
description="Country codes to allow (e.g., ['GB', 'AU'])"
)
def create_ayrshare_client():
"""Create an Ayrshare client instance."""
try:
return AyrshareClient()
except MissingConfigError:
return None

View File

@@ -1,114 +0,0 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class PostToBlueskyBlock(Block):
"""Block for posting to Bluesky with Bluesky-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for Bluesky posts."""
# Override post field to include character limit information
post: str = SchemaField(
description="The post text to be published (max 300 characters for Bluesky)",
default="",
advanced=False,
)
# Override media_urls to include Bluesky-specific constraints
media_urls: list[str] = SchemaField(
description="Optional list of media URLs to include. Bluesky supports up to 4 images or 1 video.",
default_factory=list,
advanced=False,
)
# Bluesky-specific options
alt_text: list[str] = SchemaField(
description="Alt text for each media item (accessibility)",
default_factory=list,
advanced=True,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
disabled=True,
id="cbd52c2a-06d2-43ed-9560-6576cc163283",
description="Post to Bluesky using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToBlueskyBlock.Input,
output_schema=PostToBlueskyBlock.Output,
)
async def run(
self,
input_data: "PostToBlueskyBlock.Input",
*,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Bluesky with Bluesky-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate character limit for Bluesky
if len(input_data.post) > 300:
yield "error", f"Post text exceeds Bluesky's 300 character limit ({len(input_data.post)} characters)"
return
# Validate media constraints for Bluesky
if len(input_data.media_urls) > 4:
yield "error", "Bluesky supports a maximum of 4 images or 1 video"
return
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
# Build Bluesky-specific options
bluesky_options = {}
if input_data.alt_text:
bluesky_options["altText"] = input_data.alt_text
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.BLUESKY],
media_urls=input_data.media_urls,
is_video=input_data.is_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
bluesky_options=bluesky_options if bluesky_options else None,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,212 +0,0 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import (
BaseAyrshareInput,
CarouselItem,
create_ayrshare_client,
get_profile_key,
)
class PostToFacebookBlock(Block):
"""Block for posting to Facebook with Facebook-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for Facebook posts."""
# Facebook-specific options
is_carousel: bool = SchemaField(
description="Whether to post a carousel", default=False, advanced=True
)
carousel_link: str = SchemaField(
description="The URL for the 'See More At' button in the carousel",
default="",
advanced=True,
)
carousel_items: list[CarouselItem] = SchemaField(
description="List of carousel items with name, link and picture URLs. Min 2, max 10 items.",
default_factory=list,
advanced=True,
)
is_reels: bool = SchemaField(
description="Whether to post to Facebook Reels",
default=False,
advanced=True,
)
reels_title: str = SchemaField(
description="Title for the Reels video (max 255 chars)",
default="",
advanced=True,
)
reels_thumbnail: str = SchemaField(
description="Thumbnail URL for Reels video (JPEG/PNG, <10MB)",
default="",
advanced=True,
)
is_story: bool = SchemaField(
description="Whether to post as a Facebook Story",
default=False,
advanced=True,
)
media_captions: list[str] = SchemaField(
description="Captions for each media item",
default_factory=list,
advanced=True,
)
location_id: str = SchemaField(
description="Facebook Page ID or name for location tagging",
default="",
advanced=True,
)
age_min: int = SchemaField(
description="Minimum age for audience targeting (13,15,18,21,25)",
default=0,
advanced=True,
)
target_countries: list[str] = SchemaField(
description="List of country codes to target (max 25)",
default_factory=list,
advanced=True,
)
alt_text: list[str] = SchemaField(
description="Alt text for each media item",
default_factory=list,
advanced=True,
)
video_title: str = SchemaField(
description="Title for video post", default="", advanced=True
)
video_thumbnail: str = SchemaField(
description="Thumbnail URL for video post", default="", advanced=True
)
is_draft: bool = SchemaField(
description="Save as draft in Meta Business Suite",
default=False,
advanced=True,
)
scheduled_publish_date: str = SchemaField(
description="Schedule publish time in Meta Business Suite (UTC)",
default="",
advanced=True,
)
preview_link: str = SchemaField(
description="URL for custom link preview", default="", advanced=True
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
disabled=True,
id="3352f512-3524-49ed-a08f-003042da2fc1",
description="Post to Facebook using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToFacebookBlock.Input,
output_schema=PostToFacebookBlock.Output,
)
async def run(
self,
input_data: "PostToFacebookBlock.Input",
*,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Facebook with Facebook-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
# Build Facebook-specific options
facebook_options = {}
if input_data.is_carousel:
facebook_options["isCarousel"] = True
if input_data.carousel_link:
facebook_options["carouselLink"] = input_data.carousel_link
if input_data.carousel_items:
facebook_options["carouselItems"] = [
item.dict() for item in input_data.carousel_items
]
if input_data.is_reels:
facebook_options["isReels"] = True
if input_data.reels_title:
facebook_options["reelsTitle"] = input_data.reels_title
if input_data.reels_thumbnail:
facebook_options["reelsThumbnail"] = input_data.reels_thumbnail
if input_data.is_story:
facebook_options["isStory"] = True
if input_data.media_captions:
facebook_options["mediaCaptions"] = input_data.media_captions
if input_data.location_id:
facebook_options["locationId"] = input_data.location_id
if input_data.age_min > 0:
facebook_options["ageMin"] = input_data.age_min
if input_data.target_countries:
facebook_options["targetCountries"] = input_data.target_countries
if input_data.alt_text:
facebook_options["altText"] = input_data.alt_text
if input_data.video_title:
facebook_options["videoTitle"] = input_data.video_title
if input_data.video_thumbnail:
facebook_options["videoThumbnail"] = input_data.video_thumbnail
if input_data.is_draft:
facebook_options["isDraft"] = True
if input_data.scheduled_publish_date:
facebook_options["scheduledPublishDate"] = input_data.scheduled_publish_date
if input_data.preview_link:
facebook_options["previewLink"] = input_data.preview_link
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.FACEBOOK],
media_urls=input_data.media_urls,
is_video=input_data.is_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
facebook_options=facebook_options if facebook_options else None,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,210 +0,0 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class PostToGMBBlock(Block):
"""Block for posting to Google My Business with GMB-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for Google My Business posts."""
# Override media_urls to include GMB-specific constraints
media_urls: list[str] = SchemaField(
description="Optional list of media URLs. GMB supports only one image or video per post.",
default_factory=list,
advanced=False,
)
# GMB-specific options
is_photo_video: bool = SchemaField(
description="Whether this is a photo/video post (appears in Photos section)",
default=False,
advanced=True,
)
photo_category: str = SchemaField(
description="Category for photo/video: cover, profile, logo, exterior, interior, product, at_work, food_and_drink, menu, common_area, rooms, teams",
default="",
advanced=True,
)
# Call to action options (flattened from CallToAction object)
call_to_action_type: str = SchemaField(
description="Type of action button: 'book', 'order', 'shop', 'learn_more', 'sign_up', or 'call'",
default="",
advanced=True,
)
call_to_action_url: str = SchemaField(
description="URL for the action button (not required for 'call' action)",
default="",
advanced=True,
)
# Event details options (flattened from EventDetails object)
event_title: str = SchemaField(
description="Event title for event posts",
default="",
advanced=True,
)
event_start_date: str = SchemaField(
description="Event start date in ISO format (e.g., '2024-03-15T09:00:00Z')",
default="",
advanced=True,
)
event_end_date: str = SchemaField(
description="Event end date in ISO format (e.g., '2024-03-15T17:00:00Z')",
default="",
advanced=True,
)
# Offer details options (flattened from OfferDetails object)
offer_title: str = SchemaField(
description="Offer title for promotional posts",
default="",
advanced=True,
)
offer_start_date: str = SchemaField(
description="Offer start date in ISO format (e.g., '2024-03-15T00:00:00Z')",
default="",
advanced=True,
)
offer_end_date: str = SchemaField(
description="Offer end date in ISO format (e.g., '2024-04-15T23:59:59Z')",
default="",
advanced=True,
)
offer_coupon_code: str = SchemaField(
description="Coupon code for the offer (max 58 characters)",
default="",
advanced=True,
)
offer_redeem_online_url: str = SchemaField(
description="URL where customers can redeem the offer online",
default="",
advanced=True,
)
offer_terms_conditions: str = SchemaField(
description="Terms and conditions for the offer",
default="",
advanced=True,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
disabled=True,
id="2c38c783-c484-4503-9280-ef5d1d345a7e",
description="Post to Google My Business using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToGMBBlock.Input,
output_schema=PostToGMBBlock.Output,
)
async def run(
self, input_data: "PostToGMBBlock.Input", *, user_id: str, **kwargs
) -> BlockOutput:
"""Post to Google My Business with GMB-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate GMB constraints
if len(input_data.media_urls) > 1:
yield "error", "Google My Business supports only one image or video per post"
return
# Validate offer coupon code length
if input_data.offer_coupon_code and len(input_data.offer_coupon_code) > 58:
yield "error", "GMB offer coupon code cannot exceed 58 characters"
return
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
# Build GMB-specific options
gmb_options = {}
# Photo/Video post options
if input_data.is_photo_video:
gmb_options["isPhotoVideo"] = True
if input_data.photo_category:
gmb_options["category"] = input_data.photo_category
# Call to Action (from flattened fields)
if input_data.call_to_action_type:
cta_dict = {"actionType": input_data.call_to_action_type}
# URL not required for 'call' action type
if (
input_data.call_to_action_type != "call"
and input_data.call_to_action_url
):
cta_dict["url"] = input_data.call_to_action_url
gmb_options["callToAction"] = cta_dict
# Event details (from flattened fields)
if (
input_data.event_title
and input_data.event_start_date
and input_data.event_end_date
):
gmb_options["event"] = {
"title": input_data.event_title,
"startDate": input_data.event_start_date,
"endDate": input_data.event_end_date,
}
# Offer details (from flattened fields)
if (
input_data.offer_title
and input_data.offer_start_date
and input_data.offer_end_date
and input_data.offer_coupon_code
and input_data.offer_redeem_online_url
and input_data.offer_terms_conditions
):
gmb_options["offer"] = {
"title": input_data.offer_title,
"startDate": input_data.offer_start_date,
"endDate": input_data.offer_end_date,
"couponCode": input_data.offer_coupon_code,
"redeemOnlineUrl": input_data.offer_redeem_online_url,
"termsConditions": input_data.offer_terms_conditions,
}
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.GOOGLE_MY_BUSINESS],
media_urls=input_data.media_urls,
is_video=input_data.is_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
gmb_options=gmb_options if gmb_options else None,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,249 +0,0 @@
from typing import Any
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import (
BaseAyrshareInput,
InstagramUserTag,
create_ayrshare_client,
get_profile_key,
)
class PostToInstagramBlock(Block):
"""Block for posting to Instagram with Instagram-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for Instagram posts."""
# Override post field to include Instagram-specific information
post: str = SchemaField(
description="The post text (max 2,200 chars, up to 30 hashtags, 3 @mentions)",
default="",
advanced=False,
)
# Override media_urls to include Instagram-specific constraints
media_urls: list[str] = SchemaField(
description="Optional list of media URLs. Instagram supports up to 10 images/videos in a carousel.",
default_factory=list,
advanced=False,
)
# Instagram-specific options
is_story: bool | None = SchemaField(
description="Whether to post as Instagram Story (24-hour expiration)",
default=None,
advanced=True,
)
# ------- REELS OPTIONS -------
share_reels_feed: bool | None = SchemaField(
description="Whether Reel should appear in both Feed and Reels tabs",
default=None,
advanced=True,
)
audio_name: str | None = SchemaField(
description="Audio name for Reels (e.g., 'The Weeknd - Blinding Lights')",
default=None,
advanced=True,
)
thumbnail: str | None = SchemaField(
description="Thumbnail URL for Reel video", default=None, advanced=True
)
thumbnail_offset: int | None = SchemaField(
description="Thumbnail frame offset in milliseconds (default: 0)",
default=0,
advanced=True,
)
# ------- POST OPTIONS -------
alt_text: list[str] = SchemaField(
description="Alt text for each media item (up to 1,000 chars each, accessibility feature), each item in the list corresponds to a media item in the media_urls list",
default_factory=list,
advanced=True,
)
location_id: str | None = SchemaField(
description="Facebook Page ID or name for location tagging (e.g., '7640348500' or '@guggenheimmuseum')",
default=None,
advanced=True,
)
user_tags: list[dict[str, Any]] = SchemaField(
description="List of users to tag with coordinates for images",
default_factory=list,
advanced=True,
)
collaborators: list[str] = SchemaField(
description="Instagram usernames to invite as collaborators (max 3, public accounts only)",
default_factory=list,
advanced=True,
)
auto_resize: bool | None = SchemaField(
description="Auto-resize images to 1080x1080px for Instagram",
default=None,
advanced=True,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
id="89b02b96-a7cb-46f4-9900-c48b32fe1552",
description="Post to Instagram using Ayrshare. Requires a Business or Creator Instagram Account connected with a Facebook Page",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToInstagramBlock.Input,
output_schema=PostToInstagramBlock.Output,
)
async def run(
self,
input_data: "PostToInstagramBlock.Input",
*,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Instagram with Instagram-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate Instagram constraints
if len(input_data.post) > 2200:
yield "error", f"Instagram post text exceeds 2,200 character limit ({len(input_data.post)} characters)"
return
if len(input_data.media_urls) > 10:
yield "error", "Instagram supports a maximum of 10 images/videos in a carousel"
return
if len(input_data.collaborators) > 3:
yield "error", "Instagram supports a maximum of 3 collaborators"
return
# Validate that if any reel option is set, all required reel options are set
reel_options = [
input_data.share_reels_feed,
input_data.audio_name,
input_data.thumbnail,
]
if any(reel_options) and not all(reel_options):
yield "error", "When posting a reel, all reel options must be set: share_reels_feed, audio_name, and either thumbnail or thumbnail_offset"
return
# Count hashtags and mentions
hashtag_count = input_data.post.count("#")
mention_count = input_data.post.count("@")
if hashtag_count > 30:
yield "error", f"Instagram allows maximum 30 hashtags ({hashtag_count} found)"
return
if mention_count > 3:
yield "error", f"Instagram allows maximum 3 @mentions ({mention_count} found)"
return
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
# Build Instagram-specific options
instagram_options = {}
# Stories
if input_data.is_story:
instagram_options["stories"] = True
# Reels options
if input_data.share_reels_feed is not None:
instagram_options["shareReelsFeed"] = input_data.share_reels_feed
if input_data.audio_name:
instagram_options["audioName"] = input_data.audio_name
if input_data.thumbnail:
instagram_options["thumbNail"] = input_data.thumbnail
elif input_data.thumbnail_offset and input_data.thumbnail_offset > 0:
instagram_options["thumbNailOffset"] = input_data.thumbnail_offset
# Alt text
if input_data.alt_text:
# Validate alt text length
for i, alt in enumerate(input_data.alt_text):
if len(alt) > 1000:
yield "error", f"Alt text {i+1} exceeds 1,000 character limit ({len(alt)} characters)"
return
instagram_options["altText"] = input_data.alt_text
# Location
if input_data.location_id:
instagram_options["locationId"] = input_data.location_id
# User tags
if input_data.user_tags:
user_tags_list = []
for tag in input_data.user_tags:
try:
tag_obj = InstagramUserTag(**tag)
except Exception as e:
yield "error", f"Invalid user tag: {e}, tages need to be a dictionary with a 3 items: username (str), x (float) and y (float)"
return
tag_dict: dict[str, float | str] = {"username": tag_obj.username}
if tag_obj.x is not None and tag_obj.y is not None:
# Validate coordinates
if not (0.0 <= tag_obj.x <= 1.0) or not (0.0 <= tag_obj.y <= 1.0):
yield "error", f"User tag coordinates must be between 0.0 and 1.0 (user: {tag_obj.username})"
return
tag_dict["x"] = tag_obj.x
tag_dict["y"] = tag_obj.y
user_tags_list.append(tag_dict)
instagram_options["userTags"] = user_tags_list
# Collaborators
if input_data.collaborators:
instagram_options["collaborators"] = input_data.collaborators
# Auto resize
if input_data.auto_resize:
instagram_options["autoResize"] = True
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.INSTAGRAM],
media_urls=input_data.media_urls,
is_video=input_data.is_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
instagram_options=instagram_options if instagram_options else None,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,222 +0,0 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class PostToLinkedInBlock(Block):
"""Block for posting to LinkedIn with LinkedIn-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for LinkedIn posts."""
# Override post field to include LinkedIn-specific information
post: str = SchemaField(
description="The post text (max 3,000 chars, hashtags supported with #)",
default="",
advanced=False,
)
# Override media_urls to include LinkedIn-specific constraints
media_urls: list[str] = SchemaField(
description="Optional list of media URLs. LinkedIn supports up to 9 images, videos, or documents (PPT, PPTX, DOC, DOCX, PDF <100MB, <300 pages).",
default_factory=list,
advanced=False,
)
# LinkedIn-specific options
visibility: str = SchemaField(
description="Post visibility: 'public' (default), 'connections' (personal only), 'loggedin'",
default="public",
advanced=True,
)
alt_text: list[str] = SchemaField(
description="Alt text for each image (accessibility feature, not supported for videos/documents)",
default_factory=list,
advanced=True,
)
titles: list[str] = SchemaField(
description="Title/caption for each image or video",
default_factory=list,
advanced=True,
)
document_title: str = SchemaField(
description="Title for document posts (max 400 chars, uses filename if not specified)",
default="",
advanced=True,
)
thumbnail: str = SchemaField(
description="Thumbnail URL for video (PNG/JPG, same dimensions as video, <10MB)",
default="",
advanced=True,
)
# LinkedIn targeting options (flattened from LinkedInTargeting object)
targeting_countries: list[str] | None = SchemaField(
description="Country codes for targeting (e.g., ['US', 'IN', 'DE', 'GB']). Requires 300+ followers in target audience.",
default=None,
advanced=True,
)
targeting_seniorities: list[str] | None = SchemaField(
description="Seniority levels for targeting (e.g., ['Senior', 'VP']). Requires 300+ followers in target audience.",
default=None,
advanced=True,
)
targeting_degrees: list[str] | None = SchemaField(
description="Education degrees for targeting. Requires 300+ followers in target audience.",
default=None,
advanced=True,
)
targeting_fields_of_study: list[str] | None = SchemaField(
description="Fields of study for targeting. Requires 300+ followers in target audience.",
default=None,
advanced=True,
)
targeting_industries: list[str] | None = SchemaField(
description="Industry categories for targeting. Requires 300+ followers in target audience.",
default=None,
advanced=True,
)
targeting_job_functions: list[str] | None = SchemaField(
description="Job function categories for targeting. Requires 300+ followers in target audience.",
default=None,
advanced=True,
)
targeting_staff_count_ranges: list[str] | None = SchemaField(
description="Company size ranges for targeting. Requires 300+ followers in target audience.",
default=None,
advanced=True,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
id="589af4e4-507f-42fd-b9ac-a67ecef25811",
description="Post to LinkedIn using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToLinkedInBlock.Input,
output_schema=PostToLinkedInBlock.Output,
)
async def run(
self,
input_data: "PostToLinkedInBlock.Input",
*,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to LinkedIn with LinkedIn-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate LinkedIn constraints
if len(input_data.post) > 3000:
yield "error", f"LinkedIn post text exceeds 3,000 character limit ({len(input_data.post)} characters)"
return
if len(input_data.media_urls) > 9:
yield "error", "LinkedIn supports a maximum of 9 images/videos/documents"
return
if input_data.document_title and len(input_data.document_title) > 400:
yield "error", f"LinkedIn document title exceeds 400 character limit ({len(input_data.document_title)} characters)"
return
# Validate visibility option
valid_visibility = ["public", "connections", "loggedin"]
if input_data.visibility not in valid_visibility:
yield "error", f"LinkedIn visibility must be one of: {', '.join(valid_visibility)}"
return
# Check for document extensions
document_extensions = [".ppt", ".pptx", ".doc", ".docx", ".pdf"]
has_documents = any(
any(url.lower().endswith(ext) for ext in document_extensions)
for url in input_data.media_urls
)
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
# Build LinkedIn-specific options
linkedin_options = {}
# Visibility
if input_data.visibility != "public":
linkedin_options["visibility"] = input_data.visibility
# Alt text (not supported for videos or documents)
if input_data.alt_text and not has_documents:
linkedin_options["altText"] = input_data.alt_text
# Titles/captions
if input_data.titles:
linkedin_options["titles"] = input_data.titles
# Document title
if input_data.document_title and has_documents:
linkedin_options["title"] = input_data.document_title
# Video thumbnail
if input_data.thumbnail:
linkedin_options["thumbNail"] = input_data.thumbnail
# Audience targeting (from flattened fields)
targeting_dict = {}
if input_data.targeting_countries:
targeting_dict["countries"] = input_data.targeting_countries
if input_data.targeting_seniorities:
targeting_dict["seniorities"] = input_data.targeting_seniorities
if input_data.targeting_degrees:
targeting_dict["degrees"] = input_data.targeting_degrees
if input_data.targeting_fields_of_study:
targeting_dict["fieldsOfStudy"] = input_data.targeting_fields_of_study
if input_data.targeting_industries:
targeting_dict["industries"] = input_data.targeting_industries
if input_data.targeting_job_functions:
targeting_dict["jobFunctions"] = input_data.targeting_job_functions
if input_data.targeting_staff_count_ranges:
targeting_dict["staffCountRanges"] = input_data.targeting_staff_count_ranges
if targeting_dict:
linkedin_options["targeting"] = targeting_dict
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.LINKEDIN],
media_urls=input_data.media_urls,
is_video=input_data.is_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
linkedin_options=linkedin_options if linkedin_options else None,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,214 +0,0 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import (
BaseAyrshareInput,
PinterestCarouselOption,
create_ayrshare_client,
get_profile_key,
)
class PostToPinterestBlock(Block):
"""Block for posting to Pinterest with Pinterest-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for Pinterest posts."""
# Override post field to include Pinterest-specific information
post: str = SchemaField(
description="Pin description (max 500 chars, links not clickable - use link field instead)",
default="",
advanced=False,
)
# Override media_urls to include Pinterest-specific constraints
media_urls: list[str] = SchemaField(
description="Required image/video URLs. Pinterest requires at least one image. Videos need thumbnail. Up to 5 images for carousel.",
default_factory=list,
advanced=False,
)
# Pinterest-specific options
pin_title: str = SchemaField(
description="Pin title displayed in 'Add your title' section (max 100 chars)",
default="",
advanced=True,
)
link: str = SchemaField(
description="Clickable destination URL when users click the pin (max 2048 chars)",
default="",
advanced=True,
)
board_id: str = SchemaField(
description="Pinterest Board ID to post to (from /user/details endpoint, uses default board if not specified)",
default="",
advanced=True,
)
note: str = SchemaField(
description="Private note for the pin (only visible to you and board collaborators)",
default="",
advanced=True,
)
thumbnail: str = SchemaField(
description="Required thumbnail URL for video pins (must have valid image Content-Type)",
default="",
advanced=True,
)
carousel_options: list[PinterestCarouselOption] = SchemaField(
description="Options for each image in carousel (title, link, description per image)",
default_factory=list,
advanced=True,
)
alt_text: list[str] = SchemaField(
description="Alt text for each image/video (max 500 chars each, accessibility feature)",
default_factory=list,
advanced=True,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
disabled=True,
id="3ca46e05-dbaa-4afb-9e95-5a429c4177e6",
description="Post to Pinterest using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToPinterestBlock.Input,
output_schema=PostToPinterestBlock.Output,
)
async def run(
self,
input_data: "PostToPinterestBlock.Input",
*,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Pinterest with Pinterest-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate Pinterest constraints
if len(input_data.post) > 500:
yield "error", f"Pinterest pin description exceeds 500 character limit ({len(input_data.post)} characters)"
return
if len(input_data.pin_title) > 100:
yield "error", f"Pinterest pin title exceeds 100 character limit ({len(input_data.pin_title)} characters)"
return
if len(input_data.link) > 2048:
yield "error", f"Pinterest link URL exceeds 2048 character limit ({len(input_data.link)} characters)"
return
if len(input_data.media_urls) == 0:
yield "error", "Pinterest requires at least one image or video"
return
if len(input_data.media_urls) > 5:
yield "error", "Pinterest supports a maximum of 5 images in a carousel"
return
# Check if video is included and thumbnail is provided
video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".webm"]
has_video = any(
any(url.lower().endswith(ext) for ext in video_extensions)
for url in input_data.media_urls
)
if (has_video or input_data.is_video) and not input_data.thumbnail:
yield "error", "Pinterest video pins require a thumbnail URL"
return
# Validate alt text length
for i, alt in enumerate(input_data.alt_text):
if len(alt) > 500:
yield "error", f"Pinterest alt text {i+1} exceeds 500 character limit ({len(alt)} characters)"
return
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
# Build Pinterest-specific options
pinterest_options = {}
# Pin title
if input_data.pin_title:
pinterest_options["title"] = input_data.pin_title
# Clickable link
if input_data.link:
pinterest_options["link"] = input_data.link
# Board ID
if input_data.board_id:
pinterest_options["boardId"] = input_data.board_id
# Private note
if input_data.note:
pinterest_options["note"] = input_data.note
# Video thumbnail
if input_data.thumbnail:
pinterest_options["thumbNail"] = input_data.thumbnail
# Carousel options
if input_data.carousel_options:
carousel_list = []
for option in input_data.carousel_options:
carousel_dict = {}
if option.title:
carousel_dict["title"] = option.title
if option.link:
carousel_dict["link"] = option.link
if option.description:
carousel_dict["description"] = option.description
if carousel_dict: # Only add if not empty
carousel_list.append(carousel_dict)
if carousel_list:
pinterest_options["carouselOptions"] = carousel_list
# Alt text
if input_data.alt_text:
pinterest_options["altText"] = input_data.alt_text
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.PINTEREST],
media_urls=input_data.media_urls,
is_video=input_data.is_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
pinterest_options=pinterest_options if pinterest_options else None,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,69 +0,0 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class PostToRedditBlock(Block):
"""Block for posting to Reddit."""
class Input(BaseAyrshareInput):
"""Input schema for Reddit posts."""
pass # Uses all base fields
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
disabled=True,
id="c7733580-3c72-483e-8e47-a8d58754d853",
description="Post to Reddit using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToRedditBlock.Input,
output_schema=PostToRedditBlock.Output,
)
async def run(
self, input_data: "PostToRedditBlock.Input", *, user_id: str, **kwargs
) -> BlockOutput:
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured."
return
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.REDDIT],
media_urls=input_data.media_urls,
is_video=input_data.is_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,129 +0,0 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class PostToSnapchatBlock(Block):
"""Block for posting to Snapchat with Snapchat-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for Snapchat posts."""
# Override post field to include Snapchat-specific information
post: str = SchemaField(
description="The post text (optional for video-only content)",
default="",
advanced=False,
)
# Override media_urls to include Snapchat-specific constraints
media_urls: list[str] = SchemaField(
description="Required video URL for Snapchat posts. Snapchat only supports video content.",
default_factory=list,
advanced=False,
)
# Snapchat-specific options
story_type: str = SchemaField(
description="Type of Snapchat content: 'story' (24-hour Stories), 'saved_story' (Saved Stories), or 'spotlight' (Spotlight posts)",
default="story",
advanced=True,
)
video_thumbnail: str = SchemaField(
description="Thumbnail URL for video content (optional, auto-generated if not provided)",
default="",
advanced=True,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
disabled=True,
id="a9d7f854-2c83-4e96-b3a1-7f2e9c5d4b8e",
description="Post to Snapchat using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToSnapchatBlock.Input,
output_schema=PostToSnapchatBlock.Output,
)
async def run(
self,
input_data: "PostToSnapchatBlock.Input",
*,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Snapchat with Snapchat-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate Snapchat constraints
if not input_data.media_urls:
yield "error", "Snapchat requires at least one video URL"
return
if len(input_data.media_urls) > 1:
yield "error", "Snapchat supports only one video per post"
return
# Validate story type
valid_story_types = ["story", "saved_story", "spotlight"]
if input_data.story_type not in valid_story_types:
yield "error", f"Snapchat story type must be one of: {', '.join(valid_story_types)}"
return
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
# Build Snapchat-specific options
snapchat_options = {}
# Story type
if input_data.story_type != "story":
snapchat_options["storyType"] = input_data.story_type
# Video thumbnail
if input_data.video_thumbnail:
snapchat_options["videoThumbnail"] = input_data.video_thumbnail
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.SNAPCHAT],
media_urls=input_data.media_urls,
is_video=True, # Snapchat only supports video
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
snapchat_options=snapchat_options if snapchat_options else None,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,116 +0,0 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class PostToTelegramBlock(Block):
"""Block for posting to Telegram with Telegram-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for Telegram posts."""
# Override post field to include Telegram-specific information
post: str = SchemaField(
description="The post text (empty string allowed). Use @handle to mention other Telegram users.",
default="",
advanced=False,
)
# Override media_urls to include Telegram-specific constraints
media_urls: list[str] = SchemaField(
description="Optional list of media URLs. For animated GIFs, only one URL is allowed. Telegram will auto-preview links unless image/video is included.",
default_factory=list,
advanced=False,
)
# Override is_video to include GIF-specific information
is_video: bool = SchemaField(
description="Whether the media is a video. Set to true for animated GIFs that don't end in .gif/.GIF extension.",
default=False,
advanced=True,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
disabled=True,
id="47bc74eb-4af2-452c-b933-af377c7287df",
description="Post to Telegram using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToTelegramBlock.Input,
output_schema=PostToTelegramBlock.Output,
)
async def run(
self,
input_data: "PostToTelegramBlock.Input",
*,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Telegram with Telegram-specific validation."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate Telegram constraints
# Check for animated GIFs - only one URL allowed
gif_extensions = [".gif", ".GIF"]
has_gif = any(
any(url.endswith(ext) for ext in gif_extensions)
for url in input_data.media_urls
)
if has_gif and len(input_data.media_urls) > 1:
yield "error", "Telegram animated GIFs support only one URL per post"
return
# Auto-detect if we need to set is_video for GIFs without proper extension
detected_is_video = input_data.is_video
if input_data.media_urls and not has_gif and not input_data.is_video:
# Check if this might be a GIF without proper extension
# This is just informational - user needs to set is_video manually
pass
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.TELEGRAM],
media_urls=input_data.media_urls,
is_video=detected_is_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,111 +0,0 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class PostToThreadsBlock(Block):
"""Block for posting to Threads with Threads-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for Threads posts."""
# Override post field to include Threads-specific information
post: str = SchemaField(
description="The post text (max 500 chars, empty string allowed). Only 1 hashtag allowed. Use @handle to mention users.",
default="",
advanced=False,
)
# Override media_urls to include Threads-specific constraints
media_urls: list[str] = SchemaField(
description="Optional list of media URLs. Supports up to 20 images/videos in a carousel. Auto-preview links unless media is included.",
default_factory=list,
advanced=False,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
disabled=True,
id="f8c3b2e1-9d4a-4e5f-8c7b-6a9e8d2f1c3b",
description="Post to Threads using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToThreadsBlock.Input,
output_schema=PostToThreadsBlock.Output,
)
async def run(
self,
input_data: "PostToThreadsBlock.Input",
*,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to Threads with Threads-specific validation."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate Threads constraints
if len(input_data.post) > 500:
yield "error", f"Threads post text exceeds 500 character limit ({len(input_data.post)} characters)"
return
if len(input_data.media_urls) > 20:
yield "error", "Threads supports a maximum of 20 images/videos in a carousel"
return
# Count hashtags (only 1 allowed)
hashtag_count = input_data.post.count("#")
if hashtag_count > 1:
yield "error", f"Threads allows only 1 hashtag per post ({hashtag_count} found)"
return
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
# Build Threads-specific options
threads_options = {}
# Note: Based on the documentation, Threads doesn't seem to have specific options
# beyond the standard ones. The main constraints are validation-based.
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.THREADS],
media_urls=input_data.media_urls,
is_video=input_data.is_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
threads_options=threads_options if threads_options else None,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,243 +0,0 @@
from enum import Enum
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class TikTokVisibility(str, Enum):
PUBLIC = "public"
PRIVATE = "private"
FOLLOWERS = "followers"
class PostToTikTokBlock(Block):
"""Block for posting to TikTok with TikTok-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for TikTok posts."""
# Override post field to include TikTok-specific information
post: str = SchemaField(
description="The post text (max 2,200 chars, empty string allowed). Use @handle to mention users. Line breaks will be ignored.",
advanced=False,
)
# Override media_urls to include TikTok-specific constraints
media_urls: list[str] = SchemaField(
description="Required media URLs. Either 1 video OR up to 35 images (JPG/JPEG/WEBP only). Cannot mix video and images.",
default_factory=list,
advanced=False,
)
# TikTok-specific options
auto_add_music: bool = SchemaField(
description="Whether to automatically add recommended music to the post. If you set this field to true, you can change the music later in the TikTok app.",
default=False,
advanced=True,
)
disable_comments: bool = SchemaField(
description="Disable comments on the published post",
default=False,
advanced=True,
)
disable_duet: bool = SchemaField(
description="Disable duets on published video (video only)",
default=False,
advanced=True,
)
disable_stitch: bool = SchemaField(
description="Disable stitch on published video (video only)",
default=False,
advanced=True,
)
is_ai_generated: bool = SchemaField(
description="If you enable the toggle, your video will be labeled as “Creator labeled as AI-generated” once posted and cant be changed. The “Creator labeled as AI-generated” label indicates that the content was completely AI-generated or significantly edited with AI.",
default=False,
advanced=True,
)
is_branded_content: bool = SchemaField(
description="Whether to enable the Branded Content toggle. If this field is set to true, the video will be labeled as Branded Content, indicating you are in a paid partnership with a brand. A “Paid partnership” label will be attached to the video.",
default=False,
advanced=True,
)
is_brand_organic: bool = SchemaField(
description="Whether to enable the Brand Organic Content toggle. If this field is set to true, the video will be labeled as Brand Organic Content, indicating you are promoting yourself or your own business. A “Promotional content” label will be attached to the video.",
default=False,
advanced=True,
)
image_cover_index: int = SchemaField(
description="Index of image to use as cover (0-based, image posts only)",
default=0,
advanced=True,
)
title: str = SchemaField(
description="Title for image posts", default="", advanced=True
)
thumbnail_offset: int = SchemaField(
description="Video thumbnail frame offset in milliseconds (video only)",
default=0,
advanced=True,
)
visibility: TikTokVisibility = SchemaField(
description="Post visibility: 'public', 'private', 'followers', or 'friends'",
default=TikTokVisibility.PUBLIC,
advanced=True,
)
draft: bool = SchemaField(
description="Create as draft post (video only)",
default=False,
advanced=True,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
id="7faf4b27-96b0-4f05-bf64-e0de54ae74e1",
description="Post to TikTok using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToTikTokBlock.Input,
output_schema=PostToTikTokBlock.Output,
)
async def run(
self, input_data: "PostToTikTokBlock.Input", *, user_id: str, **kwargs
) -> BlockOutput:
"""Post to TikTok with TikTok-specific validation and options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate TikTok constraints
if len(input_data.post) > 2200:
yield "error", f"TikTok post text exceeds 2,200 character limit ({len(input_data.post)} characters)"
return
if not input_data.media_urls:
yield "error", "TikTok requires at least one media URL (either 1 video or up to 35 images)"
return
# Check for video vs image constraints
video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".webm"]
image_extensions = [".jpg", ".jpeg", ".webp"]
has_video = input_data.is_video or any(
any(url.lower().endswith(ext) for ext in video_extensions)
for url in input_data.media_urls
)
has_images = any(
any(url.lower().endswith(ext) for ext in image_extensions)
for url in input_data.media_urls
)
if has_video and has_images:
yield "error", "TikTok does not support mixing video and images in the same post"
return
if has_video and len(input_data.media_urls) > 1:
yield "error", "TikTok supports only 1 video per post"
return
if has_images and len(input_data.media_urls) > 35:
yield "error", "TikTok supports a maximum of 35 images per post"
return
# Validate image cover index
if has_images and input_data.image_cover_index >= len(input_data.media_urls):
yield "error", f"Image cover index {input_data.image_cover_index} is out of range (max: {len(input_data.media_urls) - 1})"
return
# Check for PNG files (not supported)
has_png = any(url.lower().endswith(".png") for url in input_data.media_urls)
if has_png:
yield "error", "TikTok does not support PNG files. Please use JPG, JPEG, or WEBP for images."
return
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
# Build TikTok-specific options
tiktok_options = {}
# Common options
if input_data.auto_add_music and has_images:
tiktok_options["autoAddMusic"] = True
if input_data.disable_comments:
tiktok_options["disableComments"] = True
if input_data.is_branded_content:
tiktok_options["isBrandedContent"] = True
if input_data.is_brand_organic:
tiktok_options["isBrandOrganic"] = True
# Video-specific options
if has_video:
if input_data.disable_duet:
tiktok_options["disableDuet"] = True
if input_data.disable_stitch:
tiktok_options["disableStitch"] = True
if input_data.is_ai_generated:
tiktok_options["isAIGenerated"] = True
if input_data.thumbnail_offset > 0:
tiktok_options["thumbNailOffset"] = input_data.thumbnail_offset
if input_data.draft:
tiktok_options["draft"] = True
# Image-specific options
if has_images:
if input_data.image_cover_index > 0:
tiktok_options["imageCoverIndex"] = input_data.image_cover_index
if input_data.title:
tiktok_options["title"] = input_data.title
if input_data.visibility != TikTokVisibility.PUBLIC:
tiktok_options["visibility"] = input_data.visibility.value
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.TIKTOK],
media_urls=input_data.media_urls,
is_video=has_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
tiktok_options=tiktok_options if tiktok_options else None,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,241 +0,0 @@
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class PostToXBlock(Block):
"""Block for posting to X / Twitter with Twitter-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for X / Twitter posts."""
# Override post field to include X-specific information
post: str = SchemaField(
description="The post text (max 280 chars, up to 25,000 for Premium users). Use @handle to mention users. Use \\n\\n for thread breaks.",
advanced=False,
)
# Override media_urls to include X-specific constraints
media_urls: list[str] = SchemaField(
description="Optional list of media URLs. X supports up to 4 images or videos per tweet. Auto-preview links unless media is included.",
default_factory=list,
advanced=False,
)
# X-specific options
reply_to_id: str | None = SchemaField(
description="ID of the tweet to reply to",
default=None,
advanced=True,
)
quote_tweet_id: str | None = SchemaField(
description="ID of the tweet to quote (low-level Tweet ID)",
default=None,
advanced=True,
)
poll_options: list[str] = SchemaField(
description="Poll options (2-4 choices)",
default_factory=list,
advanced=True,
)
poll_duration: int = SchemaField(
description="Poll duration in minutes (1-10080)",
default=1440,
advanced=True,
)
alt_text: list[str] = SchemaField(
description="Alt text for each image (max 1,000 chars each, not supported for videos)",
default_factory=list,
advanced=True,
)
is_thread: bool = SchemaField(
description="Whether to automatically break post into thread based on line breaks",
default=False,
advanced=True,
)
thread_number: bool = SchemaField(
description="Add thread numbers (1/n format) to each thread post",
default=False,
advanced=True,
)
thread_media_urls: list[str] = SchemaField(
description="Media URLs for thread posts (one per thread, use 'null' to skip)",
default_factory=list,
advanced=True,
)
long_post: bool = SchemaField(
description="Force long form post (requires Premium X account)",
default=False,
advanced=True,
)
long_video: bool = SchemaField(
description="Enable long video upload (requires approval and Business/Enterprise plan)",
default=False,
advanced=True,
)
subtitle_url: str = SchemaField(
description="URL to SRT subtitle file for videos (must be HTTPS and end in .srt)",
default="",
advanced=True,
)
subtitle_language: str = SchemaField(
description="Language code for subtitles (default: 'en')",
default="en",
advanced=True,
)
subtitle_name: str = SchemaField(
description="Name of caption track (max 150 chars, default: 'English')",
default="English",
advanced=True,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
id="9e8f844e-b4a5-4b25-80f2-9e1dd7d67625",
description="Post to X / Twitter using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToXBlock.Input,
output_schema=PostToXBlock.Output,
)
async def run(
self,
input_data: "PostToXBlock.Input",
*,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to X / Twitter with enhanced X-specific options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate X constraints
if not input_data.long_post and len(input_data.post) > 280:
yield "error", f"X post text exceeds 280 character limit ({len(input_data.post)} characters). Enable 'long_post' for Premium accounts."
return
if input_data.long_post and len(input_data.post) > 25000:
yield "error", f"X long post text exceeds 25,000 character limit ({len(input_data.post)} characters)"
return
if len(input_data.media_urls) > 4:
yield "error", "X supports a maximum of 4 images or videos per tweet"
return
# Validate poll options
if input_data.poll_options:
if len(input_data.poll_options) < 2 or len(input_data.poll_options) > 4:
yield "error", "X polls require 2-4 options"
return
if input_data.poll_duration < 1 or input_data.poll_duration > 10080:
yield "error", "X poll duration must be between 1 and 10,080 minutes (7 days)"
return
# Validate alt text
if input_data.alt_text:
for i, alt in enumerate(input_data.alt_text):
if len(alt) > 1000:
yield "error", f"X alt text {i+1} exceeds 1,000 character limit ({len(alt)} characters)"
return
# Validate subtitle settings
if input_data.subtitle_url:
if not input_data.subtitle_url.startswith(
"https://"
) or not input_data.subtitle_url.endswith(".srt"):
yield "error", "Subtitle URL must start with https:// and end with .srt"
return
if len(input_data.subtitle_name) > 150:
yield "error", f"Subtitle name exceeds 150 character limit ({len(input_data.subtitle_name)} characters)"
return
# Convert datetime to ISO format if provided
iso_date = (
input_data.schedule_date.isoformat() if input_data.schedule_date else None
)
# Build X-specific options
twitter_options = {}
# Basic options
if input_data.reply_to_id:
twitter_options["replyToId"] = input_data.reply_to_id
if input_data.quote_tweet_id:
twitter_options["quoteTweetId"] = input_data.quote_tweet_id
if input_data.long_post:
twitter_options["longPost"] = True
if input_data.long_video:
twitter_options["longVideo"] = True
# Poll options
if input_data.poll_options:
twitter_options["poll"] = {
"duration": input_data.poll_duration,
"options": input_data.poll_options,
}
# Alt text for images
if input_data.alt_text:
twitter_options["altText"] = input_data.alt_text
# Thread options
if input_data.is_thread:
twitter_options["thread"] = True
if input_data.thread_number:
twitter_options["threadNumber"] = True
if input_data.thread_media_urls:
twitter_options["mediaUrls"] = input_data.thread_media_urls
# Video subtitle options
if input_data.subtitle_url:
twitter_options["subTitleUrl"] = input_data.subtitle_url
twitter_options["subTitleLanguage"] = input_data.subtitle_language
twitter_options["subTitleName"] = input_data.subtitle_name
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.TWITTER],
media_urls=input_data.media_urls,
is_video=input_data.is_video,
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
twitter_options=twitter_options if twitter_options else None,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -1,310 +0,0 @@
from enum import Enum
from typing import Any
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
from backend.sdk import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
SchemaField,
)
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
class YouTubeVisibility(str, Enum):
PRIVATE = "private"
PUBLIC = "public"
UNLISTED = "unlisted"
class PostToYouTubeBlock(Block):
"""Block for posting to YouTube with YouTube-specific options."""
class Input(BaseAyrshareInput):
"""Input schema for YouTube posts."""
# Override post field to include YouTube-specific information
post: str = SchemaField(
description="Video description (max 5,000 chars, empty string allowed). Cannot contain < or > characters.",
advanced=False,
)
# Override media_urls to include YouTube-specific constraints
media_urls: list[str] = SchemaField(
description="Required video URL. YouTube only supports 1 video per post.",
default_factory=list,
advanced=False,
)
# YouTube-specific required options
title: str = SchemaField(
description="Video title (max 100 chars, required). Cannot contain < or > characters.",
advanced=False,
)
# YouTube-specific optional options
visibility: YouTubeVisibility = SchemaField(
description="Video visibility: 'private' (default), 'public' , or 'unlisted'",
default=YouTubeVisibility.PRIVATE,
advanced=False,
)
thumbnail: str | None = SchemaField(
description="Thumbnail URL (JPEG/PNG under 2MB, must end in .png/.jpg/.jpeg). Requires phone verification.",
default=None,
advanced=True,
)
playlist_id: str | None = SchemaField(
description="Playlist ID to add video (user must own playlist)",
default=None,
advanced=True,
)
tags: list[str] | None = SchemaField(
description="Video tags (min 2 chars each, max 500 chars total)",
default=None,
advanced=True,
)
made_for_kids: bool | None = SchemaField(
description="Self-declared kids content", default=None, advanced=True
)
is_shorts: bool | None = SchemaField(
description="Post as YouTube Short (max 3 minutes, adds #shorts)",
default=None,
advanced=True,
)
notify_subscribers: bool | None = SchemaField(
description="Send notification to subscribers", default=None, advanced=True
)
category_id: int | None = SchemaField(
description="Video category ID (e.g., 24 = Entertainment)",
default=None,
advanced=True,
)
contains_synthetic_media: bool | None = SchemaField(
description="Disclose realistic AI/synthetic content",
default=None,
advanced=True,
)
publish_at: str | None = SchemaField(
description="UTC publish time (YouTube controlled, format: 2022-10-08T21:18:36Z)",
default=None,
advanced=True,
)
# YouTube targeting options (flattened from YouTubeTargeting object)
targeting_block_countries: list[str] | None = SchemaField(
description="Country codes to block from viewing (e.g., ['US', 'CA'])",
default=None,
advanced=True,
)
targeting_allow_countries: list[str] | None = SchemaField(
description="Country codes to allow viewing (e.g., ['GB', 'AU'])",
default=None,
advanced=True,
)
subtitle_url: str | None = SchemaField(
description="URL to SRT or SBV subtitle file (must be HTTPS and end in .srt/.sbv, under 100MB)",
default=None,
advanced=True,
)
subtitle_language: str | None = SchemaField(
description="Language code for subtitles (default: 'en')",
default=None,
advanced=True,
)
subtitle_name: str | None = SchemaField(
description="Name of caption track (max 150 chars, default: 'English')",
default=None,
advanced=True,
)
class Output(BlockSchema):
post_result: PostResponse = SchemaField(description="The result of the post")
post: PostIds = SchemaField(description="The result of the post")
def __init__(self):
super().__init__(
id="0082d712-ff1b-4c3d-8a8d-6c7721883b83",
description="Post to YouTube using Ayrshare",
categories={BlockCategory.SOCIAL},
block_type=BlockType.AYRSHARE,
input_schema=PostToYouTubeBlock.Input,
output_schema=PostToYouTubeBlock.Output,
)
async def run(
self,
input_data: "PostToYouTubeBlock.Input",
*,
user_id: str,
**kwargs,
) -> BlockOutput:
"""Post to YouTube with YouTube-specific validation and options."""
profile_key = await get_profile_key(user_id)
if not profile_key:
yield "error", "Please link a social account via Ayrshare"
return
client = create_ayrshare_client()
if not client:
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
return
# Validate YouTube constraints
if not input_data.title:
yield "error", "YouTube requires a video title"
return
if len(input_data.title) > 100:
yield "error", f"YouTube title exceeds 100 character limit ({len(input_data.title)} characters)"
return
if len(input_data.post) > 5000:
yield "error", f"YouTube description exceeds 5,000 character limit ({len(input_data.post)} characters)"
return
# Check for forbidden characters
forbidden_chars = ["<", ">"]
for char in forbidden_chars:
if char in input_data.title:
yield "error", f"YouTube title cannot contain '{char}' character"
return
if char in input_data.post:
yield "error", f"YouTube description cannot contain '{char}' character"
return
if not input_data.media_urls:
yield "error", "YouTube requires exactly one video URL"
return
if len(input_data.media_urls) > 1:
yield "error", "YouTube supports only 1 video per post"
return
# Validate visibility option
valid_visibility = ["private", "public", "unlisted"]
if input_data.visibility not in valid_visibility:
yield "error", f"YouTube visibility must be one of: {', '.join(valid_visibility)}"
return
# Validate thumbnail URL format
if input_data.thumbnail:
valid_extensions = [".png", ".jpg", ".jpeg"]
if not any(
input_data.thumbnail.lower().endswith(ext) for ext in valid_extensions
):
yield "error", "YouTube thumbnail must end in .png, .jpg, or .jpeg"
return
# Validate tags
if input_data.tags:
total_tag_length = sum(len(tag) for tag in input_data.tags)
if total_tag_length > 500:
yield "error", f"YouTube tags total length exceeds 500 characters ({total_tag_length} characters)"
return
for tag in input_data.tags:
if len(tag) < 2:
yield "error", f"YouTube tag '{tag}' is too short (minimum 2 characters)"
return
# Validate subtitle URL
if input_data.subtitle_url:
if not input_data.subtitle_url.startswith("https://"):
yield "error", "YouTube subtitle URL must start with https://"
return
valid_subtitle_extensions = [".srt", ".sbv"]
if not any(
input_data.subtitle_url.lower().endswith(ext)
for ext in valid_subtitle_extensions
):
yield "error", "YouTube subtitle URL must end in .srt or .sbv"
return
if input_data.subtitle_name and len(input_data.subtitle_name) > 150:
yield "error", f"YouTube subtitle name exceeds 150 character limit ({len(input_data.subtitle_name)} characters)"
return
# Validate publish_at format if provided
if input_data.publish_at and input_data.schedule_date:
yield "error", "Cannot use both 'publish_at' and 'schedule_date'. Use 'publish_at' for YouTube-controlled publishing."
return
# Convert datetime to ISO format if provided (only if not using publish_at)
iso_date = None
if not input_data.publish_at and input_data.schedule_date:
iso_date = input_data.schedule_date.isoformat()
# Build YouTube-specific options
youtube_options: dict[str, Any] = {"title": input_data.title}
# Basic options
if input_data.visibility != "private":
youtube_options["visibility"] = input_data.visibility
if input_data.thumbnail:
youtube_options["thumbNail"] = input_data.thumbnail
if input_data.playlist_id:
youtube_options["playListId"] = input_data.playlist_id
if input_data.tags:
youtube_options["tags"] = input_data.tags
if input_data.made_for_kids:
youtube_options["madeForKids"] = True
if input_data.is_shorts:
youtube_options["shorts"] = True
if not input_data.notify_subscribers:
youtube_options["notifySubscribers"] = False
if input_data.category_id and input_data.category_id > 0:
youtube_options["categoryId"] = input_data.category_id
if input_data.contains_synthetic_media:
youtube_options["containsSyntheticMedia"] = True
if input_data.publish_at:
youtube_options["publishAt"] = input_data.publish_at
# Country targeting (from flattened fields)
targeting_dict = {}
if input_data.targeting_block_countries:
targeting_dict["block"] = input_data.targeting_block_countries
if input_data.targeting_allow_countries:
targeting_dict["allow"] = input_data.targeting_allow_countries
if targeting_dict:
youtube_options["targeting"] = targeting_dict
# Subtitle options
if input_data.subtitle_url:
youtube_options["subTitleUrl"] = input_data.subtitle_url
youtube_options["subTitleLanguage"] = input_data.subtitle_language
youtube_options["subTitleName"] = input_data.subtitle_name
response = await client.create_post(
post=input_data.post,
platforms=[SocialPlatform.YOUTUBE],
media_urls=input_data.media_urls,
is_video=True, # YouTube only supports videos
schedule_date=iso_date,
disable_comments=input_data.disable_comments,
shorten_links=input_data.shorten_links,
unsplash=input_data.unsplash,
requires_approval=input_data.requires_approval,
random_post=input_data.random_post,
random_media_url=input_data.random_media_url,
notes=input_data.notes,
youtube_options=youtube_options,
profile_key=profile_key.get_secret_value(),
)
yield "post_result", response
if response.postIds:
for p in response.postIds:
yield "post", p

View File

@@ -39,13 +39,11 @@ class FileStoreBlock(Block):
input_data: Input,
*,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
yield "file_out", await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.file_in,
user_id=user_id,
return_content=input_data.base_64,
)
@@ -188,31 +186,3 @@ class UniversalTypeConverterBlock(Block):
yield "value", converted_value
except Exception as e:
yield "error", f"Failed to convert value: {str(e)}"
class ReverseListOrderBlock(Block):
"""
A block which takes in a list and returns it in the opposite order.
"""
class Input(BlockSchema):
input_list: list[Any] = SchemaField(description="The list to reverse")
class Output(BlockSchema):
reversed_list: list[Any] = SchemaField(description="The list in reversed order")
def __init__(self):
super().__init__(
id="422cb708-3109-4277-bfe3-bc2ae5812777",
description="Reverses the order of elements in a list",
categories={BlockCategory.BASIC},
input_schema=ReverseListOrderBlock.Input,
output_schema=ReverseListOrderBlock.Output,
test_input={"input_list": [1, 2, 3, 4, 5]},
test_output=[("reversed_list", [5, 4, 3, 2, 1])],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
reversed_list = list(input_data.input_list)
reversed_list.reverse()
yield "reversed_list", reversed_list

View File

@@ -0,0 +1,109 @@
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import ContributorDetails, SchemaField
class ReadCsvBlock(Block):
class Input(BlockSchema):
contents: str = SchemaField(
description="The contents of the CSV file to read",
placeholder="a, b, c\n1,2,3\n4,5,6",
)
delimiter: str = SchemaField(
description="The delimiter used in the CSV file",
default=",",
)
quotechar: str = SchemaField(
description="The character used to quote fields",
default='"',
)
escapechar: str = SchemaField(
description="The character used to escape the delimiter",
default="\\",
)
has_header: bool = SchemaField(
description="Whether the CSV file has a header row",
default=True,
)
skip_rows: int = SchemaField(
description="The number of rows to skip from the start of the file",
default=0,
)
strip: bool = SchemaField(
description="Whether to strip whitespace from the values",
default=True,
)
skip_columns: list[str] = SchemaField(
description="The columns to skip from the start of the row",
default_factory=list,
)
class Output(BlockSchema):
row: dict[str, str] = SchemaField(
description="The data produced from each row in the CSV file"
)
all_data: list[dict[str, str]] = SchemaField(
description="All the data in the CSV file as a list of rows"
)
def __init__(self):
super().__init__(
id="acf7625e-d2cb-4941-bfeb-2819fc6fc015",
input_schema=ReadCsvBlock.Input,
output_schema=ReadCsvBlock.Output,
description="Reads a CSV file and outputs the data as a list of dictionaries and individual rows via rows.",
contributors=[ContributorDetails(name="Nicholas Tindle")],
categories={BlockCategory.TEXT, BlockCategory.DATA},
test_input={
"contents": "a, b, c\n1,2,3\n4,5,6",
},
test_output=[
("row", {"a": "1", "b": "2", "c": "3"}),
("row", {"a": "4", "b": "5", "c": "6"}),
(
"all_data",
[
{"a": "1", "b": "2", "c": "3"},
{"a": "4", "b": "5", "c": "6"},
],
),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
import csv
from io import StringIO
csv_file = StringIO(input_data.contents)
reader = csv.reader(
csv_file,
delimiter=input_data.delimiter,
quotechar=input_data.quotechar,
escapechar=input_data.escapechar,
)
header = None
if input_data.has_header:
header = next(reader)
if input_data.strip:
header = [h.strip() for h in header]
for _ in range(input_data.skip_rows):
next(reader)
def process_row(row):
data = {}
for i, value in enumerate(row):
if i not in input_data.skip_columns:
if input_data.has_header and header:
data[header[i]] = value.strip() if input_data.strip else value
else:
data[str(i)] = value.strip() if input_data.strip else value
return data
all_data = []
for row in reader:
processed_row = process_row(row)
all_data.append(processed_row)
yield "row", processed_row
yield "all_data", all_data

File diff suppressed because it is too large Load Diff

View File

@@ -1,408 +0,0 @@
"""
API module for Enrichlayer integration.
This module provides a client for interacting with the Enrichlayer API,
which allows fetching LinkedIn profile data and related information.
"""
import datetime
import enum
import logging
from json import JSONDecodeError
from typing import Any, Optional, TypeVar
from pydantic import BaseModel, Field
from backend.data.model import APIKeyCredentials
from backend.util.request import Requests
logger = logging.getLogger(__name__)
T = TypeVar("T")
class EnrichlayerAPIException(Exception):
"""Exception raised for Enrichlayer API errors."""
def __init__(self, message: str, status_code: int):
super().__init__(message)
self.status_code = status_code
class FallbackToCache(enum.Enum):
ON_ERROR = "on-error"
NEVER = "never"
class UseCache(enum.Enum):
IF_PRESENT = "if-present"
NEVER = "never"
class SocialMediaProfiles(BaseModel):
"""Social media profiles model."""
twitter: Optional[str] = None
facebook: Optional[str] = None
github: Optional[str] = None
class Experience(BaseModel):
"""Experience model for LinkedIn profiles."""
company: Optional[str] = None
title: Optional[str] = None
description: Optional[str] = None
location: Optional[str] = None
starts_at: Optional[dict[str, int]] = None
ends_at: Optional[dict[str, int]] = None
company_linkedin_profile_url: Optional[str] = None
class Education(BaseModel):
"""Education model for LinkedIn profiles."""
school: Optional[str] = None
degree_name: Optional[str] = None
field_of_study: Optional[str] = None
starts_at: Optional[dict[str, int]] = None
ends_at: Optional[dict[str, int]] = None
school_linkedin_profile_url: Optional[str] = None
class PersonProfileResponse(BaseModel):
"""Response model for LinkedIn person profile.
This model represents the response from Enrichlayer's LinkedIn profile API.
The API returns comprehensive profile data including work experience,
education, skills, and contact information (when available).
Example API Response:
{
"public_identifier": "johnsmith",
"full_name": "John Smith",
"occupation": "Software Engineer at Tech Corp",
"experiences": [
{
"company": "Tech Corp",
"title": "Software Engineer",
"starts_at": {"year": 2020, "month": 1}
}
],
"education": [...],
"skills": ["Python", "JavaScript", ...]
}
"""
public_identifier: Optional[str] = None
profile_pic_url: Optional[str] = None
full_name: Optional[str] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
occupation: Optional[str] = None
headline: Optional[str] = None
summary: Optional[str] = None
country: Optional[str] = None
country_full_name: Optional[str] = None
city: Optional[str] = None
state: Optional[str] = None
experiences: Optional[list[Experience]] = None
education: Optional[list[Education]] = None
languages: Optional[list[str]] = None
skills: Optional[list[str]] = None
inferred_salary: Optional[dict[str, Any]] = None
personal_email: Optional[str] = None
personal_contact_number: Optional[str] = None
social_media_profiles: Optional[SocialMediaProfiles] = None
extra: Optional[dict[str, Any]] = None
class SimilarProfile(BaseModel):
"""Similar profile model for LinkedIn person lookup."""
similarity: float
linkedin_profile_url: str
class PersonLookupResponse(BaseModel):
"""Response model for LinkedIn person lookup.
This model represents the response from Enrichlayer's person lookup API.
The API returns a LinkedIn profile URL and similarity scores when
searching for a person by name and company.
Example API Response:
{
"url": "https://www.linkedin.com/in/johnsmith/",
"name_similarity_score": 0.95,
"company_similarity_score": 0.88,
"title_similarity_score": 0.75,
"location_similarity_score": 0.60
}
"""
url: str | None = None
name_similarity_score: float | None
company_similarity_score: float | None
title_similarity_score: float | None
location_similarity_score: float | None
last_updated: datetime.datetime | None = None
profile: PersonProfileResponse | None = None
class RoleLookupResponse(BaseModel):
"""Response model for LinkedIn role lookup.
This model represents the response from Enrichlayer's role lookup API.
The API returns LinkedIn profile data for a specific role at a company.
Example API Response:
{
"linkedin_profile_url": "https://www.linkedin.com/in/johnsmith/",
"profile_data": {...} // Full PersonProfileResponse data when enrich_profile=True
}
"""
linkedin_profile_url: Optional[str] = None
profile_data: Optional[PersonProfileResponse] = None
class ProfilePictureResponse(BaseModel):
"""Response model for LinkedIn profile picture.
This model represents the response from Enrichlayer's profile picture API.
The API returns a URL to the person's LinkedIn profile picture.
Example API Response:
{
"tmp_profile_pic_url": "https://media.licdn.com/dms/image/..."
}
"""
tmp_profile_pic_url: str = Field(
..., description="URL of the profile picture", alias="tmp_profile_pic_url"
)
@property
def profile_picture_url(self) -> str:
"""Backward compatibility property for profile_picture_url."""
return self.tmp_profile_pic_url
class EnrichlayerClient:
"""Client for interacting with the Enrichlayer API."""
API_BASE_URL = "https://enrichlayer.com/api/v2"
def __init__(
self,
credentials: Optional[APIKeyCredentials] = None,
custom_requests: Optional[Requests] = None,
):
"""
Initialize the Enrichlayer client.
Args:
credentials: The credentials to use for authentication.
custom_requests: Custom Requests instance for testing.
"""
if custom_requests:
self._requests = custom_requests
else:
headers: dict[str, str] = {
"Content-Type": "application/json",
}
if credentials:
headers["Authorization"] = (
f"Bearer {credentials.api_key.get_secret_value()}"
)
self._requests = Requests(
extra_headers=headers,
raise_for_status=False,
)
async def _handle_response(self, response) -> Any:
"""
Handle API response and check for errors.
Args:
response: The response object from the request.
Returns:
The response data.
Raises:
EnrichlayerAPIException: If the API request fails.
"""
if not response.ok:
try:
error_data = response.json()
error_message = error_data.get("message", "")
except JSONDecodeError:
error_message = response.text
raise EnrichlayerAPIException(
f"Enrichlayer API request failed ({response.status_code}): {error_message}",
response.status_code,
)
return response.json()
async def fetch_profile(
self,
linkedin_url: str,
fallback_to_cache: FallbackToCache = FallbackToCache.ON_ERROR,
use_cache: UseCache = UseCache.IF_PRESENT,
include_skills: bool = False,
include_inferred_salary: bool = False,
include_personal_email: bool = False,
include_personal_contact_number: bool = False,
include_social_media: bool = False,
include_extra: bool = False,
) -> PersonProfileResponse:
"""
Fetch a LinkedIn profile with optional parameters.
Args:
linkedin_url: The LinkedIn profile URL to fetch.
fallback_to_cache: Cache usage if live fetch fails ('on-error' or 'never').
use_cache: Cache utilization ('if-present' or 'never').
include_skills: Whether to include skills data.
include_inferred_salary: Whether to include inferred salary data.
include_personal_email: Whether to include personal email.
include_personal_contact_number: Whether to include personal contact number.
include_social_media: Whether to include social media profiles.
include_extra: Whether to include additional data.
Returns:
The LinkedIn profile data.
Raises:
EnrichlayerAPIException: If the API request fails.
"""
params = {
"url": linkedin_url,
"fallback_to_cache": fallback_to_cache.value.lower(),
"use_cache": use_cache.value.lower(),
}
if include_skills:
params["skills"] = "include"
if include_inferred_salary:
params["inferred_salary"] = "include"
if include_personal_email:
params["personal_email"] = "include"
if include_personal_contact_number:
params["personal_contact_number"] = "include"
if include_social_media:
params["twitter_profile_id"] = "include"
params["facebook_profile_id"] = "include"
params["github_profile_id"] = "include"
if include_extra:
params["extra"] = "include"
response = await self._requests.get(
f"{self.API_BASE_URL}/profile", params=params
)
return PersonProfileResponse(**await self._handle_response(response))
async def lookup_person(
self,
first_name: str,
company_domain: str,
last_name: str | None = None,
location: Optional[str] = None,
title: Optional[str] = None,
include_similarity_checks: bool = False,
enrich_profile: bool = False,
) -> PersonLookupResponse:
"""
Look up a LinkedIn profile by person's information.
Args:
first_name: The person's first name.
last_name: The person's last name.
company_domain: The domain of the company they work for.
location: The person's location.
title: The person's job title.
include_similarity_checks: Whether to include similarity checks.
enrich_profile: Whether to enrich the profile.
Returns:
The LinkedIn profile lookup result.
Raises:
EnrichlayerAPIException: If the API request fails.
"""
params = {"first_name": first_name, "company_domain": company_domain}
if last_name:
params["last_name"] = last_name
if location:
params["location"] = location
if title:
params["title"] = title
if include_similarity_checks:
params["similarity_checks"] = "include"
if enrich_profile:
params["enrich_profile"] = "enrich"
response = await self._requests.get(
f"{self.API_BASE_URL}/profile/resolve", params=params
)
return PersonLookupResponse(**await self._handle_response(response))
async def lookup_role(
self, role: str, company_name: str, enrich_profile: bool = False
) -> RoleLookupResponse:
"""
Look up a LinkedIn profile by role in a company.
Args:
role: The role title (e.g., CEO, CTO).
company_name: The name of the company.
enrich_profile: Whether to enrich the profile.
Returns:
The LinkedIn profile lookup result.
Raises:
EnrichlayerAPIException: If the API request fails.
"""
params = {
"role": role,
"company_name": company_name,
}
if enrich_profile:
params["enrich_profile"] = "enrich"
response = await self._requests.get(
f"{self.API_BASE_URL}/find/company/role", params=params
)
return RoleLookupResponse(**await self._handle_response(response))
async def get_profile_picture(
self, linkedin_profile_url: str
) -> ProfilePictureResponse:
"""
Get a LinkedIn profile picture URL.
Args:
linkedin_profile_url: The LinkedIn profile URL.
Returns:
The profile picture URL.
Raises:
EnrichlayerAPIException: If the API request fails.
"""
params = {
"linkedin_person_profile_url": linkedin_profile_url,
}
response = await self._requests.get(
f"{self.API_BASE_URL}/person/profile-picture", params=params
)
return ProfilePictureResponse(**await self._handle_response(response))

View File

@@ -1,34 +0,0 @@
"""
Authentication module for Enrichlayer API integration.
This module provides credential types and test credentials for the Enrichlayer API.
"""
from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
from backend.integrations.providers import ProviderName
# Define the type of credentials input expected for Enrichlayer API
EnrichlayerCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.ENRICHLAYER], Literal["api_key"]
]
# Mock credentials for testing Enrichlayer API integration
TEST_CREDENTIALS = APIKeyCredentials(
id="1234a567-89bc-4def-ab12-3456cdef7890",
provider="enrichlayer",
api_key=SecretStr("mock-enrichlayer-api-key"),
title="Mock Enrichlayer API key",
expires_at=None,
)
# Dictionary representation of test credentials for input fields
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}

View File

@@ -1,527 +0,0 @@
"""
Block definitions for Enrichlayer API integration.
This module implements blocks for interacting with the Enrichlayer API,
which provides access to LinkedIn profile data and related information.
"""
import logging
from typing import Optional
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
from backend.util.type import MediaFileType
from ._api import (
EnrichlayerClient,
Experience,
FallbackToCache,
PersonLookupResponse,
PersonProfileResponse,
RoleLookupResponse,
UseCache,
)
from ._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, EnrichlayerCredentialsInput
logger = logging.getLogger(__name__)
class GetLinkedinProfileBlock(Block):
"""Block to fetch LinkedIn profile data using Enrichlayer API."""
class Input(BlockSchema):
"""Input schema for GetLinkedinProfileBlock."""
linkedin_url: str = SchemaField(
description="LinkedIn profile URL to fetch data from",
placeholder="https://www.linkedin.com/in/username/",
)
fallback_to_cache: FallbackToCache = SchemaField(
description="Cache usage if live fetch fails",
default=FallbackToCache.ON_ERROR,
advanced=True,
)
use_cache: UseCache = SchemaField(
description="Cache utilization strategy",
default=UseCache.IF_PRESENT,
advanced=True,
)
include_skills: bool = SchemaField(
description="Include skills data",
default=False,
advanced=True,
)
include_inferred_salary: bool = SchemaField(
description="Include inferred salary data",
default=False,
advanced=True,
)
include_personal_email: bool = SchemaField(
description="Include personal email",
default=False,
advanced=True,
)
include_personal_contact_number: bool = SchemaField(
description="Include personal contact number",
default=False,
advanced=True,
)
include_social_media: bool = SchemaField(
description="Include social media profiles",
default=False,
advanced=True,
)
include_extra: bool = SchemaField(
description="Include additional data",
default=False,
advanced=True,
)
credentials: EnrichlayerCredentialsInput = CredentialsField(
description="Enrichlayer API credentials"
)
class Output(BlockSchema):
"""Output schema for GetLinkedinProfileBlock."""
profile: PersonProfileResponse = SchemaField(
description="LinkedIn profile data"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
"""Initialize GetLinkedinProfileBlock."""
super().__init__(
id="f6e0ac73-4f1d-4acb-b4b7-b67066c5984e",
description="Fetch LinkedIn profile data using Enrichlayer",
categories={BlockCategory.SOCIAL},
input_schema=GetLinkedinProfileBlock.Input,
output_schema=GetLinkedinProfileBlock.Output,
test_input={
"linkedin_url": "https://www.linkedin.com/in/williamhgates/",
"include_skills": True,
"include_social_media": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
(
"profile",
PersonProfileResponse(
public_identifier="williamhgates",
full_name="Bill Gates",
occupation="Co-chair at Bill & Melinda Gates Foundation",
experiences=[
Experience(
company="Bill & Melinda Gates Foundation",
title="Co-chair",
starts_at={"year": 2000},
)
],
),
)
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"_fetch_profile": lambda *args, **kwargs: PersonProfileResponse(
public_identifier="williamhgates",
full_name="Bill Gates",
occupation="Co-chair at Bill & Melinda Gates Foundation",
experiences=[
Experience(
company="Bill & Melinda Gates Foundation",
title="Co-chair",
starts_at={"year": 2000},
)
],
),
},
)
@staticmethod
async def _fetch_profile(
credentials: APIKeyCredentials,
linkedin_url: str,
fallback_to_cache: FallbackToCache = FallbackToCache.ON_ERROR,
use_cache: UseCache = UseCache.IF_PRESENT,
include_skills: bool = False,
include_inferred_salary: bool = False,
include_personal_email: bool = False,
include_personal_contact_number: bool = False,
include_social_media: bool = False,
include_extra: bool = False,
):
client = EnrichlayerClient(credentials)
profile = await client.fetch_profile(
linkedin_url=linkedin_url,
fallback_to_cache=fallback_to_cache,
use_cache=use_cache,
include_skills=include_skills,
include_inferred_salary=include_inferred_salary,
include_personal_email=include_personal_email,
include_personal_contact_number=include_personal_contact_number,
include_social_media=include_social_media,
include_extra=include_extra,
)
return profile
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
"""
Run the block to fetch LinkedIn profile data.
Args:
input_data: Input parameters for the block
credentials: API key credentials for Enrichlayer
**kwargs: Additional keyword arguments
Yields:
Tuples of (output_name, output_value)
"""
try:
profile = await self._fetch_profile(
credentials=credentials,
linkedin_url=input_data.linkedin_url,
fallback_to_cache=input_data.fallback_to_cache,
use_cache=input_data.use_cache,
include_skills=input_data.include_skills,
include_inferred_salary=input_data.include_inferred_salary,
include_personal_email=input_data.include_personal_email,
include_personal_contact_number=input_data.include_personal_contact_number,
include_social_media=input_data.include_social_media,
include_extra=input_data.include_extra,
)
yield "profile", profile
except Exception as e:
logger.error(f"Error fetching LinkedIn profile: {str(e)}")
yield "error", str(e)
class LinkedinPersonLookupBlock(Block):
"""Block to look up LinkedIn profiles by person's information using Enrichlayer API."""
class Input(BlockSchema):
"""Input schema for LinkedinPersonLookupBlock."""
first_name: str = SchemaField(
description="Person's first name",
placeholder="John",
advanced=False,
)
last_name: str | None = SchemaField(
description="Person's last name",
placeholder="Doe",
default=None,
advanced=False,
)
company_domain: str = SchemaField(
description="Domain of the company they work for (optional)",
placeholder="example.com",
advanced=False,
)
location: Optional[str] = SchemaField(
description="Person's location (optional)",
placeholder="San Francisco",
default=None,
)
title: Optional[str] = SchemaField(
description="Person's job title (optional)",
placeholder="CEO",
default=None,
)
include_similarity_checks: bool = SchemaField(
description="Include similarity checks",
default=False,
advanced=True,
)
enrich_profile: bool = SchemaField(
description="Enrich the profile with additional data",
default=False,
advanced=True,
)
credentials: EnrichlayerCredentialsInput = CredentialsField(
description="Enrichlayer API credentials"
)
class Output(BlockSchema):
"""Output schema for LinkedinPersonLookupBlock."""
lookup_result: PersonLookupResponse = SchemaField(
description="LinkedIn profile lookup result"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
"""Initialize LinkedinPersonLookupBlock."""
super().__init__(
id="d237a98a-5c4b-4a1c-b9e3-e6f9a6c81df7",
description="Look up LinkedIn profiles by person information using Enrichlayer",
categories={BlockCategory.SOCIAL},
input_schema=LinkedinPersonLookupBlock.Input,
output_schema=LinkedinPersonLookupBlock.Output,
test_input={
"first_name": "Bill",
"last_name": "Gates",
"company_domain": "gatesfoundation.org",
"include_similarity_checks": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
(
"lookup_result",
PersonLookupResponse(
url="https://www.linkedin.com/in/williamhgates/",
name_similarity_score=0.93,
company_similarity_score=0.83,
title_similarity_score=0.3,
location_similarity_score=0.20,
),
)
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"_lookup_person": lambda *args, **kwargs: PersonLookupResponse(
url="https://www.linkedin.com/in/williamhgates/",
name_similarity_score=0.93,
company_similarity_score=0.83,
title_similarity_score=0.3,
location_similarity_score=0.20,
)
},
)
@staticmethod
async def _lookup_person(
credentials: APIKeyCredentials,
first_name: str,
company_domain: str,
last_name: str | None = None,
location: Optional[str] = None,
title: Optional[str] = None,
include_similarity_checks: bool = False,
enrich_profile: bool = False,
):
client = EnrichlayerClient(credentials=credentials)
lookup_result = await client.lookup_person(
first_name=first_name,
last_name=last_name,
company_domain=company_domain,
location=location,
title=title,
include_similarity_checks=include_similarity_checks,
enrich_profile=enrich_profile,
)
return lookup_result
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
"""
Run the block to look up LinkedIn profiles.
Args:
input_data: Input parameters for the block
credentials: API key credentials for Enrichlayer
**kwargs: Additional keyword arguments
Yields:
Tuples of (output_name, output_value)
"""
try:
lookup_result = await self._lookup_person(
credentials=credentials,
first_name=input_data.first_name,
last_name=input_data.last_name,
company_domain=input_data.company_domain,
location=input_data.location,
title=input_data.title,
include_similarity_checks=input_data.include_similarity_checks,
enrich_profile=input_data.enrich_profile,
)
yield "lookup_result", lookup_result
except Exception as e:
logger.error(f"Error looking up LinkedIn profile: {str(e)}")
yield "error", str(e)
class LinkedinRoleLookupBlock(Block):
"""Block to look up LinkedIn profiles by role in a company using Enrichlayer API."""
class Input(BlockSchema):
"""Input schema for LinkedinRoleLookupBlock."""
role: str = SchemaField(
description="Role title (e.g., CEO, CTO)",
placeholder="CEO",
)
company_name: str = SchemaField(
description="Name of the company",
placeholder="Microsoft",
)
enrich_profile: bool = SchemaField(
description="Enrich the profile with additional data",
default=False,
advanced=True,
)
credentials: EnrichlayerCredentialsInput = CredentialsField(
description="Enrichlayer API credentials"
)
class Output(BlockSchema):
"""Output schema for LinkedinRoleLookupBlock."""
role_lookup_result: RoleLookupResponse = SchemaField(
description="LinkedIn role lookup result"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
"""Initialize LinkedinRoleLookupBlock."""
super().__init__(
id="3b9fc742-06d4-49c7-b5ce-7e302dd7c8a7",
description="Look up LinkedIn profiles by role in a company using Enrichlayer",
categories={BlockCategory.SOCIAL},
input_schema=LinkedinRoleLookupBlock.Input,
output_schema=LinkedinRoleLookupBlock.Output,
test_input={
"role": "Co-chair",
"company_name": "Gates Foundation",
"enrich_profile": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
(
"role_lookup_result",
RoleLookupResponse(
linkedin_profile_url="https://www.linkedin.com/in/williamhgates/",
),
)
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"_lookup_role": lambda *args, **kwargs: RoleLookupResponse(
linkedin_profile_url="https://www.linkedin.com/in/williamhgates/",
),
},
)
@staticmethod
async def _lookup_role(
credentials: APIKeyCredentials,
role: str,
company_name: str,
enrich_profile: bool = False,
):
client = EnrichlayerClient(credentials=credentials)
role_lookup_result = await client.lookup_role(
role=role,
company_name=company_name,
enrich_profile=enrich_profile,
)
return role_lookup_result
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
"""
Run the block to look up LinkedIn profiles by role.
Args:
input_data: Input parameters for the block
credentials: API key credentials for Enrichlayer
**kwargs: Additional keyword arguments
Yields:
Tuples of (output_name, output_value)
"""
try:
role_lookup_result = await self._lookup_role(
credentials=credentials,
role=input_data.role,
company_name=input_data.company_name,
enrich_profile=input_data.enrich_profile,
)
yield "role_lookup_result", role_lookup_result
except Exception as e:
logger.error(f"Error looking up role in company: {str(e)}")
yield "error", str(e)
class GetLinkedinProfilePictureBlock(Block):
"""Block to get LinkedIn profile pictures using Enrichlayer API."""
class Input(BlockSchema):
"""Input schema for GetLinkedinProfilePictureBlock."""
linkedin_profile_url: str = SchemaField(
description="LinkedIn profile URL",
placeholder="https://www.linkedin.com/in/username/",
)
credentials: EnrichlayerCredentialsInput = CredentialsField(
description="Enrichlayer API credentials"
)
class Output(BlockSchema):
"""Output schema for GetLinkedinProfilePictureBlock."""
profile_picture_url: MediaFileType = SchemaField(
description="LinkedIn profile picture URL"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
"""Initialize GetLinkedinProfilePictureBlock."""
super().__init__(
id="68d5a942-9b3f-4e9a-b7c1-d96ea4321f0d",
description="Get LinkedIn profile pictures using Enrichlayer",
categories={BlockCategory.SOCIAL},
input_schema=GetLinkedinProfilePictureBlock.Input,
output_schema=GetLinkedinProfilePictureBlock.Output,
test_input={
"linkedin_profile_url": "https://www.linkedin.com/in/williamhgates/",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
(
"profile_picture_url",
"https://media.licdn.com/dms/image/C4D03AQFj-xjuXrLFSQ/profile-displayphoto-shrink_800_800/0/1576881858598?e=1686787200&v=beta&t=zrQC76QwsfQQIWthfOnrKRBMZ5D-qIAvzLXLmWgYvTk",
)
],
test_credentials=TEST_CREDENTIALS,
test_mock={
"_get_profile_picture": lambda *args, **kwargs: "https://media.licdn.com/dms/image/C4D03AQFj-xjuXrLFSQ/profile-displayphoto-shrink_800_800/0/1576881858598?e=1686787200&v=beta&t=zrQC76QwsfQQIWthfOnrKRBMZ5D-qIAvzLXLmWgYvTk",
},
)
@staticmethod
async def _get_profile_picture(
credentials: APIKeyCredentials, linkedin_profile_url: str
):
client = EnrichlayerClient(credentials=credentials)
profile_picture_response = await client.get_profile_picture(
linkedin_profile_url=linkedin_profile_url,
)
return profile_picture_response.profile_picture_url
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
"""
Run the block to get LinkedIn profile pictures.
Args:
input_data: Input parameters for the block
credentials: API key credentials for Enrichlayer
**kwargs: Additional keyword arguments
Yields:
Tuples of (output_name, output_value)
"""
try:
profile_picture = await self._get_profile_picture(
credentials=credentials,
linkedin_profile_url=input_data.linkedin_profile_url,
)
yield "profile_picture_url", profile_picture
except Exception as e:
logger.error(f"Error getting profile picture: {str(e)}")
yield "error", str(e)

View File

@@ -51,9 +51,7 @@ class ExaWebhookManager(BaseWebhooksManager):
WEBSET = "webset"
@classmethod
async def validate_payload(
cls, webhook: Webhook, request, credentials: Credentials | None
) -> tuple[dict, str]:
async def validate_payload(cls, webhook: Webhook, request) -> tuple[dict, str]:
"""Validate incoming webhook payload and signature."""
payload = await request.json()

View File

@@ -119,3 +119,6 @@ class ExaAnswerBlock(Block):
except Exception as e:
yield "error", str(e)
yield "answer", ""
yield "citations", []
yield "cost_dollars", {}

View File

@@ -1,247 +0,0 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
# Enum definitions based on available options
class WebsetStatus(str, Enum):
IDLE = "idle"
PENDING = "pending"
RUNNING = "running"
PAUSED = "paused"
class WebsetSearchStatus(str, Enum):
CREATED = "created"
# Add more if known, based on example it's "created"
class ImportStatus(str, Enum):
PENDING = "pending"
# Add more if known
class ImportFormat(str, Enum):
CSV = "csv"
# Add more if known
class EnrichmentStatus(str, Enum):
PENDING = "pending"
# Add more if known
class EnrichmentFormat(str, Enum):
TEXT = "text"
# Add more if known
class MonitorStatus(str, Enum):
ENABLED = "enabled"
# Add more if known
class MonitorBehaviorType(str, Enum):
SEARCH = "search"
# Add more if known
class MonitorRunStatus(str, Enum):
CREATED = "created"
# Add more if known
class CanceledReason(str, Enum):
WEBSET_DELETED = "webset_deleted"
# Add more if known
class FailedReason(str, Enum):
INVALID_FORMAT = "invalid_format"
# Add more if known
class Confidence(str, Enum):
HIGH = "high"
# Add more if known
# Nested models
class Entity(BaseModel):
type: str
class Criterion(BaseModel):
description: str
successRate: Optional[int] = None
class ExcludeItem(BaseModel):
source: str = Field(default="import")
id: str
class Relationship(BaseModel):
definition: str
limit: Optional[float] = None
class ScopeItem(BaseModel):
source: str = Field(default="import")
id: str
relationship: Optional[Relationship] = None
class Progress(BaseModel):
found: int
analyzed: int
completion: int
timeLeft: int
class Bounds(BaseModel):
min: int
max: int
class Expected(BaseModel):
total: int
confidence: str = Field(default="high") # Use str or Confidence enum
bounds: Bounds
class Recall(BaseModel):
expected: Expected
reasoning: str
class WebsetSearch(BaseModel):
id: str
object: str = Field(default="webset_search")
status: str = Field(default="created") # Or use WebsetSearchStatus
websetId: str
query: str
entity: Entity
criteria: List[Criterion]
count: int
behavior: str = Field(default="override")
exclude: List[ExcludeItem]
scope: List[ScopeItem]
progress: Progress
recall: Recall
metadata: Dict[str, Any] = Field(default_factory=dict)
canceledAt: Optional[datetime] = None
canceledReason: Optional[str] = Field(default=None) # Or use CanceledReason
createdAt: datetime
updatedAt: datetime
class ImportEntity(BaseModel):
type: str
class Import(BaseModel):
id: str
object: str = Field(default="import")
status: str = Field(default="pending") # Or use ImportStatus
format: str = Field(default="csv") # Or use ImportFormat
entity: ImportEntity
title: str
count: int
metadata: Dict[str, Any] = Field(default_factory=dict)
failedReason: Optional[str] = Field(default=None) # Or use FailedReason
failedAt: Optional[datetime] = None
failedMessage: Optional[str] = None
createdAt: datetime
updatedAt: datetime
class Option(BaseModel):
label: str
class WebsetEnrichment(BaseModel):
id: str
object: str = Field(default="webset_enrichment")
status: str = Field(default="pending") # Or use EnrichmentStatus
websetId: str
title: str
description: str
format: str = Field(default="text") # Or use EnrichmentFormat
options: List[Option]
instructions: str
metadata: Dict[str, Any] = Field(default_factory=dict)
createdAt: datetime
updatedAt: datetime
class Cadence(BaseModel):
cron: str
timezone: str = Field(default="Etc/UTC")
class BehaviorConfig(BaseModel):
query: Optional[str] = None
criteria: Optional[List[Criterion]] = None
entity: Optional[Entity] = None
count: Optional[int] = None
behavior: Optional[str] = Field(default=None)
class Behavior(BaseModel):
type: str = Field(default="search") # Or use MonitorBehaviorType
config: BehaviorConfig
class MonitorRun(BaseModel):
id: str
object: str = Field(default="monitor_run")
status: str = Field(default="created") # Or use MonitorRunStatus
monitorId: str
type: str = Field(default="search")
completedAt: Optional[datetime] = None
failedAt: Optional[datetime] = None
failedReason: Optional[str] = None
canceledAt: Optional[datetime] = None
createdAt: datetime
updatedAt: datetime
class Monitor(BaseModel):
id: str
object: str = Field(default="monitor")
status: str = Field(default="enabled") # Or use MonitorStatus
websetId: str
cadence: Cadence
behavior: Behavior
lastRun: Optional[MonitorRun] = None
nextRunAt: Optional[datetime] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
createdAt: datetime
updatedAt: datetime
class Webset(BaseModel):
id: str
object: str = Field(default="webset")
status: WebsetStatus
externalId: Optional[str] = None
title: Optional[str] = None
searches: List[WebsetSearch]
imports: List[Import]
enrichments: List[WebsetEnrichment]
monitors: List[Monitor]
streams: List[Any]
createdAt: datetime
updatedAt: datetime
metadata: Dict[str, Any] = Field(default_factory=dict)
class ListWebsets(BaseModel):
data: List[Webset]
hasMore: bool
nextCursor: Optional[str] = None

View File

@@ -114,7 +114,6 @@ class ExaWebsetWebhookBlock(Block):
def __init__(self):
super().__init__(
disabled=True,
id="d0204ed8-8b81-408d-8b8d-ed087a546228",
description="Receive webhook notifications for Exa webset events",
categories={BlockCategory.INPUT},

View File

@@ -1,33 +1,7 @@
from datetime import datetime
from enum import Enum
from typing import Annotated, Any, Dict, List, Optional
from exa_py import Exa
from exa_py.websets.types import (
CreateCriterionParameters,
CreateEnrichmentParameters,
CreateWebsetParameters,
CreateWebsetParametersSearch,
ExcludeItem,
Format,
ImportItem,
ImportSource,
Option,
ScopeItem,
ScopeRelationship,
ScopeSourceType,
WebsetArticleEntity,
WebsetCompanyEntity,
WebsetCustomEntity,
WebsetPersonEntity,
WebsetResearchPaperEntity,
WebsetStatus,
)
from pydantic import Field
from typing import Any, Optional
from backend.sdk import (
APIKeyCredentials,
BaseModel,
Block,
BlockCategory,
BlockOutput,
@@ -38,69 +12,7 @@ from backend.sdk import (
)
from ._config import exa
class SearchEntityType(str, Enum):
COMPANY = "company"
PERSON = "person"
ARTICLE = "article"
RESEARCH_PAPER = "research_paper"
CUSTOM = "custom"
AUTO = "auto"
class SearchType(str, Enum):
IMPORT = "import"
WEBSET = "webset"
class EnrichmentFormat(str, Enum):
TEXT = "text"
DATE = "date"
NUMBER = "number"
OPTIONS = "options"
EMAIL = "email"
PHONE = "phone"
class Webset(BaseModel):
id: str
status: WebsetStatus | None = Field(..., title="WebsetStatus")
"""
The status of the webset
"""
external_id: Annotated[Optional[str], Field(alias="externalId")] = None
"""
The external identifier for the webset
NOTE: Returning dict to avoid ui crashing due to nested objects
"""
searches: List[dict[str, Any]] | None = None
"""
The searches that have been performed on the webset.
NOTE: Returning dict to avoid ui crashing due to nested objects
"""
enrichments: List[dict[str, Any]] | None = None
"""
The Enrichments to apply to the Webset Items.
NOTE: Returning dict to avoid ui crashing due to nested objects
"""
monitors: List[dict[str, Any]] | None = None
"""
The Monitors for the Webset.
NOTE: Returning dict to avoid ui crashing due to nested objects
"""
metadata: Optional[Dict[str, Any]] = {}
"""
Set of key-value pairs you want to associate with this object.
"""
created_at: Annotated[datetime, Field(alias="createdAt")] | None = None
"""
The date and time the webset was created
"""
updated_at: Annotated[datetime, Field(alias="updatedAt")] | None = None
"""
The date and time the webset was last updated
"""
from .helpers import WebsetEnrichmentConfig, WebsetSearchConfig
class ExaCreateWebsetBlock(Block):
@@ -108,121 +20,40 @@ class ExaCreateWebsetBlock(Block):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
# Search parameters (flattened)
search_query: str = SchemaField(
description="Your search query. Use this to describe what you are looking for. Any URL provided will be crawled and used as context for the search.",
placeholder="Marketing agencies based in the US, that focus on consumer products",
search: WebsetSearchConfig = SchemaField(
description="Initial search configuration for the Webset"
)
search_count: Optional[int] = SchemaField(
default=10,
description="Number of items the search will attempt to find. The actual number of items found may be less than this number depending on the search complexity.",
ge=1,
le=1000,
)
search_entity_type: SearchEntityType = SchemaField(
default=SearchEntityType.AUTO,
description="Entity type: 'company', 'person', 'article', 'research_paper', or 'custom'. If not provided, we automatically detect the entity from the query.",
advanced=True,
)
search_entity_description: Optional[str] = SchemaField(
enrichments: Optional[list[WebsetEnrichmentConfig]] = SchemaField(
default=None,
description="Description for custom entity type (required when search_entity_type is 'custom')",
description="Enrichments to apply to Webset items",
advanced=True,
)
# Search criteria (flattened)
search_criteria: list[str] = SchemaField(
default_factory=list,
description="List of criteria descriptions that every item will be evaluated against. If not provided, we automatically detect the criteria from the query.",
advanced=True,
)
# Search exclude sources (flattened)
search_exclude_sources: list[str] = SchemaField(
default_factory=list,
description="List of source IDs (imports or websets) to exclude from search results",
advanced=True,
)
search_exclude_types: list[SearchType] = SchemaField(
default_factory=list,
description="List of source types corresponding to exclude sources ('import' or 'webset')",
advanced=True,
)
# Search scope sources (flattened)
search_scope_sources: list[str] = SchemaField(
default_factory=list,
description="List of source IDs (imports or websets) to limit search scope to",
advanced=True,
)
search_scope_types: list[SearchType] = SchemaField(
default_factory=list,
description="List of source types corresponding to scope sources ('import' or 'webset')",
advanced=True,
)
search_scope_relationships: list[str] = SchemaField(
default_factory=list,
description="List of relationship definitions for hop searches (optional, one per scope source)",
advanced=True,
)
search_scope_relationship_limits: list[int] = SchemaField(
default_factory=list,
description="List of limits on the number of related entities to find (optional, one per scope relationship)",
advanced=True,
)
# Import parameters (flattened)
import_sources: list[str] = SchemaField(
default_factory=list,
description="List of source IDs to import from",
advanced=True,
)
import_types: list[SearchType] = SchemaField(
default_factory=list,
description="List of source types corresponding to import sources ('import' or 'webset')",
advanced=True,
)
# Enrichment parameters (flattened)
enrichment_descriptions: list[str] = SchemaField(
default_factory=list,
description="List of enrichment task descriptions to perform on each webset item",
advanced=True,
)
enrichment_formats: list[EnrichmentFormat] = SchemaField(
default_factory=list,
description="List of formats for enrichment responses ('text', 'date', 'number', 'options', 'email', 'phone'). If not specified, we automatically select the best format.",
advanced=True,
)
enrichment_options: list[list[str]] = SchemaField(
default_factory=list,
description="List of option lists for enrichments with 'options' format. Each inner list contains the option labels.",
advanced=True,
)
enrichment_metadata: list[dict] = SchemaField(
default_factory=list,
description="List of metadata dictionaries for enrichments",
advanced=True,
)
# Webset metadata
external_id: Optional[str] = SchemaField(
default=None,
description="External identifier for the webset. You can use this to reference the webset by your own internal identifiers.",
description="External identifier for the webset",
placeholder="my-webset-123",
advanced=True,
)
metadata: Optional[dict] = SchemaField(
default_factory=dict,
default=None,
description="Key-value pairs to associate with this webset",
advanced=True,
)
class Output(BlockSchema):
webset: Webset = SchemaField(
webset_id: str = SchemaField(
description="The unique identifier for the created webset"
)
status: str = SchemaField(description="The status of the webset")
external_id: Optional[str] = SchemaField(
description="The external identifier for the webset", default=None
)
created_at: str = SchemaField(
description="The date and time the webset was created"
)
error: str = SchemaField(
description="Error message if the request failed", default=""
)
def __init__(self):
super().__init__(
@@ -236,171 +67,44 @@ class ExaCreateWebsetBlock(Block):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/websets/v0/websets"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}
exa = Exa(credentials.api_key.get_secret_value())
# Build the payload
payload: dict[str, Any] = {
"search": input_data.search.model_dump(exclude_none=True),
}
# ------------------------------------------------------------
# Build entity (if explicitly provided)
# ------------------------------------------------------------
entity = None
if input_data.search_entity_type == SearchEntityType.COMPANY:
entity = WebsetCompanyEntity(type="company")
elif input_data.search_entity_type == SearchEntityType.PERSON:
entity = WebsetPersonEntity(type="person")
elif input_data.search_entity_type == SearchEntityType.ARTICLE:
entity = WebsetArticleEntity(type="article")
elif input_data.search_entity_type == SearchEntityType.RESEARCH_PAPER:
entity = WebsetResearchPaperEntity(type="research_paper")
elif (
input_data.search_entity_type == SearchEntityType.CUSTOM
and input_data.search_entity_description
):
entity = WebsetCustomEntity(
type="custom", description=input_data.search_entity_description
)
# Convert enrichments to API format
if input_data.enrichments:
enrichments_data = []
for enrichment in input_data.enrichments:
enrichments_data.append(enrichment.model_dump(exclude_none=True))
payload["enrichments"] = enrichments_data
# ------------------------------------------------------------
# Build criteria list
# ------------------------------------------------------------
criteria = None
if input_data.search_criteria:
criteria = [
CreateCriterionParameters(description=item)
for item in input_data.search_criteria
]
if input_data.external_id:
payload["externalId"] = input_data.external_id
# ------------------------------------------------------------
# Build exclude sources list
# ------------------------------------------------------------
exclude_items = None
if input_data.search_exclude_sources:
exclude_items = []
for idx, src_id in enumerate(input_data.search_exclude_sources):
src_type = None
if input_data.search_exclude_types and idx < len(
input_data.search_exclude_types
):
src_type = input_data.search_exclude_types[idx]
# Default to IMPORT if type missing
if src_type == SearchType.WEBSET:
source_enum = ImportSource.webset
else:
source_enum = ImportSource.import_
exclude_items.append(ExcludeItem(source=source_enum, id=src_id))
if input_data.metadata:
payload["metadata"] = input_data.metadata
# ------------------------------------------------------------
# Build scope list
# ------------------------------------------------------------
scope_items = None
if input_data.search_scope_sources:
scope_items = []
for idx, src_id in enumerate(input_data.search_scope_sources):
src_type = None
if input_data.search_scope_types and idx < len(
input_data.search_scope_types
):
src_type = input_data.search_scope_types[idx]
relationship = None
if input_data.search_scope_relationships and idx < len(
input_data.search_scope_relationships
):
rel_def = input_data.search_scope_relationships[idx]
lim = None
if input_data.search_scope_relationship_limits and idx < len(
input_data.search_scope_relationship_limits
):
lim = input_data.search_scope_relationship_limits[idx]
relationship = ScopeRelationship(definition=rel_def, limit=lim)
if src_type == SearchType.WEBSET:
src_enum = ScopeSourceType.webset
else:
src_enum = ScopeSourceType.import_
scope_items.append(
ScopeItem(source=src_enum, id=src_id, relationship=relationship)
)
try:
response = await Requests().post(url, headers=headers, json=payload)
data = response.json()
# ------------------------------------------------------------
# Assemble search parameters (only if a query is provided)
# ------------------------------------------------------------
search_params = None
if input_data.search_query:
search_params = CreateWebsetParametersSearch(
query=input_data.search_query,
count=input_data.search_count,
entity=entity,
criteria=criteria,
exclude=exclude_items,
scope=scope_items,
)
yield "webset_id", data.get("id", "")
yield "status", data.get("status", "")
yield "external_id", data.get("externalId")
yield "created_at", data.get("createdAt", "")
# ------------------------------------------------------------
# Build imports list
# ------------------------------------------------------------
imports_params = None
if input_data.import_sources:
imports_params = []
for idx, src_id in enumerate(input_data.import_sources):
src_type = None
if input_data.import_types and idx < len(input_data.import_types):
src_type = input_data.import_types[idx]
if src_type == SearchType.WEBSET:
source_enum = ImportSource.webset
else:
source_enum = ImportSource.import_
imports_params.append(ImportItem(source=source_enum, id=src_id))
# ------------------------------------------------------------
# Build enrichment list
# ------------------------------------------------------------
enrichments_params = None
if input_data.enrichment_descriptions:
enrichments_params = []
for idx, desc in enumerate(input_data.enrichment_descriptions):
fmt = None
if input_data.enrichment_formats and idx < len(
input_data.enrichment_formats
):
fmt_enum = input_data.enrichment_formats[idx]
if fmt_enum is not None:
fmt = Format(
fmt_enum.value if isinstance(fmt_enum, Enum) else fmt_enum
)
options_list = None
if input_data.enrichment_options and idx < len(
input_data.enrichment_options
):
raw_opts = input_data.enrichment_options[idx]
if raw_opts:
options_list = [Option(label=o) for o in raw_opts]
metadata_obj = None
if input_data.enrichment_metadata and idx < len(
input_data.enrichment_metadata
):
metadata_obj = input_data.enrichment_metadata[idx]
enrichments_params.append(
CreateEnrichmentParameters(
description=desc,
format=fmt,
options=options_list,
metadata=metadata_obj,
)
)
# ------------------------------------------------------------
# Create the webset
# ------------------------------------------------------------
webset = exa.websets.create(
params=CreateWebsetParameters(
search=search_params,
imports=imports_params,
enrichments=enrichments_params,
external_id=input_data.external_id,
metadata=input_data.metadata,
)
)
# Use alias field names returned from Exa SDK so that nested models validate correctly
yield "webset", Webset.model_validate(webset.model_dump(by_alias=True))
except Exception as e:
yield "error", str(e)
yield "webset_id", ""
yield "status", ""
yield "created_at", ""
class ExaUpdateWebsetBlock(Block):
@@ -479,11 +183,6 @@ class ExaListWebsetsBlock(Block):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
trigger: Any | None = SchemaField(
default=None,
description="Trigger for the webset, value is ignored!",
advanced=False,
)
cursor: Optional[str] = SchemaField(
default=None,
description="Cursor for pagination through results",
@@ -498,9 +197,7 @@ class ExaListWebsetsBlock(Block):
)
class Output(BlockSchema):
websets: list[Webset] = SchemaField(
description="List of websets", default_factory=list
)
websets: list = SchemaField(description="List of websets", default_factory=list)
has_more: bool = SchemaField(
description="Whether there are more results to paginate through",
default=False,
@@ -558,6 +255,9 @@ class ExaGetWebsetBlock(Block):
description="The ID or external ID of the Webset to retrieve",
placeholder="webset-id-or-external-id",
)
expand_items: bool = SchemaField(
default=False, description="Include items in the response", advanced=True
)
class Output(BlockSchema):
webset_id: str = SchemaField(description="The unique identifier for the webset")
@@ -609,8 +309,12 @@ class ExaGetWebsetBlock(Block):
"x-api-key": credentials.api_key.get_secret_value(),
}
params = {}
if input_data.expand_items:
params["expand[]"] = "items"
try:
response = await Requests().get(url, headers=headers)
response = await Requests().get(url, headers=headers, params=params)
data = response.json()
yield "webset_id", data.get("id", "")

View File

@@ -1,8 +0,0 @@
from backend.sdk import BlockCostType, ProviderBuilder
firecrawl = (
ProviderBuilder("firecrawl")
.with_api_key("FIRECRAWL_API_KEY", "Firecrawl API Key")
.with_base_cost(1, BlockCostType.RUN)
.build()
)

View File

@@ -1,114 +0,0 @@
from enum import Enum
from typing import Any
from firecrawl import FirecrawlApp, ScrapeOptions
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
SchemaField,
)
from ._config import firecrawl
class ScrapeFormat(Enum):
MARKDOWN = "markdown"
HTML = "html"
RAW_HTML = "rawHtml"
LINKS = "links"
SCREENSHOT = "screenshot"
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
JSON = "json"
CHANGE_TRACKING = "changeTracking"
class FirecrawlCrawlBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = firecrawl.credentials_field()
url: str = SchemaField(description="The URL to crawl")
limit: int = SchemaField(description="The number of pages to crawl", default=10)
only_main_content: bool = SchemaField(
description="Only return the main content of the page excluding headers, navs, footers, etc.",
default=True,
)
max_age: int = SchemaField(
description="The maximum age of the page in milliseconds - default is 1 hour",
default=3600000,
)
wait_for: int = SchemaField(
description="Specify a delay in milliseconds before fetching the content, allowing the page sufficient time to load.",
default=0,
)
formats: list[ScrapeFormat] = SchemaField(
description="The format of the crawl", default=[ScrapeFormat.MARKDOWN]
)
class Output(BlockSchema):
data: list[dict[str, Any]] = SchemaField(description="The result of the crawl")
markdown: str = SchemaField(description="The markdown of the crawl")
html: str = SchemaField(description="The html of the crawl")
raw_html: str = SchemaField(description="The raw html of the crawl")
links: list[str] = SchemaField(description="The links of the crawl")
screenshot: str = SchemaField(description="The screenshot of the crawl")
screenshot_full_page: str = SchemaField(
description="The screenshot full page of the crawl"
)
json_data: dict[str, Any] = SchemaField(
description="The json data of the crawl"
)
change_tracking: dict[str, Any] = SchemaField(
description="The change tracking of the crawl"
)
def __init__(self):
super().__init__(
id="bdbbaba0-03b7-4971-970e-699e2de6015e",
description="Firecrawl crawls websites to extract comprehensive data while bypassing blockers.",
categories={BlockCategory.SEARCH},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
# Sync call
crawl_result = app.crawl_url(
input_data.url,
limit=input_data.limit,
scrape_options=ScrapeOptions(
formats=[format.value for format in input_data.formats],
onlyMainContent=input_data.only_main_content,
maxAge=input_data.max_age,
waitFor=input_data.wait_for,
),
)
yield "data", crawl_result.data
for data in crawl_result.data:
for f in input_data.formats:
if f == ScrapeFormat.MARKDOWN:
yield "markdown", data.markdown
elif f == ScrapeFormat.HTML:
yield "html", data.html
elif f == ScrapeFormat.RAW_HTML:
yield "raw_html", data.rawHtml
elif f == ScrapeFormat.LINKS:
yield "links", data.links
elif f == ScrapeFormat.SCREENSHOT:
yield "screenshot", data.screenshot
elif f == ScrapeFormat.SCREENSHOT_FULL_PAGE:
yield "screenshot_full_page", data.screenshot
elif f == ScrapeFormat.CHANGE_TRACKING:
yield "change_tracking", data.changeTracking
elif f == ScrapeFormat.JSON:
yield "json", data.json

View File

@@ -1,66 +0,0 @@
from typing import Any
from firecrawl import FirecrawlApp
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockCost,
BlockCostType,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
SchemaField,
cost,
)
from ._config import firecrawl
@cost(BlockCost(2, BlockCostType.RUN))
class FirecrawlExtractBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = firecrawl.credentials_field()
urls: list[str] = SchemaField(
description="The URLs to crawl - at least one is required. Wildcards are supported. (/*)"
)
prompt: str | None = SchemaField(
description="The prompt to use for the crawl", default=None, advanced=False
)
output_schema: dict | None = SchemaField(
description="A Json Schema describing the output structure if more rigid structure is desired.",
default=None,
)
enable_web_search: bool = SchemaField(
description="When true, extraction can follow links outside the specified domain.",
default=False,
)
class Output(BlockSchema):
data: dict[str, Any] = SchemaField(description="The result of the crawl")
def __init__(self):
super().__init__(
id="d1774756-4d9e-40e6-bab1-47ec0ccd81b2",
description="Firecrawl crawls websites to extract comprehensive data while bypassing blockers.",
categories={BlockCategory.SEARCH},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
extract_result = app.extract(
urls=input_data.urls,
prompt=input_data.prompt,
schema=input_data.output_schema,
enable_web_search=input_data.enable_web_search,
)
yield "data", extract_result.data

View File

@@ -1,46 +0,0 @@
from firecrawl import FirecrawlApp
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
SchemaField,
)
from ._config import firecrawl
class FirecrawlMapWebsiteBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = firecrawl.credentials_field()
url: str = SchemaField(description="The website url to map")
class Output(BlockSchema):
links: list[str] = SchemaField(description="The links of the website")
def __init__(self):
super().__init__(
id="f0f43e2b-c943-48a0-a7f1-40136ca4d3b9",
description="Firecrawl maps a website to extract all the links.",
categories={BlockCategory.SEARCH},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
# Sync call
map_result = app.map_url(
url=input_data.url,
)
yield "links", map_result.links

View File

@@ -1,109 +0,0 @@
from enum import Enum
from typing import Any
from firecrawl import FirecrawlApp
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
SchemaField,
)
from ._config import firecrawl
class ScrapeFormat(Enum):
MARKDOWN = "markdown"
HTML = "html"
RAW_HTML = "rawHtml"
LINKS = "links"
SCREENSHOT = "screenshot"
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
JSON = "json"
CHANGE_TRACKING = "changeTracking"
class FirecrawlScrapeBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = firecrawl.credentials_field()
url: str = SchemaField(description="The URL to crawl")
limit: int = SchemaField(description="The number of pages to crawl", default=10)
only_main_content: bool = SchemaField(
description="Only return the main content of the page excluding headers, navs, footers, etc.",
default=True,
)
max_age: int = SchemaField(
description="The maximum age of the page in milliseconds - default is 1 hour",
default=3600000,
)
wait_for: int = SchemaField(
description="Specify a delay in milliseconds before fetching the content, allowing the page sufficient time to load.",
default=200,
)
formats: list[ScrapeFormat] = SchemaField(
description="The format of the crawl", default=[ScrapeFormat.MARKDOWN]
)
class Output(BlockSchema):
data: dict[str, Any] = SchemaField(description="The result of the crawl")
markdown: str = SchemaField(description="The markdown of the crawl")
html: str = SchemaField(description="The html of the crawl")
raw_html: str = SchemaField(description="The raw html of the crawl")
links: list[str] = SchemaField(description="The links of the crawl")
screenshot: str = SchemaField(description="The screenshot of the crawl")
screenshot_full_page: str = SchemaField(
description="The screenshot full page of the crawl"
)
json_data: dict[str, Any] = SchemaField(
description="The json data of the crawl"
)
change_tracking: dict[str, Any] = SchemaField(
description="The change tracking of the crawl"
)
def __init__(self):
super().__init__(
id="ac444320-cf5e-4697-b586-2604c17a3e75",
description="Firecrawl scrapes a website to extract comprehensive data while bypassing blockers.",
categories={BlockCategory.SEARCH},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
scrape_result = app.scrape_url(
input_data.url,
formats=[format.value for format in input_data.formats],
only_main_content=input_data.only_main_content,
max_age=input_data.max_age,
wait_for=input_data.wait_for,
)
yield "data", scrape_result
for f in input_data.formats:
if f == ScrapeFormat.MARKDOWN:
yield "markdown", scrape_result.markdown
elif f == ScrapeFormat.HTML:
yield "html", scrape_result.html
elif f == ScrapeFormat.RAW_HTML:
yield "raw_html", scrape_result.rawHtml
elif f == ScrapeFormat.LINKS:
yield "links", scrape_result.links
elif f == ScrapeFormat.SCREENSHOT:
yield "screenshot", scrape_result.screenshot
elif f == ScrapeFormat.SCREENSHOT_FULL_PAGE:
yield "screenshot_full_page", scrape_result.screenshot
elif f == ScrapeFormat.CHANGE_TRACKING:
yield "change_tracking", scrape_result.changeTracking
elif f == ScrapeFormat.JSON:
yield "json", scrape_result.json

View File

@@ -1,79 +0,0 @@
from enum import Enum
from typing import Any
from firecrawl import FirecrawlApp, ScrapeOptions
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
SchemaField,
)
from ._config import firecrawl
class ScrapeFormat(Enum):
MARKDOWN = "markdown"
HTML = "html"
RAW_HTML = "rawHtml"
LINKS = "links"
SCREENSHOT = "screenshot"
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
JSON = "json"
CHANGE_TRACKING = "changeTracking"
class FirecrawlSearchBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = firecrawl.credentials_field()
query: str = SchemaField(description="The query to search for")
limit: int = SchemaField(description="The number of pages to crawl", default=10)
max_age: int = SchemaField(
description="The maximum age of the page in milliseconds - default is 1 hour",
default=3600000,
)
wait_for: int = SchemaField(
description="Specify a delay in milliseconds before fetching the content, allowing the page sufficient time to load.",
default=200,
)
formats: list[ScrapeFormat] = SchemaField(
description="Returns the content of the search if specified", default=[]
)
class Output(BlockSchema):
data: dict[str, Any] = SchemaField(description="The result of the search")
site: dict[str, Any] = SchemaField(description="The site of the search")
def __init__(self):
super().__init__(
id="f8d2f28d-b3a1-405b-804e-418c087d288b",
description="Firecrawl searches the web for the given query.",
categories={BlockCategory.SEARCH},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
# Sync call
scrape_result = app.search(
input_data.query,
limit=input_data.limit,
scrape_options=ScrapeOptions(
formats=[format.value for format in input_data.formats],
maxAge=input_data.max_age,
waitFor=input_data.wait_for,
),
)
yield "data", scrape_result
for site in scrape_result.data:
yield "site", site

View File

@@ -129,7 +129,6 @@ class AIImageEditorBlock(Block):
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
result = await self.run_model(
@@ -140,7 +139,6 @@ class AIImageEditorBlock(Block):
await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.input_image,
user_id=user_id,
return_content=True,
)
if input_data.input_image

View File

@@ -3,7 +3,7 @@ import logging
from fastapi import Request
from strenum import StrEnum
from backend.sdk import Credentials, ManualWebhookManagerBase, Webhook
from backend.sdk import ManualWebhookManagerBase, Webhook
logger = logging.getLogger(__name__)
@@ -17,7 +17,7 @@ class GenericWebhooksManager(ManualWebhookManagerBase):
@classmethod
async def validate_payload(
cls, webhook: Webhook, request: Request, credentials: Credentials | None = None
cls, webhook: Webhook, request: Request
) -> tuple[dict, str]:
payload = await request.json()
event_type = GenericWebhookType.PLAIN

View File

@@ -1,388 +0,0 @@
import logging
import re
from enum import Enum
from typing import Optional
from typing_extensions import TypedDict
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
logger = logging.getLogger(__name__)
class CheckRunStatus(Enum):
QUEUED = "queued"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
class CheckRunConclusion(Enum):
SUCCESS = "success"
FAILURE = "failure"
NEUTRAL = "neutral"
CANCELLED = "cancelled"
SKIPPED = "skipped"
TIMED_OUT = "timed_out"
ACTION_REQUIRED = "action_required"
class GithubGetCIResultsBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo: str = SchemaField(
description="GitHub repository",
placeholder="owner/repo",
)
target: str | int = SchemaField(
description="Commit SHA or PR number to get CI results for",
placeholder="abc123def or 123",
)
search_pattern: Optional[str] = SchemaField(
description="Optional regex pattern to search for in CI logs (e.g., error messages, file names)",
placeholder=".*error.*|.*warning.*",
default=None,
advanced=True,
)
check_name_filter: Optional[str] = SchemaField(
description="Optional filter for specific check names (supports wildcards)",
placeholder="*lint* or build-*",
default=None,
advanced=True,
)
class Output(BlockSchema):
class CheckRunItem(TypedDict, total=False):
id: int
name: str
status: str
conclusion: Optional[str]
started_at: Optional[str]
completed_at: Optional[str]
html_url: str
details_url: Optional[str]
output_title: Optional[str]
output_summary: Optional[str]
output_text: Optional[str]
annotations: list[dict]
class MatchedLine(TypedDict):
check_name: str
line_number: int
line: str
context: list[str]
check_run: CheckRunItem = SchemaField(
title="Check Run",
description="Individual CI check run with details",
)
check_runs: list[CheckRunItem] = SchemaField(
description="List of all CI check runs"
)
matched_line: MatchedLine = SchemaField(
title="Matched Line",
description="Line matching the search pattern with context",
)
matched_lines: list[MatchedLine] = SchemaField(
description="All lines matching the search pattern across all checks"
)
overall_status: str = SchemaField(
description="Overall CI status (pending, success, failure)"
)
overall_conclusion: str = SchemaField(
description="Overall CI conclusion if completed"
)
total_checks: int = SchemaField(description="Total number of CI checks")
passed_checks: int = SchemaField(description="Number of passed checks")
failed_checks: int = SchemaField(description="Number of failed checks")
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="8ad9e103-78f2-4fdb-ba12-3571f2c95e98",
description="This block gets CI results for a commit or PR, with optional search for specific errors/warnings in logs.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubGetCIResultsBlock.Input,
output_schema=GithubGetCIResultsBlock.Output,
test_input={
"repo": "owner/repo",
"target": "abc123def456",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("overall_status", "completed"),
("overall_conclusion", "success"),
("total_checks", 1),
("passed_checks", 1),
("failed_checks", 0),
(
"check_runs",
[
{
"id": 123456,
"name": "build",
"status": "completed",
"conclusion": "success",
"started_at": "2024-01-01T00:00:00Z",
"completed_at": "2024-01-01T00:05:00Z",
"html_url": "https://github.com/owner/repo/runs/123456",
"details_url": None,
"output_title": "Build passed",
"output_summary": "All tests passed",
"output_text": "Build log output...",
"annotations": [],
}
],
),
],
test_mock={
"get_ci_results": lambda *args, **kwargs: {
"check_runs": [
{
"id": 123456,
"name": "build",
"status": "completed",
"conclusion": "success",
"started_at": "2024-01-01T00:00:00Z",
"completed_at": "2024-01-01T00:05:00Z",
"html_url": "https://github.com/owner/repo/runs/123456",
"details_url": None,
"output_title": "Build passed",
"output_summary": "All tests passed",
"output_text": "Build log output...",
"annotations": [],
}
],
"total_count": 1,
}
},
)
@staticmethod
async def get_commit_sha(api, repo: str, target: str | int) -> str:
"""Get commit SHA from either a commit SHA or PR URL."""
# If it's already a SHA, return it
if isinstance(target, str):
if re.match(r"^[0-9a-f]{6,40}$", target, re.IGNORECASE):
return target
# If it's a PR URL, get the head SHA
if isinstance(target, int):
pr_url = f"https://api.github.com/repos/{repo}/pulls/{target}"
response = await api.get(pr_url)
pr_data = response.json()
return pr_data["head"]["sha"]
raise ValueError("Target must be a commit SHA or PR URL")
@staticmethod
async def search_in_logs(
check_runs: list,
pattern: str,
) -> list[Output.MatchedLine]:
"""Search for pattern in check run logs."""
if not pattern:
return []
matched_lines = []
regex = re.compile(pattern, re.IGNORECASE | re.MULTILINE)
for check in check_runs:
output_text = check.get("output_text", "") or ""
if not output_text:
continue
lines = output_text.split("\n")
for i, line in enumerate(lines):
if regex.search(line):
# Get context (2 lines before and after)
start = max(0, i - 2)
end = min(len(lines), i + 3)
context = lines[start:end]
matched_lines.append(
{
"check_name": check["name"],
"line_number": i + 1,
"line": line,
"context": context,
}
)
return matched_lines
@staticmethod
async def get_ci_results(
credentials: GithubCredentials,
repo: str,
target: str | int,
search_pattern: Optional[str] = None,
check_name_filter: Optional[str] = None,
) -> dict:
api = get_api(credentials, convert_urls=False)
# Get the commit SHA
commit_sha = await GithubGetCIResultsBlock.get_commit_sha(api, repo, target)
# Get check runs for the commit
check_runs_url = (
f"https://api.github.com/repos/{repo}/commits/{commit_sha}/check-runs"
)
# Get all pages of check runs
all_check_runs = []
page = 1
per_page = 100
while True:
response = await api.get(
check_runs_url, params={"per_page": per_page, "page": page}
)
data = response.json()
check_runs = data.get("check_runs", [])
all_check_runs.extend(check_runs)
if len(check_runs) < per_page:
break
page += 1
# Filter by check name if specified
if check_name_filter:
import fnmatch
filtered_runs = []
for run in all_check_runs:
if fnmatch.fnmatch(run["name"].lower(), check_name_filter.lower()):
filtered_runs.append(run)
all_check_runs = filtered_runs
# Get check run details with logs
detailed_runs = []
for run in all_check_runs:
# Get detailed output including logs
if run.get("output", {}).get("text"):
# Already has output
detailed_run = {
"id": run["id"],
"name": run["name"],
"status": run["status"],
"conclusion": run.get("conclusion"),
"started_at": run.get("started_at"),
"completed_at": run.get("completed_at"),
"html_url": run["html_url"],
"details_url": run.get("details_url"),
"output_title": run.get("output", {}).get("title"),
"output_summary": run.get("output", {}).get("summary"),
"output_text": run.get("output", {}).get("text"),
"annotations": [],
}
else:
# Try to get logs from the check run
detailed_run = {
"id": run["id"],
"name": run["name"],
"status": run["status"],
"conclusion": run.get("conclusion"),
"started_at": run.get("started_at"),
"completed_at": run.get("completed_at"),
"html_url": run["html_url"],
"details_url": run.get("details_url"),
"output_title": run.get("output", {}).get("title"),
"output_summary": run.get("output", {}).get("summary"),
"output_text": None,
"annotations": [],
}
# Get annotations if available
if run.get("output", {}).get("annotations_count", 0) > 0:
annotations_url = f"https://api.github.com/repos/{repo}/check-runs/{run['id']}/annotations"
try:
ann_response = await api.get(annotations_url)
detailed_run["annotations"] = ann_response.json()
except Exception:
pass
detailed_runs.append(detailed_run)
return {
"check_runs": detailed_runs,
"total_count": len(detailed_runs),
}
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
target = int(input_data.target)
except ValueError:
target = input_data.target
result = await self.get_ci_results(
credentials,
input_data.repo,
target,
input_data.search_pattern,
input_data.check_name_filter,
)
check_runs = result["check_runs"]
# Calculate overall status
if not check_runs:
yield "overall_status", "no_checks"
yield "overall_conclusion", "no_checks"
else:
all_completed = all(run["status"] == "completed" for run in check_runs)
if all_completed:
yield "overall_status", "completed"
# Determine overall conclusion
has_failure = any(
run["conclusion"] in ["failure", "timed_out", "action_required"]
for run in check_runs
)
if has_failure:
yield "overall_conclusion", "failure"
else:
yield "overall_conclusion", "success"
else:
yield "overall_status", "pending"
yield "overall_conclusion", "pending"
# Count checks
total = len(check_runs)
passed = sum(1 for run in check_runs if run.get("conclusion") == "success")
failed = sum(
1 for run in check_runs if run.get("conclusion") in ["failure", "timed_out"]
)
yield "total_checks", total
yield "passed_checks", passed
yield "failed_checks", failed
# Output check runs
yield "check_runs", check_runs
# Search for patterns if specified
if input_data.search_pattern:
matched_lines = await self.search_in_logs(
check_runs, input_data.search_pattern
)
if matched_lines:
yield "matched_lines", matched_lines

View File

@@ -1,840 +0,0 @@
import logging
from enum import Enum
from typing import Any, List, Optional
from typing_extensions import TypedDict
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from ._api import get_api
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
GithubCredentials,
GithubCredentialsField,
GithubCredentialsInput,
)
logger = logging.getLogger(__name__)
class ReviewEvent(Enum):
COMMENT = "COMMENT"
APPROVE = "APPROVE"
REQUEST_CHANGES = "REQUEST_CHANGES"
class GithubCreatePRReviewBlock(Block):
class Input(BlockSchema):
class ReviewComment(TypedDict, total=False):
path: str
position: Optional[int]
body: str
line: Optional[int] # Will be used as position if position not provided
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo: str = SchemaField(
description="GitHub repository",
placeholder="owner/repo",
)
pr_number: int = SchemaField(
description="Pull request number",
placeholder="123",
)
body: str = SchemaField(
description="Body of the review comment",
placeholder="Enter your review comment",
)
event: ReviewEvent = SchemaField(
description="The review action to perform",
default=ReviewEvent.COMMENT,
)
create_as_draft: bool = SchemaField(
description="Create the review as a draft (pending) or post it immediately",
default=False,
advanced=False,
)
comments: Optional[List[ReviewComment]] = SchemaField(
description="Optional inline comments to add to specific files/lines. Note: Only path, body, and position are supported. Position is line number in diff from first @@ hunk.",
default=None,
advanced=True,
)
class Output(BlockSchema):
review_id: int = SchemaField(description="ID of the created review")
state: str = SchemaField(
description="State of the review (e.g., PENDING, COMMENTED, APPROVED, CHANGES_REQUESTED)"
)
html_url: str = SchemaField(description="URL of the created review")
error: str = SchemaField(
description="Error message if the review creation failed"
)
def __init__(self):
super().__init__(
id="84754b30-97d2-4c37-a3b8-eb39f268275b",
description="This block creates a review on a GitHub pull request with optional inline comments. You can create it as a draft or post immediately. Note: For inline comments, 'position' should be the line number in the diff (starting from the first @@ hunk header).",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCreatePRReviewBlock.Input,
output_schema=GithubCreatePRReviewBlock.Output,
test_input={
"repo": "owner/repo",
"pr_number": 1,
"body": "This looks good to me!",
"event": "APPROVE",
"create_as_draft": False,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("review_id", 123456),
("state", "APPROVED"),
(
"html_url",
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
),
],
test_mock={
"create_review": lambda *args, **kwargs: (
123456,
"APPROVED",
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
)
},
)
@staticmethod
async def create_review(
credentials: GithubCredentials,
repo: str,
pr_number: int,
body: str,
event: ReviewEvent,
create_as_draft: bool,
comments: Optional[List[Input.ReviewComment]] = None,
) -> tuple[int, str, str]:
api = get_api(credentials, convert_urls=False)
# GitHub API endpoint for creating reviews
reviews_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews"
# Get commit_id if we have comments
commit_id = None
if comments:
# Get PR details to get the head commit for inline comments
pr_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}"
pr_response = await api.get(pr_url)
pr_data = pr_response.json()
commit_id = pr_data["head"]["sha"]
# Prepare the request data
# If create_as_draft is True, omit the event field (creates a PENDING review)
# Otherwise, use the actual event value which will auto-submit the review
data: dict[str, Any] = {"body": body}
# Add commit_id if we have it
if commit_id:
data["commit_id"] = commit_id
# Add comments if provided
if comments:
# Process comments to ensure they have the required fields
processed_comments = []
for comment in comments:
comment_data: dict = {
"path": comment.get("path", ""),
"body": comment.get("body", ""),
}
# Add position or line
# Note: For review comments, only position is supported (not line/side)
if "position" in comment and comment.get("position") is not None:
comment_data["position"] = comment.get("position")
elif "line" in comment and comment.get("line") is not None:
# Note: Using line as position - may not work correctly
# Position should be calculated from the diff
comment_data["position"] = comment.get("line")
# Note: side, start_line, and start_side are NOT supported for review comments
# They are only for standalone PR comments
processed_comments.append(comment_data)
data["comments"] = processed_comments
if not create_as_draft:
# Only add event field if not creating a draft
data["event"] = event.value
# Create the review
response = await api.post(reviews_url, json=data)
review_data = response.json()
return review_data["id"], review_data["state"], review_data["html_url"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
review_id, state, html_url = await self.create_review(
credentials,
input_data.repo,
input_data.pr_number,
input_data.body,
input_data.event,
input_data.create_as_draft,
input_data.comments,
)
yield "review_id", review_id
yield "state", state
yield "html_url", html_url
except Exception as e:
yield "error", str(e)
class GithubListPRReviewsBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo: str = SchemaField(
description="GitHub repository",
placeholder="owner/repo",
)
pr_number: int = SchemaField(
description="Pull request number",
placeholder="123",
)
class Output(BlockSchema):
class ReviewItem(TypedDict):
id: int
user: str
state: str
body: str
html_url: str
review: ReviewItem = SchemaField(
title="Review",
description="Individual review with details",
)
reviews: list[ReviewItem] = SchemaField(
description="List of all reviews on the pull request"
)
error: str = SchemaField(description="Error message if listing reviews failed")
def __init__(self):
super().__init__(
id="f79bc6eb-33c0-4099-9c0f-d664ae1ba4d0",
description="This block lists all reviews for a specified GitHub pull request.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubListPRReviewsBlock.Input,
output_schema=GithubListPRReviewsBlock.Output,
test_input={
"repo": "owner/repo",
"pr_number": 1,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"reviews",
[
{
"id": 123456,
"user": "reviewer1",
"state": "APPROVED",
"body": "Looks good!",
"html_url": "https://github.com/owner/repo/pull/1#pullrequestreview-123456",
}
],
),
(
"review",
{
"id": 123456,
"user": "reviewer1",
"state": "APPROVED",
"body": "Looks good!",
"html_url": "https://github.com/owner/repo/pull/1#pullrequestreview-123456",
},
),
],
test_mock={
"list_reviews": lambda *args, **kwargs: [
{
"id": 123456,
"user": "reviewer1",
"state": "APPROVED",
"body": "Looks good!",
"html_url": "https://github.com/owner/repo/pull/1#pullrequestreview-123456",
}
]
},
)
@staticmethod
async def list_reviews(
credentials: GithubCredentials, repo: str, pr_number: int
) -> list[Output.ReviewItem]:
api = get_api(credentials, convert_urls=False)
# GitHub API endpoint for listing reviews
reviews_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews"
response = await api.get(reviews_url)
data = response.json()
reviews: list[GithubListPRReviewsBlock.Output.ReviewItem] = [
{
"id": review["id"],
"user": review["user"]["login"],
"state": review["state"],
"body": review.get("body", ""),
"html_url": review["html_url"],
}
for review in data
]
return reviews
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
reviews = await self.list_reviews(
credentials,
input_data.repo,
input_data.pr_number,
)
yield "reviews", reviews
for review in reviews:
yield "review", review
class GithubSubmitPendingReviewBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo: str = SchemaField(
description="GitHub repository",
placeholder="owner/repo",
)
pr_number: int = SchemaField(
description="Pull request number",
placeholder="123",
)
review_id: int = SchemaField(
description="ID of the pending review to submit",
placeholder="123456",
)
event: ReviewEvent = SchemaField(
description="The review action to perform when submitting",
default=ReviewEvent.COMMENT,
)
class Output(BlockSchema):
state: str = SchemaField(description="State of the submitted review")
html_url: str = SchemaField(description="URL of the submitted review")
error: str = SchemaField(
description="Error message if the review submission failed"
)
def __init__(self):
super().__init__(
id="2e468217-7ca0-4201-9553-36e93eb9357a",
description="This block submits a pending (draft) review on a GitHub pull request.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubSubmitPendingReviewBlock.Input,
output_schema=GithubSubmitPendingReviewBlock.Output,
test_input={
"repo": "owner/repo",
"pr_number": 1,
"review_id": 123456,
"event": "APPROVE",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("state", "APPROVED"),
(
"html_url",
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
),
],
test_mock={
"submit_review": lambda *args, **kwargs: (
"APPROVED",
"https://github.com/owner/repo/pull/1#pullrequestreview-123456",
)
},
)
@staticmethod
async def submit_review(
credentials: GithubCredentials,
repo: str,
pr_number: int,
review_id: int,
event: ReviewEvent,
) -> tuple[str, str]:
api = get_api(credentials, convert_urls=False)
# GitHub API endpoint for submitting a review
submit_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews/{review_id}/events"
data = {"event": event.value}
response = await api.post(submit_url, json=data)
review_data = response.json()
return review_data["state"], review_data["html_url"]
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
state, html_url = await self.submit_review(
credentials,
input_data.repo,
input_data.pr_number,
input_data.review_id,
input_data.event,
)
yield "state", state
yield "html_url", html_url
except Exception as e:
yield "error", str(e)
class GithubResolveReviewDiscussionBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo: str = SchemaField(
description="GitHub repository",
placeholder="owner/repo",
)
pr_number: int = SchemaField(
description="Pull request number",
placeholder="123",
)
comment_id: int = SchemaField(
description="ID of the review comment to resolve/unresolve",
placeholder="123456",
)
resolve: bool = SchemaField(
description="Whether to resolve (true) or unresolve (false) the discussion",
default=True,
)
class Output(BlockSchema):
success: bool = SchemaField(description="Whether the operation was successful")
error: str = SchemaField(description="Error message if the operation failed")
def __init__(self):
super().__init__(
id="b4b8a38c-95ae-4c91-9ef8-c2cffaf2b5d1",
description="This block resolves or unresolves a review discussion thread on a GitHub pull request.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubResolveReviewDiscussionBlock.Input,
output_schema=GithubResolveReviewDiscussionBlock.Output,
test_input={
"repo": "owner/repo",
"pr_number": 1,
"comment_id": 123456,
"resolve": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("success", True),
],
test_mock={"resolve_discussion": lambda *args, **kwargs: True},
)
@staticmethod
async def resolve_discussion(
credentials: GithubCredentials,
repo: str,
pr_number: int,
comment_id: int,
resolve: bool,
) -> bool:
api = get_api(credentials, convert_urls=False)
# Extract owner and repo name
parts = repo.split("/")
owner = parts[0]
repo_name = parts[1]
# GitHub GraphQL API is needed for resolving/unresolving discussions
# First, we need to get the node ID of the comment
graphql_url = "https://api.github.com/graphql"
# Query to get the review comment node ID
query = """
query($owner: String!, $repo: String!, $number: Int!) {
repository(owner: $owner, name: $repo) {
pullRequest(number: $number) {
reviewThreads(first: 100) {
nodes {
comments(first: 100) {
nodes {
databaseId
id
}
}
id
isResolved
}
}
}
}
}
"""
variables = {"owner": owner, "repo": repo_name, "number": pr_number}
response = await api.post(
graphql_url, json={"query": query, "variables": variables}
)
data = response.json()
# Find the thread containing our comment
thread_id = None
for thread in data["data"]["repository"]["pullRequest"]["reviewThreads"][
"nodes"
]:
for comment in thread["comments"]["nodes"]:
if comment["databaseId"] == comment_id:
thread_id = thread["id"]
break
if thread_id:
break
if not thread_id:
raise ValueError(f"Comment {comment_id} not found in pull request")
# Now resolve or unresolve the thread
# GitHub's GraphQL API has separate mutations for resolve and unresolve
if resolve:
mutation = """
mutation($threadId: ID!) {
resolveReviewThread(input: {threadId: $threadId}) {
thread {
isResolved
}
}
}
"""
else:
mutation = """
mutation($threadId: ID!) {
unresolveReviewThread(input: {threadId: $threadId}) {
thread {
isResolved
}
}
}
"""
mutation_variables = {"threadId": thread_id}
response = await api.post(
graphql_url, json={"query": mutation, "variables": mutation_variables}
)
result = response.json()
if "errors" in result:
raise Exception(f"GraphQL error: {result['errors']}")
return True
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
success = await self.resolve_discussion(
credentials,
input_data.repo,
input_data.pr_number,
input_data.comment_id,
input_data.resolve,
)
yield "success", success
except Exception as e:
yield "success", False
yield "error", str(e)
class GithubGetPRReviewCommentsBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
repo: str = SchemaField(
description="GitHub repository",
placeholder="owner/repo",
)
pr_number: int = SchemaField(
description="Pull request number",
placeholder="123",
)
review_id: Optional[int] = SchemaField(
description="ID of a specific review to get comments from (optional)",
placeholder="123456",
default=None,
advanced=True,
)
class Output(BlockSchema):
class CommentItem(TypedDict):
id: int
user: str
body: str
path: str
line: int
side: str
created_at: str
updated_at: str
in_reply_to_id: Optional[int]
html_url: str
comment: CommentItem = SchemaField(
title="Comment",
description="Individual review comment with details",
)
comments: list[CommentItem] = SchemaField(
description="List of all review comments on the pull request"
)
error: str = SchemaField(description="Error message if getting comments failed")
def __init__(self):
super().__init__(
id="1d34db7f-10c1-45c1-9d43-749f743c8bd4",
description="This block gets all review comments from a GitHub pull request or from a specific review.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubGetPRReviewCommentsBlock.Input,
output_schema=GithubGetPRReviewCommentsBlock.Output,
test_input={
"repo": "owner/repo",
"pr_number": 1,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"comments",
[
{
"id": 123456,
"user": "reviewer1",
"body": "This needs improvement",
"path": "src/main.py",
"line": 42,
"side": "RIGHT",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z",
"in_reply_to_id": None,
"html_url": "https://github.com/owner/repo/pull/1#discussion_r123456",
}
],
),
(
"comment",
{
"id": 123456,
"user": "reviewer1",
"body": "This needs improvement",
"path": "src/main.py",
"line": 42,
"side": "RIGHT",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z",
"in_reply_to_id": None,
"html_url": "https://github.com/owner/repo/pull/1#discussion_r123456",
},
),
],
test_mock={
"get_comments": lambda *args, **kwargs: [
{
"id": 123456,
"user": "reviewer1",
"body": "This needs improvement",
"path": "src/main.py",
"line": 42,
"side": "RIGHT",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z",
"in_reply_to_id": None,
"html_url": "https://github.com/owner/repo/pull/1#discussion_r123456",
}
]
},
)
@staticmethod
async def get_comments(
credentials: GithubCredentials,
repo: str,
pr_number: int,
review_id: Optional[int] = None,
) -> list[Output.CommentItem]:
api = get_api(credentials, convert_urls=False)
# Determine the endpoint based on whether we want comments from a specific review
if review_id:
# Get comments from a specific review
comments_url = f"https://api.github.com/repos/{repo}/pulls/{pr_number}/reviews/{review_id}/comments"
else:
# Get all review comments on the PR
comments_url = (
f"https://api.github.com/repos/{repo}/pulls/{pr_number}/comments"
)
response = await api.get(comments_url)
data = response.json()
comments: list[GithubGetPRReviewCommentsBlock.Output.CommentItem] = [
{
"id": comment["id"],
"user": comment["user"]["login"],
"body": comment["body"],
"path": comment.get("path", ""),
"line": comment.get("line", 0),
"side": comment.get("side", ""),
"created_at": comment["created_at"],
"updated_at": comment["updated_at"],
"in_reply_to_id": comment.get("in_reply_to_id"),
"html_url": comment["html_url"],
}
for comment in data
]
return comments
async def run(
self,
input_data: Input,
*,
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
comments = await self.get_comments(
credentials,
input_data.repo,
input_data.pr_number,
input_data.review_id,
)
yield "comments", comments
for comment in comments:
yield "comment", comment
except Exception as e:
yield "error", str(e)
class GithubCreateCommentObjectBlock(Block):
class Input(BlockSchema):
path: str = SchemaField(
description="The file path to comment on",
placeholder="src/main.py",
)
body: str = SchemaField(
description="The comment text",
placeholder="Please fix this issue",
)
position: Optional[int] = SchemaField(
description="Position in the diff (line number from first @@ hunk). Use this OR line.",
placeholder="6",
default=None,
advanced=True,
)
line: Optional[int] = SchemaField(
description="Line number in the file (will be used as position if position not provided)",
placeholder="42",
default=None,
advanced=True,
)
side: Optional[str] = SchemaField(
description="Side of the diff to comment on (NOTE: Only for standalone comments, not review comments)",
default="RIGHT",
advanced=True,
)
start_line: Optional[int] = SchemaField(
description="Start line for multi-line comments (NOTE: Only for standalone comments, not review comments)",
default=None,
advanced=True,
)
start_side: Optional[str] = SchemaField(
description="Side for the start of multi-line comments (NOTE: Only for standalone comments, not review comments)",
default=None,
advanced=True,
)
class Output(BlockSchema):
comment_object: dict = SchemaField(
description="The comment object formatted for GitHub API"
)
def __init__(self):
super().__init__(
id="b7d5e4f2-8c3a-4e6b-9f1d-7a8b9c5e4d3f",
description="Creates a comment object for use with GitHub blocks. Note: For review comments, only path, body, and position are used. Side fields are only for standalone PR comments.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=GithubCreateCommentObjectBlock.Input,
output_schema=GithubCreateCommentObjectBlock.Output,
test_input={
"path": "src/main.py",
"body": "Please fix this issue",
"position": 6,
},
test_output=[
(
"comment_object",
{
"path": "src/main.py",
"body": "Please fix this issue",
"position": 6,
},
),
],
)
async def run(
self,
input_data: Input,
**kwargs,
) -> BlockOutput:
# Build the comment object
comment_obj: dict = {
"path": input_data.path,
"body": input_data.body,
}
# Add position or line
if input_data.position is not None:
comment_obj["position"] = input_data.position
elif input_data.line is not None:
# Note: line will be used as position, which may not be accurate
# Position should be calculated from the diff
comment_obj["position"] = input_data.line
# Add optional fields only if they differ from defaults or are explicitly provided
if input_data.side and input_data.side != "RIGHT":
comment_obj["side"] = input_data.side
if input_data.start_line is not None:
comment_obj["start_line"] = input_data.start_line
if input_data.start_side:
comment_obj["start_side"] = input_data.start_side
yield "comment_object", comment_obj

View File

@@ -10,7 +10,7 @@ from pydantic import BaseModel
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.settings import Settings
from backend.util.settings import AppEnvironment, Settings
from ._auth import (
GOOGLE_OAUTH_IS_CONFIGURED,
@@ -21,8 +21,6 @@ from ._auth import (
GoogleCredentialsInput,
)
settings = Settings()
class CalendarEvent(BaseModel):
"""Structured representation of a Google Calendar event."""
@@ -90,6 +88,8 @@ class GoogleCalendarReadEventsBlock(Block):
)
def __init__(self):
settings = Settings()
# Create realistic test data for events
test_now = datetime.now(tz=timezone.utc)
test_tomorrow = test_now + timedelta(days=1)
@@ -116,7 +116,8 @@ class GoogleCalendarReadEventsBlock(Block):
categories={BlockCategory.PRODUCTIVITY, BlockCategory.DATA},
input_schema=GoogleCalendarReadEventsBlock.Input,
output_schema=GoogleCalendarReadEventsBlock.Output,
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
disabled=not GOOGLE_OAUTH_IS_CONFIGURED
or settings.config.app_env == AppEnvironment.PRODUCTION,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"calendar_id": "primary",
@@ -223,8 +224,8 @@ class GoogleCalendarReadEventsBlock(Block):
else None
),
token_uri="https://oauth2.googleapis.com/token",
client_id=settings.secrets.google_client_id,
client_secret=settings.secrets.google_client_secret,
client_id=Settings().secrets.google_client_id,
client_secret=Settings().secrets.google_client_secret,
scopes=credentials.scopes,
)
return build("calendar", "v3", credentials=creds)
@@ -441,13 +442,16 @@ class GoogleCalendarCreateEventBlock(Block):
error: str = SchemaField(description="Error message if event creation failed")
def __init__(self):
settings = Settings()
super().__init__(
id="ed2ec950-fbff-4204-94c0-023fb1d625e0",
description="This block creates a new event in Google Calendar with customizable parameters.",
categories={BlockCategory.PRODUCTIVITY},
input_schema=GoogleCalendarCreateEventBlock.Input,
output_schema=GoogleCalendarCreateEventBlock.Output,
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
disabled=not GOOGLE_OAUTH_IS_CONFIGURED
or settings.config.app_env == AppEnvironment.PRODUCTION,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"event_title": "Team Meeting",
@@ -571,8 +575,8 @@ class GoogleCalendarCreateEventBlock(Block):
else None
),
token_uri="https://oauth2.googleapis.com/token",
client_id=settings.secrets.google_client_id,
client_secret=settings.secrets.google_client_secret,
client_id=Settings().secrets.google_client_id,
client_secret=Settings().secrets.google_client_secret,
scopes=credentials.scopes,
)
return build("calendar", "v3", credentials=creds)

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,7 @@ from googleapiclient.discovery import build
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.settings import Settings
from backend.util.settings import AppEnvironment, Settings
from ._auth import (
GOOGLE_OAUTH_IS_CONFIGURED,
@@ -19,7 +19,10 @@ from ._auth import (
)
settings = Settings()
GOOGLE_SHEETS_DISABLED = not GOOGLE_OAUTH_IS_CONFIGURED
GOOGLE_SHEETS_DISABLED = (
not GOOGLE_OAUTH_IS_CONFIGURED
or settings.config.app_env == AppEnvironment.PRODUCTION
)
def parse_a1_notation(a1: str) -> tuple[str | None, str]:

View File

@@ -113,7 +113,6 @@ class SendWebRequestBlock(Block):
graph_exec_id: str,
files_name: str,
files: list[MediaFileType],
user_id: str,
) -> list[tuple[str, tuple[str, BytesIO, str]]]:
"""
Prepare files for the request by storing them and reading their content.
@@ -125,7 +124,7 @@ class SendWebRequestBlock(Block):
for media in files:
# Normalise to a list so we can repeat the same key
rel_path = await store_media_file(
graph_exec_id, media, user_id, return_content=False
graph_exec_id, media, return_content=False
)
abs_path = get_exec_file_path(graph_exec_id, rel_path)
async with aiofiles.open(abs_path, "rb") as f:
@@ -137,7 +136,7 @@ class SendWebRequestBlock(Block):
return files_payload
async def run(
self, input_data: Input, *, graph_exec_id: str, user_id: str, **kwargs
self, input_data: Input, *, graph_exec_id: str, **kwargs
) -> BlockOutput:
# ─── Parse/normalise body ────────────────────────────────────
body = input_data.body
@@ -168,7 +167,7 @@ class SendWebRequestBlock(Block):
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
if use_files:
files_payload = await self._prepare_files(
graph_exec_id, input_data.files_name, input_data.files, user_id
graph_exec_id, input_data.files_name, input_data.files
)
# Enforce body format rules
@@ -228,7 +227,6 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
*,
graph_exec_id: str,
credentials: HostScopedCredentials,
user_id: str,
**kwargs,
) -> BlockOutput:
# Create SendWebRequestBlock.Input from our input (removing credentials field)
@@ -259,6 +257,6 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
# Use parent class run method
async for output_name, output_data in super().run(
base_input, graph_exec_id=graph_exec_id, user_id=user_id, **kwargs
base_input, graph_exec_id=graph_exec_id, **kwargs
):
yield output_name, output_data

View File

@@ -447,7 +447,6 @@ class AgentFileInputBlock(AgentInputBlock):
input_data: Input,
*,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
if not input_data.value:
@@ -456,7 +455,6 @@ class AgentFileInputBlock(AgentInputBlock):
yield "result", await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.value,
user_id=user_id,
return_content=input_data.base_64,
)

View File

@@ -2,6 +2,7 @@
Shared configuration for all Linear blocks using the new SDK pattern.
"""
import os
from enum import Enum
from backend.sdk import (
@@ -37,11 +38,21 @@ class LinearScope(str, Enum):
ADMIN = "admin"
linear = (
# Check if Linear OAuth is configured
client_id = os.getenv("LINEAR_CLIENT_ID")
client_secret = os.getenv("LINEAR_CLIENT_SECRET")
LINEAR_OAUTH_IS_CONFIGURED = bool(client_id and client_secret)
# Build the Linear provider
builder = (
ProviderBuilder("linear")
.with_api_key(env_var_name="LINEAR_API_KEY", title="Linear API Key")
.with_base_cost(1, BlockCostType.RUN)
.with_oauth(
)
# Linear only supports OAuth authentication
if LINEAR_OAUTH_IS_CONFIGURED:
builder = builder.with_oauth(
LinearOAuthHandler,
scopes=[
LinearScope.READ,
@@ -52,8 +63,9 @@ linear = (
client_id_env_var="LINEAR_CLIENT_ID",
client_secret_env_var="LINEAR_CLIENT_SECRET",
)
.build()
)
# Build the provider
linear = builder.build()
TEST_CREDENTIALS_OAUTH = OAuth2Credentials(

View File

@@ -11,6 +11,7 @@ from backend.sdk import (
from ._api import LinearAPIException, LinearClient
from ._config import (
LINEAR_OAUTH_IS_CONFIGURED,
TEST_CREDENTIALS_INPUT_OAUTH,
TEST_CREDENTIALS_OAUTH,
LinearScope,
@@ -49,6 +50,7 @@ class LinearCreateCommentBlock(Block):
"comment": "Test comment",
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
},
disabled=not LINEAR_OAUTH_IS_CONFIGURED,
test_credentials=TEST_CREDENTIALS_OAUTH,
test_output=[("comment_id", "abc123"), ("comment_body", "Test comment")],
test_mock={

View File

@@ -11,6 +11,7 @@ from backend.sdk import (
from ._api import LinearAPIException, LinearClient
from ._config import (
LINEAR_OAUTH_IS_CONFIGURED,
TEST_CREDENTIALS_INPUT_OAUTH,
TEST_CREDENTIALS_OAUTH,
LinearScope,
@@ -52,6 +53,7 @@ class LinearCreateIssueBlock(Block):
super().__init__(
id="f9c68f55-dcca-40a8-8771-abf9601680aa",
description="Creates a new issue on Linear",
disabled=not LINEAR_OAUTH_IS_CONFIGURED,
input_schema=self.Input,
output_schema=self.Output,
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
@@ -145,6 +147,7 @@ class LinearSearchIssuesBlock(Block):
description="Searches for issues on Linear",
input_schema=self.Input,
output_schema=self.Output,
disabled=not LINEAR_OAUTH_IS_CONFIGURED,
test_input={
"term": "Test issue",
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,

View File

@@ -11,6 +11,7 @@ from backend.sdk import (
from ._api import LinearAPIException, LinearClient
from ._config import (
LINEAR_OAUTH_IS_CONFIGURED,
TEST_CREDENTIALS_INPUT_OAUTH,
TEST_CREDENTIALS_OAUTH,
LinearScope,
@@ -44,6 +45,7 @@ class LinearSearchProjectsBlock(Block):
"term": "Test project",
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
},
disabled=not LINEAR_OAUTH_IS_CONFIGURED,
test_credentials=TEST_CREDENTIALS_OAUTH,
test_output=[
(

View File

@@ -37,7 +37,6 @@ LLMProviderName = Literal[
ProviderName.OPENAI,
ProviderName.OPEN_ROUTER,
ProviderName.LLAMA_API,
ProviderName.V0,
]
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
@@ -81,20 +80,14 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
O3_MINI = "o3-mini"
O3 = "o3-2025-04-16"
O1 = "o1"
O1_PREVIEW = "o1-preview"
O1_MINI = "o1-mini"
# GPT-5 models
GPT5 = "gpt-5-2025-08-07"
GPT5_MINI = "gpt-5-mini-2025-08-07"
GPT5_NANO = "gpt-5-nano-2025-08-07"
GPT5_CHAT = "gpt-5-chat-latest"
GPT41 = "gpt-4.1-2025-04-14"
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
GPT4O_MINI = "gpt-4o-mini"
GPT4O = "gpt-4o"
GPT4_TURBO = "gpt-4-turbo"
GPT3_5_TURBO = "gpt-3.5-turbo"
# Anthropic models
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
CLAUDE_4_OPUS = "claude-opus-4-20250514"
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
@@ -113,6 +106,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
LLAMA3_1_8B = "llama-3.1-8b-instant"
LLAMA3_70B = "llama3-70b-8192"
LLAMA3_8B = "llama3-8b-8192"
MIXTRAL_8X7B = "mixtral-8x7b-32768"
# Groq preview models
DEEPSEEK_LLAMA_70B = "deepseek-r1-distill-llama-70b"
# Ollama models
@@ -122,22 +116,21 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
OLLAMA_LLAMA3_405B = "llama3.1:405b"
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
# OpenRouter models
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
GEMINI_FLASH_1_5 = "google/gemini-flash-1.5"
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
GROK_BETA = "x-ai/grok-beta"
MISTRAL_NEMO = "mistralai/mistral-nemo"
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
EVA_QWEN_2_5_32B = "eva-unit-01/eva-qwen-2.5-32b"
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE = (
"perplexity/llama-3.1-sonar-large-128k-online"
)
PERPLEXITY_SONAR = "perplexity/sonar"
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
QWEN_QWQ_32B_PREVIEW = "qwen/qwq-32b-preview"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
AMAZON_NOVA_LITE_V1 = "amazon/nova-lite-v1"
@@ -147,19 +140,11 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
GROK_4 = "x-ai/grok-4"
KIMI_K2 = "moonshotai/kimi-k2"
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
QWEN3_CODER = "qwen/qwen3-coder"
# Llama API models
LLAMA_API_LLAMA_4_SCOUT = "Llama-4-Scout-17B-16E-Instruct-FP8"
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
LLAMA_API_LLAMA3_3_8B = "Llama-3.3-8B-Instruct"
LLAMA_API_LLAMA3_3_70B = "Llama-3.3-70B-Instruct"
# v0 by Vercel models
V0_1_5_MD = "v0-1.5-md"
V0_1_5_LG = "v0-1.5-lg"
V0_1_0_MD = "v0-1.0-md"
@property
def metadata(self) -> ModelMetadata:
@@ -183,14 +168,11 @@ MODEL_METADATA = {
LlmModel.O3: ModelMetadata("openai", 200000, 100000),
LlmModel.O3_MINI: ModelMetadata("openai", 200000, 100000), # o3-mini-2025-01-31
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
LlmModel.O1_PREVIEW: ModelMetadata(
"openai", 128000, 32768
), # o1-preview-2024-09-12
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
# GPT-5 models
LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_MINI: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_NANO: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_CHAT: ModelMetadata("openai", 400000, 16384),
LlmModel.GPT41: ModelMetadata("openai", 1047576, 32768),
LlmModel.GPT41_MINI: ModelMetadata("openai", 1047576, 32768),
LlmModel.GPT4O_MINI: ModelMetadata(
"openai", 128000, 16384
), # gpt-4o-mini-2024-07-18
@@ -200,9 +182,6 @@ MODEL_METADATA = {
), # gpt-4-turbo-2024-04-09
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, 4096), # gpt-3.5-turbo-0125
# https://docs.anthropic.com/en/docs/about-claude/models
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
"anthropic", 200000, 32000
), # claude-opus-4-1-20250805
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
"anthropic", 200000, 8192
), # claude-4-opus-20250514
@@ -233,6 +212,7 @@ MODEL_METADATA = {
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 128000, 8192),
LlmModel.LLAMA3_70B: ModelMetadata("groq", 8192, None),
LlmModel.LLAMA3_8B: ModelMetadata("groq", 8192, None),
LlmModel.MIXTRAL_8X7B: ModelMetadata("groq", 32768, None),
LlmModel.DEEPSEEK_LLAMA_70B: ModelMetadata("groq", 128000, None),
# https://ollama.com/library
LlmModel.OLLAMA_LLAMA3_3: ModelMetadata("ollama", 8192, None),
@@ -243,17 +223,15 @@ MODEL_METADATA = {
# https://openrouter.ai/models
LlmModel.GEMINI_FLASH_1_5: ModelMetadata("open_router", 1000000, 8192),
LlmModel.GEMINI_2_5_PRO: ModelMetadata("open_router", 1050000, 8192),
LlmModel.GEMINI_2_5_FLASH: ModelMetadata("open_router", 1048576, 65535),
LlmModel.GEMINI_2_0_FLASH: ModelMetadata("open_router", 1048576, 8192),
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
"open_router", 1048576, 65535
),
LlmModel.GEMINI_2_0_FLASH_LITE: ModelMetadata("open_router", 1048576, 8192),
LlmModel.GROK_BETA: ModelMetadata("open_router", 131072, 131072),
LlmModel.MISTRAL_NEMO: ModelMetadata("open_router", 128000, 4096),
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata("open_router", 128000, 4096),
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata("open_router", 128000, 4096),
LlmModel.EVA_QWEN_2_5_32B: ModelMetadata("open_router", 16384, 4096),
LlmModel.DEEPSEEK_CHAT: ModelMetadata("open_router", 64000, 2048),
LlmModel.DEEPSEEK_R1_0528: ModelMetadata("open_router", 163840, 163840),
LlmModel.PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE: ModelMetadata(
"open_router", 127072, 127072
),
LlmModel.PERPLEXITY_SONAR: ModelMetadata("open_router", 127000, 127000),
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata("open_router", 200000, 8000),
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
@@ -261,14 +239,13 @@ MODEL_METADATA = {
128000,
128000,
),
LlmModel.QWEN_QWQ_32B_PREVIEW: ModelMetadata("open_router", 32768, 32768),
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata(
"open_router", 131000, 4096
),
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata(
"open_router", 12288, 12288
),
LlmModel.OPENAI_GPT_OSS_120B: ModelMetadata("open_router", 131072, 131072),
LlmModel.OPENAI_GPT_OSS_20B: ModelMetadata("open_router", 131072, 32768),
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata("open_router", 300000, 5120),
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata("open_router", 128000, 5120),
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata("open_router", 300000, 5120),
@@ -276,19 +253,11 @@ MODEL_METADATA = {
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata("open_router", 4096, 4096),
LlmModel.META_LLAMA_4_SCOUT: ModelMetadata("open_router", 131072, 131072),
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata("open_router", 1048576, 1000000),
LlmModel.GROK_4: ModelMetadata("open_router", 256000, 256000),
LlmModel.KIMI_K2: ModelMetadata("open_router", 131000, 131000),
LlmModel.QWEN3_235B_A22B_THINKING: ModelMetadata("open_router", 262144, 262144),
LlmModel.QWEN3_CODER: ModelMetadata("open_router", 262144, 262144),
# Llama API models
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata("llama_api", 128000, 4028),
LlmModel.LLAMA_API_LLAMA4_MAVERICK: ModelMetadata("llama_api", 128000, 4028),
LlmModel.LLAMA_API_LLAMA3_3_8B: ModelMetadata("llama_api", 128000, 4028),
LlmModel.LLAMA_API_LLAMA3_3_70B: ModelMetadata("llama_api", 128000, 4028),
# v0 by Vercel models
LlmModel.V0_1_5_MD: ModelMetadata("v0", 128000, 64000),
LlmModel.V0_1_5_LG: ModelMetadata("v0", 512000, 64000),
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000),
}
for model in LlmModel:
@@ -502,7 +471,6 @@ async def llm_call(
messages=messages,
max_tokens=max_tokens,
tools=an_tools,
timeout=600,
)
if not resp.content:
@@ -685,11 +653,7 @@ async def llm_call(
client = openai.OpenAI(
base_url="https://api.aimlapi.com/v2",
api_key=credentials.api_key.get_secret_value(),
default_headers={
"X-Project": "AutoGPT",
"X-Title": "AutoGPT",
"HTTP-Referer": "https://github.com/Significant-Gravitas/AutoGPT",
},
default_headers={"X-Project": "AutoGPT"},
)
completion = client.chat.completions.create(
@@ -709,42 +673,6 @@ async def llm_call(
),
reasoning=None,
)
elif provider == "v0":
tools_param = tools if tools else openai.NOT_GIVEN
client = openai.AsyncOpenAI(
base_url="https://api.v0.dev/v1",
api_key=credentials.api_key.get_secret_value(),
)
response_format = None
if json_format:
response_format = {"type": "json_object"}
parallel_tool_calls_param = get_parallel_tool_calls_param(
llm_model, parallel_tool_calls
)
response = await client.chat.completions.create(
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
parallel_tool_calls=parallel_tool_calls_param,
)
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
return LLMResponse(
raw_response=response.choices[0].message,
prompt=prompt,
response=response.choices[0].message.content or "",
tool_calls=tool_calls,
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
reasoning=reasoning,
)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
@@ -992,22 +920,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
)
if not response_error:
self.merge_stats(
NodeExecutionStats(
llm_call_count=retry_count + 1,
llm_retry_count=retry_count,
)
)
yield "response", response_obj
yield "prompt", self.prompt
return
else:
self.merge_stats(
NodeExecutionStats(
llm_call_count=retry_count + 1,
llm_retry_count=retry_count,
)
)
yield "response", {"response": response_text}
yield "prompt", self.prompt
return
@@ -1039,6 +955,13 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
f"Reducing max_tokens to {input_data.max_tokens} for next attempt"
)
retry_prompt = f"Error calling LLM: {e}"
finally:
self.merge_stats(
NodeExecutionStats(
llm_call_count=retry_count + 1,
llm_retry_count=retry_count,
)
)
raise RuntimeError(retry_prompt)

View File

@@ -44,14 +44,12 @@ class MediaDurationBlock(Block):
input_data: Input,
*,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
# 1) Store the input media locally
local_media_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.media_in,
user_id=user_id,
return_content=False,
)
media_abspath = get_exec_file_path(graph_exec_id, local_media_path)
@@ -113,14 +111,12 @@ class LoopVideoBlock(Block):
*,
node_exec_id: str,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
# 1) Store the input video locally
local_video_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.video_in,
user_id=user_id,
return_content=False,
)
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
@@ -153,7 +149,6 @@ class LoopVideoBlock(Block):
video_out = await store_media_file(
graph_exec_id=graph_exec_id,
file=output_filename,
user_id=user_id,
return_content=input_data.output_return_type == "data_uri",
)
@@ -205,20 +200,17 @@ class AddAudioToVideoBlock(Block):
*,
node_exec_id: str,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
# 1) Store the inputs locally
local_video_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.video_in,
user_id=user_id,
return_content=False,
)
local_audio_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.audio_in,
user_id=user_id,
return_content=False,
)
@@ -247,7 +239,6 @@ class AddAudioToVideoBlock(Block):
video_out = await store_media_file(
graph_exec_id=graph_exec_id,
file=output_filename,
user_id=user_id,
return_content=input_data.output_return_type == "data_uri",
)

View File

@@ -1,13 +1,22 @@
import logging
from typing import Any, Literal
from autogpt_libs.utils.cache import thread_cached
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.clients import get_database_manager_async_client
logger = logging.getLogger(__name__)
@thread_cached
def get_database_manager_client():
from backend.executor import DatabaseManagerAsyncClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
StorageScope = Literal["within_agent", "across_agents"]
@@ -79,7 +88,7 @@ class PersistInformationBlock(Block):
async def _store_data(
self, user_id: str, node_exec_id: str, key: str, data: Any
) -> Any | None:
return await get_database_manager_async_client().set_execution_kv_data(
return await get_database_manager_client().set_execution_kv_data(
user_id=user_id,
node_exec_id=node_exec_id,
key=key,
@@ -140,7 +149,7 @@ class RetrieveInformationBlock(Block):
yield "value", input_data.default_value
async def _retrieve_data(self, user_id: str, key: str) -> Any | None:
return await get_database_manager_async_client().get_execution_kv_data(
return await get_database_manager_client().get_execution_kv_data(
user_id=user_id,
key=key,
)

View File

@@ -1,24 +0,0 @@
from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsMetaInput, ProviderName
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="replicate",
api_key=SecretStr("mock-replicate-api-key"),
title="Mock Replicate API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.type,
}
ReplicateCredentials = APIKeyCredentials
ReplicateCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.REPLICATE], Literal["api_key"]
]

View File

@@ -1,39 +0,0 @@
import logging
from replicate.helpers import FileOutput
logger = logging.getLogger(__name__)
ReplicateOutputs = FileOutput | list[FileOutput] | list[str] | str | list[dict]
def extract_result(output: ReplicateOutputs) -> str:
result = (
"Unable to process result. Please contact us with the models and inputs used"
)
# Check if output is a list or a string and extract accordingly; otherwise, assign a default message
if isinstance(output, list) and len(output) > 0:
# we could use something like all(output, FileOutput) but it will be slower so we just type ignore
if isinstance(output[0], FileOutput):
result = output[0].url # If output is a list, get the first element
elif isinstance(output[0], str):
result = "".join(
output # type: ignore we're already not a file output here
) # type:ignore If output is a list and a str, join the elements the first element. Happens if its text
elif isinstance(output[0], dict):
result = str(output[0])
else:
logger.error(
"Replicate generated a new output type that's not a file output or a str in a replicate block"
)
elif isinstance(output, FileOutput):
result = output.url # If output is a FileOutput, use the url
elif isinstance(output, str):
result = output # If output is a string (for some reason due to their janky type hinting), use it directly
else:
result = "No output received" # Fallback message if output is not as expected
logger.error(
"We somehow didn't get an output from a replicate block. This is almost certainly an error"
)
return result

View File

@@ -1,133 +0,0 @@
import logging
from typing import Optional
from pydantic import SecretStr
from replicate.client import Client as ReplicateClient
from backend.blocks.replicate._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
ReplicateCredentialsInput,
)
from backend.blocks.replicate._helper import ReplicateOutputs, extract_result
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
logger = logging.getLogger(__name__)
class ReplicateModelBlock(Block):
"""
Block for running any Replicate model with custom inputs.
This block allows you to:
- Use any public Replicate model
- Pass custom inputs as a dictionary
- Specify model versions
- Get structured outputs with prediction metadata
"""
class Input(BlockSchema):
credentials: ReplicateCredentialsInput = CredentialsField(
description="Enter your Replicate API key to access the model API. You can obtain an API key from https://replicate.com/account/api-tokens.",
)
model_name: str = SchemaField(
description="The Replicate model name (format: 'owner/model-name')",
placeholder="stability-ai/stable-diffusion",
advanced=False,
)
model_inputs: dict[str, str | int] = SchemaField(
default={},
description="Dictionary of inputs to pass to the model",
placeholder='{"prompt": "a beautiful landscape", "num_outputs": 1}',
advanced=False,
)
version: Optional[str] = SchemaField(
default=None,
description="Specific version hash of the model (optional)",
placeholder="db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf",
advanced=True,
)
class Output(BlockSchema):
result: str = SchemaField(description="The output from the Replicate model")
status: str = SchemaField(description="Status of the prediction")
model_name: str = SchemaField(description="Name of the model used")
error: str = SchemaField(description="Error message if any", default="")
def __init__(self):
super().__init__(
id="c40d75a2-d0ea-44c9-a4f6-634bb3bdab1a",
description="Run Replicate models synchronously",
categories={BlockCategory.AI},
input_schema=ReplicateModelBlock.Input,
output_schema=ReplicateModelBlock.Output,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"model_name": "meta/llama-2-7b-chat",
"model_inputs": {"prompt": "Hello, world!", "max_new_tokens": 50},
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("result", str),
("status", str),
("model_name", str),
],
test_mock={
"run_model": lambda model_ref, model_inputs, api_key: (
"Mock response from Replicate model"
)
},
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
"""
Execute the Replicate model with the provided inputs.
Args:
input_data: The input data containing model name and inputs
credentials: The API credentials
Yields:
BlockOutput containing the model results and metadata
"""
try:
if input_data.version:
model_ref = f"{input_data.model_name}:{input_data.version}"
else:
model_ref = input_data.model_name
logger.debug(f"Running Replicate model: {model_ref}")
result = await self.run_model(
model_ref, input_data.model_inputs, credentials.api_key
)
yield "result", result
yield "status", "succeeded"
yield "model_name", input_data.model_name
except Exception as e:
error_msg = f"Unexpected error running Replicate model: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg)
async def run_model(self, model_ref: str, model_inputs: dict, api_key: SecretStr):
"""
Run the Replicate model. This method can be mocked for testing.
Args:
model_ref: The model reference (e.g., "owner/model-name:version")
model_inputs: The inputs to pass to the model
api_key: The Replicate API key as SecretStr
Returns:
Tuple of (result, prediction_id)
"""
api_key_str = api_key.get_secret_value()
client = ReplicateClient(api_token=api_key_str)
output: ReplicateOutputs = await client.async_run(
model_ref, input=model_inputs, wait=False
) # type: ignore they suck at typing
result = extract_result(output)
return result

View File

@@ -1,17 +1,33 @@
import os
from enum import Enum
from typing import Literal
from pydantic import SecretStr
from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput
from backend.blocks.replicate._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
ReplicateCredentialsInput,
)
from backend.blocks.replicate._helper import ReplicateOutputs, extract_result
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="replicate",
api_key=SecretStr("mock-replicate-api-key"),
title="Mock Replicate API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.type,
}
# Model name enum
@@ -39,7 +55,9 @@ class ImageType(str, Enum):
class ReplicateFluxAdvancedModelBlock(Block):
class Input(BlockSchema):
credentials: ReplicateCredentialsInput = CredentialsField(
credentials: CredentialsMetaInput[
Literal[ProviderName.REPLICATE], Literal["api_key"]
] = CredentialsField(
description="The Replicate integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
@@ -183,7 +201,7 @@ class ReplicateFluxAdvancedModelBlock(Block):
client = ReplicateClient(api_token=api_key.get_secret_value())
# Run the model with additional parameters
output: ReplicateOutputs = await client.async_run( # type: ignore This is because they changed the return type, and didn't update the type hint! It should be overloaded depending on the value of `use_file_output` to `FileOutput | list[FileOutput]` but it's `Any | Iterator[Any]`
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore This is because they changed the return type, and didn't update the type hint! It should be overloaded depending on the value of `use_file_output` to `FileOutput | list[FileOutput]` but it's `Any | Iterator[Any]`
f"{model_name}",
input={
"prompt": prompt,
@@ -199,6 +217,21 @@ class ReplicateFluxAdvancedModelBlock(Block):
wait=False, # don't arbitrarily return data:octect/stream or sometimes url depending on the model???? what is this api
)
result = extract_result(output)
# Check if output is a list or a string and extract accordingly; otherwise, assign a default message
if isinstance(output, list) and len(output) > 0:
if isinstance(output[0], FileOutput):
result_url = output[0].url # If output is a list, get the first element
else:
result_url = output[
0
] # If output is a list and not a FileOutput, get the first element. Should never happen, but just in case.
elif isinstance(output, FileOutput):
result_url = output.url # If output is a FileOutput, use the url
elif isinstance(output, str):
result_url = output # If output is a string (for some reason due to their janky type hinting), use it directly
else:
result_url = (
"No output received" # Fallback message if output is not as expected
)
return result
return result_url

View File

@@ -108,7 +108,6 @@ class ScreenshotWebPageBlock(Block):
async def take_screenshot(
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
url: str,
viewport_width: int,
viewport_height: int,
@@ -154,7 +153,6 @@ class ScreenshotWebPageBlock(Block):
file=MediaFileType(
f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}"
),
user_id=user_id,
return_content=True,
)
}
@@ -165,14 +163,12 @@ class ScreenshotWebPageBlock(Block):
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
**kwargs,
) -> BlockOutput:
try:
screenshot_data = await self.take_screenshot(
credentials=credentials,
graph_exec_id=graph_exec_id,
user_id=user_id,
url=input_data.url,
viewport_width=input_data.viewport_width,
viewport_height=input_data.viewport_height,

View File

@@ -3,7 +3,8 @@ from typing import List
from backend.data.block import BlockOutput, BlockSchema
from backend.data.model import APIKeyCredentials, SchemaField
from backend.util.settings import BehaveAs, Settings
from backend.util import settings
from backend.util.settings import BehaveAs
from ._api import (
TEST_CREDENTIALS,
@@ -15,8 +16,6 @@ from ._api import (
)
from .base import Slant3DBlockBase
settings = Settings()
class Slant3DCreateOrderBlock(Slant3DBlockBase):
"""Block for creating new orders"""
@@ -281,7 +280,7 @@ class Slant3DGetOrdersBlock(Slant3DBlockBase):
input_schema=self.Input,
output_schema=self.Output,
# This block is disabled for cloud hosted because it allows access to all orders for the account
disabled=settings.config.behave_as == BehaveAs.CLOUD,
disabled=settings.Settings().config.behave_as == BehaveAs.CLOUD,
test_input={"credentials": TEST_CREDENTIALS_INPUT},
test_credentials=TEST_CREDENTIALS,
test_output=[

View File

@@ -9,7 +9,8 @@ from backend.data.block import (
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from backend.util.settings import AppEnvironment, BehaveAs, Settings
from backend.util import settings
from backend.util.settings import AppEnvironment, BehaveAs
from ._api import (
TEST_CREDENTIALS,
@@ -18,8 +19,6 @@ from ._api import (
Slant3DCredentialsInput,
)
settings = Settings()
class Slant3DTriggerBase:
"""Base class for Slant3D webhook triggers"""
@@ -77,8 +76,8 @@ class Slant3DOrderWebhookBlock(Slant3DTriggerBase, Block):
),
# All webhooks are currently subscribed to for all orders. This works for self hosted, but not for cloud hosted prod
disabled=(
settings.config.behave_as == BehaveAs.CLOUD
and settings.config.app_env != AppEnvironment.LOCAL
settings.Settings().config.behave_as == BehaveAs.CLOUD
and settings.Settings().config.app_env != AppEnvironment.LOCAL
),
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=self.Input,

View File

@@ -3,6 +3,8 @@ import re
from collections import Counter
from typing import TYPE_CHECKING, Any
from autogpt_libs.utils.cache import thread_cached
import backend.blocks.llm as llm
from backend.blocks.agent import AgentExecutorBlock
from backend.data.block import (
@@ -13,9 +15,8 @@ from backend.data.block import (
BlockSchema,
BlockType,
)
from backend.data.model import NodeExecutionStats, SchemaField
from backend.data.model import SchemaField
from backend.util import json
from backend.util.clients import get_database_manager_async_client
if TYPE_CHECKING:
from backend.data.graph import Link, Node
@@ -23,6 +24,14 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
@thread_cached
def get_database_manager_client():
from backend.executor import DatabaseManagerAsyncClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
"""
Return a list of tool_call_ids if the entry is a tool request.
@@ -291,32 +300,9 @@ class SmartDecisionMakerBlock(Block):
for link in links:
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
# Handle dynamic fields (e.g., values_#_*, items_$_*, etc.)
# These are fields that get merged by the executor into their base field
if (
"_#_" in link.sink_name
or "_$_" in link.sink_name
or "_@_" in link.sink_name
):
# For dynamic fields, provide a generic string schema
# The executor will handle merging these into the appropriate structure
properties[sink_name] = {
"type": "string",
"description": f"Dynamic value for {link.sink_name}",
}
else:
# For regular fields, use the block's schema
try:
properties[sink_name] = sink_block_input_schema.get_field_schema(
link.sink_name
)
except (KeyError, AttributeError):
# If the field doesn't exist in the schema, provide a generic schema
properties[sink_name] = {
"type": "string",
"description": f"Value for {link.sink_name}",
}
properties[sink_name] = sink_block_input_schema.get_field_schema(
link.sink_name
)
tool_function["parameters"] = {
**block.input_schema.jsonschema(),
@@ -347,7 +333,7 @@ class SmartDecisionMakerBlock(Block):
if not graph_id or not graph_version:
raise ValueError("Graph ID or Graph Version not found in sink node.")
db_client = get_database_manager_async_client()
db_client = get_database_manager_client()
sink_graph_meta = await db_client.get_graph_metadata(graph_id, graph_version)
if not sink_graph_meta:
raise ValueError(
@@ -407,7 +393,7 @@ class SmartDecisionMakerBlock(Block):
ValueError: If no tool links are found for the specified node_id, or if a sink node
or its metadata cannot be found.
"""
db_client = get_database_manager_async_client()
db_client = get_database_manager_client()
tools = [
(link, node)
for link, node in await db_client.get_connected_output_nodes(node_id)
@@ -501,6 +487,10 @@ class SmartDecisionMakerBlock(Block):
}
)
prompt.extend(tool_output)
if input_data.multiple_tool_calls:
input_data.sys_prompt += "\nYou can call a tool (different tools) multiple times in a single response."
else:
input_data.sys_prompt += "\nOnly provide EXACTLY one function call, multiple tool calls is strictly prohibited."
values = input_data.prompt_values
if values:
@@ -530,15 +520,6 @@ class SmartDecisionMakerBlock(Block):
parallel_tool_calls=input_data.multiple_tool_calls,
)
# Track LLM usage stats
self.merge_stats(
NodeExecutionStats(
input_token_count=response.prompt_tokens,
output_token_count=response.completion_tokens,
llm_call_count=1,
)
)
if not response.tool_calls:
yield "finished", response.response
return

View File

@@ -1,180 +0,0 @@
from pathlib import Path
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import ContributorDetails, SchemaField
from backend.util.file import get_exec_file_path, store_media_file
from backend.util.type import MediaFileType
class ReadSpreadsheetBlock(Block):
class Input(BlockSchema):
contents: str | None = SchemaField(
description="The contents of the CSV/spreadsheet data to read",
placeholder="a, b, c\n1,2,3\n4,5,6",
default=None,
advanced=False,
)
file_input: MediaFileType | None = SchemaField(
description="CSV or Excel file to read from (URL, data URI, or local path). Excel files are automatically converted to CSV",
default=None,
advanced=False,
)
delimiter: str = SchemaField(
description="The delimiter used in the CSV/spreadsheet data",
default=",",
)
quotechar: str = SchemaField(
description="The character used to quote fields",
default='"',
)
escapechar: str = SchemaField(
description="The character used to escape the delimiter",
default="\\",
)
has_header: bool = SchemaField(
description="Whether the CSV file has a header row",
default=True,
)
skip_rows: int = SchemaField(
description="The number of rows to skip from the start of the file",
default=0,
)
strip: bool = SchemaField(
description="Whether to strip whitespace from the values",
default=True,
)
skip_columns: list[str] = SchemaField(
description="The columns to skip from the start of the row",
default_factory=list,
)
produce_singular_result: bool = SchemaField(
description="If True, yield individual 'row' outputs only (can be slow). If False, yield both 'rows' (all data)",
default=False,
)
class Output(BlockSchema):
row: dict[str, str] = SchemaField(
description="The data produced from each row in the spreadsheet"
)
rows: list[dict[str, str]] = SchemaField(
description="All the data in the spreadsheet as a list of rows"
)
def __init__(self):
super().__init__(
id="acf7625e-d2cb-4941-bfeb-2819fc6fc015",
input_schema=ReadSpreadsheetBlock.Input,
output_schema=ReadSpreadsheetBlock.Output,
description="Reads CSV and Excel files and outputs the data as a list of dictionaries and individual rows. Excel files are automatically converted to CSV format.",
contributors=[ContributorDetails(name="Nicholas Tindle")],
categories={BlockCategory.TEXT, BlockCategory.DATA},
test_input=[
{
"contents": "a, b, c\n1,2,3\n4,5,6",
"produce_singular_result": False,
},
{
"contents": "a, b, c\n1,2,3\n4,5,6",
"produce_singular_result": True,
},
],
test_output=[
(
"rows",
[
{"a": "1", "b": "2", "c": "3"},
{"a": "4", "b": "5", "c": "6"},
],
),
("row", {"a": "1", "b": "2", "c": "3"}),
("row", {"a": "4", "b": "5", "c": "6"}),
],
)
async def run(
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
) -> BlockOutput:
import csv
from io import StringIO
# Determine data source - prefer file_input if provided, otherwise use contents
if input_data.file_input:
stored_file_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=input_data.file_input,
return_content=False,
)
# Get full file path
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
if not Path(file_path).exists():
raise ValueError(f"File does not exist: {file_path}")
# Check if file is an Excel file and convert to CSV
file_extension = Path(file_path).suffix.lower()
if file_extension in [".xlsx", ".xls"]:
# Handle Excel files
try:
from io import StringIO
import pandas as pd
# Read Excel file
df = pd.read_excel(file_path)
# Convert to CSV string
csv_buffer = StringIO()
df.to_csv(csv_buffer, index=False)
csv_content = csv_buffer.getvalue()
except ImportError:
raise ValueError(
"pandas library is required to read Excel files. Please install it."
)
except Exception as e:
raise ValueError(f"Unable to read Excel file: {e}")
else:
# Handle CSV/text files
csv_content = Path(file_path).read_text(encoding="utf-8")
elif input_data.contents:
# Use direct string content
csv_content = input_data.contents
else:
raise ValueError("Either 'contents' or 'file_input' must be provided")
csv_file = StringIO(csv_content)
reader = csv.reader(
csv_file,
delimiter=input_data.delimiter,
quotechar=input_data.quotechar,
escapechar=input_data.escapechar,
)
header = None
if input_data.has_header:
header = next(reader)
if input_data.strip:
header = [h.strip() for h in header]
for _ in range(input_data.skip_rows):
next(reader)
def process_row(row):
data = {}
for i, value in enumerate(row):
if i not in input_data.skip_columns:
if input_data.has_header and header:
data[header[i]] = value.strip() if input_data.strip else value
else:
data[str(i)] = value.strip() if input_data.strip else value
return data
rows = [process_row(row) for row in reader]
if input_data.produce_singular_result:
for processed_row in rows:
yield "row", processed_row
else:
yield "rows", rows

View File

@@ -106,7 +106,6 @@ class TestHttpBlockWithHostScopedCredentials:
input_data,
credentials=exact_match_credentials,
graph_exec_id="test-exec-id",
user_id="test-user-id",
):
result.append((output_name, output_data))
@@ -162,7 +161,6 @@ class TestHttpBlockWithHostScopedCredentials:
input_data,
credentials=wildcard_credentials,
graph_exec_id="test-exec-id",
user_id="test-user-id",
):
result.append((output_name, output_data))
@@ -209,7 +207,6 @@ class TestHttpBlockWithHostScopedCredentials:
input_data,
credentials=non_matching_credentials,
graph_exec_id="test-exec-id",
user_id="test-user-id",
):
result.append((output_name, output_data))
@@ -259,7 +256,6 @@ class TestHttpBlockWithHostScopedCredentials:
input_data,
credentials=exact_match_credentials,
graph_exec_id="test-exec-id",
user_id="test-user-id",
):
result.append((output_name, output_data))
@@ -319,7 +315,6 @@ class TestHttpBlockWithHostScopedCredentials:
input_data,
credentials=auto_discovered_creds, # Execution manager found these
graph_exec_id="test-exec-id",
user_id="test-user-id",
):
result.append((output_name, output_data))
@@ -383,7 +378,6 @@ class TestHttpBlockWithHostScopedCredentials:
input_data,
credentials=multi_header_creds,
graph_exec_id="test-exec-id",
user_id="test-user-id",
):
result.append((output_name, output_data))
@@ -472,7 +466,6 @@ class TestHttpBlockWithHostScopedCredentials:
input_data,
credentials=test_creds,
graph_exec_id="test-exec-id",
user_id="test-user-id",
):
result.append((output_name, output_data))

Some files were not shown because too many files have changed in this diff Show More