mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-13 09:08:02 -05:00
Compare commits
87 Commits
update-ins
...
swiftyos/i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f850ba033e | ||
|
|
13d7b53991 | ||
|
|
ab4cf9d557 | ||
|
|
9af79f750e | ||
|
|
c3c1ac9845 | ||
|
|
a225b8ab72 | ||
|
|
c432d14db9 | ||
|
|
6435cd340c | ||
|
|
a2f3c322dc | ||
|
|
38c167ff87 | ||
|
|
31ae7e2838 | ||
|
|
1885f88a6f | ||
|
|
c5aa147fd1 | ||
|
|
7790672d9f | ||
|
|
a633c440a9 | ||
|
|
dc9a2f84e7 | ||
|
|
e3115dbe08 | ||
|
|
126498b8d0 | ||
|
|
c5dec20e0c | ||
|
|
922150c7fa | ||
|
|
3aa04d4b96 | ||
|
|
03ca3f9179 | ||
|
|
f9e0b08e19 | ||
|
|
8882768bbf | ||
|
|
249249bdcc | ||
|
|
163713df1a | ||
|
|
ee91540b1a | ||
|
|
a7503ac716 | ||
|
|
df2ef41213 | ||
|
|
a0da6dd09f | ||
|
|
ec73331c79 | ||
|
|
39758a7ee0 | ||
|
|
30cebab17e | ||
|
|
bc7ab15951 | ||
|
|
3fbd3d79af | ||
|
|
c5539c8699 | ||
|
|
dfbeb10342 | ||
|
|
9daf6fb765 | ||
|
|
b3ceceda17 | ||
|
|
002b951c88 | ||
|
|
7a5c5db56f | ||
|
|
5fd15c74bf | ||
|
|
467219323a | ||
|
|
e148063a33 | ||
|
|
3ccecb7f8e | ||
|
|
eecf8c2020 | ||
|
|
35c50e2d4c | ||
|
|
b478ae51c1 | ||
|
|
e564e15701 | ||
|
|
748600d069 | ||
|
|
31aaabc1eb | ||
|
|
4f057c5b72 | ||
|
|
75309047cf | ||
|
|
e58a4599c8 | ||
|
|
848990411d | ||
|
|
ae500cd9c6 | ||
|
|
7f062545ba | ||
|
|
b75967a9a1 | ||
|
|
7c4c9fda0c | ||
|
|
03289f7a84 | ||
|
|
088613c64b | ||
|
|
0aaaf55452 | ||
|
|
aa66188a9a | ||
|
|
31bcdb97a7 | ||
|
|
d1b8dcd298 | ||
|
|
5e27cb3147 | ||
|
|
a09ecab7f1 | ||
|
|
864f76f904 | ||
|
|
19b979ea7f | ||
|
|
213f9aaa90 | ||
|
|
7f10fe9d70 | ||
|
|
31b31e00d9 | ||
|
|
f054d2642b | ||
|
|
0d469bb094 | ||
|
|
bfdc387e02 | ||
|
|
31b99c9572 | ||
|
|
617533fa1d | ||
|
|
f99c974ea8 | ||
|
|
12d43fb2fe | ||
|
|
b49b627a14 | ||
|
|
8073f41804 | ||
|
|
fcf91a0721 | ||
|
|
bce9a6ff46 | ||
|
|
87c802898d | ||
|
|
e353e1e25f | ||
|
|
ea06aed1e1 | ||
|
|
ef9814457c |
@@ -9,13 +9,11 @@
|
||||
|
||||
# Platform - Backend
|
||||
!autogpt_platform/backend/backend/
|
||||
!autogpt_platform/backend/test/e2e_test_data.py
|
||||
!autogpt_platform/backend/migrations/
|
||||
!autogpt_platform/backend/schema.prisma
|
||||
!autogpt_platform/backend/pyproject.toml
|
||||
!autogpt_platform/backend/poetry.lock
|
||||
!autogpt_platform/backend/README.md
|
||||
!autogpt_platform/backend/.env
|
||||
|
||||
# Platform - Market
|
||||
!autogpt_platform/market/market/
|
||||
@@ -28,7 +26,6 @@
|
||||
# Platform - Frontend
|
||||
!autogpt_platform/frontend/src/
|
||||
!autogpt_platform/frontend/public/
|
||||
!autogpt_platform/frontend/scripts/
|
||||
!autogpt_platform/frontend/package.json
|
||||
!autogpt_platform/frontend/pnpm-lock.yaml
|
||||
!autogpt_platform/frontend/tsconfig.json
|
||||
@@ -36,7 +33,6 @@
|
||||
## config
|
||||
!autogpt_platform/frontend/*.config.*
|
||||
!autogpt_platform/frontend/.env.*
|
||||
!autogpt_platform/frontend/.env
|
||||
|
||||
# Classic - AutoGPT
|
||||
!classic/original_autogpt/autogpt/
|
||||
|
||||
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -24,8 +24,7 @@
|
||||
</details>
|
||||
|
||||
#### For configuration changes:
|
||||
|
||||
- [ ] `.env.default` is updated or already compatible with my changes
|
||||
- [ ] `.env.example` is updated or already compatible with my changes
|
||||
- [ ] `docker-compose.yml` is updated or already compatible with my changes
|
||||
- [ ] I have included a list of my configuration changes in the PR description (under **Changes**)
|
||||
|
||||
|
||||
133
.github/workflows/platform-frontend-ci.yml
vendored
133
.github/workflows/platform-frontend-ci.yml
vendored
@@ -37,7 +37,7 @@ jobs:
|
||||
|
||||
- name: Generate cache key
|
||||
id: cache-key
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('**/pnpm-lock.yaml') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v4
|
||||
@@ -45,7 +45,6 @@ jobs:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
@@ -73,7 +72,6 @@ jobs:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
@@ -82,6 +80,36 @@ jobs:
|
||||
- name: Run lint
|
||||
run: pnpm lint
|
||||
|
||||
type-check:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run tsc check
|
||||
run: pnpm type-check
|
||||
|
||||
chromatic:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
@@ -108,7 +136,6 @@ jobs:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
@@ -124,10 +151,12 @@ jobs:
|
||||
exitOnceUploaded: true
|
||||
|
||||
test:
|
||||
runs-on: big-boi
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
browser: [chromium, webkit]
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -143,67 +172,23 @@ jobs:
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
large-packages: false # slow
|
||||
docker-images: false # limited benefit
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
cp ../.env.example ../.env
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Cache Docker layers
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /tmp/.buildx-cache
|
||||
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-buildx-frontend-test-
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
cp ../backend/.env.example ../backend/.env
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml up -d
|
||||
env:
|
||||
DOCKER_BUILDKIT: 1
|
||||
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
|
||||
BUILDX_CACHE_TO: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
||||
|
||||
- name: Move cache
|
||||
run: |
|
||||
rm -rf /tmp/.buildx-cache
|
||||
if [ -d "/tmp/.buildx-cache-new" ]; then
|
||||
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
|
||||
fi
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||
|
||||
- name: Create E2E test data
|
||||
run: |
|
||||
echo "Creating E2E test data..."
|
||||
# First try to run the script from inside the container
|
||||
if docker compose -f ../docker-compose.yml exec -T rest_server test -f /app/autogpt_platform/backend/test/e2e_test_data.py; then
|
||||
echo "✅ Found e2e_test_data.py in container, running it..."
|
||||
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python backend/test/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
else
|
||||
echo "⚠️ e2e_test_data.py not found in container, copying and running..."
|
||||
# Copy the script into the container and run it
|
||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.yml ps -q rest_server):/tmp/e2e_test_data.py || {
|
||||
echo "❌ Failed to copy script to container"
|
||||
exit 1
|
||||
}
|
||||
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
fi
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
@@ -211,25 +196,33 @@ jobs:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Install Browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
- name: Setup .env
|
||||
run: cp .env.example .env
|
||||
|
||||
- name: Build frontend
|
||||
run: pnpm build --turbo
|
||||
# uses Turbopack, much faster and safe enough for a test pipeline
|
||||
|
||||
- name: Install Browser '${{ matrix.browser }}'
|
||||
run: pnpm playwright install --with-deps ${{ matrix.browser }}
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
|
||||
- name: Upload Playwright artifacts
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: playwright-report
|
||||
run: pnpm test:no-build --project=${{ matrix.browser }}
|
||||
env:
|
||||
BROWSER_TYPE: ${{ matrix.browser }}
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
run: docker compose -f ../docker-compose.yml logs
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: ${{ !cancelled() }}
|
||||
with:
|
||||
name: playwright-report-${{ matrix.browser }}
|
||||
path: playwright-report/
|
||||
retention-days: 30
|
||||
|
||||
132
.github/workflows/platform-fullstack-ci.yml
vendored
132
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -1,132 +0,0 @@
|
||||
name: AutoGPT Platform - Frontend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev]
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- "autogpt_platform/**"
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- "autogpt_platform/**"
|
||||
merge_group:
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: autogpt_platform/frontend
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Generate cache key
|
||||
id: cache-key
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
types:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Setup .env
|
||||
run: cp .env.default .env
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||
|
||||
- name: Generate API queries
|
||||
run: pnpm generate:api:force
|
||||
|
||||
- name: Check for API schema changes
|
||||
run: |
|
||||
if ! git diff --exit-code src/app/api/openapi.json; then
|
||||
echo "❌ API schema changes detected in src/app/api/openapi.json"
|
||||
echo ""
|
||||
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
|
||||
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
|
||||
echo "The API schema is now out of sync with the Front-end queries."
|
||||
echo ""
|
||||
echo "To fix this:"
|
||||
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
|
||||
echo "2. Run 'pnpm generate:api' locally"
|
||||
echo "3. Run 'pnpm types' locally"
|
||||
echo "4. Fix any TypeScript errors that may have been introduced"
|
||||
echo "5. Commit and push your changes"
|
||||
echo ""
|
||||
exit 1
|
||||
else
|
||||
echo "✅ No API schema changes detected"
|
||||
fi
|
||||
|
||||
- name: Run Typescript checks
|
||||
run: pnpm types
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -5,8 +5,6 @@ classic/original_autogpt/*.json
|
||||
auto_gpt_workspace/*
|
||||
*.mpeg
|
||||
.env
|
||||
# Root .env files
|
||||
/.env
|
||||
azure.yaml
|
||||
.vscode
|
||||
.idea/*
|
||||
@@ -123,6 +121,7 @@ celerybeat.pid
|
||||
|
||||
# Environments
|
||||
.direnv/
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv*/
|
||||
@@ -178,3 +177,6 @@ autogpt_platform/backend/settings.py
|
||||
*.ign.*
|
||||
.test-contents
|
||||
.claude/settings.local.json
|
||||
|
||||
api.md
|
||||
blocks.md
|
||||
@@ -235,7 +235,7 @@ repos:
|
||||
hooks:
|
||||
- id: tsc
|
||||
name: Typecheck - AutoGPT Platform - Frontend
|
||||
entry: bash -c 'cd autogpt_platform/frontend && pnpm types'
|
||||
entry: bash -c 'cd autogpt_platform/frontend && pnpm type-check'
|
||||
files: ^autogpt_platform/frontend/
|
||||
types: [file]
|
||||
language: system
|
||||
|
||||
6
.vscode/launch.json
vendored
6
.vscode/launch.json
vendored
@@ -6,7 +6,7 @@
|
||||
"type": "node-terminal",
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceFolder}/autogpt_platform/frontend",
|
||||
"command": "pnpm dev"
|
||||
"command": "yarn dev"
|
||||
},
|
||||
{
|
||||
"name": "Frontend: Client Side",
|
||||
@@ -19,12 +19,12 @@
|
||||
"type": "node-terminal",
|
||||
|
||||
"request": "launch",
|
||||
"command": "pnpm dev",
|
||||
"command": "yarn dev",
|
||||
"cwd": "${workspaceFolder}/autogpt_platform/frontend",
|
||||
"serverReadyAction": {
|
||||
"pattern": "- Local:.+(https?://.+)",
|
||||
"uriFormat": "%s",
|
||||
"action": "debugWithChrome"
|
||||
"action": "debugWithEdge"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
195
LICENSE
195
LICENSE
@@ -1,197 +1,6 @@
|
||||
All portions of this repository are under one of two licenses.
|
||||
All portions of this repository are under one of two licenses. The majority of the AutoGPT repository is under the MIT License below. The autogpt_platform folder is under the
|
||||
Polyform Shield License.
|
||||
|
||||
- Everything inside the autogpt_platform folder is under the Polyform Shield License.
|
||||
- Everything outside the autogpt_platform folder is under the MIT License.
|
||||
|
||||
More info:
|
||||
|
||||
**Polyform Shield License:**
|
||||
Code and content within the `autogpt_platform` folder is licensed under the Polyform Shield License. This new project is our in-developlemt platform for building, deploying and managing agents.
|
||||
Read more about this effort here: https://agpt.co/blog/introducing-the-autogpt-platform
|
||||
|
||||
**MIT License:**
|
||||
All other portions of the AutoGPT repository (i.e., everything outside the `autogpt_platform` folder) are licensed under the MIT License. This includes:
|
||||
- The Original, stand-alone AutoGPT Agent
|
||||
- Forge: https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/forge
|
||||
- AG Benchmark: https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/benchmark
|
||||
- AutoGPT Classic GUI: https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/frontend.
|
||||
|
||||
We also publish additional work under the MIT Licence in other repositories, such as GravitasML (https://github.com/Significant-Gravitas/gravitasml) which is developed for and used in the AutoGPT Platform, and our [Code Ability](https://github.com/Significant-Gravitas/AutoGPT-Code-Ability) project.
|
||||
|
||||
Both licences are available to read below:
|
||||
|
||||
=====================================================
|
||||
-----------------------------------------------------
|
||||
=====================================================
|
||||
|
||||
# PolyForm Shield License 1.0.0
|
||||
|
||||
<https://polyformproject.org/licenses/shield/1.0.0>
|
||||
|
||||
## Acceptance
|
||||
|
||||
In order to get any license under these terms, you must agree
|
||||
to them as both strict obligations and conditions to all
|
||||
your licenses.
|
||||
|
||||
## Copyright License
|
||||
|
||||
The licensor grants you a copyright license for the
|
||||
software to do everything you might do with the software
|
||||
that would otherwise infringe the licensor's copyright
|
||||
in it for any permitted purpose. However, you may
|
||||
only distribute the software according to [Distribution
|
||||
License](#distribution-license) and make changes or new works
|
||||
based on the software according to [Changes and New Works
|
||||
License](#changes-and-new-works-license).
|
||||
|
||||
## Distribution License
|
||||
|
||||
The licensor grants you an additional copyright license
|
||||
to distribute copies of the software. Your license
|
||||
to distribute covers distributing the software with
|
||||
changes and new works permitted by [Changes and New Works
|
||||
License](#changes-and-new-works-license).
|
||||
|
||||
## Notices
|
||||
|
||||
You must ensure that anyone who gets a copy of any part of
|
||||
the software from you also gets a copy of these terms or the
|
||||
URL for them above, as well as copies of any plain-text lines
|
||||
beginning with `Required Notice:` that the licensor provided
|
||||
with the software. For example:
|
||||
|
||||
> Required Notice: Copyright Yoyodyne, Inc. (http://example.com)
|
||||
|
||||
## Changes and New Works License
|
||||
|
||||
The licensor grants you an additional copyright license to
|
||||
make changes and new works based on the software for any
|
||||
permitted purpose.
|
||||
|
||||
## Patent License
|
||||
|
||||
The licensor grants you a patent license for the software that
|
||||
covers patent claims the licensor can license, or becomes able
|
||||
to license, that you would infringe by using the software.
|
||||
|
||||
## Noncompete
|
||||
|
||||
Any purpose is a permitted purpose, except for providing any
|
||||
product that competes with the software or any product the
|
||||
licensor or any of its affiliates provides using the software.
|
||||
|
||||
## Competition
|
||||
|
||||
Goods and services compete even when they provide functionality
|
||||
through different kinds of interfaces or for different technical
|
||||
platforms. Applications can compete with services, libraries
|
||||
with plugins, frameworks with development tools, and so on,
|
||||
even if they're written in different programming languages
|
||||
or for different computer architectures. Goods and services
|
||||
compete even when provided free of charge. If you market a
|
||||
product as a practical substitute for the software or another
|
||||
product, it definitely competes.
|
||||
|
||||
## New Products
|
||||
|
||||
If you are using the software to provide a product that does
|
||||
not compete, but the licensor or any of its affiliates brings
|
||||
your product into competition by providing a new version of
|
||||
the software or another product using the software, you may
|
||||
continue using versions of the software available under these
|
||||
terms beforehand to provide your competing product, but not
|
||||
any later versions.
|
||||
|
||||
## Discontinued Products
|
||||
|
||||
You may begin using the software to compete with a product
|
||||
or service that the licensor or any of its affiliates has
|
||||
stopped providing, unless the licensor includes a plain-text
|
||||
line beginning with `Licensor Line of Business:` with the
|
||||
software that mentions that line of business. For example:
|
||||
|
||||
> Licensor Line of Business: YoyodyneCMS Content Management
|
||||
System (http://example.com/cms)
|
||||
|
||||
## Sales of Business
|
||||
|
||||
If the licensor or any of its affiliates sells a line of
|
||||
business developing the software or using the software
|
||||
to provide a product, the buyer can also enforce
|
||||
[Noncompete](#noncompete) for that product.
|
||||
|
||||
## Fair Use
|
||||
|
||||
You may have "fair use" rights for the software under the
|
||||
law. These terms do not limit them.
|
||||
|
||||
## No Other Rights
|
||||
|
||||
These terms do not allow you to sublicense or transfer any of
|
||||
your licenses to anyone else, or prevent the licensor from
|
||||
granting licenses to anyone else. These terms do not imply
|
||||
any other licenses.
|
||||
|
||||
## Patent Defense
|
||||
|
||||
If you make any written claim that the software infringes or
|
||||
contributes to infringement of any patent, your patent license
|
||||
for the software granted under these terms ends immediately. If
|
||||
your company makes such a claim, your patent license ends
|
||||
immediately for work on behalf of your company.
|
||||
|
||||
## Violations
|
||||
|
||||
The first time you are notified in writing that you have
|
||||
violated any of these terms, or done anything with the software
|
||||
not covered by your licenses, your licenses can nonetheless
|
||||
continue if you come into full compliance with these terms,
|
||||
and take practical steps to correct past violations, within
|
||||
32 days of receiving notice. Otherwise, all your licenses
|
||||
end immediately.
|
||||
|
||||
## No Liability
|
||||
|
||||
***As far as the law allows, the software comes as is, without
|
||||
any warranty or condition, and the licensor will not be liable
|
||||
to you for any damages arising out of these terms or the use
|
||||
or nature of the software, under any kind of legal claim.***
|
||||
|
||||
## Definitions
|
||||
|
||||
The **licensor** is the individual or entity offering these
|
||||
terms, and the **software** is the software the licensor makes
|
||||
available under these terms.
|
||||
|
||||
A **product** can be a good or service, or a combination
|
||||
of them.
|
||||
|
||||
**You** refers to the individual or entity agreeing to these
|
||||
terms.
|
||||
|
||||
**Your company** is any legal entity, sole proprietorship,
|
||||
or other kind of organization that you work for, plus all
|
||||
its affiliates.
|
||||
|
||||
**Affiliates** means the other organizations than an
|
||||
organization has control over, is under the control of, or is
|
||||
under common control with.
|
||||
|
||||
**Control** means ownership of substantially all the assets of
|
||||
an entity, or the power to direct its management and policies
|
||||
by vote, contract, or otherwise. Control can be direct or
|
||||
indirect.
|
||||
|
||||
**Your licenses** are all the licenses granted to you for the
|
||||
software under these terms.
|
||||
|
||||
**Use** means anything you do with the software requiring one
|
||||
of your licenses.
|
||||
|
||||
=====================================================
|
||||
-----------------------------------------------------
|
||||
=====================================================
|
||||
|
||||
MIT License
|
||||
|
||||
|
||||
59
README.md
59
README.md
@@ -1,25 +1,16 @@
|
||||
# AutoGPT: Build, Deploy, and Run AI Agents
|
||||
|
||||
[](https://discord.gg/autogpt)  
|
||||
[](https://discord.gg/autogpt)  
|
||||
[](https://twitter.com/Auto_GPT)  
|
||||
|
||||
<!-- Keep these links. Translations will automatically update with the README. -->
|
||||
[Deutsch](https://zdoc.app/de/Significant-Gravitas/AutoGPT) |
|
||||
[Español](https://zdoc.app/es/Significant-Gravitas/AutoGPT) |
|
||||
[français](https://zdoc.app/fr/Significant-Gravitas/AutoGPT) |
|
||||
[日本語](https://zdoc.app/ja/Significant-Gravitas/AutoGPT) |
|
||||
[한국어](https://zdoc.app/ko/Significant-Gravitas/AutoGPT) |
|
||||
[Português](https://zdoc.app/pt/Significant-Gravitas/AutoGPT) |
|
||||
[Русский](https://zdoc.app/ru/Significant-Gravitas/AutoGPT) |
|
||||
[中文](https://zdoc.app/zh/Significant-Gravitas/AutoGPT)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
|
||||
**AutoGPT** is a powerful platform that allows you to create, deploy, and manage continuous AI agents that automate complex workflows.
|
||||
|
||||
## Hosting Options
|
||||
- Download to self-host (Free!)
|
||||
- [Join the Waitlist](https://bit.ly/3ZDijAI) for the cloud-hosted beta (Closed Beta - Public release Coming Soon!)
|
||||
- Download to self-host
|
||||
- [Join the Waitlist](https://bit.ly/3ZDijAI) for the cloud-hosted beta
|
||||
|
||||
## How to Self-Host the AutoGPT Platform
|
||||
## How to Setup for Self-Hosting
|
||||
> [!NOTE]
|
||||
> Setting up and hosting the AutoGPT Platform yourself is a technical process.
|
||||
> If you'd rather something that just works, we recommend [joining the waitlist](https://bit.ly/3ZDijAI) for the cloud-hosted beta.
|
||||
@@ -59,24 +50,6 @@ We've moved to a fully maintained and regularly updated documentation site.
|
||||
|
||||
This tutorial assumes you have Docker, VSCode, git and npm installed.
|
||||
|
||||
---
|
||||
|
||||
#### ⚡ Quick Setup with One-Line Script (Recommended for Local Hosting)
|
||||
|
||||
Skip the manual steps and get started in minutes using our automatic setup script.
|
||||
|
||||
For macOS/Linux:
|
||||
```
|
||||
curl -fsSL https://setup.agpt.co/install.sh -o install.sh && bash install.sh
|
||||
```
|
||||
|
||||
For Windows (PowerShell):
|
||||
```
|
||||
powershell -c "iwr https://setup.agpt.co/install.bat -o install.bat; ./install.bat"
|
||||
```
|
||||
|
||||
This will install dependencies, configure Docker, and launch your local instance — all in one go.
|
||||
|
||||
### 🧱 AutoGPT Frontend
|
||||
|
||||
The AutoGPT frontend is where users interact with our powerful AI automation platform. It offers multiple ways to engage with and leverage our AI agents. This is the interface where you'll bring your AI automation ideas to life:
|
||||
@@ -123,17 +96,7 @@ Here are two examples of what you can do with AutoGPT:
|
||||
These examples show just a glimpse of what you can achieve with AutoGPT! You can create customized workflows to build agents for any use case.
|
||||
|
||||
---
|
||||
|
||||
### **License Overview:**
|
||||
|
||||
🛡️ **Polyform Shield License:**
|
||||
All code and content within the `autogpt_platform` folder is licensed under the Polyform Shield License. This new project is our in-developlemt platform for building, deploying and managing agents.</br>_[Read more about this effort](https://agpt.co/blog/introducing-the-autogpt-platform)_
|
||||
|
||||
🦉 **MIT License:**
|
||||
All other portions of the AutoGPT repository (i.e., everything outside the `autogpt_platform` folder) are licensed under the MIT License. This includes the original stand-alone AutoGPT Agent, along with projects such as [Forge](https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/forge), [agbenchmark](https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/benchmark) and the [AutoGPT Classic GUI](https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/frontend).</br>We also publish additional work under the MIT Licence in other repositories, such as [GravitasML](https://github.com/Significant-Gravitas/gravitasml) which is developed for and used in the AutoGPT Platform. See also our MIT Licenced [Code Ability](https://github.com/Significant-Gravitas/AutoGPT-Code-Ability) project.
|
||||
|
||||
---
|
||||
### Mission
|
||||
### Mission and Licencing
|
||||
Our mission is to provide the tools, so that you can focus on what matters:
|
||||
|
||||
- 🏗️ **Building** - Lay the foundation for something amazing.
|
||||
@@ -146,6 +109,14 @@ Be part of the revolution! **AutoGPT** is here to stay, at the forefront of AI i
|
||||
 | 
|
||||
**🚀 [Contributing](CONTRIBUTING.md)**
|
||||
|
||||
**Licensing:**
|
||||
|
||||
MIT License: The majority of the AutoGPT repository is under the MIT License.
|
||||
|
||||
Polyform Shield License: This license applies to the autogpt_platform folder.
|
||||
|
||||
For more information, see https://agpt.co/blog/introducing-the-autogpt-platform
|
||||
|
||||
---
|
||||
## 🤖 AutoGPT Classic
|
||||
> Below is information about the classic version of AutoGPT.
|
||||
@@ -235,4 +206,4 @@ To maintain a uniform standard and ensure seamless compatibility with many curre
|
||||
|
||||
<a href="https://github.com/Significant-Gravitas/AutoGPT/graphs/contributors" alt="View Contributors">
|
||||
<img src="https://contrib.rocks/image?repo=Significant-Gravitas/AutoGPT&max=1000&columns=10" alt="Contributors" />
|
||||
</a>
|
||||
</a>
|
||||
@@ -1,11 +1,9 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`/backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`/frontend`): Next.js React application
|
||||
- **Shared Libraries** (`/autogpt_libs`): Common Python utilities
|
||||
@@ -13,7 +11,6 @@ AutoGPT Platform is a monorepo containing:
|
||||
## Essential Commands
|
||||
|
||||
### Backend Development
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd backend && poetry install
|
||||
@@ -33,18 +30,11 @@ poetry run test
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in TESTING.md
|
||||
|
||||
#### Creating/Updating Snapshots
|
||||
@@ -57,8 +47,8 @@ poetry run pytest path/to/test.py --snapshot-update
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
### Frontend Development
|
||||
|
||||
### Frontend Development
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd frontend && npm install
|
||||
@@ -76,13 +66,12 @@ npm run storybook
|
||||
npm run build
|
||||
|
||||
# Type checking
|
||||
npm run types
|
||||
npm run type-check
|
||||
```
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Backend Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
@@ -91,7 +80,6 @@ npm run types
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
### Frontend Architecture
|
||||
|
||||
- **Framework**: Next.js App Router with React Server Components
|
||||
- **State Management**: React hooks + Supabase client for real-time updates
|
||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
||||
@@ -99,7 +87,6 @@ npm run types
|
||||
- **Feature Flags**: LaunchDarkly integration
|
||||
|
||||
### Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
@@ -107,16 +94,13 @@ npm run types
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Testing Approach
|
||||
|
||||
- Backend uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Frontend uses Playwright for E2E tests
|
||||
- Component testing via Storybook
|
||||
|
||||
### Database Schema
|
||||
|
||||
Key models (defined in `/backend/schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
@@ -124,31 +108,13 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides)
|
||||
- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides)
|
||||
- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
1. `.env.default` files provide base configuration (tracked in git)
|
||||
2. `.env` files provide user-specific overrides (gitignored)
|
||||
3. Docker Compose `environment:` sections provide service-specific overrides
|
||||
4. Shell environment variables have highest precedence
|
||||
|
||||
#### Key Points
|
||||
|
||||
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
- Backend: `.env` file in `/backend`
|
||||
- Frontend: `.env.local` file in `/frontend`
|
||||
- Both require Supabase credentials and API keys for various services
|
||||
|
||||
### Common Development Tasks
|
||||
|
||||
**Adding a new block:**
|
||||
|
||||
1. Create new file in `/backend/backend/blocks/`
|
||||
2. Inherit from `Block` base class
|
||||
3. Define input/output schemas
|
||||
@@ -156,18 +122,13 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
5. Register in block registry
|
||||
6. Generate the block uuid using `uuid.uuid4()`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blcoks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
**Modifying the API:**
|
||||
|
||||
1. Update route in `/backend/backend/server/routers/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
**Frontend feature development:**
|
||||
|
||||
1. Components go in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for new components
|
||||
@@ -176,7 +137,6 @@ ex: do the inputs and outputs tie well together?
|
||||
### Security Implementation
|
||||
|
||||
**Cache Protection Middleware:**
|
||||
|
||||
- Located in `/backend/backend/server/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
@@ -184,47 +144,3 @@ ex: do the inputs and outputs tie well together?
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR aginst the `dev` branch of the repository.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/
|
||||
- Use conventional commit messages (see below)/
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
- When the user runs /pr-comments or tries to fetch them, also run gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews to get the reviews
|
||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews/[review_id]/comments to get the review contents
|
||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/issues/9924/comments to get the pr specific comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
Use this format for commit messages and Pull Request titles:
|
||||
|
||||
**Conventional Commit Types:**
|
||||
|
||||
- `feat`: Introduces a new feature to the codebase
|
||||
- `fix`: Patches a bug in the codebase
|
||||
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
|
||||
- `ci`: Changes to CI configuration
|
||||
- `docs`: Documentation-only changes
|
||||
- `dx`: Improvements to the developer experience
|
||||
|
||||
**Recommended Base Scopes:**
|
||||
|
||||
- `platform`: Changes affecting both frontend and backend
|
||||
- `frontend`
|
||||
- `backend`
|
||||
- `infra`
|
||||
- `blocks`: Modifications/additions of individual blocks
|
||||
|
||||
**Subscope Examples:**
|
||||
|
||||
- `backend/executor`
|
||||
- `backend/db`
|
||||
- `frontend/builder` (includes changes to the block UI component)
|
||||
- `infra/prod`
|
||||
|
||||
Use these scopes and subscopes for clarity and consistency in commit messages.
|
||||
|
||||
@@ -8,6 +8,7 @@ Welcome to the AutoGPT Platform - a powerful system for creating and running AI
|
||||
|
||||
- Docker
|
||||
- Docker Compose V2 (comes with Docker Desktop, or can be installed separately)
|
||||
- Node.js & NPM (for running the frontend application)
|
||||
|
||||
### Running the System
|
||||
|
||||
@@ -23,10 +24,10 @@ To run the AutoGPT Platform, follow these steps:
|
||||
2. Run the following command:
|
||||
|
||||
```
|
||||
cp .env.default .env
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
This command will copy the `.env.default` file to `.env`. You can modify the `.env` file to add your own environment variables.
|
||||
This command will copy the `.env.example` file to `.env`. You can modify the `.env` file to add your own environment variables.
|
||||
|
||||
3. Run the following command:
|
||||
|
||||
@@ -36,7 +37,44 @@ To run the AutoGPT Platform, follow these steps:
|
||||
|
||||
This command will start all the necessary backend services defined in the `docker-compose.yml` file in detached mode.
|
||||
|
||||
4. After all the services are in ready state, open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
4. Navigate to `frontend` within the `autogpt_platform` directory:
|
||||
|
||||
```
|
||||
cd frontend
|
||||
```
|
||||
|
||||
You will need to run your frontend application separately on your local machine.
|
||||
|
||||
5. Run the following command:
|
||||
|
||||
```
|
||||
cp .env.example .env.local
|
||||
```
|
||||
|
||||
This command will copy the `.env.example` file to `.env.local` in the `frontend` directory. You can modify the `.env.local` within this folder to add your own environment variables for the frontend application.
|
||||
|
||||
6. Run the following command:
|
||||
|
||||
Enable corepack and install dependencies by running:
|
||||
|
||||
```
|
||||
corepack enable
|
||||
pnpm i
|
||||
```
|
||||
|
||||
Generate the API client (this step is required before running the frontend):
|
||||
|
||||
```
|
||||
pnpm generate:api-client
|
||||
```
|
||||
|
||||
Then start the frontend application in development mode:
|
||||
|
||||
```
|
||||
pnpm dev
|
||||
```
|
||||
|
||||
7. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
|
||||
### Docker Compose Commands
|
||||
|
||||
@@ -139,21 +177,20 @@ The platform includes scripts for generating and managing the API client:
|
||||
|
||||
- `pnpm fetch:openapi`: Fetches the OpenAPI specification from the backend service (requires backend to be running on port 8006)
|
||||
- `pnpm generate:api-client`: Generates the TypeScript API client from the OpenAPI specification using Orval
|
||||
- `pnpm generate:api`: Runs both fetch and generate commands in sequence
|
||||
- `pnpm generate:api-all`: Runs both fetch and generate commands in sequence
|
||||
|
||||
#### Manual API Client Updates
|
||||
|
||||
If you need to update the API client after making changes to the backend API:
|
||||
|
||||
1. Ensure the backend services are running:
|
||||
|
||||
```
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
2. Generate the updated API client:
|
||||
```
|
||||
pnpm generate:api
|
||||
pnpm generate:api-all
|
||||
```
|
||||
|
||||
This will fetch the latest OpenAPI specification and regenerate the TypeScript client code.
|
||||
|
||||
@@ -7,5 +7,9 @@ class Settings:
|
||||
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
|
||||
self.JWT_ALGORITHM: str = "HS256"
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(self.JWT_SECRET_KEY)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -0,0 +1,166 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, cast
|
||||
|
||||
import ldclient
|
||||
from fastapi import HTTPException
|
||||
from ldclient import Context, LDClient
|
||||
from ldclient.config import Config
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from .config import SETTINGS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_client() -> LDClient:
|
||||
"""Get the LaunchDarkly client singleton."""
|
||||
return ldclient.get()
|
||||
|
||||
|
||||
def initialize_launchdarkly() -> None:
|
||||
sdk_key = SETTINGS.launch_darkly_sdk_key
|
||||
logger.debug(
|
||||
f"Initializing LaunchDarkly with SDK key: {'present' if sdk_key else 'missing'}"
|
||||
)
|
||||
|
||||
if not sdk_key:
|
||||
logger.warning("LaunchDarkly SDK key not configured")
|
||||
return
|
||||
|
||||
config = Config(sdk_key)
|
||||
ldclient.set_config(config)
|
||||
|
||||
if ldclient.get().is_initialized():
|
||||
logger.info("LaunchDarkly client initialized successfully")
|
||||
else:
|
||||
logger.error("LaunchDarkly client failed to initialize")
|
||||
|
||||
|
||||
def shutdown_launchdarkly() -> None:
|
||||
"""Shutdown the LaunchDarkly client."""
|
||||
if ldclient.get().is_initialized():
|
||||
ldclient.get().close()
|
||||
logger.info("LaunchDarkly client closed successfully")
|
||||
|
||||
|
||||
def create_context(
|
||||
user_id: str, additional_attributes: Optional[Dict[str, Any]] = None
|
||||
) -> Context:
|
||||
"""Create LaunchDarkly context with optional additional attributes."""
|
||||
builder = Context.builder(str(user_id)).kind("user")
|
||||
if additional_attributes:
|
||||
for key, value in additional_attributes.items():
|
||||
builder.set(key, value)
|
||||
return builder.build()
|
||||
|
||||
|
||||
def feature_flag(
|
||||
flag_key: str,
|
||||
default: bool = False,
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""
|
||||
Decorator for feature flag protected endpoints.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[P, Union[T, Awaitable[T]]],
|
||||
) -> Callable[P, Union[T, Awaitable[T]]]:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
try:
|
||||
user_id = kwargs.get("user_id")
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
if not get_client().is_initialized():
|
||||
logger.warning(
|
||||
f"LaunchDarkly not initialized, using default={default}"
|
||||
)
|
||||
is_enabled = default
|
||||
else:
|
||||
context = create_context(str(user_id))
|
||||
is_enabled = get_client().variation(flag_key, context, default)
|
||||
|
||||
if not is_enabled:
|
||||
raise HTTPException(status_code=404, detail="Feature not available")
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
if asyncio.iscoroutine(result):
|
||||
return await result
|
||||
return cast(T, result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
|
||||
raise
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
try:
|
||||
user_id = kwargs.get("user_id")
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
if not get_client().is_initialized():
|
||||
logger.warning(
|
||||
f"LaunchDarkly not initialized, using default={default}"
|
||||
)
|
||||
is_enabled = default
|
||||
else:
|
||||
context = create_context(str(user_id))
|
||||
is_enabled = get_client().variation(flag_key, context, default)
|
||||
|
||||
if not is_enabled:
|
||||
raise HTTPException(status_code=404, detail="Feature not available")
|
||||
|
||||
return cast(T, func(*args, **kwargs))
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
|
||||
raise
|
||||
|
||||
return cast(
|
||||
Callable[P, Union[T, Awaitable[T]]],
|
||||
async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper,
|
||||
)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def percentage_rollout(
|
||||
flag_key: str,
|
||||
default: bool = False,
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""Decorator for percentage-based rollouts."""
|
||||
return feature_flag(flag_key, default)
|
||||
|
||||
|
||||
def beta_feature(
|
||||
flag_key: Optional[str] = None,
|
||||
unauthorized_response: Any = {"message": "Not available in beta"},
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""Decorator for beta features."""
|
||||
actual_key = f"beta-{flag_key}" if flag_key else "beta"
|
||||
return feature_flag(actual_key, False)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_flag_variation(flag_key: str, return_value: Any):
|
||||
"""Context manager for testing feature flags."""
|
||||
original_variation = get_client().variation
|
||||
get_client().variation = lambda key, context, default: (
|
||||
return_value if key == flag_key else original_variation(key, context, default)
|
||||
)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
get_client().variation = original_variation
|
||||
@@ -0,0 +1,45 @@
|
||||
import pytest
|
||||
from ldclient import LDClient
|
||||
|
||||
from autogpt_libs.feature_flag.client import feature_flag, mock_flag_variation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ld_client(mocker):
|
||||
client = mocker.Mock(spec=LDClient)
|
||||
mocker.patch("ldclient.get", return_value=client)
|
||||
client.is_initialized.return_value = True
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feature_flag_enabled(ld_client):
|
||||
ld_client.variation.return_value = True
|
||||
|
||||
@feature_flag("test-flag")
|
||||
async def test_function(user_id: str):
|
||||
return "success"
|
||||
|
||||
result = test_function(user_id="test-user")
|
||||
assert result == "success"
|
||||
ld_client.variation.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feature_flag_unauthorized_response(ld_client):
|
||||
ld_client.variation.return_value = False
|
||||
|
||||
@feature_flag("test-flag")
|
||||
async def test_function(user_id: str):
|
||||
return "success"
|
||||
|
||||
result = test_function(user_id="test-user")
|
||||
assert result == {"error": "disabled"}
|
||||
|
||||
|
||||
def test_mock_flag_variation(ld_client):
|
||||
with mock_flag_variation("test-flag", True):
|
||||
assert ld_client.variation("test-flag", None, False)
|
||||
|
||||
with mock_flag_variation("test-flag", False):
|
||||
assert ld_client.variation("test-flag", None, False)
|
||||
@@ -0,0 +1,15 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
launch_darkly_sdk_key: str = Field(
|
||||
default="",
|
||||
description="The Launch Darkly SDK key",
|
||||
validation_alias="LAUNCH_DARKLY_SDK_KEY",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(case_sensitive=True, extra="ignore")
|
||||
|
||||
|
||||
SETTINGS = Settings()
|
||||
@@ -1,8 +1,6 @@
|
||||
"""Logging module for Auto-GPT."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@@ -12,15 +10,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from .filters import BelowLevelFilter
|
||||
from .formatters import AGPTFormatter
|
||||
|
||||
# Configure global socket timeout and gRPC keepalive to prevent deadlocks
|
||||
# This must be done at import time before any gRPC connections are established
|
||||
socket.setdefaulttimeout(30) # 30-second socket timeout
|
||||
|
||||
# Enable gRPC keepalive to detect dead connections faster
|
||||
os.environ.setdefault("GRPC_KEEPALIVE_TIME_MS", "30000") # 30 seconds
|
||||
os.environ.setdefault("GRPC_KEEPALIVE_TIMEOUT_MS", "5000") # 5 seconds
|
||||
os.environ.setdefault("GRPC_KEEPALIVE_PERMIT_WITHOUT_CALLS", "true")
|
||||
|
||||
LOG_DIR = Path(__file__).parent.parent.parent.parent / "logs"
|
||||
LOG_FILE = "activity.log"
|
||||
DEBUG_LOG_FILE = "debug.log"
|
||||
@@ -90,6 +79,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
Note: This function is typically called at the start of the application
|
||||
to set up the logging infrastructure.
|
||||
"""
|
||||
|
||||
config = LoggingConfig()
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
@@ -115,17 +105,13 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
if config.enable_cloud_logging or force_cloud_logging:
|
||||
import google.cloud.logging
|
||||
from google.cloud.logging.handlers import CloudLoggingHandler
|
||||
from google.cloud.logging_v2.handlers.transports import (
|
||||
BackgroundThreadTransport,
|
||||
)
|
||||
from google.cloud.logging_v2.handlers.transports.sync import SyncTransport
|
||||
|
||||
client = google.cloud.logging.Client()
|
||||
# Use BackgroundThreadTransport to prevent blocking the main thread
|
||||
# and deadlocks when gRPC calls to Google Cloud Logging hang
|
||||
cloud_handler = CloudLoggingHandler(
|
||||
client,
|
||||
name="autogpt_logs",
|
||||
transport=BackgroundThreadTransport,
|
||||
transport=SyncTransport,
|
||||
)
|
||||
cloud_handler.setLevel(config.level)
|
||||
log_handlers.append(cloud_handler)
|
||||
|
||||
@@ -1,5 +1,39 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import uvicorn.config
|
||||
from colorama import Fore
|
||||
|
||||
|
||||
def remove_color_codes(s: str) -> str:
|
||||
return re.sub(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])", "", s)
|
||||
|
||||
|
||||
def fmt_kwargs(kwargs: dict) -> str:
|
||||
return ", ".join(f"{n}={repr(v)}" for n, v in kwargs.items())
|
||||
|
||||
|
||||
def print_attribute(
|
||||
title: str, value: Any, title_color: str = Fore.GREEN, value_color: str = ""
|
||||
) -> None:
|
||||
logger = logging.getLogger()
|
||||
logger.info(
|
||||
str(value),
|
||||
extra={
|
||||
"title": f"{title.rstrip(':')}:",
|
||||
"title_color": title_color,
|
||||
"color": value_color,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def generate_uvicorn_config():
|
||||
"""
|
||||
Generates a uvicorn logging config that silences uvicorn's default logging and tells it to use the native logging module.
|
||||
"""
|
||||
log_config = dict(uvicorn.config.LOGGING_CONFIG)
|
||||
log_config["loggers"]["uvicorn"] = {"handlers": []}
|
||||
log_config["loggers"]["uvicorn.error"] = {"handlers": []}
|
||||
log_config["loggers"]["uvicorn.access"] = {"handlers": []}
|
||||
return log_config
|
||||
|
||||
@@ -1,34 +1,17 @@
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
|
||||
pass
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
def thread_cached(
|
||||
@@ -74,193 +57,3 @@ def thread_cached(
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
|
||||
|
||||
FuncT = TypeVar("FuncT")
|
||||
|
||||
|
||||
R_co = TypeVar("R_co", covariant=True)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AsyncCachedFunction(Protocol[P, R_co]):
|
||||
"""Protocol for async functions with cache management methods."""
|
||||
|
||||
def cache_clear(self) -> None:
|
||||
"""Clear all cached entries."""
|
||||
return None
|
||||
|
||||
def cache_info(self) -> dict[str, int | None]:
|
||||
"""Get cache statistics."""
|
||||
return {}
|
||||
|
||||
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
|
||||
"""Call the cached function."""
|
||||
return None # type: ignore
|
||||
|
||||
|
||||
def async_ttl_cache(
|
||||
maxsize: int = 128, ttl_seconds: int | None = None
|
||||
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
|
||||
"""
|
||||
TTL (Time To Live) cache decorator for async functions.
|
||||
|
||||
Similar to functools.lru_cache but works with async functions and includes optional TTL.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum number of cached entries
|
||||
ttl_seconds: Time to live in seconds. If None, entries never expire (like lru_cache)
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
|
||||
Example:
|
||||
# With TTL
|
||||
@async_ttl_cache(maxsize=1000, ttl_seconds=300)
|
||||
async def api_call(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
# Without TTL (permanent cache like lru_cache)
|
||||
@async_ttl_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
async_func: Callable[P, Awaitable[R]],
|
||||
) -> AsyncCachedFunction[P, R]:
|
||||
# Cache storage - use union type to handle both cases
|
||||
cache_storage: dict[tuple, R | Tuple[R, float]] = {}
|
||||
|
||||
@wraps(async_func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# Create cache key from arguments
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we have a valid cached entry
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
# No TTL - return cached result directly
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cast(R, cache_storage[key])
|
||||
else:
|
||||
# With TTL - check expiration
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cast(R, result)
|
||||
else:
|
||||
# Expired entry
|
||||
del cache_storage[key]
|
||||
logger.debug(
|
||||
f"Cache entry expired for {async_func.__name__}"
|
||||
)
|
||||
|
||||
# Cache miss or expired - fetch fresh data
|
||||
logger.debug(
|
||||
f"Cache miss for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
result = await async_func(*args, **kwargs)
|
||||
|
||||
# Store in cache
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Simple cleanup when cache gets too large
|
||||
if len(cache_storage) > maxsize:
|
||||
# Remove oldest entries (simple FIFO cleanup)
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
logger.debug(
|
||||
f"Cache cleanup: removed {len(oldest_keys)} entries for {async_func.__name__}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
# Add cache management methods (similar to functools.lru_cache)
|
||||
def cache_clear() -> None:
|
||||
cache_storage.clear()
|
||||
|
||||
def cache_info() -> dict[str, int | None]:
|
||||
return {
|
||||
"size": len(cache_storage),
|
||||
"maxsize": maxsize,
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
|
||||
# Attach methods to wrapper
|
||||
setattr(wrapper, "cache_clear", cache_clear)
|
||||
setattr(wrapper, "cache_info", cache_info)
|
||||
|
||||
return cast(AsyncCachedFunction[P, R], wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def async_cache(
|
||||
func: Callable[P, Awaitable[R]],
|
||||
) -> AsyncCachedFunction[P, R]:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def async_cache(
|
||||
func: None = None,
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
|
||||
pass
|
||||
|
||||
|
||||
def async_cache(
|
||||
func: Callable[P, Awaitable[R]] | None = None,
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
) -> (
|
||||
AsyncCachedFunction[P, R]
|
||||
| Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]
|
||||
):
|
||||
"""
|
||||
Process-level cache decorator for async functions (no TTL).
|
||||
|
||||
Similar to functools.lru_cache but works with async functions.
|
||||
This is a convenience wrapper around async_ttl_cache with ttl_seconds=None.
|
||||
|
||||
Args:
|
||||
func: The async function to cache (when used without parentheses)
|
||||
maxsize: Maximum number of cached entries
|
||||
|
||||
Returns:
|
||||
Decorated function or decorator
|
||||
|
||||
Example:
|
||||
# Without parentheses (uses default maxsize=128)
|
||||
@async_cache
|
||||
async def get_data(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
# With parentheses and custom maxsize
|
||||
@async_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
# Expensive computation here
|
||||
return {"result": param}
|
||||
"""
|
||||
if func is None:
|
||||
# Called with parentheses @async_cache() or @async_cache(maxsize=...)
|
||||
return async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
else:
|
||||
# Called without parentheses @async_cache
|
||||
decorator = async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
return decorator(func)
|
||||
|
||||
@@ -1,705 +0,0 @@
|
||||
"""Tests for the @thread_cached decorator.
|
||||
|
||||
This module tests the thread-local caching functionality including:
|
||||
- Basic caching for sync and async functions
|
||||
- Thread isolation (each thread has its own cache)
|
||||
- Cache clearing functionality
|
||||
- Exception handling (exceptions are not cached)
|
||||
- Argument handling (positional vs keyword arguments)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt_libs.utils.cache import (
|
||||
async_cache,
|
||||
async_ttl_cache,
|
||||
clear_thread_cache,
|
||||
thread_cached,
|
||||
)
|
||||
|
||||
|
||||
class TestThreadCached:
|
||||
def test_sync_function_caching(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def expensive_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x + y
|
||||
|
||||
assert expensive_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert expensive_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert expensive_function(1, y=2) == 3
|
||||
assert call_count == 2
|
||||
|
||||
assert expensive_function(2, 3) == 5
|
||||
assert call_count == 3
|
||||
|
||||
assert expensive_function(1) == 1
|
||||
assert call_count == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function_caching(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
async def expensive_async_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x + y
|
||||
|
||||
assert await expensive_async_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert await expensive_async_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert await expensive_async_function(1, y=2) == 3
|
||||
assert call_count == 2
|
||||
|
||||
assert await expensive_async_function(2, 3) == 5
|
||||
assert call_count == 3
|
||||
|
||||
def test_thread_isolation(self):
|
||||
call_count = 0
|
||||
results = {}
|
||||
|
||||
@thread_cached
|
||||
def thread_specific_function(x: int) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{threading.current_thread().name}-{x}"
|
||||
|
||||
def worker(thread_id: int):
|
||||
result1 = thread_specific_function(1)
|
||||
result2 = thread_specific_function(1)
|
||||
result3 = thread_specific_function(2)
|
||||
results[thread_id] = (result1, result2, result3)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||
futures = [executor.submit(worker, i) for i in range(3)]
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
assert call_count >= 2
|
||||
|
||||
for thread_id, (r1, r2, r3) in results.items():
|
||||
assert r1 == r2
|
||||
assert r1 != r3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_thread_isolation(self):
|
||||
call_count = 0
|
||||
results = {}
|
||||
|
||||
@thread_cached
|
||||
async def async_thread_specific_function(x: int) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return f"{threading.current_thread().name}-{x}"
|
||||
|
||||
async def async_worker(worker_id: int):
|
||||
result1 = await async_thread_specific_function(1)
|
||||
result2 = await async_thread_specific_function(1)
|
||||
result3 = await async_thread_specific_function(2)
|
||||
results[worker_id] = (result1, result2, result3)
|
||||
|
||||
tasks = [async_worker(i) for i in range(3)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
for worker_id, (r1, r2, r3) in results.items():
|
||||
assert r1 == r2
|
||||
assert r1 != r3
|
||||
|
||||
def test_clear_cache_sync(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
|
||||
assert clearable_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
assert clearable_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
clear_thread_cache(clearable_function)
|
||||
|
||||
assert clearable_function(5) == 10
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_cache_async(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
async def clearable_async_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x * 2
|
||||
|
||||
assert await clearable_async_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
assert await clearable_async_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
clear_thread_cache(clearable_async_function)
|
||||
|
||||
assert await clearable_async_function(5) == 10
|
||||
assert call_count == 2
|
||||
|
||||
def test_simple_arguments(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def simple_function(a: str, b: int, c: str = "default") -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{a}-{b}-{c}"
|
||||
|
||||
# First call with all positional args
|
||||
result1 = simple_function("test", 42, "custom")
|
||||
assert call_count == 1
|
||||
|
||||
# Same args, all positional - should hit cache
|
||||
result2 = simple_function("test", 42, "custom")
|
||||
assert call_count == 1
|
||||
assert result1 == result2
|
||||
|
||||
# Same values but last arg as keyword - creates different cache key
|
||||
result3 = simple_function("test", 42, c="custom")
|
||||
assert call_count == 2
|
||||
assert result1 == result3 # Same result, different cache entry
|
||||
|
||||
# Different value - new cache entry
|
||||
result4 = simple_function("test", 43, "custom")
|
||||
assert call_count == 3
|
||||
assert result1 != result4
|
||||
|
||||
def test_positional_vs_keyword_args(self):
|
||||
"""Test that positional and keyword arguments create different cache entries."""
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def func(a: int, b: int = 10) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"result-{a}-{b}"
|
||||
|
||||
# All positional
|
||||
result1 = func(1, 2)
|
||||
assert call_count == 1
|
||||
assert result1 == "result-1-2"
|
||||
|
||||
# Same values, but second arg as keyword
|
||||
result2 = func(1, b=2)
|
||||
assert call_count == 2 # Different cache key!
|
||||
assert result2 == "result-1-2" # Same result
|
||||
|
||||
# Verify both are cached separately
|
||||
func(1, 2) # Uses first cache entry
|
||||
assert call_count == 2
|
||||
|
||||
func(1, b=2) # Uses second cache entry
|
||||
assert call_count == 2
|
||||
|
||||
def test_exception_handling(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def failing_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if x < 0:
|
||||
raise ValueError("Negative value")
|
||||
return x * 2
|
||||
|
||||
assert failing_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
failing_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
failing_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
assert failing_function(5) == 10
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_exception_handling(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
async def async_failing_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
if x < 0:
|
||||
raise ValueError("Negative value")
|
||||
return x * 2
|
||||
|
||||
assert await async_failing_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await async_failing_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await async_failing_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
def test_sync_caching_performance(self):
|
||||
@thread_cached
|
||||
def slow_function(x: int) -> int:
|
||||
print(f"slow_function called with x={x}")
|
||||
time.sleep(0.1)
|
||||
return x * 2
|
||||
|
||||
start = time.time()
|
||||
result1 = slow_function(5)
|
||||
first_call_time = time.time() - start
|
||||
print(f"First call took {first_call_time:.4f} seconds")
|
||||
|
||||
start = time.time()
|
||||
result2 = slow_function(5)
|
||||
second_call_time = time.time() - start
|
||||
print(f"Second call took {second_call_time:.4f} seconds")
|
||||
|
||||
assert result1 == result2 == 10
|
||||
assert first_call_time > 0.09
|
||||
assert second_call_time < 0.01
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_caching_performance(self):
|
||||
@thread_cached
|
||||
async def slow_async_function(x: int) -> int:
|
||||
print(f"slow_async_function called with x={x}")
|
||||
await asyncio.sleep(0.1)
|
||||
return x * 2
|
||||
|
||||
start = time.time()
|
||||
result1 = await slow_async_function(5)
|
||||
first_call_time = time.time() - start
|
||||
print(f"First async call took {first_call_time:.4f} seconds")
|
||||
|
||||
start = time.time()
|
||||
result2 = await slow_async_function(5)
|
||||
second_call_time = time.time() - start
|
||||
print(f"Second async call took {second_call_time:.4f} seconds")
|
||||
|
||||
assert result1 == result2 == 10
|
||||
assert first_call_time > 0.09
|
||||
assert second_call_time < 0.01
|
||||
|
||||
def test_with_mock_objects(self):
|
||||
mock = Mock(return_value=42)
|
||||
|
||||
@thread_cached
|
||||
def function_using_mock(x: int) -> int:
|
||||
return mock(x)
|
||||
|
||||
assert function_using_mock(1) == 42
|
||||
assert mock.call_count == 1
|
||||
|
||||
assert function_using_mock(1) == 42
|
||||
assert mock.call_count == 1
|
||||
|
||||
assert function_using_mock(2) == 42
|
||||
assert mock.call_count == 2
|
||||
|
||||
|
||||
class TestAsyncTTLCache:
|
||||
"""Tests for the @async_ttl_cache decorator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching(self):
|
||||
"""Test basic caching functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result3 = await cached_function(2, 3)
|
||||
assert result3 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_expiration(self):
|
||||
"""Test that cache entries expire after TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def short_lived_cache(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First call
|
||||
result1 = await short_lived_cache(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Second call immediately - should use cache
|
||||
result2 = await short_lived_cache(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Third call after expiration - should call function again
|
||||
result3 = await short_lived_cache(5)
|
||||
assert result3 == 10
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_info(self):
|
||||
"""Test cache info functionality."""
|
||||
|
||||
@async_ttl_cache(maxsize=5, ttl_seconds=300)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] == 300
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_clear(self):
|
||||
"""Test cache clearing functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 4
|
||||
|
||||
# First call
|
||||
result1 = await clearable_function(2)
|
||||
assert result1 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
result2 = await clearable_function(2)
|
||||
assert result2 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Clear cache
|
||||
clearable_function.cache_clear()
|
||||
|
||||
# Third call after clear - should call function again
|
||||
result3 = await clearable_function(2)
|
||||
assert result3 == 8
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maxsize_cleanup(self):
|
||||
"""Test that cache cleans up when maxsize is exceeded."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=3, ttl_seconds=60)
|
||||
async def size_limited_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# Fill cache to maxsize
|
||||
await size_limited_function(1) # call_count: 1
|
||||
await size_limited_function(2) # call_count: 2
|
||||
await size_limited_function(3) # call_count: 3
|
||||
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] == 3
|
||||
|
||||
# Add one more entry - should trigger cleanup
|
||||
await size_limited_function(4) # call_count: 4
|
||||
|
||||
# Cache size should be reduced (cleanup removes oldest entries)
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] is not None and info["size"] <= 3 # Should be cleaned up
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_argument_variations(self):
|
||||
"""Test caching with different argument patterns."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def arg_test_function(a: int, b: str = "default", *, c: int = 100) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{a}-{b}-{c}"
|
||||
|
||||
# Different ways to call with same logical arguments
|
||||
result1 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
|
||||
# Same arguments, same order - should use cache
|
||||
result2 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
assert result1 == result2
|
||||
|
||||
# Different arguments - should call function
|
||||
result3 = await arg_test_function(2, "test", c=200)
|
||||
assert call_count == 2
|
||||
assert result1 != result3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_handling(self):
|
||||
"""Test that exceptions are not cached."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def exception_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if x < 0:
|
||||
raise ValueError("Negative value not allowed")
|
||||
return x * 2
|
||||
|
||||
# Successful call - should be cached
|
||||
result1 = await exception_function(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Same successful call - should use cache
|
||||
result2 = await exception_function(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Exception call - should not be cached
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
# Same exception call - should call again (not cached)
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_calls(self):
|
||||
"""Test caching behavior with concurrent calls."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def concurrent_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.05) # Simulate work
|
||||
return x * x
|
||||
|
||||
# Launch concurrent calls with same arguments
|
||||
tasks = [concurrent_function(3) for _ in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All results should be the same
|
||||
assert all(result == 9 for result in results)
|
||||
|
||||
# Note: Due to race conditions, call_count might be up to 5 for concurrent calls
|
||||
# This tests that the cache doesn't break under concurrent access
|
||||
assert 1 <= call_count <= 5
|
||||
|
||||
|
||||
class TestAsyncCache:
|
||||
"""Tests for the @async_cache decorator (no TTL)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching_no_ttl(self):
|
||||
"""Test basic caching functionality without TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_cache(maxsize=10)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Third call after some time - should still use cache (no TTL)
|
||||
await asyncio.sleep(0.05)
|
||||
result3 = await cached_function(1, 2)
|
||||
assert result3 == 3
|
||||
assert call_count == 1 # Still no additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result4 = await cached_function(2, 3)
|
||||
assert result4 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_ttl_vs_ttl_behavior(self):
|
||||
"""Test the difference between TTL and no-TTL caching."""
|
||||
ttl_call_count = 0
|
||||
no_ttl_call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_call_count
|
||||
ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
@async_cache(maxsize=10) # No TTL
|
||||
async def no_ttl_function(x: int) -> int:
|
||||
nonlocal no_ttl_call_count
|
||||
no_ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First calls
|
||||
await ttl_function(5)
|
||||
await no_ttl_function(5)
|
||||
assert ttl_call_count == 1
|
||||
assert no_ttl_call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Second calls after TTL expiry
|
||||
await ttl_function(5) # Should call function again (TTL expired)
|
||||
await no_ttl_function(5) # Should use cache (no TTL)
|
||||
assert ttl_call_count == 2 # TTL function called again
|
||||
assert no_ttl_call_count == 1 # No-TTL function still cached
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_cache_info(self):
|
||||
"""Test cache info for no-TTL cache."""
|
||||
|
||||
@async_cache(maxsize=5)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] is None # No TTL
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
|
||||
class TestTTLOptional:
|
||||
"""Tests for optional TTL functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_none_behavior(self):
|
||||
"""Test that ttl_seconds=None works like no TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=None)
|
||||
async def no_ttl_via_none(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# First call
|
||||
result1 = await no_ttl_via_none(3)
|
||||
assert result1 == 9
|
||||
assert call_count == 1
|
||||
|
||||
# Wait (would expire if there was TTL)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Second call - should still use cache
|
||||
result2 = await no_ttl_via_none(3)
|
||||
assert result2 == 9
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Check cache info
|
||||
info = no_ttl_via_none.cache_info()
|
||||
assert info["ttl_seconds"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_options_comparison(self):
|
||||
"""Test different cache options work as expected."""
|
||||
ttl_calls = 0
|
||||
no_ttl_calls = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # With TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_calls
|
||||
ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
@async_cache(maxsize=10) # Process-level cache (no TTL)
|
||||
async def process_function(x: int) -> int:
|
||||
nonlocal no_ttl_calls
|
||||
no_ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
# Both should cache initially
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Immediate second calls - both should use cache
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# After TTL expiry
|
||||
await ttl_function(3) # Should call function again
|
||||
await process_function(3) # Should still use cache
|
||||
assert ttl_calls == 2 # TTL cache expired, called again
|
||||
assert no_ttl_calls == 1 # Process cache never expires
|
||||
1782
autogpt_platform/autogpt_libs/poetry.lock
generated
1782
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,8 @@
|
||||
[tool.poetry]
|
||||
name = "autogpt-libs"
|
||||
version = "0.2.0"
|
||||
description = "Shared libraries across AutoGPT Platform"
|
||||
authors = ["AutoGPT team <info@agpt.co>"]
|
||||
description = "Shared libraries across NextGen AutoGPT"
|
||||
authors = ["Aarushi <aarushik93@gmail.com>"]
|
||||
readme = "README.md"
|
||||
packages = [{ include = "autogpt_libs" }]
|
||||
|
||||
@@ -10,20 +10,20 @@ packages = [{ include = "autogpt_libs" }]
|
||||
python = ">=3.10,<4.0"
|
||||
colorama = "^0.4.6"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.116.1"
|
||||
google-cloud-logging = "^3.12.1"
|
||||
launchdarkly-server-sdk = "^9.12.0"
|
||||
pydantic = "^2.11.7"
|
||||
pydantic-settings = "^2.10.1"
|
||||
pydantic = "^2.11.4"
|
||||
pydantic-settings = "^2.9.1"
|
||||
pyjwt = "^2.10.1"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
pytest-mock = "^3.14.1"
|
||||
redis = "^6.2.0"
|
||||
supabase = "^2.16.0"
|
||||
uvicorn = "^0.35.0"
|
||||
pytest-asyncio = "^0.26.0"
|
||||
pytest-mock = "^3.14.0"
|
||||
supabase = "^2.15.1"
|
||||
launchdarkly-server-sdk = "^9.11.1"
|
||||
fastapi = "^0.115.12"
|
||||
uvicorn = "^0.34.3"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "^0.12.3"
|
||||
redis = "^5.2.1"
|
||||
ruff = "^0.12.2"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
# Development and testing files
|
||||
**/__pycache__
|
||||
**/*.pyc
|
||||
**/*.pyo
|
||||
**/*.pyd
|
||||
**/.Python
|
||||
**/env/
|
||||
**/venv/
|
||||
**/.venv/
|
||||
**/pip-log.txt
|
||||
**/.pytest_cache/
|
||||
**/test-results/
|
||||
**/snapshots/
|
||||
**/test/
|
||||
|
||||
# IDE and editor files
|
||||
**/.vscode/
|
||||
**/.idea/
|
||||
**/*.swp
|
||||
**/*.swo
|
||||
*~
|
||||
|
||||
# OS files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Logs
|
||||
**/*.log
|
||||
**/logs/
|
||||
|
||||
# Git
|
||||
.git/
|
||||
.gitignore
|
||||
|
||||
# Documentation
|
||||
**/*.md
|
||||
!README.md
|
||||
|
||||
# Local development files
|
||||
.env
|
||||
.env.local
|
||||
**/.env.test
|
||||
|
||||
# Build artifacts
|
||||
**/dist/
|
||||
**/build/
|
||||
**/target/
|
||||
|
||||
# Docker files (avoid recursion)
|
||||
Dockerfile*
|
||||
docker-compose*
|
||||
.dockerignore
|
||||
@@ -1,9 +1,3 @@
|
||||
# Backend Configuration
|
||||
# This file contains environment variables that MUST be set for the AutoGPT platform
|
||||
# Variables with working defaults in settings.py are not included here
|
||||
|
||||
## ===== REQUIRED DATABASE CONFIGURATION ===== ##
|
||||
# PostgreSQL Database Connection
|
||||
DB_USER=postgres
|
||||
DB_PASS=your-super-secret-and-long-postgres-password
|
||||
DB_NAME=postgres
|
||||
@@ -16,50 +10,72 @@ DB_SCHEMA=platform
|
||||
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
ENABLE_AUTH=true
|
||||
|
||||
## ===== REQUIRED SERVICE CREDENTIALS ===== ##
|
||||
# Redis Configuration
|
||||
# EXECUTOR
|
||||
NUM_GRAPH_WORKERS=10
|
||||
|
||||
BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]
|
||||
|
||||
# generate using `from cryptography.fernet import Fernet;Fernet.generate_key().decode()`
|
||||
ENCRYPTION_KEY='dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw='
|
||||
UNSUBSCRIBE_SECRET_KEY = 'HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio='
|
||||
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=password
|
||||
|
||||
# RabbitMQ Credentials
|
||||
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
ENABLE_CREDIT=false
|
||||
STRIPE_API_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
|
||||
# Supabase Authentication
|
||||
# What environment things should be logged under: local dev or prod
|
||||
APP_ENV=local
|
||||
# What environment to behave as: "local" or "cloud"
|
||||
BEHAVE_AS=local
|
||||
PYRO_HOST=localhost
|
||||
SENTRY_DSN=
|
||||
|
||||
# Email For Postmark so we can send emails
|
||||
POSTMARK_SERVER_API_TOKEN=
|
||||
POSTMARK_SENDER_EMAIL=invalid@invalid.com
|
||||
POSTMARK_WEBHOOK_TOKEN=
|
||||
|
||||
## User auth with Supabase is required for any of the 3rd party integrations with auth to work.
|
||||
ENABLE_AUTH=true
|
||||
SUPABASE_URL=http://localhost:8000
|
||||
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
|
||||
## ===== REQUIRED SECURITY KEYS ===== ##
|
||||
# Generate using: from cryptography.fernet import Fernet;Fernet.generate_key().decode()
|
||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||
# RabbitMQ credentials -- Used for communication between services
|
||||
RABBITMQ_HOST=localhost
|
||||
RABBITMQ_PORT=5672
|
||||
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
|
||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||
# Platform URLs (set these for webhooks and OAuth to work)
|
||||
PLATFORM_BASE_URL=http://localhost:8000
|
||||
FRONTEND_BASE_URL=http://localhost:3000
|
||||
|
||||
# Media Storage (required for marketplace and library functionality)
|
||||
## GCS bucket is required for marketplace and library functionality
|
||||
MEDIA_GCS_BUCKET_NAME=
|
||||
|
||||
## ===== API KEYS AND OAUTH CREDENTIALS ===== ##
|
||||
# All API keys below are optional - only add what you need
|
||||
## For local development, you may need to set FRONTEND_BASE_URL for the OAuth flow
|
||||
## for integrations to work. Defaults to the value of PLATFORM_BASE_URL if not set.
|
||||
# FRONTEND_BASE_URL=http://localhost:3000
|
||||
|
||||
# AI/LLM Services
|
||||
OPENAI_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
GROQ_API_KEY=
|
||||
LLAMA_API_KEY=
|
||||
AIML_API_KEY=
|
||||
V0_API_KEY=
|
||||
OPEN_ROUTER_API_KEY=
|
||||
NVIDIA_API_KEY=
|
||||
## PLATFORM_BASE_URL must be set to a *publicly accessible* URL pointing to your backend
|
||||
## to use the platform's webhook-related functionality.
|
||||
## If you are developing locally, you can use something like ngrok to get a publc URL
|
||||
## and tunnel it to your locally running backend.
|
||||
PLATFORM_BASE_URL=http://localhost:3000
|
||||
|
||||
## Cloudflare Turnstile (CAPTCHA) Configuration
|
||||
## Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
|
||||
## This is the backend secret key
|
||||
TURNSTILE_SECRET_KEY=
|
||||
## This is the verify URL
|
||||
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
|
||||
|
||||
## == INTEGRATION CREDENTIALS == ##
|
||||
# Each set of server side credentials is required for the corresponding 3rd party
|
||||
# integration to work.
|
||||
|
||||
# OAuth Credentials
|
||||
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
|
||||
# e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||
|
||||
@@ -69,6 +85,7 @@ GITHUB_CLIENT_SECRET=
|
||||
|
||||
# Google OAuth App server credentials - https://console.cloud.google.com/apis/credentials, and enable gmail api and set scopes
|
||||
# https://console.cloud.google.com/apis/credentials/consent ?project=<your_project_id>
|
||||
|
||||
# You'll need to add/enable the following scopes (minimum):
|
||||
# https://console.developers.google.com/apis/api/gmail.googleapis.com/overview ?project=<your_project_id>
|
||||
# https://console.cloud.google.com/apis/library/sheets.googleapis.com/ ?project=<your_project_id>
|
||||
@@ -104,66 +121,92 @@ LINEAR_CLIENT_SECRET=
|
||||
TODOIST_CLIENT_ID=
|
||||
TODOIST_CLIENT_SECRET=
|
||||
|
||||
NOTION_CLIENT_ID=
|
||||
NOTION_CLIENT_SECRET=
|
||||
## ===== OPTIONAL API KEYS ===== ##
|
||||
|
||||
# LLM
|
||||
OPENAI_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
AIML_API_KEY=
|
||||
GROQ_API_KEY=
|
||||
OPEN_ROUTER_API_KEY=
|
||||
LLAMA_API_KEY=
|
||||
|
||||
# Reddit
|
||||
# Go to https://www.reddit.com/prefs/apps and create a new app
|
||||
# Choose "script" for the type
|
||||
# Fill in the redirect uri as <your_frontend_url>/auth/integrations/oauth_callback, e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||
REDDIT_CLIENT_ID=
|
||||
REDDIT_CLIENT_SECRET=
|
||||
REDDIT_USER_AGENT="AutoGPT:1.0 (by /u/autogpt)"
|
||||
|
||||
# Payment Processing
|
||||
STRIPE_API_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
|
||||
# Email Service (for sending notifications and confirmations)
|
||||
POSTMARK_SERVER_API_TOKEN=
|
||||
POSTMARK_SENDER_EMAIL=invalid@invalid.com
|
||||
POSTMARK_WEBHOOK_TOKEN=
|
||||
|
||||
# Error Tracking
|
||||
SENTRY_DSN=
|
||||
|
||||
# Cloudflare Turnstile (CAPTCHA) Configuration
|
||||
# Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
|
||||
# This is the backend secret key
|
||||
TURNSTILE_SECRET_KEY=
|
||||
# This is the verify URL
|
||||
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
|
||||
|
||||
# Feature Flags
|
||||
LAUNCH_DARKLY_SDK_KEY=
|
||||
|
||||
# Content Generation & Media
|
||||
DID_API_KEY=
|
||||
FAL_API_KEY=
|
||||
IDEOGRAM_API_KEY=
|
||||
REPLICATE_API_KEY=
|
||||
REVID_API_KEY=
|
||||
SCREENSHOTONE_API_KEY=
|
||||
UNREAL_SPEECH_API_KEY=
|
||||
|
||||
# Data & Search Services
|
||||
E2B_API_KEY=
|
||||
EXA_API_KEY=
|
||||
JINA_API_KEY=
|
||||
MEM0_API_KEY=
|
||||
OPENWEATHERMAP_API_KEY=
|
||||
GOOGLE_MAPS_API_KEY=
|
||||
|
||||
# Communication Services
|
||||
# Discord
|
||||
DISCORD_BOT_TOKEN=
|
||||
MEDIUM_API_KEY=
|
||||
MEDIUM_AUTHOR_ID=
|
||||
|
||||
# SMTP/Email
|
||||
SMTP_SERVER=
|
||||
SMTP_PORT=
|
||||
SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# Business & Marketing Tools
|
||||
# D-ID
|
||||
DID_API_KEY=
|
||||
|
||||
# Open Weather Map
|
||||
OPENWEATHERMAP_API_KEY=
|
||||
|
||||
# SMTP
|
||||
SMTP_SERVER=
|
||||
SMTP_PORT=
|
||||
SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# Medium
|
||||
MEDIUM_API_KEY=
|
||||
MEDIUM_AUTHOR_ID=
|
||||
|
||||
# Google Maps
|
||||
GOOGLE_MAPS_API_KEY=
|
||||
|
||||
# Replicate
|
||||
REPLICATE_API_KEY=
|
||||
|
||||
# Ideogram
|
||||
IDEOGRAM_API_KEY=
|
||||
|
||||
# Fal
|
||||
FAL_API_KEY=
|
||||
|
||||
# Exa
|
||||
EXA_API_KEY=
|
||||
|
||||
# E2B
|
||||
E2B_API_KEY=
|
||||
|
||||
# Mem0
|
||||
MEM0_API_KEY=
|
||||
|
||||
# Nvidia
|
||||
NVIDIA_API_KEY=
|
||||
|
||||
# Apollo
|
||||
APOLLO_API_KEY=
|
||||
ENRICHLAYER_API_KEY=
|
||||
AYRSHARE_API_KEY=
|
||||
AYRSHARE_JWT_KEY=
|
||||
|
||||
# SmartLead
|
||||
SMARTLEAD_API_KEY=
|
||||
|
||||
# ZeroBounce
|
||||
ZEROBOUNCE_API_KEY=
|
||||
|
||||
# Other Services
|
||||
AUTOMOD_API_KEY=
|
||||
## ===== OPTIONAL API KEYS END ===== ##
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=INFO
|
||||
ENABLE_CLOUD_LOGGING=false
|
||||
ENABLE_FILE_LOGGING=false
|
||||
# Use to manually set the log directory
|
||||
# LOG_DIR=./logs
|
||||
|
||||
# Example Blocks Configuration
|
||||
# Set to true to enable example blocks in development
|
||||
# These blocks are disabled by default in production
|
||||
ENABLE_EXAMPLE_BLOCKS=false
|
||||
1
autogpt_platform/backend/.gitignore
vendored
1
autogpt_platform/backend/.gitignore
vendored
@@ -1,4 +1,3 @@
|
||||
.env
|
||||
database.db
|
||||
database.db-journal
|
||||
dev.db
|
||||
|
||||
@@ -8,14 +8,14 @@ WORKDIR /app
|
||||
|
||||
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
|
||||
|
||||
# Update package list and install build dependencies in a single layer
|
||||
RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
||||
&& apt-get install -y \
|
||||
build-essential \
|
||||
libpq5 \
|
||||
libz-dev \
|
||||
libssl-dev \
|
||||
postgresql-client
|
||||
RUN apt-get update --allow-releaseinfo-change --fix-missing
|
||||
|
||||
# Install build dependencies
|
||||
RUN apt-get install -y build-essential
|
||||
RUN apt-get install -y libpq5
|
||||
RUN apt-get install -y libz-dev
|
||||
RUN apt-get install -y libssl-dev
|
||||
RUN apt-get install -y postgresql-client
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry
|
||||
ENV POETRY_NO_INTERACTION=1
|
||||
@@ -68,12 +68,6 @@ COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.tom
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
FROM server_dependencies AS migrate
|
||||
|
||||
# Migration stage only needs schema and migrations - much lighter than full backend
|
||||
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
|
||||
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
|
||||
|
||||
FROM server_dependencies AS server
|
||||
|
||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||
|
||||
@@ -1,150 +0,0 @@
|
||||
# Test Data Scripts
|
||||
|
||||
This directory contains scripts for creating and updating test data in the AutoGPT Platform database, specifically designed to test the materialized views for the store functionality.
|
||||
|
||||
## Scripts
|
||||
|
||||
### test_data_creator.py
|
||||
Creates a comprehensive set of test data including:
|
||||
- Users with profiles
|
||||
- Agent graphs, nodes, and executions
|
||||
- Store listings with multiple versions
|
||||
- Reviews and ratings
|
||||
- Library agents
|
||||
- Integration webhooks
|
||||
- Onboarding data
|
||||
- Credit transactions
|
||||
|
||||
**Image/Video Domains Used:**
|
||||
- Images: `picsum.photos` (for all image URLs)
|
||||
- Videos: `youtube.com` (for store listing videos)
|
||||
|
||||
### test_data_updater.py
|
||||
Updates existing test data to simulate real-world changes:
|
||||
- Adds new agent graph executions
|
||||
- Creates new store listing reviews
|
||||
- Updates store listing versions
|
||||
- Adds credit transactions
|
||||
- Refreshes materialized views
|
||||
|
||||
### check_db.py
|
||||
Tests and verifies materialized views functionality:
|
||||
- Checks pg_cron job status (for automatic refresh)
|
||||
- Displays current materialized view counts
|
||||
- Adds test data (executions and reviews)
|
||||
- Creates store listings if none exist
|
||||
- Manually refreshes materialized views
|
||||
- Compares before/after counts to verify updates
|
||||
- Provides a summary of test results
|
||||
|
||||
## Materialized Views
|
||||
|
||||
The scripts test three key database views:
|
||||
|
||||
1. **mv_agent_run_counts**: Tracks execution counts by agent
|
||||
2. **mv_review_stats**: Tracks review statistics (count, average rating) by store listing
|
||||
3. **StoreAgent**: A view that combines store listing data with execution counts and ratings for display
|
||||
|
||||
The materialized views (mv_agent_run_counts and mv_review_stats) are automatically refreshed every 15 minutes via pg_cron, or can be manually refreshed using the `refresh_store_materialized_views()` function.
|
||||
|
||||
## Usage
|
||||
|
||||
### Prerequisites
|
||||
|
||||
1. Ensure the database is running:
|
||||
```bash
|
||||
docker compose up -d
|
||||
# or for test database:
|
||||
docker compose -f docker-compose.test.yaml --env-file ../.env up -d
|
||||
```
|
||||
|
||||
2. Run database migrations:
|
||||
```bash
|
||||
poetry run prisma migrate deploy
|
||||
```
|
||||
|
||||
### Running the Scripts
|
||||
|
||||
#### Option 1: Use the helper script (from backend directory)
|
||||
```bash
|
||||
poetry run python run_test_data.py
|
||||
```
|
||||
|
||||
#### Option 2: Run individually
|
||||
```bash
|
||||
# From backend/test directory:
|
||||
# Create initial test data
|
||||
poetry run python test_data_creator.py
|
||||
|
||||
# Update data to test materialized view changes
|
||||
poetry run python test_data_updater.py
|
||||
|
||||
# From backend directory:
|
||||
# Test materialized views functionality
|
||||
poetry run python check_db.py
|
||||
|
||||
# Check store data status
|
||||
poetry run python check_store_data.py
|
||||
```
|
||||
|
||||
#### Option 3: Use the shell script (from backend directory)
|
||||
```bash
|
||||
./run_test_data_scripts.sh
|
||||
```
|
||||
|
||||
### Manual Materialized View Refresh
|
||||
|
||||
To manually refresh the materialized views:
|
||||
```sql
|
||||
SELECT refresh_store_materialized_views();
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
The scripts use the database configuration from your `.env` file:
|
||||
- `DATABASE_URL`: PostgreSQL connection string
|
||||
- Database should have the platform schema
|
||||
|
||||
## Data Generation Limits
|
||||
|
||||
Configured in `test_data_creator.py`:
|
||||
- 100 users
|
||||
- 100 agent blocks
|
||||
- 1-5 graphs per user
|
||||
- 2-5 nodes per graph
|
||||
- 1-5 presets per user
|
||||
- 1-10 library agents per user
|
||||
- 1-20 executions per graph
|
||||
- 1-5 reviews per store listing version
|
||||
|
||||
## Notes
|
||||
|
||||
- All image URLs use `picsum.photos` for consistency with Next.js image configuration
|
||||
- The scripts create realistic relationships between entities
|
||||
- Materialized views are refreshed at the end of each script
|
||||
- Data is designed to test both happy paths and edge cases
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Reviews and StoreAgent view showing 0
|
||||
|
||||
If `check_db.py` shows that reviews remain at 0 and StoreAgent view shows 0 store agents:
|
||||
|
||||
1. **No store listings exist**: The script will automatically create test store listings if none exist
|
||||
2. **No approved versions**: Store listings need approved versions to appear in the StoreAgent view
|
||||
3. **Check with `check_store_data.py`**: This script provides detailed information about:
|
||||
- Total store listings
|
||||
- Store listing versions by status
|
||||
- Existing reviews
|
||||
- StoreAgent view contents
|
||||
- Agent graph executions
|
||||
|
||||
### pg_cron not installed
|
||||
|
||||
The warning "pg_cron extension is not installed" is normal in local development environments. The materialized views can still be refreshed manually using the `refresh_store_materialized_views()` function, which all scripts do automatically.
|
||||
|
||||
### Common Issues
|
||||
|
||||
- **Type errors with None values**: Fixed in the latest version of check_db.py by using `or 0` for nullable numeric fields
|
||||
- **Missing relations**: Ensure you're using the correct field names (e.g., `StoreListing` not `storeListing` in includes)
|
||||
- **Column name mismatches**: The database uses camelCase for column names (e.g., `agentGraphId` not `agent_graph_id`)
|
||||
@@ -1,10 +1,6 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.util.process import AppProcess
|
||||
|
||||
@@ -42,12 +38,12 @@ def main(**kwargs):
|
||||
from backend.server.ws_api import WebsocketServer
|
||||
|
||||
run_processes(
|
||||
DatabaseManager().set_log_level("warning"),
|
||||
DatabaseManager(),
|
||||
ExecutionManager(),
|
||||
Scheduler(),
|
||||
NotificationManager(),
|
||||
WebsocketServer(),
|
||||
AgentServer(),
|
||||
ExecutionManager(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.block import Block
|
||||
|
||||
@@ -103,15 +99,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
|
||||
available_blocks[block.id] = block_cls
|
||||
|
||||
# Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets
|
||||
from backend.data.block import is_block_auth_configured
|
||||
|
||||
filtered_blocks = {}
|
||||
for block_id, block_cls in available_blocks.items():
|
||||
if is_block_auth_configured(block_cls):
|
||||
filtered_blocks[block_id] = block_cls
|
||||
|
||||
return filtered_blocks
|
||||
return available_blocks
|
||||
|
||||
|
||||
__all__ = ["load_all_blocks"]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -13,9 +14,8 @@ from backend.data.block import (
|
||||
get_block,
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.json import validate_with_jsonschema
|
||||
from backend.util.retry import func_retry
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json, retry
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -49,7 +49,7 @@ class AgentExecutorBlock(Block):
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
return validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
return json.validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
|
||||
class Output(BlockSchema):
|
||||
pass
|
||||
@@ -74,6 +74,7 @@ class AgentExecutorBlock(Block):
|
||||
user_id=input_data.user_id,
|
||||
inputs=input_data.inputs,
|
||||
nodes_input_masks=input_data.nodes_input_masks,
|
||||
use_db_query=False,
|
||||
)
|
||||
|
||||
logger = execution_utils.LogMetadata(
|
||||
@@ -95,14 +96,23 @@ class AgentExecutorBlock(Block):
|
||||
logger=logger,
|
||||
):
|
||||
yield name, data
|
||||
except BaseException as e:
|
||||
except asyncio.CancelledError:
|
||||
await self._stop(
|
||||
graph_exec_id=graph_exec.id,
|
||||
user_id=input_data.user_id,
|
||||
logger=logger,
|
||||
)
|
||||
logger.warning(
|
||||
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} failed: {e.__class__.__name__} {str(e)}; execution is stopped."
|
||||
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} was cancelled."
|
||||
)
|
||||
except Exception as e:
|
||||
await self._stop(
|
||||
graph_exec_id=graph_exec.id,
|
||||
user_id=input_data.user_id,
|
||||
logger=logger,
|
||||
)
|
||||
logger.error(
|
||||
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} failed: {e}, execution is stopped."
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -122,7 +132,6 @@ class AgentExecutorBlock(Block):
|
||||
|
||||
log_id = f"Graph #{graph_id}-V{graph_version}, exec-id: {graph_exec_id}"
|
||||
logger.info(f"Starting execution of {log_id}")
|
||||
yielded_node_exec_ids = set()
|
||||
|
||||
async for event in event_bus.listen(
|
||||
user_id=user_id,
|
||||
@@ -142,26 +151,12 @@ class AgentExecutorBlock(Block):
|
||||
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
|
||||
# If the graph execution is COMPLETED, TERMINATED, or FAILED,
|
||||
# we can stop listening for further events.
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
extra_cost=event.stats.cost if event.stats else 0,
|
||||
extra_steps=event.stats.node_exec_count if event.stats else 0,
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
|
||||
)
|
||||
|
||||
if event.node_exec_id in yielded_node_exec_ids:
|
||||
logger.warning(
|
||||
f"{log_id} received duplicate event for node execution {event.node_exec_id}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
yielded_node_exec_ids.add(event.node_exec_id)
|
||||
|
||||
if not event.block_id:
|
||||
logger.warning(f"{log_id} received event without block_id {event}")
|
||||
continue
|
||||
@@ -181,7 +176,7 @@ class AgentExecutorBlock(Block):
|
||||
)
|
||||
yield output_name, output_data
|
||||
|
||||
@func_retry
|
||||
@retry.func_retry
|
||||
async def _stop(
|
||||
self,
|
||||
graph_exec_id: str,
|
||||
@@ -197,8 +192,8 @@ class AgentExecutorBlock(Block):
|
||||
await execution_utils.stop_graph_execution(
|
||||
graph_exec_id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
wait_timeout=3600,
|
||||
use_db_query=False,
|
||||
)
|
||||
logger.info(f"Execution {log_id} stopped successfully.")
|
||||
except TimeoutError as e:
|
||||
logger.error(f"Execution {log_id} stop timed out: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop execution {log_id}: {e}")
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Airtable integration for AutoGPT Platform.
|
||||
|
||||
This integration provides comprehensive access to the Airtable Web API,
|
||||
including:
|
||||
- Webhook triggers and management
|
||||
- Record CRUD operations
|
||||
- Attachment uploads
|
||||
- Schema and table management
|
||||
- Metadata operations
|
||||
"""
|
||||
|
||||
# Attachments
|
||||
from .attachments import AirtableUploadAttachmentBlock
|
||||
|
||||
# Metadata
|
||||
from .metadata import (
|
||||
AirtableGetViewBlock,
|
||||
AirtableListBasesBlock,
|
||||
AirtableListViewsBlock,
|
||||
)
|
||||
|
||||
# Record Operations
|
||||
from .records import (
|
||||
AirtableCreateRecordsBlock,
|
||||
AirtableDeleteRecordsBlock,
|
||||
AirtableGetRecordBlock,
|
||||
AirtableListRecordsBlock,
|
||||
AirtableUpdateRecordsBlock,
|
||||
AirtableUpsertRecordsBlock,
|
||||
)
|
||||
|
||||
# Schema & Table Management
|
||||
from .schema import (
|
||||
AirtableAddFieldBlock,
|
||||
AirtableCreateTableBlock,
|
||||
AirtableDeleteFieldBlock,
|
||||
AirtableDeleteTableBlock,
|
||||
AirtableListSchemaBlock,
|
||||
AirtableUpdateFieldBlock,
|
||||
AirtableUpdateTableBlock,
|
||||
)
|
||||
|
||||
# Webhook Triggers
|
||||
from .triggers import AirtableWebhookTriggerBlock
|
||||
|
||||
# Webhook Management
|
||||
from .webhooks import (
|
||||
AirtableCreateWebhookBlock,
|
||||
AirtableDeleteWebhookBlock,
|
||||
AirtableFetchWebhookPayloadsBlock,
|
||||
AirtableRefreshWebhookBlock,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Webhook Triggers
|
||||
"AirtableWebhookTriggerBlock",
|
||||
# Webhook Management
|
||||
"AirtableCreateWebhookBlock",
|
||||
"AirtableDeleteWebhookBlock",
|
||||
"AirtableFetchWebhookPayloadsBlock",
|
||||
"AirtableRefreshWebhookBlock",
|
||||
# Record Operations
|
||||
"AirtableCreateRecordsBlock",
|
||||
"AirtableDeleteRecordsBlock",
|
||||
"AirtableGetRecordBlock",
|
||||
"AirtableListRecordsBlock",
|
||||
"AirtableUpdateRecordsBlock",
|
||||
"AirtableUpsertRecordsBlock",
|
||||
# Attachments
|
||||
"AirtableUploadAttachmentBlock",
|
||||
# Schema & Table Management
|
||||
"AirtableAddFieldBlock",
|
||||
"AirtableCreateTableBlock",
|
||||
"AirtableDeleteFieldBlock",
|
||||
"AirtableDeleteTableBlock",
|
||||
"AirtableListSchemaBlock",
|
||||
"AirtableUpdateFieldBlock",
|
||||
"AirtableUpdateTableBlock",
|
||||
# Metadata
|
||||
"AirtableGetViewBlock",
|
||||
"AirtableListBasesBlock",
|
||||
"AirtableListViewsBlock",
|
||||
]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,323 +0,0 @@
|
||||
from os import getenv
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.sdk import APIKeyCredentials, SecretStr
|
||||
|
||||
from ._api import (
|
||||
TableFieldType,
|
||||
WebhookFilters,
|
||||
WebhookSpecification,
|
||||
create_base,
|
||||
create_field,
|
||||
create_record,
|
||||
create_table,
|
||||
create_webhook,
|
||||
delete_multiple_records,
|
||||
delete_record,
|
||||
delete_webhook,
|
||||
get_record,
|
||||
list_bases,
|
||||
list_records,
|
||||
list_webhook_payloads,
|
||||
update_field,
|
||||
update_multiple_records,
|
||||
update_record,
|
||||
update_table,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_update_table():
|
||||
|
||||
key = getenv("AIRTABLE_API_KEY")
|
||||
if not key:
|
||||
return pytest.skip("AIRTABLE_API_KEY is not set")
|
||||
|
||||
credentials = APIKeyCredentials(
|
||||
provider="airtable",
|
||||
api_key=SecretStr(key),
|
||||
)
|
||||
postfix = uuid4().hex[:4]
|
||||
workspace_id = "wsphuHmfllg7V3Brd"
|
||||
response = await create_base(credentials, workspace_id, "API Testing Base")
|
||||
assert response is not None, f"Checking create base response: {response}"
|
||||
assert (
|
||||
response.get("id") is not None
|
||||
), f"Checking create base response id: {response}"
|
||||
base_id = response.get("id")
|
||||
assert base_id is not None, f"Checking create base response id: {base_id}"
|
||||
|
||||
response = await list_bases(credentials)
|
||||
assert response is not None, f"Checking list bases response: {response}"
|
||||
assert "API Testing Base" in [
|
||||
base.get("name") for base in response.get("bases", [])
|
||||
], f"Checking list bases response bases: {response}"
|
||||
|
||||
table_name = f"test_table_{postfix}"
|
||||
table_fields = [{"name": "test_field", "type": "singleLineText"}]
|
||||
table = await create_table(credentials, base_id, table_name, table_fields)
|
||||
assert table.get("name") == table_name
|
||||
|
||||
table_id = table.get("id")
|
||||
|
||||
assert table_id is not None
|
||||
|
||||
table_name = f"test_table_updated_{postfix}"
|
||||
table_description = "test_description_updated"
|
||||
table = await update_table(
|
||||
credentials,
|
||||
base_id,
|
||||
table_id,
|
||||
table_name=table_name,
|
||||
table_description=table_description,
|
||||
)
|
||||
assert table.get("name") == table_name
|
||||
assert table.get("description") == table_description
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_field_type():
|
||||
|
||||
key = getenv("AIRTABLE_API_KEY")
|
||||
if not key:
|
||||
return pytest.skip("AIRTABLE_API_KEY is not set")
|
||||
|
||||
credentials = APIKeyCredentials(
|
||||
provider="airtable",
|
||||
api_key=SecretStr(key),
|
||||
)
|
||||
postfix = uuid4().hex[:4]
|
||||
base_id = "appZPxegHEU3kDc1S"
|
||||
table_name = f"test_table_{postfix}"
|
||||
table_fields = [{"name": "test_field", "type": "notValid"}]
|
||||
with pytest.raises(AssertionError):
|
||||
await create_table(credentials, base_id, table_name, table_fields)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_update_field():
|
||||
key = getenv("AIRTABLE_API_KEY")
|
||||
if not key:
|
||||
return pytest.skip("AIRTABLE_API_KEY is not set")
|
||||
|
||||
credentials = APIKeyCredentials(
|
||||
provider="airtable",
|
||||
api_key=SecretStr(key),
|
||||
)
|
||||
postfix = uuid4().hex[:4]
|
||||
base_id = "appZPxegHEU3kDc1S"
|
||||
table_name = f"test_table_{postfix}"
|
||||
table_fields = [{"name": "test_field", "type": "singleLineText"}]
|
||||
table = await create_table(credentials, base_id, table_name, table_fields)
|
||||
assert table.get("name") == table_name
|
||||
|
||||
table_id = table.get("id")
|
||||
|
||||
assert table_id is not None
|
||||
|
||||
field_name = f"test_field_{postfix}"
|
||||
field_type = TableFieldType.SINGLE_LINE_TEXT
|
||||
field = await create_field(credentials, base_id, table_id, field_type, field_name)
|
||||
assert field.get("name") == field_name
|
||||
|
||||
field_id = field.get("id")
|
||||
|
||||
assert field_id is not None
|
||||
assert isinstance(field_id, str)
|
||||
|
||||
field_name = f"test_field_updated_{postfix}"
|
||||
field = await update_field(credentials, base_id, table_id, field_id, field_name)
|
||||
assert field.get("name") == field_name
|
||||
|
||||
field_description = "test_description_updated"
|
||||
field = await update_field(
|
||||
credentials, base_id, table_id, field_id, description=field_description
|
||||
)
|
||||
assert field.get("description") == field_description
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_management():
|
||||
key = getenv("AIRTABLE_API_KEY")
|
||||
if not key:
|
||||
return pytest.skip("AIRTABLE_API_KEY is not set")
|
||||
|
||||
credentials = APIKeyCredentials(
|
||||
provider="airtable",
|
||||
api_key=SecretStr(key),
|
||||
)
|
||||
postfix = uuid4().hex[:4]
|
||||
base_id = "appZPxegHEU3kDc1S"
|
||||
table_name = f"test_table_{postfix}"
|
||||
table_fields = [{"name": "test_field", "type": "singleLineText"}]
|
||||
table = await create_table(credentials, base_id, table_name, table_fields)
|
||||
assert table.get("name") == table_name
|
||||
|
||||
table_id = table.get("id")
|
||||
assert table_id is not None
|
||||
|
||||
# Create a record
|
||||
record_fields = {"test_field": "test_value"}
|
||||
record = await create_record(credentials, base_id, table_id, fields=record_fields)
|
||||
fields = record.get("fields")
|
||||
assert fields is not None
|
||||
assert isinstance(fields, dict)
|
||||
assert fields.get("test_field") == "test_value"
|
||||
|
||||
record_id = record.get("id")
|
||||
|
||||
assert record_id is not None
|
||||
assert isinstance(record_id, str)
|
||||
|
||||
# Get a record
|
||||
record = await get_record(credentials, base_id, table_id, record_id)
|
||||
fields = record.get("fields")
|
||||
assert fields is not None
|
||||
assert isinstance(fields, dict)
|
||||
assert fields.get("test_field") == "test_value"
|
||||
|
||||
# Updata a record
|
||||
record_fields = {"test_field": "test_value_updated"}
|
||||
record = await update_record(
|
||||
credentials, base_id, table_id, record_id, fields=record_fields
|
||||
)
|
||||
fields = record.get("fields")
|
||||
assert fields is not None
|
||||
assert isinstance(fields, dict)
|
||||
assert fields.get("test_field") == "test_value_updated"
|
||||
|
||||
# Delete a record
|
||||
record = await delete_record(credentials, base_id, table_id, record_id)
|
||||
assert record is not None
|
||||
assert record.get("id") == record_id
|
||||
assert record.get("deleted")
|
||||
|
||||
# Create 2 records
|
||||
records = [
|
||||
{"fields": {"test_field": "test_value_1"}},
|
||||
{"fields": {"test_field": "test_value_2"}},
|
||||
]
|
||||
response = await create_record(credentials, base_id, table_id, records=records)
|
||||
created_records = response.get("records")
|
||||
assert created_records is not None
|
||||
assert isinstance(created_records, list)
|
||||
assert len(created_records) == 2, f"Created records: {created_records}"
|
||||
first_record = created_records[0] # type: ignore
|
||||
second_record = created_records[1] # type: ignore
|
||||
first_record_id = first_record.get("id")
|
||||
second_record_id = second_record.get("id")
|
||||
assert first_record_id is not None
|
||||
assert second_record_id is not None
|
||||
assert first_record_id != second_record_id
|
||||
first_fields = first_record.get("fields")
|
||||
second_fields = second_record.get("fields")
|
||||
assert first_fields is not None
|
||||
assert second_fields is not None
|
||||
assert first_fields.get("test_field") == "test_value_1" # type: ignore
|
||||
assert second_fields.get("test_field") == "test_value_2" # type: ignore
|
||||
|
||||
# List records
|
||||
response = await list_records(credentials, base_id, table_id)
|
||||
records = response.get("records")
|
||||
assert records is not None
|
||||
assert len(records) == 2, f"Records: {records}"
|
||||
assert isinstance(records, list), f"Type of records: {type(records)}"
|
||||
|
||||
# Update multiple records
|
||||
records = [
|
||||
{"id": first_record_id, "fields": {"test_field": "test_value_1_updated"}},
|
||||
{"id": second_record_id, "fields": {"test_field": "test_value_2_updated"}},
|
||||
]
|
||||
response = await update_multiple_records(
|
||||
credentials, base_id, table_id, records=records
|
||||
)
|
||||
updated_records = response.get("records")
|
||||
assert updated_records is not None
|
||||
assert len(updated_records) == 2, f"Updated records: {updated_records}"
|
||||
assert isinstance(
|
||||
updated_records, list
|
||||
), f"Type of updated records: {type(updated_records)}"
|
||||
first_updated = updated_records[0] # type: ignore
|
||||
second_updated = updated_records[1] # type: ignore
|
||||
first_updated_fields = first_updated.get("fields")
|
||||
second_updated_fields = second_updated.get("fields")
|
||||
assert first_updated_fields is not None
|
||||
assert second_updated_fields is not None
|
||||
assert first_updated_fields.get("test_field") == "test_value_1_updated" # type: ignore
|
||||
assert second_updated_fields.get("test_field") == "test_value_2_updated" # type: ignore
|
||||
|
||||
# Delete multiple records
|
||||
assert isinstance(first_record_id, str)
|
||||
assert isinstance(second_record_id, str)
|
||||
response = await delete_multiple_records(
|
||||
credentials, base_id, table_id, records=[first_record_id, second_record_id]
|
||||
)
|
||||
deleted_records = response.get("records")
|
||||
assert deleted_records is not None
|
||||
assert len(deleted_records) == 2, f"Deleted records: {deleted_records}"
|
||||
assert isinstance(
|
||||
deleted_records, list
|
||||
), f"Type of deleted records: {type(deleted_records)}"
|
||||
first_deleted = deleted_records[0] # type: ignore
|
||||
second_deleted = deleted_records[1] # type: ignore
|
||||
assert first_deleted.get("deleted")
|
||||
assert second_deleted.get("deleted")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_management():
|
||||
key = getenv("AIRTABLE_API_KEY")
|
||||
if not key:
|
||||
return pytest.skip("AIRTABLE_API_KEY is not set")
|
||||
|
||||
credentials = APIKeyCredentials(
|
||||
provider="airtable",
|
||||
api_key=SecretStr(key),
|
||||
)
|
||||
postfix = uuid4().hex[:4]
|
||||
base_id = "appZPxegHEU3kDc1S"
|
||||
table_name = f"test_table_{postfix}"
|
||||
table_fields = [{"name": "test_field", "type": "singleLineText"}]
|
||||
table = await create_table(credentials, base_id, table_name, table_fields)
|
||||
assert table.get("name") == table_name
|
||||
|
||||
table_id = table.get("id")
|
||||
assert table_id is not None
|
||||
webhook_specification = WebhookSpecification(
|
||||
filters=WebhookFilters(
|
||||
dataTypes=["tableData", "tableFields", "tableMetadata"],
|
||||
changeTypes=["add", "update", "remove"],
|
||||
)
|
||||
)
|
||||
response = await create_webhook(credentials, base_id, webhook_specification)
|
||||
assert response is not None, f"Checking create webhook response: {response}"
|
||||
assert (
|
||||
response.get("id") is not None
|
||||
), f"Checking create webhook response id: {response}"
|
||||
assert (
|
||||
response.get("macSecretBase64") is not None
|
||||
), f"Checking create webhook response macSecretBase64: {response}"
|
||||
|
||||
webhook_id = response.get("id")
|
||||
assert webhook_id is not None, f"Webhook ID: {webhook_id}"
|
||||
assert isinstance(webhook_id, str)
|
||||
|
||||
response = await create_record(
|
||||
credentials, base_id, table_id, fields={"test_field": "test_value"}
|
||||
)
|
||||
assert response is not None, f"Checking create record response: {response}"
|
||||
assert (
|
||||
response.get("id") is not None
|
||||
), f"Checking create record response id: {response}"
|
||||
fields = response.get("fields")
|
||||
assert fields is not None, f"Checking create record response fields: {response}"
|
||||
assert (
|
||||
fields.get("test_field") == "test_value"
|
||||
), f"Checking create record response fields test_field: {response}"
|
||||
|
||||
response = await list_webhook_payloads(credentials, base_id, webhook_id)
|
||||
assert response is not None, f"Checking list webhook payloads response: {response}"
|
||||
|
||||
response = await delete_webhook(credentials, base_id, webhook_id)
|
||||
@@ -4,7 +4,6 @@ Shared configuration for all Airtable blocks using the SDK pattern.
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
from ._oauth import AirtableOAuthHandler, AirtableScope
|
||||
from ._webhook import AirtableWebhookManager
|
||||
|
||||
# Configure the Airtable provider with API key authentication
|
||||
@@ -13,20 +12,5 @@ airtable = (
|
||||
.with_api_key("AIRTABLE_API_KEY", "Airtable Personal Access Token")
|
||||
.with_webhook_manager(AirtableWebhookManager)
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.with_oauth(
|
||||
AirtableOAuthHandler,
|
||||
scopes=[
|
||||
v.value
|
||||
for v in [
|
||||
AirtableScope.DATA_RECORDS_READ,
|
||||
AirtableScope.DATA_RECORDS_WRITE,
|
||||
AirtableScope.SCHEMA_BASES_READ,
|
||||
AirtableScope.SCHEMA_BASES_WRITE,
|
||||
AirtableScope.WEBHOOK_MANAGE,
|
||||
]
|
||||
],
|
||||
client_id_env_var="AIRTABLE_CLIENT_ID",
|
||||
client_secret_env_var="AIRTABLE_CLIENT_SECRET",
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -1,185 +0,0 @@
|
||||
"""
|
||||
Airtable OAuth handler implementation.
|
||||
"""
|
||||
|
||||
import time
|
||||
from enum import Enum
|
||||
from logging import getLogger
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import BaseOAuthHandler, OAuth2Credentials, ProviderName, SecretStr
|
||||
|
||||
from ._api import (
|
||||
OAuthTokenResponse,
|
||||
make_oauth_authorize_url,
|
||||
oauth_exchange_code_for_tokens,
|
||||
oauth_refresh_tokens,
|
||||
)
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class AirtableScope(str, Enum):
|
||||
# Basic scopes
|
||||
DATA_RECORDS_READ = "data.records:read"
|
||||
DATA_RECORDS_WRITE = "data.records:write"
|
||||
DATA_RECORD_COMMENTS_READ = "data.recordComments:read"
|
||||
DATA_RECORD_COMMENTS_WRITE = "data.recordComments:write"
|
||||
SCHEMA_BASES_READ = "schema.bases:read"
|
||||
SCHEMA_BASES_WRITE = "schema.bases:write"
|
||||
WEBHOOK_MANAGE = "webhook:manage"
|
||||
BLOCK_MANAGE = "block:manage"
|
||||
USER_EMAIL_READ = "user.email:read"
|
||||
|
||||
# Enterprise member scopes
|
||||
ENTERPRISE_GROUPS_READ = "enterprise.groups:read"
|
||||
WORKSPACES_AND_BASES_READ = "workspacesAndBases:read"
|
||||
WORKSPACES_AND_BASES_WRITE = "workspacesAndBases:write"
|
||||
WORKSPACES_AND_BASES_SHARES_MANAGE = "workspacesAndBases.shares:manage"
|
||||
|
||||
# Enterprise admin scopes
|
||||
ENTERPRISE_SCIM_USERS_AND_GROUPS_MANAGE = "enterprise.scim.usersAndGroups:manage"
|
||||
ENTERPRISE_AUDIT_LOGS_READ = "enterprise.auditLogs:read"
|
||||
ENTERPRISE_CHANGE_EVENTS_READ = "enterprise.changeEvents:read"
|
||||
ENTERPRISE_EXPORTS_MANAGE = "enterprise.exports:manage"
|
||||
ENTERPRISE_ACCOUNT_READ = "enterprise.account:read"
|
||||
ENTERPRISE_ACCOUNT_WRITE = "enterprise.account:write"
|
||||
ENTERPRISE_USER_READ = "enterprise.user:read"
|
||||
ENTERPRISE_USER_WRITE = "enterprise.user:write"
|
||||
ENTERPRISE_GROUPS_MANAGE = "enterprise.groups:manage"
|
||||
WORKSPACES_AND_BASES_MANAGE = "workspacesAndBases:manage"
|
||||
HYPERDB_RECORDS_READ = "hyperDB.records:read"
|
||||
HYPERDB_RECORDS_WRITE = "hyperDB.records:write"
|
||||
|
||||
|
||||
class AirtableOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
OAuth2 handler for Airtable with PKCE support.
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = ProviderName("airtable")
|
||||
DEFAULT_SCOPES = [
|
||||
v.value
|
||||
for v in [
|
||||
AirtableScope.DATA_RECORDS_READ,
|
||||
AirtableScope.DATA_RECORDS_WRITE,
|
||||
AirtableScope.SCHEMA_BASES_READ,
|
||||
AirtableScope.SCHEMA_BASES_WRITE,
|
||||
AirtableScope.WEBHOOK_MANAGE,
|
||||
]
|
||||
]
|
||||
|
||||
def __init__(self, client_id: str, client_secret: Optional[str], redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.scopes = self.DEFAULT_SCOPES
|
||||
self.auth_base_url = "https://airtable.com/oauth2/v1/authorize"
|
||||
self.token_url = "https://airtable.com/oauth2/v1/token"
|
||||
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
logger.debug("Generating Airtable OAuth login URL")
|
||||
# Generate code_challenge if not provided (PKCE is required)
|
||||
if not scopes:
|
||||
logger.debug("No scopes provided, using default scopes")
|
||||
scopes = self.scopes
|
||||
|
||||
logger.debug(f"Using scopes: {scopes}")
|
||||
logger.debug(f"State: {state}")
|
||||
logger.debug(f"Code challenge: {code_challenge}")
|
||||
if not code_challenge:
|
||||
logger.error("Code challenge is required but none was provided")
|
||||
raise ValueError("No code challenge provided")
|
||||
|
||||
try:
|
||||
url = make_oauth_authorize_url(
|
||||
self.client_id, self.redirect_uri, scopes, state, code_challenge
|
||||
)
|
||||
logger.debug(f"Generated OAuth URL: {url}")
|
||||
return url
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate OAuth URL: {str(e)}")
|
||||
raise
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
) -> OAuth2Credentials:
|
||||
logger.debug("Exchanging authorization code for tokens")
|
||||
logger.debug(f"Code: {code[:4]}...") # Log first 4 chars only for security
|
||||
logger.debug(f"Scopes: {scopes}")
|
||||
if not code_verifier:
|
||||
logger.error("Code verifier is required but none was provided")
|
||||
raise ValueError("No code verifier provided")
|
||||
|
||||
try:
|
||||
response: OAuthTokenResponse = await oauth_exchange_code_for_tokens(
|
||||
client_id=self.client_id,
|
||||
code=code,
|
||||
code_verifier=code_verifier.encode("utf-8"),
|
||||
redirect_uri=self.redirect_uri,
|
||||
client_secret=self.client_secret,
|
||||
)
|
||||
logger.info("Successfully exchanged code for tokens")
|
||||
|
||||
credentials = OAuth2Credentials(
|
||||
access_token=SecretStr(response.access_token),
|
||||
refresh_token=SecretStr(response.refresh_token),
|
||||
access_token_expires_at=int(time.time()) + response.expires_in,
|
||||
refresh_token_expires_at=int(time.time()) + response.refresh_expires_in,
|
||||
provider=self.PROVIDER_NAME,
|
||||
scopes=scopes,
|
||||
)
|
||||
logger.debug(f"Access token expires in {response.expires_in} seconds")
|
||||
logger.debug(
|
||||
f"Refresh token expires in {response.refresh_expires_in} seconds"
|
||||
)
|
||||
return credentials
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to exchange code for tokens: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _refresh_tokens(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
logger.debug("Attempting to refresh OAuth tokens")
|
||||
|
||||
if credentials.refresh_token is None:
|
||||
logger.error("Cannot refresh tokens - no refresh token available")
|
||||
raise ValueError("No refresh token available")
|
||||
|
||||
try:
|
||||
response: OAuthTokenResponse = await oauth_refresh_tokens(
|
||||
client_id=self.client_id,
|
||||
refresh_token=credentials.refresh_token.get_secret_value(),
|
||||
client_secret=self.client_secret,
|
||||
)
|
||||
logger.info("Successfully refreshed tokens")
|
||||
|
||||
new_credentials = OAuth2Credentials(
|
||||
id=credentials.id,
|
||||
access_token=SecretStr(response.access_token),
|
||||
refresh_token=SecretStr(response.refresh_token),
|
||||
access_token_expires_at=int(time.time()) + response.expires_in,
|
||||
refresh_token_expires_at=int(time.time()) + response.refresh_expires_in,
|
||||
provider=self.PROVIDER_NAME,
|
||||
scopes=self.scopes,
|
||||
)
|
||||
logger.debug(f"New access token expires in {response.expires_in} seconds")
|
||||
logger.debug(
|
||||
f"New refresh token expires in {response.refresh_expires_in} seconds"
|
||||
)
|
||||
return new_credentials
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh tokens: {str(e)}")
|
||||
raise
|
||||
|
||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
logger.debug("Token revocation requested")
|
||||
logger.info(
|
||||
"Airtable doesn't provide a token revocation endpoint - tokens will expire naturally after 60 minutes"
|
||||
)
|
||||
return False
|
||||
@@ -4,48 +4,30 @@ Webhook management for Airtable blocks.
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Tuple
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseWebhooksManager,
|
||||
Credentials,
|
||||
ProviderName,
|
||||
Requests,
|
||||
Webhook,
|
||||
update_webhook,
|
||||
)
|
||||
|
||||
from ._api import (
|
||||
WebhookFilters,
|
||||
WebhookSpecification,
|
||||
create_webhook,
|
||||
delete_webhook,
|
||||
list_webhook_payloads,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AirtableWebhookEvent(str, Enum):
|
||||
TABLE_DATA = "tableData"
|
||||
TABLE_FIELDS = "tableFields"
|
||||
TABLE_METADATA = "tableMetadata"
|
||||
|
||||
|
||||
class AirtableWebhookManager(BaseWebhooksManager):
|
||||
"""Webhook manager for Airtable API."""
|
||||
|
||||
PROVIDER_NAME = ProviderName("airtable")
|
||||
|
||||
class WebhookType(str, Enum):
|
||||
TABLE_CHANGE = "table_change"
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(
|
||||
cls, webhook: Webhook, request, credentials: Credentials | None
|
||||
) -> tuple[dict, str]:
|
||||
async def validate_payload(cls, webhook: Webhook, request) -> Tuple[dict, str]:
|
||||
"""Validate incoming webhook payload and signature."""
|
||||
|
||||
if not credentials:
|
||||
raise ValueError("Missing credentials in webhook metadata")
|
||||
|
||||
payload = await request.json()
|
||||
|
||||
# Verify webhook signature using HMAC-SHA256
|
||||
@@ -56,9 +38,9 @@ class AirtableWebhookManager(BaseWebhooksManager):
|
||||
body = await request.body()
|
||||
|
||||
# Calculate expected signature
|
||||
mac_secret_decoded = mac_secret.encode()
|
||||
hmac_obj = hmac.new(mac_secret_decoded, body, hashlib.sha256)
|
||||
expected_mac = f"hmac-sha256={hmac_obj.hexdigest()}"
|
||||
expected_mac = hmac.new(
|
||||
mac_secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Get signature from headers
|
||||
signature = request.headers.get("X-Airtable-Content-MAC")
|
||||
@@ -66,29 +48,9 @@ class AirtableWebhookManager(BaseWebhooksManager):
|
||||
if signature and not hmac.compare_digest(signature, expected_mac):
|
||||
raise ValueError("Invalid webhook signature")
|
||||
|
||||
# Validate payload structure
|
||||
required_fields = ["base", "webhook", "timestamp"]
|
||||
if not all(field in payload for field in required_fields):
|
||||
raise ValueError("Invalid webhook payload structure")
|
||||
|
||||
if "id" not in payload["base"] or "id" not in payload["webhook"]:
|
||||
raise ValueError("Missing required IDs in webhook payload")
|
||||
base_id = payload["base"]["id"]
|
||||
webhook_id = payload["webhook"]["id"]
|
||||
|
||||
# get payload request parameters
|
||||
cursor = webhook.config.get("cursor", 1)
|
||||
|
||||
response = await list_webhook_payloads(credentials, base_id, webhook_id, cursor)
|
||||
|
||||
# update webhook config
|
||||
await update_webhook(
|
||||
webhook.id,
|
||||
config={"base_id": base_id, "cursor": response.cursor},
|
||||
)
|
||||
|
||||
# Airtable sends the cursor in the payload
|
||||
event_type = "notification"
|
||||
return response.model_dump(), event_type
|
||||
return payload, event_type
|
||||
|
||||
async def _register_webhook(
|
||||
self,
|
||||
@@ -98,8 +60,12 @@ class AirtableWebhookManager(BaseWebhooksManager):
|
||||
events: list[str],
|
||||
ingress_url: str,
|
||||
secret: str,
|
||||
) -> tuple[str, dict]:
|
||||
) -> Tuple[str, dict]:
|
||||
"""Register webhook with Airtable API."""
|
||||
if not isinstance(credentials, APIKeyCredentials):
|
||||
raise ValueError("Airtable webhooks require API key credentials")
|
||||
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Parse resource to get base_id and table_id/name
|
||||
# Resource format: "{base_id}/{table_id_or_name}"
|
||||
@@ -110,30 +76,33 @@ class AirtableWebhookManager(BaseWebhooksManager):
|
||||
base_id, table_id_or_name = parts
|
||||
|
||||
# Prepare webhook specification
|
||||
webhook_specification = WebhookSpecification(
|
||||
filters=WebhookFilters(
|
||||
dataTypes=events,
|
||||
)
|
||||
)
|
||||
specification = {
|
||||
"filters": {
|
||||
"dataTypes": events or ["tableData", "tableFields", "tableMetadata"]
|
||||
}
|
||||
}
|
||||
|
||||
# If specific table is provided, add to specification
|
||||
if table_id_or_name and table_id_or_name != "*":
|
||||
specification["filters"]["recordChangeScope"] = [table_id_or_name]
|
||||
|
||||
# Create webhook
|
||||
webhook_data = await create_webhook(
|
||||
credentials=credentials,
|
||||
base_id=base_id,
|
||||
webhook_specification=webhook_specification,
|
||||
notification_url=ingress_url,
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/bases/{base_id}/webhooks",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json={"notificationUrl": ingress_url, "specification": specification},
|
||||
)
|
||||
|
||||
webhook_data = response.json()
|
||||
webhook_id = webhook_data["id"]
|
||||
mac_secret = webhook_data.get("macSecretBase64")
|
||||
|
||||
return webhook_id, {
|
||||
"webhook_id": webhook_id,
|
||||
"base_id": base_id,
|
||||
"table_id_or_name": table_id_or_name,
|
||||
"events": events,
|
||||
"mac_secret": mac_secret,
|
||||
"cursor": 1,
|
||||
"cursor": 1, # Start from cursor 1
|
||||
"expiration_time": webhook_data.get("expirationTime"),
|
||||
}
|
||||
|
||||
@@ -141,14 +110,16 @@ class AirtableWebhookManager(BaseWebhooksManager):
|
||||
self, webhook: Webhook, credentials: Credentials
|
||||
) -> None:
|
||||
"""Deregister webhook from Airtable API."""
|
||||
if not isinstance(credentials, APIKeyCredentials):
|
||||
raise ValueError("Airtable webhooks require API key credentials")
|
||||
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
base_id = webhook.config.get("base_id")
|
||||
webhook_id = webhook.config.get("webhook_id")
|
||||
|
||||
if not base_id:
|
||||
raise ValueError("Missing base_id in webhook metadata")
|
||||
|
||||
if not webhook_id:
|
||||
raise ValueError("Missing webhook_id in webhook metadata")
|
||||
|
||||
await delete_webhook(credentials, base_id, webhook_id)
|
||||
await Requests().delete(
|
||||
f"https://api.airtable.com/v0/bases/{base_id}/webhooks/{webhook.provider_webhook_id}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Airtable attachment blocks.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Union
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableUploadAttachmentBlock(Block):
|
||||
"""
|
||||
Uploads a file to Airtable for use as an attachment.
|
||||
|
||||
Files can be uploaded directly (up to 5MB) or via URL.
|
||||
The returned attachment ID can be used when creating or updating records.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
filename: str = SchemaField(description="Name of the file")
|
||||
file: Union[bytes, str] = SchemaField(
|
||||
description="File content (binary data or base64 string)"
|
||||
)
|
||||
content_type: str = SchemaField(
|
||||
description="MIME type of the file", default="application/octet-stream"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
attachment: dict = SchemaField(
|
||||
description="Attachment object with id, url, size, and type"
|
||||
)
|
||||
attachment_id: str = SchemaField(description="ID of the uploaded attachment")
|
||||
url: str = SchemaField(description="URL of the uploaded attachment")
|
||||
size: int = SchemaField(description="Size of the file in bytes")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="962e801b-5a6f-4c56-a929-83e816343a41",
|
||||
description="Upload a file to Airtable for use as an attachment",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Convert file to base64 if it's bytes
|
||||
if isinstance(input_data.file, bytes):
|
||||
file_data = base64.b64encode(input_data.file).decode("utf-8")
|
||||
else:
|
||||
# Assume it's already base64 encoded
|
||||
file_data = input_data.file
|
||||
|
||||
# Check file size (5MB limit)
|
||||
file_bytes = base64.b64decode(file_data)
|
||||
if len(file_bytes) > 5 * 1024 * 1024:
|
||||
raise ValueError(
|
||||
"File size exceeds 5MB limit. Use URL upload for larger files."
|
||||
)
|
||||
|
||||
# Upload the attachment
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/bases/{input_data.base_id}/attachments/upload",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"content": file_data,
|
||||
"filename": input_data.filename,
|
||||
"type": input_data.content_type,
|
||||
},
|
||||
)
|
||||
|
||||
attachment_data = response.json()
|
||||
|
||||
yield "attachment", attachment_data
|
||||
yield "attachment_id", attachment_data.get("id", "")
|
||||
yield "url", attachment_data.get("url", "")
|
||||
yield "size", attachment_data.get("size", 0)
|
||||
@@ -1,122 +0,0 @@
|
||||
"""
|
||||
Airtable base operation blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import create_base, list_bases
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableCreateBaseBlock(Block):
|
||||
"""
|
||||
Creates a new base in an Airtable workspace.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
workspace_id: str = SchemaField(
|
||||
description="The workspace ID where the base will be created"
|
||||
)
|
||||
name: str = SchemaField(description="The name of the new base")
|
||||
tables: list[dict] = SchemaField(
|
||||
description="At least one table and field must be specified. Array of table objects to create in the base. Each table should have 'name' and 'fields' properties",
|
||||
default=[
|
||||
{
|
||||
"description": "Default table",
|
||||
"name": "Default table",
|
||||
"fields": [
|
||||
{
|
||||
"name": "ID",
|
||||
"type": "number",
|
||||
"description": "Auto-incrementing ID field",
|
||||
"options": {"precision": 0},
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
base_id: str = SchemaField(description="The ID of the created base")
|
||||
tables: list[dict] = SchemaField(description="Array of table objects")
|
||||
table: dict = SchemaField(description="A single table object")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f59b88a8-54ce-4676-a508-fd614b4e8dce",
|
||||
description="Create a new base in Airtable",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
data = await create_base(
|
||||
credentials,
|
||||
input_data.workspace_id,
|
||||
input_data.name,
|
||||
input_data.tables,
|
||||
)
|
||||
|
||||
yield "base_id", data.get("id", None)
|
||||
yield "tables", data.get("tables", [])
|
||||
for table in data.get("tables", []):
|
||||
yield "table", table
|
||||
|
||||
|
||||
class AirtableListBasesBlock(Block):
|
||||
"""
|
||||
Lists all bases in an Airtable workspace that the user has access to.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
trigger: str = SchemaField(
|
||||
description="Trigger the block to run - value is ignored", default="manual"
|
||||
)
|
||||
offset: str = SchemaField(
|
||||
description="Pagination offset from previous request", default=""
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
bases: list[dict] = SchemaField(description="Array of base objects")
|
||||
offset: Optional[str] = SchemaField(
|
||||
description="Offset for next page (null if no more bases)", default=None
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4bd8d466-ed5d-4e44-8083-97f25a8044e7",
|
||||
description="List all bases in Airtable",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
data = await list_bases(
|
||||
credentials,
|
||||
offset=input_data.offset if input_data.offset else None,
|
||||
)
|
||||
|
||||
yield "bases", data.get("bases", [])
|
||||
yield "offset", data.get("offset", None)
|
||||
145
autogpt_platform/backend/backend/blocks/airtable/metadata.py
Normal file
145
autogpt_platform/backend/backend/blocks/airtable/metadata.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
Airtable metadata blocks for bases and views.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableListBasesBlock(Block):
|
||||
"""
|
||||
Lists all Airtable bases accessible by the API token.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
bases: list[dict] = SchemaField(
|
||||
description="Array of base objects with id and name"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="613f9907-bef8-468a-be6d-2dd7a53f96e7",
|
||||
description="List all accessible Airtable bases",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# List bases
|
||||
response = await Requests().get(
|
||||
"https://api.airtable.com/v0/meta/bases",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "bases", data.get("bases", [])
|
||||
|
||||
|
||||
class AirtableListViewsBlock(Block):
|
||||
"""
|
||||
Lists all views in an Airtable base with their associated tables.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
|
||||
class Output(BlockSchema):
|
||||
views: list[dict] = SchemaField(
|
||||
description="Array of view objects with tableId"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3878cf82-d384-40c2-aace-097042233f6a",
|
||||
description="List all views in an Airtable base",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Get base schema which includes views
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/tables",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Extract all views from all tables
|
||||
all_views = []
|
||||
for table in data.get("tables", []):
|
||||
table_id = table.get("id")
|
||||
for view in table.get("views", []):
|
||||
view_with_table = {**view, "tableId": table_id}
|
||||
all_views.append(view_with_table)
|
||||
|
||||
yield "views", all_views
|
||||
|
||||
|
||||
class AirtableGetViewBlock(Block):
|
||||
"""
|
||||
Gets detailed information about a specific view in an Airtable base.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
view_id: str = SchemaField(description="The view ID to retrieve")
|
||||
|
||||
class Output(BlockSchema):
|
||||
view: dict = SchemaField(description="Full view object with configuration")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ad0dd9f3-b3f4-446b-8142-e81a566797c4",
|
||||
description="Get details of a specific Airtable view",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Get specific view
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/views/{input_data.view_id}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
view_data = response.json()
|
||||
|
||||
yield "view", view_data
|
||||
@@ -11,16 +11,10 @@ from backend.sdk import (
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import (
|
||||
create_record,
|
||||
delete_multiple_records,
|
||||
get_record,
|
||||
list_records,
|
||||
update_multiple_records,
|
||||
)
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
@@ -33,8 +27,8 @@ class AirtableListRecordsBlock(Block):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name")
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name", default="")
|
||||
filter_formula: str = SchemaField(
|
||||
description="Airtable formula to filter records", default=""
|
||||
)
|
||||
@@ -65,7 +59,7 @@ class AirtableListRecordsBlock(Block):
|
||||
super().__init__(
|
||||
id="588a9fde-5733-4da7-b03c-35f5671e960f",
|
||||
description="List records from an Airtable table",
|
||||
categories={BlockCategory.DATA},
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
@@ -73,21 +67,37 @@ class AirtableListRecordsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
data = await list_records(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id_or_name,
|
||||
filter_by_formula=(
|
||||
input_data.filter_formula if input_data.filter_formula else None
|
||||
),
|
||||
view=input_data.view if input_data.view else None,
|
||||
sort=input_data.sort if input_data.sort else None,
|
||||
max_records=input_data.max_records if input_data.max_records else None,
|
||||
page_size=min(input_data.page_size, 100) if input_data.page_size else None,
|
||||
offset=input_data.offset if input_data.offset else None,
|
||||
fields=input_data.return_fields if input_data.return_fields else None,
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build query parameters
|
||||
params = {}
|
||||
if input_data.filter_formula:
|
||||
params["filterByFormula"] = input_data.filter_formula
|
||||
if input_data.view:
|
||||
params["view"] = input_data.view
|
||||
if input_data.sort:
|
||||
for i, sort_config in enumerate(input_data.sort):
|
||||
params[f"sort[{i}][field]"] = sort_config.get("field", "")
|
||||
params[f"sort[{i}][direction]"] = sort_config.get("direction", "asc")
|
||||
if input_data.max_records:
|
||||
params["maxRecords"] = input_data.max_records
|
||||
if input_data.page_size:
|
||||
params["pageSize"] = min(input_data.page_size, 100)
|
||||
if input_data.offset:
|
||||
params["offset"] = input_data.offset
|
||||
if input_data.return_fields:
|
||||
for i, field in enumerate(input_data.return_fields):
|
||||
params[f"fields[{i}]"] = field
|
||||
|
||||
# Make request
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/{input_data.base_id}/{input_data.table_id_or_name}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
params=params,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
yield "offset", data.get("offset", None)
|
||||
|
||||
@@ -101,20 +111,21 @@ class AirtableGetRecordBlock(Block):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name")
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name", default="")
|
||||
record_id: str = SchemaField(description="The record ID to retrieve")
|
||||
return_fields: list[str] = SchemaField(
|
||||
description="Specific fields to return", default=[]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
id: str = SchemaField(description="The record ID")
|
||||
fields: dict = SchemaField(description="The record fields")
|
||||
created_time: str = SchemaField(description="The record created time")
|
||||
record: dict = SchemaField(description="The record object")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c29c5cbf-0aff-40f9-bbb5-f26061792d2b",
|
||||
description="Get a single record from Airtable",
|
||||
categories={BlockCategory.DATA},
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
@@ -122,16 +133,24 @@ class AirtableGetRecordBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
record = await get_record(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id_or_name,
|
||||
input_data.record_id,
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build query parameters
|
||||
params = {}
|
||||
if input_data.return_fields:
|
||||
for i, field in enumerate(input_data.return_fields):
|
||||
params[f"fields[{i}]"] = field
|
||||
|
||||
# Make request
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/{input_data.base_id}/{input_data.table_id_or_name}/{input_data.record_id}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
params=params,
|
||||
)
|
||||
|
||||
yield "id", record.get("id", None)
|
||||
yield "fields", record.get("fields", None)
|
||||
yield "created_time", record.get("createdTime", None)
|
||||
record = response.json()
|
||||
|
||||
yield "record", record
|
||||
|
||||
|
||||
class AirtableCreateRecordsBlock(Block):
|
||||
@@ -143,8 +162,8 @@ class AirtableCreateRecordsBlock(Block):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name")
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name", default="")
|
||||
records: list[dict] = SchemaField(
|
||||
description="Array of records to create (each with 'fields' object)"
|
||||
)
|
||||
@@ -152,14 +171,12 @@ class AirtableCreateRecordsBlock(Block):
|
||||
description="Automatically convert string values to appropriate types",
|
||||
default=False,
|
||||
)
|
||||
return_fields_by_field_id: bool | None = SchemaField(
|
||||
description="Return fields by field ID",
|
||||
default=None,
|
||||
return_fields: list[str] = SchemaField(
|
||||
description="Specific fields to return in created records", default=[]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of created record objects")
|
||||
details: dict = SchemaField(description="Details of the created records")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -173,20 +190,28 @@ class AirtableCreateRecordsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# The create_record API expects records in a specific format
|
||||
data = await create_record(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id_or_name,
|
||||
records=[{"fields": record} for record in input_data.records],
|
||||
typecast=input_data.typecast if input_data.typecast else None,
|
||||
return_fields_by_field_id=input_data.return_fields_by_field_id,
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build request body
|
||||
body = {"records": input_data.records, "typecast": input_data.typecast}
|
||||
|
||||
# Build query parameters for return fields
|
||||
params = {}
|
||||
if input_data.return_fields:
|
||||
for i, field in enumerate(input_data.return_fields):
|
||||
params[f"fields[{i}]"] = field
|
||||
|
||||
# Make request
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/{input_data.base_id}/{input_data.table_id_or_name}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json=body,
|
||||
params=params,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
details = data.get("details", None)
|
||||
if details:
|
||||
yield "details", details
|
||||
|
||||
|
||||
class AirtableUpdateRecordsBlock(Block):
|
||||
@@ -198,16 +223,17 @@ class AirtableUpdateRecordsBlock(Block):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(
|
||||
description="Table ID or name - It's better to use the table ID instead of the name"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name", default="")
|
||||
records: list[dict] = SchemaField(
|
||||
description="Array of records to update (each with 'id' and 'fields')"
|
||||
)
|
||||
typecast: bool | None = SchemaField(
|
||||
typecast: bool = SchemaField(
|
||||
description="Automatically convert string values to appropriate types",
|
||||
default=None,
|
||||
default=False,
|
||||
)
|
||||
return_fields: list[str] = SchemaField(
|
||||
description="Specific fields to return in updated records", default=[]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -225,16 +251,100 @@ class AirtableUpdateRecordsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# The update_multiple_records API expects records with id and fields
|
||||
data = await update_multiple_records(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id_or_name,
|
||||
records=input_data.records,
|
||||
typecast=input_data.typecast if input_data.typecast else None,
|
||||
return_fields_by_field_id=False, # Use field names, not IDs
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build request body
|
||||
body = {"records": input_data.records, "typecast": input_data.typecast}
|
||||
|
||||
# Build query parameters for return fields
|
||||
params = {}
|
||||
if input_data.return_fields:
|
||||
for i, field in enumerate(input_data.return_fields):
|
||||
params[f"fields[{i}]"] = field
|
||||
|
||||
# Make request
|
||||
response = await Requests().patch(
|
||||
f"https://api.airtable.com/v0/{input_data.base_id}/{input_data.table_id_or_name}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json=body,
|
||||
params=params,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
|
||||
|
||||
class AirtableUpsertRecordsBlock(Block):
|
||||
"""
|
||||
Creates or updates records in an Airtable table based on a merge field.
|
||||
|
||||
If a record with the same value in the merge field exists, it will be updated.
|
||||
Otherwise, a new record will be created.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name", default="")
|
||||
records: list[dict] = SchemaField(
|
||||
description="Array of records to upsert (each with 'fields' object)"
|
||||
)
|
||||
merge_field: str = SchemaField(
|
||||
description="Field to use for matching existing records"
|
||||
)
|
||||
typecast: bool = SchemaField(
|
||||
description="Automatically convert string values to appropriate types",
|
||||
default=False,
|
||||
)
|
||||
return_fields: list[str] = SchemaField(
|
||||
description="Specific fields to return in upserted records", default=[]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(
|
||||
description="Array of created/updated record objects"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="99f78a9d-3418-429f-a6fb-9d2166638e99",
|
||||
description="Create or update records based on a merge field",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build request body
|
||||
body = {
|
||||
"performUpsert": {"fieldsToMergeOn": [input_data.merge_field]},
|
||||
"records": input_data.records,
|
||||
"typecast": input_data.typecast,
|
||||
}
|
||||
|
||||
# Build query parameters for return fields
|
||||
params = {}
|
||||
if input_data.return_fields:
|
||||
for i, field in enumerate(input_data.return_fields):
|
||||
params[f"fields[{i}]"] = field
|
||||
|
||||
# Make request
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/{input_data.base_id}/{input_data.table_id_or_name}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json=body,
|
||||
params=params,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
|
||||
|
||||
@@ -247,13 +357,9 @@ class AirtableDeleteRecordsBlock(Block):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(
|
||||
description="Table ID or name - It's better to use the table ID instead of the name"
|
||||
)
|
||||
record_ids: list[str] = SchemaField(
|
||||
description="Array of upto 10 record IDs to delete"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name", default="")
|
||||
record_ids: list[str] = SchemaField(description="Array of record IDs to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of deletion results")
|
||||
@@ -270,14 +376,20 @@ class AirtableDeleteRecordsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
if len(input_data.record_ids) > 10:
|
||||
yield "error", "Only upto 10 record IDs can be deleted at a time"
|
||||
else:
|
||||
data = await delete_multiple_records(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id_or_name,
|
||||
input_data.record_ids,
|
||||
)
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
# Build query parameters
|
||||
params = {}
|
||||
for i, record_id in enumerate(input_data.record_ids):
|
||||
params[f"records[{i}]"] = record_id
|
||||
|
||||
# Make request
|
||||
response = await Requests().delete(
|
||||
f"https://api.airtable.com/v0/{input_data.base_id}/{input_data.table_id_or_name}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
params=params,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
|
||||
@@ -13,7 +13,6 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import TableFieldType, create_field, create_table, update_field, update_table
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
@@ -27,7 +26,7 @@ class AirtableListSchemaBlock(Block):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
|
||||
class Output(BlockSchema):
|
||||
base_schema: dict = SchemaField(
|
||||
@@ -39,7 +38,7 @@ class AirtableListSchemaBlock(Block):
|
||||
super().__init__(
|
||||
id="64291d3c-99b5-47b7-a976-6d94293cdb2d",
|
||||
description="Get the complete schema of an Airtable base",
|
||||
categories={BlockCategory.DATA},
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
@@ -70,11 +69,13 @@ class AirtableCreateTableBlock(Block):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_name: str = SchemaField(description="The name of the table to create")
|
||||
table_fields: list[dict] = SchemaField(
|
||||
description="Table fields with name, type, and options",
|
||||
default=[{"name": "Name", "type": "singleLineText"}],
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_definition: dict = SchemaField(
|
||||
description="Table definition with name, description, fields, and views",
|
||||
default={
|
||||
"name": "New Table",
|
||||
"fields": [{"name": "Name", "type": "singleLineText"}],
|
||||
},
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -93,13 +94,17 @@ class AirtableCreateTableBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
table_data = await create_table(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_name,
|
||||
input_data.table_fields,
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Create table
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/tables",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json=input_data.table_definition,
|
||||
)
|
||||
|
||||
table_data = response.json()
|
||||
|
||||
yield "table", table_data
|
||||
yield "table_id", table_data.get("id", "")
|
||||
|
||||
@@ -113,16 +118,10 @@ class AirtableUpdateTableBlock(Block):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id: str = SchemaField(description="The table ID to update")
|
||||
table_name: str | None = SchemaField(
|
||||
description="The name of the table to update", default=None
|
||||
)
|
||||
table_description: str | None = SchemaField(
|
||||
description="The description of the table to update", default=None
|
||||
)
|
||||
date_dependency: dict | None = SchemaField(
|
||||
description="The date dependency of the table to update", default=None
|
||||
patch: dict = SchemaField(
|
||||
description="Properties to update (name, description)", default={}
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -140,19 +139,63 @@ class AirtableUpdateTableBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
table_data = await update_table(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id,
|
||||
input_data.table_name,
|
||||
input_data.table_description,
|
||||
input_data.date_dependency,
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Update table
|
||||
response = await Requests().patch(
|
||||
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/tables/{input_data.table_id}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json=input_data.patch,
|
||||
)
|
||||
|
||||
table_data = response.json()
|
||||
|
||||
yield "table", table_data
|
||||
|
||||
|
||||
class AirtableCreateFieldBlock(Block):
|
||||
class AirtableDeleteTableBlock(Block):
|
||||
"""
|
||||
Deletes a table from an Airtable base.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id: str = SchemaField(description="The table ID to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
deleted: bool = SchemaField(
|
||||
description="Confirmation that the table was deleted"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="6b96c196-d0ad-4fb2-981f-7a330549bc22",
|
||||
description="Delete a table from an Airtable base",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Delete table
|
||||
response = await Requests().delete(
|
||||
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/tables/{input_data.table_id}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
deleted = response.status in [200, 204]
|
||||
|
||||
yield "deleted", deleted
|
||||
|
||||
|
||||
class AirtableAddFieldBlock(Block):
|
||||
"""
|
||||
Adds a new field (column) to an existing Airtable table.
|
||||
"""
|
||||
@@ -161,19 +204,11 @@ class AirtableCreateFieldBlock(Block):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id: str = SchemaField(description="The table ID to add field to")
|
||||
field_type: TableFieldType = SchemaField(
|
||||
description="The type of the field to create",
|
||||
default=TableFieldType.SINGLE_LINE_TEXT,
|
||||
advanced=False,
|
||||
)
|
||||
name: str = SchemaField(description="The name of the field to create")
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the field to create", default=None
|
||||
)
|
||||
options: dict[str, str] | None = SchemaField(
|
||||
description="The options of the field to create", default=None
|
||||
field_definition: dict = SchemaField(
|
||||
description="Field definition with name, type, and options",
|
||||
default={"name": "New Field", "type": "singleLineText"},
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -192,14 +227,17 @@ class AirtableCreateFieldBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
field_data = await create_field(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id,
|
||||
input_data.field_type,
|
||||
input_data.name,
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Add field
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/tables/{input_data.table_id}/fields",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json=input_data.field_definition,
|
||||
)
|
||||
|
||||
field_data = response.json()
|
||||
|
||||
yield "field", field_data
|
||||
yield "field_id", field_data.get("id", "")
|
||||
|
||||
@@ -213,17 +251,10 @@ class AirtableUpdateFieldBlock(Block):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id: str = SchemaField(description="The table ID containing the field")
|
||||
field_id: str = SchemaField(description="The field ID to update")
|
||||
name: str | None = SchemaField(
|
||||
description="The name of the field to update", default=None, advanced=False
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the field to update",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
patch: dict = SchemaField(description="Field properties to update", default={})
|
||||
|
||||
class Output(BlockSchema):
|
||||
field: dict = SchemaField(description="Updated field object")
|
||||
@@ -240,13 +271,58 @@ class AirtableUpdateFieldBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
field_data = await update_field(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id,
|
||||
input_data.field_id,
|
||||
input_data.name,
|
||||
input_data.description,
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Update field
|
||||
response = await Requests().patch(
|
||||
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/tables/{input_data.table_id}/fields/{input_data.field_id}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json=input_data.patch,
|
||||
)
|
||||
|
||||
field_data = response.json()
|
||||
|
||||
yield "field", field_data
|
||||
|
||||
|
||||
class AirtableDeleteFieldBlock(Block):
|
||||
"""
|
||||
Deletes a field from an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID", default="")
|
||||
table_id: str = SchemaField(description="The table ID containing the field")
|
||||
field_id: str = SchemaField(description="The field ID to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
deleted: bool = SchemaField(
|
||||
description="Confirmation that the field was deleted"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ca6ebacb-be8b-4c54-80a3-1fb519ad51c6",
|
||||
description="Delete a field from an Airtable table",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Delete field
|
||||
response = await Requests().delete(
|
||||
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/tables/{input_data.table_id}/fields/{input_data.field_id}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
deleted = response.status in [200, 204]
|
||||
|
||||
yield "deleted", deleted
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""
|
||||
Airtable webhook trigger blocks.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
BaseModel,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
@@ -11,103 +14,136 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import WebhookPayload
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableEventSelector(BaseModel):
|
||||
"""
|
||||
Selects the Airtable webhook event to trigger on.
|
||||
"""
|
||||
|
||||
tableData: bool = True
|
||||
tableFields: bool = True
|
||||
tableMetadata: bool = True
|
||||
|
||||
|
||||
class AirtableWebhookTriggerBlock(Block):
|
||||
"""
|
||||
Starts a flow whenever Airtable emits a webhook event.
|
||||
Starts a flow whenever Airtable pings your webhook URL.
|
||||
|
||||
Thin wrapper just forwards the payloads one at a time to the next block.
|
||||
If auto-fetch is enabled, it automatically fetches the full payloads
|
||||
after receiving the notification.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(description="Airtable table ID or name")
|
||||
payload: dict = SchemaField(hidden=True, default_factory=dict)
|
||||
events: AirtableEventSelector = SchemaField(
|
||||
description="Airtable webhook event filter"
|
||||
webhook_url: str = SchemaField(
|
||||
description="URL to receive webhooks (auto-generated)",
|
||||
default="",
|
||||
hidden=True,
|
||||
)
|
||||
base_id: str = SchemaField(
|
||||
description="The Airtable base ID to monitor",
|
||||
default="",
|
||||
)
|
||||
table_id_or_name: str = SchemaField(
|
||||
description="Table ID or name to monitor (leave empty for all tables)",
|
||||
default="",
|
||||
)
|
||||
event_types: list[str] = SchemaField(
|
||||
description="Event types to listen for",
|
||||
default=["tableData", "tableFields", "tableMetadata"],
|
||||
)
|
||||
auto_fetch: bool = SchemaField(
|
||||
description="Automatically fetch full payloads after notification",
|
||||
default=True,
|
||||
)
|
||||
payload: dict = SchemaField(
|
||||
description="Webhook payload data",
|
||||
default={},
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
payload: WebhookPayload = SchemaField(description="Airtable webhook payload")
|
||||
ping: dict = SchemaField(description="Raw webhook notification body")
|
||||
headers: dict = SchemaField(description="Webhook request headers")
|
||||
verified: bool = SchemaField(
|
||||
description="Whether the webhook signature was verified"
|
||||
)
|
||||
# Fields populated when auto_fetch is True
|
||||
payloads: list[dict] = SchemaField(
|
||||
description="Array of change payloads (when auto-fetch is enabled)",
|
||||
default=[],
|
||||
)
|
||||
next_cursor: int = SchemaField(
|
||||
description="Next cursor for pagination (when auto-fetch is enabled)",
|
||||
default=0,
|
||||
)
|
||||
might_have_more: bool = SchemaField(
|
||||
description="Whether there might be more payloads (when auto-fetch is enabled)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
example_payload = {
|
||||
"payloads": [
|
||||
{
|
||||
"timestamp": "2022-02-01T21:25:05.663Z",
|
||||
"baseTransactionNumber": 4,
|
||||
"actionMetadata": {
|
||||
"source": "client",
|
||||
"sourceMetadata": {
|
||||
"user": {
|
||||
"id": "usr00000000000000",
|
||||
"email": "foo@bar.com",
|
||||
"permissionLevel": "create",
|
||||
}
|
||||
},
|
||||
},
|
||||
"payloadFormat": "v0",
|
||||
}
|
||||
],
|
||||
"cursor": 5,
|
||||
"mightHaveMore": False,
|
||||
}
|
||||
|
||||
super().__init__(
|
||||
# NOTE: This is disabled whilst the webhook system is finalised.
|
||||
disabled=False,
|
||||
id="d0180ce6-ccb9-48c7-8256-b39e93e62801",
|
||||
description="Starts a flow whenever Airtable emits a webhook event",
|
||||
categories={BlockCategory.INPUT, BlockCategory.DATA},
|
||||
description="Starts a flow whenever Airtable pings your webhook URL",
|
||||
categories={BlockCategory.INPUT},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
block_type=BlockType.WEBHOOK,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName("airtable"),
|
||||
webhook_type="not-used",
|
||||
event_filter_input="events",
|
||||
event_format="{event}",
|
||||
webhook_type="table_change",
|
||||
# event_filter_input="event_types",
|
||||
resource_format="{base_id}/{table_id_or_name}",
|
||||
),
|
||||
test_input={
|
||||
"credentials": airtable.get_test_credentials().model_dump(),
|
||||
"base_id": "app1234567890",
|
||||
"table_id_or_name": "table1234567890",
|
||||
"events": AirtableEventSelector(
|
||||
tableData=True,
|
||||
tableFields=True,
|
||||
tableMetadata=False,
|
||||
).model_dump(),
|
||||
"payload": example_payload,
|
||||
},
|
||||
test_credentials=airtable.get_test_credentials(),
|
||||
test_output=[
|
||||
(
|
||||
"payload",
|
||||
WebhookPayload.model_validate(example_payload["payloads"][0]),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
if len(input_data.payload["payloads"]) > 0:
|
||||
for item in input_data.payload["payloads"]:
|
||||
yield "payload", WebhookPayload.model_validate(item)
|
||||
payload = input_data.payload
|
||||
|
||||
# Extract headers from the webhook request (passed through kwargs)
|
||||
headers = kwargs.get("webhook_headers", {})
|
||||
|
||||
# Check if signature was verified (handled by webhook manager)
|
||||
verified = True # Webhook manager raises error if verification fails
|
||||
|
||||
# Output basic webhook data
|
||||
yield "ping", payload
|
||||
yield "headers", headers
|
||||
yield "verified", verified
|
||||
|
||||
# If auto-fetch is enabled and we have a cursor, fetch the full payloads
|
||||
if input_data.auto_fetch and payload.get("base", {}).get("id"):
|
||||
base_id = payload["base"]["id"]
|
||||
webhook_id = payload.get("webhook", {}).get("id", "")
|
||||
cursor = payload.get("cursor", 1)
|
||||
|
||||
if webhook_id and cursor:
|
||||
# Get credentials from kwargs
|
||||
credentials = kwargs.get("credentials")
|
||||
if credentials:
|
||||
# Fetch payloads using the Airtable API
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
from backend.sdk import Requests
|
||||
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/bases/{base_id}/webhooks/{webhook_id}/payloads",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
params={"cursor": cursor},
|
||||
)
|
||||
|
||||
if response.status == 200:
|
||||
data = response.json()
|
||||
yield "payloads", data.get("payloads", [])
|
||||
yield "next_cursor", data.get("cursor", cursor)
|
||||
yield "might_have_more", data.get("mightHaveMore", False)
|
||||
else:
|
||||
# On error, still output empty payloads
|
||||
yield "payloads", []
|
||||
yield "next_cursor", cursor
|
||||
yield "might_have_more", False
|
||||
else:
|
||||
# No credentials, can't fetch
|
||||
yield "payloads", []
|
||||
yield "next_cursor", cursor
|
||||
yield "might_have_more", False
|
||||
else:
|
||||
yield "error", "No valid payloads found in webhook payload"
|
||||
# Auto-fetch disabled or missing data
|
||||
yield "payloads", []
|
||||
yield "next_cursor", 0
|
||||
yield "might_have_more", False
|
||||
|
||||
229
autogpt_platform/backend/backend/blocks/airtable/webhooks.py
Normal file
229
autogpt_platform/backend/backend/blocks/airtable/webhooks.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
Airtable webhook management blocks.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableFetchWebhookPayloadsBlock(Block):
|
||||
"""
|
||||
Fetches accumulated event payloads for a webhook.
|
||||
|
||||
Use this to pull the full change details after receiving a webhook notification,
|
||||
or run on a schedule to poll for changes.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
webhook_id: str = SchemaField(
|
||||
description="The webhook ID to fetch payloads for"
|
||||
)
|
||||
cursor: int = SchemaField(
|
||||
description="Cursor position (0 = all payloads)", default=0
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
payloads: list[dict] = SchemaField(description="Array of webhook payloads")
|
||||
next_cursor: int = SchemaField(description="Next cursor for pagination")
|
||||
might_have_more: bool = SchemaField(
|
||||
description="Whether there might be more payloads"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7172db38-e338-4561-836f-9fa282c99949",
|
||||
description="Fetch webhook payloads from Airtable",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Fetch payloads from Airtable
|
||||
params = {}
|
||||
if input_data.cursor > 0:
|
||||
params["cursor"] = input_data.cursor
|
||||
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/bases/{input_data.base_id}/webhooks/{input_data.webhook_id}/payloads",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
params=params,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "payloads", data.get("payloads", [])
|
||||
yield "next_cursor", data.get("cursor", input_data.cursor)
|
||||
yield "might_have_more", data.get("mightHaveMore", False)
|
||||
|
||||
|
||||
class AirtableRefreshWebhookBlock(Block):
|
||||
"""
|
||||
Refreshes a webhook to extend its expiration by another 7 days.
|
||||
|
||||
Webhooks expire after 7 days of inactivity. Use this block in a daily
|
||||
cron job to keep long-lived webhooks active.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
webhook_id: str = SchemaField(description="The webhook ID to refresh")
|
||||
|
||||
class Output(BlockSchema):
|
||||
expiration_time: str = SchemaField(
|
||||
description="New expiration time (ISO format)"
|
||||
)
|
||||
webhook: dict = SchemaField(description="Full webhook object")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5e82d957-02b8-47eb-8974-7bdaf8caff78",
|
||||
description="Refresh a webhook to extend its expiration",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Refresh the webhook
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/bases/{input_data.base_id}/webhooks/{input_data.webhook_id}/refresh",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
webhook_data = response.json()
|
||||
|
||||
yield "expiration_time", webhook_data.get("expirationTime", "")
|
||||
yield "webhook", webhook_data
|
||||
|
||||
|
||||
class AirtableCreateWebhookBlock(Block):
|
||||
"""
|
||||
Creates a new webhook for monitoring changes in an Airtable base.
|
||||
|
||||
The webhook will send notifications to the specified URL when changes occur.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID to monitor")
|
||||
notification_url: str = SchemaField(
|
||||
description="URL to receive webhook notifications"
|
||||
)
|
||||
specification: dict = SchemaField(
|
||||
description="Webhook specification (filters, options)",
|
||||
default={
|
||||
"filters": {"dataTypes": ["tableData", "tableFields", "tableMetadata"]}
|
||||
},
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webhook: dict = SchemaField(description="Created webhook object")
|
||||
webhook_id: str = SchemaField(description="ID of the created webhook")
|
||||
mac_secret: str = SchemaField(
|
||||
description="MAC secret for signature verification"
|
||||
)
|
||||
expiration_time: str = SchemaField(description="Webhook expiration time")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b9f1f4ec-f4d1-4fbd-ab0b-b219c0e4da9a",
|
||||
description="Create a new Airtable webhook",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Create the webhook
|
||||
response = await Requests().post(
|
||||
f"https://api.airtable.com/v0/bases/{input_data.base_id}/webhooks",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json={
|
||||
"notificationUrl": input_data.notification_url,
|
||||
"specification": input_data.specification,
|
||||
},
|
||||
)
|
||||
|
||||
webhook_data = response.json()
|
||||
|
||||
yield "webhook", webhook_data
|
||||
yield "webhook_id", webhook_data.get("id", "")
|
||||
yield "mac_secret", webhook_data.get("macSecretBase64", "")
|
||||
yield "expiration_time", webhook_data.get("expirationTime", "")
|
||||
|
||||
|
||||
class AirtableDeleteWebhookBlock(Block):
|
||||
"""
|
||||
Deletes a webhook from an Airtable base.
|
||||
|
||||
This will stop all notifications from the webhook.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
webhook_id: str = SchemaField(description="The webhook ID to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
deleted: bool = SchemaField(
|
||||
description="Whether the webhook was successfully deleted"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e4ded448-1515-4fe2-b93e-3e4db527df83",
|
||||
description="Delete an Airtable webhook",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Delete the webhook
|
||||
response = await Requests().delete(
|
||||
f"https://api.airtable.com/v0/bases/{input_data.base_id}/webhooks/{input_data.webhook_id}",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
# Check if deletion was successful
|
||||
deleted = response.status in [200, 204]
|
||||
|
||||
yield "deleted", deleted
|
||||
@@ -1,15 +0,0 @@
|
||||
AYRSHARE_BLOCK_IDS = [
|
||||
"cbd52c2a-06d2-43ed-9560-6576cc163283", # PostToBlueskyBlock
|
||||
"3352f512-3524-49ed-a08f-003042da2fc1", # PostToFacebookBlock
|
||||
"9e8f844e-b4a5-4b25-80f2-9e1dd7d67625", # PostToXBlock
|
||||
"589af4e4-507f-42fd-b9ac-a67ecef25811", # PostToLinkedInBlock
|
||||
"89b02b96-a7cb-46f4-9900-c48b32fe1552", # PostToInstagramBlock
|
||||
"0082d712-ff1b-4c3d-8a8d-6c7721883b83", # PostToYouTubeBlock
|
||||
"c7733580-3c72-483e-8e47-a8d58754d853", # PostToRedditBlock
|
||||
"47bc74eb-4af2-452c-b933-af377c7287df", # PostToTelegramBlock
|
||||
"2c38c783-c484-4503-9280-ef5d1d345a7e", # PostToGMBBlock
|
||||
"3ca46e05-dbaa-4afb-9e95-5a429c4177e6", # PostToPinterestBlock
|
||||
"7faf4b27-96b0-4f05-bf64-e0de54ae74e1", # PostToTikTokBlock
|
||||
"f8c3b2e1-9d4a-4e5f-8c7b-6a9e8d2f1c3b", # PostToThreadsBlock
|
||||
"a9d7f854-2c83-4e96-b3a1-7f2e9c5d4b8e", # PostToSnapchatBlock
|
||||
]
|
||||
@@ -1,152 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.block import BlockSchema
|
||||
from backend.data.model import SchemaField, UserIntegrations
|
||||
from backend.integrations.ayrshare import AyrshareClient
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
|
||||
|
||||
async def get_profile_key(user_id: str):
|
||||
user_integrations: UserIntegrations = (
|
||||
await get_database_manager_async_client().get_user_integrations(user_id)
|
||||
)
|
||||
return user_integrations.managed_credentials.ayrshare_profile_key
|
||||
|
||||
|
||||
class BaseAyrshareInput(BlockSchema):
|
||||
"""Base input model for Ayrshare social media posts with common fields."""
|
||||
|
||||
post: str = SchemaField(
|
||||
description="The post text to be published", default="", advanced=False
|
||||
)
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs to include. Set is_video in advanced settings to true if you want to upload videos.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
is_video: bool = SchemaField(
|
||||
description="Whether the media is a video", default=False, advanced=True
|
||||
)
|
||||
schedule_date: Optional[datetime] = SchemaField(
|
||||
description="UTC datetime for scheduling (YYYY-MM-DDThh:mm:ssZ)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
disable_comments: bool = SchemaField(
|
||||
description="Whether to disable comments", default=False, advanced=True
|
||||
)
|
||||
shorten_links: bool = SchemaField(
|
||||
description="Whether to shorten links", default=False, advanced=True
|
||||
)
|
||||
unsplash: Optional[str] = SchemaField(
|
||||
description="Unsplash image configuration", default=None, advanced=True
|
||||
)
|
||||
requires_approval: bool = SchemaField(
|
||||
description="Whether to enable approval workflow",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
random_post: bool = SchemaField(
|
||||
description="Whether to generate random post text",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
random_media_url: bool = SchemaField(
|
||||
description="Whether to generate random media", default=False, advanced=True
|
||||
)
|
||||
notes: Optional[str] = SchemaField(
|
||||
description="Additional notes for the post", default=None, advanced=True
|
||||
)
|
||||
|
||||
|
||||
class CarouselItem(BaseModel):
|
||||
"""Model for Facebook carousel items."""
|
||||
|
||||
name: str = Field(..., description="The name of the item")
|
||||
link: str = Field(..., description="The link of the item")
|
||||
picture: str = Field(..., description="The picture URL of the item")
|
||||
|
||||
|
||||
class CallToAction(BaseModel):
|
||||
"""Model for Google My Business Call to Action."""
|
||||
|
||||
action_type: str = Field(
|
||||
..., description="Type of action (book, order, shop, learn_more, sign_up, call)"
|
||||
)
|
||||
url: Optional[str] = Field(
|
||||
description="URL for the action (not required for 'call' action)"
|
||||
)
|
||||
|
||||
|
||||
class EventDetails(BaseModel):
|
||||
"""Model for Google My Business Event details."""
|
||||
|
||||
title: str = Field(..., description="Event title")
|
||||
start_date: str = Field(..., description="Event start date (ISO format)")
|
||||
end_date: str = Field(..., description="Event end date (ISO format)")
|
||||
|
||||
|
||||
class OfferDetails(BaseModel):
|
||||
"""Model for Google My Business Offer details."""
|
||||
|
||||
title: str = Field(..., description="Offer title")
|
||||
start_date: str = Field(..., description="Offer start date (ISO format)")
|
||||
end_date: str = Field(..., description="Offer end date (ISO format)")
|
||||
coupon_code: str = Field(..., description="Coupon code (max 58 characters)")
|
||||
redeem_online_url: str = Field(..., description="URL to redeem the offer")
|
||||
terms_conditions: str = Field(..., description="Terms and conditions")
|
||||
|
||||
|
||||
class InstagramUserTag(BaseModel):
|
||||
"""Model for Instagram user tags."""
|
||||
|
||||
username: str = Field(..., description="Instagram username (without @)")
|
||||
x: Optional[float] = Field(description="X coordinate (0.0-1.0) for image posts")
|
||||
y: Optional[float] = Field(description="Y coordinate (0.0-1.0) for image posts")
|
||||
|
||||
|
||||
class LinkedInTargeting(BaseModel):
|
||||
"""Model for LinkedIn audience targeting."""
|
||||
|
||||
countries: Optional[list[str]] = Field(
|
||||
description="Country codes (e.g., ['US', 'IN', 'DE', 'GB'])"
|
||||
)
|
||||
seniorities: Optional[list[str]] = Field(
|
||||
description="Seniority levels (e.g., ['Senior', 'VP'])"
|
||||
)
|
||||
degrees: Optional[list[str]] = Field(description="Education degrees")
|
||||
fields_of_study: Optional[list[str]] = Field(description="Fields of study")
|
||||
industries: Optional[list[str]] = Field(description="Industry categories")
|
||||
job_functions: Optional[list[str]] = Field(description="Job function categories")
|
||||
staff_count_ranges: Optional[list[str]] = Field(description="Company size ranges")
|
||||
|
||||
|
||||
class PinterestCarouselOption(BaseModel):
|
||||
"""Model for Pinterest carousel image options."""
|
||||
|
||||
title: Optional[str] = Field(description="Image title")
|
||||
link: Optional[str] = Field(description="External destination link for the image")
|
||||
description: Optional[str] = Field(description="Image description")
|
||||
|
||||
|
||||
class YouTubeTargeting(BaseModel):
|
||||
"""Model for YouTube country targeting."""
|
||||
|
||||
block: Optional[list[str]] = Field(
|
||||
description="Country codes to block (e.g., ['US', 'CA'])"
|
||||
)
|
||||
allow: Optional[list[str]] = Field(
|
||||
description="Country codes to allow (e.g., ['GB', 'AU'])"
|
||||
)
|
||||
|
||||
|
||||
def create_ayrshare_client():
|
||||
"""Create an Ayrshare client instance."""
|
||||
try:
|
||||
return AyrshareClient()
|
||||
except MissingConfigError:
|
||||
return None
|
||||
@@ -1,114 +0,0 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToBlueskyBlock(Block):
|
||||
"""Block for posting to Bluesky with Bluesky-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Bluesky posts."""
|
||||
|
||||
# Override post field to include character limit information
|
||||
post: str = SchemaField(
|
||||
description="The post text to be published (max 300 characters for Bluesky)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include Bluesky-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs to include. Bluesky supports up to 4 images or 1 video.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Bluesky-specific options
|
||||
alt_text: list[str] = SchemaField(
|
||||
description="Alt text for each media item (accessibility)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="cbd52c2a-06d2-43ed-9560-6576cc163283",
|
||||
description="Post to Bluesky using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToBlueskyBlock.Input,
|
||||
output_schema=PostToBlueskyBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToBlueskyBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Bluesky with Bluesky-specific options."""
|
||||
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate character limit for Bluesky
|
||||
if len(input_data.post) > 300:
|
||||
yield "error", f"Post text exceeds Bluesky's 300 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
# Validate media constraints for Bluesky
|
||||
if len(input_data.media_urls) > 4:
|
||||
yield "error", "Bluesky supports a maximum of 4 images or 1 video"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build Bluesky-specific options
|
||||
bluesky_options = {}
|
||||
if input_data.alt_text:
|
||||
bluesky_options["altText"] = input_data.alt_text
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.BLUESKY],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
bluesky_options=bluesky_options if bluesky_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,212 +0,0 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
CarouselItem,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
|
||||
|
||||
class PostToFacebookBlock(Block):
|
||||
"""Block for posting to Facebook with Facebook-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Facebook posts."""
|
||||
|
||||
# Facebook-specific options
|
||||
is_carousel: bool = SchemaField(
|
||||
description="Whether to post a carousel", default=False, advanced=True
|
||||
)
|
||||
carousel_link: str = SchemaField(
|
||||
description="The URL for the 'See More At' button in the carousel",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
carousel_items: list[CarouselItem] = SchemaField(
|
||||
description="List of carousel items with name, link and picture URLs. Min 2, max 10 items.",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
is_reels: bool = SchemaField(
|
||||
description="Whether to post to Facebook Reels",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
reels_title: str = SchemaField(
|
||||
description="Title for the Reels video (max 255 chars)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
reels_thumbnail: str = SchemaField(
|
||||
description="Thumbnail URL for Reels video (JPEG/PNG, <10MB)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
is_story: bool = SchemaField(
|
||||
description="Whether to post as a Facebook Story",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
media_captions: list[str] = SchemaField(
|
||||
description="Captions for each media item",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
location_id: str = SchemaField(
|
||||
description="Facebook Page ID or name for location tagging",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
age_min: int = SchemaField(
|
||||
description="Minimum age for audience targeting (13,15,18,21,25)",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
target_countries: list[str] = SchemaField(
|
||||
description="List of country codes to target (max 25)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
alt_text: list[str] = SchemaField(
|
||||
description="Alt text for each media item",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
video_title: str = SchemaField(
|
||||
description="Title for video post", default="", advanced=True
|
||||
)
|
||||
video_thumbnail: str = SchemaField(
|
||||
description="Thumbnail URL for video post", default="", advanced=True
|
||||
)
|
||||
is_draft: bool = SchemaField(
|
||||
description="Save as draft in Meta Business Suite",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
scheduled_publish_date: str = SchemaField(
|
||||
description="Schedule publish time in Meta Business Suite (UTC)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
preview_link: str = SchemaField(
|
||||
description="URL for custom link preview", default="", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="3352f512-3524-49ed-a08f-003042da2fc1",
|
||||
description="Post to Facebook using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToFacebookBlock.Input,
|
||||
output_schema=PostToFacebookBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToFacebookBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Facebook with Facebook-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build Facebook-specific options
|
||||
facebook_options = {}
|
||||
if input_data.is_carousel:
|
||||
facebook_options["isCarousel"] = True
|
||||
if input_data.carousel_link:
|
||||
facebook_options["carouselLink"] = input_data.carousel_link
|
||||
if input_data.carousel_items:
|
||||
facebook_options["carouselItems"] = [
|
||||
item.dict() for item in input_data.carousel_items
|
||||
]
|
||||
|
||||
if input_data.is_reels:
|
||||
facebook_options["isReels"] = True
|
||||
if input_data.reels_title:
|
||||
facebook_options["reelsTitle"] = input_data.reels_title
|
||||
if input_data.reels_thumbnail:
|
||||
facebook_options["reelsThumbnail"] = input_data.reels_thumbnail
|
||||
|
||||
if input_data.is_story:
|
||||
facebook_options["isStory"] = True
|
||||
|
||||
if input_data.media_captions:
|
||||
facebook_options["mediaCaptions"] = input_data.media_captions
|
||||
|
||||
if input_data.location_id:
|
||||
facebook_options["locationId"] = input_data.location_id
|
||||
|
||||
if input_data.age_min > 0:
|
||||
facebook_options["ageMin"] = input_data.age_min
|
||||
|
||||
if input_data.target_countries:
|
||||
facebook_options["targetCountries"] = input_data.target_countries
|
||||
|
||||
if input_data.alt_text:
|
||||
facebook_options["altText"] = input_data.alt_text
|
||||
|
||||
if input_data.video_title:
|
||||
facebook_options["videoTitle"] = input_data.video_title
|
||||
|
||||
if input_data.video_thumbnail:
|
||||
facebook_options["videoThumbnail"] = input_data.video_thumbnail
|
||||
|
||||
if input_data.is_draft:
|
||||
facebook_options["isDraft"] = True
|
||||
|
||||
if input_data.scheduled_publish_date:
|
||||
facebook_options["scheduledPublishDate"] = input_data.scheduled_publish_date
|
||||
|
||||
if input_data.preview_link:
|
||||
facebook_options["previewLink"] = input_data.preview_link
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.FACEBOOK],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
facebook_options=facebook_options if facebook_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,210 +0,0 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToGMBBlock(Block):
|
||||
"""Block for posting to Google My Business with GMB-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Google My Business posts."""
|
||||
|
||||
# Override media_urls to include GMB-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs. GMB supports only one image or video per post.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# GMB-specific options
|
||||
is_photo_video: bool = SchemaField(
|
||||
description="Whether this is a photo/video post (appears in Photos section)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
photo_category: str = SchemaField(
|
||||
description="Category for photo/video: cover, profile, logo, exterior, interior, product, at_work, food_and_drink, menu, common_area, rooms, teams",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
# Call to action options (flattened from CallToAction object)
|
||||
call_to_action_type: str = SchemaField(
|
||||
description="Type of action button: 'book', 'order', 'shop', 'learn_more', 'sign_up', or 'call'",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
call_to_action_url: str = SchemaField(
|
||||
description="URL for the action button (not required for 'call' action)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
# Event details options (flattened from EventDetails object)
|
||||
event_title: str = SchemaField(
|
||||
description="Event title for event posts",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
event_start_date: str = SchemaField(
|
||||
description="Event start date in ISO format (e.g., '2024-03-15T09:00:00Z')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
event_end_date: str = SchemaField(
|
||||
description="Event end date in ISO format (e.g., '2024-03-15T17:00:00Z')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
# Offer details options (flattened from OfferDetails object)
|
||||
offer_title: str = SchemaField(
|
||||
description="Offer title for promotional posts",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
offer_start_date: str = SchemaField(
|
||||
description="Offer start date in ISO format (e.g., '2024-03-15T00:00:00Z')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
offer_end_date: str = SchemaField(
|
||||
description="Offer end date in ISO format (e.g., '2024-04-15T23:59:59Z')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
offer_coupon_code: str = SchemaField(
|
||||
description="Coupon code for the offer (max 58 characters)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
offer_redeem_online_url: str = SchemaField(
|
||||
description="URL where customers can redeem the offer online",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
offer_terms_conditions: str = SchemaField(
|
||||
description="Terms and conditions for the offer",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="2c38c783-c484-4503-9280-ef5d1d345a7e",
|
||||
description="Post to Google My Business using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToGMBBlock.Input,
|
||||
output_schema=PostToGMBBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: "PostToGMBBlock.Input", *, user_id: str, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""Post to Google My Business with GMB-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate GMB constraints
|
||||
if len(input_data.media_urls) > 1:
|
||||
yield "error", "Google My Business supports only one image or video per post"
|
||||
return
|
||||
|
||||
# Validate offer coupon code length
|
||||
if input_data.offer_coupon_code and len(input_data.offer_coupon_code) > 58:
|
||||
yield "error", "GMB offer coupon code cannot exceed 58 characters"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build GMB-specific options
|
||||
gmb_options = {}
|
||||
|
||||
# Photo/Video post options
|
||||
if input_data.is_photo_video:
|
||||
gmb_options["isPhotoVideo"] = True
|
||||
if input_data.photo_category:
|
||||
gmb_options["category"] = input_data.photo_category
|
||||
|
||||
# Call to Action (from flattened fields)
|
||||
if input_data.call_to_action_type:
|
||||
cta_dict = {"actionType": input_data.call_to_action_type}
|
||||
# URL not required for 'call' action type
|
||||
if (
|
||||
input_data.call_to_action_type != "call"
|
||||
and input_data.call_to_action_url
|
||||
):
|
||||
cta_dict["url"] = input_data.call_to_action_url
|
||||
gmb_options["callToAction"] = cta_dict
|
||||
|
||||
# Event details (from flattened fields)
|
||||
if (
|
||||
input_data.event_title
|
||||
and input_data.event_start_date
|
||||
and input_data.event_end_date
|
||||
):
|
||||
gmb_options["event"] = {
|
||||
"title": input_data.event_title,
|
||||
"startDate": input_data.event_start_date,
|
||||
"endDate": input_data.event_end_date,
|
||||
}
|
||||
|
||||
# Offer details (from flattened fields)
|
||||
if (
|
||||
input_data.offer_title
|
||||
and input_data.offer_start_date
|
||||
and input_data.offer_end_date
|
||||
and input_data.offer_coupon_code
|
||||
and input_data.offer_redeem_online_url
|
||||
and input_data.offer_terms_conditions
|
||||
):
|
||||
gmb_options["offer"] = {
|
||||
"title": input_data.offer_title,
|
||||
"startDate": input_data.offer_start_date,
|
||||
"endDate": input_data.offer_end_date,
|
||||
"couponCode": input_data.offer_coupon_code,
|
||||
"redeemOnlineUrl": input_data.offer_redeem_online_url,
|
||||
"termsConditions": input_data.offer_terms_conditions,
|
||||
}
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.GOOGLE_MY_BUSINESS],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
gmb_options=gmb_options if gmb_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,249 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
InstagramUserTag,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
|
||||
|
||||
class PostToInstagramBlock(Block):
|
||||
"""Block for posting to Instagram with Instagram-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Instagram posts."""
|
||||
|
||||
# Override post field to include Instagram-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (max 2,200 chars, up to 30 hashtags, 3 @mentions)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include Instagram-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs. Instagram supports up to 10 images/videos in a carousel.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Instagram-specific options
|
||||
is_story: bool | None = SchemaField(
|
||||
description="Whether to post as Instagram Story (24-hour expiration)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# ------- REELS OPTIONS -------
|
||||
share_reels_feed: bool | None = SchemaField(
|
||||
description="Whether Reel should appear in both Feed and Reels tabs",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
audio_name: str | None = SchemaField(
|
||||
description="Audio name for Reels (e.g., 'The Weeknd - Blinding Lights')",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
thumbnail: str | None = SchemaField(
|
||||
description="Thumbnail URL for Reel video", default=None, advanced=True
|
||||
)
|
||||
thumbnail_offset: int | None = SchemaField(
|
||||
description="Thumbnail frame offset in milliseconds (default: 0)",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# ------- POST OPTIONS -------
|
||||
|
||||
alt_text: list[str] = SchemaField(
|
||||
description="Alt text for each media item (up to 1,000 chars each, accessibility feature), each item in the list corresponds to a media item in the media_urls list",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
location_id: str | None = SchemaField(
|
||||
description="Facebook Page ID or name for location tagging (e.g., '7640348500' or '@guggenheimmuseum')",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
user_tags: list[dict[str, Any]] = SchemaField(
|
||||
description="List of users to tag with coordinates for images",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
collaborators: list[str] = SchemaField(
|
||||
description="Instagram usernames to invite as collaborators (max 3, public accounts only)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
auto_resize: bool | None = SchemaField(
|
||||
description="Auto-resize images to 1080x1080px for Instagram",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="89b02b96-a7cb-46f4-9900-c48b32fe1552",
|
||||
description="Post to Instagram using Ayrshare. Requires a Business or Creator Instagram Account connected with a Facebook Page",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToInstagramBlock.Input,
|
||||
output_schema=PostToInstagramBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToInstagramBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Instagram with Instagram-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate Instagram constraints
|
||||
if len(input_data.post) > 2200:
|
||||
yield "error", f"Instagram post text exceeds 2,200 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 10:
|
||||
yield "error", "Instagram supports a maximum of 10 images/videos in a carousel"
|
||||
return
|
||||
|
||||
if len(input_data.collaborators) > 3:
|
||||
yield "error", "Instagram supports a maximum of 3 collaborators"
|
||||
return
|
||||
|
||||
# Validate that if any reel option is set, all required reel options are set
|
||||
reel_options = [
|
||||
input_data.share_reels_feed,
|
||||
input_data.audio_name,
|
||||
input_data.thumbnail,
|
||||
]
|
||||
|
||||
if any(reel_options) and not all(reel_options):
|
||||
yield "error", "When posting a reel, all reel options must be set: share_reels_feed, audio_name, and either thumbnail or thumbnail_offset"
|
||||
return
|
||||
|
||||
# Count hashtags and mentions
|
||||
hashtag_count = input_data.post.count("#")
|
||||
mention_count = input_data.post.count("@")
|
||||
|
||||
if hashtag_count > 30:
|
||||
yield "error", f"Instagram allows maximum 30 hashtags ({hashtag_count} found)"
|
||||
return
|
||||
|
||||
if mention_count > 3:
|
||||
yield "error", f"Instagram allows maximum 3 @mentions ({mention_count} found)"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build Instagram-specific options
|
||||
instagram_options = {}
|
||||
|
||||
# Stories
|
||||
if input_data.is_story:
|
||||
instagram_options["stories"] = True
|
||||
|
||||
# Reels options
|
||||
if input_data.share_reels_feed is not None:
|
||||
instagram_options["shareReelsFeed"] = input_data.share_reels_feed
|
||||
|
||||
if input_data.audio_name:
|
||||
instagram_options["audioName"] = input_data.audio_name
|
||||
|
||||
if input_data.thumbnail:
|
||||
instagram_options["thumbNail"] = input_data.thumbnail
|
||||
elif input_data.thumbnail_offset and input_data.thumbnail_offset > 0:
|
||||
instagram_options["thumbNailOffset"] = input_data.thumbnail_offset
|
||||
|
||||
# Alt text
|
||||
if input_data.alt_text:
|
||||
# Validate alt text length
|
||||
for i, alt in enumerate(input_data.alt_text):
|
||||
if len(alt) > 1000:
|
||||
yield "error", f"Alt text {i+1} exceeds 1,000 character limit ({len(alt)} characters)"
|
||||
return
|
||||
instagram_options["altText"] = input_data.alt_text
|
||||
|
||||
# Location
|
||||
if input_data.location_id:
|
||||
instagram_options["locationId"] = input_data.location_id
|
||||
|
||||
# User tags
|
||||
if input_data.user_tags:
|
||||
user_tags_list = []
|
||||
for tag in input_data.user_tags:
|
||||
try:
|
||||
tag_obj = InstagramUserTag(**tag)
|
||||
except Exception as e:
|
||||
yield "error", f"Invalid user tag: {e}, tages need to be a dictionary with a 3 items: username (str), x (float) and y (float)"
|
||||
return
|
||||
tag_dict: dict[str, float | str] = {"username": tag_obj.username}
|
||||
if tag_obj.x is not None and tag_obj.y is not None:
|
||||
# Validate coordinates
|
||||
if not (0.0 <= tag_obj.x <= 1.0) or not (0.0 <= tag_obj.y <= 1.0):
|
||||
yield "error", f"User tag coordinates must be between 0.0 and 1.0 (user: {tag_obj.username})"
|
||||
return
|
||||
tag_dict["x"] = tag_obj.x
|
||||
tag_dict["y"] = tag_obj.y
|
||||
user_tags_list.append(tag_dict)
|
||||
instagram_options["userTags"] = user_tags_list
|
||||
|
||||
# Collaborators
|
||||
if input_data.collaborators:
|
||||
instagram_options["collaborators"] = input_data.collaborators
|
||||
|
||||
# Auto resize
|
||||
if input_data.auto_resize:
|
||||
instagram_options["autoResize"] = True
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.INSTAGRAM],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
instagram_options=instagram_options if instagram_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,222 +0,0 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToLinkedInBlock(Block):
|
||||
"""Block for posting to LinkedIn with LinkedIn-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for LinkedIn posts."""
|
||||
|
||||
# Override post field to include LinkedIn-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (max 3,000 chars, hashtags supported with #)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include LinkedIn-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs. LinkedIn supports up to 9 images, videos, or documents (PPT, PPTX, DOC, DOCX, PDF <100MB, <300 pages).",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# LinkedIn-specific options
|
||||
visibility: str = SchemaField(
|
||||
description="Post visibility: 'public' (default), 'connections' (personal only), 'loggedin'",
|
||||
default="public",
|
||||
advanced=True,
|
||||
)
|
||||
alt_text: list[str] = SchemaField(
|
||||
description="Alt text for each image (accessibility feature, not supported for videos/documents)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
titles: list[str] = SchemaField(
|
||||
description="Title/caption for each image or video",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
document_title: str = SchemaField(
|
||||
description="Title for document posts (max 400 chars, uses filename if not specified)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
thumbnail: str = SchemaField(
|
||||
description="Thumbnail URL for video (PNG/JPG, same dimensions as video, <10MB)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
# LinkedIn targeting options (flattened from LinkedInTargeting object)
|
||||
targeting_countries: list[str] | None = SchemaField(
|
||||
description="Country codes for targeting (e.g., ['US', 'IN', 'DE', 'GB']). Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_seniorities: list[str] | None = SchemaField(
|
||||
description="Seniority levels for targeting (e.g., ['Senior', 'VP']). Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_degrees: list[str] | None = SchemaField(
|
||||
description="Education degrees for targeting. Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_fields_of_study: list[str] | None = SchemaField(
|
||||
description="Fields of study for targeting. Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_industries: list[str] | None = SchemaField(
|
||||
description="Industry categories for targeting. Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_job_functions: list[str] | None = SchemaField(
|
||||
description="Job function categories for targeting. Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_staff_count_ranges: list[str] | None = SchemaField(
|
||||
description="Company size ranges for targeting. Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="589af4e4-507f-42fd-b9ac-a67ecef25811",
|
||||
description="Post to LinkedIn using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToLinkedInBlock.Input,
|
||||
output_schema=PostToLinkedInBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToLinkedInBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to LinkedIn with LinkedIn-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate LinkedIn constraints
|
||||
if len(input_data.post) > 3000:
|
||||
yield "error", f"LinkedIn post text exceeds 3,000 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 9:
|
||||
yield "error", "LinkedIn supports a maximum of 9 images/videos/documents"
|
||||
return
|
||||
|
||||
if input_data.document_title and len(input_data.document_title) > 400:
|
||||
yield "error", f"LinkedIn document title exceeds 400 character limit ({len(input_data.document_title)} characters)"
|
||||
return
|
||||
|
||||
# Validate visibility option
|
||||
valid_visibility = ["public", "connections", "loggedin"]
|
||||
if input_data.visibility not in valid_visibility:
|
||||
yield "error", f"LinkedIn visibility must be one of: {', '.join(valid_visibility)}"
|
||||
return
|
||||
|
||||
# Check for document extensions
|
||||
document_extensions = [".ppt", ".pptx", ".doc", ".docx", ".pdf"]
|
||||
has_documents = any(
|
||||
any(url.lower().endswith(ext) for ext in document_extensions)
|
||||
for url in input_data.media_urls
|
||||
)
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build LinkedIn-specific options
|
||||
linkedin_options = {}
|
||||
|
||||
# Visibility
|
||||
if input_data.visibility != "public":
|
||||
linkedin_options["visibility"] = input_data.visibility
|
||||
|
||||
# Alt text (not supported for videos or documents)
|
||||
if input_data.alt_text and not has_documents:
|
||||
linkedin_options["altText"] = input_data.alt_text
|
||||
|
||||
# Titles/captions
|
||||
if input_data.titles:
|
||||
linkedin_options["titles"] = input_data.titles
|
||||
|
||||
# Document title
|
||||
if input_data.document_title and has_documents:
|
||||
linkedin_options["title"] = input_data.document_title
|
||||
|
||||
# Video thumbnail
|
||||
if input_data.thumbnail:
|
||||
linkedin_options["thumbNail"] = input_data.thumbnail
|
||||
|
||||
# Audience targeting (from flattened fields)
|
||||
targeting_dict = {}
|
||||
if input_data.targeting_countries:
|
||||
targeting_dict["countries"] = input_data.targeting_countries
|
||||
if input_data.targeting_seniorities:
|
||||
targeting_dict["seniorities"] = input_data.targeting_seniorities
|
||||
if input_data.targeting_degrees:
|
||||
targeting_dict["degrees"] = input_data.targeting_degrees
|
||||
if input_data.targeting_fields_of_study:
|
||||
targeting_dict["fieldsOfStudy"] = input_data.targeting_fields_of_study
|
||||
if input_data.targeting_industries:
|
||||
targeting_dict["industries"] = input_data.targeting_industries
|
||||
if input_data.targeting_job_functions:
|
||||
targeting_dict["jobFunctions"] = input_data.targeting_job_functions
|
||||
if input_data.targeting_staff_count_ranges:
|
||||
targeting_dict["staffCountRanges"] = input_data.targeting_staff_count_ranges
|
||||
|
||||
if targeting_dict:
|
||||
linkedin_options["targeting"] = targeting_dict
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.LINKEDIN],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
linkedin_options=linkedin_options if linkedin_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,214 +0,0 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
PinterestCarouselOption,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
|
||||
|
||||
class PostToPinterestBlock(Block):
|
||||
"""Block for posting to Pinterest with Pinterest-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Pinterest posts."""
|
||||
|
||||
# Override post field to include Pinterest-specific information
|
||||
post: str = SchemaField(
|
||||
description="Pin description (max 500 chars, links not clickable - use link field instead)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include Pinterest-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Required image/video URLs. Pinterest requires at least one image. Videos need thumbnail. Up to 5 images for carousel.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Pinterest-specific options
|
||||
pin_title: str = SchemaField(
|
||||
description="Pin title displayed in 'Add your title' section (max 100 chars)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
link: str = SchemaField(
|
||||
description="Clickable destination URL when users click the pin (max 2048 chars)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
board_id: str = SchemaField(
|
||||
description="Pinterest Board ID to post to (from /user/details endpoint, uses default board if not specified)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
note: str = SchemaField(
|
||||
description="Private note for the pin (only visible to you and board collaborators)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
thumbnail: str = SchemaField(
|
||||
description="Required thumbnail URL for video pins (must have valid image Content-Type)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
carousel_options: list[PinterestCarouselOption] = SchemaField(
|
||||
description="Options for each image in carousel (title, link, description per image)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
alt_text: list[str] = SchemaField(
|
||||
description="Alt text for each image/video (max 500 chars each, accessibility feature)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="3ca46e05-dbaa-4afb-9e95-5a429c4177e6",
|
||||
description="Post to Pinterest using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToPinterestBlock.Input,
|
||||
output_schema=PostToPinterestBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToPinterestBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Pinterest with Pinterest-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate Pinterest constraints
|
||||
if len(input_data.post) > 500:
|
||||
yield "error", f"Pinterest pin description exceeds 500 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.pin_title) > 100:
|
||||
yield "error", f"Pinterest pin title exceeds 100 character limit ({len(input_data.pin_title)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.link) > 2048:
|
||||
yield "error", f"Pinterest link URL exceeds 2048 character limit ({len(input_data.link)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) == 0:
|
||||
yield "error", "Pinterest requires at least one image or video"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 5:
|
||||
yield "error", "Pinterest supports a maximum of 5 images in a carousel"
|
||||
return
|
||||
|
||||
# Check if video is included and thumbnail is provided
|
||||
video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".webm"]
|
||||
has_video = any(
|
||||
any(url.lower().endswith(ext) for ext in video_extensions)
|
||||
for url in input_data.media_urls
|
||||
)
|
||||
|
||||
if (has_video or input_data.is_video) and not input_data.thumbnail:
|
||||
yield "error", "Pinterest video pins require a thumbnail URL"
|
||||
return
|
||||
|
||||
# Validate alt text length
|
||||
for i, alt in enumerate(input_data.alt_text):
|
||||
if len(alt) > 500:
|
||||
yield "error", f"Pinterest alt text {i+1} exceeds 500 character limit ({len(alt)} characters)"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build Pinterest-specific options
|
||||
pinterest_options = {}
|
||||
|
||||
# Pin title
|
||||
if input_data.pin_title:
|
||||
pinterest_options["title"] = input_data.pin_title
|
||||
|
||||
# Clickable link
|
||||
if input_data.link:
|
||||
pinterest_options["link"] = input_data.link
|
||||
|
||||
# Board ID
|
||||
if input_data.board_id:
|
||||
pinterest_options["boardId"] = input_data.board_id
|
||||
|
||||
# Private note
|
||||
if input_data.note:
|
||||
pinterest_options["note"] = input_data.note
|
||||
|
||||
# Video thumbnail
|
||||
if input_data.thumbnail:
|
||||
pinterest_options["thumbNail"] = input_data.thumbnail
|
||||
|
||||
# Carousel options
|
||||
if input_data.carousel_options:
|
||||
carousel_list = []
|
||||
for option in input_data.carousel_options:
|
||||
carousel_dict = {}
|
||||
if option.title:
|
||||
carousel_dict["title"] = option.title
|
||||
if option.link:
|
||||
carousel_dict["link"] = option.link
|
||||
if option.description:
|
||||
carousel_dict["description"] = option.description
|
||||
if carousel_dict: # Only add if not empty
|
||||
carousel_list.append(carousel_dict)
|
||||
if carousel_list:
|
||||
pinterest_options["carouselOptions"] = carousel_list
|
||||
|
||||
# Alt text
|
||||
if input_data.alt_text:
|
||||
pinterest_options["altText"] = input_data.alt_text
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.PINTEREST],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
pinterest_options=pinterest_options if pinterest_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,69 +0,0 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToRedditBlock(Block):
|
||||
"""Block for posting to Reddit."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Reddit posts."""
|
||||
|
||||
pass # Uses all base fields
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="c7733580-3c72-483e-8e47-a8d58754d853",
|
||||
description="Post to Reddit using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToRedditBlock.Input,
|
||||
output_schema=PostToRedditBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: "PostToRedditBlock.Input", *, user_id: str, **kwargs
|
||||
) -> BlockOutput:
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured."
|
||||
return
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.REDDIT],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,129 +0,0 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToSnapchatBlock(Block):
|
||||
"""Block for posting to Snapchat with Snapchat-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Snapchat posts."""
|
||||
|
||||
# Override post field to include Snapchat-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (optional for video-only content)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include Snapchat-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Required video URL for Snapchat posts. Snapchat only supports video content.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Snapchat-specific options
|
||||
story_type: str = SchemaField(
|
||||
description="Type of Snapchat content: 'story' (24-hour Stories), 'saved_story' (Saved Stories), or 'spotlight' (Spotlight posts)",
|
||||
default="story",
|
||||
advanced=True,
|
||||
)
|
||||
video_thumbnail: str = SchemaField(
|
||||
description="Thumbnail URL for video content (optional, auto-generated if not provided)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="a9d7f854-2c83-4e96-b3a1-7f2e9c5d4b8e",
|
||||
description="Post to Snapchat using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToSnapchatBlock.Input,
|
||||
output_schema=PostToSnapchatBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToSnapchatBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Snapchat with Snapchat-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate Snapchat constraints
|
||||
if not input_data.media_urls:
|
||||
yield "error", "Snapchat requires at least one video URL"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 1:
|
||||
yield "error", "Snapchat supports only one video per post"
|
||||
return
|
||||
|
||||
# Validate story type
|
||||
valid_story_types = ["story", "saved_story", "spotlight"]
|
||||
if input_data.story_type not in valid_story_types:
|
||||
yield "error", f"Snapchat story type must be one of: {', '.join(valid_story_types)}"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build Snapchat-specific options
|
||||
snapchat_options = {}
|
||||
|
||||
# Story type
|
||||
if input_data.story_type != "story":
|
||||
snapchat_options["storyType"] = input_data.story_type
|
||||
|
||||
# Video thumbnail
|
||||
if input_data.video_thumbnail:
|
||||
snapchat_options["videoThumbnail"] = input_data.video_thumbnail
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.SNAPCHAT],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=True, # Snapchat only supports video
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
snapchat_options=snapchat_options if snapchat_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,116 +0,0 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToTelegramBlock(Block):
|
||||
"""Block for posting to Telegram with Telegram-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Telegram posts."""
|
||||
|
||||
# Override post field to include Telegram-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (empty string allowed). Use @handle to mention other Telegram users.",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include Telegram-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs. For animated GIFs, only one URL is allowed. Telegram will auto-preview links unless image/video is included.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override is_video to include GIF-specific information
|
||||
is_video: bool = SchemaField(
|
||||
description="Whether the media is a video. Set to true for animated GIFs that don't end in .gif/.GIF extension.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="47bc74eb-4af2-452c-b933-af377c7287df",
|
||||
description="Post to Telegram using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToTelegramBlock.Input,
|
||||
output_schema=PostToTelegramBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToTelegramBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Telegram with Telegram-specific validation."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate Telegram constraints
|
||||
# Check for animated GIFs - only one URL allowed
|
||||
gif_extensions = [".gif", ".GIF"]
|
||||
has_gif = any(
|
||||
any(url.endswith(ext) for ext in gif_extensions)
|
||||
for url in input_data.media_urls
|
||||
)
|
||||
|
||||
if has_gif and len(input_data.media_urls) > 1:
|
||||
yield "error", "Telegram animated GIFs support only one URL per post"
|
||||
return
|
||||
|
||||
# Auto-detect if we need to set is_video for GIFs without proper extension
|
||||
detected_is_video = input_data.is_video
|
||||
if input_data.media_urls and not has_gif and not input_data.is_video:
|
||||
# Check if this might be a GIF without proper extension
|
||||
# This is just informational - user needs to set is_video manually
|
||||
pass
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.TELEGRAM],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=detected_is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,111 +0,0 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToThreadsBlock(Block):
|
||||
"""Block for posting to Threads with Threads-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Threads posts."""
|
||||
|
||||
# Override post field to include Threads-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (max 500 chars, empty string allowed). Only 1 hashtag allowed. Use @handle to mention users.",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include Threads-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs. Supports up to 20 images/videos in a carousel. Auto-preview links unless media is included.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="f8c3b2e1-9d4a-4e5f-8c7b-6a9e8d2f1c3b",
|
||||
description="Post to Threads using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToThreadsBlock.Input,
|
||||
output_schema=PostToThreadsBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToThreadsBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Threads with Threads-specific validation."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate Threads constraints
|
||||
if len(input_data.post) > 500:
|
||||
yield "error", f"Threads post text exceeds 500 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 20:
|
||||
yield "error", "Threads supports a maximum of 20 images/videos in a carousel"
|
||||
return
|
||||
|
||||
# Count hashtags (only 1 allowed)
|
||||
hashtag_count = input_data.post.count("#")
|
||||
if hashtag_count > 1:
|
||||
yield "error", f"Threads allows only 1 hashtag per post ({hashtag_count} found)"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build Threads-specific options
|
||||
threads_options = {}
|
||||
# Note: Based on the documentation, Threads doesn't seem to have specific options
|
||||
# beyond the standard ones. The main constraints are validation-based.
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.THREADS],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
threads_options=threads_options if threads_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,243 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class TikTokVisibility(str, Enum):
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
||||
FOLLOWERS = "followers"
|
||||
|
||||
|
||||
class PostToTikTokBlock(Block):
|
||||
"""Block for posting to TikTok with TikTok-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for TikTok posts."""
|
||||
|
||||
# Override post field to include TikTok-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (max 2,200 chars, empty string allowed). Use @handle to mention users. Line breaks will be ignored.",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include TikTok-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Required media URLs. Either 1 video OR up to 35 images (JPG/JPEG/WEBP only). Cannot mix video and images.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# TikTok-specific options
|
||||
auto_add_music: bool = SchemaField(
|
||||
description="Whether to automatically add recommended music to the post. If you set this field to true, you can change the music later in the TikTok app.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
disable_comments: bool = SchemaField(
|
||||
description="Disable comments on the published post",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
disable_duet: bool = SchemaField(
|
||||
description="Disable duets on published video (video only)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
disable_stitch: bool = SchemaField(
|
||||
description="Disable stitch on published video (video only)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
is_ai_generated: bool = SchemaField(
|
||||
description="If you enable the toggle, your video will be labeled as “Creator labeled as AI-generated” once posted and can’t be changed. The “Creator labeled as AI-generated” label indicates that the content was completely AI-generated or significantly edited with AI.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
is_branded_content: bool = SchemaField(
|
||||
description="Whether to enable the Branded Content toggle. If this field is set to true, the video will be labeled as Branded Content, indicating you are in a paid partnership with a brand. A “Paid partnership” label will be attached to the video.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
is_brand_organic: bool = SchemaField(
|
||||
description="Whether to enable the Brand Organic Content toggle. If this field is set to true, the video will be labeled as Brand Organic Content, indicating you are promoting yourself or your own business. A “Promotional content” label will be attached to the video.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
image_cover_index: int = SchemaField(
|
||||
description="Index of image to use as cover (0-based, image posts only)",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
title: str = SchemaField(
|
||||
description="Title for image posts", default="", advanced=True
|
||||
)
|
||||
thumbnail_offset: int = SchemaField(
|
||||
description="Video thumbnail frame offset in milliseconds (video only)",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
visibility: TikTokVisibility = SchemaField(
|
||||
description="Post visibility: 'public', 'private', 'followers', or 'friends'",
|
||||
default=TikTokVisibility.PUBLIC,
|
||||
advanced=True,
|
||||
)
|
||||
draft: bool = SchemaField(
|
||||
description="Create as draft post (video only)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7faf4b27-96b0-4f05-bf64-e0de54ae74e1",
|
||||
description="Post to TikTok using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToTikTokBlock.Input,
|
||||
output_schema=PostToTikTokBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: "PostToTikTokBlock.Input", *, user_id: str, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""Post to TikTok with TikTok-specific validation and options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate TikTok constraints
|
||||
if len(input_data.post) > 2200:
|
||||
yield "error", f"TikTok post text exceeds 2,200 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
if not input_data.media_urls:
|
||||
yield "error", "TikTok requires at least one media URL (either 1 video or up to 35 images)"
|
||||
return
|
||||
|
||||
# Check for video vs image constraints
|
||||
video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".webm"]
|
||||
image_extensions = [".jpg", ".jpeg", ".webp"]
|
||||
|
||||
has_video = input_data.is_video or any(
|
||||
any(url.lower().endswith(ext) for ext in video_extensions)
|
||||
for url in input_data.media_urls
|
||||
)
|
||||
|
||||
has_images = any(
|
||||
any(url.lower().endswith(ext) for ext in image_extensions)
|
||||
for url in input_data.media_urls
|
||||
)
|
||||
|
||||
if has_video and has_images:
|
||||
yield "error", "TikTok does not support mixing video and images in the same post"
|
||||
return
|
||||
|
||||
if has_video and len(input_data.media_urls) > 1:
|
||||
yield "error", "TikTok supports only 1 video per post"
|
||||
return
|
||||
|
||||
if has_images and len(input_data.media_urls) > 35:
|
||||
yield "error", "TikTok supports a maximum of 35 images per post"
|
||||
return
|
||||
|
||||
# Validate image cover index
|
||||
if has_images and input_data.image_cover_index >= len(input_data.media_urls):
|
||||
yield "error", f"Image cover index {input_data.image_cover_index} is out of range (max: {len(input_data.media_urls) - 1})"
|
||||
return
|
||||
|
||||
# Check for PNG files (not supported)
|
||||
has_png = any(url.lower().endswith(".png") for url in input_data.media_urls)
|
||||
if has_png:
|
||||
yield "error", "TikTok does not support PNG files. Please use JPG, JPEG, or WEBP for images."
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build TikTok-specific options
|
||||
tiktok_options = {}
|
||||
|
||||
# Common options
|
||||
if input_data.auto_add_music and has_images:
|
||||
tiktok_options["autoAddMusic"] = True
|
||||
|
||||
if input_data.disable_comments:
|
||||
tiktok_options["disableComments"] = True
|
||||
|
||||
if input_data.is_branded_content:
|
||||
tiktok_options["isBrandedContent"] = True
|
||||
|
||||
if input_data.is_brand_organic:
|
||||
tiktok_options["isBrandOrganic"] = True
|
||||
|
||||
# Video-specific options
|
||||
if has_video:
|
||||
if input_data.disable_duet:
|
||||
tiktok_options["disableDuet"] = True
|
||||
|
||||
if input_data.disable_stitch:
|
||||
tiktok_options["disableStitch"] = True
|
||||
|
||||
if input_data.is_ai_generated:
|
||||
tiktok_options["isAIGenerated"] = True
|
||||
|
||||
if input_data.thumbnail_offset > 0:
|
||||
tiktok_options["thumbNailOffset"] = input_data.thumbnail_offset
|
||||
|
||||
if input_data.draft:
|
||||
tiktok_options["draft"] = True
|
||||
|
||||
# Image-specific options
|
||||
if has_images:
|
||||
if input_data.image_cover_index > 0:
|
||||
tiktok_options["imageCoverIndex"] = input_data.image_cover_index
|
||||
|
||||
if input_data.title:
|
||||
tiktok_options["title"] = input_data.title
|
||||
|
||||
if input_data.visibility != TikTokVisibility.PUBLIC:
|
||||
tiktok_options["visibility"] = input_data.visibility.value
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.TIKTOK],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=has_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
tiktok_options=tiktok_options if tiktok_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,241 +0,0 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToXBlock(Block):
|
||||
"""Block for posting to X / Twitter with Twitter-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for X / Twitter posts."""
|
||||
|
||||
# Override post field to include X-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (max 280 chars, up to 25,000 for Premium users). Use @handle to mention users. Use \\n\\n for thread breaks.",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include X-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs. X supports up to 4 images or videos per tweet. Auto-preview links unless media is included.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# X-specific options
|
||||
reply_to_id: str | None = SchemaField(
|
||||
description="ID of the tweet to reply to",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
quote_tweet_id: str | None = SchemaField(
|
||||
description="ID of the tweet to quote (low-level Tweet ID)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
poll_options: list[str] = SchemaField(
|
||||
description="Poll options (2-4 choices)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
poll_duration: int = SchemaField(
|
||||
description="Poll duration in minutes (1-10080)",
|
||||
default=1440,
|
||||
advanced=True,
|
||||
)
|
||||
alt_text: list[str] = SchemaField(
|
||||
description="Alt text for each image (max 1,000 chars each, not supported for videos)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
is_thread: bool = SchemaField(
|
||||
description="Whether to automatically break post into thread based on line breaks",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
thread_number: bool = SchemaField(
|
||||
description="Add thread numbers (1/n format) to each thread post",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
thread_media_urls: list[str] = SchemaField(
|
||||
description="Media URLs for thread posts (one per thread, use 'null' to skip)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
long_post: bool = SchemaField(
|
||||
description="Force long form post (requires Premium X account)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
long_video: bool = SchemaField(
|
||||
description="Enable long video upload (requires approval and Business/Enterprise plan)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
subtitle_url: str = SchemaField(
|
||||
description="URL to SRT subtitle file for videos (must be HTTPS and end in .srt)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
subtitle_language: str = SchemaField(
|
||||
description="Language code for subtitles (default: 'en')",
|
||||
default="en",
|
||||
advanced=True,
|
||||
)
|
||||
subtitle_name: str = SchemaField(
|
||||
description="Name of caption track (max 150 chars, default: 'English')",
|
||||
default="English",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9e8f844e-b4a5-4b25-80f2-9e1dd7d67625",
|
||||
description="Post to X / Twitter using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToXBlock.Input,
|
||||
output_schema=PostToXBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToXBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to X / Twitter with enhanced X-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate X constraints
|
||||
if not input_data.long_post and len(input_data.post) > 280:
|
||||
yield "error", f"X post text exceeds 280 character limit ({len(input_data.post)} characters). Enable 'long_post' for Premium accounts."
|
||||
return
|
||||
|
||||
if input_data.long_post and len(input_data.post) > 25000:
|
||||
yield "error", f"X long post text exceeds 25,000 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 4:
|
||||
yield "error", "X supports a maximum of 4 images or videos per tweet"
|
||||
return
|
||||
|
||||
# Validate poll options
|
||||
if input_data.poll_options:
|
||||
if len(input_data.poll_options) < 2 or len(input_data.poll_options) > 4:
|
||||
yield "error", "X polls require 2-4 options"
|
||||
return
|
||||
|
||||
if input_data.poll_duration < 1 or input_data.poll_duration > 10080:
|
||||
yield "error", "X poll duration must be between 1 and 10,080 minutes (7 days)"
|
||||
return
|
||||
|
||||
# Validate alt text
|
||||
if input_data.alt_text:
|
||||
for i, alt in enumerate(input_data.alt_text):
|
||||
if len(alt) > 1000:
|
||||
yield "error", f"X alt text {i+1} exceeds 1,000 character limit ({len(alt)} characters)"
|
||||
return
|
||||
|
||||
# Validate subtitle settings
|
||||
if input_data.subtitle_url:
|
||||
if not input_data.subtitle_url.startswith(
|
||||
"https://"
|
||||
) or not input_data.subtitle_url.endswith(".srt"):
|
||||
yield "error", "Subtitle URL must start with https:// and end with .srt"
|
||||
return
|
||||
|
||||
if len(input_data.subtitle_name) > 150:
|
||||
yield "error", f"Subtitle name exceeds 150 character limit ({len(input_data.subtitle_name)} characters)"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build X-specific options
|
||||
twitter_options = {}
|
||||
|
||||
# Basic options
|
||||
if input_data.reply_to_id:
|
||||
twitter_options["replyToId"] = input_data.reply_to_id
|
||||
|
||||
if input_data.quote_tweet_id:
|
||||
twitter_options["quoteTweetId"] = input_data.quote_tweet_id
|
||||
|
||||
if input_data.long_post:
|
||||
twitter_options["longPost"] = True
|
||||
|
||||
if input_data.long_video:
|
||||
twitter_options["longVideo"] = True
|
||||
|
||||
# Poll options
|
||||
if input_data.poll_options:
|
||||
twitter_options["poll"] = {
|
||||
"duration": input_data.poll_duration,
|
||||
"options": input_data.poll_options,
|
||||
}
|
||||
|
||||
# Alt text for images
|
||||
if input_data.alt_text:
|
||||
twitter_options["altText"] = input_data.alt_text
|
||||
|
||||
# Thread options
|
||||
if input_data.is_thread:
|
||||
twitter_options["thread"] = True
|
||||
|
||||
if input_data.thread_number:
|
||||
twitter_options["threadNumber"] = True
|
||||
|
||||
if input_data.thread_media_urls:
|
||||
twitter_options["mediaUrls"] = input_data.thread_media_urls
|
||||
|
||||
# Video subtitle options
|
||||
if input_data.subtitle_url:
|
||||
twitter_options["subTitleUrl"] = input_data.subtitle_url
|
||||
twitter_options["subTitleLanguage"] = input_data.subtitle_language
|
||||
twitter_options["subTitleName"] = input_data.subtitle_name
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.TWITTER],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
twitter_options=twitter_options if twitter_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,310 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class YouTubeVisibility(str, Enum):
|
||||
PRIVATE = "private"
|
||||
PUBLIC = "public"
|
||||
UNLISTED = "unlisted"
|
||||
|
||||
|
||||
class PostToYouTubeBlock(Block):
|
||||
"""Block for posting to YouTube with YouTube-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for YouTube posts."""
|
||||
|
||||
# Override post field to include YouTube-specific information
|
||||
post: str = SchemaField(
|
||||
description="Video description (max 5,000 chars, empty string allowed). Cannot contain < or > characters.",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include YouTube-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Required video URL. YouTube only supports 1 video per post.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# YouTube-specific required options
|
||||
title: str = SchemaField(
|
||||
description="Video title (max 100 chars, required). Cannot contain < or > characters.",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# YouTube-specific optional options
|
||||
visibility: YouTubeVisibility = SchemaField(
|
||||
description="Video visibility: 'private' (default), 'public' , or 'unlisted'",
|
||||
default=YouTubeVisibility.PRIVATE,
|
||||
advanced=False,
|
||||
)
|
||||
thumbnail: str | None = SchemaField(
|
||||
description="Thumbnail URL (JPEG/PNG under 2MB, must end in .png/.jpg/.jpeg). Requires phone verification.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
playlist_id: str | None = SchemaField(
|
||||
description="Playlist ID to add video (user must own playlist)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
tags: list[str] | None = SchemaField(
|
||||
description="Video tags (min 2 chars each, max 500 chars total)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
made_for_kids: bool | None = SchemaField(
|
||||
description="Self-declared kids content", default=None, advanced=True
|
||||
)
|
||||
is_shorts: bool | None = SchemaField(
|
||||
description="Post as YouTube Short (max 3 minutes, adds #shorts)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
notify_subscribers: bool | None = SchemaField(
|
||||
description="Send notification to subscribers", default=None, advanced=True
|
||||
)
|
||||
category_id: int | None = SchemaField(
|
||||
description="Video category ID (e.g., 24 = Entertainment)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
contains_synthetic_media: bool | None = SchemaField(
|
||||
description="Disclose realistic AI/synthetic content",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
publish_at: str | None = SchemaField(
|
||||
description="UTC publish time (YouTube controlled, format: 2022-10-08T21:18:36Z)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
# YouTube targeting options (flattened from YouTubeTargeting object)
|
||||
targeting_block_countries: list[str] | None = SchemaField(
|
||||
description="Country codes to block from viewing (e.g., ['US', 'CA'])",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_allow_countries: list[str] | None = SchemaField(
|
||||
description="Country codes to allow viewing (e.g., ['GB', 'AU'])",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
subtitle_url: str | None = SchemaField(
|
||||
description="URL to SRT or SBV subtitle file (must be HTTPS and end in .srt/.sbv, under 100MB)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
subtitle_language: str | None = SchemaField(
|
||||
description="Language code for subtitles (default: 'en')",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
subtitle_name: str | None = SchemaField(
|
||||
description="Name of caption track (max 150 chars, default: 'English')",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0082d712-ff1b-4c3d-8a8d-6c7721883b83",
|
||||
description="Post to YouTube using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToYouTubeBlock.Input,
|
||||
output_schema=PostToYouTubeBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToYouTubeBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to YouTube with YouTube-specific validation and options."""
|
||||
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate YouTube constraints
|
||||
if not input_data.title:
|
||||
yield "error", "YouTube requires a video title"
|
||||
return
|
||||
|
||||
if len(input_data.title) > 100:
|
||||
yield "error", f"YouTube title exceeds 100 character limit ({len(input_data.title)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.post) > 5000:
|
||||
yield "error", f"YouTube description exceeds 5,000 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
# Check for forbidden characters
|
||||
forbidden_chars = ["<", ">"]
|
||||
for char in forbidden_chars:
|
||||
if char in input_data.title:
|
||||
yield "error", f"YouTube title cannot contain '{char}' character"
|
||||
return
|
||||
if char in input_data.post:
|
||||
yield "error", f"YouTube description cannot contain '{char}' character"
|
||||
return
|
||||
|
||||
if not input_data.media_urls:
|
||||
yield "error", "YouTube requires exactly one video URL"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 1:
|
||||
yield "error", "YouTube supports only 1 video per post"
|
||||
return
|
||||
|
||||
# Validate visibility option
|
||||
valid_visibility = ["private", "public", "unlisted"]
|
||||
if input_data.visibility not in valid_visibility:
|
||||
yield "error", f"YouTube visibility must be one of: {', '.join(valid_visibility)}"
|
||||
return
|
||||
|
||||
# Validate thumbnail URL format
|
||||
if input_data.thumbnail:
|
||||
valid_extensions = [".png", ".jpg", ".jpeg"]
|
||||
if not any(
|
||||
input_data.thumbnail.lower().endswith(ext) for ext in valid_extensions
|
||||
):
|
||||
yield "error", "YouTube thumbnail must end in .png, .jpg, or .jpeg"
|
||||
return
|
||||
|
||||
# Validate tags
|
||||
if input_data.tags:
|
||||
total_tag_length = sum(len(tag) for tag in input_data.tags)
|
||||
if total_tag_length > 500:
|
||||
yield "error", f"YouTube tags total length exceeds 500 characters ({total_tag_length} characters)"
|
||||
return
|
||||
|
||||
for tag in input_data.tags:
|
||||
if len(tag) < 2:
|
||||
yield "error", f"YouTube tag '{tag}' is too short (minimum 2 characters)"
|
||||
return
|
||||
|
||||
# Validate subtitle URL
|
||||
if input_data.subtitle_url:
|
||||
if not input_data.subtitle_url.startswith("https://"):
|
||||
yield "error", "YouTube subtitle URL must start with https://"
|
||||
return
|
||||
|
||||
valid_subtitle_extensions = [".srt", ".sbv"]
|
||||
if not any(
|
||||
input_data.subtitle_url.lower().endswith(ext)
|
||||
for ext in valid_subtitle_extensions
|
||||
):
|
||||
yield "error", "YouTube subtitle URL must end in .srt or .sbv"
|
||||
return
|
||||
|
||||
if input_data.subtitle_name and len(input_data.subtitle_name) > 150:
|
||||
yield "error", f"YouTube subtitle name exceeds 150 character limit ({len(input_data.subtitle_name)} characters)"
|
||||
return
|
||||
|
||||
# Validate publish_at format if provided
|
||||
if input_data.publish_at and input_data.schedule_date:
|
||||
yield "error", "Cannot use both 'publish_at' and 'schedule_date'. Use 'publish_at' for YouTube-controlled publishing."
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided (only if not using publish_at)
|
||||
iso_date = None
|
||||
if not input_data.publish_at and input_data.schedule_date:
|
||||
iso_date = input_data.schedule_date.isoformat()
|
||||
|
||||
# Build YouTube-specific options
|
||||
youtube_options: dict[str, Any] = {"title": input_data.title}
|
||||
|
||||
# Basic options
|
||||
if input_data.visibility != "private":
|
||||
youtube_options["visibility"] = input_data.visibility
|
||||
|
||||
if input_data.thumbnail:
|
||||
youtube_options["thumbNail"] = input_data.thumbnail
|
||||
|
||||
if input_data.playlist_id:
|
||||
youtube_options["playListId"] = input_data.playlist_id
|
||||
|
||||
if input_data.tags:
|
||||
youtube_options["tags"] = input_data.tags
|
||||
|
||||
if input_data.made_for_kids:
|
||||
youtube_options["madeForKids"] = True
|
||||
|
||||
if input_data.is_shorts:
|
||||
youtube_options["shorts"] = True
|
||||
|
||||
if not input_data.notify_subscribers:
|
||||
youtube_options["notifySubscribers"] = False
|
||||
|
||||
if input_data.category_id and input_data.category_id > 0:
|
||||
youtube_options["categoryId"] = input_data.category_id
|
||||
|
||||
if input_data.contains_synthetic_media:
|
||||
youtube_options["containsSyntheticMedia"] = True
|
||||
|
||||
if input_data.publish_at:
|
||||
youtube_options["publishAt"] = input_data.publish_at
|
||||
|
||||
# Country targeting (from flattened fields)
|
||||
targeting_dict = {}
|
||||
if input_data.targeting_block_countries:
|
||||
targeting_dict["block"] = input_data.targeting_block_countries
|
||||
if input_data.targeting_allow_countries:
|
||||
targeting_dict["allow"] = input_data.targeting_allow_countries
|
||||
|
||||
if targeting_dict:
|
||||
youtube_options["targeting"] = targeting_dict
|
||||
|
||||
# Subtitle options
|
||||
if input_data.subtitle_url:
|
||||
youtube_options["subTitleUrl"] = input_data.subtitle_url
|
||||
youtube_options["subTitleLanguage"] = input_data.subtitle_language
|
||||
youtube_options["subTitleName"] = input_data.subtitle_name
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.YOUTUBE],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=True, # YouTube only supports videos
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
youtube_options=youtube_options,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
66
autogpt_platform/backend/backend/blocks/baas/__init__.py
Normal file
66
autogpt_platform/backend/backend/blocks/baas/__init__.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
Meeting BaaS integration for AutoGPT Platform.
|
||||
|
||||
This integration provides comprehensive access to the Meeting BaaS API,
|
||||
including:
|
||||
- Bot management for meeting recordings
|
||||
- Calendar integration (Google/Microsoft)
|
||||
- Event management and scheduling
|
||||
- Webhook triggers for real-time events
|
||||
"""
|
||||
|
||||
# Bot (Recording) Blocks
|
||||
from .bots import (
|
||||
BaasBotDeleteRecordingBlock,
|
||||
BaasBotFetchMeetingDataBlock,
|
||||
BaasBotFetchScreenshotsBlock,
|
||||
BaasBotJoinMeetingBlock,
|
||||
BaasBotLeaveMeetingBlock,
|
||||
BaasBotRetranscribeBlock,
|
||||
)
|
||||
|
||||
# Calendar Blocks
|
||||
from .calendars import (
|
||||
BaasCalendarConnectBlock,
|
||||
BaasCalendarDeleteBlock,
|
||||
BaasCalendarListAllBlock,
|
||||
BaasCalendarResyncAllBlock,
|
||||
BaasCalendarUpdateCredsBlock,
|
||||
)
|
||||
|
||||
# Event Blocks
|
||||
from .events import (
|
||||
BaasEventGetDetailsBlock,
|
||||
BaasEventListBlock,
|
||||
BaasEventPatchBotBlock,
|
||||
BaasEventScheduleBotBlock,
|
||||
BaasEventUnscheduleBotBlock,
|
||||
)
|
||||
|
||||
# Webhook Triggers
|
||||
from .triggers import BaasOnCalendarEventBlock, BaasOnMeetingEventBlock
|
||||
|
||||
__all__ = [
|
||||
# Bot (Recording) Blocks
|
||||
"BaasBotJoinMeetingBlock",
|
||||
"BaasBotLeaveMeetingBlock",
|
||||
"BaasBotFetchMeetingDataBlock",
|
||||
"BaasBotFetchScreenshotsBlock",
|
||||
"BaasBotDeleteRecordingBlock",
|
||||
"BaasBotRetranscribeBlock",
|
||||
# Calendar Blocks
|
||||
"BaasCalendarConnectBlock",
|
||||
"BaasCalendarListAllBlock",
|
||||
"BaasCalendarUpdateCredsBlock",
|
||||
"BaasCalendarDeleteBlock",
|
||||
"BaasCalendarResyncAllBlock",
|
||||
# Event Blocks
|
||||
"BaasEventListBlock",
|
||||
"BaasEventGetDetailsBlock",
|
||||
"BaasEventScheduleBotBlock",
|
||||
"BaasEventUnscheduleBotBlock",
|
||||
"BaasEventPatchBotBlock",
|
||||
# Webhook Triggers
|
||||
"BaasOnMeetingEventBlock",
|
||||
"BaasOnCalendarEventBlock",
|
||||
]
|
||||
16
autogpt_platform/backend/backend/blocks/baas/_config.py
Normal file
16
autogpt_platform/backend/backend/blocks/baas/_config.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Shared configuration for all Meeting BaaS blocks using the SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
from ._webhook import BaasWebhookManager
|
||||
|
||||
# Configure the Meeting BaaS provider with API key authentication
|
||||
baas = (
|
||||
ProviderBuilder("baas")
|
||||
.with_api_key("MEETING_BAAS_API_KEY", "Meeting BaaS API Key")
|
||||
.with_webhook_manager(BaasWebhookManager)
|
||||
.with_base_cost(5, BlockCostType.RUN) # Higher cost for meeting recording service
|
||||
.build()
|
||||
)
|
||||
83
autogpt_platform/backend/backend/blocks/baas/_webhook.py
Normal file
83
autogpt_platform/backend/backend/blocks/baas/_webhook.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
Webhook management for Meeting BaaS blocks.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Tuple
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseWebhooksManager,
|
||||
Credentials,
|
||||
ProviderName,
|
||||
Webhook,
|
||||
)
|
||||
|
||||
|
||||
class BaasWebhookManager(BaseWebhooksManager):
|
||||
"""Webhook manager for Meeting BaaS API."""
|
||||
|
||||
PROVIDER_NAME = ProviderName("baas")
|
||||
|
||||
class WebhookType(str, Enum):
|
||||
MEETING_EVENT = "meeting_event"
|
||||
CALENDAR_EVENT = "calendar_event"
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(cls, webhook: Webhook, request) -> Tuple[dict, str]:
|
||||
"""Validate incoming webhook payload."""
|
||||
payload = await request.json()
|
||||
|
||||
# Verify API key in header
|
||||
api_key_header = request.headers.get("x-meeting-baas-api-key")
|
||||
if webhook.secret and api_key_header != webhook.secret:
|
||||
raise ValueError("Invalid webhook API key")
|
||||
|
||||
# Extract event type from payload
|
||||
event_type = payload.get("event", "unknown")
|
||||
|
||||
return payload, event_type
|
||||
|
||||
async def _register_webhook(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
ingress_url: str,
|
||||
secret: str,
|
||||
) -> Tuple[str, dict]:
|
||||
"""
|
||||
Register webhook with Meeting BaaS.
|
||||
|
||||
Note: Meeting BaaS doesn't have a webhook registration API.
|
||||
Webhooks are configured per-bot or as account defaults.
|
||||
This returns a synthetic webhook ID.
|
||||
"""
|
||||
if not isinstance(credentials, APIKeyCredentials):
|
||||
raise ValueError("Meeting BaaS webhooks require API key credentials")
|
||||
|
||||
# Generate a synthetic webhook ID since BaaS doesn't provide one
|
||||
import uuid
|
||||
|
||||
webhook_id = str(uuid.uuid4())
|
||||
|
||||
return webhook_id, {
|
||||
"webhook_type": webhook_type,
|
||||
"resource": resource,
|
||||
"events": events,
|
||||
"ingress_url": ingress_url,
|
||||
"api_key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
async def _deregister_webhook(
|
||||
self, webhook: Webhook, credentials: Credentials
|
||||
) -> None:
|
||||
"""
|
||||
Deregister webhook from Meeting BaaS.
|
||||
|
||||
Note: Meeting BaaS doesn't have a webhook deregistration API.
|
||||
Webhooks are removed by updating bot/calendar configurations.
|
||||
"""
|
||||
# No-op since BaaS doesn't have webhook deregistration
|
||||
pass
|
||||
367
autogpt_platform/backend/backend/blocks/baas/bots.py
Normal file
367
autogpt_platform/backend/backend/blocks/baas/bots.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
Meeting BaaS bot (recording) blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import baas
|
||||
|
||||
|
||||
class BaasBotJoinMeetingBlock(Block):
|
||||
"""
|
||||
Deploy a bot immediately or at a scheduled start_time to join and record a meeting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
meeting_url: str = SchemaField(
|
||||
description="The URL of the meeting the bot should join"
|
||||
)
|
||||
bot_name: str = SchemaField(
|
||||
description="Display name for the bot in the meeting"
|
||||
)
|
||||
bot_image: str = SchemaField(
|
||||
description="URL to an image for the bot's avatar (16:9 ratio recommended)",
|
||||
default="",
|
||||
)
|
||||
entry_message: str = SchemaField(
|
||||
description="Chat message the bot will post upon entry", default=""
|
||||
)
|
||||
reserved: bool = SchemaField(
|
||||
description="Use a reserved bot slot (joins 4 min before meeting)",
|
||||
default=False,
|
||||
)
|
||||
start_time: Optional[int] = SchemaField(
|
||||
description="Unix timestamp (ms) when bot should join", default=None
|
||||
)
|
||||
speech_to_text: dict = SchemaField(
|
||||
description="Speech-to-text configuration", default={"provider": "Gladia"}
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="URL to receive webhook events for this bot", default=""
|
||||
)
|
||||
timeouts: dict = SchemaField(
|
||||
description="Automatic leave timeouts configuration", default={}
|
||||
)
|
||||
extra: dict = SchemaField(
|
||||
description="Custom metadata to attach to the bot", default={}
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
bot_id: str = SchemaField(description="UUID of the deployed bot")
|
||||
join_response: dict = SchemaField(
|
||||
description="Full response from join operation"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7f8e9d0c-1b2a-3c4d-5e6f-7a8b9c0d1e2f",
|
||||
description="Deploy a bot to join and record a meeting",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build request body
|
||||
body = {
|
||||
"meeting_url": input_data.meeting_url,
|
||||
"bot_name": input_data.bot_name,
|
||||
"reserved": input_data.reserved,
|
||||
"speech_to_text": input_data.speech_to_text,
|
||||
}
|
||||
|
||||
# Add optional fields
|
||||
if input_data.bot_image:
|
||||
body["bot_image"] = input_data.bot_image
|
||||
if input_data.entry_message:
|
||||
body["entry_message"] = input_data.entry_message
|
||||
if input_data.start_time is not None:
|
||||
body["start_time"] = input_data.start_time
|
||||
if input_data.webhook_url:
|
||||
body["webhook_url"] = input_data.webhook_url
|
||||
if input_data.timeouts:
|
||||
body["automatic_leave"] = input_data.timeouts
|
||||
if input_data.extra:
|
||||
body["extra"] = input_data.extra
|
||||
|
||||
# Join meeting
|
||||
response = await Requests().post(
|
||||
"https://api.meetingbaas.com/bots",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
json=body,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "bot_id", data.get("bot_id", "")
|
||||
yield "join_response", data
|
||||
|
||||
|
||||
class BaasBotLeaveMeetingBlock(Block):
|
||||
"""
|
||||
Force the bot to exit the call.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot to remove from meeting")
|
||||
|
||||
class Output(BlockSchema):
|
||||
left: bool = SchemaField(description="Whether the bot successfully left")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8a9b0c1d-2e3f-4a5b-6c7d-8e9f0a1b2c3d",
|
||||
description="Remove a bot from an ongoing meeting",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Leave meeting
|
||||
response = await Requests().delete(
|
||||
f"https://api.meetingbaas.com/bots/{input_data.bot_id}",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
)
|
||||
|
||||
# Check if successful
|
||||
left = response.status in [200, 204]
|
||||
|
||||
yield "left", left
|
||||
|
||||
|
||||
class BaasBotFetchMeetingDataBlock(Block):
|
||||
"""
|
||||
Pull MP4 URL, transcript & metadata for a completed meeting.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot whose data to fetch")
|
||||
include_transcripts: bool = SchemaField(
|
||||
description="Include transcript data in response", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
mp4_url: str = SchemaField(
|
||||
description="URL to download the meeting recording (time-limited)"
|
||||
)
|
||||
transcript: list = SchemaField(description="Meeting transcript data")
|
||||
metadata: dict = SchemaField(description="Meeting metadata and bot information")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9b0c1d2e-3f4a-5b6c-7d8e-9f0a1b2c3d4e",
|
||||
description="Retrieve recorded meeting data",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build query parameters
|
||||
params = {
|
||||
"bot_id": input_data.bot_id,
|
||||
"include_transcripts": str(input_data.include_transcripts).lower(),
|
||||
}
|
||||
|
||||
# Fetch meeting data
|
||||
response = await Requests().get(
|
||||
"https://api.meetingbaas.com/bots/meeting_data",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
params=params,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "mp4_url", data.get("mp4", "")
|
||||
yield "transcript", data.get("bot_data", {}).get("transcripts", [])
|
||||
yield "metadata", data.get("bot_data", {}).get("bot", {})
|
||||
|
||||
|
||||
class BaasBotFetchScreenshotsBlock(Block):
|
||||
"""
|
||||
List screenshots captured during the call.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(
|
||||
description="UUID of the bot whose screenshots to fetch"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
screenshots: list[dict] = SchemaField(
|
||||
description="Array of screenshot objects with date and url"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0c1d2e3f-4a5b-6c7d-8e9f-0a1b2c3d4e5f",
|
||||
description="Retrieve screenshots captured during a meeting",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Fetch screenshots
|
||||
response = await Requests().get(
|
||||
f"https://api.meetingbaas.com/bots/{input_data.bot_id}/screenshots",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
)
|
||||
|
||||
screenshots = response.json()
|
||||
|
||||
yield "screenshots", screenshots
|
||||
|
||||
|
||||
class BaasBotDeleteRecordingBlock(Block):
|
||||
"""
|
||||
Purge MP4 + transcript data for privacy or storage management.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(description="UUID of the bot whose data to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
deleted: bool = SchemaField(
|
||||
description="Whether the data was successfully deleted"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1d2e3f4a-5b6c-7d8e-9f0a-1b2c3d4e5f6a",
|
||||
description="Permanently delete a meeting's recorded data",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Delete recording data
|
||||
response = await Requests().post(
|
||||
f"https://api.meetingbaas.com/bots/{input_data.bot_id}/delete_data",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
)
|
||||
|
||||
# Check if successful
|
||||
deleted = response.status == 200
|
||||
|
||||
yield "deleted", deleted
|
||||
|
||||
|
||||
class BaasBotRetranscribeBlock(Block):
|
||||
"""
|
||||
Re-run STT on past audio with a different provider or settings.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
bot_id: str = SchemaField(
|
||||
description="UUID of the bot whose audio to retranscribe"
|
||||
)
|
||||
provider: str = SchemaField(
|
||||
description="Speech-to-text provider to use (e.g., Gladia, Deepgram)"
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="URL to receive transcription complete event", default=""
|
||||
)
|
||||
custom_options: dict = SchemaField(
|
||||
description="Provider-specific options", default={}
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
job_id: Optional[str] = SchemaField(
|
||||
description="Transcription job ID if available"
|
||||
)
|
||||
accepted: bool = SchemaField(
|
||||
description="Whether the retranscription request was accepted"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2e3f4a5b-6c7d-8e9f-0a1b-2c3d4e5f6a7b",
|
||||
description="Re-run transcription on a meeting's audio",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build request body
|
||||
body = {"bot_uuid": input_data.bot_id, "provider": input_data.provider}
|
||||
|
||||
if input_data.webhook_url:
|
||||
body["webhook_url"] = input_data.webhook_url
|
||||
|
||||
if input_data.custom_options:
|
||||
body.update(input_data.custom_options)
|
||||
|
||||
# Start retranscription
|
||||
response = await Requests().post(
|
||||
"https://api.meetingbaas.com/bots/retranscribe",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
json=body,
|
||||
)
|
||||
|
||||
# Check if accepted
|
||||
accepted = response.status in [200, 202]
|
||||
job_id = None
|
||||
|
||||
if accepted and response.status == 200:
|
||||
data = response.json()
|
||||
job_id = data.get("job_id")
|
||||
|
||||
yield "job_id", job_id
|
||||
yield "accepted", accepted
|
||||
265
autogpt_platform/backend/backend/blocks/baas/calendars.py
Normal file
265
autogpt_platform/backend/backend/blocks/baas/calendars.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Meeting BaaS calendar blocks.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import baas
|
||||
|
||||
|
||||
class BaasCalendarConnectBlock(Block):
|
||||
"""
|
||||
One-time integration of a Google or Microsoft calendar.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
oauth_client_id: str = SchemaField(description="OAuth client ID from provider")
|
||||
oauth_client_secret: str = SchemaField(description="OAuth client secret")
|
||||
oauth_refresh_token: str = SchemaField(
|
||||
description="OAuth refresh token with calendar access"
|
||||
)
|
||||
platform: str = SchemaField(
|
||||
description="Calendar platform (Google or Microsoft)"
|
||||
)
|
||||
calendar_email_or_id: str = SchemaField(
|
||||
description="Specific calendar email/ID to connect", default=""
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
calendar_id: str = SchemaField(description="UUID of the connected calendar")
|
||||
calendar_obj: dict = SchemaField(description="Full calendar object")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3f4a5b6c-7d8e-9f0a-1b2c-3d4e5f6a7b8c",
|
||||
description="Connect a Google or Microsoft calendar for integration",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build request body
|
||||
body = {
|
||||
"oauth_client_id": input_data.oauth_client_id,
|
||||
"oauth_client_secret": input_data.oauth_client_secret,
|
||||
"oauth_refresh_token": input_data.oauth_refresh_token,
|
||||
"platform": input_data.platform,
|
||||
}
|
||||
|
||||
if input_data.calendar_email_or_id:
|
||||
body["calendar_email"] = input_data.calendar_email_or_id
|
||||
|
||||
# Connect calendar
|
||||
response = await Requests().post(
|
||||
"https://api.meetingbaas.com/calendars",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
json=body,
|
||||
)
|
||||
|
||||
calendar = response.json()
|
||||
|
||||
yield "calendar_id", calendar.get("uuid", "")
|
||||
yield "calendar_obj", calendar
|
||||
|
||||
|
||||
class BaasCalendarListAllBlock(Block):
|
||||
"""
|
||||
Enumerate connected calendars.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
calendars: list[dict] = SchemaField(
|
||||
description="Array of connected calendar objects"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4a5b6c7d-8e9f-0a1b-2c3d-4e5f6a7b8c9d",
|
||||
description="List all integrated calendars",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# List calendars
|
||||
response = await Requests().get(
|
||||
"https://api.meetingbaas.com/calendars",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
)
|
||||
|
||||
calendars = response.json()
|
||||
|
||||
yield "calendars", calendars
|
||||
|
||||
|
||||
class BaasCalendarUpdateCredsBlock(Block):
|
||||
"""
|
||||
Refresh OAuth or switch provider for an existing calendar.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
calendar_id: str = SchemaField(description="UUID of the calendar to update")
|
||||
oauth_client_id: str = SchemaField(
|
||||
description="New OAuth client ID", default=""
|
||||
)
|
||||
oauth_client_secret: str = SchemaField(
|
||||
description="New OAuth client secret", default=""
|
||||
)
|
||||
oauth_refresh_token: str = SchemaField(
|
||||
description="New OAuth refresh token", default=""
|
||||
)
|
||||
platform: str = SchemaField(description="New platform if switching", default="")
|
||||
|
||||
class Output(BlockSchema):
|
||||
calendar_obj: dict = SchemaField(description="Updated calendar object")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5b6c7d8e-9f0a-1b2c-3d4e-5f6a7b8c9d0e",
|
||||
description="Update calendar credentials or platform",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build request body with only provided fields
|
||||
body = {}
|
||||
if input_data.oauth_client_id:
|
||||
body["oauth_client_id"] = input_data.oauth_client_id
|
||||
if input_data.oauth_client_secret:
|
||||
body["oauth_client_secret"] = input_data.oauth_client_secret
|
||||
if input_data.oauth_refresh_token:
|
||||
body["oauth_refresh_token"] = input_data.oauth_refresh_token
|
||||
if input_data.platform:
|
||||
body["platform"] = input_data.platform
|
||||
|
||||
# Update calendar
|
||||
response = await Requests().patch(
|
||||
f"https://api.meetingbaas.com/calendars/{input_data.calendar_id}",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
json=body,
|
||||
)
|
||||
|
||||
calendar = response.json()
|
||||
|
||||
yield "calendar_obj", calendar
|
||||
|
||||
|
||||
class BaasCalendarDeleteBlock(Block):
|
||||
"""
|
||||
Disconnect calendar & unschedule future bots.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
calendar_id: str = SchemaField(description="UUID of the calendar to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
deleted: bool = SchemaField(
|
||||
description="Whether the calendar was successfully deleted"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="6c7d8e9f-0a1b-2c3d-4e5f-6a7b8c9d0e1f",
|
||||
description="Remove a calendar integration",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Delete calendar
|
||||
response = await Requests().delete(
|
||||
f"https://api.meetingbaas.com/calendars/{input_data.calendar_id}",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
)
|
||||
|
||||
deleted = response.status in [200, 204]
|
||||
|
||||
yield "deleted", deleted
|
||||
|
||||
|
||||
class BaasCalendarResyncAllBlock(Block):
|
||||
"""
|
||||
Force full sync now (maintenance).
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
synced_ids: list[str] = SchemaField(
|
||||
description="Calendar UUIDs that synced successfully"
|
||||
)
|
||||
errors: list[list] = SchemaField(
|
||||
description="Array of [calendar_id, error_message] tuples"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7d8e9f0a-1b2c-3d4e-5f6a-7b8c9d0e1f2a",
|
||||
description="Force immediate re-sync of all connected calendars",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Resync all calendars
|
||||
response = await Requests().post(
|
||||
"https://api.meetingbaas.com/internal/calendar/resync_all",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "synced_ids", data.get("synced_calendars", [])
|
||||
yield "errors", data.get("errors", [])
|
||||
276
autogpt_platform/backend/backend/blocks/baas/events.py
Normal file
276
autogpt_platform/backend/backend/blocks/baas/events.py
Normal file
@@ -0,0 +1,276 @@
|
||||
"""
|
||||
Meeting BaaS calendar event blocks.
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import baas
|
||||
|
||||
|
||||
class BaasEventListBlock(Block):
|
||||
"""
|
||||
Get events for a calendar & date range.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
calendar_id: str = SchemaField(
|
||||
description="UUID of the calendar to list events from"
|
||||
)
|
||||
start_date_gte: str = SchemaField(
|
||||
description="ISO date string for start date (greater than or equal)",
|
||||
default="",
|
||||
)
|
||||
start_date_lte: str = SchemaField(
|
||||
description="ISO date string for start date (less than or equal)",
|
||||
default="",
|
||||
)
|
||||
cursor: str = SchemaField(
|
||||
description="Pagination cursor from previous request", default=""
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
events: list[dict] = SchemaField(description="Array of calendar events")
|
||||
next_cursor: str = SchemaField(description="Cursor for next page of results")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8e9f0a1b-2c3d-4e5f-6a7b-8c9d0e1f2a3b",
|
||||
description="List calendar events with optional date filtering",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build query parameters
|
||||
params = {"calendar_id": input_data.calendar_id}
|
||||
|
||||
if input_data.start_date_gte:
|
||||
params["start_date_gte"] = input_data.start_date_gte
|
||||
if input_data.start_date_lte:
|
||||
params["start_date_lte"] = input_data.start_date_lte
|
||||
if input_data.cursor:
|
||||
params["cursor"] = input_data.cursor
|
||||
|
||||
# List events
|
||||
response = await Requests().get(
|
||||
"https://api.meetingbaas.com/calendar_events",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
params=params,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "events", data.get("events", [])
|
||||
yield "next_cursor", data.get("next", "")
|
||||
|
||||
|
||||
class BaasEventGetDetailsBlock(Block):
|
||||
"""
|
||||
Fetch full object for one event.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
event_id: str = SchemaField(description="UUID of the event to retrieve")
|
||||
|
||||
class Output(BlockSchema):
|
||||
event: dict = SchemaField(description="Full event object with all details")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9f0a1b2c-3d4e-5f6a-7b8c-9d0e1f2a3b4c",
|
||||
description="Get detailed information for a specific calendar event",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Get event details
|
||||
response = await Requests().get(
|
||||
f"https://api.meetingbaas.com/calendar_events/{input_data.event_id}",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
)
|
||||
|
||||
event = response.json()
|
||||
|
||||
yield "event", event
|
||||
|
||||
|
||||
class BaasEventScheduleBotBlock(Block):
|
||||
"""
|
||||
Attach bot config to the event for automatic recording.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
event_id: str = SchemaField(description="UUID of the event to schedule bot for")
|
||||
all_occurrences: bool = SchemaField(
|
||||
description="Apply to all occurrences of recurring event", default=False
|
||||
)
|
||||
bot_config: dict = SchemaField(
|
||||
description="Bot configuration (same as Bot → Join Meeting)"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
events: Union[dict, list[dict]] = SchemaField(
|
||||
description="Updated event(s) with bot scheduled"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0a1b2c3d-4e5f-6a7b-8c9d-0e1f2a3b4c5d",
|
||||
description="Schedule a recording bot for a calendar event",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build query parameters
|
||||
params = {"all_occurrences": str(input_data.all_occurrences).lower()}
|
||||
|
||||
# Schedule bot
|
||||
response = await Requests().post(
|
||||
f"https://api.meetingbaas.com/calendar_events/{input_data.event_id}/bot",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
params=params,
|
||||
json=input_data.bot_config,
|
||||
)
|
||||
|
||||
events = response.json()
|
||||
|
||||
yield "events", events
|
||||
|
||||
|
||||
class BaasEventUnscheduleBotBlock(Block):
|
||||
"""
|
||||
Remove bot from event/series.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
event_id: str = SchemaField(
|
||||
description="UUID of the event to unschedule bot from"
|
||||
)
|
||||
all_occurrences: bool = SchemaField(
|
||||
description="Apply to all occurrences of recurring event", default=False
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
events: Union[dict, list[dict]] = SchemaField(
|
||||
description="Updated event(s) with bot removed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1b2c3d4e-5f6a-7b8c-9d0e-1f2a3b4c5d6e",
|
||||
description="Cancel a scheduled recording for an event",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build query parameters
|
||||
params = {"all_occurrences": str(input_data.all_occurrences).lower()}
|
||||
|
||||
# Unschedule bot
|
||||
response = await Requests().delete(
|
||||
f"https://api.meetingbaas.com/calendar_events/{input_data.event_id}/bot",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
params=params,
|
||||
)
|
||||
|
||||
events = response.json()
|
||||
|
||||
yield "events", events
|
||||
|
||||
|
||||
class BaasEventPatchBotBlock(Block):
|
||||
"""
|
||||
Modify an already-scheduled bot configuration.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
event_id: str = SchemaField(description="UUID of the event with scheduled bot")
|
||||
all_occurrences: bool = SchemaField(
|
||||
description="Apply to all occurrences of recurring event", default=False
|
||||
)
|
||||
bot_patch: dict = SchemaField(description="Bot configuration fields to update")
|
||||
|
||||
class Output(BlockSchema):
|
||||
events: Union[dict, list[dict]] = SchemaField(
|
||||
description="Updated event(s) with modified bot config"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2c3d4e5f-6a7b-8c9d-0e1f-2a3b4c5d6e7f",
|
||||
description="Update configuration of a scheduled bot",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build query parameters
|
||||
params = {}
|
||||
if input_data.all_occurrences is not None:
|
||||
params["all_occurrences"] = str(input_data.all_occurrences).lower()
|
||||
|
||||
# Patch bot
|
||||
response = await Requests().patch(
|
||||
f"https://api.meetingbaas.com/calendar_events/{input_data.event_id}/bot",
|
||||
headers={"x-meeting-baas-api-key": api_key},
|
||||
params=params,
|
||||
json=input_data.bot_patch,
|
||||
)
|
||||
|
||||
events = response.json()
|
||||
|
||||
yield "events", events
|
||||
185
autogpt_platform/backend/backend/blocks/baas/triggers.py
Normal file
185
autogpt_platform/backend/backend/blocks/baas/triggers.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
Meeting BaaS webhook trigger blocks.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
BlockWebhookConfig,
|
||||
CredentialsMetaInput,
|
||||
ProviderName,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import baas
|
||||
|
||||
|
||||
class BaasOnMeetingEventBlock(Block):
|
||||
"""
|
||||
Trigger when Meeting BaaS sends meeting-related events:
|
||||
bot.status_change, complete, failed, transcription_complete
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="URL to receive webhooks (auto-generated)",
|
||||
default="",
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
class EventsFilter(BaseModel):
|
||||
"""Meeting event types to subscribe to"""
|
||||
|
||||
bot_status_change: bool = SchemaField(
|
||||
description="Bot status changes", default=True
|
||||
)
|
||||
complete: bool = SchemaField(description="Meeting completed", default=True)
|
||||
failed: bool = SchemaField(description="Meeting failed", default=True)
|
||||
transcription_complete: bool = SchemaField(
|
||||
description="Transcription completed", default=True
|
||||
)
|
||||
|
||||
events: EventsFilter = SchemaField(
|
||||
title="Events", description="The events to subscribe to"
|
||||
)
|
||||
|
||||
payload: dict = SchemaField(
|
||||
description="Webhook payload data",
|
||||
default={},
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
event_type: str = SchemaField(description="Type of event received")
|
||||
data: dict = SchemaField(description="Event data payload")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3d4e5f6a-7b8c-9d0e-1f2a-3b4c5d6e7f8a",
|
||||
description="Receive meeting events from Meeting BaaS webhooks",
|
||||
categories={BlockCategory.INPUT},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
block_type=BlockType.WEBHOOK,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName("baas"),
|
||||
webhook_type="meeting_event",
|
||||
event_filter_input="events",
|
||||
resource_format="meeting",
|
||||
),
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
payload = input_data.payload
|
||||
|
||||
# Extract event type and data
|
||||
event_type = payload.get("event", "unknown")
|
||||
data = payload.get("data", {})
|
||||
|
||||
# Map event types to filter fields
|
||||
event_filter_map = {
|
||||
"bot.status_change": input_data.events.bot_status_change,
|
||||
"complete": input_data.events.complete,
|
||||
"failed": input_data.events.failed,
|
||||
"transcription_complete": input_data.events.transcription_complete,
|
||||
}
|
||||
|
||||
# Filter events if needed
|
||||
if not event_filter_map.get(event_type, False):
|
||||
return # Skip unwanted events
|
||||
|
||||
yield "event_type", event_type
|
||||
yield "data", data
|
||||
|
||||
|
||||
class BaasOnCalendarEventBlock(Block):
|
||||
"""
|
||||
Trigger when Meeting BaaS sends calendar-related events:
|
||||
event.added, event.updated, event.deleted, calendar.synced
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = baas.credentials_field(
|
||||
description="Meeting BaaS API credentials"
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="URL to receive webhooks (auto-generated)",
|
||||
default="",
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
class EventsFilter(BaseModel):
|
||||
"""Calendar event types to subscribe to"""
|
||||
|
||||
event_added: bool = SchemaField(
|
||||
description="Calendar event added", default=True
|
||||
)
|
||||
event_updated: bool = SchemaField(
|
||||
description="Calendar event updated", default=True
|
||||
)
|
||||
event_deleted: bool = SchemaField(
|
||||
description="Calendar event deleted", default=True
|
||||
)
|
||||
calendar_synced: bool = SchemaField(
|
||||
description="Calendar synced", default=True
|
||||
)
|
||||
|
||||
events: EventsFilter = SchemaField(
|
||||
title="Events", description="The events to subscribe to"
|
||||
)
|
||||
|
||||
payload: dict = SchemaField(
|
||||
description="Webhook payload data",
|
||||
default={},
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
event_type: str = SchemaField(description="Type of event received")
|
||||
data: dict = SchemaField(description="Event data payload")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4e5f6a7b-8c9d-0e1f-2a3b-4c5d6e7f8a9b",
|
||||
description="Receive calendar events from Meeting BaaS webhooks",
|
||||
categories={BlockCategory.INPUT},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
block_type=BlockType.WEBHOOK,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName("baas"),
|
||||
webhook_type="calendar_event",
|
||||
event_filter_input="events",
|
||||
resource_format="calendar",
|
||||
),
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
payload = input_data.payload
|
||||
|
||||
# Extract event type and data
|
||||
event_type = payload.get("event", "unknown")
|
||||
data = payload.get("data", {})
|
||||
|
||||
# Map event types to filter fields
|
||||
event_filter_map = {
|
||||
"event.added": input_data.events.event_added,
|
||||
"event.updated": input_data.events.event_updated,
|
||||
"event.deleted": input_data.events.event_deleted,
|
||||
"calendar.synced": input_data.events.calendar_synced,
|
||||
}
|
||||
|
||||
# Filter events if needed
|
||||
if not event_filter_map.get(event_type, False):
|
||||
return # Skip unwanted events
|
||||
|
||||
yield "event_type", event_type
|
||||
yield "data", data
|
||||
@@ -39,13 +39,11 @@ class FileStoreBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
yield "file_out", await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.file_in,
|
||||
user_id=user_id,
|
||||
return_content=input_data.base_64,
|
||||
)
|
||||
|
||||
|
||||
109
autogpt_platform/backend/backend/blocks/csv.py
Normal file
109
autogpt_platform/backend/backend/blocks/csv.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import ContributorDetails, SchemaField
|
||||
|
||||
|
||||
class ReadCsvBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
contents: str = SchemaField(
|
||||
description="The contents of the CSV file to read",
|
||||
placeholder="a, b, c\n1,2,3\n4,5,6",
|
||||
)
|
||||
delimiter: str = SchemaField(
|
||||
description="The delimiter used in the CSV file",
|
||||
default=",",
|
||||
)
|
||||
quotechar: str = SchemaField(
|
||||
description="The character used to quote fields",
|
||||
default='"',
|
||||
)
|
||||
escapechar: str = SchemaField(
|
||||
description="The character used to escape the delimiter",
|
||||
default="\\",
|
||||
)
|
||||
has_header: bool = SchemaField(
|
||||
description="Whether the CSV file has a header row",
|
||||
default=True,
|
||||
)
|
||||
skip_rows: int = SchemaField(
|
||||
description="The number of rows to skip from the start of the file",
|
||||
default=0,
|
||||
)
|
||||
strip: bool = SchemaField(
|
||||
description="Whether to strip whitespace from the values",
|
||||
default=True,
|
||||
)
|
||||
skip_columns: list[str] = SchemaField(
|
||||
description="The columns to skip from the start of the row",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
row: dict[str, str] = SchemaField(
|
||||
description="The data produced from each row in the CSV file"
|
||||
)
|
||||
all_data: list[dict[str, str]] = SchemaField(
|
||||
description="All the data in the CSV file as a list of rows"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="acf7625e-d2cb-4941-bfeb-2819fc6fc015",
|
||||
input_schema=ReadCsvBlock.Input,
|
||||
output_schema=ReadCsvBlock.Output,
|
||||
description="Reads a CSV file and outputs the data as a list of dictionaries and individual rows via rows.",
|
||||
contributors=[ContributorDetails(name="Nicholas Tindle")],
|
||||
categories={BlockCategory.TEXT, BlockCategory.DATA},
|
||||
test_input={
|
||||
"contents": "a, b, c\n1,2,3\n4,5,6",
|
||||
},
|
||||
test_output=[
|
||||
("row", {"a": "1", "b": "2", "c": "3"}),
|
||||
("row", {"a": "4", "b": "5", "c": "6"}),
|
||||
(
|
||||
"all_data",
|
||||
[
|
||||
{"a": "1", "b": "2", "c": "3"},
|
||||
{"a": "4", "b": "5", "c": "6"},
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
import csv
|
||||
from io import StringIO
|
||||
|
||||
csv_file = StringIO(input_data.contents)
|
||||
reader = csv.reader(
|
||||
csv_file,
|
||||
delimiter=input_data.delimiter,
|
||||
quotechar=input_data.quotechar,
|
||||
escapechar=input_data.escapechar,
|
||||
)
|
||||
|
||||
header = None
|
||||
if input_data.has_header:
|
||||
header = next(reader)
|
||||
if input_data.strip:
|
||||
header = [h.strip() for h in header]
|
||||
|
||||
for _ in range(input_data.skip_rows):
|
||||
next(reader)
|
||||
|
||||
def process_row(row):
|
||||
data = {}
|
||||
for i, value in enumerate(row):
|
||||
if i not in input_data.skip_columns:
|
||||
if input_data.has_header and header:
|
||||
data[header[i]] = value.strip() if input_data.strip else value
|
||||
else:
|
||||
data[str(i)] = value.strip() if input_data.strip else value
|
||||
return data
|
||||
|
||||
all_data = []
|
||||
for row in reader:
|
||||
processed_row = process_row(row)
|
||||
all_data.append(processed_row)
|
||||
yield "row", processed_row
|
||||
|
||||
yield "all_data", all_data
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
ElevenLabs integration blocks for AutoGPT Platform.
|
||||
"""
|
||||
|
||||
# Speech generation blocks
|
||||
from .speech import (
|
||||
ElevenLabsGenerateSpeechBlock,
|
||||
ElevenLabsGenerateSpeechWithTimestampsBlock,
|
||||
)
|
||||
|
||||
# Speech-to-text blocks
|
||||
from .transcription import (
|
||||
ElevenLabsTranscribeAudioAsyncBlock,
|
||||
ElevenLabsTranscribeAudioSyncBlock,
|
||||
)
|
||||
|
||||
# Webhook trigger blocks
|
||||
from .triggers import ElevenLabsWebhookTriggerBlock
|
||||
|
||||
# Utility blocks
|
||||
from .utility import ElevenLabsGetUsageStatsBlock, ElevenLabsListModelsBlock
|
||||
|
||||
# Voice management blocks
|
||||
from .voices import (
|
||||
ElevenLabsCreateVoiceCloneBlock,
|
||||
ElevenLabsDeleteVoiceBlock,
|
||||
ElevenLabsGetVoiceDetailsBlock,
|
||||
ElevenLabsListVoicesBlock,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Voice management
|
||||
"ElevenLabsListVoicesBlock",
|
||||
"ElevenLabsGetVoiceDetailsBlock",
|
||||
"ElevenLabsCreateVoiceCloneBlock",
|
||||
"ElevenLabsDeleteVoiceBlock",
|
||||
# Speech generation
|
||||
"ElevenLabsGenerateSpeechBlock",
|
||||
"ElevenLabsGenerateSpeechWithTimestampsBlock",
|
||||
# Speech-to-text
|
||||
"ElevenLabsTranscribeAudioSyncBlock",
|
||||
"ElevenLabsTranscribeAudioAsyncBlock",
|
||||
# Utility
|
||||
"ElevenLabsListModelsBlock",
|
||||
"ElevenLabsGetUsageStatsBlock",
|
||||
# Webhook triggers
|
||||
"ElevenLabsWebhookTriggerBlock",
|
||||
]
|
||||
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Shared configuration for all ElevenLabs blocks using the SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
from ._webhook import ElevenLabsWebhookManager
|
||||
|
||||
# Configure the ElevenLabs provider with API key authentication
|
||||
elevenlabs = (
|
||||
ProviderBuilder("elevenlabs")
|
||||
.with_api_key("ELEVENLABS_API_KEY", "ElevenLabs API Key")
|
||||
.with_webhook_manager(ElevenLabsWebhookManager)
|
||||
.with_base_cost(2, BlockCostType.RUN) # Base cost for API calls
|
||||
.build()
|
||||
)
|
||||
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
ElevenLabs webhook manager for handling webhook events.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
from typing import Tuple
|
||||
|
||||
from backend.data.model import Credentials
|
||||
from backend.sdk import BaseWebhooksManager, ProviderName, Webhook
|
||||
|
||||
|
||||
class ElevenLabsWebhookManager(BaseWebhooksManager):
|
||||
"""Manages ElevenLabs webhook events."""
|
||||
|
||||
PROVIDER_NAME = ProviderName("elevenlabs")
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(cls, webhook: Webhook, request) -> Tuple[dict, str]:
|
||||
"""
|
||||
Validate incoming webhook payload and signature.
|
||||
|
||||
ElevenLabs supports HMAC authentication for webhooks.
|
||||
"""
|
||||
payload = await request.json()
|
||||
|
||||
# Verify webhook signature if configured
|
||||
if webhook.secret:
|
||||
webhook_secret = webhook.config.get("webhook_secret")
|
||||
if webhook_secret:
|
||||
# Get the raw body for signature verification
|
||||
body = await request.body()
|
||||
|
||||
# Calculate expected signature
|
||||
expected_signature = hmac.new(
|
||||
webhook_secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Get signature from headers
|
||||
signature = request.headers.get("x-elevenlabs-signature")
|
||||
|
||||
if signature and not hmac.compare_digest(signature, expected_signature):
|
||||
raise ValueError("Invalid webhook signature")
|
||||
|
||||
# Extract event type from payload
|
||||
event_type = payload.get("type", "unknown")
|
||||
return payload, event_type
|
||||
|
||||
async def _register_webhook(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
ingress_url: str,
|
||||
secret: str,
|
||||
) -> tuple[str, dict]:
|
||||
"""
|
||||
Register a webhook with ElevenLabs.
|
||||
|
||||
Note: ElevenLabs webhook registration is done through their dashboard,
|
||||
not via API. This is a placeholder implementation.
|
||||
"""
|
||||
# ElevenLabs requires manual webhook setup through dashboard
|
||||
# Return empty webhook ID and config with instructions
|
||||
config = {
|
||||
"manual_setup_required": True,
|
||||
"webhook_secret": secret,
|
||||
"instructions": "Please configure webhook URL in ElevenLabs dashboard",
|
||||
}
|
||||
return "", config
|
||||
|
||||
async def _deregister_webhook(
|
||||
self, webhook: Webhook, credentials: Credentials
|
||||
) -> None:
|
||||
"""
|
||||
Deregister a webhook with ElevenLabs.
|
||||
|
||||
Note: ElevenLabs webhook removal is done through their dashboard.
|
||||
"""
|
||||
# ElevenLabs requires manual webhook removal through dashboard
|
||||
pass
|
||||
179
autogpt_platform/backend/backend/blocks/elevenlabs/speech.py
Normal file
179
autogpt_platform/backend/backend/blocks/elevenlabs/speech.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
ElevenLabs speech generation (text-to-speech) blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import elevenlabs
|
||||
|
||||
|
||||
class ElevenLabsGenerateSpeechBlock(Block):
|
||||
"""
|
||||
Turn text into audio (binary).
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
voice_id: str = SchemaField(description="ID of the voice to use")
|
||||
text: str = SchemaField(description="Text to convert to speech")
|
||||
model_id: str = SchemaField(
|
||||
description="Model ID to use for generation",
|
||||
default="eleven_multilingual_v2",
|
||||
)
|
||||
output_format: str = SchemaField(
|
||||
description="Audio format (e.g., mp3_44100_128)",
|
||||
default="mp3_44100_128",
|
||||
)
|
||||
voice_settings: Optional[dict] = SchemaField(
|
||||
description="Override voice settings (stability, similarity_boost, etc.)",
|
||||
default=None,
|
||||
)
|
||||
language_code: Optional[str] = SchemaField(
|
||||
description="Language code to enforce output language", default=None
|
||||
)
|
||||
seed: Optional[int] = SchemaField(
|
||||
description="Seed for reproducible output", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
audio: str = SchemaField(description="Base64-encoded audio data")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c5d6e7f8-a9b0-c1d2-e3f4-a5b6c7d8e9f0",
|
||||
description="Generate speech audio from text using a specified voice",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
import base64
|
||||
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build request body
|
||||
body: dict[str, str | int | dict] = {
|
||||
"text": input_data.text,
|
||||
"model_id": input_data.model_id,
|
||||
}
|
||||
|
||||
# Add optional fields
|
||||
if input_data.voice_settings:
|
||||
body["voice_settings"] = input_data.voice_settings
|
||||
if input_data.language_code:
|
||||
body["language_code"] = input_data.language_code
|
||||
if input_data.seed is not None:
|
||||
body["seed"] = input_data.seed
|
||||
|
||||
# Generate speech
|
||||
response = await Requests().post(
|
||||
f"https://api.elevenlabs.io/v1/text-to-speech/{input_data.voice_id}",
|
||||
headers={
|
||||
"xi-api-key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=body,
|
||||
params={"output_format": input_data.output_format},
|
||||
)
|
||||
|
||||
# Get audio data and encode to base64
|
||||
audio_data = response.content
|
||||
audio_base64 = base64.b64encode(audio_data).decode("utf-8")
|
||||
|
||||
yield "audio", audio_base64
|
||||
|
||||
|
||||
class ElevenLabsGenerateSpeechWithTimestampsBlock(Block):
|
||||
"""
|
||||
Text to audio AND per-character timing data.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
voice_id: str = SchemaField(description="ID of the voice to use")
|
||||
text: str = SchemaField(description="Text to convert to speech")
|
||||
model_id: str = SchemaField(
|
||||
description="Model ID to use for generation",
|
||||
default="eleven_multilingual_v2",
|
||||
)
|
||||
output_format: str = SchemaField(
|
||||
description="Audio format (e.g., mp3_44100_128)",
|
||||
default="mp3_44100_128",
|
||||
)
|
||||
voice_settings: Optional[dict] = SchemaField(
|
||||
description="Override voice settings (stability, similarity_boost, etc.)",
|
||||
default=None,
|
||||
)
|
||||
language_code: Optional[str] = SchemaField(
|
||||
description="Language code to enforce output language", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
audio_base64: str = SchemaField(description="Base64-encoded audio data")
|
||||
alignment: dict = SchemaField(
|
||||
description="Character-level timing alignment data"
|
||||
)
|
||||
normalized_alignment: dict = SchemaField(
|
||||
description="Normalized text alignment data"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d6e7f8a9-b0c1-d2e3-f4a5-b6c7d8e9f0a1",
|
||||
description="Generate speech with character-level timestamp information",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build request body
|
||||
body: dict[str, str | dict] = {
|
||||
"text": input_data.text,
|
||||
"model_id": input_data.model_id,
|
||||
}
|
||||
|
||||
# Add optional fields
|
||||
if input_data.voice_settings:
|
||||
body["voice_settings"] = input_data.voice_settings
|
||||
if input_data.language_code:
|
||||
body["language_code"] = input_data.language_code
|
||||
|
||||
# Generate speech with timestamps
|
||||
response = await Requests().post(
|
||||
f"https://api.elevenlabs.io/v1/text-to-speech/{input_data.voice_id}/with-timestamps",
|
||||
headers={
|
||||
"xi-api-key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=body,
|
||||
params={"output_format": input_data.output_format},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "audio_base64", data.get("audio_base64", "")
|
||||
yield "alignment", data.get("alignment", {})
|
||||
yield "normalized_alignment", data.get("normalized_alignment", {})
|
||||
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
ElevenLabs speech-to-text (transcription) blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import elevenlabs
|
||||
|
||||
|
||||
class ElevenLabsTranscribeAudioSyncBlock(Block):
|
||||
"""
|
||||
Synchronously convert audio to text (+ word timestamps, diarization).
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
model_id: str = SchemaField(
|
||||
description="Model ID for transcription", default="scribe_v1"
|
||||
)
|
||||
file: Optional[str] = SchemaField(
|
||||
description="Base64-encoded audio file", default=None
|
||||
)
|
||||
cloud_storage_url: Optional[str] = SchemaField(
|
||||
description="URL to audio file in cloud storage", default=None
|
||||
)
|
||||
language_code: Optional[str] = SchemaField(
|
||||
description="Language code (ISO 639-1 or -3) to improve accuracy",
|
||||
default=None,
|
||||
)
|
||||
diarize: bool = SchemaField(
|
||||
description="Enable speaker diarization", default=False
|
||||
)
|
||||
num_speakers: Optional[int] = SchemaField(
|
||||
description="Expected number of speakers (max 32)", default=None
|
||||
)
|
||||
timestamps_granularity: str = SchemaField(
|
||||
description="Timestamp detail level: word, character, or none",
|
||||
default="word",
|
||||
)
|
||||
tag_audio_events: bool = SchemaField(
|
||||
description="Tag non-speech sounds (laughter, noise)", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
text: str = SchemaField(description="Full transcribed text")
|
||||
words: list[dict] = SchemaField(
|
||||
description="Array with word timing and speaker info"
|
||||
)
|
||||
language_code: str = SchemaField(description="Detected language code")
|
||||
language_probability: float = SchemaField(
|
||||
description="Confidence in language detection"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e7f8a9b0-c1d2-e3f4-a5b6-c7d8e9f0a1b2",
|
||||
description="Transcribe audio to text with timing and speaker information",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Validate input - must have either file or URL
|
||||
if not input_data.file and not input_data.cloud_storage_url:
|
||||
raise ValueError("Either 'file' or 'cloud_storage_url' must be provided")
|
||||
if input_data.file and input_data.cloud_storage_url:
|
||||
raise ValueError(
|
||||
"Only one of 'file' or 'cloud_storage_url' should be provided"
|
||||
)
|
||||
|
||||
# Build form data
|
||||
form_data = {
|
||||
"model_id": input_data.model_id,
|
||||
"diarize": str(input_data.diarize).lower(),
|
||||
"timestamps_granularity": input_data.timestamps_granularity,
|
||||
"tag_audio_events": str(input_data.tag_audio_events).lower(),
|
||||
}
|
||||
|
||||
if input_data.language_code:
|
||||
form_data["language_code"] = input_data.language_code
|
||||
if input_data.num_speakers is not None:
|
||||
form_data["num_speakers"] = str(input_data.num_speakers)
|
||||
|
||||
# Handle file or URL
|
||||
files = None
|
||||
if input_data.file:
|
||||
# Decode base64 file
|
||||
file_data = base64.b64decode(input_data.file)
|
||||
files = [("file", ("audio.wav", BytesIO(file_data), "audio/wav"))]
|
||||
elif input_data.cloud_storage_url:
|
||||
form_data["cloud_storage_url"] = input_data.cloud_storage_url
|
||||
|
||||
# Transcribe audio
|
||||
response = await Requests().post(
|
||||
"https://api.elevenlabs.io/v1/speech-to-text",
|
||||
headers={"xi-api-key": api_key},
|
||||
data=form_data,
|
||||
files=files,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "text", data.get("text", "")
|
||||
yield "words", data.get("words", [])
|
||||
yield "language_code", data.get("language_code", "")
|
||||
yield "language_probability", data.get("language_probability", 0.0)
|
||||
|
||||
|
||||
class ElevenLabsTranscribeAudioAsyncBlock(Block):
|
||||
"""
|
||||
Kick off transcription that returns quickly; result arrives via webhook.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
model_id: str = SchemaField(
|
||||
description="Model ID for transcription", default="scribe_v1"
|
||||
)
|
||||
file: Optional[str] = SchemaField(
|
||||
description="Base64-encoded audio file", default=None
|
||||
)
|
||||
cloud_storage_url: Optional[str] = SchemaField(
|
||||
description="URL to audio file in cloud storage", default=None
|
||||
)
|
||||
language_code: Optional[str] = SchemaField(
|
||||
description="Language code (ISO 639-1 or -3) to improve accuracy",
|
||||
default=None,
|
||||
)
|
||||
diarize: bool = SchemaField(
|
||||
description="Enable speaker diarization", default=False
|
||||
)
|
||||
num_speakers: Optional[int] = SchemaField(
|
||||
description="Expected number of speakers (max 32)", default=None
|
||||
)
|
||||
timestamps_granularity: str = SchemaField(
|
||||
description="Timestamp detail level: word, character, or none",
|
||||
default="word",
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="URL to receive transcription result",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
tracking_id: str = SchemaField(description="ID to track the transcription job")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f8a9b0c1-d2e3-f4a5-b6c7-d8e9f0a1b2c3",
|
||||
description="Start async transcription with webhook callback",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
import base64
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Validate input
|
||||
if not input_data.file and not input_data.cloud_storage_url:
|
||||
raise ValueError("Either 'file' or 'cloud_storage_url' must be provided")
|
||||
if input_data.file and input_data.cloud_storage_url:
|
||||
raise ValueError(
|
||||
"Only one of 'file' or 'cloud_storage_url' should be provided"
|
||||
)
|
||||
|
||||
# Build form data
|
||||
form_data = {
|
||||
"model_id": input_data.model_id,
|
||||
"diarize": str(input_data.diarize).lower(),
|
||||
"timestamps_granularity": input_data.timestamps_granularity,
|
||||
"webhook": "true", # Enable async mode
|
||||
}
|
||||
|
||||
if input_data.language_code:
|
||||
form_data["language_code"] = input_data.language_code
|
||||
if input_data.num_speakers is not None:
|
||||
form_data["num_speakers"] = str(input_data.num_speakers)
|
||||
if input_data.webhook_url:
|
||||
form_data["webhook_url"] = input_data.webhook_url
|
||||
|
||||
# Handle file or URL
|
||||
files = None
|
||||
if input_data.file:
|
||||
# Decode base64 file
|
||||
file_data = base64.b64decode(input_data.file)
|
||||
files = [("file", ("audio.wav", BytesIO(file_data), "audio/wav"))]
|
||||
elif input_data.cloud_storage_url:
|
||||
form_data["cloud_storage_url"] = input_data.cloud_storage_url
|
||||
|
||||
# Start async transcription
|
||||
response = await Requests().post(
|
||||
"https://api.elevenlabs.io/v1/speech-to-text",
|
||||
headers={"xi-api-key": api_key},
|
||||
data=form_data,
|
||||
files=files,
|
||||
)
|
||||
|
||||
# Generate tracking ID (API might return one)
|
||||
data = response.json()
|
||||
tracking_id = data.get("tracking_id", str(uuid.uuid4()))
|
||||
|
||||
yield "tracking_id", tracking_id
|
||||
160
autogpt_platform/backend/backend/blocks/elevenlabs/triggers.py
Normal file
160
autogpt_platform/backend/backend/blocks/elevenlabs/triggers.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
ElevenLabs webhook trigger blocks.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
BlockWebhookConfig,
|
||||
CredentialsMetaInput,
|
||||
ProviderName,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import elevenlabs
|
||||
|
||||
|
||||
class ElevenLabsWebhookTriggerBlock(Block):
|
||||
"""
|
||||
Starts a flow when ElevenLabs POSTs an event (STT finished, voice removal, etc.).
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="URL to receive webhooks (auto-generated)",
|
||||
default="",
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
class EventsFilter(BaseModel):
|
||||
"""ElevenLabs event types to subscribe to"""
|
||||
|
||||
speech_to_text_completed: bool = SchemaField(
|
||||
description="Speech-to-text transcription completed", default=True
|
||||
)
|
||||
post_call_transcription: bool = SchemaField(
|
||||
description="Conversational AI call transcription completed",
|
||||
default=True,
|
||||
)
|
||||
voice_removal_notice: bool = SchemaField(
|
||||
description="Voice scheduled for removal", default=True
|
||||
)
|
||||
voice_removed: bool = SchemaField(
|
||||
description="Voice has been removed", default=True
|
||||
)
|
||||
voice_removal_notice_withdrawn: bool = SchemaField(
|
||||
description="Voice removal cancelled", default=True
|
||||
)
|
||||
|
||||
events: EventsFilter = SchemaField(
|
||||
title="Events", description="The events to subscribe to"
|
||||
)
|
||||
|
||||
# Webhook payload - populated by the system
|
||||
payload: dict = SchemaField(
|
||||
description="Webhook payload data",
|
||||
default={},
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
type: str = SchemaField(description="Event type")
|
||||
event_timestamp: int = SchemaField(description="Unix timestamp of the event")
|
||||
data: dict = SchemaField(description="Event-specific data payload")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c1d2e3f4-a5b6-c7d8-e9f0-a1b2c3d4e5f6",
|
||||
description="Receive webhook events from ElevenLabs",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
block_type=BlockType.WEBHOOK,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName("elevenlabs"),
|
||||
webhook_type="notification",
|
||||
event_filter_input="events",
|
||||
resource_format="",
|
||||
),
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Extract webhook data
|
||||
payload = input_data.payload
|
||||
|
||||
# Extract event type
|
||||
event_type = payload.get("type", "unknown")
|
||||
|
||||
# Map event types to filter fields
|
||||
event_filter_map = {
|
||||
"speech_to_text_completed": input_data.events.speech_to_text_completed,
|
||||
"post_call_transcription": input_data.events.post_call_transcription,
|
||||
"voice_removal_notice": input_data.events.voice_removal_notice,
|
||||
"voice_removed": input_data.events.voice_removed,
|
||||
"voice_removal_notice_withdrawn": input_data.events.voice_removal_notice_withdrawn,
|
||||
}
|
||||
|
||||
# Check if this event type is enabled
|
||||
if not event_filter_map.get(event_type, False):
|
||||
# Skip this event
|
||||
return
|
||||
|
||||
# Extract common fields
|
||||
yield "type", event_type
|
||||
yield "event_timestamp", payload.get("event_timestamp", 0)
|
||||
|
||||
# Extract event-specific data
|
||||
data = payload.get("data", {})
|
||||
|
||||
# Process based on event type
|
||||
if event_type == "speech_to_text_completed":
|
||||
# STT transcription completed
|
||||
processed_data = {
|
||||
"transcription_id": data.get("transcription_id"),
|
||||
"text": data.get("text"),
|
||||
"words": data.get("words", []),
|
||||
"language_code": data.get("language_code"),
|
||||
"language_probability": data.get("language_probability"),
|
||||
}
|
||||
elif event_type == "post_call_transcription":
|
||||
# Conversational AI call transcription
|
||||
processed_data = {
|
||||
"agent_id": data.get("agent_id"),
|
||||
"conversation_id": data.get("conversation_id"),
|
||||
"transcript": data.get("transcript"),
|
||||
"metadata": data.get("metadata", {}),
|
||||
}
|
||||
elif event_type == "voice_removal_notice":
|
||||
# Voice scheduled for removal
|
||||
processed_data = {
|
||||
"voice_id": data.get("voice_id"),
|
||||
"voice_name": data.get("voice_name"),
|
||||
"removal_date": data.get("removal_date"),
|
||||
"reason": data.get("reason"),
|
||||
}
|
||||
elif event_type == "voice_removal_notice_withdrawn":
|
||||
# Voice removal cancelled
|
||||
processed_data = {
|
||||
"voice_id": data.get("voice_id"),
|
||||
"voice_name": data.get("voice_name"),
|
||||
}
|
||||
elif event_type == "voice_removed":
|
||||
# Voice has been removed
|
||||
processed_data = {
|
||||
"voice_id": data.get("voice_id"),
|
||||
"voice_name": data.get("voice_name"),
|
||||
"removed_at": data.get("removed_at"),
|
||||
}
|
||||
else:
|
||||
# Unknown event type, pass through raw data
|
||||
processed_data = data
|
||||
|
||||
yield "data", processed_data
|
||||
116
autogpt_platform/backend/backend/blocks/elevenlabs/utility.py
Normal file
116
autogpt_platform/backend/backend/blocks/elevenlabs/utility.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
ElevenLabs utility blocks for models and usage stats.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import elevenlabs
|
||||
|
||||
|
||||
class ElevenLabsListModelsBlock(Block):
|
||||
"""
|
||||
Get all available model IDs & capabilities.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
models: list[dict] = SchemaField(
|
||||
description="Array of model objects with capabilities"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a9b0c1d2-e3f4-a5b6-c7d8-e9f0a1b2c3d4",
|
||||
description="List all available voice models and their capabilities",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Fetch models
|
||||
response = await Requests().get(
|
||||
"https://api.elevenlabs.io/v1/models",
|
||||
headers={"xi-api-key": api_key},
|
||||
)
|
||||
|
||||
models = response.json()
|
||||
|
||||
yield "models", models
|
||||
|
||||
|
||||
class ElevenLabsGetUsageStatsBlock(Block):
|
||||
"""
|
||||
Character / credit usage for billing dashboards.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
start_unix: int = SchemaField(
|
||||
description="Start timestamp in Unix epoch seconds"
|
||||
)
|
||||
end_unix: int = SchemaField(description="End timestamp in Unix epoch seconds")
|
||||
aggregation_interval: str = SchemaField(
|
||||
description="Aggregation interval: daily or monthly",
|
||||
default="daily",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
usage: list[dict] = SchemaField(description="Array of usage data per interval")
|
||||
total_character_count: int = SchemaField(
|
||||
description="Total characters used in period"
|
||||
)
|
||||
total_requests: int = SchemaField(description="Total API requests in period")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b0c1d2e3-f4a5-b6c7-d8e9-f0a1b2c3d4e5",
|
||||
description="Get character and credit usage statistics",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build query parameters
|
||||
params = {
|
||||
"start_unix": input_data.start_unix,
|
||||
"end_unix": input_data.end_unix,
|
||||
"aggregation_interval": input_data.aggregation_interval,
|
||||
}
|
||||
|
||||
# Fetch usage stats
|
||||
response = await Requests().get(
|
||||
"https://api.elevenlabs.io/v1/usage/character-stats",
|
||||
headers={"xi-api-key": api_key},
|
||||
params=params,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "usage", data.get("usage", [])
|
||||
yield "total_character_count", data.get("total_character_count", 0)
|
||||
yield "total_requests", data.get("total_requests", 0)
|
||||
249
autogpt_platform/backend/backend/blocks/elevenlabs/voices.py
Normal file
249
autogpt_platform/backend/backend/blocks/elevenlabs/voices.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""
|
||||
ElevenLabs voice management blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import elevenlabs
|
||||
|
||||
|
||||
class ElevenLabsListVoicesBlock(Block):
|
||||
"""
|
||||
Fetch all voices the account can use (for pick-lists, UI menus, etc.).
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
search: str = SchemaField(
|
||||
description="Search term to filter voices", default=""
|
||||
)
|
||||
voice_type: Optional[str] = SchemaField(
|
||||
description="Filter by voice type: premade, cloned, or professional",
|
||||
default=None,
|
||||
)
|
||||
page_size: int = SchemaField(
|
||||
description="Number of voices per page (max 100)", default=10
|
||||
)
|
||||
next_page_token: str = SchemaField(
|
||||
description="Token for fetching next page", default=""
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
voices: list[dict] = SchemaField(
|
||||
description="Array of voice objects with id, name, category, etc."
|
||||
)
|
||||
next_page_token: Optional[str] = SchemaField(
|
||||
description="Token for fetching next page, null if no more pages"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e1a2b3c4-d5e6-f7a8-b9c0-d1e2f3a4b5c6",
|
||||
description="List all available voices with filtering and pagination",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Build query parameters
|
||||
params: dict[str, str | int] = {"page_size": input_data.page_size}
|
||||
|
||||
if input_data.search:
|
||||
params["search"] = input_data.search
|
||||
if input_data.voice_type:
|
||||
params["voice_type"] = input_data.voice_type
|
||||
if input_data.next_page_token:
|
||||
params["next_page_token"] = input_data.next_page_token
|
||||
|
||||
# Fetch voices
|
||||
response = await Requests().get(
|
||||
"https://api.elevenlabs.io/v2/voices",
|
||||
headers={"xi-api-key": api_key},
|
||||
params=params,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "voices", data.get("voices", [])
|
||||
yield "next_page_token", data.get("next_page_token")
|
||||
|
||||
|
||||
class ElevenLabsGetVoiceDetailsBlock(Block):
|
||||
"""
|
||||
Retrieve metadata/settings for a single voice.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
voice_id: str = SchemaField(description="The ID of the voice to retrieve")
|
||||
|
||||
class Output(BlockSchema):
|
||||
voice: dict = SchemaField(
|
||||
description="Voice object with name, labels, settings, etc."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f2a3b4c5-d6e7-f8a9-b0c1-d2e3f4a5b6c7",
|
||||
description="Get detailed information about a specific voice",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Fetch voice details
|
||||
response = await Requests().get(
|
||||
f"https://api.elevenlabs.io/v1/voices/{input_data.voice_id}",
|
||||
headers={"xi-api-key": api_key},
|
||||
)
|
||||
|
||||
voice = response.json()
|
||||
|
||||
yield "voice", voice
|
||||
|
||||
|
||||
class ElevenLabsCreateVoiceCloneBlock(Block):
|
||||
"""
|
||||
Upload sample clips to create a custom (IVC) voice.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
name: str = SchemaField(description="Name for the new voice")
|
||||
files: list[str] = SchemaField(
|
||||
description="Base64-encoded audio files (1-10 files, max 25MB each)"
|
||||
)
|
||||
description: str = SchemaField(
|
||||
description="Description of the voice", default=""
|
||||
)
|
||||
labels: dict = SchemaField(
|
||||
description="Metadata labels (e.g., accent, age)", default={}
|
||||
)
|
||||
remove_background_noise: bool = SchemaField(
|
||||
description="Whether to remove background noise from samples", default=False
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
voice_id: str = SchemaField(description="ID of the newly created voice")
|
||||
requires_verification: bool = SchemaField(
|
||||
description="Whether the voice requires verification"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a3b4c5d6-e7f8-a9b0-c1d2-e3f4a5b6c7d8",
|
||||
description="Create a new voice clone from audio samples",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
import base64
|
||||
import json
|
||||
from io import BytesIO
|
||||
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Prepare multipart form data
|
||||
form_data = {
|
||||
"name": input_data.name,
|
||||
}
|
||||
|
||||
if input_data.description:
|
||||
form_data["description"] = input_data.description
|
||||
if input_data.labels:
|
||||
form_data["labels"] = json.dumps(input_data.labels)
|
||||
if input_data.remove_background_noise:
|
||||
form_data["remove_background_noise"] = "true"
|
||||
|
||||
# Prepare files
|
||||
files = []
|
||||
for i, file_b64 in enumerate(input_data.files):
|
||||
file_data = base64.b64decode(file_b64)
|
||||
files.append(
|
||||
("files", (f"sample_{i}.mp3", BytesIO(file_data), "audio/mpeg"))
|
||||
)
|
||||
|
||||
# Create voice
|
||||
response = await Requests().post(
|
||||
"https://api.elevenlabs.io/v1/voices/add",
|
||||
headers={"xi-api-key": api_key},
|
||||
data=form_data,
|
||||
files=files,
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
|
||||
yield "voice_id", result.get("voice_id", "")
|
||||
yield "requires_verification", result.get("requires_verification", False)
|
||||
|
||||
|
||||
class ElevenLabsDeleteVoiceBlock(Block):
|
||||
"""
|
||||
Permanently remove a custom voice.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = elevenlabs.credentials_field(
|
||||
description="ElevenLabs API credentials"
|
||||
)
|
||||
voice_id: str = SchemaField(description="The ID of the voice to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Deletion status (ok or error)")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b4c5d6e7-f8a9-b0c1-d2e3-f4a5b6c7d8e9",
|
||||
description="Delete a custom voice from your account",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Delete voice
|
||||
response = await Requests().delete(
|
||||
f"https://api.elevenlabs.io/v1/voices/{input_data.voice_id}",
|
||||
headers={"xi-api-key": api_key},
|
||||
)
|
||||
|
||||
# Check if successful
|
||||
if response.status in [200, 204]:
|
||||
yield "status", "ok"
|
||||
else:
|
||||
yield "status", "error"
|
||||
@@ -1,408 +0,0 @@
|
||||
"""
|
||||
API module for Enrichlayer integration.
|
||||
|
||||
This module provides a client for interacting with the Enrichlayer API,
|
||||
which allows fetching LinkedIn profile data and related information.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import enum
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.model import APIKeyCredentials
|
||||
from backend.util.request import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class EnrichlayerAPIException(Exception):
|
||||
"""Exception raised for Enrichlayer API errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class FallbackToCache(enum.Enum):
|
||||
ON_ERROR = "on-error"
|
||||
NEVER = "never"
|
||||
|
||||
|
||||
class UseCache(enum.Enum):
|
||||
IF_PRESENT = "if-present"
|
||||
NEVER = "never"
|
||||
|
||||
|
||||
class SocialMediaProfiles(BaseModel):
|
||||
"""Social media profiles model."""
|
||||
|
||||
twitter: Optional[str] = None
|
||||
facebook: Optional[str] = None
|
||||
github: Optional[str] = None
|
||||
|
||||
|
||||
class Experience(BaseModel):
|
||||
"""Experience model for LinkedIn profiles."""
|
||||
|
||||
company: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
location: Optional[str] = None
|
||||
starts_at: Optional[dict[str, int]] = None
|
||||
ends_at: Optional[dict[str, int]] = None
|
||||
company_linkedin_profile_url: Optional[str] = None
|
||||
|
||||
|
||||
class Education(BaseModel):
|
||||
"""Education model for LinkedIn profiles."""
|
||||
|
||||
school: Optional[str] = None
|
||||
degree_name: Optional[str] = None
|
||||
field_of_study: Optional[str] = None
|
||||
starts_at: Optional[dict[str, int]] = None
|
||||
ends_at: Optional[dict[str, int]] = None
|
||||
school_linkedin_profile_url: Optional[str] = None
|
||||
|
||||
|
||||
class PersonProfileResponse(BaseModel):
|
||||
"""Response model for LinkedIn person profile.
|
||||
|
||||
This model represents the response from Enrichlayer's LinkedIn profile API.
|
||||
The API returns comprehensive profile data including work experience,
|
||||
education, skills, and contact information (when available).
|
||||
|
||||
Example API Response:
|
||||
{
|
||||
"public_identifier": "johnsmith",
|
||||
"full_name": "John Smith",
|
||||
"occupation": "Software Engineer at Tech Corp",
|
||||
"experiences": [
|
||||
{
|
||||
"company": "Tech Corp",
|
||||
"title": "Software Engineer",
|
||||
"starts_at": {"year": 2020, "month": 1}
|
||||
}
|
||||
],
|
||||
"education": [...],
|
||||
"skills": ["Python", "JavaScript", ...]
|
||||
}
|
||||
"""
|
||||
|
||||
public_identifier: Optional[str] = None
|
||||
profile_pic_url: Optional[str] = None
|
||||
full_name: Optional[str] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
occupation: Optional[str] = None
|
||||
headline: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
country_full_name: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
state: Optional[str] = None
|
||||
experiences: Optional[list[Experience]] = None
|
||||
education: Optional[list[Education]] = None
|
||||
languages: Optional[list[str]] = None
|
||||
skills: Optional[list[str]] = None
|
||||
inferred_salary: Optional[dict[str, Any]] = None
|
||||
personal_email: Optional[str] = None
|
||||
personal_contact_number: Optional[str] = None
|
||||
social_media_profiles: Optional[SocialMediaProfiles] = None
|
||||
extra: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class SimilarProfile(BaseModel):
|
||||
"""Similar profile model for LinkedIn person lookup."""
|
||||
|
||||
similarity: float
|
||||
linkedin_profile_url: str
|
||||
|
||||
|
||||
class PersonLookupResponse(BaseModel):
|
||||
"""Response model for LinkedIn person lookup.
|
||||
|
||||
This model represents the response from Enrichlayer's person lookup API.
|
||||
The API returns a LinkedIn profile URL and similarity scores when
|
||||
searching for a person by name and company.
|
||||
|
||||
Example API Response:
|
||||
{
|
||||
"url": "https://www.linkedin.com/in/johnsmith/",
|
||||
"name_similarity_score": 0.95,
|
||||
"company_similarity_score": 0.88,
|
||||
"title_similarity_score": 0.75,
|
||||
"location_similarity_score": 0.60
|
||||
}
|
||||
"""
|
||||
|
||||
url: str | None = None
|
||||
name_similarity_score: float | None
|
||||
company_similarity_score: float | None
|
||||
title_similarity_score: float | None
|
||||
location_similarity_score: float | None
|
||||
last_updated: datetime.datetime | None = None
|
||||
profile: PersonProfileResponse | None = None
|
||||
|
||||
|
||||
class RoleLookupResponse(BaseModel):
|
||||
"""Response model for LinkedIn role lookup.
|
||||
|
||||
This model represents the response from Enrichlayer's role lookup API.
|
||||
The API returns LinkedIn profile data for a specific role at a company.
|
||||
|
||||
Example API Response:
|
||||
{
|
||||
"linkedin_profile_url": "https://www.linkedin.com/in/johnsmith/",
|
||||
"profile_data": {...} // Full PersonProfileResponse data when enrich_profile=True
|
||||
}
|
||||
"""
|
||||
|
||||
linkedin_profile_url: Optional[str] = None
|
||||
profile_data: Optional[PersonProfileResponse] = None
|
||||
|
||||
|
||||
class ProfilePictureResponse(BaseModel):
|
||||
"""Response model for LinkedIn profile picture.
|
||||
|
||||
This model represents the response from Enrichlayer's profile picture API.
|
||||
The API returns a URL to the person's LinkedIn profile picture.
|
||||
|
||||
Example API Response:
|
||||
{
|
||||
"tmp_profile_pic_url": "https://media.licdn.com/dms/image/..."
|
||||
}
|
||||
"""
|
||||
|
||||
tmp_profile_pic_url: str = Field(
|
||||
..., description="URL of the profile picture", alias="tmp_profile_pic_url"
|
||||
)
|
||||
|
||||
@property
|
||||
def profile_picture_url(self) -> str:
|
||||
"""Backward compatibility property for profile_picture_url."""
|
||||
return self.tmp_profile_pic_url
|
||||
|
||||
|
||||
class EnrichlayerClient:
|
||||
"""Client for interacting with the Enrichlayer API."""
|
||||
|
||||
API_BASE_URL = "https://enrichlayer.com/api/v2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
credentials: Optional[APIKeyCredentials] = None,
|
||||
custom_requests: Optional[Requests] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Enrichlayer client.
|
||||
|
||||
Args:
|
||||
credentials: The credentials to use for authentication.
|
||||
custom_requests: Custom Requests instance for testing.
|
||||
"""
|
||||
if custom_requests:
|
||||
self._requests = custom_requests
|
||||
else:
|
||||
headers: dict[str, str] = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if credentials:
|
||||
headers["Authorization"] = (
|
||||
f"Bearer {credentials.api_key.get_secret_value()}"
|
||||
)
|
||||
|
||||
self._requests = Requests(
|
||||
extra_headers=headers,
|
||||
raise_for_status=False,
|
||||
)
|
||||
|
||||
async def _handle_response(self, response) -> Any:
|
||||
"""
|
||||
Handle API response and check for errors.
|
||||
|
||||
Args:
|
||||
response: The response object from the request.
|
||||
|
||||
Returns:
|
||||
The response data.
|
||||
|
||||
Raises:
|
||||
EnrichlayerAPIException: If the API request fails.
|
||||
"""
|
||||
if not response.ok:
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_message = error_data.get("message", "")
|
||||
except JSONDecodeError:
|
||||
error_message = response.text
|
||||
|
||||
raise EnrichlayerAPIException(
|
||||
f"Enrichlayer API request failed ({response.status_code}): {error_message}",
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def fetch_profile(
|
||||
self,
|
||||
linkedin_url: str,
|
||||
fallback_to_cache: FallbackToCache = FallbackToCache.ON_ERROR,
|
||||
use_cache: UseCache = UseCache.IF_PRESENT,
|
||||
include_skills: bool = False,
|
||||
include_inferred_salary: bool = False,
|
||||
include_personal_email: bool = False,
|
||||
include_personal_contact_number: bool = False,
|
||||
include_social_media: bool = False,
|
||||
include_extra: bool = False,
|
||||
) -> PersonProfileResponse:
|
||||
"""
|
||||
Fetch a LinkedIn profile with optional parameters.
|
||||
|
||||
Args:
|
||||
linkedin_url: The LinkedIn profile URL to fetch.
|
||||
fallback_to_cache: Cache usage if live fetch fails ('on-error' or 'never').
|
||||
use_cache: Cache utilization ('if-present' or 'never').
|
||||
include_skills: Whether to include skills data.
|
||||
include_inferred_salary: Whether to include inferred salary data.
|
||||
include_personal_email: Whether to include personal email.
|
||||
include_personal_contact_number: Whether to include personal contact number.
|
||||
include_social_media: Whether to include social media profiles.
|
||||
include_extra: Whether to include additional data.
|
||||
|
||||
Returns:
|
||||
The LinkedIn profile data.
|
||||
|
||||
Raises:
|
||||
EnrichlayerAPIException: If the API request fails.
|
||||
"""
|
||||
params = {
|
||||
"url": linkedin_url,
|
||||
"fallback_to_cache": fallback_to_cache.value.lower(),
|
||||
"use_cache": use_cache.value.lower(),
|
||||
}
|
||||
|
||||
if include_skills:
|
||||
params["skills"] = "include"
|
||||
if include_inferred_salary:
|
||||
params["inferred_salary"] = "include"
|
||||
if include_personal_email:
|
||||
params["personal_email"] = "include"
|
||||
if include_personal_contact_number:
|
||||
params["personal_contact_number"] = "include"
|
||||
if include_social_media:
|
||||
params["twitter_profile_id"] = "include"
|
||||
params["facebook_profile_id"] = "include"
|
||||
params["github_profile_id"] = "include"
|
||||
if include_extra:
|
||||
params["extra"] = "include"
|
||||
|
||||
response = await self._requests.get(
|
||||
f"{self.API_BASE_URL}/profile", params=params
|
||||
)
|
||||
return PersonProfileResponse(**await self._handle_response(response))
|
||||
|
||||
async def lookup_person(
|
||||
self,
|
||||
first_name: str,
|
||||
company_domain: str,
|
||||
last_name: str | None = None,
|
||||
location: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
include_similarity_checks: bool = False,
|
||||
enrich_profile: bool = False,
|
||||
) -> PersonLookupResponse:
|
||||
"""
|
||||
Look up a LinkedIn profile by person's information.
|
||||
|
||||
Args:
|
||||
first_name: The person's first name.
|
||||
last_name: The person's last name.
|
||||
company_domain: The domain of the company they work for.
|
||||
location: The person's location.
|
||||
title: The person's job title.
|
||||
include_similarity_checks: Whether to include similarity checks.
|
||||
enrich_profile: Whether to enrich the profile.
|
||||
|
||||
Returns:
|
||||
The LinkedIn profile lookup result.
|
||||
|
||||
Raises:
|
||||
EnrichlayerAPIException: If the API request fails.
|
||||
"""
|
||||
params = {"first_name": first_name, "company_domain": company_domain}
|
||||
|
||||
if last_name:
|
||||
params["last_name"] = last_name
|
||||
if location:
|
||||
params["location"] = location
|
||||
if title:
|
||||
params["title"] = title
|
||||
if include_similarity_checks:
|
||||
params["similarity_checks"] = "include"
|
||||
if enrich_profile:
|
||||
params["enrich_profile"] = "enrich"
|
||||
|
||||
response = await self._requests.get(
|
||||
f"{self.API_BASE_URL}/profile/resolve", params=params
|
||||
)
|
||||
return PersonLookupResponse(**await self._handle_response(response))
|
||||
|
||||
async def lookup_role(
|
||||
self, role: str, company_name: str, enrich_profile: bool = False
|
||||
) -> RoleLookupResponse:
|
||||
"""
|
||||
Look up a LinkedIn profile by role in a company.
|
||||
|
||||
Args:
|
||||
role: The role title (e.g., CEO, CTO).
|
||||
company_name: The name of the company.
|
||||
enrich_profile: Whether to enrich the profile.
|
||||
|
||||
Returns:
|
||||
The LinkedIn profile lookup result.
|
||||
|
||||
Raises:
|
||||
EnrichlayerAPIException: If the API request fails.
|
||||
"""
|
||||
params = {
|
||||
"role": role,
|
||||
"company_name": company_name,
|
||||
}
|
||||
|
||||
if enrich_profile:
|
||||
params["enrich_profile"] = "enrich"
|
||||
|
||||
response = await self._requests.get(
|
||||
f"{self.API_BASE_URL}/find/company/role", params=params
|
||||
)
|
||||
return RoleLookupResponse(**await self._handle_response(response))
|
||||
|
||||
async def get_profile_picture(
|
||||
self, linkedin_profile_url: str
|
||||
) -> ProfilePictureResponse:
|
||||
"""
|
||||
Get a LinkedIn profile picture URL.
|
||||
|
||||
Args:
|
||||
linkedin_profile_url: The LinkedIn profile URL.
|
||||
|
||||
Returns:
|
||||
The profile picture URL.
|
||||
|
||||
Raises:
|
||||
EnrichlayerAPIException: If the API request fails.
|
||||
"""
|
||||
params = {
|
||||
"linkedin_person_profile_url": linkedin_profile_url,
|
||||
}
|
||||
|
||||
response = await self._requests.get(
|
||||
f"{self.API_BASE_URL}/person/profile-picture", params=params
|
||||
)
|
||||
return ProfilePictureResponse(**await self._handle_response(response))
|
||||
@@ -1,34 +0,0 @@
|
||||
"""
|
||||
Authentication module for Enrichlayer API integration.
|
||||
|
||||
This module provides credential types and test credentials for the Enrichlayer API.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
# Define the type of credentials input expected for Enrichlayer API
|
||||
EnrichlayerCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.ENRICHLAYER], Literal["api_key"]
|
||||
]
|
||||
|
||||
# Mock credentials for testing Enrichlayer API integration
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="1234a567-89bc-4def-ab12-3456cdef7890",
|
||||
provider="enrichlayer",
|
||||
api_key=SecretStr("mock-enrichlayer-api-key"),
|
||||
title="Mock Enrichlayer API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
# Dictionary representation of test credentials for input fields
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
@@ -1,527 +0,0 @@
|
||||
"""
|
||||
Block definitions for Enrichlayer API integration.
|
||||
|
||||
This module implements blocks for interacting with the Enrichlayer API,
|
||||
which provides access to LinkedIn profile data and related information.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
from ._api import (
|
||||
EnrichlayerClient,
|
||||
Experience,
|
||||
FallbackToCache,
|
||||
PersonLookupResponse,
|
||||
PersonProfileResponse,
|
||||
RoleLookupResponse,
|
||||
UseCache,
|
||||
)
|
||||
from ._auth import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT, EnrichlayerCredentialsInput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GetLinkedinProfileBlock(Block):
|
||||
"""Block to fetch LinkedIn profile data using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
"""Input schema for GetLinkedinProfileBlock."""
|
||||
|
||||
linkedin_url: str = SchemaField(
|
||||
description="LinkedIn profile URL to fetch data from",
|
||||
placeholder="https://www.linkedin.com/in/username/",
|
||||
)
|
||||
fallback_to_cache: FallbackToCache = SchemaField(
|
||||
description="Cache usage if live fetch fails",
|
||||
default=FallbackToCache.ON_ERROR,
|
||||
advanced=True,
|
||||
)
|
||||
use_cache: UseCache = SchemaField(
|
||||
description="Cache utilization strategy",
|
||||
default=UseCache.IF_PRESENT,
|
||||
advanced=True,
|
||||
)
|
||||
include_skills: bool = SchemaField(
|
||||
description="Include skills data",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
include_inferred_salary: bool = SchemaField(
|
||||
description="Include inferred salary data",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
include_personal_email: bool = SchemaField(
|
||||
description="Include personal email",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
include_personal_contact_number: bool = SchemaField(
|
||||
description="Include personal contact number",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
include_social_media: bool = SchemaField(
|
||||
description="Include social media profiles",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
include_extra: bool = SchemaField(
|
||||
description="Include additional data",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
credentials: EnrichlayerCredentialsInput = CredentialsField(
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
"""Output schema for GetLinkedinProfileBlock."""
|
||||
|
||||
profile: PersonProfileResponse = SchemaField(
|
||||
description="LinkedIn profile data"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GetLinkedinProfileBlock."""
|
||||
super().__init__(
|
||||
id="f6e0ac73-4f1d-4acb-b4b7-b67066c5984e",
|
||||
description="Fetch LinkedIn profile data using Enrichlayer",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=GetLinkedinProfileBlock.Input,
|
||||
output_schema=GetLinkedinProfileBlock.Output,
|
||||
test_input={
|
||||
"linkedin_url": "https://www.linkedin.com/in/williamhgates/",
|
||||
"include_skills": True,
|
||||
"include_social_media": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"profile",
|
||||
PersonProfileResponse(
|
||||
public_identifier="williamhgates",
|
||||
full_name="Bill Gates",
|
||||
occupation="Co-chair at Bill & Melinda Gates Foundation",
|
||||
experiences=[
|
||||
Experience(
|
||||
company="Bill & Melinda Gates Foundation",
|
||||
title="Co-chair",
|
||||
starts_at={"year": 2000},
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"_fetch_profile": lambda *args, **kwargs: PersonProfileResponse(
|
||||
public_identifier="williamhgates",
|
||||
full_name="Bill Gates",
|
||||
occupation="Co-chair at Bill & Melinda Gates Foundation",
|
||||
experiences=[
|
||||
Experience(
|
||||
company="Bill & Melinda Gates Foundation",
|
||||
title="Co-chair",
|
||||
starts_at={"year": 2000},
|
||||
)
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _fetch_profile(
|
||||
credentials: APIKeyCredentials,
|
||||
linkedin_url: str,
|
||||
fallback_to_cache: FallbackToCache = FallbackToCache.ON_ERROR,
|
||||
use_cache: UseCache = UseCache.IF_PRESENT,
|
||||
include_skills: bool = False,
|
||||
include_inferred_salary: bool = False,
|
||||
include_personal_email: bool = False,
|
||||
include_personal_contact_number: bool = False,
|
||||
include_social_media: bool = False,
|
||||
include_extra: bool = False,
|
||||
):
|
||||
client = EnrichlayerClient(credentials)
|
||||
profile = await client.fetch_profile(
|
||||
linkedin_url=linkedin_url,
|
||||
fallback_to_cache=fallback_to_cache,
|
||||
use_cache=use_cache,
|
||||
include_skills=include_skills,
|
||||
include_inferred_salary=include_inferred_salary,
|
||||
include_personal_email=include_personal_email,
|
||||
include_personal_contact_number=include_personal_contact_number,
|
||||
include_social_media=include_social_media,
|
||||
include_extra=include_extra,
|
||||
)
|
||||
return profile
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Run the block to fetch LinkedIn profile data.
|
||||
|
||||
Args:
|
||||
input_data: Input parameters for the block
|
||||
credentials: API key credentials for Enrichlayer
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Yields:
|
||||
Tuples of (output_name, output_value)
|
||||
"""
|
||||
try:
|
||||
profile = await self._fetch_profile(
|
||||
credentials=credentials,
|
||||
linkedin_url=input_data.linkedin_url,
|
||||
fallback_to_cache=input_data.fallback_to_cache,
|
||||
use_cache=input_data.use_cache,
|
||||
include_skills=input_data.include_skills,
|
||||
include_inferred_salary=input_data.include_inferred_salary,
|
||||
include_personal_email=input_data.include_personal_email,
|
||||
include_personal_contact_number=input_data.include_personal_contact_number,
|
||||
include_social_media=input_data.include_social_media,
|
||||
include_extra=input_data.include_extra,
|
||||
)
|
||||
yield "profile", profile
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching LinkedIn profile: {str(e)}")
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class LinkedinPersonLookupBlock(Block):
|
||||
"""Block to look up LinkedIn profiles by person's information using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
"""Input schema for LinkedinPersonLookupBlock."""
|
||||
|
||||
first_name: str = SchemaField(
|
||||
description="Person's first name",
|
||||
placeholder="John",
|
||||
advanced=False,
|
||||
)
|
||||
last_name: str | None = SchemaField(
|
||||
description="Person's last name",
|
||||
placeholder="Doe",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
company_domain: str = SchemaField(
|
||||
description="Domain of the company they work for (optional)",
|
||||
placeholder="example.com",
|
||||
advanced=False,
|
||||
)
|
||||
location: Optional[str] = SchemaField(
|
||||
description="Person's location (optional)",
|
||||
placeholder="San Francisco",
|
||||
default=None,
|
||||
)
|
||||
title: Optional[str] = SchemaField(
|
||||
description="Person's job title (optional)",
|
||||
placeholder="CEO",
|
||||
default=None,
|
||||
)
|
||||
include_similarity_checks: bool = SchemaField(
|
||||
description="Include similarity checks",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
enrich_profile: bool = SchemaField(
|
||||
description="Enrich the profile with additional data",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
credentials: EnrichlayerCredentialsInput = CredentialsField(
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
"""Output schema for LinkedinPersonLookupBlock."""
|
||||
|
||||
lookup_result: PersonLookupResponse = SchemaField(
|
||||
description="LinkedIn profile lookup result"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize LinkedinPersonLookupBlock."""
|
||||
super().__init__(
|
||||
id="d237a98a-5c4b-4a1c-b9e3-e6f9a6c81df7",
|
||||
description="Look up LinkedIn profiles by person information using Enrichlayer",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=LinkedinPersonLookupBlock.Input,
|
||||
output_schema=LinkedinPersonLookupBlock.Output,
|
||||
test_input={
|
||||
"first_name": "Bill",
|
||||
"last_name": "Gates",
|
||||
"company_domain": "gatesfoundation.org",
|
||||
"include_similarity_checks": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"lookup_result",
|
||||
PersonLookupResponse(
|
||||
url="https://www.linkedin.com/in/williamhgates/",
|
||||
name_similarity_score=0.93,
|
||||
company_similarity_score=0.83,
|
||||
title_similarity_score=0.3,
|
||||
location_similarity_score=0.20,
|
||||
),
|
||||
)
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"_lookup_person": lambda *args, **kwargs: PersonLookupResponse(
|
||||
url="https://www.linkedin.com/in/williamhgates/",
|
||||
name_similarity_score=0.93,
|
||||
company_similarity_score=0.83,
|
||||
title_similarity_score=0.3,
|
||||
location_similarity_score=0.20,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _lookup_person(
|
||||
credentials: APIKeyCredentials,
|
||||
first_name: str,
|
||||
company_domain: str,
|
||||
last_name: str | None = None,
|
||||
location: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
include_similarity_checks: bool = False,
|
||||
enrich_profile: bool = False,
|
||||
):
|
||||
client = EnrichlayerClient(credentials=credentials)
|
||||
lookup_result = await client.lookup_person(
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
company_domain=company_domain,
|
||||
location=location,
|
||||
title=title,
|
||||
include_similarity_checks=include_similarity_checks,
|
||||
enrich_profile=enrich_profile,
|
||||
)
|
||||
return lookup_result
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Run the block to look up LinkedIn profiles.
|
||||
|
||||
Args:
|
||||
input_data: Input parameters for the block
|
||||
credentials: API key credentials for Enrichlayer
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Yields:
|
||||
Tuples of (output_name, output_value)
|
||||
"""
|
||||
try:
|
||||
lookup_result = await self._lookup_person(
|
||||
credentials=credentials,
|
||||
first_name=input_data.first_name,
|
||||
last_name=input_data.last_name,
|
||||
company_domain=input_data.company_domain,
|
||||
location=input_data.location,
|
||||
title=input_data.title,
|
||||
include_similarity_checks=input_data.include_similarity_checks,
|
||||
enrich_profile=input_data.enrich_profile,
|
||||
)
|
||||
yield "lookup_result", lookup_result
|
||||
except Exception as e:
|
||||
logger.error(f"Error looking up LinkedIn profile: {str(e)}")
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class LinkedinRoleLookupBlock(Block):
|
||||
"""Block to look up LinkedIn profiles by role in a company using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
"""Input schema for LinkedinRoleLookupBlock."""
|
||||
|
||||
role: str = SchemaField(
|
||||
description="Role title (e.g., CEO, CTO)",
|
||||
placeholder="CEO",
|
||||
)
|
||||
company_name: str = SchemaField(
|
||||
description="Name of the company",
|
||||
placeholder="Microsoft",
|
||||
)
|
||||
enrich_profile: bool = SchemaField(
|
||||
description="Enrich the profile with additional data",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
credentials: EnrichlayerCredentialsInput = CredentialsField(
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
"""Output schema for LinkedinRoleLookupBlock."""
|
||||
|
||||
role_lookup_result: RoleLookupResponse = SchemaField(
|
||||
description="LinkedIn role lookup result"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize LinkedinRoleLookupBlock."""
|
||||
super().__init__(
|
||||
id="3b9fc742-06d4-49c7-b5ce-7e302dd7c8a7",
|
||||
description="Look up LinkedIn profiles by role in a company using Enrichlayer",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=LinkedinRoleLookupBlock.Input,
|
||||
output_schema=LinkedinRoleLookupBlock.Output,
|
||||
test_input={
|
||||
"role": "Co-chair",
|
||||
"company_name": "Gates Foundation",
|
||||
"enrich_profile": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"role_lookup_result",
|
||||
RoleLookupResponse(
|
||||
linkedin_profile_url="https://www.linkedin.com/in/williamhgates/",
|
||||
),
|
||||
)
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"_lookup_role": lambda *args, **kwargs: RoleLookupResponse(
|
||||
linkedin_profile_url="https://www.linkedin.com/in/williamhgates/",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _lookup_role(
|
||||
credentials: APIKeyCredentials,
|
||||
role: str,
|
||||
company_name: str,
|
||||
enrich_profile: bool = False,
|
||||
):
|
||||
client = EnrichlayerClient(credentials=credentials)
|
||||
role_lookup_result = await client.lookup_role(
|
||||
role=role,
|
||||
company_name=company_name,
|
||||
enrich_profile=enrich_profile,
|
||||
)
|
||||
return role_lookup_result
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Run the block to look up LinkedIn profiles by role.
|
||||
|
||||
Args:
|
||||
input_data: Input parameters for the block
|
||||
credentials: API key credentials for Enrichlayer
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Yields:
|
||||
Tuples of (output_name, output_value)
|
||||
"""
|
||||
try:
|
||||
role_lookup_result = await self._lookup_role(
|
||||
credentials=credentials,
|
||||
role=input_data.role,
|
||||
company_name=input_data.company_name,
|
||||
enrich_profile=input_data.enrich_profile,
|
||||
)
|
||||
yield "role_lookup_result", role_lookup_result
|
||||
except Exception as e:
|
||||
logger.error(f"Error looking up role in company: {str(e)}")
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GetLinkedinProfilePictureBlock(Block):
|
||||
"""Block to get LinkedIn profile pictures using Enrichlayer API."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
"""Input schema for GetLinkedinProfilePictureBlock."""
|
||||
|
||||
linkedin_profile_url: str = SchemaField(
|
||||
description="LinkedIn profile URL",
|
||||
placeholder="https://www.linkedin.com/in/username/",
|
||||
)
|
||||
credentials: EnrichlayerCredentialsInput = CredentialsField(
|
||||
description="Enrichlayer API credentials"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
"""Output schema for GetLinkedinProfilePictureBlock."""
|
||||
|
||||
profile_picture_url: MediaFileType = SchemaField(
|
||||
description="LinkedIn profile picture URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GetLinkedinProfilePictureBlock."""
|
||||
super().__init__(
|
||||
id="68d5a942-9b3f-4e9a-b7c1-d96ea4321f0d",
|
||||
description="Get LinkedIn profile pictures using Enrichlayer",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=GetLinkedinProfilePictureBlock.Input,
|
||||
output_schema=GetLinkedinProfilePictureBlock.Output,
|
||||
test_input={
|
||||
"linkedin_profile_url": "https://www.linkedin.com/in/williamhgates/",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"profile_picture_url",
|
||||
"https://media.licdn.com/dms/image/C4D03AQFj-xjuXrLFSQ/profile-displayphoto-shrink_800_800/0/1576881858598?e=1686787200&v=beta&t=zrQC76QwsfQQIWthfOnrKRBMZ5D-qIAvzLXLmWgYvTk",
|
||||
)
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={
|
||||
"_get_profile_picture": lambda *args, **kwargs: "https://media.licdn.com/dms/image/C4D03AQFj-xjuXrLFSQ/profile-displayphoto-shrink_800_800/0/1576881858598?e=1686787200&v=beta&t=zrQC76QwsfQQIWthfOnrKRBMZ5D-qIAvzLXLmWgYvTk",
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _get_profile_picture(
|
||||
credentials: APIKeyCredentials, linkedin_profile_url: str
|
||||
):
|
||||
client = EnrichlayerClient(credentials=credentials)
|
||||
profile_picture_response = await client.get_profile_picture(
|
||||
linkedin_profile_url=linkedin_profile_url,
|
||||
)
|
||||
return profile_picture_response.profile_picture_url
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Run the block to get LinkedIn profile pictures.
|
||||
|
||||
Args:
|
||||
input_data: Input parameters for the block
|
||||
credentials: API key credentials for Enrichlayer
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Yields:
|
||||
Tuples of (output_name, output_value)
|
||||
"""
|
||||
try:
|
||||
profile_picture = await self._get_profile_picture(
|
||||
credentials=credentials,
|
||||
linkedin_profile_url=input_data.linkedin_profile_url,
|
||||
)
|
||||
yield "profile_picture_url", profile_picture
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting profile picture: {str(e)}")
|
||||
yield "error", str(e)
|
||||
@@ -6,10 +6,10 @@ import hashlib
|
||||
import hmac
|
||||
from enum import Enum
|
||||
|
||||
from backend.data.model import Credentials
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseWebhooksManager,
|
||||
Credentials,
|
||||
ProviderName,
|
||||
Requests,
|
||||
Webhook,
|
||||
@@ -51,9 +51,7 @@ class ExaWebhookManager(BaseWebhooksManager):
|
||||
WEBSET = "webset"
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(
|
||||
cls, webhook: Webhook, request, credentials: Credentials | None
|
||||
) -> tuple[dict, str]:
|
||||
async def validate_payload(cls, webhook: Webhook, request) -> tuple[dict, str]:
|
||||
"""Validate incoming webhook payload and signature."""
|
||||
payload = await request.json()
|
||||
|
||||
|
||||
190
autogpt_platform/backend/backend/blocks/exa/answers.md
Normal file
190
autogpt_platform/backend/backend/blocks/exa/answers.md
Normal file
@@ -0,0 +1,190 @@
|
||||
|
||||
|
||||
Exa home pagelight logo
|
||||
|
||||
Search or ask...
|
||||
⌘K
|
||||
Exa Search
|
||||
Log In
|
||||
API Dashboard
|
||||
Documentation
|
||||
Examples
|
||||
Integrations
|
||||
SDKs
|
||||
Websets
|
||||
Changelog
|
||||
Discord
|
||||
Blog
|
||||
Getting Started
|
||||
|
||||
Overview
|
||||
Quickstart
|
||||
API Reference
|
||||
|
||||
POST
|
||||
Search
|
||||
POST
|
||||
Get contents
|
||||
POST
|
||||
Find similar links
|
||||
POST
|
||||
Answer
|
||||
OpenAPI Specification
|
||||
RAG Quick Start Guide
|
||||
|
||||
RAG with Exa and OpenAI
|
||||
RAG with LangChain
|
||||
OpenAI Exa Wrapper
|
||||
CrewAI agents with Exa
|
||||
RAG with LlamaIndex
|
||||
Tool calling with GPT
|
||||
Tool calling with Claude
|
||||
OpenAI Chat Completions
|
||||
OpenAI Responses API
|
||||
Concepts
|
||||
|
||||
How Exa Search Works
|
||||
The Exa Index
|
||||
Contents retrieval with Exa API
|
||||
Exa's Capabilities Explained
|
||||
FAQs
|
||||
Crawling Subpages with Exa
|
||||
Exa LiveCrawl
|
||||
Admin
|
||||
|
||||
Setting Up and Managing Your Team
|
||||
Rate Limits
|
||||
Enterprise Documentation & Security
|
||||
API Reference
|
||||
Answer
|
||||
Get an LLM answer to a question informed by Exa search results. Fully compatible with OpenAI’s chat completions endpoint - docs here. /answer performs an Exa search and uses an LLM to generate either:
|
||||
|
||||
A direct answer for specific queries. (i.e. “What is the capital of France?” would return “Paris”)
|
||||
A detailed summary with citations for open-ended queries (i.e. “What is the state of ai in healthcare?” would return a summary with citations to relevant sources)
|
||||
The response includes both the generated answer and the sources used to create it. The endpoint also supports streaming (as stream=True), which will returns tokens as they are generated.
|
||||
POST
|
||||
/
|
||||
answer
|
||||
|
||||
Try it
|
||||
Get your Exa API key
|
||||
|
||||
Authorizations
|
||||
|
||||
x-api-key
|
||||
stringheaderrequired
|
||||
API key can be provided either via x-api-key header or Authorization header with Bearer scheme
|
||||
Body
|
||||
application/json
|
||||
|
||||
query
|
||||
stringrequired
|
||||
The question or query to answer.
|
||||
Minimum length: 1
|
||||
Example:
|
||||
"What is the latest valuation of SpaceX?"
|
||||
|
||||
stream
|
||||
booleandefault:false
|
||||
If true, the response is returned as a server-sent events (SSS) stream.
|
||||
|
||||
text
|
||||
booleandefault:false
|
||||
If true, the response includes full text content in the search results
|
||||
|
||||
model
|
||||
enum<string>default:exa
|
||||
The search model to use for the answer. Exa passes only one query to exa, while exa-pro also passes 2 expanded queries to our search model.
|
||||
Available options: exa, exa-pro
|
||||
Response
|
||||
200
|
||||
application/json
|
||||
|
||||
OK
|
||||
|
||||
answer
|
||||
string
|
||||
The generated answer based on search results.
|
||||
Example:
|
||||
"$350 billion."
|
||||
|
||||
citations
|
||||
object[]
|
||||
Search results used to generate the answer.
|
||||
|
||||
Show child attributes
|
||||
|
||||
costDollars
|
||||
object
|
||||
|
||||
Show child attributes
|
||||
Find similar links
|
||||
OpenAPI Specification
|
||||
x
|
||||
discord
|
||||
Powered by Mintlify
|
||||
|
||||
cURL
|
||||
|
||||
Python
|
||||
|
||||
JavaScript
|
||||
|
||||
Copy
|
||||
# pip install exa-py
|
||||
from exa_py import Exa
|
||||
exa = Exa('YOUR_EXA_API_KEY')
|
||||
|
||||
result = exa.answer(
|
||||
"What is the latest valuation of SpaceX?",
|
||||
text=True
|
||||
)
|
||||
|
||||
print(result)
|
||||
|
||||
200
|
||||
|
||||
Copy
|
||||
{
|
||||
"answer": "$350 billion.",
|
||||
"citations": [
|
||||
{
|
||||
"id": "https://www.theguardian.com/science/2024/dec/11/spacex-valued-at-350bn-as-company-agrees-to-buy-shares-from-employees",
|
||||
"url": "https://www.theguardian.com/science/2024/dec/11/spacex-valued-at-350bn-as-company-agrees-to-buy-shares-from-employees",
|
||||
"title": "SpaceX valued at $350bn as company agrees to buy shares from ...",
|
||||
"author": "Dan Milmon",
|
||||
"publishedDate": "2023-11-16T01:36:32.547Z",
|
||||
"text": "SpaceX valued at $350bn as company agrees to buy shares from ...",
|
||||
"image": "https://i.guim.co.uk/img/media/7cfee7e84b24b73c97a079c402642a333ad31e77/0_380_6176_3706/master/6176.jpg?width=1200&height=630&quality=85&auto=format&fit=crop&overlay-align=bottom%2Cleft&overlay-width=100p&overlay-base64=L2ltZy9zdGF0aWMvb3ZlcmxheXMvdGctZGVmYXVsdC5wbmc&enable=upscale&s=71ebb2fbf458c185229d02d380c01530",
|
||||
"favicon": "https://assets.guim.co.uk/static/frontend/icons/homescreen/apple-touch-icon.svg"
|
||||
}
|
||||
],
|
||||
"costDollars": {
|
||||
"total": 0.005,
|
||||
"breakDown": [
|
||||
{
|
||||
"search": 0.005,
|
||||
"contents": 0,
|
||||
"breakdown": {
|
||||
"keywordSearch": 0,
|
||||
"neuralSearch": 0.005,
|
||||
"contentText": 0,
|
||||
"contentHighlight": 0,
|
||||
"contentSummary": 0
|
||||
}
|
||||
}
|
||||
],
|
||||
"perRequestPrices": {
|
||||
"neuralSearch_1_25_results": 0.005,
|
||||
"neuralSearch_26_100_results": 0.025,
|
||||
"neuralSearch_100_plus_results": 1,
|
||||
"keywordSearch_1_100_results": 0.0025,
|
||||
"keywordSearch_100_plus_results": 3
|
||||
},
|
||||
"perPagePrices": {
|
||||
"contentText": 0.001,
|
||||
"contentHighlight": 0.001,
|
||||
"contentSummary": 0.001
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -119,3 +119,6 @@ class ExaAnswerBlock(Block):
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "answer", ""
|
||||
yield "citations", []
|
||||
yield "cost_dollars", {}
|
||||
|
||||
@@ -1,247 +0,0 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# Enum definitions based on available options
|
||||
class WebsetStatus(str, Enum):
|
||||
IDLE = "idle"
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
|
||||
|
||||
class WebsetSearchStatus(str, Enum):
|
||||
CREATED = "created"
|
||||
# Add more if known, based on example it's "created"
|
||||
|
||||
|
||||
class ImportStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class ImportFormat(str, Enum):
|
||||
CSV = "csv"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class EnrichmentStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class EnrichmentFormat(str, Enum):
|
||||
TEXT = "text"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorStatus(str, Enum):
|
||||
ENABLED = "enabled"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorBehaviorType(str, Enum):
|
||||
SEARCH = "search"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class MonitorRunStatus(str, Enum):
|
||||
CREATED = "created"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class CanceledReason(str, Enum):
|
||||
WEBSET_DELETED = "webset_deleted"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class FailedReason(str, Enum):
|
||||
INVALID_FORMAT = "invalid_format"
|
||||
# Add more if known
|
||||
|
||||
|
||||
class Confidence(str, Enum):
|
||||
HIGH = "high"
|
||||
# Add more if known
|
||||
|
||||
|
||||
# Nested models
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
class Criterion(BaseModel):
|
||||
description: str
|
||||
successRate: Optional[int] = None
|
||||
|
||||
|
||||
class ExcludeItem(BaseModel):
|
||||
source: str = Field(default="import")
|
||||
id: str
|
||||
|
||||
|
||||
class Relationship(BaseModel):
|
||||
definition: str
|
||||
limit: Optional[float] = None
|
||||
|
||||
|
||||
class ScopeItem(BaseModel):
|
||||
source: str = Field(default="import")
|
||||
id: str
|
||||
relationship: Optional[Relationship] = None
|
||||
|
||||
|
||||
class Progress(BaseModel):
|
||||
found: int
|
||||
analyzed: int
|
||||
completion: int
|
||||
timeLeft: int
|
||||
|
||||
|
||||
class Bounds(BaseModel):
|
||||
min: int
|
||||
max: int
|
||||
|
||||
|
||||
class Expected(BaseModel):
|
||||
total: int
|
||||
confidence: str = Field(default="high") # Use str or Confidence enum
|
||||
bounds: Bounds
|
||||
|
||||
|
||||
class Recall(BaseModel):
|
||||
expected: Expected
|
||||
reasoning: str
|
||||
|
||||
|
||||
class WebsetSearch(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset_search")
|
||||
status: str = Field(default="created") # Or use WebsetSearchStatus
|
||||
websetId: str
|
||||
query: str
|
||||
entity: Entity
|
||||
criteria: List[Criterion]
|
||||
count: int
|
||||
behavior: str = Field(default="override")
|
||||
exclude: List[ExcludeItem]
|
||||
scope: List[ScopeItem]
|
||||
progress: Progress
|
||||
recall: Recall
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
canceledAt: Optional[datetime] = None
|
||||
canceledReason: Optional[str] = Field(default=None) # Or use CanceledReason
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class ImportEntity(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
class Import(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="import")
|
||||
status: str = Field(default="pending") # Or use ImportStatus
|
||||
format: str = Field(default="csv") # Or use ImportFormat
|
||||
entity: ImportEntity
|
||||
title: str
|
||||
count: int
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
failedReason: Optional[str] = Field(default=None) # Or use FailedReason
|
||||
failedAt: Optional[datetime] = None
|
||||
failedMessage: Optional[str] = None
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Option(BaseModel):
|
||||
label: str
|
||||
|
||||
|
||||
class WebsetEnrichment(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset_enrichment")
|
||||
status: str = Field(default="pending") # Or use EnrichmentStatus
|
||||
websetId: str
|
||||
title: str
|
||||
description: str
|
||||
format: str = Field(default="text") # Or use EnrichmentFormat
|
||||
options: List[Option]
|
||||
instructions: str
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Cadence(BaseModel):
|
||||
cron: str
|
||||
timezone: str = Field(default="Etc/UTC")
|
||||
|
||||
|
||||
class BehaviorConfig(BaseModel):
|
||||
query: Optional[str] = None
|
||||
criteria: Optional[List[Criterion]] = None
|
||||
entity: Optional[Entity] = None
|
||||
count: Optional[int] = None
|
||||
behavior: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class Behavior(BaseModel):
|
||||
type: str = Field(default="search") # Or use MonitorBehaviorType
|
||||
config: BehaviorConfig
|
||||
|
||||
|
||||
class MonitorRun(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="monitor_run")
|
||||
status: str = Field(default="created") # Or use MonitorRunStatus
|
||||
monitorId: str
|
||||
type: str = Field(default="search")
|
||||
completedAt: Optional[datetime] = None
|
||||
failedAt: Optional[datetime] = None
|
||||
failedReason: Optional[str] = None
|
||||
canceledAt: Optional[datetime] = None
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Monitor(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="monitor")
|
||||
status: str = Field(default="enabled") # Or use MonitorStatus
|
||||
websetId: str
|
||||
cadence: Cadence
|
||||
behavior: Behavior
|
||||
lastRun: Optional[MonitorRun] = None
|
||||
nextRunAt: Optional[datetime] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
|
||||
|
||||
class Webset(BaseModel):
|
||||
id: str
|
||||
object: str = Field(default="webset")
|
||||
status: WebsetStatus
|
||||
externalId: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
searches: List[WebsetSearch]
|
||||
imports: List[Import]
|
||||
enrichments: List[WebsetEnrichment]
|
||||
monitors: List[Monitor]
|
||||
streams: List[Any]
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ListWebsets(BaseModel):
|
||||
data: List[Webset]
|
||||
hasMore: bool
|
||||
nextCursor: Optional[str] = None
|
||||
@@ -114,7 +114,6 @@ class ExaWebsetWebhookBlock(Block):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="d0204ed8-8b81-408d-8b8d-ed087a546228",
|
||||
description="Receive webhook notifications for Exa webset events",
|
||||
categories={BlockCategory.INPUT},
|
||||
|
||||
1004
autogpt_platform/backend/backend/blocks/exa/webset_webhook.md
Normal file
1004
autogpt_platform/backend/backend/blocks/exa/webset_webhook.md
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,33 +1,7 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Optional
|
||||
|
||||
from exa_py import Exa
|
||||
from exa_py.websets.types import (
|
||||
CreateCriterionParameters,
|
||||
CreateEnrichmentParameters,
|
||||
CreateWebsetParameters,
|
||||
CreateWebsetParametersSearch,
|
||||
ExcludeItem,
|
||||
Format,
|
||||
ImportItem,
|
||||
ImportSource,
|
||||
Option,
|
||||
ScopeItem,
|
||||
ScopeRelationship,
|
||||
ScopeSourceType,
|
||||
WebsetArticleEntity,
|
||||
WebsetCompanyEntity,
|
||||
WebsetCustomEntity,
|
||||
WebsetPersonEntity,
|
||||
WebsetResearchPaperEntity,
|
||||
WebsetStatus,
|
||||
)
|
||||
from pydantic import Field
|
||||
from typing import Any, Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseModel,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
@@ -38,69 +12,7 @@ from backend.sdk import (
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
class SearchEntityType(str, Enum):
|
||||
COMPANY = "company"
|
||||
PERSON = "person"
|
||||
ARTICLE = "article"
|
||||
RESEARCH_PAPER = "research_paper"
|
||||
CUSTOM = "custom"
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
IMPORT = "import"
|
||||
WEBSET = "webset"
|
||||
|
||||
|
||||
class EnrichmentFormat(str, Enum):
|
||||
TEXT = "text"
|
||||
DATE = "date"
|
||||
NUMBER = "number"
|
||||
OPTIONS = "options"
|
||||
EMAIL = "email"
|
||||
PHONE = "phone"
|
||||
|
||||
|
||||
class Webset(BaseModel):
|
||||
id: str
|
||||
status: WebsetStatus | None = Field(..., title="WebsetStatus")
|
||||
"""
|
||||
The status of the webset
|
||||
"""
|
||||
external_id: Annotated[Optional[str], Field(alias="externalId")] = None
|
||||
"""
|
||||
The external identifier for the webset
|
||||
NOTE: Returning dict to avoid ui crashing due to nested objects
|
||||
"""
|
||||
searches: List[dict[str, Any]] | None = None
|
||||
"""
|
||||
The searches that have been performed on the webset.
|
||||
NOTE: Returning dict to avoid ui crashing due to nested objects
|
||||
"""
|
||||
enrichments: List[dict[str, Any]] | None = None
|
||||
"""
|
||||
The Enrichments to apply to the Webset Items.
|
||||
NOTE: Returning dict to avoid ui crashing due to nested objects
|
||||
"""
|
||||
monitors: List[dict[str, Any]] | None = None
|
||||
"""
|
||||
The Monitors for the Webset.
|
||||
NOTE: Returning dict to avoid ui crashing due to nested objects
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = {}
|
||||
"""
|
||||
Set of key-value pairs you want to associate with this object.
|
||||
"""
|
||||
created_at: Annotated[datetime, Field(alias="createdAt")] | None = None
|
||||
"""
|
||||
The date and time the webset was created
|
||||
"""
|
||||
updated_at: Annotated[datetime, Field(alias="updatedAt")] | None = None
|
||||
"""
|
||||
The date and time the webset was last updated
|
||||
"""
|
||||
from .helpers import WebsetEnrichmentConfig, WebsetSearchConfig
|
||||
|
||||
|
||||
class ExaCreateWebsetBlock(Block):
|
||||
@@ -108,121 +20,40 @@ class ExaCreateWebsetBlock(Block):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
|
||||
# Search parameters (flattened)
|
||||
search_query: str = SchemaField(
|
||||
description="Your search query. Use this to describe what you are looking for. Any URL provided will be crawled and used as context for the search.",
|
||||
placeholder="Marketing agencies based in the US, that focus on consumer products",
|
||||
search: WebsetSearchConfig = SchemaField(
|
||||
description="Initial search configuration for the Webset"
|
||||
)
|
||||
search_count: Optional[int] = SchemaField(
|
||||
default=10,
|
||||
description="Number of items the search will attempt to find. The actual number of items found may be less than this number depending on the search complexity.",
|
||||
ge=1,
|
||||
le=1000,
|
||||
)
|
||||
search_entity_type: SearchEntityType = SchemaField(
|
||||
default=SearchEntityType.AUTO,
|
||||
description="Entity type: 'company', 'person', 'article', 'research_paper', or 'custom'. If not provided, we automatically detect the entity from the query.",
|
||||
advanced=True,
|
||||
)
|
||||
search_entity_description: Optional[str] = SchemaField(
|
||||
enrichments: Optional[list[WebsetEnrichmentConfig]] = SchemaField(
|
||||
default=None,
|
||||
description="Description for custom entity type (required when search_entity_type is 'custom')",
|
||||
description="Enrichments to apply to Webset items",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Search criteria (flattened)
|
||||
search_criteria: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of criteria descriptions that every item will be evaluated against. If not provided, we automatically detect the criteria from the query.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Search exclude sources (flattened)
|
||||
search_exclude_sources: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source IDs (imports or websets) to exclude from search results",
|
||||
advanced=True,
|
||||
)
|
||||
search_exclude_types: list[SearchType] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source types corresponding to exclude sources ('import' or 'webset')",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Search scope sources (flattened)
|
||||
search_scope_sources: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source IDs (imports or websets) to limit search scope to",
|
||||
advanced=True,
|
||||
)
|
||||
search_scope_types: list[SearchType] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source types corresponding to scope sources ('import' or 'webset')",
|
||||
advanced=True,
|
||||
)
|
||||
search_scope_relationships: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of relationship definitions for hop searches (optional, one per scope source)",
|
||||
advanced=True,
|
||||
)
|
||||
search_scope_relationship_limits: list[int] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of limits on the number of related entities to find (optional, one per scope relationship)",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Import parameters (flattened)
|
||||
import_sources: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source IDs to import from",
|
||||
advanced=True,
|
||||
)
|
||||
import_types: list[SearchType] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of source types corresponding to import sources ('import' or 'webset')",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Enrichment parameters (flattened)
|
||||
enrichment_descriptions: list[str] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of enrichment task descriptions to perform on each webset item",
|
||||
advanced=True,
|
||||
)
|
||||
enrichment_formats: list[EnrichmentFormat] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of formats for enrichment responses ('text', 'date', 'number', 'options', 'email', 'phone'). If not specified, we automatically select the best format.",
|
||||
advanced=True,
|
||||
)
|
||||
enrichment_options: list[list[str]] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of option lists for enrichments with 'options' format. Each inner list contains the option labels.",
|
||||
advanced=True,
|
||||
)
|
||||
enrichment_metadata: list[dict] = SchemaField(
|
||||
default_factory=list,
|
||||
description="List of metadata dictionaries for enrichments",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# Webset metadata
|
||||
external_id: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="External identifier for the webset. You can use this to reference the webset by your own internal identifiers.",
|
||||
description="External identifier for the webset",
|
||||
placeholder="my-webset-123",
|
||||
advanced=True,
|
||||
)
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default_factory=dict,
|
||||
default=None,
|
||||
description="Key-value pairs to associate with this webset",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset: Webset = SchemaField(
|
||||
webset_id: str = SchemaField(
|
||||
description="The unique identifier for the created webset"
|
||||
)
|
||||
status: str = SchemaField(description="The status of the webset")
|
||||
external_id: Optional[str] = SchemaField(
|
||||
description="The external identifier for the webset", default=None
|
||||
)
|
||||
created_at: str = SchemaField(
|
||||
description="The date and time the webset was created"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -236,171 +67,44 @@ class ExaCreateWebsetBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/websets/v0/websets"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
exa = Exa(credentials.api_key.get_secret_value())
|
||||
# Build the payload
|
||||
payload: dict[str, Any] = {
|
||||
"search": input_data.search.model_dump(exclude_none=True),
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Build entity (if explicitly provided)
|
||||
# ------------------------------------------------------------
|
||||
entity = None
|
||||
if input_data.search_entity_type == SearchEntityType.COMPANY:
|
||||
entity = WebsetCompanyEntity(type="company")
|
||||
elif input_data.search_entity_type == SearchEntityType.PERSON:
|
||||
entity = WebsetPersonEntity(type="person")
|
||||
elif input_data.search_entity_type == SearchEntityType.ARTICLE:
|
||||
entity = WebsetArticleEntity(type="article")
|
||||
elif input_data.search_entity_type == SearchEntityType.RESEARCH_PAPER:
|
||||
entity = WebsetResearchPaperEntity(type="research_paper")
|
||||
elif (
|
||||
input_data.search_entity_type == SearchEntityType.CUSTOM
|
||||
and input_data.search_entity_description
|
||||
):
|
||||
entity = WebsetCustomEntity(
|
||||
type="custom", description=input_data.search_entity_description
|
||||
)
|
||||
# Convert enrichments to API format
|
||||
if input_data.enrichments:
|
||||
enrichments_data = []
|
||||
for enrichment in input_data.enrichments:
|
||||
enrichments_data.append(enrichment.model_dump(exclude_none=True))
|
||||
payload["enrichments"] = enrichments_data
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Build criteria list
|
||||
# ------------------------------------------------------------
|
||||
criteria = None
|
||||
if input_data.search_criteria:
|
||||
criteria = [
|
||||
CreateCriterionParameters(description=item)
|
||||
for item in input_data.search_criteria
|
||||
]
|
||||
if input_data.external_id:
|
||||
payload["externalId"] = input_data.external_id
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Build exclude sources list
|
||||
# ------------------------------------------------------------
|
||||
exclude_items = None
|
||||
if input_data.search_exclude_sources:
|
||||
exclude_items = []
|
||||
for idx, src_id in enumerate(input_data.search_exclude_sources):
|
||||
src_type = None
|
||||
if input_data.search_exclude_types and idx < len(
|
||||
input_data.search_exclude_types
|
||||
):
|
||||
src_type = input_data.search_exclude_types[idx]
|
||||
# Default to IMPORT if type missing
|
||||
if src_type == SearchType.WEBSET:
|
||||
source_enum = ImportSource.webset
|
||||
else:
|
||||
source_enum = ImportSource.import_
|
||||
exclude_items.append(ExcludeItem(source=source_enum, id=src_id))
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Build scope list
|
||||
# ------------------------------------------------------------
|
||||
scope_items = None
|
||||
if input_data.search_scope_sources:
|
||||
scope_items = []
|
||||
for idx, src_id in enumerate(input_data.search_scope_sources):
|
||||
src_type = None
|
||||
if input_data.search_scope_types and idx < len(
|
||||
input_data.search_scope_types
|
||||
):
|
||||
src_type = input_data.search_scope_types[idx]
|
||||
relationship = None
|
||||
if input_data.search_scope_relationships and idx < len(
|
||||
input_data.search_scope_relationships
|
||||
):
|
||||
rel_def = input_data.search_scope_relationships[idx]
|
||||
lim = None
|
||||
if input_data.search_scope_relationship_limits and idx < len(
|
||||
input_data.search_scope_relationship_limits
|
||||
):
|
||||
lim = input_data.search_scope_relationship_limits[idx]
|
||||
relationship = ScopeRelationship(definition=rel_def, limit=lim)
|
||||
if src_type == SearchType.WEBSET:
|
||||
src_enum = ScopeSourceType.webset
|
||||
else:
|
||||
src_enum = ScopeSourceType.import_
|
||||
scope_items.append(
|
||||
ScopeItem(source=src_enum, id=src_id, relationship=relationship)
|
||||
)
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Assemble search parameters (only if a query is provided)
|
||||
# ------------------------------------------------------------
|
||||
search_params = None
|
||||
if input_data.search_query:
|
||||
search_params = CreateWebsetParametersSearch(
|
||||
query=input_data.search_query,
|
||||
count=input_data.search_count,
|
||||
entity=entity,
|
||||
criteria=criteria,
|
||||
exclude=exclude_items,
|
||||
scope=scope_items,
|
||||
)
|
||||
yield "webset_id", data.get("id", "")
|
||||
yield "status", data.get("status", "")
|
||||
yield "external_id", data.get("externalId")
|
||||
yield "created_at", data.get("createdAt", "")
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Build imports list
|
||||
# ------------------------------------------------------------
|
||||
imports_params = None
|
||||
if input_data.import_sources:
|
||||
imports_params = []
|
||||
for idx, src_id in enumerate(input_data.import_sources):
|
||||
src_type = None
|
||||
if input_data.import_types and idx < len(input_data.import_types):
|
||||
src_type = input_data.import_types[idx]
|
||||
if src_type == SearchType.WEBSET:
|
||||
source_enum = ImportSource.webset
|
||||
else:
|
||||
source_enum = ImportSource.import_
|
||||
imports_params.append(ImportItem(source=source_enum, id=src_id))
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Build enrichment list
|
||||
# ------------------------------------------------------------
|
||||
enrichments_params = None
|
||||
if input_data.enrichment_descriptions:
|
||||
enrichments_params = []
|
||||
for idx, desc in enumerate(input_data.enrichment_descriptions):
|
||||
fmt = None
|
||||
if input_data.enrichment_formats and idx < len(
|
||||
input_data.enrichment_formats
|
||||
):
|
||||
fmt_enum = input_data.enrichment_formats[idx]
|
||||
if fmt_enum is not None:
|
||||
fmt = Format(
|
||||
fmt_enum.value if isinstance(fmt_enum, Enum) else fmt_enum
|
||||
)
|
||||
options_list = None
|
||||
if input_data.enrichment_options and idx < len(
|
||||
input_data.enrichment_options
|
||||
):
|
||||
raw_opts = input_data.enrichment_options[idx]
|
||||
if raw_opts:
|
||||
options_list = [Option(label=o) for o in raw_opts]
|
||||
metadata_obj = None
|
||||
if input_data.enrichment_metadata and idx < len(
|
||||
input_data.enrichment_metadata
|
||||
):
|
||||
metadata_obj = input_data.enrichment_metadata[idx]
|
||||
enrichments_params.append(
|
||||
CreateEnrichmentParameters(
|
||||
description=desc,
|
||||
format=fmt,
|
||||
options=options_list,
|
||||
metadata=metadata_obj,
|
||||
)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Create the webset
|
||||
# ------------------------------------------------------------
|
||||
webset = exa.websets.create(
|
||||
params=CreateWebsetParameters(
|
||||
search=search_params,
|
||||
imports=imports_params,
|
||||
enrichments=enrichments_params,
|
||||
external_id=input_data.external_id,
|
||||
metadata=input_data.metadata,
|
||||
)
|
||||
)
|
||||
|
||||
# Use alias field names returned from Exa SDK so that nested models validate correctly
|
||||
yield "webset", Webset.model_validate(webset.model_dump(by_alias=True))
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "webset_id", ""
|
||||
yield "status", ""
|
||||
yield "created_at", ""
|
||||
|
||||
|
||||
class ExaUpdateWebsetBlock(Block):
|
||||
@@ -479,11 +183,6 @@ class ExaListWebsetsBlock(Block):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
trigger: Any | None = SchemaField(
|
||||
default=None,
|
||||
description="Trigger for the webset, value is ignored!",
|
||||
advanced=False,
|
||||
)
|
||||
cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor for pagination through results",
|
||||
@@ -498,9 +197,7 @@ class ExaListWebsetsBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
websets: list[Webset] = SchemaField(
|
||||
description="List of websets", default_factory=list
|
||||
)
|
||||
websets: list = SchemaField(description="List of websets", default_factory=list)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more results to paginate through",
|
||||
default=False,
|
||||
@@ -558,6 +255,9 @@ class ExaGetWebsetBlock(Block):
|
||||
description="The ID or external ID of the Webset to retrieve",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
expand_items: bool = SchemaField(
|
||||
default=False, description="Include items in the response", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: str = SchemaField(description="The unique identifier for the webset")
|
||||
@@ -609,8 +309,12 @@ class ExaGetWebsetBlock(Block):
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
params = {}
|
||||
if input_data.expand_items:
|
||||
params["expand[]"] = "items"
|
||||
|
||||
try:
|
||||
response = await Requests().get(url, headers=headers)
|
||||
response = await Requests().get(url, headers=headers, params=params)
|
||||
data = response.json()
|
||||
|
||||
yield "webset_id", data.get("id", "")
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
# Example Blocks Deployment Guide
|
||||
|
||||
## Overview
|
||||
|
||||
Example blocks are disabled by default in production environments to keep the production block list clean and focused on real functionality. This guide explains how to control the visibility of example blocks.
|
||||
|
||||
## Configuration
|
||||
|
||||
Example blocks are controlled by the `ENABLE_EXAMPLE_BLOCKS` setting:
|
||||
|
||||
- **Default**: `false` (example blocks are hidden)
|
||||
- **Development**: Set to `true` to show example blocks
|
||||
|
||||
## How to Enable/Disable
|
||||
|
||||
### Method 1: Environment Variable (Recommended)
|
||||
|
||||
Add to your `.env` file:
|
||||
|
||||
```bash
|
||||
# Enable example blocks in development
|
||||
ENABLE_EXAMPLE_BLOCKS=true
|
||||
|
||||
# Disable example blocks in production (default)
|
||||
ENABLE_EXAMPLE_BLOCKS=false
|
||||
```
|
||||
|
||||
### Method 2: Configuration File
|
||||
|
||||
If you're using a `config.json` file:
|
||||
|
||||
```json
|
||||
{
|
||||
"enable_example_blocks": true
|
||||
}
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
The setting is checked in `backend/blocks/__init__.py` during the block loading process:
|
||||
|
||||
1. The `load_all_blocks()` function reads the `enable_example_blocks` setting from `Config`
|
||||
2. If disabled (default), any Python files in the `examples/` directory are skipped
|
||||
3. If enabled, example blocks are loaded normally
|
||||
|
||||
## Production Deployment
|
||||
|
||||
For production deployments:
|
||||
|
||||
1. **Do not set** `ENABLE_EXAMPLE_BLOCKS` in your production `.env` file (it defaults to `false`)
|
||||
2. Or explicitly set `ENABLE_EXAMPLE_BLOCKS=false` for clarity
|
||||
3. Example blocks will not appear in the block list or be available for use
|
||||
|
||||
## Development Environment
|
||||
|
||||
For local development:
|
||||
|
||||
1. Set `ENABLE_EXAMPLE_BLOCKS=true` in your `.env` file
|
||||
2. Restart your backend server
|
||||
3. Example blocks will be available for testing and demonstration
|
||||
|
||||
## Verification
|
||||
|
||||
To verify the setting is working:
|
||||
|
||||
```python
|
||||
# Check current setting
|
||||
from backend.util.settings import Config
|
||||
config = Config()
|
||||
print(f"Example blocks enabled: {config.enable_example_blocks}")
|
||||
|
||||
# Check loaded blocks
|
||||
from backend.blocks import load_all_blocks
|
||||
blocks = load_all_blocks()
|
||||
example_blocks = [b for b in blocks.values() if 'examples' in b.__module__]
|
||||
print(f"Example blocks loaded: {len(example_blocks)}")
|
||||
```
|
||||
|
||||
## Security Note
|
||||
|
||||
Example blocks are for demonstration purposes only and may not follow production security standards. Always keep them disabled in production environments.
|
||||
@@ -1,8 +0,0 @@
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
firecrawl = (
|
||||
ProviderBuilder("firecrawl")
|
||||
.with_api_key("FIRECRAWL_API_KEY", "Firecrawl API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
@@ -1,114 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp, ScrapeOptions
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import firecrawl
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
|
||||
|
||||
class FirecrawlCrawlBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
url: str = SchemaField(description="The URL to crawl")
|
||||
limit: int = SchemaField(description="The number of pages to crawl", default=10)
|
||||
only_main_content: bool = SchemaField(
|
||||
description="Only return the main content of the page excluding headers, navs, footers, etc.",
|
||||
default=True,
|
||||
)
|
||||
max_age: int = SchemaField(
|
||||
description="The maximum age of the page in milliseconds - default is 1 hour",
|
||||
default=3600000,
|
||||
)
|
||||
wait_for: int = SchemaField(
|
||||
description="Specify a delay in milliseconds before fetching the content, allowing the page sufficient time to load.",
|
||||
default=0,
|
||||
)
|
||||
formats: list[ScrapeFormat] = SchemaField(
|
||||
description="The format of the crawl", default=[ScrapeFormat.MARKDOWN]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
data: list[dict[str, Any]] = SchemaField(description="The result of the crawl")
|
||||
markdown: str = SchemaField(description="The markdown of the crawl")
|
||||
html: str = SchemaField(description="The html of the crawl")
|
||||
raw_html: str = SchemaField(description="The raw html of the crawl")
|
||||
links: list[str] = SchemaField(description="The links of the crawl")
|
||||
screenshot: str = SchemaField(description="The screenshot of the crawl")
|
||||
screenshot_full_page: str = SchemaField(
|
||||
description="The screenshot full page of the crawl"
|
||||
)
|
||||
json_data: dict[str, Any] = SchemaField(
|
||||
description="The json data of the crawl"
|
||||
)
|
||||
change_tracking: dict[str, Any] = SchemaField(
|
||||
description="The change tracking of the crawl"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bdbbaba0-03b7-4971-970e-699e2de6015e",
|
||||
description="Firecrawl crawls websites to extract comprehensive data while bypassing blockers.",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Sync call
|
||||
crawl_result = app.crawl_url(
|
||||
input_data.url,
|
||||
limit=input_data.limit,
|
||||
scrape_options=ScrapeOptions(
|
||||
formats=[format.value for format in input_data.formats],
|
||||
onlyMainContent=input_data.only_main_content,
|
||||
maxAge=input_data.max_age,
|
||||
waitFor=input_data.wait_for,
|
||||
),
|
||||
)
|
||||
yield "data", crawl_result.data
|
||||
|
||||
for data in crawl_result.data:
|
||||
for f in input_data.formats:
|
||||
if f == ScrapeFormat.MARKDOWN:
|
||||
yield "markdown", data.markdown
|
||||
elif f == ScrapeFormat.HTML:
|
||||
yield "html", data.html
|
||||
elif f == ScrapeFormat.RAW_HTML:
|
||||
yield "raw_html", data.rawHtml
|
||||
elif f == ScrapeFormat.LINKS:
|
||||
yield "links", data.links
|
||||
elif f == ScrapeFormat.SCREENSHOT:
|
||||
yield "screenshot", data.screenshot
|
||||
elif f == ScrapeFormat.SCREENSHOT_FULL_PAGE:
|
||||
yield "screenshot_full_page", data.screenshot
|
||||
elif f == ScrapeFormat.CHANGE_TRACKING:
|
||||
yield "change_tracking", data.changeTracking
|
||||
elif f == ScrapeFormat.JSON:
|
||||
yield "json", data.json
|
||||
@@ -1,66 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockCost,
|
||||
BlockCostType,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
|
||||
from ._config import firecrawl
|
||||
|
||||
|
||||
@cost(BlockCost(2, BlockCostType.RUN))
|
||||
class FirecrawlExtractBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
urls: list[str] = SchemaField(
|
||||
description="The URLs to crawl - at least one is required. Wildcards are supported. (/*)"
|
||||
)
|
||||
prompt: str | None = SchemaField(
|
||||
description="The prompt to use for the crawl", default=None, advanced=False
|
||||
)
|
||||
output_schema: dict | None = SchemaField(
|
||||
description="A Json Schema describing the output structure if more rigid structure is desired.",
|
||||
default=None,
|
||||
)
|
||||
enable_web_search: bool = SchemaField(
|
||||
description="When true, extraction can follow links outside the specified domain.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
data: dict[str, Any] = SchemaField(description="The result of the crawl")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d1774756-4d9e-40e6-bab1-47ec0ccd81b2",
|
||||
description="Firecrawl crawls websites to extract comprehensive data while bypassing blockers.",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
extract_result = app.extract(
|
||||
urls=input_data.urls,
|
||||
prompt=input_data.prompt,
|
||||
schema=input_data.output_schema,
|
||||
enable_web_search=input_data.enable_web_search,
|
||||
)
|
||||
|
||||
yield "data", extract_result.data
|
||||
@@ -1,46 +0,0 @@
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import firecrawl
|
||||
|
||||
|
||||
class FirecrawlMapWebsiteBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
|
||||
url: str = SchemaField(description="The website url to map")
|
||||
|
||||
class Output(BlockSchema):
|
||||
links: list[str] = SchemaField(description="The links of the website")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f0f43e2b-c943-48a0-a7f1-40136ca4d3b9",
|
||||
description="Firecrawl maps a website to extract all the links.",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Sync call
|
||||
map_result = app.map_url(
|
||||
url=input_data.url,
|
||||
)
|
||||
|
||||
yield "links", map_result.links
|
||||
@@ -1,109 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import firecrawl
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
|
||||
|
||||
class FirecrawlScrapeBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
url: str = SchemaField(description="The URL to crawl")
|
||||
limit: int = SchemaField(description="The number of pages to crawl", default=10)
|
||||
only_main_content: bool = SchemaField(
|
||||
description="Only return the main content of the page excluding headers, navs, footers, etc.",
|
||||
default=True,
|
||||
)
|
||||
max_age: int = SchemaField(
|
||||
description="The maximum age of the page in milliseconds - default is 1 hour",
|
||||
default=3600000,
|
||||
)
|
||||
wait_for: int = SchemaField(
|
||||
description="Specify a delay in milliseconds before fetching the content, allowing the page sufficient time to load.",
|
||||
default=200,
|
||||
)
|
||||
formats: list[ScrapeFormat] = SchemaField(
|
||||
description="The format of the crawl", default=[ScrapeFormat.MARKDOWN]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
data: dict[str, Any] = SchemaField(description="The result of the crawl")
|
||||
markdown: str = SchemaField(description="The markdown of the crawl")
|
||||
html: str = SchemaField(description="The html of the crawl")
|
||||
raw_html: str = SchemaField(description="The raw html of the crawl")
|
||||
links: list[str] = SchemaField(description="The links of the crawl")
|
||||
screenshot: str = SchemaField(description="The screenshot of the crawl")
|
||||
screenshot_full_page: str = SchemaField(
|
||||
description="The screenshot full page of the crawl"
|
||||
)
|
||||
json_data: dict[str, Any] = SchemaField(
|
||||
description="The json data of the crawl"
|
||||
)
|
||||
change_tracking: dict[str, Any] = SchemaField(
|
||||
description="The change tracking of the crawl"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ac444320-cf5e-4697-b586-2604c17a3e75",
|
||||
description="Firecrawl scrapes a website to extract comprehensive data while bypassing blockers.",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
scrape_result = app.scrape_url(
|
||||
input_data.url,
|
||||
formats=[format.value for format in input_data.formats],
|
||||
only_main_content=input_data.only_main_content,
|
||||
max_age=input_data.max_age,
|
||||
wait_for=input_data.wait_for,
|
||||
)
|
||||
yield "data", scrape_result
|
||||
|
||||
for f in input_data.formats:
|
||||
if f == ScrapeFormat.MARKDOWN:
|
||||
yield "markdown", scrape_result.markdown
|
||||
elif f == ScrapeFormat.HTML:
|
||||
yield "html", scrape_result.html
|
||||
elif f == ScrapeFormat.RAW_HTML:
|
||||
yield "raw_html", scrape_result.rawHtml
|
||||
elif f == ScrapeFormat.LINKS:
|
||||
yield "links", scrape_result.links
|
||||
elif f == ScrapeFormat.SCREENSHOT:
|
||||
yield "screenshot", scrape_result.screenshot
|
||||
elif f == ScrapeFormat.SCREENSHOT_FULL_PAGE:
|
||||
yield "screenshot_full_page", scrape_result.screenshot
|
||||
elif f == ScrapeFormat.CHANGE_TRACKING:
|
||||
yield "change_tracking", scrape_result.changeTracking
|
||||
elif f == ScrapeFormat.JSON:
|
||||
yield "json", scrape_result.json
|
||||
@@ -1,79 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from firecrawl import FirecrawlApp, ScrapeOptions
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import firecrawl
|
||||
|
||||
|
||||
class ScrapeFormat(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
HTML = "html"
|
||||
RAW_HTML = "rawHtml"
|
||||
LINKS = "links"
|
||||
SCREENSHOT = "screenshot"
|
||||
SCREENSHOT_FULL_PAGE = "screenshot@fullPage"
|
||||
JSON = "json"
|
||||
CHANGE_TRACKING = "changeTracking"
|
||||
|
||||
|
||||
class FirecrawlSearchBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = firecrawl.credentials_field()
|
||||
query: str = SchemaField(description="The query to search for")
|
||||
limit: int = SchemaField(description="The number of pages to crawl", default=10)
|
||||
max_age: int = SchemaField(
|
||||
description="The maximum age of the page in milliseconds - default is 1 hour",
|
||||
default=3600000,
|
||||
)
|
||||
wait_for: int = SchemaField(
|
||||
description="Specify a delay in milliseconds before fetching the content, allowing the page sufficient time to load.",
|
||||
default=200,
|
||||
)
|
||||
formats: list[ScrapeFormat] = SchemaField(
|
||||
description="Returns the content of the search if specified", default=[]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
data: dict[str, Any] = SchemaField(description="The result of the search")
|
||||
site: dict[str, Any] = SchemaField(description="The site of the search")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f8d2f28d-b3a1-405b-804e-418c087d288b",
|
||||
description="Firecrawl searches the web for the given query.",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
# Sync call
|
||||
scrape_result = app.search(
|
||||
input_data.query,
|
||||
limit=input_data.limit,
|
||||
scrape_options=ScrapeOptions(
|
||||
formats=[format.value for format in input_data.formats],
|
||||
maxAge=input_data.max_age,
|
||||
waitFor=input_data.wait_for,
|
||||
),
|
||||
)
|
||||
yield "data", scrape_result
|
||||
for site in scrape_result.data:
|
||||
yield "site", site
|
||||
@@ -129,7 +129,6 @@ class AIImageEditorBlock(Block):
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
result = await self.run_model(
|
||||
@@ -140,7 +139,6 @@ class AIImageEditorBlock(Block):
|
||||
await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.input_image,
|
||||
user_id=user_id,
|
||||
return_content=True,
|
||||
)
|
||||
if input_data.input_image
|
||||
|
||||
13
autogpt_platform/backend/backend/blocks/gem/_config.py
Normal file
13
autogpt_platform/backend/backend/blocks/gem/_config.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Shared configuration for all GEM blocks using the new SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
# Configure the GEM provider once for all blocks
|
||||
gem = (
|
||||
ProviderBuilder("gem")
|
||||
.with_api_key("GEM_API_KEY", "GEM API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
1617
autogpt_platform/backend/backend/blocks/gem/blocks.py
Normal file
1617
autogpt_platform/backend/backend/blocks/gem/blocks.py
Normal file
File diff suppressed because it is too large
Load Diff
2751
autogpt_platform/backend/backend/blocks/gem/gem.md
Normal file
2751
autogpt_platform/backend/backend/blocks/gem/gem.md
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,7 +3,7 @@ import logging
|
||||
from fastapi import Request
|
||||
from strenum import StrEnum
|
||||
|
||||
from backend.sdk import Credentials, ManualWebhookManagerBase, Webhook
|
||||
from backend.sdk import ManualWebhookManagerBase, Webhook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,7 +17,7 @@ class GenericWebhooksManager(ManualWebhookManagerBase):
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(
|
||||
cls, webhook: Webhook, request: Request, credentials: Credentials | None = None
|
||||
cls, webhook: Webhook, request: Request
|
||||
) -> tuple[dict, str]:
|
||||
payload = await request.json()
|
||||
event_type = GenericWebhookType.PLAIN
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user