mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 00:28:31 -05:00
Compare commits
4 Commits
update-ins
...
abhi/fix-g
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d5b8e2dae0 | ||
|
|
0be6ff45c7 | ||
|
|
c8b17c6d03 | ||
|
|
5a13db8054 |
@@ -9,13 +9,11 @@
|
||||
|
||||
# Platform - Backend
|
||||
!autogpt_platform/backend/backend/
|
||||
!autogpt_platform/backend/test/e2e_test_data.py
|
||||
!autogpt_platform/backend/migrations/
|
||||
!autogpt_platform/backend/schema.prisma
|
||||
!autogpt_platform/backend/pyproject.toml
|
||||
!autogpt_platform/backend/poetry.lock
|
||||
!autogpt_platform/backend/README.md
|
||||
!autogpt_platform/backend/.env
|
||||
|
||||
# Platform - Market
|
||||
!autogpt_platform/market/market/
|
||||
@@ -28,15 +26,13 @@
|
||||
# Platform - Frontend
|
||||
!autogpt_platform/frontend/src/
|
||||
!autogpt_platform/frontend/public/
|
||||
!autogpt_platform/frontend/scripts/
|
||||
!autogpt_platform/frontend/package.json
|
||||
!autogpt_platform/frontend/pnpm-lock.yaml
|
||||
!autogpt_platform/frontend/yarn.lock
|
||||
!autogpt_platform/frontend/tsconfig.json
|
||||
!autogpt_platform/frontend/README.md
|
||||
## config
|
||||
!autogpt_platform/frontend/*.config.*
|
||||
!autogpt_platform/frontend/.env.*
|
||||
!autogpt_platform/frontend/.env
|
||||
|
||||
# Classic - AutoGPT
|
||||
!classic/original_autogpt/autogpt/
|
||||
|
||||
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -24,8 +24,7 @@
|
||||
</details>
|
||||
|
||||
#### For configuration changes:
|
||||
|
||||
- [ ] `.env.default` is updated or already compatible with my changes
|
||||
- [ ] `.env.example` is updated or already compatible with my changes
|
||||
- [ ] `docker-compose.yml` is updated or already compatible with my changes
|
||||
- [ ] I have included a list of my configuration changes in the PR description (under **Changes**)
|
||||
|
||||
|
||||
88
.github/dependabot.yml
vendored
88
.github/dependabot.yml
vendored
@@ -10,19 +10,17 @@ updates:
|
||||
commit-message:
|
||||
prefix: "chore(libs/deps)"
|
||||
prefix-development: "chore(libs/deps-dev)"
|
||||
ignore:
|
||||
- dependency-name: "poetry"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
# backend (Poetry project)
|
||||
- package-ecosystem: "pip"
|
||||
@@ -34,19 +32,17 @@ updates:
|
||||
commit-message:
|
||||
prefix: "chore(backend/deps)"
|
||||
prefix-development: "chore(backend/deps-dev)"
|
||||
ignore:
|
||||
- dependency-name: "poetry"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
# frontend (Next.js project)
|
||||
- package-ecosystem: "npm"
|
||||
@@ -62,13 +58,13 @@ updates:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
# infra (Terraform)
|
||||
- package-ecosystem: "terraform"
|
||||
@@ -85,13 +81,14 @@ updates:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# GitHub Actions
|
||||
- package-ecosystem: "github-actions"
|
||||
@@ -104,13 +101,14 @@ updates:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# Docker
|
||||
- package-ecosystem: "docker"
|
||||
@@ -123,16 +121,40 @@ updates:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# Submodules
|
||||
- package-ecosystem: "gitsubmodule"
|
||||
directory: "autogpt_platform/supabase"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 1
|
||||
target-branch: "dev"
|
||||
commit-message:
|
||||
prefix: "chore(platform/deps)"
|
||||
prefix-development: "chore(platform/deps-dev)"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# Docs
|
||||
- package-ecosystem: "pip"
|
||||
- package-ecosystem: 'pip'
|
||||
directory: "docs/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
@@ -144,10 +166,10 @@ updates:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
5
.github/labeler.yml
vendored
5
.github/labeler.yml
vendored
@@ -24,9 +24,8 @@ platform/frontend:
|
||||
|
||||
platform/backend:
|
||||
- changed-files:
|
||||
- all-globs-to-any-file:
|
||||
- autogpt_platform/backend/**
|
||||
- '!autogpt_platform/backend/backend/blocks/**'
|
||||
- any-glob-to-any-file: autogpt_platform/backend/**
|
||||
- all-globs-to-all-files: '!autogpt_platform/backend/backend/blocks/**'
|
||||
|
||||
platform/blocks:
|
||||
- changed-files:
|
||||
|
||||
47
.github/workflows/claude.yml
vendored
47
.github/workflows/claude.yml
vendored
@@ -1,47 +0,0 @@
|
||||
name: Claude Code
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
pull_request_review_comment:
|
||||
types: [created]
|
||||
issues:
|
||||
types: [opened, assigned]
|
||||
pull_request_review:
|
||||
types: [submitted]
|
||||
|
||||
jobs:
|
||||
claude:
|
||||
if: |
|
||||
(
|
||||
(github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) ||
|
||||
(github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) ||
|
||||
(github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) ||
|
||||
(github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')))
|
||||
) && (
|
||||
github.event.comment.author_association == 'OWNER' ||
|
||||
github.event.comment.author_association == 'MEMBER' ||
|
||||
github.event.comment.author_association == 'COLLABORATOR' ||
|
||||
github.event.review.author_association == 'OWNER' ||
|
||||
github.event.review.author_association == 'MEMBER' ||
|
||||
github.event.review.author_association == 'COLLABORATOR' ||
|
||||
github.event.issue.author_association == 'OWNER' ||
|
||||
github.event.issue.author_association == 'MEMBER' ||
|
||||
github.event.issue.author_association == 'COLLABORATOR'
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
issues: read
|
||||
id-token: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Run Claude Code
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@beta
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
@@ -34,7 +34,6 @@ jobs:
|
||||
python -m prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
|
||||
trigger:
|
||||
|
||||
@@ -36,7 +36,6 @@ jobs:
|
||||
python -m prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
trigger:
|
||||
needs: migrate
|
||||
|
||||
88
.github/workflows/platform-backend-ci.yml
vendored
88
.github/workflows/platform-backend-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
python-version: ["3.10"]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
@@ -50,23 +50,6 @@ jobs:
|
||||
env:
|
||||
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
|
||||
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
|
||||
clamav:
|
||||
image: clamav/clamav-debian:latest
|
||||
ports:
|
||||
- 3310:3310
|
||||
env:
|
||||
CLAMAV_NO_FRESHCLAMD: false
|
||||
CLAMD_CONF_StreamMaxLength: 50M
|
||||
CLAMD_CONF_MaxFileSize: 100M
|
||||
CLAMD_CONF_MaxScanSize: 100M
|
||||
CLAMD_CONF_MaxThreads: 4
|
||||
CLAMD_CONF_ReadTimeout: 300
|
||||
options: >-
|
||||
--health-cmd "clamdscan --version || exit 1"
|
||||
--health-interval 30s
|
||||
--health-timeout 10s
|
||||
--health-retries 5
|
||||
--health-start-period 180s
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -83,7 +66,7 @@ jobs:
|
||||
- name: Setup Supabase
|
||||
uses: supabase/setup-cli@v1
|
||||
with:
|
||||
version: 1.178.1
|
||||
version: latest
|
||||
|
||||
- id: get_date
|
||||
name: Get date
|
||||
@@ -97,35 +80,18 @@ jobs:
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
run: |
|
||||
# Extract Poetry version from backend/poetry.lock
|
||||
HEAD_POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||
echo "Found Poetry version ${HEAD_POETRY_VERSION} in backend/poetry.lock"
|
||||
|
||||
if [ -n "$BASE_REF" ]; then
|
||||
BASE_BRANCH=${BASE_REF/refs\/heads\//}
|
||||
BASE_POETRY_VERSION=$((git show "origin/$BASE_BRANCH":./poetry.lock; true) | python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry -)
|
||||
echo "Found Poetry version ${BASE_POETRY_VERSION} in backend/poetry.lock on ${BASE_REF}"
|
||||
POETRY_VERSION=$(printf '%s\n' "$HEAD_POETRY_VERSION" "$BASE_POETRY_VERSION" | sort -V | tail -n1)
|
||||
else
|
||||
POETRY_VERSION=$HEAD_POETRY_VERSION
|
||||
fi
|
||||
echo "Using Poetry version ${POETRY_VERSION}"
|
||||
|
||||
# Install Poetry
|
||||
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$POETRY_VERSION python3 -
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
env:
|
||||
BASE_REF: ${{ github.base_ref || github.event.merge_group.base_ref }}
|
||||
|
||||
- name: Check poetry.lock
|
||||
run: |
|
||||
poetry lock
|
||||
|
||||
if ! git diff --quiet --ignore-matching-lines="^# " poetry.lock; then
|
||||
if ! git diff --quiet poetry.lock; then
|
||||
echo "Error: poetry.lock not up to date."
|
||||
echo
|
||||
git diff poetry.lock
|
||||
@@ -148,40 +114,10 @@ jobs:
|
||||
# outputs:
|
||||
# DB_URL, API_URL, GRAPHQL_URL, ANON_KEY, SERVICE_ROLE_KEY, JWT_SECRET
|
||||
|
||||
- name: Wait for ClamAV to be ready
|
||||
run: |
|
||||
echo "Waiting for ClamAV daemon to start..."
|
||||
max_attempts=60
|
||||
attempt=0
|
||||
|
||||
until nc -z localhost 3310 || [ $attempt -eq $max_attempts ]; do
|
||||
echo "ClamAV is unavailable - sleeping (attempt $((attempt+1))/$max_attempts)"
|
||||
sleep 5
|
||||
attempt=$((attempt+1))
|
||||
done
|
||||
|
||||
if [ $attempt -eq $max_attempts ]; then
|
||||
echo "ClamAV failed to start after $((max_attempts*5)) seconds"
|
||||
echo "Checking ClamAV service logs..."
|
||||
docker logs $(docker ps -q --filter "ancestor=clamav/clamav-debian:latest") 2>&1 | tail -50 || echo "No ClamAV container found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "ClamAV is ready!"
|
||||
|
||||
# Verify ClamAV is responsive
|
||||
echo "Testing ClamAV connection..."
|
||||
timeout 10 bash -c 'echo "PING" | nc localhost 3310' || {
|
||||
echo "ClamAV is not responding to PING"
|
||||
docker logs $(docker ps -q --filter "ancestor=clamav/clamav-debian:latest") 2>&1 | tail -50 || echo "No ClamAV container found"
|
||||
exit 1
|
||||
}
|
||||
|
||||
- name: Run Database Migrations
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
env:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
- id: lint
|
||||
name: Run Linter
|
||||
@@ -190,22 +126,20 @@ jobs:
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG test
|
||||
else
|
||||
poetry run pytest -s -vv
|
||||
poetry run pytest -s -vv test
|
||||
fi
|
||||
if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
REDIS_HOST: 'localhost'
|
||||
REDIS_PORT: '6379'
|
||||
REDIS_PASSWORD: 'testpassword'
|
||||
|
||||
env:
|
||||
CI: true
|
||||
@@ -218,8 +152,8 @@ jobs:
|
||||
# If you want to replace this, you can do so by making our entire system generate
|
||||
# new credentials for each local user and update the environment variables in
|
||||
# the backend service, docker composes, and examples
|
||||
RABBITMQ_DEFAULT_USER: "rabbitmq_user_default"
|
||||
RABBITMQ_DEFAULT_PASS: "k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7"
|
||||
RABBITMQ_DEFAULT_USER: 'rabbitmq_user_default'
|
||||
RABBITMQ_DEFAULT_PASS: 'k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7'
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
|
||||
@@ -1,198 +0,0 @@
|
||||
name: AutoGPT Platform - Dev Deploy PR Event Dispatcher
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed]
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
dispatch:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check comment permissions and deployment status
|
||||
id: check_status
|
||||
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const commentBody = context.payload.comment.body.trim();
|
||||
const commentUser = context.payload.comment.user.login;
|
||||
const prAuthor = context.payload.issue.user.login;
|
||||
const authorAssociation = context.payload.comment.author_association;
|
||||
|
||||
// Check permissions
|
||||
const hasPermission = (
|
||||
authorAssociation === 'OWNER' ||
|
||||
authorAssociation === 'MEMBER' ||
|
||||
authorAssociation === 'COLLABORATOR'
|
||||
);
|
||||
|
||||
core.setOutput('comment_body', commentBody);
|
||||
core.setOutput('has_permission', hasPermission);
|
||||
|
||||
if (!hasPermission && (commentBody === '!deploy' || commentBody === '!undeploy')) {
|
||||
core.setOutput('permission_denied', 'true');
|
||||
return;
|
||||
}
|
||||
|
||||
if (commentBody !== '!deploy' && commentBody !== '!undeploy') {
|
||||
return;
|
||||
}
|
||||
|
||||
// Process deploy command
|
||||
if (commentBody === '!deploy') {
|
||||
core.setOutput('should_deploy', 'true');
|
||||
}
|
||||
// Process undeploy command
|
||||
else if (commentBody === '!undeploy') {
|
||||
core.setOutput('should_undeploy', 'true');
|
||||
}
|
||||
|
||||
- name: Post permission denied comment
|
||||
if: steps.check_status.outputs.permission_denied == 'true'
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
body: `❌ **Permission denied**: Only the repository owners, members, or collaborators can use deployment commands.`
|
||||
});
|
||||
|
||||
- name: Get PR details for deployment
|
||||
id: pr_details
|
||||
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const pr = await github.rest.pulls.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: context.issue.number
|
||||
});
|
||||
core.setOutput('pr_number', pr.data.number);
|
||||
core.setOutput('pr_title', pr.data.title);
|
||||
core.setOutput('pr_state', pr.data.state);
|
||||
|
||||
- name: Dispatch Deploy Event
|
||||
if: steps.check_status.outputs.should_deploy == 'true'
|
||||
uses: peter-evans/repository-dispatch@v3
|
||||
with:
|
||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
event-type: pr-event
|
||||
client-payload: |
|
||||
{
|
||||
"action": "deploy",
|
||||
"pr_number": "${{ steps.pr_details.outputs.pr_number }}",
|
||||
"pr_title": "${{ steps.pr_details.outputs.pr_title }}",
|
||||
"pr_state": "${{ steps.pr_details.outputs.pr_state }}",
|
||||
"repo": "${{ github.repository }}"
|
||||
}
|
||||
|
||||
- name: Post deploy success comment
|
||||
if: steps.check_status.outputs.should_deploy == 'true'
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
body: `🚀 **Deploying PR #${{ steps.pr_details.outputs.pr_number }}** to development environment...`
|
||||
});
|
||||
|
||||
- name: Dispatch Undeploy Event (from comment)
|
||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||
uses: peter-evans/repository-dispatch@v3
|
||||
with:
|
||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
event-type: pr-event
|
||||
client-payload: |
|
||||
{
|
||||
"action": "undeploy",
|
||||
"pr_number": "${{ steps.pr_details.outputs.pr_number }}",
|
||||
"pr_title": "${{ steps.pr_details.outputs.pr_title }}",
|
||||
"pr_state": "${{ steps.pr_details.outputs.pr_state }}",
|
||||
"repo": "${{ github.repository }}"
|
||||
}
|
||||
|
||||
- name: Post undeploy success comment
|
||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
body: `🗑️ **Undeploying PR #${{ steps.pr_details.outputs.pr_number }}** from development environment...`
|
||||
});
|
||||
|
||||
- name: Check deployment status on PR close
|
||||
id: check_pr_close
|
||||
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const comments = await github.rest.issues.listComments({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number
|
||||
});
|
||||
|
||||
let lastDeployIndex = -1;
|
||||
let lastUndeployIndex = -1;
|
||||
|
||||
comments.data.forEach((comment, index) => {
|
||||
if (comment.body.trim() === '!deploy') {
|
||||
lastDeployIndex = index;
|
||||
} else if (comment.body.trim() === '!undeploy') {
|
||||
lastUndeployIndex = index;
|
||||
}
|
||||
});
|
||||
|
||||
// Should undeploy if there's a !deploy without a subsequent !undeploy
|
||||
const shouldUndeploy = lastDeployIndex !== -1 && lastDeployIndex > lastUndeployIndex;
|
||||
core.setOutput('should_undeploy', shouldUndeploy);
|
||||
|
||||
- name: Dispatch Undeploy Event (PR closed with active deployment)
|
||||
if: >-
|
||||
github.event_name == 'pull_request' &&
|
||||
github.event.action == 'closed' &&
|
||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||
uses: peter-evans/repository-dispatch@v3
|
||||
with:
|
||||
token: ${{ secrets.DISPATCH_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
event-type: pr-event
|
||||
client-payload: |
|
||||
{
|
||||
"action": "undeploy",
|
||||
"pr_number": "${{ github.event.pull_request.number }}",
|
||||
"pr_title": "${{ github.event.pull_request.title }}",
|
||||
"pr_state": "${{ github.event.pull_request.state }}",
|
||||
"repo": "${{ github.repository }}"
|
||||
}
|
||||
|
||||
- name: Post PR close undeploy comment
|
||||
if: >-
|
||||
github.event_name == 'pull_request' &&
|
||||
github.event.action == 'closed' &&
|
||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
body: `🧹 **Auto-undeploying**: PR closed with active deployment. Cleaning up development environment for PR #${{ github.event.pull_request.number }}.`
|
||||
});
|
||||
202
.github/workflows/platform-frontend-ci.yml
vendored
202
.github/workflows/platform-frontend-ci.yml
vendored
@@ -18,116 +18,50 @@ defaults:
|
||||
working-directory: autogpt_platform/frontend
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Generate cache key
|
||||
id: cache-key
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
run: |
|
||||
yarn install --frozen-lockfile
|
||||
|
||||
- name: Run lint
|
||||
run: pnpm lint
|
||||
run: |
|
||||
yarn lint
|
||||
|
||||
chromatic:
|
||||
type-check:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
# Only run on dev branch pushes or PRs targeting dev
|
||||
if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
run: |
|
||||
yarn install --frozen-lockfile
|
||||
|
||||
- name: Run Chromatic
|
||||
uses: chromaui/action@latest
|
||||
with:
|
||||
projectToken: chpt_9e7c1a76478c9c8
|
||||
onlyChanged: true
|
||||
workingDir: autogpt_platform/frontend
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
exitOnceUploaded: true
|
||||
- name: Run tsc check
|
||||
run: |
|
||||
yarn type-check
|
||||
|
||||
test:
|
||||
runs-on: big-boi
|
||||
needs: setup
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
browser: [chromium, webkit]
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -140,96 +74,48 @@ jobs:
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
large-packages: false # slow
|
||||
docker-images: false # limited benefit
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
cp ../supabase/docker/.env.example ../.env
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Cache Docker layers
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /tmp/.buildx-cache
|
||||
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-buildx-frontend-test-
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
cp ../backend/.env.example ../backend/.env
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml up -d
|
||||
env:
|
||||
DOCKER_BUILDKIT: 1
|
||||
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
|
||||
BUILDX_CACHE_TO: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
||||
|
||||
- name: Move cache
|
||||
run: |
|
||||
rm -rf /tmp/.buildx-cache
|
||||
if [ -d "/tmp/.buildx-cache-new" ]; then
|
||||
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
|
||||
fi
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||
|
||||
- name: Create E2E test data
|
||||
run: |
|
||||
echo "Creating E2E test data..."
|
||||
# First try to run the script from inside the container
|
||||
if docker compose -f ../docker-compose.yml exec -T rest_server test -f /app/autogpt_platform/backend/test/e2e_test_data.py; then
|
||||
echo "✅ Found e2e_test_data.py in container, running it..."
|
||||
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python backend/test/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
else
|
||||
echo "⚠️ e2e_test_data.py not found in container, copying and running..."
|
||||
# Copy the script into the container and run it
|
||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.yml ps -q rest_server):/tmp/e2e_test_data.py || {
|
||||
echo "❌ Failed to copy script to container"
|
||||
exit 1
|
||||
}
|
||||
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
fi
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
run: |
|
||||
yarn install --frozen-lockfile
|
||||
|
||||
- name: Install Browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
- name: Setup Builder .env
|
||||
run: |
|
||||
cp .env.example .env
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
- name: Install Browser '${{ matrix.browser }}'
|
||||
run: yarn playwright install --with-deps ${{ matrix.browser }}
|
||||
|
||||
- name: Upload Playwright artifacts
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: playwright-report
|
||||
- name: Run tests
|
||||
timeout-minutes: 20
|
||||
run: |
|
||||
yarn test --project=${{ matrix.browser }}
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
run: docker compose -f ../docker-compose.yml logs
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml logs
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: ${{ !cancelled() }}
|
||||
with:
|
||||
name: playwright-report-${{ matrix.browser }}
|
||||
path: playwright-report/
|
||||
retention-days: 30
|
||||
|
||||
132
.github/workflows/platform-fullstack-ci.yml
vendored
132
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -1,132 +0,0 @@
|
||||
name: AutoGPT Platform - Frontend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev]
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- "autogpt_platform/**"
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/platform-fullstack-ci.yml"
|
||||
- "autogpt_platform/**"
|
||||
merge_group:
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: autogpt_platform/frontend
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Generate cache key
|
||||
id: cache-key
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
types:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
|
||||
- name: Run docker compose
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Setup .env
|
||||
run: cp .env.default .env
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||
|
||||
- name: Generate API queries
|
||||
run: pnpm generate:api:force
|
||||
|
||||
- name: Check for API schema changes
|
||||
run: |
|
||||
if ! git diff --exit-code src/app/api/openapi.json; then
|
||||
echo "❌ API schema changes detected in src/app/api/openapi.json"
|
||||
echo ""
|
||||
echo "The openapi.json file has been modified after running 'pnpm generate:api-all'."
|
||||
echo "This usually means changes have been made in the BE endpoints without updating the Frontend."
|
||||
echo "The API schema is now out of sync with the Front-end queries."
|
||||
echo ""
|
||||
echo "To fix this:"
|
||||
echo "1. Pull the backend 'docker compose pull && docker compose up -d --build --force-recreate'"
|
||||
echo "2. Run 'pnpm generate:api' locally"
|
||||
echo "3. Run 'pnpm types' locally"
|
||||
echo "4. Fix any TypeScript errors that may have been introduced"
|
||||
echo "5. Commit and push your changes"
|
||||
echo ""
|
||||
exit 1
|
||||
else
|
||||
echo "✅ No API schema changes detected"
|
||||
fi
|
||||
|
||||
- name: Run Typescript checks
|
||||
run: pnpm types
|
||||
@@ -16,7 +16,7 @@ jobs:
|
||||
# operations-per-run: 5000
|
||||
stale-issue-message: >
|
||||
This issue has automatically been marked as _stale_ because it has not had
|
||||
any activity in the last 170 days. You can _unstale_ it by commenting or
|
||||
any activity in the last 50 days. You can _unstale_ it by commenting or
|
||||
removing the label. Otherwise, this issue will be closed in 10 days.
|
||||
stale-pr-message: >
|
||||
This pull request has automatically been marked as _stale_ because it has
|
||||
@@ -25,7 +25,7 @@ jobs:
|
||||
close-issue-message: >
|
||||
This issue was closed automatically because it has been stale for 10 days
|
||||
with no activity.
|
||||
days-before-stale: 170
|
||||
days-before-stale: 100
|
||||
days-before-close: 10
|
||||
# Do not touch meta issues:
|
||||
exempt-issue-labels: meta,fridge,project management
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
print("Python version 3.11 or higher required")
|
||||
sys.exit(1)
|
||||
|
||||
import tomllib
|
||||
|
||||
|
||||
def get_package_version(package_name: str, lockfile_path: str) -> str | None:
|
||||
"""Extract package version from poetry.lock file."""
|
||||
try:
|
||||
if lockfile_path == "-":
|
||||
data = tomllib.load(sys.stdin.buffer)
|
||||
else:
|
||||
with open(lockfile_path, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: File '{lockfile_path}' not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except tomllib.TOMLDecodeError as e:
|
||||
print(f"Error parsing TOML file: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"Error reading file: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Look for the package in the packages list
|
||||
packages = data.get("package", [])
|
||||
for package in packages:
|
||||
if package.get("name", "").lower() == package_name.lower():
|
||||
return package.get("version")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) not in (2, 3):
|
||||
print(
|
||||
"Usages: python get_package_version_from_lockfile.py <package name> [poetry.lock path]\n"
|
||||
" cat poetry.lock | python get_package_version_from_lockfile.py <package name> -",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
package_name = sys.argv[1]
|
||||
lockfile_path = sys.argv[2] if len(sys.argv) == 3 else "poetry.lock"
|
||||
|
||||
version = get_package_version(package_name, lockfile_path)
|
||||
|
||||
if version:
|
||||
print(version)
|
||||
else:
|
||||
print(f"Package '{package_name}' not found in {lockfile_path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -5,8 +5,6 @@ classic/original_autogpt/*.json
|
||||
auto_gpt_workspace/*
|
||||
*.mpeg
|
||||
.env
|
||||
# Root .env files
|
||||
/.env
|
||||
azure.yaml
|
||||
.vscode
|
||||
.idea/*
|
||||
@@ -123,6 +121,7 @@ celerybeat.pid
|
||||
|
||||
# Environments
|
||||
.direnv/
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv*/
|
||||
@@ -166,7 +165,7 @@ package-lock.json
|
||||
|
||||
# Allow for locally private items
|
||||
# private
|
||||
pri*
|
||||
pri*
|
||||
# ignore
|
||||
ig*
|
||||
.github_access_token
|
||||
@@ -177,4 +176,3 @@ autogpt_platform/backend/settings.py
|
||||
|
||||
*.ign.*
|
||||
.test-contents
|
||||
.claude/settings.local.json
|
||||
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +1,6 @@
|
||||
[submodule "classic/forge/tests/vcr_cassettes"]
|
||||
path = classic/forge/tests/vcr_cassettes
|
||||
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes
|
||||
[submodule "autogpt_platform/supabase"]
|
||||
path = autogpt_platform/supabase
|
||||
url = https://github.com/supabase/supabase.git
|
||||
|
||||
@@ -17,7 +17,7 @@ repos:
|
||||
name: Detect secrets
|
||||
description: Detects high entropy strings that are likely to be passwords.
|
||||
files: ^autogpt_platform/
|
||||
stages: [pre-push]
|
||||
stages: [push]
|
||||
|
||||
- repo: local
|
||||
# For proper type checking, all dependencies need to be up-to-date.
|
||||
@@ -140,7 +140,7 @@ repos:
|
||||
language: system
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 24.10.0
|
||||
rev: 23.12.1
|
||||
# Black has sensible defaults, doesn't need package context, and ignores
|
||||
# everything in .gitignore, so it works fine without any config or arguments.
|
||||
hooks:
|
||||
@@ -235,44 +235,44 @@ repos:
|
||||
hooks:
|
||||
- id: tsc
|
||||
name: Typecheck - AutoGPT Platform - Frontend
|
||||
entry: bash -c 'cd autogpt_platform/frontend && pnpm types'
|
||||
entry: bash -c 'cd autogpt_platform/frontend && npm run type-check'
|
||||
files: ^autogpt_platform/frontend/
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
# - repo: local
|
||||
# hooks:
|
||||
# - id: pytest
|
||||
# name: Run tests - AutoGPT Platform - Backend
|
||||
# alias: pytest-platform-backend
|
||||
# entry: bash -c 'cd autogpt_platform/backend && poetry run pytest'
|
||||
# # include autogpt_libs source (since it's a path dependency) but exclude *_test.py files:
|
||||
# files: ^autogpt_platform/(backend/((backend|test)/|poetry\.lock$)|autogpt_libs/(autogpt_libs/.*(?<!_test)\.py|poetry\.lock)$)
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: pytest
|
||||
name: Run tests - AutoGPT Platform - Backend
|
||||
alias: pytest-platform-backend
|
||||
entry: bash -c 'cd autogpt_platform/backend && poetry run pytest'
|
||||
# include autogpt_libs source (since it's a path dependency) but exclude *_test.py files:
|
||||
files: ^autogpt_platform/(backend/((backend|test)/|poetry\.lock$)|autogpt_libs/(autogpt_libs/.*(?<!_test)\.py|poetry\.lock)$)
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
# - id: pytest
|
||||
# name: Run tests - Classic - AutoGPT (excl. slow tests)
|
||||
# alias: pytest-classic-autogpt
|
||||
# entry: bash -c 'cd classic/original_autogpt && poetry run pytest --cov=autogpt -m "not slow" tests/unit tests/integration'
|
||||
# # include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||
# files: ^(classic/original_autogpt/((autogpt|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
- id: pytest
|
||||
name: Run tests - Classic - AutoGPT (excl. slow tests)
|
||||
alias: pytest-classic-autogpt
|
||||
entry: bash -c 'cd classic/original_autogpt && poetry run pytest --cov=autogpt -m "not slow" tests/unit tests/integration'
|
||||
# include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||
files: ^(classic/original_autogpt/((autogpt|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
# - id: pytest
|
||||
# name: Run tests - Classic - Forge (excl. slow tests)
|
||||
# alias: pytest-classic-forge
|
||||
# entry: bash -c 'cd classic/forge && poetry run pytest --cov=forge -m "not slow"'
|
||||
# files: ^classic/forge/(forge/|tests/|poetry\.lock$)
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
- id: pytest
|
||||
name: Run tests - Classic - Forge (excl. slow tests)
|
||||
alias: pytest-classic-forge
|
||||
entry: bash -c 'cd classic/forge && poetry run pytest --cov=forge -m "not slow"'
|
||||
files: ^classic/forge/(forge/|tests/|poetry\.lock$)
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
# - id: pytest
|
||||
# name: Run tests - Classic - Benchmark
|
||||
# alias: pytest-classic-benchmark
|
||||
# entry: bash -c 'cd classic/benchmark && poetry run pytest --cov=benchmark'
|
||||
# files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
|
||||
# language: system
|
||||
# pass_filenames: false
|
||||
- id: pytest
|
||||
name: Run tests - Classic - Benchmark
|
||||
alias: pytest-classic-benchmark
|
||||
entry: bash -c 'cd classic/benchmark && poetry run pytest --cov=benchmark'
|
||||
files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
12
.vscode/launch.json
vendored
12
.vscode/launch.json
vendored
@@ -6,7 +6,7 @@
|
||||
"type": "node-terminal",
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceFolder}/autogpt_platform/frontend",
|
||||
"command": "pnpm dev"
|
||||
"command": "yarn dev"
|
||||
},
|
||||
{
|
||||
"name": "Frontend: Client Side",
|
||||
@@ -19,12 +19,12 @@
|
||||
"type": "node-terminal",
|
||||
|
||||
"request": "launch",
|
||||
"command": "pnpm dev",
|
||||
"command": "yarn dev",
|
||||
"cwd": "${workspaceFolder}/autogpt_platform/frontend",
|
||||
"serverReadyAction": {
|
||||
"pattern": "- Local:.+(https?://.+)",
|
||||
"uriFormat": "%s",
|
||||
"action": "debugWithChrome"
|
||||
"action": "debugWithEdge"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -32,9 +32,9 @@
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "backend.app",
|
||||
"env": {
|
||||
"OBJC_DISABLE_INITIALIZE_FORK_SAFETY": "YES"
|
||||
},
|
||||
// "env": {
|
||||
// "ENV": "dev"
|
||||
// },
|
||||
"envFile": "${workspaceFolder}/backend/.env",
|
||||
"justMyCode": false,
|
||||
"cwd": "${workspaceFolder}/autogpt_platform/backend"
|
||||
|
||||
53
AGENTS.md
53
AGENTS.md
@@ -1,53 +0,0 @@
|
||||
# AutoGPT Platform Contribution Guide
|
||||
|
||||
This guide provides context for Codex when updating the **autogpt_platform** folder.
|
||||
|
||||
## Directory overview
|
||||
|
||||
- `autogpt_platform/backend` – FastAPI based backend service.
|
||||
- `autogpt_platform/autogpt_libs` – Shared Python libraries.
|
||||
- `autogpt_platform/frontend` – Next.js + Typescript frontend.
|
||||
- `autogpt_platform/docker-compose.yml` – development stack.
|
||||
|
||||
See `docs/content/platform/getting-started.md` for setup instructions.
|
||||
|
||||
## Code style
|
||||
|
||||
- Format Python code with `poetry run format`.
|
||||
- Format frontend code using `pnpm format`.
|
||||
|
||||
## Testing
|
||||
|
||||
- Backend: `poetry run test` (runs pytest with a docker based postgres + prisma).
|
||||
- Frontend: `pnpm test` or `pnpm test-ui` for Playwright tests. See `docs/content/platform/contributing/tests.md` for tips.
|
||||
|
||||
Always run the relevant linters and tests before committing.
|
||||
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
|
||||
Types:
|
||||
- feat
|
||||
- fix
|
||||
- refactor
|
||||
- ci
|
||||
- dx (developer experience)
|
||||
Scopes:
|
||||
- platform
|
||||
- platform/library
|
||||
- platform/marketplace
|
||||
- backend
|
||||
- backend/executor
|
||||
- frontend
|
||||
- frontend/library
|
||||
- frontend/marketplace
|
||||
- blocks
|
||||
|
||||
## Pull requests
|
||||
|
||||
- Use the template in `.github/PULL_REQUEST_TEMPLATE.md`.
|
||||
- Rely on the pre-commit checks for linting and formatting
|
||||
- Fill out the **Changes** section and the checklist.
|
||||
- Use conventional commit titles with a scope (e.g. `feat(frontend): add feature`).
|
||||
- Keep out-of-scope changes under 20% of the PR.
|
||||
- Ensure PR descriptions are complete.
|
||||
- For changes touching `data/*.py`, validate user ID checks or explain why not needed.
|
||||
- If adding protected frontend routes, update `frontend/lib/supabase/middleware.ts`.
|
||||
- Use the linear ticket branch structure if given codex/open-1668-resume-dropped-runs
|
||||
@@ -2,6 +2,9 @@
|
||||
If you are reading this, you are probably looking for the full **[contribution guide]**,
|
||||
which is part of our [wiki].
|
||||
|
||||
Also check out our [🚀 Roadmap][roadmap] for information about our priorities and associated tasks.
|
||||
<!-- You can find our immediate priorities and their progress on our public [kanban board]. -->
|
||||
|
||||
[contribution guide]: https://github.com/Significant-Gravitas/AutoGPT/wiki/Contributing
|
||||
[wiki]: https://github.com/Significant-Gravitas/AutoGPT/wiki
|
||||
[roadmap]: https://github.com/Significant-Gravitas/AutoGPT/discussions/6971
|
||||
|
||||
195
LICENSE
195
LICENSE
@@ -1,197 +1,6 @@
|
||||
All portions of this repository are under one of two licenses.
|
||||
All portions of this repository are under one of two licenses. The majority of the AutoGPT repository is under the MIT License below. The autogpt_platform folder is under the
|
||||
Polyform Shield License.
|
||||
|
||||
- Everything inside the autogpt_platform folder is under the Polyform Shield License.
|
||||
- Everything outside the autogpt_platform folder is under the MIT License.
|
||||
|
||||
More info:
|
||||
|
||||
**Polyform Shield License:**
|
||||
Code and content within the `autogpt_platform` folder is licensed under the Polyform Shield License. This new project is our in-developlemt platform for building, deploying and managing agents.
|
||||
Read more about this effort here: https://agpt.co/blog/introducing-the-autogpt-platform
|
||||
|
||||
**MIT License:**
|
||||
All other portions of the AutoGPT repository (i.e., everything outside the `autogpt_platform` folder) are licensed under the MIT License. This includes:
|
||||
- The Original, stand-alone AutoGPT Agent
|
||||
- Forge: https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/forge
|
||||
- AG Benchmark: https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/benchmark
|
||||
- AutoGPT Classic GUI: https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/frontend.
|
||||
|
||||
We also publish additional work under the MIT Licence in other repositories, such as GravitasML (https://github.com/Significant-Gravitas/gravitasml) which is developed for and used in the AutoGPT Platform, and our [Code Ability](https://github.com/Significant-Gravitas/AutoGPT-Code-Ability) project.
|
||||
|
||||
Both licences are available to read below:
|
||||
|
||||
=====================================================
|
||||
-----------------------------------------------------
|
||||
=====================================================
|
||||
|
||||
# PolyForm Shield License 1.0.0
|
||||
|
||||
<https://polyformproject.org/licenses/shield/1.0.0>
|
||||
|
||||
## Acceptance
|
||||
|
||||
In order to get any license under these terms, you must agree
|
||||
to them as both strict obligations and conditions to all
|
||||
your licenses.
|
||||
|
||||
## Copyright License
|
||||
|
||||
The licensor grants you a copyright license for the
|
||||
software to do everything you might do with the software
|
||||
that would otherwise infringe the licensor's copyright
|
||||
in it for any permitted purpose. However, you may
|
||||
only distribute the software according to [Distribution
|
||||
License](#distribution-license) and make changes or new works
|
||||
based on the software according to [Changes and New Works
|
||||
License](#changes-and-new-works-license).
|
||||
|
||||
## Distribution License
|
||||
|
||||
The licensor grants you an additional copyright license
|
||||
to distribute copies of the software. Your license
|
||||
to distribute covers distributing the software with
|
||||
changes and new works permitted by [Changes and New Works
|
||||
License](#changes-and-new-works-license).
|
||||
|
||||
## Notices
|
||||
|
||||
You must ensure that anyone who gets a copy of any part of
|
||||
the software from you also gets a copy of these terms or the
|
||||
URL for them above, as well as copies of any plain-text lines
|
||||
beginning with `Required Notice:` that the licensor provided
|
||||
with the software. For example:
|
||||
|
||||
> Required Notice: Copyright Yoyodyne, Inc. (http://example.com)
|
||||
|
||||
## Changes and New Works License
|
||||
|
||||
The licensor grants you an additional copyright license to
|
||||
make changes and new works based on the software for any
|
||||
permitted purpose.
|
||||
|
||||
## Patent License
|
||||
|
||||
The licensor grants you a patent license for the software that
|
||||
covers patent claims the licensor can license, or becomes able
|
||||
to license, that you would infringe by using the software.
|
||||
|
||||
## Noncompete
|
||||
|
||||
Any purpose is a permitted purpose, except for providing any
|
||||
product that competes with the software or any product the
|
||||
licensor or any of its affiliates provides using the software.
|
||||
|
||||
## Competition
|
||||
|
||||
Goods and services compete even when they provide functionality
|
||||
through different kinds of interfaces or for different technical
|
||||
platforms. Applications can compete with services, libraries
|
||||
with plugins, frameworks with development tools, and so on,
|
||||
even if they're written in different programming languages
|
||||
or for different computer architectures. Goods and services
|
||||
compete even when provided free of charge. If you market a
|
||||
product as a practical substitute for the software or another
|
||||
product, it definitely competes.
|
||||
|
||||
## New Products
|
||||
|
||||
If you are using the software to provide a product that does
|
||||
not compete, but the licensor or any of its affiliates brings
|
||||
your product into competition by providing a new version of
|
||||
the software or another product using the software, you may
|
||||
continue using versions of the software available under these
|
||||
terms beforehand to provide your competing product, but not
|
||||
any later versions.
|
||||
|
||||
## Discontinued Products
|
||||
|
||||
You may begin using the software to compete with a product
|
||||
or service that the licensor or any of its affiliates has
|
||||
stopped providing, unless the licensor includes a plain-text
|
||||
line beginning with `Licensor Line of Business:` with the
|
||||
software that mentions that line of business. For example:
|
||||
|
||||
> Licensor Line of Business: YoyodyneCMS Content Management
|
||||
System (http://example.com/cms)
|
||||
|
||||
## Sales of Business
|
||||
|
||||
If the licensor or any of its affiliates sells a line of
|
||||
business developing the software or using the software
|
||||
to provide a product, the buyer can also enforce
|
||||
[Noncompete](#noncompete) for that product.
|
||||
|
||||
## Fair Use
|
||||
|
||||
You may have "fair use" rights for the software under the
|
||||
law. These terms do not limit them.
|
||||
|
||||
## No Other Rights
|
||||
|
||||
These terms do not allow you to sublicense or transfer any of
|
||||
your licenses to anyone else, or prevent the licensor from
|
||||
granting licenses to anyone else. These terms do not imply
|
||||
any other licenses.
|
||||
|
||||
## Patent Defense
|
||||
|
||||
If you make any written claim that the software infringes or
|
||||
contributes to infringement of any patent, your patent license
|
||||
for the software granted under these terms ends immediately. If
|
||||
your company makes such a claim, your patent license ends
|
||||
immediately for work on behalf of your company.
|
||||
|
||||
## Violations
|
||||
|
||||
The first time you are notified in writing that you have
|
||||
violated any of these terms, or done anything with the software
|
||||
not covered by your licenses, your licenses can nonetheless
|
||||
continue if you come into full compliance with these terms,
|
||||
and take practical steps to correct past violations, within
|
||||
32 days of receiving notice. Otherwise, all your licenses
|
||||
end immediately.
|
||||
|
||||
## No Liability
|
||||
|
||||
***As far as the law allows, the software comes as is, without
|
||||
any warranty or condition, and the licensor will not be liable
|
||||
to you for any damages arising out of these terms or the use
|
||||
or nature of the software, under any kind of legal claim.***
|
||||
|
||||
## Definitions
|
||||
|
||||
The **licensor** is the individual or entity offering these
|
||||
terms, and the **software** is the software the licensor makes
|
||||
available under these terms.
|
||||
|
||||
A **product** can be a good or service, or a combination
|
||||
of them.
|
||||
|
||||
**You** refers to the individual or entity agreeing to these
|
||||
terms.
|
||||
|
||||
**Your company** is any legal entity, sole proprietorship,
|
||||
or other kind of organization that you work for, plus all
|
||||
its affiliates.
|
||||
|
||||
**Affiliates** means the other organizations than an
|
||||
organization has control over, is under the control of, or is
|
||||
under common control with.
|
||||
|
||||
**Control** means ownership of substantially all the assets of
|
||||
an entity, or the power to direct its management and policies
|
||||
by vote, contract, or otherwise. Control can be direct or
|
||||
indirect.
|
||||
|
||||
**Your licenses** are all the licenses granted to you for the
|
||||
software under these terms.
|
||||
|
||||
**Use** means anything you do with the software requiring one
|
||||
of your licenses.
|
||||
|
||||
=====================================================
|
||||
-----------------------------------------------------
|
||||
=====================================================
|
||||
|
||||
MIT License
|
||||
|
||||
|
||||
92
README.md
92
README.md
@@ -1,82 +1,24 @@
|
||||
# AutoGPT: Build, Deploy, and Run AI Agents
|
||||
|
||||
[](https://discord.gg/autogpt)  
|
||||
[](https://discord.gg/autogpt)  
|
||||
[](https://twitter.com/Auto_GPT)  
|
||||
|
||||
<!-- Keep these links. Translations will automatically update with the README. -->
|
||||
[Deutsch](https://zdoc.app/de/Significant-Gravitas/AutoGPT) |
|
||||
[Español](https://zdoc.app/es/Significant-Gravitas/AutoGPT) |
|
||||
[français](https://zdoc.app/fr/Significant-Gravitas/AutoGPT) |
|
||||
[日本語](https://zdoc.app/ja/Significant-Gravitas/AutoGPT) |
|
||||
[한국어](https://zdoc.app/ko/Significant-Gravitas/AutoGPT) |
|
||||
[Português](https://zdoc.app/pt/Significant-Gravitas/AutoGPT) |
|
||||
[Русский](https://zdoc.app/ru/Significant-Gravitas/AutoGPT) |
|
||||
[中文](https://zdoc.app/zh/Significant-Gravitas/AutoGPT)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
|
||||
**AutoGPT** is a powerful platform that allows you to create, deploy, and manage continuous AI agents that automate complex workflows.
|
||||
|
||||
## Hosting Options
|
||||
- Download to self-host (Free!)
|
||||
- [Join the Waitlist](https://bit.ly/3ZDijAI) for the cloud-hosted beta (Closed Beta - Public release Coming Soon!)
|
||||
- Download to self-host
|
||||
- [Join the Waitlist](https://bit.ly/3ZDijAI) for the cloud-hosted beta
|
||||
|
||||
## How to Self-Host the AutoGPT Platform
|
||||
## How to Setup for Self-Hosting
|
||||
> [!NOTE]
|
||||
> Setting up and hosting the AutoGPT Platform yourself is a technical process.
|
||||
> If you'd rather something that just works, we recommend [joining the waitlist](https://bit.ly/3ZDijAI) for the cloud-hosted beta.
|
||||
|
||||
### System Requirements
|
||||
|
||||
Before proceeding with the installation, ensure your system meets the following requirements:
|
||||
|
||||
#### Hardware Requirements
|
||||
- CPU: 4+ cores recommended
|
||||
- RAM: Minimum 8GB, 16GB recommended
|
||||
- Storage: At least 10GB of free space
|
||||
|
||||
#### Software Requirements
|
||||
- Operating Systems:
|
||||
- Linux (Ubuntu 20.04 or newer recommended)
|
||||
- macOS (10.15 or newer)
|
||||
- Windows 10/11 with WSL2
|
||||
- Required Software (with minimum versions):
|
||||
- Docker Engine (20.10.0 or newer)
|
||||
- Docker Compose (2.0.0 or newer)
|
||||
- Git (2.30 or newer)
|
||||
- Node.js (16.x or newer)
|
||||
- npm (8.x or newer)
|
||||
- VSCode (1.60 or newer) or any modern code editor
|
||||
|
||||
#### Network Requirements
|
||||
- Stable internet connection
|
||||
- Access to required ports (will be configured in Docker)
|
||||
- Ability to make outbound HTTPS connections
|
||||
|
||||
### Updated Setup Instructions:
|
||||
We've moved to a fully maintained and regularly updated documentation site.
|
||||
|
||||
👉 [Follow the official self-hosting guide here](https://docs.agpt.co/platform/getting-started/)
|
||||
|
||||
https://github.com/user-attachments/assets/d04273a5-b36a-4a37-818e-f631ce72d603
|
||||
|
||||
This tutorial assumes you have Docker, VSCode, git and npm installed.
|
||||
|
||||
---
|
||||
|
||||
#### ⚡ Quick Setup with One-Line Script (Recommended for Local Hosting)
|
||||
|
||||
Skip the manual steps and get started in minutes using our automatic setup script.
|
||||
|
||||
For macOS/Linux:
|
||||
```
|
||||
curl -fsSL https://setup.agpt.co/install.sh -o install.sh && bash install.sh
|
||||
```
|
||||
|
||||
For Windows (PowerShell):
|
||||
```
|
||||
powershell -c "iwr https://setup.agpt.co/install.bat -o install.bat; ./install.bat"
|
||||
```
|
||||
|
||||
This will install dependencies, configure Docker, and launch your local instance — all in one go.
|
||||
|
||||
### 🧱 AutoGPT Frontend
|
||||
|
||||
The AutoGPT frontend is where users interact with our powerful AI automation platform. It offers multiple ways to engage with and leverage our AI agents. This is the interface where you'll bring your AI automation ideas to life:
|
||||
@@ -123,17 +65,7 @@ Here are two examples of what you can do with AutoGPT:
|
||||
These examples show just a glimpse of what you can achieve with AutoGPT! You can create customized workflows to build agents for any use case.
|
||||
|
||||
---
|
||||
|
||||
### **License Overview:**
|
||||
|
||||
🛡️ **Polyform Shield License:**
|
||||
All code and content within the `autogpt_platform` folder is licensed under the Polyform Shield License. This new project is our in-developlemt platform for building, deploying and managing agents.</br>_[Read more about this effort](https://agpt.co/blog/introducing-the-autogpt-platform)_
|
||||
|
||||
🦉 **MIT License:**
|
||||
All other portions of the AutoGPT repository (i.e., everything outside the `autogpt_platform` folder) are licensed under the MIT License. This includes the original stand-alone AutoGPT Agent, along with projects such as [Forge](https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/forge), [agbenchmark](https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/benchmark) and the [AutoGPT Classic GUI](https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/frontend).</br>We also publish additional work under the MIT Licence in other repositories, such as [GravitasML](https://github.com/Significant-Gravitas/gravitasml) which is developed for and used in the AutoGPT Platform. See also our MIT Licenced [Code Ability](https://github.com/Significant-Gravitas/AutoGPT-Code-Ability) project.
|
||||
|
||||
---
|
||||
### Mission
|
||||
### Mission and Licencing
|
||||
Our mission is to provide the tools, so that you can focus on what matters:
|
||||
|
||||
- 🏗️ **Building** - Lay the foundation for something amazing.
|
||||
@@ -146,6 +78,14 @@ Be part of the revolution! **AutoGPT** is here to stay, at the forefront of AI i
|
||||
 | 
|
||||
**🚀 [Contributing](CONTRIBUTING.md)**
|
||||
|
||||
**Licensing:**
|
||||
|
||||
MIT License: The majority of the AutoGPT repository is under the MIT License.
|
||||
|
||||
Polyform Shield License: This license applies to the autogpt_platform folder.
|
||||
|
||||
For more information, see https://agpt.co/blog/introducing-the-autogpt-platform
|
||||
|
||||
---
|
||||
## 🤖 AutoGPT Classic
|
||||
> Below is information about the classic version of AutoGPT.
|
||||
@@ -208,7 +148,7 @@ Just clone the repo, install dependencies with `./run setup`, and you should be
|
||||
|
||||
[](https://discord.gg/autogpt)
|
||||
|
||||
To report a bug or request a feature, create a [GitHub Issue](https://github.com/Significant-Gravitas/AutoGPT/issues/new/choose). Please ensure someone else hasn't created an issue for the same topic.
|
||||
To report a bug or request a feature, create a [GitHub Issue](https://github.com/Significant-Gravitas/AutoGPT/issues/new/choose). Please ensure someone else hasn’t created an issue for the same topic.
|
||||
|
||||
## 🤝 Sister projects
|
||||
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
############
|
||||
# Secrets
|
||||
# YOU MUST CHANGE THESE BEFORE GOING INTO PRODUCTION
|
||||
############
|
||||
|
||||
POSTGRES_PASSWORD=your-super-secret-and-long-postgres-password
|
||||
JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
|
||||
SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
DASHBOARD_USERNAME=supabase
|
||||
DASHBOARD_PASSWORD=this_password_is_insecure_and_should_be_updated
|
||||
SECRET_KEY_BASE=UpNVntn3cDxHJpq99YMc1T1AQgQpc8kfYTuRgBiYa15BLrx8etQoXz3gZv1/u2oq
|
||||
VAULT_ENC_KEY=your-encryption-key-32-chars-min
|
||||
|
||||
|
||||
############
|
||||
# Database - You can change these to any PostgreSQL database that has logical replication enabled.
|
||||
############
|
||||
|
||||
POSTGRES_HOST=db
|
||||
POSTGRES_DB=postgres
|
||||
POSTGRES_PORT=5432
|
||||
# default user is postgres
|
||||
|
||||
|
||||
############
|
||||
# Supavisor -- Database pooler
|
||||
############
|
||||
POOLER_PROXY_PORT_TRANSACTION=6543
|
||||
POOLER_DEFAULT_POOL_SIZE=20
|
||||
POOLER_MAX_CLIENT_CONN=100
|
||||
POOLER_TENANT_ID=your-tenant-id
|
||||
|
||||
|
||||
############
|
||||
# API Proxy - Configuration for the Kong Reverse proxy.
|
||||
############
|
||||
|
||||
KONG_HTTP_PORT=8000
|
||||
KONG_HTTPS_PORT=8443
|
||||
|
||||
|
||||
############
|
||||
# API - Configuration for PostgREST.
|
||||
############
|
||||
|
||||
PGRST_DB_SCHEMAS=public,storage,graphql_public
|
||||
|
||||
|
||||
############
|
||||
# Auth - Configuration for the GoTrue authentication server.
|
||||
############
|
||||
|
||||
## General
|
||||
SITE_URL=http://localhost:3000
|
||||
ADDITIONAL_REDIRECT_URLS=
|
||||
JWT_EXPIRY=3600
|
||||
DISABLE_SIGNUP=false
|
||||
API_EXTERNAL_URL=http://localhost:8000
|
||||
|
||||
## Mailer Config
|
||||
MAILER_URLPATHS_CONFIRMATION="/auth/v1/verify"
|
||||
MAILER_URLPATHS_INVITE="/auth/v1/verify"
|
||||
MAILER_URLPATHS_RECOVERY="/auth/v1/verify"
|
||||
MAILER_URLPATHS_EMAIL_CHANGE="/auth/v1/verify"
|
||||
|
||||
## Email auth
|
||||
ENABLE_EMAIL_SIGNUP=true
|
||||
ENABLE_EMAIL_AUTOCONFIRM=false
|
||||
SMTP_ADMIN_EMAIL=admin@example.com
|
||||
SMTP_HOST=supabase-mail
|
||||
SMTP_PORT=2500
|
||||
SMTP_USER=fake_mail_user
|
||||
SMTP_PASS=fake_mail_password
|
||||
SMTP_SENDER_NAME=fake_sender
|
||||
ENABLE_ANONYMOUS_USERS=false
|
||||
|
||||
## Phone auth
|
||||
ENABLE_PHONE_SIGNUP=true
|
||||
ENABLE_PHONE_AUTOCONFIRM=true
|
||||
|
||||
|
||||
############
|
||||
# Studio - Configuration for the Dashboard
|
||||
############
|
||||
|
||||
STUDIO_DEFAULT_ORGANIZATION=Default Organization
|
||||
STUDIO_DEFAULT_PROJECT=Default Project
|
||||
|
||||
STUDIO_PORT=3000
|
||||
# replace if you intend to use Studio outside of localhost
|
||||
SUPABASE_PUBLIC_URL=http://localhost:8000
|
||||
|
||||
# Enable webp support
|
||||
IMGPROXY_ENABLE_WEBP_DETECTION=true
|
||||
|
||||
# Add your OpenAI API key to enable SQL Editor Assistant
|
||||
OPENAI_API_KEY=
|
||||
|
||||
|
||||
############
|
||||
# Functions - Configuration for Functions
|
||||
############
|
||||
# NOTE: VERIFY_JWT applies to all functions. Per-function VERIFY_JWT is not supported yet.
|
||||
FUNCTIONS_VERIFY_JWT=false
|
||||
|
||||
|
||||
############
|
||||
# Logs - Configuration for Logflare
|
||||
# Please refer to https://supabase.com/docs/reference/self-hosting-analytics/introduction
|
||||
############
|
||||
|
||||
LOGFLARE_LOGGER_BACKEND_API_KEY=your-super-secret-and-long-logflare-key
|
||||
|
||||
# Change vector.toml sinks to reflect this change
|
||||
LOGFLARE_API_KEY=your-super-secret-and-long-logflare-key
|
||||
|
||||
# Docker socket location - this value will differ depending on your OS
|
||||
DOCKER_SOCKET_LOCATION=/var/run/docker.sock
|
||||
|
||||
# Google Cloud Project details
|
||||
GOOGLE_PROJECT_ID=GOOGLE_PROJECT_ID
|
||||
GOOGLE_PROJECT_NUMBER=GOOGLE_PROJECT_NUMBER
|
||||
@@ -1,230 +0,0 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`/backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`/frontend`): Next.js React application
|
||||
- **Shared Libraries** (`/autogpt_libs`): Common Python utilities
|
||||
|
||||
## Essential Commands
|
||||
|
||||
### Backend Development
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd backend && poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend server
|
||||
poetry run serve
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in TESTING.md
|
||||
|
||||
#### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
### Frontend Development
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd frontend && npm install
|
||||
|
||||
# Start development server
|
||||
npm run dev
|
||||
|
||||
# Run E2E tests
|
||||
npm run test
|
||||
|
||||
# Run Storybook for component development
|
||||
npm run storybook
|
||||
|
||||
# Build production
|
||||
npm run build
|
||||
|
||||
# Type checking
|
||||
npm run types
|
||||
```
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Backend Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
### Frontend Architecture
|
||||
|
||||
- **Framework**: Next.js App Router with React Server Components
|
||||
- **State Management**: React hooks + Supabase client for real-time updates
|
||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
||||
- **UI Components**: Radix UI primitives with Tailwind CSS styling
|
||||
- **Feature Flags**: LaunchDarkly integration
|
||||
|
||||
### Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
4. **Store**: Marketplace for sharing agent templates
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Testing Approach
|
||||
|
||||
- Backend uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Frontend uses Playwright for E2E tests
|
||||
- Component testing via Storybook
|
||||
|
||||
### Database Schema
|
||||
|
||||
Key models (defined in `/backend/schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides)
|
||||
- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides)
|
||||
- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
1. `.env.default` files provide base configuration (tracked in git)
|
||||
2. `.env` files provide user-specific overrides (gitignored)
|
||||
3. Docker Compose `environment:` sections provide service-specific overrides
|
||||
4. Shell environment variables have highest precedence
|
||||
|
||||
#### Key Points
|
||||
|
||||
- All services use hardcoded defaults in docker-compose files (no `${VARIABLE}` substitutions)
|
||||
- The `env_file` directive loads variables INTO containers at runtime
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Common Development Tasks
|
||||
|
||||
**Adding a new block:**
|
||||
|
||||
1. Create new file in `/backend/backend/blocks/`
|
||||
2. Inherit from `Block` base class
|
||||
3. Define input/output schemas
|
||||
4. Implement `run` method
|
||||
5. Register in block registry
|
||||
6. Generate the block uuid using `uuid.uuid4()`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blcoks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
**Modifying the API:**
|
||||
|
||||
1. Update route in `/backend/backend/server/routers/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
**Frontend feature development:**
|
||||
|
||||
1. Components go in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for new components
|
||||
4. Test with Playwright if user-facing
|
||||
|
||||
### Security Implementation
|
||||
|
||||
**Cache Protection Middleware:**
|
||||
|
||||
- Located in `/backend/backend/server/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR aginst the `dev` branch of the repository.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/
|
||||
- Use conventional commit messages (see below)/
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
- When the user runs /pr-comments or tries to fetch them, also run gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews to get the reviews
|
||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/pulls/[issuenum]/reviews/[review_id]/comments to get the review contents
|
||||
- Use gh api /repos/Significant-Gravitas/AutoGPT/issues/9924/comments to get the pr specific comments
|
||||
|
||||
### Conventional Commits
|
||||
|
||||
Use this format for commit messages and Pull Request titles:
|
||||
|
||||
**Conventional Commit Types:**
|
||||
|
||||
- `feat`: Introduces a new feature to the codebase
|
||||
- `fix`: Patches a bug in the codebase
|
||||
- `refactor`: Code change that neither fixes a bug nor adds a feature; also applies to removing features
|
||||
- `ci`: Changes to CI configuration
|
||||
- `docs`: Documentation-only changes
|
||||
- `dx`: Improvements to the developer experience
|
||||
|
||||
**Recommended Base Scopes:**
|
||||
|
||||
- `platform`: Changes affecting both frontend and backend
|
||||
- `frontend`
|
||||
- `backend`
|
||||
- `infra`
|
||||
- `blocks`: Modifications/additions of individual blocks
|
||||
|
||||
**Subscope Examples:**
|
||||
|
||||
- `backend/executor`
|
||||
- `backend/db`
|
||||
- `frontend/builder` (includes changes to the block UI component)
|
||||
- `infra/prod`
|
||||
|
||||
Use these scopes and subscopes for clarity and consistency in commit messages.
|
||||
@@ -8,35 +8,60 @@ Welcome to the AutoGPT Platform - a powerful system for creating and running AI
|
||||
|
||||
- Docker
|
||||
- Docker Compose V2 (comes with Docker Desktop, or can be installed separately)
|
||||
- Node.js & NPM (for running the frontend application)
|
||||
|
||||
### Running the System
|
||||
|
||||
To run the AutoGPT Platform, follow these steps:
|
||||
|
||||
1. Clone this repository to your local machine and navigate to the `autogpt_platform` directory within the repository:
|
||||
|
||||
```
|
||||
git clone <https://github.com/Significant-Gravitas/AutoGPT.git | git@github.com:Significant-Gravitas/AutoGPT.git>
|
||||
cd AutoGPT/autogpt_platform
|
||||
```
|
||||
|
||||
2. Run the following command:
|
||||
|
||||
```
|
||||
cp .env.default .env
|
||||
git submodule update --init --recursive --progress
|
||||
```
|
||||
|
||||
This command will copy the `.env.default` file to `.env`. You can modify the `.env` file to add your own environment variables.
|
||||
This command will initialize and update the submodules in the repository. The `supabase` folder will be cloned to the root directory.
|
||||
|
||||
3. Run the following command:
|
||||
```
|
||||
cp supabase/docker/.env.example .env
|
||||
```
|
||||
This command will copy the `.env.example` file to `.env` in the `supabase/docker` directory. You can modify the `.env` file to add your own environment variables.
|
||||
|
||||
4. Run the following command:
|
||||
```
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
This command will start all the necessary backend services defined in the `docker-compose.yml` file in detached mode.
|
||||
|
||||
4. After all the services are in ready state, open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
5. Navigate to `frontend` within the `autogpt_platform` directory:
|
||||
```
|
||||
cd frontend
|
||||
```
|
||||
You will need to run your frontend application separately on your local machine.
|
||||
|
||||
6. Run the following command:
|
||||
```
|
||||
cp .env.example .env.local
|
||||
```
|
||||
This command will copy the `.env.example` file to `.env.local` in the `frontend` directory. You can modify the `.env.local` within this folder to add your own environment variables for the frontend application.
|
||||
|
||||
7. Run the following command:
|
||||
```
|
||||
npm install
|
||||
npm run dev
|
||||
```
|
||||
This command will install the necessary dependencies and start the frontend application in development mode.
|
||||
If you are using Yarn, you can run the following commands instead:
|
||||
```
|
||||
yarn install && yarn dev
|
||||
```
|
||||
|
||||
8. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
|
||||
### Docker Compose Commands
|
||||
|
||||
@@ -49,52 +74,43 @@ Here are some useful Docker Compose commands for managing your AutoGPT Platform:
|
||||
- `docker compose down`: Stop and remove containers, networks, and volumes.
|
||||
- `docker compose watch`: Watch for changes in your services and automatically update them.
|
||||
|
||||
|
||||
### Sample Scenarios
|
||||
|
||||
Here are some common scenarios where you might use multiple Docker Compose commands:
|
||||
|
||||
1. Updating and restarting a specific service:
|
||||
|
||||
```
|
||||
docker compose build api_srv
|
||||
docker compose up -d --no-deps api_srv
|
||||
```
|
||||
|
||||
This rebuilds the `api_srv` service and restarts it without affecting other services.
|
||||
|
||||
2. Viewing logs for troubleshooting:
|
||||
|
||||
```
|
||||
docker compose logs -f api_srv ws_srv
|
||||
```
|
||||
|
||||
This shows and follows the logs for both `api_srv` and `ws_srv` services.
|
||||
|
||||
3. Scaling a service for increased load:
|
||||
|
||||
```
|
||||
docker compose up -d --scale executor=3
|
||||
```
|
||||
|
||||
This scales the `executor` service to 3 instances to handle increased load.
|
||||
|
||||
4. Stopping the entire system for maintenance:
|
||||
|
||||
```
|
||||
docker compose stop
|
||||
docker compose rm -f
|
||||
docker compose pull
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
This stops all services, removes containers, pulls the latest images, and restarts the system.
|
||||
|
||||
5. Developing with live updates:
|
||||
|
||||
```
|
||||
docker compose watch
|
||||
```
|
||||
|
||||
This watches for changes in your code and automatically updates the relevant services.
|
||||
|
||||
6. Checking the status of services:
|
||||
@@ -105,6 +121,7 @@ Here are some common scenarios where you might use multiple Docker Compose comma
|
||||
|
||||
These scenarios demonstrate how to use Docker Compose commands in combination to manage your AutoGPT Platform effectively.
|
||||
|
||||
|
||||
### Persisting Data
|
||||
|
||||
To persist data for PostgreSQL and Redis, you can modify the `docker-compose.yml` file to add volumes. Here's how:
|
||||
@@ -132,28 +149,3 @@ To persist data for PostgreSQL and Redis, you can modify the `docker-compose.yml
|
||||
3. Save the file and run `docker compose up -d` to apply the changes.
|
||||
|
||||
This configuration will create named volumes for PostgreSQL and Redis, ensuring that your data persists across container restarts.
|
||||
|
||||
### API Client Generation
|
||||
|
||||
The platform includes scripts for generating and managing the API client:
|
||||
|
||||
- `pnpm fetch:openapi`: Fetches the OpenAPI specification from the backend service (requires backend to be running on port 8006)
|
||||
- `pnpm generate:api-client`: Generates the TypeScript API client from the OpenAPI specification using Orval
|
||||
- `pnpm generate:api`: Runs both fetch and generate commands in sequence
|
||||
|
||||
#### Manual API Client Updates
|
||||
|
||||
If you need to update the API client after making changes to the backend API:
|
||||
|
||||
1. Ensure the backend services are running:
|
||||
|
||||
```
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
2. Generate the updated API client:
|
||||
```
|
||||
pnpm generate:api
|
||||
```
|
||||
|
||||
This will fetch the latest OpenAPI specification and regenerate the TypeScript client code.
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# AutoGPT Libs
|
||||
|
||||
This is a new project to store shared functionality across different services in the AutoGPT Platform (e.g. authentication)
|
||||
This is a new project to store shared functionality across different services in NextGen AutoGPT (e.g. authentication)
|
||||
|
||||
@@ -31,5 +31,4 @@ class APIKeyManager:
|
||||
"""Verify if a provided API key matches the stored hash."""
|
||||
if not provided_key.startswith(self.PREFIX):
|
||||
return False
|
||||
provided_hash = hashlib.sha256(provided_key.encode()).hexdigest()
|
||||
return secrets.compare_digest(provided_hash, stored_hash)
|
||||
return hashlib.sha256(provided_key.encode()).hexdigest() == stored_hash
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from .config import Settings
|
||||
from .depends import requires_admin_user, requires_user
|
||||
from .jwt_utils import parse_jwt_token
|
||||
from .middleware import APIKeyValidator, auth_middleware
|
||||
from .models import User
|
||||
|
||||
__all__ = [
|
||||
"Settings",
|
||||
"parse_jwt_token",
|
||||
"requires_user",
|
||||
"requires_admin_user",
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Settings:
|
||||
def __init__(self):
|
||||
self.JWT_SECRET_KEY: str = os.getenv("SUPABASE_JWT_SECRET", "")
|
||||
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
|
||||
self.JWT_ALGORITHM: str = "HS256"
|
||||
JWT_SECRET_KEY: str = os.getenv("SUPABASE_JWT_SECRET", "")
|
||||
ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(self.JWT_SECRET_KEY)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import fastapi
|
||||
|
||||
from .config import settings
|
||||
from .config import Settings
|
||||
from .middleware import auth_middleware
|
||||
from .models import DEFAULT_USER_ID, User
|
||||
|
||||
@@ -17,7 +17,7 @@ def requires_admin_user(
|
||||
|
||||
def verify_user(payload: dict | None, admin_only: bool) -> User:
|
||||
if not payload:
|
||||
if settings.ENABLE_AUTH:
|
||||
if Settings.ENABLE_AUTH:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="Authorization header is missing"
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import inspect
|
||||
import logging
|
||||
import secrets
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from fastapi import HTTPException, Request, Security
|
||||
@@ -17,7 +16,7 @@ logger = logging.getLogger(__name__)
|
||||
async def auth_middleware(request: Request):
|
||||
if not settings.ENABLE_AUTH:
|
||||
# If authentication is disabled, allow the request to proceed
|
||||
logger.warning("Auth disabled")
|
||||
logger.warn("Auth disabled")
|
||||
return {}
|
||||
|
||||
security = HTTPBearer()
|
||||
@@ -94,11 +93,7 @@ class APIKeyValidator:
|
||||
self.error_message = error_message
|
||||
|
||||
async def default_validator(self, api_key: str) -> bool:
|
||||
if not self.expected_token:
|
||||
raise ValueError(
|
||||
"Expected Token Required to be set when uisng API Key Validator default validation"
|
||||
)
|
||||
return secrets.compare_digest(api_key, self.expected_token)
|
||||
return api_key == self.expected_token
|
||||
|
||||
async def __call__(
|
||||
self, request: Request, api_key: str = Security(APIKeyHeader)
|
||||
|
||||
@@ -0,0 +1,166 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, cast
|
||||
|
||||
import ldclient
|
||||
from fastapi import HTTPException
|
||||
from ldclient import Context, LDClient
|
||||
from ldclient.config import Config
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from .config import SETTINGS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_client() -> LDClient:
|
||||
"""Get the LaunchDarkly client singleton."""
|
||||
return ldclient.get()
|
||||
|
||||
|
||||
def initialize_launchdarkly() -> None:
|
||||
sdk_key = SETTINGS.launch_darkly_sdk_key
|
||||
logger.debug(
|
||||
f"Initializing LaunchDarkly with SDK key: {'present' if sdk_key else 'missing'}"
|
||||
)
|
||||
|
||||
if not sdk_key:
|
||||
logger.warning("LaunchDarkly SDK key not configured")
|
||||
return
|
||||
|
||||
config = Config(sdk_key)
|
||||
ldclient.set_config(config)
|
||||
|
||||
if ldclient.get().is_initialized():
|
||||
logger.info("LaunchDarkly client initialized successfully")
|
||||
else:
|
||||
logger.error("LaunchDarkly client failed to initialize")
|
||||
|
||||
|
||||
def shutdown_launchdarkly() -> None:
|
||||
"""Shutdown the LaunchDarkly client."""
|
||||
if ldclient.get().is_initialized():
|
||||
ldclient.get().close()
|
||||
logger.info("LaunchDarkly client closed successfully")
|
||||
|
||||
|
||||
def create_context(
|
||||
user_id: str, additional_attributes: Optional[Dict[str, Any]] = None
|
||||
) -> Context:
|
||||
"""Create LaunchDarkly context with optional additional attributes."""
|
||||
builder = Context.builder(str(user_id)).kind("user")
|
||||
if additional_attributes:
|
||||
for key, value in additional_attributes.items():
|
||||
builder.set(key, value)
|
||||
return builder.build()
|
||||
|
||||
|
||||
def feature_flag(
|
||||
flag_key: str,
|
||||
default: bool = False,
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""
|
||||
Decorator for feature flag protected endpoints.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[P, Union[T, Awaitable[T]]],
|
||||
) -> Callable[P, Union[T, Awaitable[T]]]:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
try:
|
||||
user_id = kwargs.get("user_id")
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
if not get_client().is_initialized():
|
||||
logger.warning(
|
||||
f"LaunchDarkly not initialized, using default={default}"
|
||||
)
|
||||
is_enabled = default
|
||||
else:
|
||||
context = create_context(str(user_id))
|
||||
is_enabled = get_client().variation(flag_key, context, default)
|
||||
|
||||
if not is_enabled:
|
||||
raise HTTPException(status_code=404, detail="Feature not available")
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
if asyncio.iscoroutine(result):
|
||||
return await result
|
||||
return cast(T, result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
|
||||
raise
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
try:
|
||||
user_id = kwargs.get("user_id")
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
if not get_client().is_initialized():
|
||||
logger.warning(
|
||||
f"LaunchDarkly not initialized, using default={default}"
|
||||
)
|
||||
is_enabled = default
|
||||
else:
|
||||
context = create_context(str(user_id))
|
||||
is_enabled = get_client().variation(flag_key, context, default)
|
||||
|
||||
if not is_enabled:
|
||||
raise HTTPException(status_code=404, detail="Feature not available")
|
||||
|
||||
return cast(T, func(*args, **kwargs))
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
|
||||
raise
|
||||
|
||||
return cast(
|
||||
Callable[P, Union[T, Awaitable[T]]],
|
||||
async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper,
|
||||
)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def percentage_rollout(
|
||||
flag_key: str,
|
||||
default: bool = False,
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""Decorator for percentage-based rollouts."""
|
||||
return feature_flag(flag_key, default)
|
||||
|
||||
|
||||
def beta_feature(
|
||||
flag_key: Optional[str] = None,
|
||||
unauthorized_response: Any = {"message": "Not available in beta"},
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""Decorator for beta features."""
|
||||
actual_key = f"beta-{flag_key}" if flag_key else "beta"
|
||||
return feature_flag(actual_key, False)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_flag_variation(flag_key: str, return_value: Any):
|
||||
"""Context manager for testing feature flags."""
|
||||
original_variation = get_client().variation
|
||||
get_client().variation = lambda key, context, default: (
|
||||
return_value if key == flag_key else original_variation(key, context, default)
|
||||
)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
get_client().variation = original_variation
|
||||
@@ -0,0 +1,45 @@
|
||||
import pytest
|
||||
from ldclient import LDClient
|
||||
|
||||
from autogpt_libs.feature_flag.client import feature_flag, mock_flag_variation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ld_client(mocker):
|
||||
client = mocker.Mock(spec=LDClient)
|
||||
mocker.patch("ldclient.get", return_value=client)
|
||||
client.is_initialized.return_value = True
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feature_flag_enabled(ld_client):
|
||||
ld_client.variation.return_value = True
|
||||
|
||||
@feature_flag("test-flag")
|
||||
async def test_function(user_id: str):
|
||||
return "success"
|
||||
|
||||
result = test_function(user_id="test-user")
|
||||
assert result == "success"
|
||||
ld_client.variation.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feature_flag_unauthorized_response(ld_client):
|
||||
ld_client.variation.return_value = False
|
||||
|
||||
@feature_flag("test-flag")
|
||||
async def test_function(user_id: str):
|
||||
return "success"
|
||||
|
||||
result = test_function(user_id="test-user")
|
||||
assert result == {"error": "disabled"}
|
||||
|
||||
|
||||
def test_mock_flag_variation(ld_client):
|
||||
with mock_flag_variation("test-flag", True):
|
||||
assert ld_client.variation("test-flag", None, False)
|
||||
|
||||
with mock_flag_variation("test-flag", False):
|
||||
assert ld_client.variation("test-flag", None, False)
|
||||
@@ -0,0 +1,15 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
launch_darkly_sdk_key: str = Field(
|
||||
default="",
|
||||
description="The Launch Darkly SDK key",
|
||||
validation_alias="LAUNCH_DARKLY_SDK_KEY",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(case_sensitive=True, extra="ignore")
|
||||
|
||||
|
||||
SETTINGS = Settings()
|
||||
@@ -1,8 +1,6 @@
|
||||
"""Logging module for Auto-GPT."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@@ -10,16 +8,7 @@ from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from .filters import BelowLevelFilter
|
||||
from .formatters import AGPTFormatter
|
||||
|
||||
# Configure global socket timeout and gRPC keepalive to prevent deadlocks
|
||||
# This must be done at import time before any gRPC connections are established
|
||||
socket.setdefaulttimeout(30) # 30-second socket timeout
|
||||
|
||||
# Enable gRPC keepalive to detect dead connections faster
|
||||
os.environ.setdefault("GRPC_KEEPALIVE_TIME_MS", "30000") # 30 seconds
|
||||
os.environ.setdefault("GRPC_KEEPALIVE_TIMEOUT_MS", "5000") # 5 seconds
|
||||
os.environ.setdefault("GRPC_KEEPALIVE_PERMIT_WITHOUT_CALLS", "true")
|
||||
from .formatters import AGPTFormatter, StructuredLoggingFormatter
|
||||
|
||||
LOG_DIR = Path(__file__).parent.parent.parent.parent / "logs"
|
||||
LOG_FILE = "activity.log"
|
||||
@@ -90,45 +79,46 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
Note: This function is typically called at the start of the application
|
||||
to set up the logging infrastructure.
|
||||
"""
|
||||
|
||||
config = LoggingConfig()
|
||||
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
# Console output handlers
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
if config.level == logging.DEBUG:
|
||||
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
if config.level == logging.DEBUG:
|
||||
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
log_handlers += [stdout, stderr]
|
||||
|
||||
# Cloud logging setup
|
||||
if config.enable_cloud_logging or force_cloud_logging:
|
||||
import google.cloud.logging
|
||||
from google.cloud.logging.handlers import CloudLoggingHandler
|
||||
from google.cloud.logging_v2.handlers.transports import (
|
||||
BackgroundThreadTransport,
|
||||
)
|
||||
from google.cloud.logging_v2.handlers.transports.sync import SyncTransport
|
||||
|
||||
client = google.cloud.logging.Client()
|
||||
# Use BackgroundThreadTransport to prevent blocking the main thread
|
||||
# and deadlocks when gRPC calls to Google Cloud Logging hang
|
||||
cloud_handler = CloudLoggingHandler(
|
||||
client,
|
||||
name="autogpt_logs",
|
||||
transport=BackgroundThreadTransport,
|
||||
transport=SyncTransport,
|
||||
)
|
||||
cloud_handler.setLevel(config.level)
|
||||
cloud_handler.setFormatter(StructuredLoggingFormatter())
|
||||
log_handlers.append(cloud_handler)
|
||||
print("Cloud logging enabled")
|
||||
else:
|
||||
# Console output handlers
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
if config.level == logging.DEBUG:
|
||||
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
stderr = logging.StreamHandler()
|
||||
stderr.setLevel(logging.WARNING)
|
||||
if config.level == logging.DEBUG:
|
||||
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
|
||||
else:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
log_handlers += [stdout, stderr]
|
||||
print("Console logging enabled")
|
||||
|
||||
# File logging setup
|
||||
if config.enable_file_logging:
|
||||
@@ -166,6 +156,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
error_log_handler.setLevel(logging.ERROR)
|
||||
error_log_handler.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT, no_color=True))
|
||||
log_handlers.append(error_log_handler)
|
||||
print("File logging enabled")
|
||||
|
||||
# Configure the root logger
|
||||
logging.basicConfig(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from colorama import Fore, Style
|
||||
from google.cloud.logging_v2.handlers import CloudLoggingFilter, StructuredLogHandler
|
||||
|
||||
from .utils import remove_color_codes
|
||||
|
||||
@@ -79,3 +80,16 @@ class AGPTFormatter(FancyConsoleFormatter):
|
||||
return remove_color_codes(super().format(record))
|
||||
else:
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class StructuredLoggingFormatter(StructuredLogHandler, logging.Formatter):
|
||||
def __init__(self):
|
||||
# Set up CloudLoggingFilter to add diagnostic info to the log records
|
||||
self.cloud_logging_filter = CloudLoggingFilter()
|
||||
|
||||
# Init StructuredLogHandler
|
||||
super().__init__()
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
self.cloud_logging_filter.filter(record)
|
||||
return super().format(record)
|
||||
|
||||
@@ -1,5 +1,27 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from colorama import Fore
|
||||
|
||||
|
||||
def remove_color_codes(s: str) -> str:
|
||||
return re.sub(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])", "", s)
|
||||
|
||||
|
||||
def fmt_kwargs(kwargs: dict) -> str:
|
||||
return ", ".join(f"{n}={repr(v)}" for n, v in kwargs.items())
|
||||
|
||||
|
||||
def print_attribute(
|
||||
title: str, value: Any, title_color: str = Fore.GREEN, value_color: str = ""
|
||||
) -> None:
|
||||
logger = logging.getLogger()
|
||||
logger.info(
|
||||
str(value),
|
||||
extra={
|
||||
"title": f"{title.rstrip(':')}:",
|
||||
"title_color": title_color,
|
||||
"color": value_color,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,266 +1,20 @@
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
from typing import Callable, ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
|
||||
pass
|
||||
|
||||
|
||||
def thread_cached(
|
||||
func: Callable[P, R] | Callable[P, Awaitable[R]],
|
||||
) -> Callable[P, R] | Callable[P, Awaitable[R]]:
|
||||
thread_local = threading.local()
|
||||
|
||||
def _clear():
|
||||
if hasattr(thread_local, "cache"):
|
||||
del thread_local.cache
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = await cast(Callable[P, Awaitable[R]], func)(
|
||||
*args, **kwargs
|
||||
)
|
||||
return cache[key]
|
||||
|
||||
setattr(async_wrapper, "clear_cache", _clear)
|
||||
return async_wrapper
|
||||
|
||||
else:
|
||||
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
setattr(sync_wrapper, "clear_cache", _clear)
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
|
||||
|
||||
FuncT = TypeVar("FuncT")
|
||||
|
||||
|
||||
R_co = TypeVar("R_co", covariant=True)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AsyncCachedFunction(Protocol[P, R_co]):
|
||||
"""Protocol for async functions with cache management methods."""
|
||||
|
||||
def cache_clear(self) -> None:
|
||||
"""Clear all cached entries."""
|
||||
return None
|
||||
|
||||
def cache_info(self) -> dict[str, int | None]:
|
||||
"""Get cache statistics."""
|
||||
return {}
|
||||
|
||||
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
|
||||
"""Call the cached function."""
|
||||
return None # type: ignore
|
||||
|
||||
|
||||
def async_ttl_cache(
|
||||
maxsize: int = 128, ttl_seconds: int | None = None
|
||||
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
|
||||
"""
|
||||
TTL (Time To Live) cache decorator for async functions.
|
||||
|
||||
Similar to functools.lru_cache but works with async functions and includes optional TTL.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum number of cached entries
|
||||
ttl_seconds: Time to live in seconds. If None, entries never expire (like lru_cache)
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
|
||||
Example:
|
||||
# With TTL
|
||||
@async_ttl_cache(maxsize=1000, ttl_seconds=300)
|
||||
async def api_call(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
# Without TTL (permanent cache like lru_cache)
|
||||
@async_ttl_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
async_func: Callable[P, Awaitable[R]],
|
||||
) -> AsyncCachedFunction[P, R]:
|
||||
# Cache storage - use union type to handle both cases
|
||||
cache_storage: dict[tuple, R | Tuple[R, float]] = {}
|
||||
|
||||
@wraps(async_func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# Create cache key from arguments
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we have a valid cached entry
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
# No TTL - return cached result directly
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cast(R, cache_storage[key])
|
||||
else:
|
||||
# With TTL - check expiration
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cast(R, result)
|
||||
else:
|
||||
# Expired entry
|
||||
del cache_storage[key]
|
||||
logger.debug(
|
||||
f"Cache entry expired for {async_func.__name__}"
|
||||
)
|
||||
|
||||
# Cache miss or expired - fetch fresh data
|
||||
logger.debug(
|
||||
f"Cache miss for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
result = await async_func(*args, **kwargs)
|
||||
|
||||
# Store in cache
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Simple cleanup when cache gets too large
|
||||
if len(cache_storage) > maxsize:
|
||||
# Remove oldest entries (simple FIFO cleanup)
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
logger.debug(
|
||||
f"Cache cleanup: removed {len(oldest_keys)} entries for {async_func.__name__}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
# Add cache management methods (similar to functools.lru_cache)
|
||||
def cache_clear() -> None:
|
||||
cache_storage.clear()
|
||||
|
||||
def cache_info() -> dict[str, int | None]:
|
||||
return {
|
||||
"size": len(cache_storage),
|
||||
"maxsize": maxsize,
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
|
||||
# Attach methods to wrapper
|
||||
setattr(wrapper, "cache_clear", cache_clear)
|
||||
setattr(wrapper, "cache_info", cache_info)
|
||||
|
||||
return cast(AsyncCachedFunction[P, R], wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def async_cache(
|
||||
func: Callable[P, Awaitable[R]],
|
||||
) -> AsyncCachedFunction[P, R]:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def async_cache(
|
||||
func: None = None,
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]:
|
||||
pass
|
||||
|
||||
|
||||
def async_cache(
|
||||
func: Callable[P, Awaitable[R]] | None = None,
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
) -> (
|
||||
AsyncCachedFunction[P, R]
|
||||
| Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]
|
||||
):
|
||||
"""
|
||||
Process-level cache decorator for async functions (no TTL).
|
||||
|
||||
Similar to functools.lru_cache but works with async functions.
|
||||
This is a convenience wrapper around async_ttl_cache with ttl_seconds=None.
|
||||
|
||||
Args:
|
||||
func: The async function to cache (when used without parentheses)
|
||||
maxsize: Maximum number of cached entries
|
||||
|
||||
Returns:
|
||||
Decorated function or decorator
|
||||
|
||||
Example:
|
||||
# Without parentheses (uses default maxsize=128)
|
||||
@async_cache
|
||||
async def get_data(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
# With parentheses and custom maxsize
|
||||
@async_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
# Expensive computation here
|
||||
return {"result": param}
|
||||
"""
|
||||
if func is None:
|
||||
# Called with parentheses @async_cache() or @async_cache(maxsize=...)
|
||||
return async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
else:
|
||||
# Called without parentheses @async_cache
|
||||
decorator = async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
return decorator(func)
|
||||
return wrapper
|
||||
|
||||
@@ -1,705 +0,0 @@
|
||||
"""Tests for the @thread_cached decorator.
|
||||
|
||||
This module tests the thread-local caching functionality including:
|
||||
- Basic caching for sync and async functions
|
||||
- Thread isolation (each thread has its own cache)
|
||||
- Cache clearing functionality
|
||||
- Exception handling (exceptions are not cached)
|
||||
- Argument handling (positional vs keyword arguments)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt_libs.utils.cache import (
|
||||
async_cache,
|
||||
async_ttl_cache,
|
||||
clear_thread_cache,
|
||||
thread_cached,
|
||||
)
|
||||
|
||||
|
||||
class TestThreadCached:
|
||||
def test_sync_function_caching(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def expensive_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x + y
|
||||
|
||||
assert expensive_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert expensive_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert expensive_function(1, y=2) == 3
|
||||
assert call_count == 2
|
||||
|
||||
assert expensive_function(2, 3) == 5
|
||||
assert call_count == 3
|
||||
|
||||
assert expensive_function(1) == 1
|
||||
assert call_count == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function_caching(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
async def expensive_async_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x + y
|
||||
|
||||
assert await expensive_async_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert await expensive_async_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert await expensive_async_function(1, y=2) == 3
|
||||
assert call_count == 2
|
||||
|
||||
assert await expensive_async_function(2, 3) == 5
|
||||
assert call_count == 3
|
||||
|
||||
def test_thread_isolation(self):
|
||||
call_count = 0
|
||||
results = {}
|
||||
|
||||
@thread_cached
|
||||
def thread_specific_function(x: int) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{threading.current_thread().name}-{x}"
|
||||
|
||||
def worker(thread_id: int):
|
||||
result1 = thread_specific_function(1)
|
||||
result2 = thread_specific_function(1)
|
||||
result3 = thread_specific_function(2)
|
||||
results[thread_id] = (result1, result2, result3)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||
futures = [executor.submit(worker, i) for i in range(3)]
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
assert call_count >= 2
|
||||
|
||||
for thread_id, (r1, r2, r3) in results.items():
|
||||
assert r1 == r2
|
||||
assert r1 != r3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_thread_isolation(self):
|
||||
call_count = 0
|
||||
results = {}
|
||||
|
||||
@thread_cached
|
||||
async def async_thread_specific_function(x: int) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return f"{threading.current_thread().name}-{x}"
|
||||
|
||||
async def async_worker(worker_id: int):
|
||||
result1 = await async_thread_specific_function(1)
|
||||
result2 = await async_thread_specific_function(1)
|
||||
result3 = await async_thread_specific_function(2)
|
||||
results[worker_id] = (result1, result2, result3)
|
||||
|
||||
tasks = [async_worker(i) for i in range(3)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
for worker_id, (r1, r2, r3) in results.items():
|
||||
assert r1 == r2
|
||||
assert r1 != r3
|
||||
|
||||
def test_clear_cache_sync(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
|
||||
assert clearable_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
assert clearable_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
clear_thread_cache(clearable_function)
|
||||
|
||||
assert clearable_function(5) == 10
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_cache_async(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
async def clearable_async_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x * 2
|
||||
|
||||
assert await clearable_async_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
assert await clearable_async_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
clear_thread_cache(clearable_async_function)
|
||||
|
||||
assert await clearable_async_function(5) == 10
|
||||
assert call_count == 2
|
||||
|
||||
def test_simple_arguments(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def simple_function(a: str, b: int, c: str = "default") -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{a}-{b}-{c}"
|
||||
|
||||
# First call with all positional args
|
||||
result1 = simple_function("test", 42, "custom")
|
||||
assert call_count == 1
|
||||
|
||||
# Same args, all positional - should hit cache
|
||||
result2 = simple_function("test", 42, "custom")
|
||||
assert call_count == 1
|
||||
assert result1 == result2
|
||||
|
||||
# Same values but last arg as keyword - creates different cache key
|
||||
result3 = simple_function("test", 42, c="custom")
|
||||
assert call_count == 2
|
||||
assert result1 == result3 # Same result, different cache entry
|
||||
|
||||
# Different value - new cache entry
|
||||
result4 = simple_function("test", 43, "custom")
|
||||
assert call_count == 3
|
||||
assert result1 != result4
|
||||
|
||||
def test_positional_vs_keyword_args(self):
|
||||
"""Test that positional and keyword arguments create different cache entries."""
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def func(a: int, b: int = 10) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"result-{a}-{b}"
|
||||
|
||||
# All positional
|
||||
result1 = func(1, 2)
|
||||
assert call_count == 1
|
||||
assert result1 == "result-1-2"
|
||||
|
||||
# Same values, but second arg as keyword
|
||||
result2 = func(1, b=2)
|
||||
assert call_count == 2 # Different cache key!
|
||||
assert result2 == "result-1-2" # Same result
|
||||
|
||||
# Verify both are cached separately
|
||||
func(1, 2) # Uses first cache entry
|
||||
assert call_count == 2
|
||||
|
||||
func(1, b=2) # Uses second cache entry
|
||||
assert call_count == 2
|
||||
|
||||
def test_exception_handling(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def failing_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if x < 0:
|
||||
raise ValueError("Negative value")
|
||||
return x * 2
|
||||
|
||||
assert failing_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
failing_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
failing_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
assert failing_function(5) == 10
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_exception_handling(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
async def async_failing_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
if x < 0:
|
||||
raise ValueError("Negative value")
|
||||
return x * 2
|
||||
|
||||
assert await async_failing_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await async_failing_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await async_failing_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
def test_sync_caching_performance(self):
|
||||
@thread_cached
|
||||
def slow_function(x: int) -> int:
|
||||
print(f"slow_function called with x={x}")
|
||||
time.sleep(0.1)
|
||||
return x * 2
|
||||
|
||||
start = time.time()
|
||||
result1 = slow_function(5)
|
||||
first_call_time = time.time() - start
|
||||
print(f"First call took {first_call_time:.4f} seconds")
|
||||
|
||||
start = time.time()
|
||||
result2 = slow_function(5)
|
||||
second_call_time = time.time() - start
|
||||
print(f"Second call took {second_call_time:.4f} seconds")
|
||||
|
||||
assert result1 == result2 == 10
|
||||
assert first_call_time > 0.09
|
||||
assert second_call_time < 0.01
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_caching_performance(self):
|
||||
@thread_cached
|
||||
async def slow_async_function(x: int) -> int:
|
||||
print(f"slow_async_function called with x={x}")
|
||||
await asyncio.sleep(0.1)
|
||||
return x * 2
|
||||
|
||||
start = time.time()
|
||||
result1 = await slow_async_function(5)
|
||||
first_call_time = time.time() - start
|
||||
print(f"First async call took {first_call_time:.4f} seconds")
|
||||
|
||||
start = time.time()
|
||||
result2 = await slow_async_function(5)
|
||||
second_call_time = time.time() - start
|
||||
print(f"Second async call took {second_call_time:.4f} seconds")
|
||||
|
||||
assert result1 == result2 == 10
|
||||
assert first_call_time > 0.09
|
||||
assert second_call_time < 0.01
|
||||
|
||||
def test_with_mock_objects(self):
|
||||
mock = Mock(return_value=42)
|
||||
|
||||
@thread_cached
|
||||
def function_using_mock(x: int) -> int:
|
||||
return mock(x)
|
||||
|
||||
assert function_using_mock(1) == 42
|
||||
assert mock.call_count == 1
|
||||
|
||||
assert function_using_mock(1) == 42
|
||||
assert mock.call_count == 1
|
||||
|
||||
assert function_using_mock(2) == 42
|
||||
assert mock.call_count == 2
|
||||
|
||||
|
||||
class TestAsyncTTLCache:
|
||||
"""Tests for the @async_ttl_cache decorator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching(self):
|
||||
"""Test basic caching functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result3 = await cached_function(2, 3)
|
||||
assert result3 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_expiration(self):
|
||||
"""Test that cache entries expire after TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def short_lived_cache(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First call
|
||||
result1 = await short_lived_cache(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Second call immediately - should use cache
|
||||
result2 = await short_lived_cache(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Third call after expiration - should call function again
|
||||
result3 = await short_lived_cache(5)
|
||||
assert result3 == 10
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_info(self):
|
||||
"""Test cache info functionality."""
|
||||
|
||||
@async_ttl_cache(maxsize=5, ttl_seconds=300)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] == 300
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_clear(self):
|
||||
"""Test cache clearing functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 4
|
||||
|
||||
# First call
|
||||
result1 = await clearable_function(2)
|
||||
assert result1 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
result2 = await clearable_function(2)
|
||||
assert result2 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Clear cache
|
||||
clearable_function.cache_clear()
|
||||
|
||||
# Third call after clear - should call function again
|
||||
result3 = await clearable_function(2)
|
||||
assert result3 == 8
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maxsize_cleanup(self):
|
||||
"""Test that cache cleans up when maxsize is exceeded."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=3, ttl_seconds=60)
|
||||
async def size_limited_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# Fill cache to maxsize
|
||||
await size_limited_function(1) # call_count: 1
|
||||
await size_limited_function(2) # call_count: 2
|
||||
await size_limited_function(3) # call_count: 3
|
||||
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] == 3
|
||||
|
||||
# Add one more entry - should trigger cleanup
|
||||
await size_limited_function(4) # call_count: 4
|
||||
|
||||
# Cache size should be reduced (cleanup removes oldest entries)
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] is not None and info["size"] <= 3 # Should be cleaned up
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_argument_variations(self):
|
||||
"""Test caching with different argument patterns."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def arg_test_function(a: int, b: str = "default", *, c: int = 100) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{a}-{b}-{c}"
|
||||
|
||||
# Different ways to call with same logical arguments
|
||||
result1 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
|
||||
# Same arguments, same order - should use cache
|
||||
result2 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
assert result1 == result2
|
||||
|
||||
# Different arguments - should call function
|
||||
result3 = await arg_test_function(2, "test", c=200)
|
||||
assert call_count == 2
|
||||
assert result1 != result3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_handling(self):
|
||||
"""Test that exceptions are not cached."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def exception_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if x < 0:
|
||||
raise ValueError("Negative value not allowed")
|
||||
return x * 2
|
||||
|
||||
# Successful call - should be cached
|
||||
result1 = await exception_function(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Same successful call - should use cache
|
||||
result2 = await exception_function(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Exception call - should not be cached
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
# Same exception call - should call again (not cached)
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_calls(self):
|
||||
"""Test caching behavior with concurrent calls."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def concurrent_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.05) # Simulate work
|
||||
return x * x
|
||||
|
||||
# Launch concurrent calls with same arguments
|
||||
tasks = [concurrent_function(3) for _ in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All results should be the same
|
||||
assert all(result == 9 for result in results)
|
||||
|
||||
# Note: Due to race conditions, call_count might be up to 5 for concurrent calls
|
||||
# This tests that the cache doesn't break under concurrent access
|
||||
assert 1 <= call_count <= 5
|
||||
|
||||
|
||||
class TestAsyncCache:
|
||||
"""Tests for the @async_cache decorator (no TTL)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching_no_ttl(self):
|
||||
"""Test basic caching functionality without TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_cache(maxsize=10)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Third call after some time - should still use cache (no TTL)
|
||||
await asyncio.sleep(0.05)
|
||||
result3 = await cached_function(1, 2)
|
||||
assert result3 == 3
|
||||
assert call_count == 1 # Still no additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result4 = await cached_function(2, 3)
|
||||
assert result4 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_ttl_vs_ttl_behavior(self):
|
||||
"""Test the difference between TTL and no-TTL caching."""
|
||||
ttl_call_count = 0
|
||||
no_ttl_call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_call_count
|
||||
ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
@async_cache(maxsize=10) # No TTL
|
||||
async def no_ttl_function(x: int) -> int:
|
||||
nonlocal no_ttl_call_count
|
||||
no_ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First calls
|
||||
await ttl_function(5)
|
||||
await no_ttl_function(5)
|
||||
assert ttl_call_count == 1
|
||||
assert no_ttl_call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Second calls after TTL expiry
|
||||
await ttl_function(5) # Should call function again (TTL expired)
|
||||
await no_ttl_function(5) # Should use cache (no TTL)
|
||||
assert ttl_call_count == 2 # TTL function called again
|
||||
assert no_ttl_call_count == 1 # No-TTL function still cached
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_cache_info(self):
|
||||
"""Test cache info for no-TTL cache."""
|
||||
|
||||
@async_cache(maxsize=5)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] is None # No TTL
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
|
||||
class TestTTLOptional:
|
||||
"""Tests for optional TTL functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_none_behavior(self):
|
||||
"""Test that ttl_seconds=None works like no TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=None)
|
||||
async def no_ttl_via_none(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# First call
|
||||
result1 = await no_ttl_via_none(3)
|
||||
assert result1 == 9
|
||||
assert call_count == 1
|
||||
|
||||
# Wait (would expire if there was TTL)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Second call - should still use cache
|
||||
result2 = await no_ttl_via_none(3)
|
||||
assert result2 == 9
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Check cache info
|
||||
info = no_ttl_via_none.cache_info()
|
||||
assert info["ttl_seconds"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_options_comparison(self):
|
||||
"""Test different cache options work as expected."""
|
||||
ttl_calls = 0
|
||||
no_ttl_calls = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # With TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_calls
|
||||
ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
@async_cache(maxsize=10) # Process-level cache (no TTL)
|
||||
async def process_function(x: int) -> int:
|
||||
nonlocal no_ttl_calls
|
||||
no_ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
# Both should cache initially
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Immediate second calls - both should use cache
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# After TTL expiry
|
||||
await ttl_function(3) # Should call function again
|
||||
await process_function(3) # Should still use cache
|
||||
assert ttl_calls == 2 # TTL cache expired, called again
|
||||
assert no_ttl_calls == 1 # Process cache never expires
|
||||
@@ -1,15 +1,15 @@
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from contextlib import contextmanager
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from expiringdict import ExpiringDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
from redis.asyncio.lock import Lock as AsyncRedisLock
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
|
||||
class AsyncRedisKeyedMutex:
|
||||
class RedisKeyedMutex:
|
||||
"""
|
||||
This class provides a mutex that can be locked and unlocked by a specific key,
|
||||
using Redis as a distributed locking provider.
|
||||
@@ -17,45 +17,41 @@ class AsyncRedisKeyedMutex:
|
||||
in case the key is not unlocked for a specified duration, to prevent memory leaks.
|
||||
"""
|
||||
|
||||
def __init__(self, redis: "AsyncRedis", timeout: int | None = 60):
|
||||
def __init__(self, redis: "Redis", timeout: int | None = 60):
|
||||
self.redis = redis
|
||||
self.timeout = timeout
|
||||
self.locks: dict[Any, "AsyncRedisLock"] = ExpiringDict(
|
||||
self.locks: dict[Any, "RedisLock"] = ExpiringDict(
|
||||
max_len=6000, max_age_seconds=self.timeout
|
||||
)
|
||||
self.locks_lock = asyncio.Lock()
|
||||
self.locks_lock = Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def locked(self, key: Any):
|
||||
lock = await self.acquire(key)
|
||||
@contextmanager
|
||||
def locked(self, key: Any):
|
||||
lock = self.acquire(key)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if (await lock.locked()) and (await lock.owned()):
|
||||
await lock.release()
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
|
||||
async def acquire(self, key: Any) -> "AsyncRedisLock":
|
||||
def acquire(self, key: Any) -> "RedisLock":
|
||||
"""Acquires and returns a lock with the given key"""
|
||||
async with self.locks_lock:
|
||||
with self.locks_lock:
|
||||
if key not in self.locks:
|
||||
self.locks[key] = self.redis.lock(
|
||||
str(key), self.timeout, thread_local=False
|
||||
)
|
||||
lock = self.locks[key]
|
||||
await lock.acquire()
|
||||
lock.acquire()
|
||||
return lock
|
||||
|
||||
async def release(self, key: Any):
|
||||
if (
|
||||
(lock := self.locks.get(key))
|
||||
and (await lock.locked())
|
||||
and (await lock.owned())
|
||||
):
|
||||
await lock.release()
|
||||
def release(self, key: Any):
|
||||
if (lock := self.locks.get(key)) and lock.locked() and lock.owned():
|
||||
lock.release()
|
||||
|
||||
async def release_all_locks(self):
|
||||
def release_all_locks(self):
|
||||
"""Call this on process termination to ensure all locks are released"""
|
||||
async with self.locks_lock:
|
||||
for lock in self.locks.values():
|
||||
if (await lock.locked()) and (await lock.owned()):
|
||||
await lock.release()
|
||||
self.locks_lock.acquire(blocking=False)
|
||||
for lock in self.locks.values():
|
||||
if lock.locked() and lock.owned():
|
||||
lock.release()
|
||||
|
||||
2021
autogpt_platform/autogpt_libs/poetry.lock
generated
2021
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,29 +1,27 @@
|
||||
[tool.poetry]
|
||||
name = "autogpt-libs"
|
||||
version = "0.2.0"
|
||||
description = "Shared libraries across AutoGPT Platform"
|
||||
authors = ["AutoGPT team <info@agpt.co>"]
|
||||
description = "Shared libraries across NextGen AutoGPT"
|
||||
authors = ["Aarushi <aarushik93@gmail.com>"]
|
||||
readme = "README.md"
|
||||
packages = [{ include = "autogpt_libs" }]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<4.0"
|
||||
colorama = "^0.4.6"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.116.1"
|
||||
google-cloud-logging = "^3.12.1"
|
||||
launchdarkly-server-sdk = "^9.12.0"
|
||||
pydantic = "^2.11.7"
|
||||
pydantic-settings = "^2.10.1"
|
||||
google-cloud-logging = "^3.11.4"
|
||||
pydantic = "^2.10.6"
|
||||
pydantic-settings = "^2.7.1"
|
||||
pyjwt = "^2.10.1"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
pytest-mock = "^3.14.1"
|
||||
redis = "^6.2.0"
|
||||
supabase = "^2.16.0"
|
||||
uvicorn = "^0.35.0"
|
||||
pytest-asyncio = "^0.25.3"
|
||||
pytest-mock = "^3.14.0"
|
||||
python = ">=3.10,<4.0"
|
||||
python-dotenv = "^1.0.1"
|
||||
supabase = "^2.13.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "^0.12.3"
|
||||
redis = "^5.2.1"
|
||||
ruff = "^0.9.6"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
# Development and testing files
|
||||
**/__pycache__
|
||||
**/*.pyc
|
||||
**/*.pyo
|
||||
**/*.pyd
|
||||
**/.Python
|
||||
**/env/
|
||||
**/venv/
|
||||
**/.venv/
|
||||
**/pip-log.txt
|
||||
**/.pytest_cache/
|
||||
**/test-results/
|
||||
**/snapshots/
|
||||
**/test/
|
||||
|
||||
# IDE and editor files
|
||||
**/.vscode/
|
||||
**/.idea/
|
||||
**/*.swp
|
||||
**/*.swo
|
||||
*~
|
||||
|
||||
# OS files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Logs
|
||||
**/*.log
|
||||
**/logs/
|
||||
|
||||
# Git
|
||||
.git/
|
||||
.gitignore
|
||||
|
||||
# Documentation
|
||||
**/*.md
|
||||
!README.md
|
||||
|
||||
# Local development files
|
||||
.env
|
||||
.env.local
|
||||
**/.env.test
|
||||
|
||||
# Build artifacts
|
||||
**/dist/
|
||||
**/build/
|
||||
**/target/
|
||||
|
||||
# Docker files (avoid recursion)
|
||||
Dockerfile*
|
||||
docker-compose*
|
||||
.dockerignore
|
||||
@@ -1,65 +1,64 @@
|
||||
# Backend Configuration
|
||||
# This file contains environment variables that MUST be set for the AutoGPT platform
|
||||
# Variables with working defaults in settings.py are not included here
|
||||
|
||||
## ===== REQUIRED DATABASE CONFIGURATION ===== ##
|
||||
# PostgreSQL Database Connection
|
||||
DB_USER=postgres
|
||||
DB_PASS=your-super-secret-and-long-postgres-password
|
||||
DB_NAME=postgres
|
||||
DB_PORT=5432
|
||||
DB_HOST=localhost
|
||||
DB_CONNECTION_LIMIT=12
|
||||
DB_CONNECT_TIMEOUT=60
|
||||
DB_POOL_TIMEOUT=300
|
||||
DB_SCHEMA=platform
|
||||
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@localhost:${DB_PORT}/${DB_NAME}?connect_timeout=60&schema=platform"
|
||||
PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
ENABLE_AUTH=true
|
||||
|
||||
## ===== REQUIRED SERVICE CREDENTIALS ===== ##
|
||||
# Redis Configuration
|
||||
BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]
|
||||
|
||||
# generate using `from cryptography.fernet import Fernet;Fernet.generate_key().decode()`
|
||||
ENCRYPTION_KEY='dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw='
|
||||
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=password
|
||||
|
||||
# RabbitMQ Credentials
|
||||
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
ENABLE_CREDIT=false
|
||||
STRIPE_API_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
|
||||
# Supabase Authentication
|
||||
# What environment things should be logged under: local dev or prod
|
||||
APP_ENV=local
|
||||
# What environment to behave as: "local" or "cloud"
|
||||
BEHAVE_AS=local
|
||||
PYRO_HOST=localhost
|
||||
SENTRY_DSN=
|
||||
|
||||
# Email For Postmark so we can send emails
|
||||
POSTMARK_SERVER_API_TOKEN=
|
||||
POSTMARK_SENDER_EMAIL=invalid@invalid.com
|
||||
POSTMARK_WEBHOOK_TOKEN=
|
||||
|
||||
## User auth with Supabase is required for any of the 3rd party integrations with auth to work.
|
||||
ENABLE_AUTH=true
|
||||
SUPABASE_URL=http://localhost:8000
|
||||
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
|
||||
## ===== REQUIRED SECURITY KEYS ===== ##
|
||||
# Generate using: from cryptography.fernet import Fernet;Fernet.generate_key().decode()
|
||||
ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=
|
||||
UNSUBSCRIBE_SECRET_KEY=HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio=
|
||||
# RabbitMQ credentials -- Used for communication between services
|
||||
RABBITMQ_HOST=localhost
|
||||
RABBITMQ_PORT=5672
|
||||
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
|
||||
## ===== IMPORTANT OPTIONAL CONFIGURATION ===== ##
|
||||
# Platform URLs (set these for webhooks and OAuth to work)
|
||||
PLATFORM_BASE_URL=http://localhost:8000
|
||||
FRONTEND_BASE_URL=http://localhost:3000
|
||||
|
||||
# Media Storage (required for marketplace and library functionality)
|
||||
## GCS bucket is required for marketplace and library functionality
|
||||
MEDIA_GCS_BUCKET_NAME=
|
||||
|
||||
## ===== API KEYS AND OAUTH CREDENTIALS ===== ##
|
||||
# All API keys below are optional - only add what you need
|
||||
## For local development, you may need to set FRONTEND_BASE_URL for the OAuth flow
|
||||
## for integrations to work. Defaults to the value of PLATFORM_BASE_URL if not set.
|
||||
# FRONTEND_BASE_URL=http://localhost:3000
|
||||
|
||||
# AI/LLM Services
|
||||
OPENAI_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
GROQ_API_KEY=
|
||||
LLAMA_API_KEY=
|
||||
AIML_API_KEY=
|
||||
V0_API_KEY=
|
||||
OPEN_ROUTER_API_KEY=
|
||||
NVIDIA_API_KEY=
|
||||
## PLATFORM_BASE_URL must be set to a *publicly accessible* URL pointing to your backend
|
||||
## to use the platform's webhook-related functionality.
|
||||
## If you are developing locally, you can use something like ngrok to get a publc URL
|
||||
## and tunnel it to your locally running backend.
|
||||
PLATFORM_BASE_URL=http://localhost:3000
|
||||
|
||||
## == INTEGRATION CREDENTIALS == ##
|
||||
# Each set of server side credentials is required for the corresponding 3rd party
|
||||
# integration to work.
|
||||
|
||||
# OAuth Credentials
|
||||
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
|
||||
# e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||
|
||||
@@ -69,6 +68,7 @@ GITHUB_CLIENT_SECRET=
|
||||
|
||||
# Google OAuth App server credentials - https://console.cloud.google.com/apis/credentials, and enable gmail api and set scopes
|
||||
# https://console.cloud.google.com/apis/credentials/consent ?project=<your_project_id>
|
||||
|
||||
# You'll need to add/enable the following scopes (minimum):
|
||||
# https://console.developers.google.com/apis/api/gmail.googleapis.com/overview ?project=<your_project_id>
|
||||
# https://console.cloud.google.com/apis/library/sheets.googleapis.com/ ?project=<your_project_id>
|
||||
@@ -104,66 +104,83 @@ LINEAR_CLIENT_SECRET=
|
||||
TODOIST_CLIENT_ID=
|
||||
TODOIST_CLIENT_SECRET=
|
||||
|
||||
NOTION_CLIENT_ID=
|
||||
NOTION_CLIENT_SECRET=
|
||||
## ===== OPTIONAL API KEYS ===== ##
|
||||
|
||||
# LLM
|
||||
OPENAI_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
GROQ_API_KEY=
|
||||
OPEN_ROUTER_API_KEY=
|
||||
|
||||
# Reddit
|
||||
# Go to https://www.reddit.com/prefs/apps and create a new app
|
||||
# Choose "script" for the type
|
||||
# Fill in the redirect uri as <your_frontend_url>/auth/integrations/oauth_callback, e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||
REDDIT_CLIENT_ID=
|
||||
REDDIT_CLIENT_SECRET=
|
||||
REDDIT_USER_AGENT="AutoGPT:1.0 (by /u/autogpt)"
|
||||
|
||||
# Payment Processing
|
||||
STRIPE_API_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
|
||||
# Email Service (for sending notifications and confirmations)
|
||||
POSTMARK_SERVER_API_TOKEN=
|
||||
POSTMARK_SENDER_EMAIL=invalid@invalid.com
|
||||
POSTMARK_WEBHOOK_TOKEN=
|
||||
|
||||
# Error Tracking
|
||||
SENTRY_DSN=
|
||||
|
||||
# Cloudflare Turnstile (CAPTCHA) Configuration
|
||||
# Get these from the Cloudflare Turnstile dashboard: https://dash.cloudflare.com/?to=/:account/turnstile
|
||||
# This is the backend secret key
|
||||
TURNSTILE_SECRET_KEY=
|
||||
# This is the verify URL
|
||||
TURNSTILE_VERIFY_URL=https://challenges.cloudflare.com/turnstile/v0/siteverify
|
||||
|
||||
# Feature Flags
|
||||
LAUNCH_DARKLY_SDK_KEY=
|
||||
|
||||
# Content Generation & Media
|
||||
DID_API_KEY=
|
||||
FAL_API_KEY=
|
||||
IDEOGRAM_API_KEY=
|
||||
REPLICATE_API_KEY=
|
||||
REVID_API_KEY=
|
||||
SCREENSHOTONE_API_KEY=
|
||||
UNREAL_SPEECH_API_KEY=
|
||||
|
||||
# Data & Search Services
|
||||
E2B_API_KEY=
|
||||
EXA_API_KEY=
|
||||
JINA_API_KEY=
|
||||
MEM0_API_KEY=
|
||||
OPENWEATHERMAP_API_KEY=
|
||||
GOOGLE_MAPS_API_KEY=
|
||||
|
||||
# Communication Services
|
||||
# Discord
|
||||
DISCORD_BOT_TOKEN=
|
||||
MEDIUM_API_KEY=
|
||||
MEDIUM_AUTHOR_ID=
|
||||
|
||||
# SMTP/Email
|
||||
SMTP_SERVER=
|
||||
SMTP_PORT=
|
||||
SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# Business & Marketing Tools
|
||||
# D-ID
|
||||
DID_API_KEY=
|
||||
|
||||
# Open Weather Map
|
||||
OPENWEATHERMAP_API_KEY=
|
||||
|
||||
# SMTP
|
||||
SMTP_SERVER=
|
||||
SMTP_PORT=
|
||||
SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# Medium
|
||||
MEDIUM_API_KEY=
|
||||
MEDIUM_AUTHOR_ID=
|
||||
|
||||
# Google Maps
|
||||
GOOGLE_MAPS_API_KEY=
|
||||
|
||||
# Replicate
|
||||
REPLICATE_API_KEY=
|
||||
|
||||
# Ideogram
|
||||
IDEOGRAM_API_KEY=
|
||||
|
||||
# Fal
|
||||
FAL_API_KEY=
|
||||
|
||||
# Exa
|
||||
EXA_API_KEY=
|
||||
|
||||
# E2B
|
||||
E2B_API_KEY=
|
||||
|
||||
# Mem0
|
||||
MEM0_API_KEY=
|
||||
|
||||
# Nvidia
|
||||
NVIDIA_API_KEY=
|
||||
|
||||
# Apollo
|
||||
APOLLO_API_KEY=
|
||||
ENRICHLAYER_API_KEY=
|
||||
AYRSHARE_API_KEY=
|
||||
AYRSHARE_JWT_KEY=
|
||||
|
||||
# SmartLead
|
||||
SMARTLEAD_API_KEY=
|
||||
|
||||
# ZeroBounce
|
||||
ZEROBOUNCE_API_KEY=
|
||||
|
||||
# Other Services
|
||||
AUTOMOD_API_KEY=
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=INFO
|
||||
ENABLE_CLOUD_LOGGING=false
|
||||
ENABLE_FILE_LOGGING=false
|
||||
# Use to manually set the log directory
|
||||
# LOG_DIR=./logs
|
||||
1
autogpt_platform/backend/.gitignore
vendored
1
autogpt_platform/backend/.gitignore
vendored
@@ -1,4 +1,3 @@
|
||||
.env
|
||||
database.db
|
||||
database.db-journal
|
||||
dev.db
|
||||
|
||||
@@ -8,14 +8,14 @@ WORKDIR /app
|
||||
|
||||
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
|
||||
|
||||
# Update package list and install build dependencies in a single layer
|
||||
RUN apt-get update --allow-releaseinfo-change --fix-missing \
|
||||
&& apt-get install -y \
|
||||
build-essential \
|
||||
libpq5 \
|
||||
libz-dev \
|
||||
libssl-dev \
|
||||
postgresql-client
|
||||
RUN apt-get update --allow-releaseinfo-change --fix-missing
|
||||
|
||||
# Install build dependencies
|
||||
RUN apt-get install -y build-essential
|
||||
RUN apt-get install -y libpq5
|
||||
RUN apt-get install -y libz-dev
|
||||
RUN apt-get install -y libssl-dev
|
||||
RUN apt-get install -y postgresql-client
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry
|
||||
ENV POETRY_NO_INTERACTION=1
|
||||
@@ -68,17 +68,12 @@ COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.tom
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
FROM server_dependencies AS migrate
|
||||
|
||||
# Migration stage only needs schema and migrations - much lighter than full backend
|
||||
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
|
||||
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
|
||||
|
||||
FROM server_dependencies AS server
|
||||
|
||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||
RUN poetry install --no-ansi --only-root
|
||||
|
||||
ENV DATABASE_URL=""
|
||||
ENV PORT=8000
|
||||
|
||||
CMD ["poetry", "run", "rest"]
|
||||
|
||||
@@ -1 +1,75 @@
|
||||
[Advanced Setup (Dev Branch)](https://dev-docs.agpt.co/platform/advanced_setup/#autogpt_agent_server_advanced_set_up)
|
||||
# AutoGPT Agent Server Advanced set up
|
||||
|
||||
This guide walks you through a dockerized set up, with an external DB (postgres)
|
||||
|
||||
## Setup
|
||||
|
||||
We use the Poetry to manage the dependencies. To set up the project, follow these steps inside this directory:
|
||||
|
||||
0. Install Poetry
|
||||
```sh
|
||||
pip install poetry
|
||||
```
|
||||
|
||||
1. Configure Poetry to use .venv in your project directory
|
||||
```sh
|
||||
poetry config virtualenvs.in-project true
|
||||
```
|
||||
|
||||
2. Enter the poetry shell
|
||||
|
||||
```sh
|
||||
poetry shell
|
||||
```
|
||||
|
||||
3. Install dependencies
|
||||
|
||||
```sh
|
||||
poetry install
|
||||
```
|
||||
|
||||
4. Copy .env.example to .env
|
||||
|
||||
```sh
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
5. Generate the Prisma client
|
||||
|
||||
```sh
|
||||
poetry run prisma generate
|
||||
```
|
||||
|
||||
|
||||
> In case Prisma generates the client for the global Python installation instead of the virtual environment, the current mitigation is to just uninstall the global Prisma package:
|
||||
>
|
||||
> ```sh
|
||||
> pip uninstall prisma
|
||||
> ```
|
||||
>
|
||||
> Then run the generation again. The path *should* look something like this:
|
||||
> `<some path>/pypoetry/virtualenvs/backend-TQIRSwR6-py3.12/bin/prisma`
|
||||
|
||||
6. Run the postgres database from the /rnd folder
|
||||
|
||||
```sh
|
||||
cd autogpt_platform/
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
7. Run the migrations (from the backend folder)
|
||||
|
||||
```sh
|
||||
cd ../backend
|
||||
prisma migrate deploy
|
||||
```
|
||||
|
||||
## Running The Server
|
||||
|
||||
### Starting the server directly
|
||||
|
||||
Run the following command:
|
||||
|
||||
```sh
|
||||
poetry run app
|
||||
```
|
||||
|
||||
@@ -1 +1,210 @@
|
||||
[Getting Started (Released)](https://docs.agpt.co/platform/getting-started/#autogpt_agent_server)
|
||||
# AutoGPT Agent Server
|
||||
|
||||
This is an initial project for creating the next generation of agent execution, which is an AutoGPT agent server.
|
||||
The agent server will enable the creation of composite multi-agent systems that utilize AutoGPT agents and other non-agent components as its primitives.
|
||||
|
||||
## Docs
|
||||
|
||||
You can access the docs for the [AutoGPT Agent Server here](https://docs.agpt.co/server/setup).
|
||||
|
||||
## Setup
|
||||
|
||||
We use the Poetry to manage the dependencies. To set up the project, follow these steps inside this directory:
|
||||
|
||||
0. Install Poetry
|
||||
```sh
|
||||
pip install poetry
|
||||
```
|
||||
|
||||
1. Configure Poetry to use .venv in your project directory
|
||||
```sh
|
||||
poetry config virtualenvs.in-project true
|
||||
```
|
||||
|
||||
2. Enter the poetry shell
|
||||
|
||||
```sh
|
||||
poetry shell
|
||||
```
|
||||
|
||||
3. Install dependencies
|
||||
|
||||
```sh
|
||||
poetry install
|
||||
```
|
||||
|
||||
4. Copy .env.example to .env
|
||||
|
||||
```sh
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
5. Generate the Prisma client
|
||||
|
||||
```sh
|
||||
poetry run prisma generate
|
||||
```
|
||||
|
||||
|
||||
> In case Prisma generates the client for the global Python installation instead of the virtual environment, the current mitigation is to just uninstall the global Prisma package:
|
||||
>
|
||||
> ```sh
|
||||
> pip uninstall prisma
|
||||
> ```
|
||||
>
|
||||
> Then run the generation again. The path *should* look something like this:
|
||||
> `<some path>/pypoetry/virtualenvs/backend-TQIRSwR6-py3.12/bin/prisma`
|
||||
|
||||
6. Migrate the database. Be careful because this deletes current data in the database.
|
||||
|
||||
```sh
|
||||
docker compose up db -d
|
||||
poetry run prisma migrate deploy
|
||||
```
|
||||
|
||||
## Running The Server
|
||||
|
||||
### Starting the server without Docker
|
||||
|
||||
To run the server locally, start in the autogpt_platform folder:
|
||||
|
||||
```sh
|
||||
cd ..
|
||||
```
|
||||
|
||||
Run the following command to run database in docker but the application locally:
|
||||
|
||||
```sh
|
||||
docker compose --profile local up deps --build --detach
|
||||
cd backend
|
||||
poetry run app
|
||||
```
|
||||
|
||||
### Starting the server with Docker
|
||||
|
||||
Run the following command to build the dockerfiles:
|
||||
|
||||
```sh
|
||||
docker compose build
|
||||
```
|
||||
|
||||
Run the following command to run the app:
|
||||
|
||||
```sh
|
||||
docker compose up
|
||||
```
|
||||
|
||||
Run the following to automatically rebuild when code changes, in another terminal:
|
||||
|
||||
```sh
|
||||
docker compose watch
|
||||
```
|
||||
|
||||
Run the following command to shut down:
|
||||
|
||||
```sh
|
||||
docker compose down
|
||||
```
|
||||
|
||||
If you run into issues with dangling orphans, try:
|
||||
|
||||
```sh
|
||||
docker compose down --volumes --remove-orphans && docker-compose up --force-recreate --renew-anon-volumes --remove-orphans
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
To run the tests:
|
||||
|
||||
```sh
|
||||
poetry run test
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Formatting & Linting
|
||||
Auto formatter and linter are set up in the project. To run them:
|
||||
|
||||
Install:
|
||||
```sh
|
||||
poetry install --with dev
|
||||
```
|
||||
|
||||
Format the code:
|
||||
```sh
|
||||
poetry run format
|
||||
```
|
||||
|
||||
Lint the code:
|
||||
```sh
|
||||
poetry run lint
|
||||
```
|
||||
|
||||
## Project Outline
|
||||
|
||||
The current project has the following main modules:
|
||||
|
||||
### **blocks**
|
||||
|
||||
This module stores all the Agent Blocks, which are reusable components to build a graph that represents the agent's behavior.
|
||||
|
||||
### **data**
|
||||
|
||||
This module stores the logical model that is persisted in the database.
|
||||
It abstracts the database operations into functions that can be called by the service layer.
|
||||
Any code that interacts with Prisma objects or the database should reside in this module.
|
||||
The main models are:
|
||||
* `block`: anything related to the block used in the graph
|
||||
* `execution`: anything related to the execution graph execution
|
||||
* `graph`: anything related to the graph, node, and its relations
|
||||
|
||||
### **execution**
|
||||
|
||||
This module stores the business logic of executing the graph.
|
||||
It currently has the following main modules:
|
||||
* `manager`: A service that consumes the queue of the graph execution and executes the graph. It contains both pieces of logic.
|
||||
* `scheduler`: A service that triggers scheduled graph execution based on a cron expression. It pushes an execution request to the manager.
|
||||
|
||||
### **server**
|
||||
|
||||
This module stores the logic for the server API.
|
||||
It contains all the logic used for the API that allows the client to create, execute, and monitor the graph and its execution.
|
||||
This API service interacts with other services like those defined in `manager` and `scheduler`.
|
||||
|
||||
### **utils**
|
||||
|
||||
This module stores utility functions that are used across the project.
|
||||
Currently, it has two main modules:
|
||||
* `process`: A module that contains the logic to spawn a new process.
|
||||
* `service`: A module that serves as a parent class for all the services in the project.
|
||||
|
||||
## Service Communication
|
||||
|
||||
Currently, there are only 3 active services:
|
||||
|
||||
- AgentServer (the API, defined in `server.py`)
|
||||
- ExecutionManager (the executor, defined in `manager.py`)
|
||||
- ExecutionScheduler (the scheduler, defined in `scheduler.py`)
|
||||
|
||||
The services run in independent Python processes and communicate through an IPC.
|
||||
A communication layer (`service.py`) is created to decouple the communication library from the implementation.
|
||||
|
||||
Currently, the IPC is done using Pyro5 and abstracted in a way that allows a function decorated with `@expose` to be called from a different process.
|
||||
|
||||
|
||||
By default the daemons run on the following ports:
|
||||
|
||||
Execution Manager Daemon: 8002
|
||||
Execution Scheduler Daemon: 8003
|
||||
Rest Server Daemon: 8004
|
||||
|
||||
## Adding a New Agent Block
|
||||
|
||||
To add a new agent block, you need to create a new class that inherits from `Block` and provides the following information:
|
||||
* All the block code should live in the `blocks` (`backend.blocks`) module.
|
||||
* `input_schema`: the schema of the input data, represented by a Pydantic object.
|
||||
* `output_schema`: the schema of the output data, represented by a Pydantic object.
|
||||
* `run` method: the main logic of the block.
|
||||
* `test_input` & `test_output`: the sample input and output data for the block, which will be used to auto-test the block.
|
||||
* You can mock the functions declared in the block using the `test_mock` field for your unit tests.
|
||||
* Once you finish creating the block, you can test it by running `poetry run pytest -s test/block/test_block.py`.
|
||||
|
||||
@@ -1,237 +0,0 @@
|
||||
# Backend Testing Guide
|
||||
|
||||
This guide covers testing practices for the AutoGPT Platform backend, with a focus on snapshot testing for API endpoints.
|
||||
|
||||
## Table of Contents
|
||||
- [Overview](#overview)
|
||||
- [Running Tests](#running-tests)
|
||||
- [Snapshot Testing](#snapshot-testing)
|
||||
- [Writing Tests for API Routes](#writing-tests-for-api-routes)
|
||||
- [Best Practices](#best-practices)
|
||||
|
||||
## Overview
|
||||
|
||||
The backend uses pytest for testing with the following key libraries:
|
||||
- `pytest` - Test framework
|
||||
- `pytest-asyncio` - Async test support
|
||||
- `pytest-mock` - Mocking support
|
||||
- `pytest-snapshot` - Snapshot testing for API responses
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Run all tests
|
||||
```bash
|
||||
poetry run test
|
||||
```
|
||||
|
||||
### Run specific test file
|
||||
```bash
|
||||
poetry run pytest path/to/test_file.py
|
||||
```
|
||||
|
||||
### Run with verbose output
|
||||
```bash
|
||||
poetry run pytest -v
|
||||
```
|
||||
|
||||
### Run with coverage
|
||||
```bash
|
||||
poetry run pytest --cov=backend
|
||||
```
|
||||
|
||||
## Snapshot Testing
|
||||
|
||||
Snapshot testing captures the output of your code and compares it against previously saved snapshots. This is particularly useful for testing API responses.
|
||||
|
||||
### How Snapshot Testing Works
|
||||
|
||||
1. First run: Creates snapshot files in `snapshots/` directories
|
||||
2. Subsequent runs: Compares output against saved snapshots
|
||||
3. Changes detected: Test fails if output differs from snapshot
|
||||
|
||||
### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
### Snapshot Test Example
|
||||
|
||||
```python
|
||||
import json
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
def test_api_endpoint(snapshot: Snapshot):
|
||||
response = client.get("/api/endpoint")
|
||||
|
||||
# Snapshot the response
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
snapshot.assert_match(
|
||||
json.dumps(response.json(), indent=2, sort_keys=True),
|
||||
"endpoint_response"
|
||||
)
|
||||
```
|
||||
|
||||
### Best Practices for Snapshots
|
||||
|
||||
1. **Use descriptive names**: `"user_list_response"` not `"response1"`
|
||||
2. **Sort JSON keys**: Ensures consistent snapshots
|
||||
3. **Format JSON**: Use `indent=2` for readable diffs
|
||||
4. **Exclude dynamic data**: Remove timestamps, IDs, etc. that change between runs
|
||||
|
||||
Example of excluding dynamic data:
|
||||
```python
|
||||
response_data = response.json()
|
||||
# Remove dynamic fields for snapshot
|
||||
response_data.pop("created_at", None)
|
||||
response_data.pop("id", None)
|
||||
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"static_response_data"
|
||||
)
|
||||
```
|
||||
|
||||
## Writing Tests for API Routes
|
||||
|
||||
### Basic Structure
|
||||
|
||||
```python
|
||||
import json
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
from backend.server.v2.myroute import router
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
def test_endpoint_success(snapshot: Snapshot):
|
||||
response = client.get("/endpoint")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Test specific fields
|
||||
data = response.json()
|
||||
assert data["status"] == "success"
|
||||
|
||||
# Snapshot the full response
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
snapshot.assert_match(
|
||||
json.dumps(data, indent=2, sort_keys=True),
|
||||
"endpoint_success_response"
|
||||
)
|
||||
```
|
||||
|
||||
### Testing with Authentication
|
||||
|
||||
```python
|
||||
def override_auth_middleware():
|
||||
return {"sub": "test-user-id"}
|
||||
|
||||
def override_get_user_id():
|
||||
return "test-user-id"
|
||||
|
||||
app.dependency_overrides[auth_middleware] = override_auth_middleware
|
||||
app.dependency_overrides[get_user_id] = override_get_user_id
|
||||
```
|
||||
|
||||
### Mocking External Services
|
||||
|
||||
```python
|
||||
def test_external_api_call(mocker, snapshot):
|
||||
# Mock external service
|
||||
mock_response = {"external": "data"}
|
||||
mocker.patch(
|
||||
"backend.services.external_api.call",
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
response = client.post("/api/process")
|
||||
assert response.status_code == 200
|
||||
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
snapshot.assert_match(
|
||||
json.dumps(response.json(), indent=2, sort_keys=True),
|
||||
"process_with_external_response"
|
||||
)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Test Organization
|
||||
- Place tests next to the code: `routes.py` → `routes_test.py`
|
||||
- Use descriptive test names: `test_create_user_with_invalid_email`
|
||||
- Group related tests in classes when appropriate
|
||||
|
||||
### 2. Test Coverage
|
||||
- Test happy path and error cases
|
||||
- Test edge cases (empty data, invalid formats)
|
||||
- Test authentication and authorization
|
||||
|
||||
### 3. Snapshot Testing Guidelines
|
||||
- Review all snapshot changes carefully
|
||||
- Don't snapshot sensitive data
|
||||
- Keep snapshots focused and minimal
|
||||
- Update snapshots intentionally, not accidentally
|
||||
|
||||
### 4. Async Testing
|
||||
- Use regular `def` for FastAPI TestClient tests
|
||||
- Use `async def` with `@pytest.mark.asyncio` for testing async functions directly
|
||||
|
||||
### 5. Fixtures
|
||||
Create reusable fixtures for common test data:
|
||||
|
||||
```python
|
||||
@pytest.fixture
|
||||
def sample_user():
|
||||
return {
|
||||
"email": "test@example.com",
|
||||
"name": "Test User"
|
||||
}
|
||||
|
||||
def test_create_user(sample_user, snapshot):
|
||||
response = client.post("/users", json=sample_user)
|
||||
# ... test implementation
|
||||
```
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
The GitHub Actions workflow automatically runs tests on:
|
||||
- Pull requests
|
||||
- Pushes to main branch
|
||||
|
||||
Snapshot tests work in CI by:
|
||||
1. Committing snapshot files to the repository
|
||||
2. CI compares against committed snapshots
|
||||
3. Fails if snapshots don't match
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Snapshot Mismatches
|
||||
- Review the diff carefully
|
||||
- If changes are expected: `poetry run pytest --snapshot-update`
|
||||
- If changes are unexpected: Fix the code causing the difference
|
||||
|
||||
### Async Test Issues
|
||||
- Ensure async functions use `@pytest.mark.asyncio`
|
||||
- Use `AsyncMock` for mocking async functions
|
||||
- FastAPI TestClient handles async automatically
|
||||
|
||||
### Import Errors
|
||||
- Check that all dependencies are in `pyproject.toml`
|
||||
- Run `poetry install` to ensure dependencies are installed
|
||||
- Verify import paths are correct
|
||||
|
||||
## Summary
|
||||
|
||||
Snapshot testing provides a powerful way to ensure API responses remain consistent. Combined with traditional assertions, it creates a robust test suite that catches regressions while remaining maintainable.
|
||||
|
||||
Remember: Good tests are as important as good code!
|
||||
@@ -1,150 +0,0 @@
|
||||
# Test Data Scripts
|
||||
|
||||
This directory contains scripts for creating and updating test data in the AutoGPT Platform database, specifically designed to test the materialized views for the store functionality.
|
||||
|
||||
## Scripts
|
||||
|
||||
### test_data_creator.py
|
||||
Creates a comprehensive set of test data including:
|
||||
- Users with profiles
|
||||
- Agent graphs, nodes, and executions
|
||||
- Store listings with multiple versions
|
||||
- Reviews and ratings
|
||||
- Library agents
|
||||
- Integration webhooks
|
||||
- Onboarding data
|
||||
- Credit transactions
|
||||
|
||||
**Image/Video Domains Used:**
|
||||
- Images: `picsum.photos` (for all image URLs)
|
||||
- Videos: `youtube.com` (for store listing videos)
|
||||
|
||||
### test_data_updater.py
|
||||
Updates existing test data to simulate real-world changes:
|
||||
- Adds new agent graph executions
|
||||
- Creates new store listing reviews
|
||||
- Updates store listing versions
|
||||
- Adds credit transactions
|
||||
- Refreshes materialized views
|
||||
|
||||
### check_db.py
|
||||
Tests and verifies materialized views functionality:
|
||||
- Checks pg_cron job status (for automatic refresh)
|
||||
- Displays current materialized view counts
|
||||
- Adds test data (executions and reviews)
|
||||
- Creates store listings if none exist
|
||||
- Manually refreshes materialized views
|
||||
- Compares before/after counts to verify updates
|
||||
- Provides a summary of test results
|
||||
|
||||
## Materialized Views
|
||||
|
||||
The scripts test three key database views:
|
||||
|
||||
1. **mv_agent_run_counts**: Tracks execution counts by agent
|
||||
2. **mv_review_stats**: Tracks review statistics (count, average rating) by store listing
|
||||
3. **StoreAgent**: A view that combines store listing data with execution counts and ratings for display
|
||||
|
||||
The materialized views (mv_agent_run_counts and mv_review_stats) are automatically refreshed every 15 minutes via pg_cron, or can be manually refreshed using the `refresh_store_materialized_views()` function.
|
||||
|
||||
## Usage
|
||||
|
||||
### Prerequisites
|
||||
|
||||
1. Ensure the database is running:
|
||||
```bash
|
||||
docker compose up -d
|
||||
# or for test database:
|
||||
docker compose -f docker-compose.test.yaml --env-file ../.env up -d
|
||||
```
|
||||
|
||||
2. Run database migrations:
|
||||
```bash
|
||||
poetry run prisma migrate deploy
|
||||
```
|
||||
|
||||
### Running the Scripts
|
||||
|
||||
#### Option 1: Use the helper script (from backend directory)
|
||||
```bash
|
||||
poetry run python run_test_data.py
|
||||
```
|
||||
|
||||
#### Option 2: Run individually
|
||||
```bash
|
||||
# From backend/test directory:
|
||||
# Create initial test data
|
||||
poetry run python test_data_creator.py
|
||||
|
||||
# Update data to test materialized view changes
|
||||
poetry run python test_data_updater.py
|
||||
|
||||
# From backend directory:
|
||||
# Test materialized views functionality
|
||||
poetry run python check_db.py
|
||||
|
||||
# Check store data status
|
||||
poetry run python check_store_data.py
|
||||
```
|
||||
|
||||
#### Option 3: Use the shell script (from backend directory)
|
||||
```bash
|
||||
./run_test_data_scripts.sh
|
||||
```
|
||||
|
||||
### Manual Materialized View Refresh
|
||||
|
||||
To manually refresh the materialized views:
|
||||
```sql
|
||||
SELECT refresh_store_materialized_views();
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
The scripts use the database configuration from your `.env` file:
|
||||
- `DATABASE_URL`: PostgreSQL connection string
|
||||
- Database should have the platform schema
|
||||
|
||||
## Data Generation Limits
|
||||
|
||||
Configured in `test_data_creator.py`:
|
||||
- 100 users
|
||||
- 100 agent blocks
|
||||
- 1-5 graphs per user
|
||||
- 2-5 nodes per graph
|
||||
- 1-5 presets per user
|
||||
- 1-10 library agents per user
|
||||
- 1-20 executions per graph
|
||||
- 1-5 reviews per store listing version
|
||||
|
||||
## Notes
|
||||
|
||||
- All image URLs use `picsum.photos` for consistency with Next.js image configuration
|
||||
- The scripts create realistic relationships between entities
|
||||
- Materialized views are refreshed at the end of each script
|
||||
- Data is designed to test both happy paths and edge cases
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Reviews and StoreAgent view showing 0
|
||||
|
||||
If `check_db.py` shows that reviews remain at 0 and StoreAgent view shows 0 store agents:
|
||||
|
||||
1. **No store listings exist**: The script will automatically create test store listings if none exist
|
||||
2. **No approved versions**: Store listings need approved versions to appear in the StoreAgent view
|
||||
3. **Check with `check_store_data.py`**: This script provides detailed information about:
|
||||
- Total store listings
|
||||
- Store listing versions by status
|
||||
- Existing reviews
|
||||
- StoreAgent view contents
|
||||
- Agent graph executions
|
||||
|
||||
### pg_cron not installed
|
||||
|
||||
The warning "pg_cron extension is not installed" is normal in local development environments. The materialized views can still be refreshed manually using the `refresh_store_materialized_views()` function, which all scripts do automatically.
|
||||
|
||||
### Common Issues
|
||||
|
||||
- **Type errors with None values**: Fixed in the latest version of check_db.py by using `or 0` for nullable numeric fields
|
||||
- **Missing relations**: Ensure you're using the correct field names (e.g., `StoreListing` not `storeListing` in includes)
|
||||
- **Column name mismatches**: The database uses camelCase for column names (e.g., `agentGraphId` not `agent_graph_id`)
|
||||
@@ -1,10 +1,6 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.util.process import AppProcess
|
||||
|
||||
@@ -36,18 +32,18 @@ def main(**kwargs):
|
||||
Run all the processes required for the AutoGPT-server (REST and WebSocket APIs).
|
||||
"""
|
||||
|
||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||
from backend.executor import DatabaseManager, ExecutionManager, ExecutionScheduler
|
||||
from backend.notifications import NotificationManager
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.server.ws_api import WebsocketServer
|
||||
|
||||
run_processes(
|
||||
DatabaseManager().set_log_level("warning"),
|
||||
Scheduler(),
|
||||
DatabaseManager(),
|
||||
ExecutionManager(),
|
||||
ExecutionScheduler(),
|
||||
NotificationManager(),
|
||||
WebsocketServer(),
|
||||
AgentServer(),
|
||||
ExecutionManager(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,124 +1,89 @@
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
from typing import Type, TypeVar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from backend.data.block import Block
|
||||
|
||||
# Dynamically load all modules under backend.blocks
|
||||
AVAILABLE_MODULES = []
|
||||
current_dir = Path(__file__).parent
|
||||
modules = [
|
||||
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
|
||||
for f in current_dir.rglob("*.py")
|
||||
if f.is_file() and f.name != "__init__.py"
|
||||
]
|
||||
for module in modules:
|
||||
if not re.match("^[a-z0-9_.]+$", module):
|
||||
raise ValueError(
|
||||
f"Block module {module} error: module name must be lowercase, "
|
||||
"and contain only alphanumeric characters and underscores."
|
||||
)
|
||||
|
||||
importlib.import_module(f".{module}", package=__name__)
|
||||
AVAILABLE_MODULES.append(module)
|
||||
|
||||
# Load all Block instances from the available modules
|
||||
AVAILABLE_BLOCKS: dict[str, Type[Block]] = {}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.block import Block
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@functools.cache
|
||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
from backend.data.block import Block
|
||||
from backend.util.settings import Config
|
||||
|
||||
# Check if example blocks should be loaded from settings
|
||||
config = Config()
|
||||
load_examples = config.enable_example_blocks
|
||||
|
||||
# Dynamically load all modules under backend.blocks
|
||||
current_dir = Path(__file__).parent
|
||||
modules = []
|
||||
for f in current_dir.rglob("*.py"):
|
||||
if not f.is_file() or f.name == "__init__.py" or f.name.startswith("test_"):
|
||||
continue
|
||||
|
||||
# Skip examples directory if not enabled
|
||||
relative_path = f.relative_to(current_dir)
|
||||
if not load_examples and relative_path.parts[0] == "examples":
|
||||
continue
|
||||
|
||||
module_path = str(relative_path)[:-3].replace(os.path.sep, ".")
|
||||
modules.append(module_path)
|
||||
|
||||
for module in modules:
|
||||
if not re.match("^[a-z0-9_.]+$", module):
|
||||
raise ValueError(
|
||||
f"Block module {module} error: module name must be lowercase, "
|
||||
"and contain only alphanumeric characters and underscores."
|
||||
)
|
||||
|
||||
importlib.import_module(f".{module}", package=__name__)
|
||||
|
||||
# Load all Block instances from the available modules
|
||||
available_blocks: dict[str, type["Block"]] = {}
|
||||
for block_cls in all_subclasses(Block):
|
||||
class_name = block_cls.__name__
|
||||
|
||||
if class_name.endswith("Base"):
|
||||
continue
|
||||
|
||||
if not class_name.endswith("Block"):
|
||||
raise ValueError(
|
||||
f"Block class {class_name} does not end with 'Block'. "
|
||||
"If you are creating an abstract class, "
|
||||
"please name the class with 'Base' at the end"
|
||||
)
|
||||
|
||||
block = block_cls.create()
|
||||
|
||||
if not isinstance(block.id, str) or len(block.id) != 36:
|
||||
raise ValueError(
|
||||
f"Block ID {block.name} error: {block.id} is not a valid UUID"
|
||||
)
|
||||
|
||||
if block.id in available_blocks:
|
||||
raise ValueError(
|
||||
f"Block ID {block.name} error: {block.id} is already in use"
|
||||
)
|
||||
|
||||
input_schema = block.input_schema.model_fields
|
||||
output_schema = block.output_schema.model_fields
|
||||
|
||||
# Make sure `error` field is a string in the output schema
|
||||
if "error" in output_schema and output_schema["error"].annotation is not str:
|
||||
raise ValueError(
|
||||
f"{block.name} `error` field in output_schema must be a string"
|
||||
)
|
||||
|
||||
# Ensure all fields in input_schema and output_schema are annotated SchemaFields
|
||||
for field_name, field in [*input_schema.items(), *output_schema.items()]:
|
||||
if field.annotation is None:
|
||||
raise ValueError(
|
||||
f"{block.name} has a field {field_name} that is not annotated"
|
||||
)
|
||||
if field.json_schema_extra is None:
|
||||
raise ValueError(
|
||||
f"{block.name} has a field {field_name} not defined as SchemaField"
|
||||
)
|
||||
|
||||
for field in block.input_schema.model_fields.values():
|
||||
if field.annotation is bool and field.default not in (True, False):
|
||||
raise ValueError(
|
||||
f"{block.name} has a boolean field with no default value"
|
||||
)
|
||||
|
||||
available_blocks[block.id] = block_cls
|
||||
|
||||
# Filter out blocks with incomplete auth configs, e.g. missing OAuth server secrets
|
||||
from backend.data.block import is_block_auth_configured
|
||||
|
||||
filtered_blocks = {}
|
||||
for block_id, block_cls in available_blocks.items():
|
||||
if is_block_auth_configured(block_cls):
|
||||
filtered_blocks[block_id] = block_cls
|
||||
|
||||
return filtered_blocks
|
||||
|
||||
|
||||
__all__ = ["load_all_blocks"]
|
||||
|
||||
|
||||
def all_subclasses(cls: type[T]) -> list[type[T]]:
|
||||
def all_subclasses(cls: Type[T]) -> list[Type[T]]:
|
||||
subclasses = cls.__subclasses__()
|
||||
for subclass in subclasses:
|
||||
subclasses += all_subclasses(subclass)
|
||||
return subclasses
|
||||
|
||||
|
||||
for block_cls in all_subclasses(Block):
|
||||
name = block_cls.__name__
|
||||
|
||||
if block_cls.__name__.endswith("Base"):
|
||||
continue
|
||||
|
||||
if not block_cls.__name__.endswith("Block"):
|
||||
raise ValueError(
|
||||
f"Block class {block_cls.__name__} does not end with 'Block', If you are creating an abstract class, please name the class with 'Base' at the end"
|
||||
)
|
||||
|
||||
block = block_cls.create()
|
||||
|
||||
if not isinstance(block.id, str) or len(block.id) != 36:
|
||||
raise ValueError(f"Block ID {block.name} error: {block.id} is not a valid UUID")
|
||||
|
||||
if block.id in AVAILABLE_BLOCKS:
|
||||
raise ValueError(f"Block ID {block.name} error: {block.id} is already in use")
|
||||
|
||||
input_schema = block.input_schema.model_fields
|
||||
output_schema = block.output_schema.model_fields
|
||||
|
||||
# Make sure `error` field is a string in the output schema
|
||||
if "error" in output_schema and output_schema["error"].annotation is not str:
|
||||
raise ValueError(
|
||||
f"{block.name} `error` field in output_schema must be a string"
|
||||
)
|
||||
|
||||
# Make sure all fields in input_schema and output_schema are annotated and has a value
|
||||
for field_name, field in [*input_schema.items(), *output_schema.items()]:
|
||||
if field.annotation is None:
|
||||
raise ValueError(
|
||||
f"{block.name} has a field {field_name} that is not annotated"
|
||||
)
|
||||
if field.json_schema_extra is None:
|
||||
raise ValueError(
|
||||
f"{block.name} has a field {field_name} not defined as SchemaField"
|
||||
)
|
||||
|
||||
for field in block.input_schema.model_fields.values():
|
||||
if field.annotation is bool and field.default not in (True, False):
|
||||
raise ValueError(f"{block.name} has a boolean field with no default value")
|
||||
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
AVAILABLE_BLOCKS[block.id] = block_cls
|
||||
|
||||
__all__ = ["AVAILABLE_MODULES", "AVAILABLE_BLOCKS"]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import JsonValue
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
@@ -13,11 +13,25 @@ from backend.data.block import (
|
||||
get_block,
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util.json import validate_with_jsonschema
|
||||
from backend.util.retry import func_retry
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_executor_manager_client():
|
||||
from backend.executor import ExecutionManager
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_event_bus():
|
||||
from backend.data.execution import RedisExecutionEventBus
|
||||
|
||||
return RedisExecutionEventBus()
|
||||
|
||||
|
||||
class AgentExecutorBlock(Block):
|
||||
@@ -26,21 +40,17 @@ class AgentExecutorBlock(Block):
|
||||
graph_id: str = SchemaField(description="Graph ID")
|
||||
graph_version: int = SchemaField(description="Graph Version")
|
||||
|
||||
inputs: BlockInput = SchemaField(description="Input data for the graph")
|
||||
data: BlockInput = SchemaField(description="Input data for the graph")
|
||||
input_schema: dict = SchemaField(description="Input schema for the graph")
|
||||
output_schema: dict = SchemaField(description="Output schema for the graph")
|
||||
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = SchemaField(
|
||||
default=None, hidden=True
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
|
||||
return data.get("input_schema", {})
|
||||
|
||||
@classmethod
|
||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||
return data.get("inputs", {})
|
||||
return data.get("data", {})
|
||||
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
@@ -49,7 +59,7 @@ class AgentExecutorBlock(Block):
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
return validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
return json.validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
|
||||
class Output(BlockSchema):
|
||||
pass
|
||||
@@ -64,103 +74,36 @@ class AgentExecutorBlock(Block):
|
||||
categories={BlockCategory.AGENT},
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
executor_manager = get_executor_manager_client()
|
||||
event_bus = get_event_bus()
|
||||
|
||||
from backend.executor import utils as execution_utils
|
||||
|
||||
graph_exec = await execution_utils.add_graph_execution(
|
||||
graph_exec = executor_manager.add_execution(
|
||||
graph_id=input_data.graph_id,
|
||||
graph_version=input_data.graph_version,
|
||||
user_id=input_data.user_id,
|
||||
inputs=input_data.inputs,
|
||||
nodes_input_masks=input_data.nodes_input_masks,
|
||||
data=input_data.data,
|
||||
)
|
||||
|
||||
logger = execution_utils.LogMetadata(
|
||||
logger=_logger,
|
||||
user_id=input_data.user_id,
|
||||
graph_eid=graph_exec.id,
|
||||
graph_id=input_data.graph_id,
|
||||
node_eid="*",
|
||||
node_id="*",
|
||||
block_name=self.name,
|
||||
)
|
||||
|
||||
try:
|
||||
async for name, data in self._run(
|
||||
graph_id=input_data.graph_id,
|
||||
graph_version=input_data.graph_version,
|
||||
graph_exec_id=graph_exec.id,
|
||||
user_id=input_data.user_id,
|
||||
logger=logger,
|
||||
):
|
||||
yield name, data
|
||||
except BaseException as e:
|
||||
await self._stop(
|
||||
graph_exec_id=graph_exec.id,
|
||||
user_id=input_data.user_id,
|
||||
logger=logger,
|
||||
)
|
||||
logger.warning(
|
||||
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} failed: {e.__class__.__name__} {str(e)}; execution is stopped."
|
||||
)
|
||||
raise
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
logger,
|
||||
) -> BlockOutput:
|
||||
|
||||
from backend.data.execution import ExecutionEventType
|
||||
from backend.executor import utils as execution_utils
|
||||
|
||||
event_bus = execution_utils.get_async_execution_event_bus()
|
||||
|
||||
log_id = f"Graph #{graph_id}-V{graph_version}, exec-id: {graph_exec_id}"
|
||||
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.graph_exec_id}"
|
||||
logger.info(f"Starting execution of {log_id}")
|
||||
yielded_node_exec_ids = set()
|
||||
|
||||
async for event in event_bus.listen(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
for event in event_bus.listen(
|
||||
graph_id=graph_exec.graph_id, graph_exec_id=graph_exec.graph_exec_id
|
||||
):
|
||||
if event.status not in [
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.TERMINATED,
|
||||
ExecutionStatus.FAILED,
|
||||
]:
|
||||
logger.debug(
|
||||
f"Execution {log_id} received event {event.event_type} with status {event.status}"
|
||||
)
|
||||
continue
|
||||
|
||||
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
|
||||
# If the graph execution is COMPLETED, TERMINATED, or FAILED,
|
||||
# we can stop listening for further events.
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
extra_cost=event.stats.cost if event.stats else 0,
|
||||
extra_steps=event.stats.node_exec_count if event.stats else 0,
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
|
||||
)
|
||||
|
||||
if event.node_exec_id in yielded_node_exec_ids:
|
||||
logger.warning(
|
||||
f"{log_id} received duplicate event for node execution {event.node_exec_id}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
yielded_node_exec_ids.add(event.node_exec_id)
|
||||
if not event.node_id:
|
||||
if event.status in [
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.TERMINATED,
|
||||
ExecutionStatus.FAILED,
|
||||
]:
|
||||
logger.info(f"Execution {log_id} ended with status {event.status}")
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
if not event.block_id:
|
||||
logger.warning(f"{log_id} received event without block_id {event}")
|
||||
@@ -176,29 +119,5 @@ class AgentExecutorBlock(Block):
|
||||
continue
|
||||
|
||||
for output_data in event.output_data.get("output", []):
|
||||
logger.debug(
|
||||
f"Execution {log_id} produced {output_name}: {output_data}"
|
||||
)
|
||||
logger.info(f"Execution {log_id} produced {output_name}: {output_data}")
|
||||
yield output_name, output_data
|
||||
|
||||
@func_retry
|
||||
async def _stop(
|
||||
self,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
logger,
|
||||
) -> None:
|
||||
from backend.executor import utils as execution_utils
|
||||
|
||||
log_id = f"Graph exec-id: {graph_exec_id}"
|
||||
logger.info(f"Stopping execution of {log_id}")
|
||||
|
||||
try:
|
||||
await execution_utils.stop_graph_execution(
|
||||
graph_exec_id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
wait_timeout=3600,
|
||||
)
|
||||
logger.info(f"Execution {log_id} stopped successfully.")
|
||||
except TimeoutError as e:
|
||||
logger.error(f"Execution {log_id} stop timed out: {e}")
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
import replicate
|
||||
from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockSchema
|
||||
@@ -165,15 +165,15 @@ class AIImageGeneratorBlock(Block):
|
||||
},
|
||||
)
|
||||
|
||||
async def _run_client(
|
||||
def _run_client(
|
||||
self, credentials: APIKeyCredentials, model_name: str, input_params: dict
|
||||
):
|
||||
try:
|
||||
# Initialize Replicate client
|
||||
client = ReplicateClient(api_token=credentials.api_key.get_secret_value())
|
||||
client = replicate.Client(api_token=credentials.api_key.get_secret_value())
|
||||
|
||||
# Run the model with input parameters
|
||||
output = await client.async_run(model_name, input=input_params, wait=False)
|
||||
output = client.run(model_name, input=input_params, wait=False)
|
||||
|
||||
# Process output
|
||||
if isinstance(output, list) and len(output) > 0:
|
||||
@@ -195,7 +195,7 @@ class AIImageGeneratorBlock(Block):
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Unexpected error during model execution: {e}")
|
||||
|
||||
async def generate_image(self, input_data: Input, credentials: APIKeyCredentials):
|
||||
def generate_image(self, input_data: Input, credentials: APIKeyCredentials):
|
||||
try:
|
||||
# Handle style-based prompt modification for models without native style support
|
||||
modified_prompt = input_data.prompt
|
||||
@@ -213,7 +213,7 @@ class AIImageGeneratorBlock(Block):
|
||||
"steps": 40,
|
||||
"cfg_scale": 7.0,
|
||||
}
|
||||
output = await self._run_client(
|
||||
output = self._run_client(
|
||||
credentials,
|
||||
"stability-ai/stable-diffusion-3.5-medium",
|
||||
input_params,
|
||||
@@ -231,7 +231,7 @@ class AIImageGeneratorBlock(Block):
|
||||
"output_format": "jpg", # Set to jpg for Flux models
|
||||
"output_quality": 90,
|
||||
}
|
||||
output = await self._run_client(
|
||||
output = self._run_client(
|
||||
credentials, "black-forest-labs/flux-1.1-pro", input_params
|
||||
)
|
||||
return output
|
||||
@@ -246,7 +246,7 @@ class AIImageGeneratorBlock(Block):
|
||||
"output_format": "jpg",
|
||||
"output_quality": 90,
|
||||
}
|
||||
output = await self._run_client(
|
||||
output = self._run_client(
|
||||
credentials, "black-forest-labs/flux-1.1-pro-ultra", input_params
|
||||
)
|
||||
return output
|
||||
@@ -257,7 +257,7 @@ class AIImageGeneratorBlock(Block):
|
||||
"size": SIZE_TO_RECRAFT_DIMENSIONS[input_data.size],
|
||||
"style": input_data.style.value,
|
||||
}
|
||||
output = await self._run_client(
|
||||
output = self._run_client(
|
||||
credentials, "recraft-ai/recraft-v3", input_params
|
||||
)
|
||||
return output
|
||||
@@ -296,9 +296,9 @@ class AIImageGeneratorBlock(Block):
|
||||
style_text = style_map.get(style, "")
|
||||
return f"{style_text} of" if style_text else ""
|
||||
|
||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
try:
|
||||
url = await self.generate_image(input_data, credentials)
|
||||
url = self.generate_image(input_data, credentials)
|
||||
if url:
|
||||
yield "image_url", url
|
||||
else:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
import replicate
|
||||
from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
@@ -142,7 +142,7 @@ class AIMusicGeneratorBlock(Block):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
max_retries = 3
|
||||
@@ -154,7 +154,7 @@ class AIMusicGeneratorBlock(Block):
|
||||
logger.debug(
|
||||
f"[AIMusicGeneratorBlock] - Running model (attempt {attempt + 1})"
|
||||
)
|
||||
result = await self.run_model(
|
||||
result = self.run_model(
|
||||
api_key=credentials.api_key,
|
||||
music_gen_model_version=input_data.music_gen_model_version,
|
||||
prompt=input_data.prompt,
|
||||
@@ -176,13 +176,13 @@ class AIMusicGeneratorBlock(Block):
|
||||
last_error = f"Unexpected error: {str(e)}"
|
||||
logger.error(f"[AIMusicGeneratorBlock] - Error: {last_error}")
|
||||
if attempt < max_retries - 1:
|
||||
await asyncio.sleep(retry_delay)
|
||||
time.sleep(retry_delay)
|
||||
continue
|
||||
|
||||
# If we've exhausted all retries, yield the error
|
||||
yield "error", f"Failed after {max_retries} attempts. Last error: {last_error}"
|
||||
|
||||
async def run_model(
|
||||
def run_model(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
music_gen_model_version: MusicGenModelVersion,
|
||||
@@ -196,10 +196,10 @@ class AIMusicGeneratorBlock(Block):
|
||||
normalization_strategy: NormalizationStrategy,
|
||||
):
|
||||
# Initialize Replicate client with the API key
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
client = replicate.Client(api_token=api_key.get_secret_value())
|
||||
|
||||
# Run the model with parameters
|
||||
output = await client.async_run(
|
||||
output = client.run(
|
||||
"meta/musicgen:671ac645ce5e552cc63a54a2bbff63fcf798043055d2dac5fc9e36a837eedcfb",
|
||||
input={
|
||||
"prompt": prompt,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
@@ -14,7 +13,7 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import Requests
|
||||
from backend.util.request import requests
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -53,7 +52,6 @@ class AudioTrack(str, Enum):
|
||||
REFRESHER = ("Refresher",)
|
||||
TOURIST = ("Tourist",)
|
||||
TWIN_TYCHES = ("Twin Tyches",)
|
||||
DONT_STOP_ME_ABSTRACT_FUTURE_BASS = ("Dont Stop Me Abstract Future Bass",)
|
||||
|
||||
@property
|
||||
def audio_url(self):
|
||||
@@ -79,7 +77,6 @@ class AudioTrack(str, Enum):
|
||||
AudioTrack.REFRESHER: "https://cdn.tfrv.xyz/audio/refresher.mp3",
|
||||
AudioTrack.TOURIST: "https://cdn.tfrv.xyz/audio/tourist.mp3",
|
||||
AudioTrack.TWIN_TYCHES: "https://cdn.tfrv.xyz/audio/twin-tynches.mp3",
|
||||
AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS: "https://cdn.revid.ai/audio/_dont-stop-me-abstract-future-bass.mp3",
|
||||
}
|
||||
return audio_urls[self]
|
||||
|
||||
@@ -107,7 +104,6 @@ class GenerationPreset(str, Enum):
|
||||
MOVIE = ("Movie",)
|
||||
STYLIZED_ILLUSTRATION = ("Stylized Illustration",)
|
||||
MANGA = ("Manga",)
|
||||
DEFAULT = ("DEFAULT",)
|
||||
|
||||
|
||||
class Voice(str, Enum):
|
||||
@@ -117,7 +113,6 @@ class Voice(str, Enum):
|
||||
JESSICA = "Jessica"
|
||||
CHARLOTTE = "Charlotte"
|
||||
CALLUM = "Callum"
|
||||
EVA = "Eva"
|
||||
|
||||
@property
|
||||
def voice_id(self):
|
||||
@@ -128,7 +123,6 @@ class Voice(str, Enum):
|
||||
Voice.JESSICA: "cgSgspJ2msm6clMCkdW9",
|
||||
Voice.CHARLOTTE: "XB0fDUnXU5powFXDhCwa",
|
||||
Voice.CALLUM: "N2lVS1w4EtoT3dr4eOWO",
|
||||
Voice.EVA: "FGY2WhTYpPnrIDTdsKH5",
|
||||
}
|
||||
return voice_id_map[self]
|
||||
|
||||
@@ -146,8 +140,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIShortformVideoCreatorBlock(Block):
|
||||
"""Creates a short‑form text‑to‑video clip using stock or AI imagery."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
@@ -191,58 +183,6 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
video_url: str = SchemaField(description="The URL of the created video")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = await Requests().post(url, headers=headers)
|
||||
webhook_data = response.json()
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
"""Create a video using the Revid API."""
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().post(url, json=payload, headers=headers)
|
||||
logger.debug(
|
||||
f"API Response Status Code: {response.status}, Content: {response.text}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
"""Check the status of a video creation job."""
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().get(url, headers=headers)
|
||||
return response.json()
|
||||
|
||||
async def wait_for_video(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
"""Wait for video creation to complete and return the video URL."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = await self.check_video_status(api_key, pid)
|
||||
logger.debug(f"Video status: {status}")
|
||||
|
||||
if status.get("status") == "ready" and "videoUrl" in status:
|
||||
return status["videoUrl"]
|
||||
elif status.get("status") == "error":
|
||||
error_message = status.get("error", "Unknown error occurred")
|
||||
logger.error(f"Video creation failed: {error_message}")
|
||||
raise ValueError(f"Video creation failed: {error_message}")
|
||||
elif status.get("status") in ["FAILED", "CANCELED"]:
|
||||
logger.error(f"Video creation failed: {status.get('message')}")
|
||||
raise ValueError(f"Video creation failed: {status.get('message')}")
|
||||
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="361697fb-0c4f-4feb-aed3-8320c88c771b",
|
||||
@@ -261,41 +201,91 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
"voice": Voice.LILY,
|
||||
"video_style": VisualMediaType.STOCK_VIDEOS,
|
||||
},
|
||||
test_output=("video_url", "https://example.com/video.mp4"),
|
||||
test_output=(
|
||||
"video_url",
|
||||
"https://example.com/video.mp4",
|
||||
),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"create_webhook": lambda: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/video.mp4",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
|
||||
"create_video": lambda api_key, payload: {"pid": "test_pid"},
|
||||
"wait_for_video": lambda api_key, pid, webhook_token, max_wait_time=1000: "https://example.com/video.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(
|
||||
def create_webhook(self):
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = requests.post(url, headers=headers)
|
||||
webhook_data = response.json()
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
logger.debug(
|
||||
f"API Response Status Code: {response.status_code}, Content: {response.text}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = requests.get(url, headers=headers)
|
||||
return response.json()
|
||||
|
||||
def wait_for_video(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
webhook_token: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = self.check_video_status(api_key, pid)
|
||||
logger.debug(f"Video status: {status}")
|
||||
|
||||
if status.get("status") == "ready" and "videoUrl" in status:
|
||||
return status["videoUrl"]
|
||||
elif status.get("status") == "error":
|
||||
error_message = status.get("error", "Unknown error occurred")
|
||||
logger.error(f"Video creation failed: {error_message}")
|
||||
raise ValueError(f"Video creation failed: {error_message}")
|
||||
elif status.get("status") in ["FAILED", "CANCELED"]:
|
||||
logger.error(f"Video creation failed: {status.get('message')}")
|
||||
raise ValueError(f"Video creation failed: {status.get('message')}")
|
||||
|
||||
time.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Create a new Webhook.site URL
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
webhook_token, webhook_url = self.create_webhook()
|
||||
logger.debug(f"Webhook URL: {webhook_url}")
|
||||
|
||||
audio_url = input_data.background_music.audio_url
|
||||
|
||||
payload = {
|
||||
"frameRate": input_data.frame_rate,
|
||||
"resolution": input_data.resolution,
|
||||
"frameDurationMultiplier": 18,
|
||||
"webhook": None,
|
||||
"webhook": webhook_url,
|
||||
"creationParams": {
|
||||
"mediaType": input_data.video_style,
|
||||
"captionPresetName": "Wrap 1",
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"hasEnhancedGeneration": True,
|
||||
"generationPreset": input_data.generation_preset.name,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"selectedAudio": input_data.background_music,
|
||||
"origin": "/create",
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
@@ -311,12 +301,12 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
"selectedStoryStyle": {"value": "custom", "label": "Custom"},
|
||||
"hasToGenerateVideos": input_data.video_style
|
||||
!= VisualMediaType.STOCK_VIDEOS,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
"audioUrl": audio_url,
|
||||
},
|
||||
}
|
||||
|
||||
logger.debug("Creating video...")
|
||||
response = await self.create_video(credentials.api_key, payload)
|
||||
response = self.create_video(credentials.api_key, payload)
|
||||
pid = response.get("pid")
|
||||
|
||||
if not pid:
|
||||
@@ -328,370 +318,6 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
logger.debug(
|
||||
f"Video created with project ID: {pid}. Waiting for completion..."
|
||||
)
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
video_url = self.wait_for_video(credentials.api_key, pid, webhook_token)
|
||||
logger.debug(f"Video ready: {video_url}")
|
||||
yield "video_url", video_url
|
||||
|
||||
|
||||
class AIAdMakerVideoCreatorBlock(Block):
|
||||
"""Generates a 30‑second vertical AI advert using optional user‑supplied imagery."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Credentials for Revid.ai API access.",
|
||||
)
|
||||
script: str = SchemaField(
|
||||
description="Short advertising copy. Line breaks create new scenes.",
|
||||
placeholder="Introducing Foobar – [show product photo] the gadget that does it all.",
|
||||
)
|
||||
ratio: str = SchemaField(description="Aspect ratio", default="9 / 16")
|
||||
target_duration: int = SchemaField(
|
||||
description="Desired length of the ad in seconds.", default=30
|
||||
)
|
||||
voice: Voice = SchemaField(
|
||||
description="Narration voice", default=Voice.EVA, placeholder=Voice.EVA
|
||||
)
|
||||
background_music: AudioTrack = SchemaField(
|
||||
description="Background track",
|
||||
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS,
|
||||
)
|
||||
input_media_urls: list[str] = SchemaField(
|
||||
description="List of image URLs to feature in the advert.", default=[]
|
||||
)
|
||||
use_only_provided_media: bool = SchemaField(
|
||||
description="Restrict visuals to supplied images only.", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="URL of the finished advert")
|
||||
error: str = SchemaField(description="Error message on failure")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = await Requests().post(url, headers=headers)
|
||||
webhook_data = response.json()
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
"""Create a video using the Revid API."""
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().post(url, json=payload, headers=headers)
|
||||
logger.debug(
|
||||
f"API Response Status Code: {response.status}, Content: {response.text}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
"""Check the status of a video creation job."""
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().get(url, headers=headers)
|
||||
return response.json()
|
||||
|
||||
async def wait_for_video(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
"""Wait for video creation to complete and return the video URL."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = await self.check_video_status(api_key, pid)
|
||||
logger.debug(f"Video status: {status}")
|
||||
|
||||
if status.get("status") == "ready" and "videoUrl" in status:
|
||||
return status["videoUrl"]
|
||||
elif status.get("status") == "error":
|
||||
error_message = status.get("error", "Unknown error occurred")
|
||||
logger.error(f"Video creation failed: {error_message}")
|
||||
raise ValueError(f"Video creation failed: {error_message}")
|
||||
elif status.get("status") in ["FAILED", "CANCELED"]:
|
||||
logger.error(f"Video creation failed: {status.get('message')}")
|
||||
raise ValueError(f"Video creation failed: {status.get('message')}")
|
||||
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="58bd2a19-115d-4fd1-8ca4-13b9e37fa6a0",
|
||||
description="Creates an AI‑generated 30‑second advert (text + images)",
|
||||
categories={BlockCategory.MARKETING, BlockCategory.AI},
|
||||
input_schema=AIAdMakerVideoCreatorBlock.Input,
|
||||
output_schema=AIAdMakerVideoCreatorBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "Test product launch!",
|
||||
"input_media_urls": [
|
||||
"https://cdn.revid.ai/uploads/1747076315114-image.png",
|
||||
],
|
||||
},
|
||||
test_output=("video_url", "https://example.com/ad.mp4"),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/ad.mp4",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
|
||||
payload = {
|
||||
"webhook": webhook_url,
|
||||
"creationParams": {
|
||||
"targetDuration": input_data.target_duration,
|
||||
"ratio": input_data.ratio,
|
||||
"mediaType": "aiVideo",
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
"slug": "ai-ad-generator",
|
||||
"slugNew": "",
|
||||
"isCopiedFrom": False,
|
||||
"hasToGenerateVoice": True,
|
||||
"hasToTranscript": False,
|
||||
"hasToSearchMedia": True,
|
||||
"hasAvatar": False,
|
||||
"hasWebsiteRecorder": False,
|
||||
"hasTextSmallAtBottom": False,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
|
||||
"selectedAvatarType": "video/mp4",
|
||||
"websiteToRecord": "",
|
||||
"hasToGenerateCover": True,
|
||||
"nbGenerations": 1,
|
||||
"disableCaptions": False,
|
||||
"mediaMultiplier": "medium",
|
||||
"characters": [],
|
||||
"captionPresetName": "Revid",
|
||||
"sourceType": "contentScraping",
|
||||
"selectedStoryStyle": {"value": "custom", "label": "General"},
|
||||
"generationPreset": "DEFAULT",
|
||||
"hasToGenerateMusic": False,
|
||||
"isOptimizedForChinese": False,
|
||||
"generationUserPrompt": "",
|
||||
"enableNsfwFilter": False,
|
||||
"addStickers": False,
|
||||
"typeMovingImageAnim": "dynamic",
|
||||
"hasToGenerateSoundEffects": False,
|
||||
"forceModelType": "gpt-image-1",
|
||||
"selectedCharacters": [],
|
||||
"lang": "",
|
||||
"voiceSpeed": 1,
|
||||
"disableAudio": False,
|
||||
"disableVoice": False,
|
||||
"useOnlyProvidedMedia": input_data.use_only_provided_media,
|
||||
"imageGenerationModel": "ultra",
|
||||
"videoGenerationModel": "pro",
|
||||
"hasEnhancedGeneration": True,
|
||||
"hasEnhancedGenerationPro": True,
|
||||
"inputMedias": [
|
||||
{"url": url, "title": "", "type": "image"}
|
||||
for url in input_data.input_media_urls
|
||||
],
|
||||
"hasToGenerateVideos": True,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
"watermark": None,
|
||||
},
|
||||
}
|
||||
|
||||
response = await self.create_video(credentials.api_key, payload)
|
||||
pid = response.get("pid")
|
||||
if not pid:
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
yield "video_url", video_url
|
||||
|
||||
|
||||
class AIScreenshotToVideoAdBlock(Block):
|
||||
"""Creates an advert where the supplied screenshot is narrated by an AI avatar."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(description="Revid.ai API key")
|
||||
script: str = SchemaField(
|
||||
description="Narration that will accompany the screenshot.",
|
||||
placeholder="Check out these amazing stats!",
|
||||
)
|
||||
screenshot_url: str = SchemaField(
|
||||
description="Screenshot or image URL to showcase."
|
||||
)
|
||||
ratio: str = SchemaField(default="9 / 16")
|
||||
target_duration: int = SchemaField(default=30)
|
||||
voice: Voice = SchemaField(default=Voice.EVA)
|
||||
background_music: AudioTrack = SchemaField(
|
||||
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="Rendered video URL")
|
||||
error: str = SchemaField(description="Error, if encountered")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = await Requests().post(url, headers=headers)
|
||||
webhook_data = response.json()
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
"""Create a video using the Revid API."""
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().post(url, json=payload, headers=headers)
|
||||
logger.debug(
|
||||
f"API Response Status Code: {response.status}, Content: {response.text}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
"""Check the status of a video creation job."""
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().get(url, headers=headers)
|
||||
return response.json()
|
||||
|
||||
async def wait_for_video(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
"""Wait for video creation to complete and return the video URL."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = await self.check_video_status(api_key, pid)
|
||||
logger.debug(f"Video status: {status}")
|
||||
|
||||
if status.get("status") == "ready" and "videoUrl" in status:
|
||||
return status["videoUrl"]
|
||||
elif status.get("status") == "error":
|
||||
error_message = status.get("error", "Unknown error occurred")
|
||||
logger.error(f"Video creation failed: {error_message}")
|
||||
raise ValueError(f"Video creation failed: {error_message}")
|
||||
elif status.get("status") in ["FAILED", "CANCELED"]:
|
||||
logger.error(f"Video creation failed: {status.get('message')}")
|
||||
raise ValueError(f"Video creation failed: {status.get('message')}")
|
||||
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0f3e4635-e810-43d9-9e81-49e6f4e83b7c",
|
||||
description="Turns a screenshot into an engaging, avatar‑narrated video advert.",
|
||||
categories={BlockCategory.AI, BlockCategory.MARKETING},
|
||||
input_schema=AIScreenshotToVideoAdBlock.Input,
|
||||
output_schema=AIScreenshotToVideoAdBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "Amazing numbers!",
|
||||
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
|
||||
},
|
||||
test_output=("video_url", "https://example.com/screenshot.mp4"),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/screenshot.mp4",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
|
||||
payload = {
|
||||
"webhook": webhook_url,
|
||||
"creationParams": {
|
||||
"targetDuration": input_data.target_duration,
|
||||
"ratio": input_data.ratio,
|
||||
"mediaType": "aiVideo",
|
||||
"hasAvatar": True,
|
||||
"removeAvatarBackground": True,
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
"slug": "ai-ad-generator",
|
||||
"slugNew": "screenshot-to-video-ad",
|
||||
"isCopiedFrom": "ai-ad-generator",
|
||||
"hasToGenerateVoice": True,
|
||||
"hasToTranscript": False,
|
||||
"hasToSearchMedia": True,
|
||||
"hasWebsiteRecorder": False,
|
||||
"hasTextSmallAtBottom": False,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
|
||||
"selectedAvatarType": "video/mp4",
|
||||
"websiteToRecord": "",
|
||||
"hasToGenerateCover": True,
|
||||
"nbGenerations": 1,
|
||||
"disableCaptions": False,
|
||||
"mediaMultiplier": "medium",
|
||||
"characters": [],
|
||||
"captionPresetName": "Revid",
|
||||
"sourceType": "contentScraping",
|
||||
"selectedStoryStyle": {"value": "custom", "label": "General"},
|
||||
"generationPreset": "DEFAULT",
|
||||
"hasToGenerateMusic": False,
|
||||
"isOptimizedForChinese": False,
|
||||
"generationUserPrompt": "",
|
||||
"enableNsfwFilter": False,
|
||||
"addStickers": False,
|
||||
"typeMovingImageAnim": "dynamic",
|
||||
"hasToGenerateSoundEffects": False,
|
||||
"forceModelType": "gpt-image-1",
|
||||
"selectedCharacters": [],
|
||||
"lang": "",
|
||||
"voiceSpeed": 1,
|
||||
"disableAudio": False,
|
||||
"disableVoice": False,
|
||||
"useOnlyProvidedMedia": True,
|
||||
"imageGenerationModel": "ultra",
|
||||
"videoGenerationModel": "ultra",
|
||||
"hasEnhancedGeneration": True,
|
||||
"hasEnhancedGenerationPro": True,
|
||||
"inputMedias": [
|
||||
{"url": input_data.screenshot_url, "title": "", "type": "image"}
|
||||
],
|
||||
"hasToGenerateVideos": True,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
"watermark": None,
|
||||
},
|
||||
}
|
||||
|
||||
response = await self.create_video(credentials.api_key, payload)
|
||||
pid = response.get("pid")
|
||||
if not pid:
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
yield "video_url", video_url
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,323 +0,0 @@
|
||||
from os import getenv
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.sdk import APIKeyCredentials, SecretStr
|
||||
|
||||
from ._api import (
|
||||
TableFieldType,
|
||||
WebhookFilters,
|
||||
WebhookSpecification,
|
||||
create_base,
|
||||
create_field,
|
||||
create_record,
|
||||
create_table,
|
||||
create_webhook,
|
||||
delete_multiple_records,
|
||||
delete_record,
|
||||
delete_webhook,
|
||||
get_record,
|
||||
list_bases,
|
||||
list_records,
|
||||
list_webhook_payloads,
|
||||
update_field,
|
||||
update_multiple_records,
|
||||
update_record,
|
||||
update_table,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_update_table():
|
||||
|
||||
key = getenv("AIRTABLE_API_KEY")
|
||||
if not key:
|
||||
return pytest.skip("AIRTABLE_API_KEY is not set")
|
||||
|
||||
credentials = APIKeyCredentials(
|
||||
provider="airtable",
|
||||
api_key=SecretStr(key),
|
||||
)
|
||||
postfix = uuid4().hex[:4]
|
||||
workspace_id = "wsphuHmfllg7V3Brd"
|
||||
response = await create_base(credentials, workspace_id, "API Testing Base")
|
||||
assert response is not None, f"Checking create base response: {response}"
|
||||
assert (
|
||||
response.get("id") is not None
|
||||
), f"Checking create base response id: {response}"
|
||||
base_id = response.get("id")
|
||||
assert base_id is not None, f"Checking create base response id: {base_id}"
|
||||
|
||||
response = await list_bases(credentials)
|
||||
assert response is not None, f"Checking list bases response: {response}"
|
||||
assert "API Testing Base" in [
|
||||
base.get("name") for base in response.get("bases", [])
|
||||
], f"Checking list bases response bases: {response}"
|
||||
|
||||
table_name = f"test_table_{postfix}"
|
||||
table_fields = [{"name": "test_field", "type": "singleLineText"}]
|
||||
table = await create_table(credentials, base_id, table_name, table_fields)
|
||||
assert table.get("name") == table_name
|
||||
|
||||
table_id = table.get("id")
|
||||
|
||||
assert table_id is not None
|
||||
|
||||
table_name = f"test_table_updated_{postfix}"
|
||||
table_description = "test_description_updated"
|
||||
table = await update_table(
|
||||
credentials,
|
||||
base_id,
|
||||
table_id,
|
||||
table_name=table_name,
|
||||
table_description=table_description,
|
||||
)
|
||||
assert table.get("name") == table_name
|
||||
assert table.get("description") == table_description
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_field_type():
|
||||
|
||||
key = getenv("AIRTABLE_API_KEY")
|
||||
if not key:
|
||||
return pytest.skip("AIRTABLE_API_KEY is not set")
|
||||
|
||||
credentials = APIKeyCredentials(
|
||||
provider="airtable",
|
||||
api_key=SecretStr(key),
|
||||
)
|
||||
postfix = uuid4().hex[:4]
|
||||
base_id = "appZPxegHEU3kDc1S"
|
||||
table_name = f"test_table_{postfix}"
|
||||
table_fields = [{"name": "test_field", "type": "notValid"}]
|
||||
with pytest.raises(AssertionError):
|
||||
await create_table(credentials, base_id, table_name, table_fields)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_update_field():
|
||||
key = getenv("AIRTABLE_API_KEY")
|
||||
if not key:
|
||||
return pytest.skip("AIRTABLE_API_KEY is not set")
|
||||
|
||||
credentials = APIKeyCredentials(
|
||||
provider="airtable",
|
||||
api_key=SecretStr(key),
|
||||
)
|
||||
postfix = uuid4().hex[:4]
|
||||
base_id = "appZPxegHEU3kDc1S"
|
||||
table_name = f"test_table_{postfix}"
|
||||
table_fields = [{"name": "test_field", "type": "singleLineText"}]
|
||||
table = await create_table(credentials, base_id, table_name, table_fields)
|
||||
assert table.get("name") == table_name
|
||||
|
||||
table_id = table.get("id")
|
||||
|
||||
assert table_id is not None
|
||||
|
||||
field_name = f"test_field_{postfix}"
|
||||
field_type = TableFieldType.SINGLE_LINE_TEXT
|
||||
field = await create_field(credentials, base_id, table_id, field_type, field_name)
|
||||
assert field.get("name") == field_name
|
||||
|
||||
field_id = field.get("id")
|
||||
|
||||
assert field_id is not None
|
||||
assert isinstance(field_id, str)
|
||||
|
||||
field_name = f"test_field_updated_{postfix}"
|
||||
field = await update_field(credentials, base_id, table_id, field_id, field_name)
|
||||
assert field.get("name") == field_name
|
||||
|
||||
field_description = "test_description_updated"
|
||||
field = await update_field(
|
||||
credentials, base_id, table_id, field_id, description=field_description
|
||||
)
|
||||
assert field.get("description") == field_description
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_management():
|
||||
key = getenv("AIRTABLE_API_KEY")
|
||||
if not key:
|
||||
return pytest.skip("AIRTABLE_API_KEY is not set")
|
||||
|
||||
credentials = APIKeyCredentials(
|
||||
provider="airtable",
|
||||
api_key=SecretStr(key),
|
||||
)
|
||||
postfix = uuid4().hex[:4]
|
||||
base_id = "appZPxegHEU3kDc1S"
|
||||
table_name = f"test_table_{postfix}"
|
||||
table_fields = [{"name": "test_field", "type": "singleLineText"}]
|
||||
table = await create_table(credentials, base_id, table_name, table_fields)
|
||||
assert table.get("name") == table_name
|
||||
|
||||
table_id = table.get("id")
|
||||
assert table_id is not None
|
||||
|
||||
# Create a record
|
||||
record_fields = {"test_field": "test_value"}
|
||||
record = await create_record(credentials, base_id, table_id, fields=record_fields)
|
||||
fields = record.get("fields")
|
||||
assert fields is not None
|
||||
assert isinstance(fields, dict)
|
||||
assert fields.get("test_field") == "test_value"
|
||||
|
||||
record_id = record.get("id")
|
||||
|
||||
assert record_id is not None
|
||||
assert isinstance(record_id, str)
|
||||
|
||||
# Get a record
|
||||
record = await get_record(credentials, base_id, table_id, record_id)
|
||||
fields = record.get("fields")
|
||||
assert fields is not None
|
||||
assert isinstance(fields, dict)
|
||||
assert fields.get("test_field") == "test_value"
|
||||
|
||||
# Updata a record
|
||||
record_fields = {"test_field": "test_value_updated"}
|
||||
record = await update_record(
|
||||
credentials, base_id, table_id, record_id, fields=record_fields
|
||||
)
|
||||
fields = record.get("fields")
|
||||
assert fields is not None
|
||||
assert isinstance(fields, dict)
|
||||
assert fields.get("test_field") == "test_value_updated"
|
||||
|
||||
# Delete a record
|
||||
record = await delete_record(credentials, base_id, table_id, record_id)
|
||||
assert record is not None
|
||||
assert record.get("id") == record_id
|
||||
assert record.get("deleted")
|
||||
|
||||
# Create 2 records
|
||||
records = [
|
||||
{"fields": {"test_field": "test_value_1"}},
|
||||
{"fields": {"test_field": "test_value_2"}},
|
||||
]
|
||||
response = await create_record(credentials, base_id, table_id, records=records)
|
||||
created_records = response.get("records")
|
||||
assert created_records is not None
|
||||
assert isinstance(created_records, list)
|
||||
assert len(created_records) == 2, f"Created records: {created_records}"
|
||||
first_record = created_records[0] # type: ignore
|
||||
second_record = created_records[1] # type: ignore
|
||||
first_record_id = first_record.get("id")
|
||||
second_record_id = second_record.get("id")
|
||||
assert first_record_id is not None
|
||||
assert second_record_id is not None
|
||||
assert first_record_id != second_record_id
|
||||
first_fields = first_record.get("fields")
|
||||
second_fields = second_record.get("fields")
|
||||
assert first_fields is not None
|
||||
assert second_fields is not None
|
||||
assert first_fields.get("test_field") == "test_value_1" # type: ignore
|
||||
assert second_fields.get("test_field") == "test_value_2" # type: ignore
|
||||
|
||||
# List records
|
||||
response = await list_records(credentials, base_id, table_id)
|
||||
records = response.get("records")
|
||||
assert records is not None
|
||||
assert len(records) == 2, f"Records: {records}"
|
||||
assert isinstance(records, list), f"Type of records: {type(records)}"
|
||||
|
||||
# Update multiple records
|
||||
records = [
|
||||
{"id": first_record_id, "fields": {"test_field": "test_value_1_updated"}},
|
||||
{"id": second_record_id, "fields": {"test_field": "test_value_2_updated"}},
|
||||
]
|
||||
response = await update_multiple_records(
|
||||
credentials, base_id, table_id, records=records
|
||||
)
|
||||
updated_records = response.get("records")
|
||||
assert updated_records is not None
|
||||
assert len(updated_records) == 2, f"Updated records: {updated_records}"
|
||||
assert isinstance(
|
||||
updated_records, list
|
||||
), f"Type of updated records: {type(updated_records)}"
|
||||
first_updated = updated_records[0] # type: ignore
|
||||
second_updated = updated_records[1] # type: ignore
|
||||
first_updated_fields = first_updated.get("fields")
|
||||
second_updated_fields = second_updated.get("fields")
|
||||
assert first_updated_fields is not None
|
||||
assert second_updated_fields is not None
|
||||
assert first_updated_fields.get("test_field") == "test_value_1_updated" # type: ignore
|
||||
assert second_updated_fields.get("test_field") == "test_value_2_updated" # type: ignore
|
||||
|
||||
# Delete multiple records
|
||||
assert isinstance(first_record_id, str)
|
||||
assert isinstance(second_record_id, str)
|
||||
response = await delete_multiple_records(
|
||||
credentials, base_id, table_id, records=[first_record_id, second_record_id]
|
||||
)
|
||||
deleted_records = response.get("records")
|
||||
assert deleted_records is not None
|
||||
assert len(deleted_records) == 2, f"Deleted records: {deleted_records}"
|
||||
assert isinstance(
|
||||
deleted_records, list
|
||||
), f"Type of deleted records: {type(deleted_records)}"
|
||||
first_deleted = deleted_records[0] # type: ignore
|
||||
second_deleted = deleted_records[1] # type: ignore
|
||||
assert first_deleted.get("deleted")
|
||||
assert second_deleted.get("deleted")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_management():
|
||||
key = getenv("AIRTABLE_API_KEY")
|
||||
if not key:
|
||||
return pytest.skip("AIRTABLE_API_KEY is not set")
|
||||
|
||||
credentials = APIKeyCredentials(
|
||||
provider="airtable",
|
||||
api_key=SecretStr(key),
|
||||
)
|
||||
postfix = uuid4().hex[:4]
|
||||
base_id = "appZPxegHEU3kDc1S"
|
||||
table_name = f"test_table_{postfix}"
|
||||
table_fields = [{"name": "test_field", "type": "singleLineText"}]
|
||||
table = await create_table(credentials, base_id, table_name, table_fields)
|
||||
assert table.get("name") == table_name
|
||||
|
||||
table_id = table.get("id")
|
||||
assert table_id is not None
|
||||
webhook_specification = WebhookSpecification(
|
||||
filters=WebhookFilters(
|
||||
dataTypes=["tableData", "tableFields", "tableMetadata"],
|
||||
changeTypes=["add", "update", "remove"],
|
||||
)
|
||||
)
|
||||
response = await create_webhook(credentials, base_id, webhook_specification)
|
||||
assert response is not None, f"Checking create webhook response: {response}"
|
||||
assert (
|
||||
response.get("id") is not None
|
||||
), f"Checking create webhook response id: {response}"
|
||||
assert (
|
||||
response.get("macSecretBase64") is not None
|
||||
), f"Checking create webhook response macSecretBase64: {response}"
|
||||
|
||||
webhook_id = response.get("id")
|
||||
assert webhook_id is not None, f"Webhook ID: {webhook_id}"
|
||||
assert isinstance(webhook_id, str)
|
||||
|
||||
response = await create_record(
|
||||
credentials, base_id, table_id, fields={"test_field": "test_value"}
|
||||
)
|
||||
assert response is not None, f"Checking create record response: {response}"
|
||||
assert (
|
||||
response.get("id") is not None
|
||||
), f"Checking create record response id: {response}"
|
||||
fields = response.get("fields")
|
||||
assert fields is not None, f"Checking create record response fields: {response}"
|
||||
assert (
|
||||
fields.get("test_field") == "test_value"
|
||||
), f"Checking create record response fields test_field: {response}"
|
||||
|
||||
response = await list_webhook_payloads(credentials, base_id, webhook_id)
|
||||
assert response is not None, f"Checking list webhook payloads response: {response}"
|
||||
|
||||
response = await delete_webhook(credentials, base_id, webhook_id)
|
||||
@@ -1,32 +0,0 @@
|
||||
"""
|
||||
Shared configuration for all Airtable blocks using the SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
from ._oauth import AirtableOAuthHandler, AirtableScope
|
||||
from ._webhook import AirtableWebhookManager
|
||||
|
||||
# Configure the Airtable provider with API key authentication
|
||||
airtable = (
|
||||
ProviderBuilder("airtable")
|
||||
.with_api_key("AIRTABLE_API_KEY", "Airtable Personal Access Token")
|
||||
.with_webhook_manager(AirtableWebhookManager)
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.with_oauth(
|
||||
AirtableOAuthHandler,
|
||||
scopes=[
|
||||
v.value
|
||||
for v in [
|
||||
AirtableScope.DATA_RECORDS_READ,
|
||||
AirtableScope.DATA_RECORDS_WRITE,
|
||||
AirtableScope.SCHEMA_BASES_READ,
|
||||
AirtableScope.SCHEMA_BASES_WRITE,
|
||||
AirtableScope.WEBHOOK_MANAGE,
|
||||
]
|
||||
],
|
||||
client_id_env_var="AIRTABLE_CLIENT_ID",
|
||||
client_secret_env_var="AIRTABLE_CLIENT_SECRET",
|
||||
)
|
||||
.build()
|
||||
)
|
||||
@@ -1,185 +0,0 @@
|
||||
"""
|
||||
Airtable OAuth handler implementation.
|
||||
"""
|
||||
|
||||
import time
|
||||
from enum import Enum
|
||||
from logging import getLogger
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import BaseOAuthHandler, OAuth2Credentials, ProviderName, SecretStr
|
||||
|
||||
from ._api import (
|
||||
OAuthTokenResponse,
|
||||
make_oauth_authorize_url,
|
||||
oauth_exchange_code_for_tokens,
|
||||
oauth_refresh_tokens,
|
||||
)
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class AirtableScope(str, Enum):
|
||||
# Basic scopes
|
||||
DATA_RECORDS_READ = "data.records:read"
|
||||
DATA_RECORDS_WRITE = "data.records:write"
|
||||
DATA_RECORD_COMMENTS_READ = "data.recordComments:read"
|
||||
DATA_RECORD_COMMENTS_WRITE = "data.recordComments:write"
|
||||
SCHEMA_BASES_READ = "schema.bases:read"
|
||||
SCHEMA_BASES_WRITE = "schema.bases:write"
|
||||
WEBHOOK_MANAGE = "webhook:manage"
|
||||
BLOCK_MANAGE = "block:manage"
|
||||
USER_EMAIL_READ = "user.email:read"
|
||||
|
||||
# Enterprise member scopes
|
||||
ENTERPRISE_GROUPS_READ = "enterprise.groups:read"
|
||||
WORKSPACES_AND_BASES_READ = "workspacesAndBases:read"
|
||||
WORKSPACES_AND_BASES_WRITE = "workspacesAndBases:write"
|
||||
WORKSPACES_AND_BASES_SHARES_MANAGE = "workspacesAndBases.shares:manage"
|
||||
|
||||
# Enterprise admin scopes
|
||||
ENTERPRISE_SCIM_USERS_AND_GROUPS_MANAGE = "enterprise.scim.usersAndGroups:manage"
|
||||
ENTERPRISE_AUDIT_LOGS_READ = "enterprise.auditLogs:read"
|
||||
ENTERPRISE_CHANGE_EVENTS_READ = "enterprise.changeEvents:read"
|
||||
ENTERPRISE_EXPORTS_MANAGE = "enterprise.exports:manage"
|
||||
ENTERPRISE_ACCOUNT_READ = "enterprise.account:read"
|
||||
ENTERPRISE_ACCOUNT_WRITE = "enterprise.account:write"
|
||||
ENTERPRISE_USER_READ = "enterprise.user:read"
|
||||
ENTERPRISE_USER_WRITE = "enterprise.user:write"
|
||||
ENTERPRISE_GROUPS_MANAGE = "enterprise.groups:manage"
|
||||
WORKSPACES_AND_BASES_MANAGE = "workspacesAndBases:manage"
|
||||
HYPERDB_RECORDS_READ = "hyperDB.records:read"
|
||||
HYPERDB_RECORDS_WRITE = "hyperDB.records:write"
|
||||
|
||||
|
||||
class AirtableOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
OAuth2 handler for Airtable with PKCE support.
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = ProviderName("airtable")
|
||||
DEFAULT_SCOPES = [
|
||||
v.value
|
||||
for v in [
|
||||
AirtableScope.DATA_RECORDS_READ,
|
||||
AirtableScope.DATA_RECORDS_WRITE,
|
||||
AirtableScope.SCHEMA_BASES_READ,
|
||||
AirtableScope.SCHEMA_BASES_WRITE,
|
||||
AirtableScope.WEBHOOK_MANAGE,
|
||||
]
|
||||
]
|
||||
|
||||
def __init__(self, client_id: str, client_secret: Optional[str], redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.scopes = self.DEFAULT_SCOPES
|
||||
self.auth_base_url = "https://airtable.com/oauth2/v1/authorize"
|
||||
self.token_url = "https://airtable.com/oauth2/v1/token"
|
||||
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
logger.debug("Generating Airtable OAuth login URL")
|
||||
# Generate code_challenge if not provided (PKCE is required)
|
||||
if not scopes:
|
||||
logger.debug("No scopes provided, using default scopes")
|
||||
scopes = self.scopes
|
||||
|
||||
logger.debug(f"Using scopes: {scopes}")
|
||||
logger.debug(f"State: {state}")
|
||||
logger.debug(f"Code challenge: {code_challenge}")
|
||||
if not code_challenge:
|
||||
logger.error("Code challenge is required but none was provided")
|
||||
raise ValueError("No code challenge provided")
|
||||
|
||||
try:
|
||||
url = make_oauth_authorize_url(
|
||||
self.client_id, self.redirect_uri, scopes, state, code_challenge
|
||||
)
|
||||
logger.debug(f"Generated OAuth URL: {url}")
|
||||
return url
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate OAuth URL: {str(e)}")
|
||||
raise
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
) -> OAuth2Credentials:
|
||||
logger.debug("Exchanging authorization code for tokens")
|
||||
logger.debug(f"Code: {code[:4]}...") # Log first 4 chars only for security
|
||||
logger.debug(f"Scopes: {scopes}")
|
||||
if not code_verifier:
|
||||
logger.error("Code verifier is required but none was provided")
|
||||
raise ValueError("No code verifier provided")
|
||||
|
||||
try:
|
||||
response: OAuthTokenResponse = await oauth_exchange_code_for_tokens(
|
||||
client_id=self.client_id,
|
||||
code=code,
|
||||
code_verifier=code_verifier.encode("utf-8"),
|
||||
redirect_uri=self.redirect_uri,
|
||||
client_secret=self.client_secret,
|
||||
)
|
||||
logger.info("Successfully exchanged code for tokens")
|
||||
|
||||
credentials = OAuth2Credentials(
|
||||
access_token=SecretStr(response.access_token),
|
||||
refresh_token=SecretStr(response.refresh_token),
|
||||
access_token_expires_at=int(time.time()) + response.expires_in,
|
||||
refresh_token_expires_at=int(time.time()) + response.refresh_expires_in,
|
||||
provider=self.PROVIDER_NAME,
|
||||
scopes=scopes,
|
||||
)
|
||||
logger.debug(f"Access token expires in {response.expires_in} seconds")
|
||||
logger.debug(
|
||||
f"Refresh token expires in {response.refresh_expires_in} seconds"
|
||||
)
|
||||
return credentials
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to exchange code for tokens: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _refresh_tokens(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
logger.debug("Attempting to refresh OAuth tokens")
|
||||
|
||||
if credentials.refresh_token is None:
|
||||
logger.error("Cannot refresh tokens - no refresh token available")
|
||||
raise ValueError("No refresh token available")
|
||||
|
||||
try:
|
||||
response: OAuthTokenResponse = await oauth_refresh_tokens(
|
||||
client_id=self.client_id,
|
||||
refresh_token=credentials.refresh_token.get_secret_value(),
|
||||
client_secret=self.client_secret,
|
||||
)
|
||||
logger.info("Successfully refreshed tokens")
|
||||
|
||||
new_credentials = OAuth2Credentials(
|
||||
id=credentials.id,
|
||||
access_token=SecretStr(response.access_token),
|
||||
refresh_token=SecretStr(response.refresh_token),
|
||||
access_token_expires_at=int(time.time()) + response.expires_in,
|
||||
refresh_token_expires_at=int(time.time()) + response.refresh_expires_in,
|
||||
provider=self.PROVIDER_NAME,
|
||||
scopes=self.scopes,
|
||||
)
|
||||
logger.debug(f"New access token expires in {response.expires_in} seconds")
|
||||
logger.debug(
|
||||
f"New refresh token expires in {response.refresh_expires_in} seconds"
|
||||
)
|
||||
return new_credentials
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh tokens: {str(e)}")
|
||||
raise
|
||||
|
||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
logger.debug("Token revocation requested")
|
||||
logger.info(
|
||||
"Airtable doesn't provide a token revocation endpoint - tokens will expire naturally after 60 minutes"
|
||||
)
|
||||
return False
|
||||
@@ -1,154 +0,0 @@
|
||||
"""
|
||||
Webhook management for Airtable blocks.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
from backend.sdk import (
|
||||
BaseWebhooksManager,
|
||||
Credentials,
|
||||
ProviderName,
|
||||
Webhook,
|
||||
update_webhook,
|
||||
)
|
||||
|
||||
from ._api import (
|
||||
WebhookFilters,
|
||||
WebhookSpecification,
|
||||
create_webhook,
|
||||
delete_webhook,
|
||||
list_webhook_payloads,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AirtableWebhookEvent(str, Enum):
|
||||
TABLE_DATA = "tableData"
|
||||
TABLE_FIELDS = "tableFields"
|
||||
TABLE_METADATA = "tableMetadata"
|
||||
|
||||
|
||||
class AirtableWebhookManager(BaseWebhooksManager):
|
||||
"""Webhook manager for Airtable API."""
|
||||
|
||||
PROVIDER_NAME = ProviderName("airtable")
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(
|
||||
cls, webhook: Webhook, request, credentials: Credentials | None
|
||||
) -> tuple[dict, str]:
|
||||
"""Validate incoming webhook payload and signature."""
|
||||
|
||||
if not credentials:
|
||||
raise ValueError("Missing credentials in webhook metadata")
|
||||
|
||||
payload = await request.json()
|
||||
|
||||
# Verify webhook signature using HMAC-SHA256
|
||||
if webhook.secret:
|
||||
mac_secret = webhook.config.get("mac_secret")
|
||||
if mac_secret:
|
||||
# Get the raw body for signature verification
|
||||
body = await request.body()
|
||||
|
||||
# Calculate expected signature
|
||||
mac_secret_decoded = mac_secret.encode()
|
||||
hmac_obj = hmac.new(mac_secret_decoded, body, hashlib.sha256)
|
||||
expected_mac = f"hmac-sha256={hmac_obj.hexdigest()}"
|
||||
|
||||
# Get signature from headers
|
||||
signature = request.headers.get("X-Airtable-Content-MAC")
|
||||
|
||||
if signature and not hmac.compare_digest(signature, expected_mac):
|
||||
raise ValueError("Invalid webhook signature")
|
||||
|
||||
# Validate payload structure
|
||||
required_fields = ["base", "webhook", "timestamp"]
|
||||
if not all(field in payload for field in required_fields):
|
||||
raise ValueError("Invalid webhook payload structure")
|
||||
|
||||
if "id" not in payload["base"] or "id" not in payload["webhook"]:
|
||||
raise ValueError("Missing required IDs in webhook payload")
|
||||
base_id = payload["base"]["id"]
|
||||
webhook_id = payload["webhook"]["id"]
|
||||
|
||||
# get payload request parameters
|
||||
cursor = webhook.config.get("cursor", 1)
|
||||
|
||||
response = await list_webhook_payloads(credentials, base_id, webhook_id, cursor)
|
||||
|
||||
# update webhook config
|
||||
await update_webhook(
|
||||
webhook.id,
|
||||
config={"base_id": base_id, "cursor": response.cursor},
|
||||
)
|
||||
|
||||
event_type = "notification"
|
||||
return response.model_dump(), event_type
|
||||
|
||||
async def _register_webhook(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
ingress_url: str,
|
||||
secret: str,
|
||||
) -> tuple[str, dict]:
|
||||
"""Register webhook with Airtable API."""
|
||||
|
||||
# Parse resource to get base_id and table_id/name
|
||||
# Resource format: "{base_id}/{table_id_or_name}"
|
||||
parts = resource.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError("Resource must be in format: {base_id}/{table_id_or_name}")
|
||||
|
||||
base_id, table_id_or_name = parts
|
||||
|
||||
# Prepare webhook specification
|
||||
webhook_specification = WebhookSpecification(
|
||||
filters=WebhookFilters(
|
||||
dataTypes=events,
|
||||
)
|
||||
)
|
||||
|
||||
# Create webhook
|
||||
webhook_data = await create_webhook(
|
||||
credentials=credentials,
|
||||
base_id=base_id,
|
||||
webhook_specification=webhook_specification,
|
||||
notification_url=ingress_url,
|
||||
)
|
||||
|
||||
webhook_id = webhook_data["id"]
|
||||
mac_secret = webhook_data.get("macSecretBase64")
|
||||
|
||||
return webhook_id, {
|
||||
"webhook_id": webhook_id,
|
||||
"base_id": base_id,
|
||||
"table_id_or_name": table_id_or_name,
|
||||
"events": events,
|
||||
"mac_secret": mac_secret,
|
||||
"cursor": 1,
|
||||
"expiration_time": webhook_data.get("expirationTime"),
|
||||
}
|
||||
|
||||
async def _deregister_webhook(
|
||||
self, webhook: Webhook, credentials: Credentials
|
||||
) -> None:
|
||||
"""Deregister webhook from Airtable API."""
|
||||
|
||||
base_id = webhook.config.get("base_id")
|
||||
webhook_id = webhook.config.get("webhook_id")
|
||||
|
||||
if not base_id:
|
||||
raise ValueError("Missing base_id in webhook metadata")
|
||||
|
||||
if not webhook_id:
|
||||
raise ValueError("Missing webhook_id in webhook metadata")
|
||||
|
||||
await delete_webhook(credentials, base_id, webhook_id)
|
||||
@@ -1,122 +0,0 @@
|
||||
"""
|
||||
Airtable base operation blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import create_base, list_bases
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableCreateBaseBlock(Block):
|
||||
"""
|
||||
Creates a new base in an Airtable workspace.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
workspace_id: str = SchemaField(
|
||||
description="The workspace ID where the base will be created"
|
||||
)
|
||||
name: str = SchemaField(description="The name of the new base")
|
||||
tables: list[dict] = SchemaField(
|
||||
description="At least one table and field must be specified. Array of table objects to create in the base. Each table should have 'name' and 'fields' properties",
|
||||
default=[
|
||||
{
|
||||
"description": "Default table",
|
||||
"name": "Default table",
|
||||
"fields": [
|
||||
{
|
||||
"name": "ID",
|
||||
"type": "number",
|
||||
"description": "Auto-incrementing ID field",
|
||||
"options": {"precision": 0},
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
base_id: str = SchemaField(description="The ID of the created base")
|
||||
tables: list[dict] = SchemaField(description="Array of table objects")
|
||||
table: dict = SchemaField(description="A single table object")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f59b88a8-54ce-4676-a508-fd614b4e8dce",
|
||||
description="Create a new base in Airtable",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
data = await create_base(
|
||||
credentials,
|
||||
input_data.workspace_id,
|
||||
input_data.name,
|
||||
input_data.tables,
|
||||
)
|
||||
|
||||
yield "base_id", data.get("id", None)
|
||||
yield "tables", data.get("tables", [])
|
||||
for table in data.get("tables", []):
|
||||
yield "table", table
|
||||
|
||||
|
||||
class AirtableListBasesBlock(Block):
|
||||
"""
|
||||
Lists all bases in an Airtable workspace that the user has access to.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
trigger: str = SchemaField(
|
||||
description="Trigger the block to run - value is ignored", default="manual"
|
||||
)
|
||||
offset: str = SchemaField(
|
||||
description="Pagination offset from previous request", default=""
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
bases: list[dict] = SchemaField(description="Array of base objects")
|
||||
offset: Optional[str] = SchemaField(
|
||||
description="Offset for next page (null if no more bases)", default=None
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4bd8d466-ed5d-4e44-8083-97f25a8044e7",
|
||||
description="List all bases in Airtable",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
data = await list_bases(
|
||||
credentials,
|
||||
offset=input_data.offset if input_data.offset else None,
|
||||
)
|
||||
|
||||
yield "bases", data.get("bases", [])
|
||||
yield "offset", data.get("offset", None)
|
||||
@@ -1,283 +0,0 @@
|
||||
"""
|
||||
Airtable record operation blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import (
|
||||
create_record,
|
||||
delete_multiple_records,
|
||||
get_record,
|
||||
list_records,
|
||||
update_multiple_records,
|
||||
)
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableListRecordsBlock(Block):
|
||||
"""
|
||||
Lists records from an Airtable table with optional filtering, sorting, and pagination.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name")
|
||||
filter_formula: str = SchemaField(
|
||||
description="Airtable formula to filter records", default=""
|
||||
)
|
||||
view: str = SchemaField(description="View ID or name to use", default="")
|
||||
sort: list[dict] = SchemaField(
|
||||
description="Sort configuration (array of {field, direction})", default=[]
|
||||
)
|
||||
max_records: int = SchemaField(
|
||||
description="Maximum number of records to return", default=100
|
||||
)
|
||||
page_size: int = SchemaField(
|
||||
description="Number of records per page (max 100)", default=100
|
||||
)
|
||||
offset: str = SchemaField(
|
||||
description="Pagination offset from previous request", default=""
|
||||
)
|
||||
return_fields: list[str] = SchemaField(
|
||||
description="Specific fields to return (comma-separated)", default=[]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of record objects")
|
||||
offset: Optional[str] = SchemaField(
|
||||
description="Offset for next page (null if no more records)", default=None
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="588a9fde-5733-4da7-b03c-35f5671e960f",
|
||||
description="List records from an Airtable table",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
data = await list_records(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id_or_name,
|
||||
filter_by_formula=(
|
||||
input_data.filter_formula if input_data.filter_formula else None
|
||||
),
|
||||
view=input_data.view if input_data.view else None,
|
||||
sort=input_data.sort if input_data.sort else None,
|
||||
max_records=input_data.max_records if input_data.max_records else None,
|
||||
page_size=min(input_data.page_size, 100) if input_data.page_size else None,
|
||||
offset=input_data.offset if input_data.offset else None,
|
||||
fields=input_data.return_fields if input_data.return_fields else None,
|
||||
)
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
yield "offset", data.get("offset", None)
|
||||
|
||||
|
||||
class AirtableGetRecordBlock(Block):
|
||||
"""
|
||||
Retrieves a single record from an Airtable table by its ID.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name")
|
||||
record_id: str = SchemaField(description="The record ID to retrieve")
|
||||
|
||||
class Output(BlockSchema):
|
||||
id: str = SchemaField(description="The record ID")
|
||||
fields: dict = SchemaField(description="The record fields")
|
||||
created_time: str = SchemaField(description="The record created time")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c29c5cbf-0aff-40f9-bbb5-f26061792d2b",
|
||||
description="Get a single record from Airtable",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
record = await get_record(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id_or_name,
|
||||
input_data.record_id,
|
||||
)
|
||||
|
||||
yield "id", record.get("id", None)
|
||||
yield "fields", record.get("fields", None)
|
||||
yield "created_time", record.get("createdTime", None)
|
||||
|
||||
|
||||
class AirtableCreateRecordsBlock(Block):
|
||||
"""
|
||||
Creates one or more records in an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name")
|
||||
records: list[dict] = SchemaField(
|
||||
description="Array of records to create (each with 'fields' object)"
|
||||
)
|
||||
typecast: bool = SchemaField(
|
||||
description="Automatically convert string values to appropriate types",
|
||||
default=False,
|
||||
)
|
||||
return_fields_by_field_id: bool | None = SchemaField(
|
||||
description="Return fields by field ID",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of created record objects")
|
||||
details: dict = SchemaField(description="Details of the created records")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="42527e98-47b6-44ce-ac0e-86b4883721d3",
|
||||
description="Create records in an Airtable table",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# The create_record API expects records in a specific format
|
||||
data = await create_record(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id_or_name,
|
||||
records=[{"fields": record} for record in input_data.records],
|
||||
typecast=input_data.typecast if input_data.typecast else None,
|
||||
return_fields_by_field_id=input_data.return_fields_by_field_id,
|
||||
)
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
details = data.get("details", None)
|
||||
if details:
|
||||
yield "details", details
|
||||
|
||||
|
||||
class AirtableUpdateRecordsBlock(Block):
|
||||
"""
|
||||
Updates one or more existing records in an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(
|
||||
description="Table ID or name - It's better to use the table ID instead of the name"
|
||||
)
|
||||
records: list[dict] = SchemaField(
|
||||
description="Array of records to update (each with 'id' and 'fields')"
|
||||
)
|
||||
typecast: bool | None = SchemaField(
|
||||
description="Automatically convert string values to appropriate types",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of updated record objects")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="6e7d2590-ac2b-4b5d-b08c-fc039cd77e1f",
|
||||
description="Update records in an Airtable table",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# The update_multiple_records API expects records with id and fields
|
||||
data = await update_multiple_records(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id_or_name,
|
||||
records=input_data.records,
|
||||
typecast=input_data.typecast if input_data.typecast else None,
|
||||
return_fields_by_field_id=False, # Use field names, not IDs
|
||||
)
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
|
||||
|
||||
class AirtableDeleteRecordsBlock(Block):
|
||||
"""
|
||||
Deletes one or more records from an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(
|
||||
description="Table ID or name - It's better to use the table ID instead of the name"
|
||||
)
|
||||
record_ids: list[str] = SchemaField(
|
||||
description="Array of upto 10 record IDs to delete"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of deletion results")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="93e22b8b-3642-4477-aefb-1c0929a4a3a6",
|
||||
description="Delete records from an Airtable table",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
if len(input_data.record_ids) > 10:
|
||||
yield "error", "Only upto 10 record IDs can be deleted at a time"
|
||||
else:
|
||||
data = await delete_multiple_records(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id_or_name,
|
||||
input_data.record_ids,
|
||||
)
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
@@ -1,252 +0,0 @@
|
||||
"""
|
||||
Airtable schema and table management blocks.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import TableFieldType, create_field, create_table, update_field, update_table
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableListSchemaBlock(Block):
|
||||
"""
|
||||
Retrieves the complete schema of an Airtable base, including all tables,
|
||||
fields, and views.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
|
||||
class Output(BlockSchema):
|
||||
base_schema: dict = SchemaField(
|
||||
description="Complete base schema with tables, fields, and views"
|
||||
)
|
||||
tables: list[dict] = SchemaField(description="Array of table objects")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="64291d3c-99b5-47b7-a976-6d94293cdb2d",
|
||||
description="Get the complete schema of an Airtable base",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Get base schema
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{input_data.base_id}/tables",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
yield "base_schema", data
|
||||
yield "tables", data.get("tables", [])
|
||||
|
||||
|
||||
class AirtableCreateTableBlock(Block):
|
||||
"""
|
||||
Creates a new table in an Airtable base with specified fields and views.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_name: str = SchemaField(description="The name of the table to create")
|
||||
table_fields: list[dict] = SchemaField(
|
||||
description="Table fields with name, type, and options",
|
||||
default=[{"name": "Name", "type": "singleLineText"}],
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
table: dict = SchemaField(description="Created table object")
|
||||
table_id: str = SchemaField(description="ID of the created table")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fcc20ced-d817-42ea-9b40-c35e7bf34b4f",
|
||||
description="Create a new table in an Airtable base",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
table_data = await create_table(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_name,
|
||||
input_data.table_fields,
|
||||
)
|
||||
|
||||
yield "table", table_data
|
||||
yield "table_id", table_data.get("id", "")
|
||||
|
||||
|
||||
class AirtableUpdateTableBlock(Block):
|
||||
"""
|
||||
Updates an existing table's properties such as name or description.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id: str = SchemaField(description="The table ID to update")
|
||||
table_name: str | None = SchemaField(
|
||||
description="The name of the table to update", default=None
|
||||
)
|
||||
table_description: str | None = SchemaField(
|
||||
description="The description of the table to update", default=None
|
||||
)
|
||||
date_dependency: dict | None = SchemaField(
|
||||
description="The date dependency of the table to update", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
table: dict = SchemaField(description="Updated table object")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="34077c5f-f962-49f2-9ec6-97c67077013a",
|
||||
description="Update table properties",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
table_data = await update_table(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id,
|
||||
input_data.table_name,
|
||||
input_data.table_description,
|
||||
input_data.date_dependency,
|
||||
)
|
||||
|
||||
yield "table", table_data
|
||||
|
||||
|
||||
class AirtableCreateFieldBlock(Block):
|
||||
"""
|
||||
Adds a new field (column) to an existing Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id: str = SchemaField(description="The table ID to add field to")
|
||||
field_type: TableFieldType = SchemaField(
|
||||
description="The type of the field to create",
|
||||
default=TableFieldType.SINGLE_LINE_TEXT,
|
||||
advanced=False,
|
||||
)
|
||||
name: str = SchemaField(description="The name of the field to create")
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the field to create", default=None
|
||||
)
|
||||
options: dict[str, str] | None = SchemaField(
|
||||
description="The options of the field to create", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
field: dict = SchemaField(description="Created field object")
|
||||
field_id: str = SchemaField(description="ID of the created field")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="6c98a32f-dbf9-45d8-a2a8-5e97e8326351",
|
||||
description="Add a new field to an Airtable table",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
field_data = await create_field(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id,
|
||||
input_data.field_type,
|
||||
input_data.name,
|
||||
)
|
||||
|
||||
yield "field", field_data
|
||||
yield "field_id", field_data.get("id", "")
|
||||
|
||||
|
||||
class AirtableUpdateFieldBlock(Block):
|
||||
"""
|
||||
Updates an existing field's properties in an Airtable table.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id: str = SchemaField(description="The table ID containing the field")
|
||||
field_id: str = SchemaField(description="The field ID to update")
|
||||
name: str | None = SchemaField(
|
||||
description="The name of the field to update", default=None, advanced=False
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the field to update",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
field: dict = SchemaField(description="Updated field object")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f46ac716-3b18-4da1-92e4-34ca9a464d48",
|
||||
description="Update field properties in an Airtable table",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
field_data = await update_field(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
input_data.table_id,
|
||||
input_data.field_id,
|
||||
input_data.name,
|
||||
input_data.description,
|
||||
)
|
||||
|
||||
yield "field", field_data
|
||||
@@ -1,113 +0,0 @@
|
||||
from backend.sdk import (
|
||||
BaseModel,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
BlockWebhookConfig,
|
||||
CredentialsMetaInput,
|
||||
ProviderName,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import WebhookPayload
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableEventSelector(BaseModel):
|
||||
"""
|
||||
Selects the Airtable webhook event to trigger on.
|
||||
"""
|
||||
|
||||
tableData: bool = True
|
||||
tableFields: bool = True
|
||||
tableMetadata: bool = True
|
||||
|
||||
|
||||
class AirtableWebhookTriggerBlock(Block):
|
||||
"""
|
||||
Starts a flow whenever Airtable emits a webhook event.
|
||||
|
||||
Thin wrapper just forwards the payloads one at a time to the next block.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = airtable.credentials_field(
|
||||
description="Airtable API credentials"
|
||||
)
|
||||
base_id: str = SchemaField(description="Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(description="Airtable table ID or name")
|
||||
payload: dict = SchemaField(hidden=True, default_factory=dict)
|
||||
events: AirtableEventSelector = SchemaField(
|
||||
description="Airtable webhook event filter"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
payload: WebhookPayload = SchemaField(description="Airtable webhook payload")
|
||||
|
||||
def __init__(self):
|
||||
example_payload = {
|
||||
"payloads": [
|
||||
{
|
||||
"timestamp": "2022-02-01T21:25:05.663Z",
|
||||
"baseTransactionNumber": 4,
|
||||
"actionMetadata": {
|
||||
"source": "client",
|
||||
"sourceMetadata": {
|
||||
"user": {
|
||||
"id": "usr00000000000000",
|
||||
"email": "foo@bar.com",
|
||||
"permissionLevel": "create",
|
||||
}
|
||||
},
|
||||
},
|
||||
"payloadFormat": "v0",
|
||||
}
|
||||
],
|
||||
"cursor": 5,
|
||||
"mightHaveMore": False,
|
||||
}
|
||||
|
||||
super().__init__(
|
||||
# NOTE: This is disabled whilst the webhook system is finalised.
|
||||
disabled=False,
|
||||
id="d0180ce6-ccb9-48c7-8256-b39e93e62801",
|
||||
description="Starts a flow whenever Airtable emits a webhook event",
|
||||
categories={BlockCategory.INPUT, BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
block_type=BlockType.WEBHOOK,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName("airtable"),
|
||||
webhook_type="not-used",
|
||||
event_filter_input="events",
|
||||
event_format="{event}",
|
||||
resource_format="{base_id}/{table_id_or_name}",
|
||||
),
|
||||
test_input={
|
||||
"credentials": airtable.get_test_credentials().model_dump(),
|
||||
"base_id": "app1234567890",
|
||||
"table_id_or_name": "table1234567890",
|
||||
"events": AirtableEventSelector(
|
||||
tableData=True,
|
||||
tableFields=True,
|
||||
tableMetadata=False,
|
||||
).model_dump(),
|
||||
"payload": example_payload,
|
||||
},
|
||||
test_credentials=airtable.get_test_credentials(),
|
||||
test_output=[
|
||||
(
|
||||
"payload",
|
||||
WebhookPayload.model_validate(example_payload["payloads"][0]),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
if len(input_data.payload["payloads"]) > 0:
|
||||
for item in input_data.payload["payloads"]:
|
||||
yield "payload", WebhookPayload.model_validate(item)
|
||||
else:
|
||||
yield "error", "No valid payloads found in webhook payload"
|
||||
@@ -4,7 +4,6 @@ from typing import List
|
||||
from backend.blocks.apollo._auth import ApolloCredentials
|
||||
from backend.blocks.apollo.models import (
|
||||
Contact,
|
||||
EnrichPersonRequest,
|
||||
Organization,
|
||||
SearchOrganizationsRequest,
|
||||
SearchOrganizationsResponse,
|
||||
@@ -28,15 +27,14 @@ class ApolloClient:
|
||||
def _get_headers(self) -> dict[str, str]:
|
||||
return {"x-api-key": self.credentials.api_key.get_secret_value()}
|
||||
|
||||
async def search_people(self, query: SearchPeopleRequest) -> List[Contact]:
|
||||
def search_people(self, query: SearchPeopleRequest) -> List[Contact]:
|
||||
"""Search for people in Apollo"""
|
||||
response = await self.requests.post(
|
||||
response = self.requests.get(
|
||||
f"{self.API_URL}/mixed_people/search",
|
||||
headers=self._get_headers(),
|
||||
json=query.model_dump(exclude={"max_results"}),
|
||||
params=query.model_dump(exclude={"credentials", "max_results"}),
|
||||
)
|
||||
data = response.json()
|
||||
parsed_response = SearchPeopleResponse(**data)
|
||||
parsed_response = SearchPeopleResponse(**response.json())
|
||||
if parsed_response.pagination.total_entries == 0:
|
||||
return []
|
||||
|
||||
@@ -54,29 +52,27 @@ class ApolloClient:
|
||||
and len(parsed_response.people) > 0
|
||||
):
|
||||
query.page += 1
|
||||
response = await self.requests.post(
|
||||
response = self.requests.get(
|
||||
f"{self.API_URL}/mixed_people/search",
|
||||
headers=self._get_headers(),
|
||||
json=query.model_dump(exclude={"max_results"}),
|
||||
params=query.model_dump(exclude={"credentials", "max_results"}),
|
||||
)
|
||||
data = response.json()
|
||||
parsed_response = SearchPeopleResponse(**data)
|
||||
parsed_response = SearchPeopleResponse(**response.json())
|
||||
people.extend(parsed_response.people[: query.max_results - len(people)])
|
||||
|
||||
logger.info(f"Found {len(people)} people")
|
||||
return people[: query.max_results] if query.max_results else people
|
||||
|
||||
async def search_organizations(
|
||||
def search_organizations(
|
||||
self, query: SearchOrganizationsRequest
|
||||
) -> List[Organization]:
|
||||
"""Search for organizations in Apollo"""
|
||||
response = await self.requests.post(
|
||||
response = self.requests.get(
|
||||
f"{self.API_URL}/mixed_companies/search",
|
||||
headers=self._get_headers(),
|
||||
json=query.model_dump(exclude={"max_results"}),
|
||||
params=query.model_dump(exclude={"credentials", "max_results"}),
|
||||
)
|
||||
data = response.json()
|
||||
parsed_response = SearchOrganizationsResponse(**data)
|
||||
parsed_response = SearchOrganizationsResponse(**response.json())
|
||||
if parsed_response.pagination.total_entries == 0:
|
||||
return []
|
||||
|
||||
@@ -94,13 +90,12 @@ class ApolloClient:
|
||||
and len(parsed_response.organizations) > 0
|
||||
):
|
||||
query.page += 1
|
||||
response = await self.requests.post(
|
||||
response = self.requests.get(
|
||||
f"{self.API_URL}/mixed_companies/search",
|
||||
headers=self._get_headers(),
|
||||
json=query.model_dump(exclude={"max_results"}),
|
||||
params=query.model_dump(exclude={"credentials", "max_results"}),
|
||||
)
|
||||
data = response.json()
|
||||
parsed_response = SearchOrganizationsResponse(**data)
|
||||
parsed_response = SearchOrganizationsResponse(**response.json())
|
||||
organizations.extend(
|
||||
parsed_response.organizations[
|
||||
: query.max_results - len(organizations)
|
||||
@@ -111,21 +106,3 @@ class ApolloClient:
|
||||
return (
|
||||
organizations[: query.max_results] if query.max_results else organizations
|
||||
)
|
||||
|
||||
async def enrich_person(self, query: EnrichPersonRequest) -> Contact:
|
||||
"""Enrich a person's data including email & phone reveal"""
|
||||
response = await self.requests.post(
|
||||
f"{self.API_URL}/people/match",
|
||||
headers=self._get_headers(),
|
||||
json=query.model_dump(),
|
||||
params={
|
||||
"reveal_personal_emails": "true",
|
||||
},
|
||||
)
|
||||
data = response.json()
|
||||
if "person" not in data:
|
||||
raise ValueError(f"Person not found or enrichment failed: {data}")
|
||||
|
||||
contact = Contact(**data["person"])
|
||||
contact.email = contact.email or "-"
|
||||
return contact
|
||||
|
||||
@@ -1,31 +1,17 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel as OriginalBaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class BaseModel(OriginalBaseModel):
|
||||
def model_dump(self, *args, exclude: set[str] | None = None, **kwargs):
|
||||
if exclude is None:
|
||||
exclude = set("credentials")
|
||||
else:
|
||||
exclude.add("credentials")
|
||||
|
||||
kwargs.setdefault("exclude_none", True)
|
||||
kwargs.setdefault("exclude_unset", True)
|
||||
kwargs.setdefault("exclude_defaults", True)
|
||||
return super().model_dump(*args, exclude=exclude, **kwargs)
|
||||
|
||||
|
||||
class PrimaryPhone(BaseModel):
|
||||
"""A primary phone in Apollo"""
|
||||
|
||||
number: Optional[str] = ""
|
||||
source: Optional[str] = ""
|
||||
sanitized_number: Optional[str] = ""
|
||||
number: str
|
||||
source: str
|
||||
sanitized_number: str
|
||||
|
||||
|
||||
class SenorityLevels(str, Enum):
|
||||
@@ -56,159 +42,157 @@ class ContactEmailStatuses(str, Enum):
|
||||
class RuleConfigStatus(BaseModel):
|
||||
"""A rule config status in Apollo"""
|
||||
|
||||
_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
rule_action_config_id: Optional[str] = ""
|
||||
rule_config_id: Optional[str] = ""
|
||||
status_cd: Optional[str] = ""
|
||||
updated_at: Optional[str] = ""
|
||||
id: Optional[str] = ""
|
||||
key: Optional[str] = ""
|
||||
_id: str
|
||||
created_at: str
|
||||
rule_action_config_id: str
|
||||
rule_config_id: str
|
||||
status_cd: str
|
||||
updated_at: str
|
||||
id: str
|
||||
key: str
|
||||
|
||||
|
||||
class ContactCampaignStatus(BaseModel):
|
||||
"""A contact campaign status in Apollo"""
|
||||
|
||||
id: Optional[str] = ""
|
||||
emailer_campaign_id: Optional[str] = ""
|
||||
send_email_from_user_id: Optional[str] = ""
|
||||
inactive_reason: Optional[str] = ""
|
||||
status: Optional[str] = ""
|
||||
added_at: Optional[str] = ""
|
||||
added_by_user_id: Optional[str] = ""
|
||||
finished_at: Optional[str] = ""
|
||||
paused_at: Optional[str] = ""
|
||||
auto_unpause_at: Optional[str] = ""
|
||||
send_email_from_email_address: Optional[str] = ""
|
||||
send_email_from_email_account_id: Optional[str] = ""
|
||||
manually_set_unpause: Optional[str] = ""
|
||||
failure_reason: Optional[str] = ""
|
||||
current_step_id: Optional[str] = ""
|
||||
in_response_to_emailer_message_id: Optional[str] = ""
|
||||
cc_emails: Optional[str] = ""
|
||||
bcc_emails: Optional[str] = ""
|
||||
to_emails: Optional[str] = ""
|
||||
id: str
|
||||
emailer_campaign_id: str
|
||||
send_email_from_user_id: str
|
||||
inactive_reason: str
|
||||
status: str
|
||||
added_at: str
|
||||
added_by_user_id: str
|
||||
finished_at: str
|
||||
paused_at: str
|
||||
auto_unpause_at: str
|
||||
send_email_from_email_address: str
|
||||
send_email_from_email_account_id: str
|
||||
manually_set_unpause: str
|
||||
failure_reason: str
|
||||
current_step_id: str
|
||||
in_response_to_emailer_message_id: str
|
||||
cc_emails: str
|
||||
bcc_emails: str
|
||||
to_emails: str
|
||||
|
||||
|
||||
class Account(BaseModel):
|
||||
"""An account in Apollo"""
|
||||
|
||||
id: Optional[str] = ""
|
||||
name: Optional[str] = ""
|
||||
website_url: Optional[str] = ""
|
||||
blog_url: Optional[str] = ""
|
||||
angellist_url: Optional[str] = ""
|
||||
linkedin_url: Optional[str] = ""
|
||||
twitter_url: Optional[str] = ""
|
||||
facebook_url: Optional[str] = ""
|
||||
primary_phone: Optional[PrimaryPhone] = PrimaryPhone()
|
||||
languages: Optional[list[str]] = []
|
||||
alexa_ranking: Optional[int] = 0
|
||||
phone: Optional[str] = ""
|
||||
linkedin_uid: Optional[str] = ""
|
||||
founded_year: Optional[int] = 0
|
||||
publicly_traded_symbol: Optional[str] = ""
|
||||
publicly_traded_exchange: Optional[str] = ""
|
||||
logo_url: Optional[str] = ""
|
||||
chrunchbase_url: Optional[str] = ""
|
||||
primary_domain: Optional[str] = ""
|
||||
domain: Optional[str] = ""
|
||||
team_id: Optional[str] = ""
|
||||
organization_id: Optional[str] = ""
|
||||
account_stage_id: Optional[str] = ""
|
||||
source: Optional[str] = ""
|
||||
original_source: Optional[str] = ""
|
||||
creator_id: Optional[str] = ""
|
||||
owner_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
phone_status: Optional[str] = ""
|
||||
hubspot_id: Optional[str] = ""
|
||||
salesforce_id: Optional[str] = ""
|
||||
crm_owner_id: Optional[str] = ""
|
||||
parent_account_id: Optional[str] = ""
|
||||
sanitized_phone: Optional[str] = ""
|
||||
id: str
|
||||
name: str
|
||||
website_url: str
|
||||
blog_url: str
|
||||
angellist_url: str
|
||||
linkedin_url: str
|
||||
twitter_url: str
|
||||
facebook_url: str
|
||||
primary_phone: PrimaryPhone
|
||||
languages: list[str]
|
||||
alexa_ranking: int
|
||||
phone: str
|
||||
linkedin_uid: str
|
||||
founded_year: int
|
||||
publicly_traded_symbol: str
|
||||
publicly_traded_exchange: str
|
||||
logo_url: str
|
||||
chrunchbase_url: str
|
||||
primary_domain: str
|
||||
domain: str
|
||||
team_id: str
|
||||
organization_id: str
|
||||
account_stage_id: str
|
||||
source: str
|
||||
original_source: str
|
||||
creator_id: str
|
||||
owner_id: str
|
||||
created_at: str
|
||||
phone_status: str
|
||||
hubspot_id: str
|
||||
salesforce_id: str
|
||||
crm_owner_id: str
|
||||
parent_account_id: str
|
||||
sanitized_phone: str
|
||||
# no listed type on the API docs
|
||||
account_playbook_statues: Optional[list[Any]] = []
|
||||
account_rule_config_statuses: Optional[list[RuleConfigStatus]] = []
|
||||
existence_level: Optional[str] = ""
|
||||
label_ids: Optional[list[str]] = []
|
||||
typed_custom_fields: Optional[Any] = {}
|
||||
custom_field_errors: Optional[Any] = {}
|
||||
modality: Optional[str] = ""
|
||||
source_display_name: Optional[str] = ""
|
||||
salesforce_record_id: Optional[str] = ""
|
||||
crm_record_url: Optional[str] = ""
|
||||
account_playbook_statues: list[Any]
|
||||
account_rule_config_statuses: list[RuleConfigStatus]
|
||||
existence_level: str
|
||||
label_ids: list[str]
|
||||
typed_custom_fields: Any
|
||||
custom_field_errors: Any
|
||||
modality: str
|
||||
source_display_name: str
|
||||
salesforce_record_id: str
|
||||
crm_record_url: str
|
||||
|
||||
|
||||
class ContactEmail(BaseModel):
|
||||
"""A contact email in Apollo"""
|
||||
|
||||
email: Optional[str] = ""
|
||||
email_md5: Optional[str] = ""
|
||||
email_sha256: Optional[str] = ""
|
||||
email_status: Optional[str] = ""
|
||||
email_source: Optional[str] = ""
|
||||
extrapolated_email_confidence: Optional[str] = ""
|
||||
position: Optional[int] = 0
|
||||
email_from_customer: Optional[str] = ""
|
||||
free_domain: Optional[bool] = True
|
||||
email: str = ""
|
||||
email_md5: str = ""
|
||||
email_sha256: str = ""
|
||||
email_status: str = ""
|
||||
email_source: str = ""
|
||||
extrapolated_email_confidence: str = ""
|
||||
position: int = 0
|
||||
email_from_customer: str = ""
|
||||
free_domain: bool = True
|
||||
|
||||
|
||||
class EmploymentHistory(BaseModel):
|
||||
"""An employment history in Apollo"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
from_attributes = True
|
||||
populate_by_name = True
|
||||
|
||||
_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
current: Optional[bool] = False
|
||||
degree: Optional[str] = ""
|
||||
description: Optional[str] = ""
|
||||
emails: Optional[str] = ""
|
||||
end_date: Optional[str] = ""
|
||||
grade_level: Optional[str] = ""
|
||||
kind: Optional[str] = ""
|
||||
major: Optional[str] = ""
|
||||
organization_id: Optional[str] = ""
|
||||
organization_name: Optional[str] = ""
|
||||
raw_address: Optional[str] = ""
|
||||
start_date: Optional[str] = ""
|
||||
title: Optional[str] = ""
|
||||
updated_at: Optional[str] = ""
|
||||
id: Optional[str] = ""
|
||||
key: Optional[str] = ""
|
||||
_id: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
current: Optional[bool] = None
|
||||
degree: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
emails: Optional[str] = None
|
||||
end_date: Optional[str] = None
|
||||
grade_level: Optional[str] = None
|
||||
kind: Optional[str] = None
|
||||
major: Optional[str] = None
|
||||
organization_id: Optional[str] = None
|
||||
organization_name: Optional[str] = None
|
||||
raw_address: Optional[str] = None
|
||||
start_date: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
id: Optional[str] = None
|
||||
key: Optional[str] = None
|
||||
|
||||
|
||||
class Breadcrumb(BaseModel):
|
||||
"""A breadcrumb in Apollo"""
|
||||
|
||||
label: Optional[str] = ""
|
||||
signal_field_name: Optional[str] = ""
|
||||
value: str | list | None = ""
|
||||
display_name: Optional[str] = ""
|
||||
label: Optional[str] = "N/A"
|
||||
signal_field_name: Optional[str] = "N/A"
|
||||
value: str | list | None = "N/A"
|
||||
display_name: Optional[str] = "N/A"
|
||||
|
||||
|
||||
class TypedCustomField(BaseModel):
|
||||
"""A typed custom field in Apollo"""
|
||||
|
||||
id: Optional[str] = ""
|
||||
value: Optional[str] = ""
|
||||
id: Optional[str] = "N/A"
|
||||
value: Optional[str] = "N/A"
|
||||
|
||||
|
||||
class Pagination(BaseModel):
|
||||
"""Pagination in Apollo"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
class Config:
|
||||
extra = "allow" # Allow extra fields
|
||||
arbitrary_types_allowed = True # Allow any type
|
||||
from_attributes = True # Allow from_orm
|
||||
populate_by_name = True # Allow field aliases to work both ways
|
||||
|
||||
page: int = 0
|
||||
per_page: int = 0
|
||||
@@ -219,23 +203,23 @@ class Pagination(BaseModel):
|
||||
class DialerFlags(BaseModel):
|
||||
"""A dialer flags in Apollo"""
|
||||
|
||||
country_name: Optional[str] = ""
|
||||
country_enabled: Optional[bool] = True
|
||||
high_risk_calling_enabled: Optional[bool] = True
|
||||
potential_high_risk_number: Optional[bool] = True
|
||||
country_name: str
|
||||
country_enabled: bool
|
||||
high_risk_calling_enabled: bool
|
||||
potential_high_risk_number: bool
|
||||
|
||||
|
||||
class PhoneNumber(BaseModel):
|
||||
"""A phone number in Apollo"""
|
||||
|
||||
raw_number: Optional[str] = ""
|
||||
sanitized_number: Optional[str] = ""
|
||||
type: Optional[str] = ""
|
||||
position: Optional[int] = 0
|
||||
status: Optional[str] = ""
|
||||
dnc_status: Optional[str] = ""
|
||||
dnc_other_info: Optional[str] = ""
|
||||
dailer_flags: Optional[DialerFlags] = DialerFlags(
|
||||
raw_number: str = ""
|
||||
sanitized_number: str = ""
|
||||
type: str = ""
|
||||
position: int = 0
|
||||
status: str = ""
|
||||
dnc_status: str = ""
|
||||
dnc_other_info: str = ""
|
||||
dailer_flags: DialerFlags = DialerFlags(
|
||||
country_name="",
|
||||
country_enabled=True,
|
||||
high_risk_calling_enabled=True,
|
||||
@@ -246,171 +230,169 @@ class PhoneNumber(BaseModel):
|
||||
class Organization(BaseModel):
|
||||
"""An organization in Apollo"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
from_attributes = True
|
||||
populate_by_name = True
|
||||
|
||||
id: Optional[str] = ""
|
||||
name: Optional[str] = ""
|
||||
website_url: Optional[str] = ""
|
||||
blog_url: Optional[str] = ""
|
||||
angellist_url: Optional[str] = ""
|
||||
linkedin_url: Optional[str] = ""
|
||||
twitter_url: Optional[str] = ""
|
||||
facebook_url: Optional[str] = ""
|
||||
primary_phone: Optional[PrimaryPhone] = PrimaryPhone()
|
||||
languages: Optional[list[str]] = []
|
||||
id: Optional[str] = "N/A"
|
||||
name: Optional[str] = "N/A"
|
||||
website_url: Optional[str] = "N/A"
|
||||
blog_url: Optional[str] = "N/A"
|
||||
angellist_url: Optional[str] = "N/A"
|
||||
linkedin_url: Optional[str] = "N/A"
|
||||
twitter_url: Optional[str] = "N/A"
|
||||
facebook_url: Optional[str] = "N/A"
|
||||
primary_phone: Optional[PrimaryPhone] = PrimaryPhone(
|
||||
number="N/A", source="N/A", sanitized_number="N/A"
|
||||
)
|
||||
languages: list[str] = []
|
||||
alexa_ranking: Optional[int] = 0
|
||||
phone: Optional[str] = ""
|
||||
linkedin_uid: Optional[str] = ""
|
||||
phone: Optional[str] = "N/A"
|
||||
linkedin_uid: Optional[str] = "N/A"
|
||||
founded_year: Optional[int] = 0
|
||||
publicly_traded_symbol: Optional[str] = ""
|
||||
publicly_traded_exchange: Optional[str] = ""
|
||||
logo_url: Optional[str] = ""
|
||||
chrunchbase_url: Optional[str] = ""
|
||||
primary_domain: Optional[str] = ""
|
||||
sanitized_phone: Optional[str] = ""
|
||||
owned_by_organization_id: Optional[str] = ""
|
||||
intent_strength: Optional[str] = ""
|
||||
show_intent: Optional[bool] = True
|
||||
publicly_traded_symbol: Optional[str] = "N/A"
|
||||
publicly_traded_exchange: Optional[str] = "N/A"
|
||||
logo_url: Optional[str] = "N/A"
|
||||
chrunchbase_url: Optional[str] = "N/A"
|
||||
primary_domain: Optional[str] = "N/A"
|
||||
sanitized_phone: Optional[str] = "N/A"
|
||||
owned_by_organization_id: Optional[str] = "N/A"
|
||||
intent_strength: Optional[str] = "N/A"
|
||||
show_intent: bool = True
|
||||
has_intent_signal_account: Optional[bool] = True
|
||||
intent_signal_account: Optional[str] = ""
|
||||
intent_signal_account: Optional[str] = "N/A"
|
||||
|
||||
|
||||
class Contact(BaseModel):
|
||||
"""A contact in Apollo"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
from_attributes = True
|
||||
populate_by_name = True
|
||||
|
||||
contact_roles: Optional[list[Any]] = []
|
||||
id: Optional[str] = ""
|
||||
first_name: Optional[str] = ""
|
||||
last_name: Optional[str] = ""
|
||||
name: Optional[str] = ""
|
||||
linkedin_url: Optional[str] = ""
|
||||
title: Optional[str] = ""
|
||||
contact_stage_id: Optional[str] = ""
|
||||
owner_id: Optional[str] = ""
|
||||
creator_id: Optional[str] = ""
|
||||
person_id: Optional[str] = ""
|
||||
email_needs_tickling: Optional[bool] = True
|
||||
organization_name: Optional[str] = ""
|
||||
source: Optional[str] = ""
|
||||
original_source: Optional[str] = ""
|
||||
organization_id: Optional[str] = ""
|
||||
headline: Optional[str] = ""
|
||||
photo_url: Optional[str] = ""
|
||||
present_raw_address: Optional[str] = ""
|
||||
linkededin_uid: Optional[str] = ""
|
||||
extrapolated_email_confidence: Optional[float] = 0.0
|
||||
salesforce_id: Optional[str] = ""
|
||||
salesforce_lead_id: Optional[str] = ""
|
||||
salesforce_contact_id: Optional[str] = ""
|
||||
saleforce_account_id: Optional[str] = ""
|
||||
crm_owner_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
emailer_campaign_ids: Optional[list[str]] = []
|
||||
direct_dial_status: Optional[str] = ""
|
||||
direct_dial_enrichment_failed_at: Optional[str] = ""
|
||||
email_status: Optional[str] = ""
|
||||
email_source: Optional[str] = ""
|
||||
account_id: Optional[str] = ""
|
||||
last_activity_date: Optional[str] = ""
|
||||
hubspot_vid: Optional[str] = ""
|
||||
hubspot_company_id: Optional[str] = ""
|
||||
crm_id: Optional[str] = ""
|
||||
sanitized_phone: Optional[str] = ""
|
||||
merged_crm_ids: Optional[str] = ""
|
||||
updated_at: Optional[str] = ""
|
||||
queued_for_crm_push: Optional[bool] = True
|
||||
suggested_from_rule_engine_config_id: Optional[str] = ""
|
||||
email_unsubscribed: Optional[str] = ""
|
||||
label_ids: Optional[list[Any]] = []
|
||||
has_pending_email_arcgate_request: Optional[bool] = True
|
||||
has_email_arcgate_request: Optional[bool] = True
|
||||
existence_level: Optional[str] = ""
|
||||
email: Optional[str] = ""
|
||||
email_from_customer: Optional[str] = ""
|
||||
typed_custom_fields: Optional[list[TypedCustomField]] = []
|
||||
custom_field_errors: Optional[Any] = {}
|
||||
salesforce_record_id: Optional[str] = ""
|
||||
crm_record_url: Optional[str] = ""
|
||||
email_status_unavailable_reason: Optional[str] = ""
|
||||
email_true_status: Optional[str] = ""
|
||||
updated_email_true_status: Optional[bool] = True
|
||||
contact_rule_config_statuses: Optional[list[RuleConfigStatus]] = []
|
||||
source_display_name: Optional[str] = ""
|
||||
twitter_url: Optional[str] = ""
|
||||
contact_campaign_statuses: Optional[list[ContactCampaignStatus]] = []
|
||||
state: Optional[str] = ""
|
||||
city: Optional[str] = ""
|
||||
country: Optional[str] = ""
|
||||
account: Optional[Account] = Account()
|
||||
contact_emails: Optional[list[ContactEmail]] = []
|
||||
organization: Optional[Organization] = Organization()
|
||||
employment_history: Optional[list[EmploymentHistory]] = []
|
||||
time_zone: Optional[str] = ""
|
||||
intent_strength: Optional[str] = ""
|
||||
show_intent: Optional[bool] = True
|
||||
phone_numbers: Optional[list[PhoneNumber]] = []
|
||||
account_phone_note: Optional[str] = ""
|
||||
free_domain: Optional[bool] = True
|
||||
is_likely_to_engage: Optional[bool] = True
|
||||
email_domain_catchall: Optional[bool] = True
|
||||
contact_job_change_event: Optional[str] = ""
|
||||
contact_roles: list[Any] = []
|
||||
id: Optional[str] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
linkedin_url: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
contact_stage_id: Optional[str] = None
|
||||
owner_id: Optional[str] = None
|
||||
creator_id: Optional[str] = None
|
||||
person_id: Optional[str] = None
|
||||
email_needs_tickling: bool = True
|
||||
organization_name: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
original_source: Optional[str] = None
|
||||
organization_id: Optional[str] = None
|
||||
headline: Optional[str] = None
|
||||
photo_url: Optional[str] = None
|
||||
present_raw_address: Optional[str] = None
|
||||
linkededin_uid: Optional[str] = None
|
||||
extrapolated_email_confidence: Optional[float] = None
|
||||
salesforce_id: Optional[str] = None
|
||||
salesforce_lead_id: Optional[str] = None
|
||||
salesforce_contact_id: Optional[str] = None
|
||||
saleforce_account_id: Optional[str] = None
|
||||
crm_owner_id: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
emailer_campaign_ids: list[str] = []
|
||||
direct_dial_status: Optional[str] = None
|
||||
direct_dial_enrichment_failed_at: Optional[str] = None
|
||||
email_status: Optional[str] = None
|
||||
email_source: Optional[str] = None
|
||||
account_id: Optional[str] = None
|
||||
last_activity_date: Optional[str] = None
|
||||
hubspot_vid: Optional[str] = None
|
||||
hubspot_company_id: Optional[str] = None
|
||||
crm_id: Optional[str] = None
|
||||
sanitized_phone: Optional[str] = None
|
||||
merged_crm_ids: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
queued_for_crm_push: bool = True
|
||||
suggested_from_rule_engine_config_id: Optional[str] = None
|
||||
email_unsubscribed: Optional[str] = None
|
||||
label_ids: list[Any] = []
|
||||
has_pending_email_arcgate_request: bool = True
|
||||
has_email_arcgate_request: bool = True
|
||||
existence_level: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
email_from_customer: Optional[str] = None
|
||||
typed_custom_fields: list[TypedCustomField] = []
|
||||
custom_field_errors: Any = None
|
||||
salesforce_record_id: Optional[str] = None
|
||||
crm_record_url: Optional[str] = None
|
||||
email_status_unavailable_reason: Optional[str] = None
|
||||
email_true_status: Optional[str] = None
|
||||
updated_email_true_status: bool = True
|
||||
contact_rule_config_statuses: list[RuleConfigStatus] = []
|
||||
source_display_name: Optional[str] = None
|
||||
twitter_url: Optional[str] = None
|
||||
contact_campaign_statuses: list[ContactCampaignStatus] = []
|
||||
state: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
account: Optional[Account] = None
|
||||
contact_emails: list[ContactEmail] = []
|
||||
organization: Optional[Organization] = None
|
||||
employment_history: list[EmploymentHistory] = []
|
||||
time_zone: Optional[str] = None
|
||||
intent_strength: Optional[str] = None
|
||||
show_intent: bool = True
|
||||
phone_numbers: list[PhoneNumber] = []
|
||||
account_phone_note: Optional[str] = None
|
||||
free_domain: bool = True
|
||||
is_likely_to_engage: bool = True
|
||||
email_domain_catchall: bool = True
|
||||
contact_job_change_event: Optional[str] = None
|
||||
|
||||
|
||||
class SearchOrganizationsRequest(BaseModel):
|
||||
"""Request for Apollo's search organizations API"""
|
||||
|
||||
organization_num_employees_range: Optional[list[int]] = SchemaField(
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default=[0, 1000000],
|
||||
)
|
||||
|
||||
organization_locations: Optional[list[str]] = SchemaField(
|
||||
organization_locations: list[str] = SchemaField(
|
||||
description="""The location of the company headquarters. You can search across cities, US states, and countries.
|
||||
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, any Boston-based companies will not appearch in your search results, even if they match other parameters.
|
||||
|
||||
To exclude companies based on location, use the organization_not_locations parameter.
|
||||
""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
organizations_not_locations: Optional[list[str]] = SchemaField(
|
||||
organizations_not_locations: list[str] = SchemaField(
|
||||
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
|
||||
|
||||
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
|
||||
""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
q_organization_keyword_tags: Optional[list[str]] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry.""",
|
||||
default_factory=list,
|
||||
q_organization_keyword_tags: list[str] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry."""
|
||||
)
|
||||
q_organization_name: Optional[str] = SchemaField(
|
||||
q_organization_name: str = SchemaField(
|
||||
description="""Filter search results to include a specific company name.
|
||||
|
||||
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible.""",
|
||||
default="",
|
||||
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible."""
|
||||
)
|
||||
organization_ids: Optional[list[str]] = SchemaField(
|
||||
organization_ids: list[str] = SchemaField(
|
||||
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, identify the values for organization_id when you call this endpoint.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
max_results: Optional[int] = SchemaField(
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
ge=1,
|
||||
@@ -435,11 +417,11 @@ Use the page parameter to search the different pages of data.""",
|
||||
class SearchOrganizationsResponse(BaseModel):
|
||||
"""Response from Apollo's search organizations API"""
|
||||
|
||||
breadcrumbs: Optional[list[Breadcrumb]] = []
|
||||
partial_results_only: Optional[bool] = True
|
||||
has_join: Optional[bool] = True
|
||||
disable_eu_prospecting: Optional[bool] = True
|
||||
partial_results_limit: Optional[int] = 0
|
||||
breadcrumbs: list[Breadcrumb] = []
|
||||
partial_results_only: bool = True
|
||||
has_join: bool = True
|
||||
disable_eu_prospecting: bool = True
|
||||
partial_results_limit: int = 0
|
||||
pagination: Pagination = Pagination(
|
||||
page=0, per_page=0, total_entries=0, total_pages=0
|
||||
)
|
||||
@@ -447,30 +429,30 @@ class SearchOrganizationsResponse(BaseModel):
|
||||
accounts: list[Any] = []
|
||||
organizations: list[Organization] = []
|
||||
models_ids: list[str] = []
|
||||
num_fetch_result: Optional[str] = ""
|
||||
derived_params: Optional[str] = ""
|
||||
num_fetch_result: Optional[str] = "N/A"
|
||||
derived_params: Optional[str] = "N/A"
|
||||
|
||||
|
||||
class SearchPeopleRequest(BaseModel):
|
||||
"""Request for Apollo's search people API"""
|
||||
|
||||
person_titles: Optional[list[str]] = SchemaField(
|
||||
person_titles: list[str] = SchemaField(
|
||||
description="""Job titles held by the people you want to find. For a person to be included in search results, they only need to match 1 of the job titles you add. Adding more job titles expands your search results.
|
||||
|
||||
Results also include job titles with the same terms, even if they are not exact matches. For example, searching for marketing manager might return people with the job title content marketing manager.
|
||||
|
||||
Use this parameter in combination with the person_seniorities[] parameter to find people based on specific job functions and seniority levels.
|
||||
""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
placeholder="marketing manager",
|
||||
)
|
||||
person_locations: Optional[list[str]] = SchemaField(
|
||||
person_locations: list[str] = SchemaField(
|
||||
description="""The location where people live. You can search across cities, US states, and countries.
|
||||
|
||||
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
person_seniorities: Optional[list[SenorityLevels]] = SchemaField(
|
||||
person_seniorities: list[SenorityLevels] = SchemaField(
|
||||
description="""The job seniority that people hold within their current employer. This enables you to find people that currently hold positions at certain reporting levels, such as Director level or senior IC level.
|
||||
|
||||
For a person to be included in search results, they only need to match 1 of the seniorities you add. Adding more seniorities expands your search results.
|
||||
@@ -478,41 +460,41 @@ For a person to be included in search results, they only need to match 1 of the
|
||||
Searches only return results based on their current job title, so searching for Director-level employees only returns people that currently hold a Director-level title. If someone was previously a Director, but is currently a VP, they would not be included in your search results.
|
||||
|
||||
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
organization_locations: Optional[list[str]] = SchemaField(
|
||||
organization_locations: list[str] = SchemaField(
|
||||
description="""The location of the company headquarters for a person's current employer. You can search across cities, US states, and countries.
|
||||
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
|
||||
|
||||
To find people based on their personal location, use the person_locations parameter.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
q_organization_domains: Optional[list[str]] = SchemaField(
|
||||
q_organization_domains: list[str] = SchemaField(
|
||||
description="""The domain name for the person's employer. This can be the current employer or a previous employer. Do not include www., the @ symbol, or similar.
|
||||
|
||||
You can add multiple domains to search across companies.
|
||||
|
||||
Examples: apollo.io and microsoft.com""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
contact_email_statuses: Optional[list[ContactEmailStatuses]] = SchemaField(
|
||||
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
|
||||
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
organization_ids: Optional[list[str]] = SchemaField(
|
||||
organization_ids: list[str] = SchemaField(
|
||||
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
organization_num_employees_range: Optional[list[int]] = SchemaField(
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
q_keywords: Optional[str] = SchemaField(
|
||||
q_keywords: str = SchemaField(
|
||||
description="""A string of words over which we want to filter the results""",
|
||||
default="",
|
||||
)
|
||||
@@ -528,7 +510,7 @@ Use this parameter in combination with the per_page parameter to make search res
|
||||
Use the page parameter to search the different pages of data.""",
|
||||
default=100,
|
||||
)
|
||||
max_results: Optional[int] = SchemaField(
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
ge=1,
|
||||
@@ -540,68 +522,22 @@ Use the page parameter to search the different pages of data.""",
|
||||
class SearchPeopleResponse(BaseModel):
|
||||
"""Response from Apollo's search people API"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
populate_by_name=True,
|
||||
)
|
||||
class Config:
|
||||
extra = "allow" # Allow extra fields
|
||||
arbitrary_types_allowed = True # Allow any type
|
||||
from_attributes = True # Allow from_orm
|
||||
populate_by_name = True # Allow field aliases to work both ways
|
||||
|
||||
breadcrumbs: Optional[list[Breadcrumb]] = []
|
||||
partial_results_only: Optional[bool] = True
|
||||
has_join: Optional[bool] = True
|
||||
disable_eu_prospecting: Optional[bool] = True
|
||||
partial_results_limit: Optional[int] = 0
|
||||
breadcrumbs: list[Breadcrumb] = []
|
||||
partial_results_only: bool = True
|
||||
has_join: bool = True
|
||||
disable_eu_prospecting: bool = True
|
||||
partial_results_limit: int = 0
|
||||
pagination: Pagination = Pagination(
|
||||
page=0, per_page=0, total_entries=0, total_pages=0
|
||||
)
|
||||
contacts: list[Contact] = []
|
||||
people: list[Contact] = []
|
||||
model_ids: list[str] = []
|
||||
num_fetch_result: Optional[str] = ""
|
||||
derived_params: Optional[str] = ""
|
||||
|
||||
|
||||
class EnrichPersonRequest(BaseModel):
|
||||
"""Request for Apollo's person enrichment API"""
|
||||
|
||||
person_id: Optional[str] = SchemaField(
|
||||
description="Apollo person ID to enrich (most accurate method)",
|
||||
default="",
|
||||
)
|
||||
first_name: Optional[str] = SchemaField(
|
||||
description="First name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
last_name: Optional[str] = SchemaField(
|
||||
description="Last name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
name: Optional[str] = SchemaField(
|
||||
description="Full name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
email: Optional[str] = SchemaField(
|
||||
description="Email address of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
domain: Optional[str] = SchemaField(
|
||||
description="Company domain of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
company: Optional[str] = SchemaField(
|
||||
description="Company name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
linkedin_url: Optional[str] = SchemaField(
|
||||
description="LinkedIn URL of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
organization_id: Optional[str] = SchemaField(
|
||||
description="Apollo organization ID of the person's company",
|
||||
default="",
|
||||
)
|
||||
title: Optional[str] = SchemaField(
|
||||
description="Job title of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
num_fetch_result: Optional[str] = "N/A"
|
||||
derived_params: Optional[str] = "N/A"
|
||||
|
||||
@@ -11,14 +11,14 @@ from backend.blocks.apollo.models import (
|
||||
SearchOrganizationsRequest,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class SearchOrganizationsBlock(Block):
|
||||
"""Search for organizations in Apollo"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
organization_num_employees_range: list[int] = SchemaField(
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
@@ -32,18 +32,18 @@ If a company has several office locations, results are still based on the headqu
|
||||
|
||||
To exclude companies based on location, use the organization_not_locations parameter.
|
||||
""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
organizations_not_locations: list[str] = SchemaField(
|
||||
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
|
||||
|
||||
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
|
||||
""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
q_organization_keyword_tags: list[str] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
q_organization_name: str = SchemaField(
|
||||
description="""Filter search results to include a specific company name.
|
||||
@@ -56,7 +56,7 @@ If the value you enter for this parameter does not match with a company's name,
|
||||
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, identify the values for organization_id when you call this endpoint.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
@@ -65,14 +65,14 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
le=50000,
|
||||
advanced=True,
|
||||
)
|
||||
credentials: ApolloCredentialsInput = CredentialsField(
|
||||
credentials: ApolloCredentialsInput = SchemaField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
organizations: list[Organization] = SchemaField(
|
||||
description="List of organizations found",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
organization: Organization = SchemaField(
|
||||
description="Each found organization, one at a time",
|
||||
@@ -201,17 +201,19 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def search_organizations(
|
||||
def search_organizations(
|
||||
query: SearchOrganizationsRequest, credentials: ApolloCredentials
|
||||
) -> list[Organization]:
|
||||
client = ApolloClient(credentials)
|
||||
return await client.search_organizations(query)
|
||||
return client.search_organizations(query)
|
||||
|
||||
async def run(
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: ApolloCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
query = SearchOrganizationsRequest(**input_data.model_dump())
|
||||
organizations = await self.search_organizations(query, credentials)
|
||||
query = SearchOrganizationsRequest(
|
||||
**input_data.model_dump(exclude={"credentials"})
|
||||
)
|
||||
organizations = self.search_organizations(query, credentials)
|
||||
for organization in organizations:
|
||||
yield "organization", organization
|
||||
yield "organizations", organizations
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import asyncio
|
||||
|
||||
from backend.blocks.apollo._api import ApolloClient
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
@@ -10,12 +8,11 @@ from backend.blocks.apollo._auth import (
|
||||
from backend.blocks.apollo.models import (
|
||||
Contact,
|
||||
ContactEmailStatuses,
|
||||
EnrichPersonRequest,
|
||||
SearchPeopleRequest,
|
||||
SenorityLevels,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class SearchPeopleBlock(Block):
|
||||
@@ -29,14 +26,14 @@ class SearchPeopleBlock(Block):
|
||||
|
||||
Use this parameter in combination with the person_seniorities[] parameter to find people based on specific job functions and seniority levels.
|
||||
""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
person_locations: list[str] = SchemaField(
|
||||
description="""The location where people live. You can search across cities, US states, and countries.
|
||||
|
||||
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
person_seniorities: list[SenorityLevels] = SchemaField(
|
||||
@@ -47,7 +44,7 @@ class SearchPeopleBlock(Block):
|
||||
Searches only return results based on their current job title, so searching for Director-level employees only returns people that currently hold a Director-level title. If someone was previously a Director, but is currently a VP, they would not be included in your search results.
|
||||
|
||||
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
organization_locations: list[str] = SchemaField(
|
||||
@@ -56,7 +53,7 @@ class SearchPeopleBlock(Block):
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
|
||||
|
||||
To find people based on their personal location, use the person_locations parameter.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
q_organization_domains: list[str] = SchemaField(
|
||||
@@ -65,26 +62,26 @@ class SearchPeopleBlock(Block):
|
||||
You can add multiple domains to search across companies.
|
||||
|
||||
Examples: apollo.io and microsoft.com""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
|
||||
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
organization_ids: list[str] = SchemaField(
|
||||
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
organization_num_employees_range: list[int] = SchemaField(
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
q_keywords: str = SchemaField(
|
||||
@@ -93,26 +90,24 @@ class SearchPeopleBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 25. Limited to 500 to prevent overspending.""",
|
||||
default=25,
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
ge=1,
|
||||
le=500,
|
||||
advanced=True,
|
||||
)
|
||||
enrich_info: bool = SchemaField(
|
||||
description="""Whether to enrich contacts with detailed information including real email addresses. This will double the search cost.""",
|
||||
default=False,
|
||||
le=50000,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
credentials: ApolloCredentialsInput = CredentialsField(
|
||||
credentials: ApolloCredentialsInput = SchemaField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
people: list[Contact] = SchemaField(
|
||||
description="List of people found",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
)
|
||||
person: Contact = SchemaField(
|
||||
description="Each found person, one at a time",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the search failed",
|
||||
@@ -129,6 +124,87 @@ class SearchPeopleBlock(Block):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
(
|
||||
"person",
|
||||
Contact(
|
||||
contact_roles=[],
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
organization_id="123456",
|
||||
contact_stage_id="1",
|
||||
owner_id="1",
|
||||
creator_id="1",
|
||||
person_id="1",
|
||||
email_needs_tickling=True,
|
||||
source="apollo",
|
||||
original_source="apollo",
|
||||
headline="Software Engineer",
|
||||
photo_url="https://www.linkedin.com/in/johndoe",
|
||||
present_raw_address="123 Main St, Anytown, USA",
|
||||
linkededin_uid="123456",
|
||||
extrapolated_email_confidence=0.8,
|
||||
salesforce_id="123456",
|
||||
salesforce_lead_id="123456",
|
||||
salesforce_contact_id="123456",
|
||||
saleforce_account_id="123456",
|
||||
crm_owner_id="123456",
|
||||
created_at="2021-01-01",
|
||||
emailer_campaign_ids=[],
|
||||
direct_dial_status="active",
|
||||
direct_dial_enrichment_failed_at="2021-01-01",
|
||||
email_status="active",
|
||||
email_source="apollo",
|
||||
account_id="123456",
|
||||
last_activity_date="2021-01-01",
|
||||
hubspot_vid="123456",
|
||||
hubspot_company_id="123456",
|
||||
crm_id="123456",
|
||||
sanitized_phone="123456",
|
||||
merged_crm_ids="123456",
|
||||
updated_at="2021-01-01",
|
||||
queued_for_crm_push=True,
|
||||
suggested_from_rule_engine_config_id="123456",
|
||||
email_unsubscribed=None,
|
||||
label_ids=[],
|
||||
has_pending_email_arcgate_request=True,
|
||||
has_email_arcgate_request=True,
|
||||
existence_level=None,
|
||||
email=None,
|
||||
email_from_customer=None,
|
||||
typed_custom_fields=[],
|
||||
custom_field_errors=None,
|
||||
salesforce_record_id=None,
|
||||
crm_record_url=None,
|
||||
email_status_unavailable_reason=None,
|
||||
email_true_status=None,
|
||||
updated_email_true_status=True,
|
||||
contact_rule_config_statuses=[],
|
||||
source_display_name=None,
|
||||
twitter_url=None,
|
||||
contact_campaign_statuses=[],
|
||||
state=None,
|
||||
city=None,
|
||||
country=None,
|
||||
account=None,
|
||||
contact_emails=[],
|
||||
organization=None,
|
||||
employment_history=[],
|
||||
time_zone=None,
|
||||
intent_strength=None,
|
||||
show_intent=True,
|
||||
phone_numbers=[],
|
||||
account_phone_note=None,
|
||||
free_domain=True,
|
||||
is_likely_to_engage=True,
|
||||
email_domain_catchall=True,
|
||||
contact_job_change_event=None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"people",
|
||||
[
|
||||
@@ -297,41 +373,13 @@ class SearchPeopleBlock(Block):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def search_people(
|
||||
def search_people(
|
||||
query: SearchPeopleRequest, credentials: ApolloCredentials
|
||||
) -> list[Contact]:
|
||||
client = ApolloClient(credentials)
|
||||
return await client.search_people(query)
|
||||
return client.search_people(query)
|
||||
|
||||
@staticmethod
|
||||
async def enrich_person(
|
||||
query: EnrichPersonRequest, credentials: ApolloCredentials
|
||||
) -> Contact:
|
||||
client = ApolloClient(credentials)
|
||||
return await client.enrich_person(query)
|
||||
|
||||
@staticmethod
|
||||
def merge_contact_data(original: Contact, enriched: Contact) -> Contact:
|
||||
"""
|
||||
Merge contact data from original search with enriched data.
|
||||
Enriched data complements original data, only filling in missing values.
|
||||
"""
|
||||
merged_data = original.model_dump()
|
||||
enriched_data = enriched.model_dump()
|
||||
|
||||
# Only update fields that are None, empty string, empty list, or default values in original
|
||||
for key, enriched_value in enriched_data.items():
|
||||
# Skip if enriched value is None, empty string, or empty list
|
||||
if enriched_value is None or enriched_value == "" or enriched_value == []:
|
||||
continue
|
||||
|
||||
# Update if original value is None, empty string, empty list, or zero
|
||||
if enriched_value:
|
||||
merged_data[key] = enriched_value
|
||||
|
||||
return Contact(**merged_data)
|
||||
|
||||
async def run(
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
@@ -339,25 +387,8 @@ class SearchPeopleBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
query = SearchPeopleRequest(**input_data.model_dump())
|
||||
people = await self.search_people(query, credentials)
|
||||
|
||||
# Enrich with detailed info if requested
|
||||
if input_data.enrich_info:
|
||||
|
||||
async def enrich_or_fallback(person: Contact):
|
||||
try:
|
||||
enrich_query = EnrichPersonRequest(person_id=person.id)
|
||||
enriched_person = await self.enrich_person(
|
||||
enrich_query, credentials
|
||||
)
|
||||
# Merge enriched data with original data, complementing instead of replacing
|
||||
return self.merge_contact_data(person, enriched_person)
|
||||
except Exception:
|
||||
return person # If enrichment fails, use original person data
|
||||
|
||||
people = await asyncio.gather(
|
||||
*(enrich_or_fallback(person) for person in people)
|
||||
)
|
||||
|
||||
query = SearchPeopleRequest(**input_data.model_dump(exclude={"credentials"}))
|
||||
people = self.search_people(query, credentials)
|
||||
for person in people:
|
||||
yield "person", person
|
||||
yield "people", people
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
from backend.blocks.apollo._api import ApolloClient
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
ApolloCredentials,
|
||||
ApolloCredentialsInput,
|
||||
)
|
||||
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class GetPersonDetailBlock(Block):
|
||||
"""Get detailed person data with Apollo API, including email reveal"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
person_id: str = SchemaField(
|
||||
description="Apollo person ID to enrich (most accurate method)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
first_name: str = SchemaField(
|
||||
description="First name of the person to enrich",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
last_name: str = SchemaField(
|
||||
description="Last name of the person to enrich",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
name: str = SchemaField(
|
||||
description="Full name of the person to enrich (alternative to first_name + last_name)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
email: str = SchemaField(
|
||||
description="Known email address of the person (helps with matching)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
domain: str = SchemaField(
|
||||
description="Company domain of the person (e.g., 'google.com')",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
company: str = SchemaField(
|
||||
description="Company name of the person",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
linkedin_url: str = SchemaField(
|
||||
description="LinkedIn URL of the person",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
organization_id: str = SchemaField(
|
||||
description="Apollo organization ID of the person's company",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
title: str = SchemaField(
|
||||
description="Job title of the person to enrich",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
credentials: ApolloCredentialsInput = CredentialsField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
contact: Contact = SchemaField(
|
||||
description="Enriched contact information",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if enrichment failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3b18d46c-3db6-42ae-a228-0ba441bdd176",
|
||||
description="Get detailed person data with Apollo API, including email reveal",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=GetPersonDetailBlock.Input,
|
||||
output_schema=GetPersonDetailBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"first_name": "John",
|
||||
"last_name": "Doe",
|
||||
"company": "Google",
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"contact",
|
||||
Contact(
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
email="john.doe@gmail.com",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
),
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"enrich_person": lambda query, credentials: Contact(
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
email="john.doe@gmail.com",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def enrich_person(
|
||||
query: EnrichPersonRequest, credentials: ApolloCredentials
|
||||
) -> Contact:
|
||||
client = ApolloClient(credentials)
|
||||
return await client.enrich_person(query)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: ApolloCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
query = EnrichPersonRequest(**input_data.model_dump())
|
||||
yield "contact", await self.enrich_person(query, credentials)
|
||||
@@ -1,15 +0,0 @@
|
||||
AYRSHARE_BLOCK_IDS = [
|
||||
"cbd52c2a-06d2-43ed-9560-6576cc163283", # PostToBlueskyBlock
|
||||
"3352f512-3524-49ed-a08f-003042da2fc1", # PostToFacebookBlock
|
||||
"9e8f844e-b4a5-4b25-80f2-9e1dd7d67625", # PostToXBlock
|
||||
"589af4e4-507f-42fd-b9ac-a67ecef25811", # PostToLinkedInBlock
|
||||
"89b02b96-a7cb-46f4-9900-c48b32fe1552", # PostToInstagramBlock
|
||||
"0082d712-ff1b-4c3d-8a8d-6c7721883b83", # PostToYouTubeBlock
|
||||
"c7733580-3c72-483e-8e47-a8d58754d853", # PostToRedditBlock
|
||||
"47bc74eb-4af2-452c-b933-af377c7287df", # PostToTelegramBlock
|
||||
"2c38c783-c484-4503-9280-ef5d1d345a7e", # PostToGMBBlock
|
||||
"3ca46e05-dbaa-4afb-9e95-5a429c4177e6", # PostToPinterestBlock
|
||||
"7faf4b27-96b0-4f05-bf64-e0de54ae74e1", # PostToTikTokBlock
|
||||
"f8c3b2e1-9d4a-4e5f-8c7b-6a9e8d2f1c3b", # PostToThreadsBlock
|
||||
"a9d7f854-2c83-4e96-b3a1-7f2e9c5d4b8e", # PostToSnapchatBlock
|
||||
]
|
||||
@@ -1,152 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.block import BlockSchema
|
||||
from backend.data.model import SchemaField, UserIntegrations
|
||||
from backend.integrations.ayrshare import AyrshareClient
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
|
||||
|
||||
async def get_profile_key(user_id: str):
|
||||
user_integrations: UserIntegrations = (
|
||||
await get_database_manager_async_client().get_user_integrations(user_id)
|
||||
)
|
||||
return user_integrations.managed_credentials.ayrshare_profile_key
|
||||
|
||||
|
||||
class BaseAyrshareInput(BlockSchema):
|
||||
"""Base input model for Ayrshare social media posts with common fields."""
|
||||
|
||||
post: str = SchemaField(
|
||||
description="The post text to be published", default="", advanced=False
|
||||
)
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs to include. Set is_video in advanced settings to true if you want to upload videos.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
is_video: bool = SchemaField(
|
||||
description="Whether the media is a video", default=False, advanced=True
|
||||
)
|
||||
schedule_date: Optional[datetime] = SchemaField(
|
||||
description="UTC datetime for scheduling (YYYY-MM-DDThh:mm:ssZ)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
disable_comments: bool = SchemaField(
|
||||
description="Whether to disable comments", default=False, advanced=True
|
||||
)
|
||||
shorten_links: bool = SchemaField(
|
||||
description="Whether to shorten links", default=False, advanced=True
|
||||
)
|
||||
unsplash: Optional[str] = SchemaField(
|
||||
description="Unsplash image configuration", default=None, advanced=True
|
||||
)
|
||||
requires_approval: bool = SchemaField(
|
||||
description="Whether to enable approval workflow",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
random_post: bool = SchemaField(
|
||||
description="Whether to generate random post text",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
random_media_url: bool = SchemaField(
|
||||
description="Whether to generate random media", default=False, advanced=True
|
||||
)
|
||||
notes: Optional[str] = SchemaField(
|
||||
description="Additional notes for the post", default=None, advanced=True
|
||||
)
|
||||
|
||||
|
||||
class CarouselItem(BaseModel):
|
||||
"""Model for Facebook carousel items."""
|
||||
|
||||
name: str = Field(..., description="The name of the item")
|
||||
link: str = Field(..., description="The link of the item")
|
||||
picture: str = Field(..., description="The picture URL of the item")
|
||||
|
||||
|
||||
class CallToAction(BaseModel):
|
||||
"""Model for Google My Business Call to Action."""
|
||||
|
||||
action_type: str = Field(
|
||||
..., description="Type of action (book, order, shop, learn_more, sign_up, call)"
|
||||
)
|
||||
url: Optional[str] = Field(
|
||||
description="URL for the action (not required for 'call' action)"
|
||||
)
|
||||
|
||||
|
||||
class EventDetails(BaseModel):
|
||||
"""Model for Google My Business Event details."""
|
||||
|
||||
title: str = Field(..., description="Event title")
|
||||
start_date: str = Field(..., description="Event start date (ISO format)")
|
||||
end_date: str = Field(..., description="Event end date (ISO format)")
|
||||
|
||||
|
||||
class OfferDetails(BaseModel):
|
||||
"""Model for Google My Business Offer details."""
|
||||
|
||||
title: str = Field(..., description="Offer title")
|
||||
start_date: str = Field(..., description="Offer start date (ISO format)")
|
||||
end_date: str = Field(..., description="Offer end date (ISO format)")
|
||||
coupon_code: str = Field(..., description="Coupon code (max 58 characters)")
|
||||
redeem_online_url: str = Field(..., description="URL to redeem the offer")
|
||||
terms_conditions: str = Field(..., description="Terms and conditions")
|
||||
|
||||
|
||||
class InstagramUserTag(BaseModel):
|
||||
"""Model for Instagram user tags."""
|
||||
|
||||
username: str = Field(..., description="Instagram username (without @)")
|
||||
x: Optional[float] = Field(description="X coordinate (0.0-1.0) for image posts")
|
||||
y: Optional[float] = Field(description="Y coordinate (0.0-1.0) for image posts")
|
||||
|
||||
|
||||
class LinkedInTargeting(BaseModel):
|
||||
"""Model for LinkedIn audience targeting."""
|
||||
|
||||
countries: Optional[list[str]] = Field(
|
||||
description="Country codes (e.g., ['US', 'IN', 'DE', 'GB'])"
|
||||
)
|
||||
seniorities: Optional[list[str]] = Field(
|
||||
description="Seniority levels (e.g., ['Senior', 'VP'])"
|
||||
)
|
||||
degrees: Optional[list[str]] = Field(description="Education degrees")
|
||||
fields_of_study: Optional[list[str]] = Field(description="Fields of study")
|
||||
industries: Optional[list[str]] = Field(description="Industry categories")
|
||||
job_functions: Optional[list[str]] = Field(description="Job function categories")
|
||||
staff_count_ranges: Optional[list[str]] = Field(description="Company size ranges")
|
||||
|
||||
|
||||
class PinterestCarouselOption(BaseModel):
|
||||
"""Model for Pinterest carousel image options."""
|
||||
|
||||
title: Optional[str] = Field(description="Image title")
|
||||
link: Optional[str] = Field(description="External destination link for the image")
|
||||
description: Optional[str] = Field(description="Image description")
|
||||
|
||||
|
||||
class YouTubeTargeting(BaseModel):
|
||||
"""Model for YouTube country targeting."""
|
||||
|
||||
block: Optional[list[str]] = Field(
|
||||
description="Country codes to block (e.g., ['US', 'CA'])"
|
||||
)
|
||||
allow: Optional[list[str]] = Field(
|
||||
description="Country codes to allow (e.g., ['GB', 'AU'])"
|
||||
)
|
||||
|
||||
|
||||
def create_ayrshare_client():
|
||||
"""Create an Ayrshare client instance."""
|
||||
try:
|
||||
return AyrshareClient()
|
||||
except MissingConfigError:
|
||||
return None
|
||||
@@ -1,114 +0,0 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToBlueskyBlock(Block):
|
||||
"""Block for posting to Bluesky with Bluesky-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Bluesky posts."""
|
||||
|
||||
# Override post field to include character limit information
|
||||
post: str = SchemaField(
|
||||
description="The post text to be published (max 300 characters for Bluesky)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include Bluesky-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs to include. Bluesky supports up to 4 images or 1 video.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Bluesky-specific options
|
||||
alt_text: list[str] = SchemaField(
|
||||
description="Alt text for each media item (accessibility)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="cbd52c2a-06d2-43ed-9560-6576cc163283",
|
||||
description="Post to Bluesky using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToBlueskyBlock.Input,
|
||||
output_schema=PostToBlueskyBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToBlueskyBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Bluesky with Bluesky-specific options."""
|
||||
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate character limit for Bluesky
|
||||
if len(input_data.post) > 300:
|
||||
yield "error", f"Post text exceeds Bluesky's 300 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
# Validate media constraints for Bluesky
|
||||
if len(input_data.media_urls) > 4:
|
||||
yield "error", "Bluesky supports a maximum of 4 images or 1 video"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build Bluesky-specific options
|
||||
bluesky_options = {}
|
||||
if input_data.alt_text:
|
||||
bluesky_options["altText"] = input_data.alt_text
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.BLUESKY],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
bluesky_options=bluesky_options if bluesky_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,212 +0,0 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
CarouselItem,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
|
||||
|
||||
class PostToFacebookBlock(Block):
|
||||
"""Block for posting to Facebook with Facebook-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Facebook posts."""
|
||||
|
||||
# Facebook-specific options
|
||||
is_carousel: bool = SchemaField(
|
||||
description="Whether to post a carousel", default=False, advanced=True
|
||||
)
|
||||
carousel_link: str = SchemaField(
|
||||
description="The URL for the 'See More At' button in the carousel",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
carousel_items: list[CarouselItem] = SchemaField(
|
||||
description="List of carousel items with name, link and picture URLs. Min 2, max 10 items.",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
is_reels: bool = SchemaField(
|
||||
description="Whether to post to Facebook Reels",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
reels_title: str = SchemaField(
|
||||
description="Title for the Reels video (max 255 chars)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
reels_thumbnail: str = SchemaField(
|
||||
description="Thumbnail URL for Reels video (JPEG/PNG, <10MB)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
is_story: bool = SchemaField(
|
||||
description="Whether to post as a Facebook Story",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
media_captions: list[str] = SchemaField(
|
||||
description="Captions for each media item",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
location_id: str = SchemaField(
|
||||
description="Facebook Page ID or name for location tagging",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
age_min: int = SchemaField(
|
||||
description="Minimum age for audience targeting (13,15,18,21,25)",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
target_countries: list[str] = SchemaField(
|
||||
description="List of country codes to target (max 25)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
alt_text: list[str] = SchemaField(
|
||||
description="Alt text for each media item",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
video_title: str = SchemaField(
|
||||
description="Title for video post", default="", advanced=True
|
||||
)
|
||||
video_thumbnail: str = SchemaField(
|
||||
description="Thumbnail URL for video post", default="", advanced=True
|
||||
)
|
||||
is_draft: bool = SchemaField(
|
||||
description="Save as draft in Meta Business Suite",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
scheduled_publish_date: str = SchemaField(
|
||||
description="Schedule publish time in Meta Business Suite (UTC)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
preview_link: str = SchemaField(
|
||||
description="URL for custom link preview", default="", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="3352f512-3524-49ed-a08f-003042da2fc1",
|
||||
description="Post to Facebook using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToFacebookBlock.Input,
|
||||
output_schema=PostToFacebookBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToFacebookBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Facebook with Facebook-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build Facebook-specific options
|
||||
facebook_options = {}
|
||||
if input_data.is_carousel:
|
||||
facebook_options["isCarousel"] = True
|
||||
if input_data.carousel_link:
|
||||
facebook_options["carouselLink"] = input_data.carousel_link
|
||||
if input_data.carousel_items:
|
||||
facebook_options["carouselItems"] = [
|
||||
item.dict() for item in input_data.carousel_items
|
||||
]
|
||||
|
||||
if input_data.is_reels:
|
||||
facebook_options["isReels"] = True
|
||||
if input_data.reels_title:
|
||||
facebook_options["reelsTitle"] = input_data.reels_title
|
||||
if input_data.reels_thumbnail:
|
||||
facebook_options["reelsThumbnail"] = input_data.reels_thumbnail
|
||||
|
||||
if input_data.is_story:
|
||||
facebook_options["isStory"] = True
|
||||
|
||||
if input_data.media_captions:
|
||||
facebook_options["mediaCaptions"] = input_data.media_captions
|
||||
|
||||
if input_data.location_id:
|
||||
facebook_options["locationId"] = input_data.location_id
|
||||
|
||||
if input_data.age_min > 0:
|
||||
facebook_options["ageMin"] = input_data.age_min
|
||||
|
||||
if input_data.target_countries:
|
||||
facebook_options["targetCountries"] = input_data.target_countries
|
||||
|
||||
if input_data.alt_text:
|
||||
facebook_options["altText"] = input_data.alt_text
|
||||
|
||||
if input_data.video_title:
|
||||
facebook_options["videoTitle"] = input_data.video_title
|
||||
|
||||
if input_data.video_thumbnail:
|
||||
facebook_options["videoThumbnail"] = input_data.video_thumbnail
|
||||
|
||||
if input_data.is_draft:
|
||||
facebook_options["isDraft"] = True
|
||||
|
||||
if input_data.scheduled_publish_date:
|
||||
facebook_options["scheduledPublishDate"] = input_data.scheduled_publish_date
|
||||
|
||||
if input_data.preview_link:
|
||||
facebook_options["previewLink"] = input_data.preview_link
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.FACEBOOK],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
facebook_options=facebook_options if facebook_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,210 +0,0 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToGMBBlock(Block):
|
||||
"""Block for posting to Google My Business with GMB-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Google My Business posts."""
|
||||
|
||||
# Override media_urls to include GMB-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs. GMB supports only one image or video per post.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# GMB-specific options
|
||||
is_photo_video: bool = SchemaField(
|
||||
description="Whether this is a photo/video post (appears in Photos section)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
photo_category: str = SchemaField(
|
||||
description="Category for photo/video: cover, profile, logo, exterior, interior, product, at_work, food_and_drink, menu, common_area, rooms, teams",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
# Call to action options (flattened from CallToAction object)
|
||||
call_to_action_type: str = SchemaField(
|
||||
description="Type of action button: 'book', 'order', 'shop', 'learn_more', 'sign_up', or 'call'",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
call_to_action_url: str = SchemaField(
|
||||
description="URL for the action button (not required for 'call' action)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
# Event details options (flattened from EventDetails object)
|
||||
event_title: str = SchemaField(
|
||||
description="Event title for event posts",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
event_start_date: str = SchemaField(
|
||||
description="Event start date in ISO format (e.g., '2024-03-15T09:00:00Z')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
event_end_date: str = SchemaField(
|
||||
description="Event end date in ISO format (e.g., '2024-03-15T17:00:00Z')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
# Offer details options (flattened from OfferDetails object)
|
||||
offer_title: str = SchemaField(
|
||||
description="Offer title for promotional posts",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
offer_start_date: str = SchemaField(
|
||||
description="Offer start date in ISO format (e.g., '2024-03-15T00:00:00Z')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
offer_end_date: str = SchemaField(
|
||||
description="Offer end date in ISO format (e.g., '2024-04-15T23:59:59Z')",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
offer_coupon_code: str = SchemaField(
|
||||
description="Coupon code for the offer (max 58 characters)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
offer_redeem_online_url: str = SchemaField(
|
||||
description="URL where customers can redeem the offer online",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
offer_terms_conditions: str = SchemaField(
|
||||
description="Terms and conditions for the offer",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="2c38c783-c484-4503-9280-ef5d1d345a7e",
|
||||
description="Post to Google My Business using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToGMBBlock.Input,
|
||||
output_schema=PostToGMBBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: "PostToGMBBlock.Input", *, user_id: str, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""Post to Google My Business with GMB-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate GMB constraints
|
||||
if len(input_data.media_urls) > 1:
|
||||
yield "error", "Google My Business supports only one image or video per post"
|
||||
return
|
||||
|
||||
# Validate offer coupon code length
|
||||
if input_data.offer_coupon_code and len(input_data.offer_coupon_code) > 58:
|
||||
yield "error", "GMB offer coupon code cannot exceed 58 characters"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build GMB-specific options
|
||||
gmb_options = {}
|
||||
|
||||
# Photo/Video post options
|
||||
if input_data.is_photo_video:
|
||||
gmb_options["isPhotoVideo"] = True
|
||||
if input_data.photo_category:
|
||||
gmb_options["category"] = input_data.photo_category
|
||||
|
||||
# Call to Action (from flattened fields)
|
||||
if input_data.call_to_action_type:
|
||||
cta_dict = {"actionType": input_data.call_to_action_type}
|
||||
# URL not required for 'call' action type
|
||||
if (
|
||||
input_data.call_to_action_type != "call"
|
||||
and input_data.call_to_action_url
|
||||
):
|
||||
cta_dict["url"] = input_data.call_to_action_url
|
||||
gmb_options["callToAction"] = cta_dict
|
||||
|
||||
# Event details (from flattened fields)
|
||||
if (
|
||||
input_data.event_title
|
||||
and input_data.event_start_date
|
||||
and input_data.event_end_date
|
||||
):
|
||||
gmb_options["event"] = {
|
||||
"title": input_data.event_title,
|
||||
"startDate": input_data.event_start_date,
|
||||
"endDate": input_data.event_end_date,
|
||||
}
|
||||
|
||||
# Offer details (from flattened fields)
|
||||
if (
|
||||
input_data.offer_title
|
||||
and input_data.offer_start_date
|
||||
and input_data.offer_end_date
|
||||
and input_data.offer_coupon_code
|
||||
and input_data.offer_redeem_online_url
|
||||
and input_data.offer_terms_conditions
|
||||
):
|
||||
gmb_options["offer"] = {
|
||||
"title": input_data.offer_title,
|
||||
"startDate": input_data.offer_start_date,
|
||||
"endDate": input_data.offer_end_date,
|
||||
"couponCode": input_data.offer_coupon_code,
|
||||
"redeemOnlineUrl": input_data.offer_redeem_online_url,
|
||||
"termsConditions": input_data.offer_terms_conditions,
|
||||
}
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.GOOGLE_MY_BUSINESS],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
gmb_options=gmb_options if gmb_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,249 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
InstagramUserTag,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
|
||||
|
||||
class PostToInstagramBlock(Block):
|
||||
"""Block for posting to Instagram with Instagram-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Instagram posts."""
|
||||
|
||||
# Override post field to include Instagram-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (max 2,200 chars, up to 30 hashtags, 3 @mentions)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include Instagram-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs. Instagram supports up to 10 images/videos in a carousel.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Instagram-specific options
|
||||
is_story: bool | None = SchemaField(
|
||||
description="Whether to post as Instagram Story (24-hour expiration)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# ------- REELS OPTIONS -------
|
||||
share_reels_feed: bool | None = SchemaField(
|
||||
description="Whether Reel should appear in both Feed and Reels tabs",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
audio_name: str | None = SchemaField(
|
||||
description="Audio name for Reels (e.g., 'The Weeknd - Blinding Lights')",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
thumbnail: str | None = SchemaField(
|
||||
description="Thumbnail URL for Reel video", default=None, advanced=True
|
||||
)
|
||||
thumbnail_offset: int | None = SchemaField(
|
||||
description="Thumbnail frame offset in milliseconds (default: 0)",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# ------- POST OPTIONS -------
|
||||
|
||||
alt_text: list[str] = SchemaField(
|
||||
description="Alt text for each media item (up to 1,000 chars each, accessibility feature), each item in the list corresponds to a media item in the media_urls list",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
location_id: str | None = SchemaField(
|
||||
description="Facebook Page ID or name for location tagging (e.g., '7640348500' or '@guggenheimmuseum')",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
user_tags: list[dict[str, Any]] = SchemaField(
|
||||
description="List of users to tag with coordinates for images",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
collaborators: list[str] = SchemaField(
|
||||
description="Instagram usernames to invite as collaborators (max 3, public accounts only)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
auto_resize: bool | None = SchemaField(
|
||||
description="Auto-resize images to 1080x1080px for Instagram",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="89b02b96-a7cb-46f4-9900-c48b32fe1552",
|
||||
description="Post to Instagram using Ayrshare. Requires a Business or Creator Instagram Account connected with a Facebook Page",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToInstagramBlock.Input,
|
||||
output_schema=PostToInstagramBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToInstagramBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Instagram with Instagram-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate Instagram constraints
|
||||
if len(input_data.post) > 2200:
|
||||
yield "error", f"Instagram post text exceeds 2,200 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 10:
|
||||
yield "error", "Instagram supports a maximum of 10 images/videos in a carousel"
|
||||
return
|
||||
|
||||
if len(input_data.collaborators) > 3:
|
||||
yield "error", "Instagram supports a maximum of 3 collaborators"
|
||||
return
|
||||
|
||||
# Validate that if any reel option is set, all required reel options are set
|
||||
reel_options = [
|
||||
input_data.share_reels_feed,
|
||||
input_data.audio_name,
|
||||
input_data.thumbnail,
|
||||
]
|
||||
|
||||
if any(reel_options) and not all(reel_options):
|
||||
yield "error", "When posting a reel, all reel options must be set: share_reels_feed, audio_name, and either thumbnail or thumbnail_offset"
|
||||
return
|
||||
|
||||
# Count hashtags and mentions
|
||||
hashtag_count = input_data.post.count("#")
|
||||
mention_count = input_data.post.count("@")
|
||||
|
||||
if hashtag_count > 30:
|
||||
yield "error", f"Instagram allows maximum 30 hashtags ({hashtag_count} found)"
|
||||
return
|
||||
|
||||
if mention_count > 3:
|
||||
yield "error", f"Instagram allows maximum 3 @mentions ({mention_count} found)"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build Instagram-specific options
|
||||
instagram_options = {}
|
||||
|
||||
# Stories
|
||||
if input_data.is_story:
|
||||
instagram_options["stories"] = True
|
||||
|
||||
# Reels options
|
||||
if input_data.share_reels_feed is not None:
|
||||
instagram_options["shareReelsFeed"] = input_data.share_reels_feed
|
||||
|
||||
if input_data.audio_name:
|
||||
instagram_options["audioName"] = input_data.audio_name
|
||||
|
||||
if input_data.thumbnail:
|
||||
instagram_options["thumbNail"] = input_data.thumbnail
|
||||
elif input_data.thumbnail_offset and input_data.thumbnail_offset > 0:
|
||||
instagram_options["thumbNailOffset"] = input_data.thumbnail_offset
|
||||
|
||||
# Alt text
|
||||
if input_data.alt_text:
|
||||
# Validate alt text length
|
||||
for i, alt in enumerate(input_data.alt_text):
|
||||
if len(alt) > 1000:
|
||||
yield "error", f"Alt text {i+1} exceeds 1,000 character limit ({len(alt)} characters)"
|
||||
return
|
||||
instagram_options["altText"] = input_data.alt_text
|
||||
|
||||
# Location
|
||||
if input_data.location_id:
|
||||
instagram_options["locationId"] = input_data.location_id
|
||||
|
||||
# User tags
|
||||
if input_data.user_tags:
|
||||
user_tags_list = []
|
||||
for tag in input_data.user_tags:
|
||||
try:
|
||||
tag_obj = InstagramUserTag(**tag)
|
||||
except Exception as e:
|
||||
yield "error", f"Invalid user tag: {e}, tages need to be a dictionary with a 3 items: username (str), x (float) and y (float)"
|
||||
return
|
||||
tag_dict: dict[str, float | str] = {"username": tag_obj.username}
|
||||
if tag_obj.x is not None and tag_obj.y is not None:
|
||||
# Validate coordinates
|
||||
if not (0.0 <= tag_obj.x <= 1.0) or not (0.0 <= tag_obj.y <= 1.0):
|
||||
yield "error", f"User tag coordinates must be between 0.0 and 1.0 (user: {tag_obj.username})"
|
||||
return
|
||||
tag_dict["x"] = tag_obj.x
|
||||
tag_dict["y"] = tag_obj.y
|
||||
user_tags_list.append(tag_dict)
|
||||
instagram_options["userTags"] = user_tags_list
|
||||
|
||||
# Collaborators
|
||||
if input_data.collaborators:
|
||||
instagram_options["collaborators"] = input_data.collaborators
|
||||
|
||||
# Auto resize
|
||||
if input_data.auto_resize:
|
||||
instagram_options["autoResize"] = True
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.INSTAGRAM],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
instagram_options=instagram_options if instagram_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,222 +0,0 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToLinkedInBlock(Block):
|
||||
"""Block for posting to LinkedIn with LinkedIn-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for LinkedIn posts."""
|
||||
|
||||
# Override post field to include LinkedIn-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (max 3,000 chars, hashtags supported with #)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include LinkedIn-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs. LinkedIn supports up to 9 images, videos, or documents (PPT, PPTX, DOC, DOCX, PDF <100MB, <300 pages).",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# LinkedIn-specific options
|
||||
visibility: str = SchemaField(
|
||||
description="Post visibility: 'public' (default), 'connections' (personal only), 'loggedin'",
|
||||
default="public",
|
||||
advanced=True,
|
||||
)
|
||||
alt_text: list[str] = SchemaField(
|
||||
description="Alt text for each image (accessibility feature, not supported for videos/documents)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
titles: list[str] = SchemaField(
|
||||
description="Title/caption for each image or video",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
document_title: str = SchemaField(
|
||||
description="Title for document posts (max 400 chars, uses filename if not specified)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
thumbnail: str = SchemaField(
|
||||
description="Thumbnail URL for video (PNG/JPG, same dimensions as video, <10MB)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
# LinkedIn targeting options (flattened from LinkedInTargeting object)
|
||||
targeting_countries: list[str] | None = SchemaField(
|
||||
description="Country codes for targeting (e.g., ['US', 'IN', 'DE', 'GB']). Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_seniorities: list[str] | None = SchemaField(
|
||||
description="Seniority levels for targeting (e.g., ['Senior', 'VP']). Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_degrees: list[str] | None = SchemaField(
|
||||
description="Education degrees for targeting. Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_fields_of_study: list[str] | None = SchemaField(
|
||||
description="Fields of study for targeting. Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_industries: list[str] | None = SchemaField(
|
||||
description="Industry categories for targeting. Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_job_functions: list[str] | None = SchemaField(
|
||||
description="Job function categories for targeting. Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_staff_count_ranges: list[str] | None = SchemaField(
|
||||
description="Company size ranges for targeting. Requires 300+ followers in target audience.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="589af4e4-507f-42fd-b9ac-a67ecef25811",
|
||||
description="Post to LinkedIn using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToLinkedInBlock.Input,
|
||||
output_schema=PostToLinkedInBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToLinkedInBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to LinkedIn with LinkedIn-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate LinkedIn constraints
|
||||
if len(input_data.post) > 3000:
|
||||
yield "error", f"LinkedIn post text exceeds 3,000 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 9:
|
||||
yield "error", "LinkedIn supports a maximum of 9 images/videos/documents"
|
||||
return
|
||||
|
||||
if input_data.document_title and len(input_data.document_title) > 400:
|
||||
yield "error", f"LinkedIn document title exceeds 400 character limit ({len(input_data.document_title)} characters)"
|
||||
return
|
||||
|
||||
# Validate visibility option
|
||||
valid_visibility = ["public", "connections", "loggedin"]
|
||||
if input_data.visibility not in valid_visibility:
|
||||
yield "error", f"LinkedIn visibility must be one of: {', '.join(valid_visibility)}"
|
||||
return
|
||||
|
||||
# Check for document extensions
|
||||
document_extensions = [".ppt", ".pptx", ".doc", ".docx", ".pdf"]
|
||||
has_documents = any(
|
||||
any(url.lower().endswith(ext) for ext in document_extensions)
|
||||
for url in input_data.media_urls
|
||||
)
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build LinkedIn-specific options
|
||||
linkedin_options = {}
|
||||
|
||||
# Visibility
|
||||
if input_data.visibility != "public":
|
||||
linkedin_options["visibility"] = input_data.visibility
|
||||
|
||||
# Alt text (not supported for videos or documents)
|
||||
if input_data.alt_text and not has_documents:
|
||||
linkedin_options["altText"] = input_data.alt_text
|
||||
|
||||
# Titles/captions
|
||||
if input_data.titles:
|
||||
linkedin_options["titles"] = input_data.titles
|
||||
|
||||
# Document title
|
||||
if input_data.document_title and has_documents:
|
||||
linkedin_options["title"] = input_data.document_title
|
||||
|
||||
# Video thumbnail
|
||||
if input_data.thumbnail:
|
||||
linkedin_options["thumbNail"] = input_data.thumbnail
|
||||
|
||||
# Audience targeting (from flattened fields)
|
||||
targeting_dict = {}
|
||||
if input_data.targeting_countries:
|
||||
targeting_dict["countries"] = input_data.targeting_countries
|
||||
if input_data.targeting_seniorities:
|
||||
targeting_dict["seniorities"] = input_data.targeting_seniorities
|
||||
if input_data.targeting_degrees:
|
||||
targeting_dict["degrees"] = input_data.targeting_degrees
|
||||
if input_data.targeting_fields_of_study:
|
||||
targeting_dict["fieldsOfStudy"] = input_data.targeting_fields_of_study
|
||||
if input_data.targeting_industries:
|
||||
targeting_dict["industries"] = input_data.targeting_industries
|
||||
if input_data.targeting_job_functions:
|
||||
targeting_dict["jobFunctions"] = input_data.targeting_job_functions
|
||||
if input_data.targeting_staff_count_ranges:
|
||||
targeting_dict["staffCountRanges"] = input_data.targeting_staff_count_ranges
|
||||
|
||||
if targeting_dict:
|
||||
linkedin_options["targeting"] = targeting_dict
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.LINKEDIN],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
linkedin_options=linkedin_options if linkedin_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,214 +0,0 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import (
|
||||
BaseAyrshareInput,
|
||||
PinterestCarouselOption,
|
||||
create_ayrshare_client,
|
||||
get_profile_key,
|
||||
)
|
||||
|
||||
|
||||
class PostToPinterestBlock(Block):
|
||||
"""Block for posting to Pinterest with Pinterest-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Pinterest posts."""
|
||||
|
||||
# Override post field to include Pinterest-specific information
|
||||
post: str = SchemaField(
|
||||
description="Pin description (max 500 chars, links not clickable - use link field instead)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include Pinterest-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Required image/video URLs. Pinterest requires at least one image. Videos need thumbnail. Up to 5 images for carousel.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Pinterest-specific options
|
||||
pin_title: str = SchemaField(
|
||||
description="Pin title displayed in 'Add your title' section (max 100 chars)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
link: str = SchemaField(
|
||||
description="Clickable destination URL when users click the pin (max 2048 chars)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
board_id: str = SchemaField(
|
||||
description="Pinterest Board ID to post to (from /user/details endpoint, uses default board if not specified)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
note: str = SchemaField(
|
||||
description="Private note for the pin (only visible to you and board collaborators)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
thumbnail: str = SchemaField(
|
||||
description="Required thumbnail URL for video pins (must have valid image Content-Type)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
carousel_options: list[PinterestCarouselOption] = SchemaField(
|
||||
description="Options for each image in carousel (title, link, description per image)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
alt_text: list[str] = SchemaField(
|
||||
description="Alt text for each image/video (max 500 chars each, accessibility feature)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="3ca46e05-dbaa-4afb-9e95-5a429c4177e6",
|
||||
description="Post to Pinterest using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToPinterestBlock.Input,
|
||||
output_schema=PostToPinterestBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToPinterestBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Pinterest with Pinterest-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate Pinterest constraints
|
||||
if len(input_data.post) > 500:
|
||||
yield "error", f"Pinterest pin description exceeds 500 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.pin_title) > 100:
|
||||
yield "error", f"Pinterest pin title exceeds 100 character limit ({len(input_data.pin_title)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.link) > 2048:
|
||||
yield "error", f"Pinterest link URL exceeds 2048 character limit ({len(input_data.link)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) == 0:
|
||||
yield "error", "Pinterest requires at least one image or video"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 5:
|
||||
yield "error", "Pinterest supports a maximum of 5 images in a carousel"
|
||||
return
|
||||
|
||||
# Check if video is included and thumbnail is provided
|
||||
video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".webm"]
|
||||
has_video = any(
|
||||
any(url.lower().endswith(ext) for ext in video_extensions)
|
||||
for url in input_data.media_urls
|
||||
)
|
||||
|
||||
if (has_video or input_data.is_video) and not input_data.thumbnail:
|
||||
yield "error", "Pinterest video pins require a thumbnail URL"
|
||||
return
|
||||
|
||||
# Validate alt text length
|
||||
for i, alt in enumerate(input_data.alt_text):
|
||||
if len(alt) > 500:
|
||||
yield "error", f"Pinterest alt text {i+1} exceeds 500 character limit ({len(alt)} characters)"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build Pinterest-specific options
|
||||
pinterest_options = {}
|
||||
|
||||
# Pin title
|
||||
if input_data.pin_title:
|
||||
pinterest_options["title"] = input_data.pin_title
|
||||
|
||||
# Clickable link
|
||||
if input_data.link:
|
||||
pinterest_options["link"] = input_data.link
|
||||
|
||||
# Board ID
|
||||
if input_data.board_id:
|
||||
pinterest_options["boardId"] = input_data.board_id
|
||||
|
||||
# Private note
|
||||
if input_data.note:
|
||||
pinterest_options["note"] = input_data.note
|
||||
|
||||
# Video thumbnail
|
||||
if input_data.thumbnail:
|
||||
pinterest_options["thumbNail"] = input_data.thumbnail
|
||||
|
||||
# Carousel options
|
||||
if input_data.carousel_options:
|
||||
carousel_list = []
|
||||
for option in input_data.carousel_options:
|
||||
carousel_dict = {}
|
||||
if option.title:
|
||||
carousel_dict["title"] = option.title
|
||||
if option.link:
|
||||
carousel_dict["link"] = option.link
|
||||
if option.description:
|
||||
carousel_dict["description"] = option.description
|
||||
if carousel_dict: # Only add if not empty
|
||||
carousel_list.append(carousel_dict)
|
||||
if carousel_list:
|
||||
pinterest_options["carouselOptions"] = carousel_list
|
||||
|
||||
# Alt text
|
||||
if input_data.alt_text:
|
||||
pinterest_options["altText"] = input_data.alt_text
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.PINTEREST],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
pinterest_options=pinterest_options if pinterest_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,69 +0,0 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToRedditBlock(Block):
|
||||
"""Block for posting to Reddit."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Reddit posts."""
|
||||
|
||||
pass # Uses all base fields
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="c7733580-3c72-483e-8e47-a8d58754d853",
|
||||
description="Post to Reddit using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToRedditBlock.Input,
|
||||
output_schema=PostToRedditBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: "PostToRedditBlock.Input", *, user_id: str, **kwargs
|
||||
) -> BlockOutput:
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured."
|
||||
return
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.REDDIT],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,129 +0,0 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToSnapchatBlock(Block):
|
||||
"""Block for posting to Snapchat with Snapchat-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Snapchat posts."""
|
||||
|
||||
# Override post field to include Snapchat-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (optional for video-only content)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include Snapchat-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Required video URL for Snapchat posts. Snapchat only supports video content.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Snapchat-specific options
|
||||
story_type: str = SchemaField(
|
||||
description="Type of Snapchat content: 'story' (24-hour Stories), 'saved_story' (Saved Stories), or 'spotlight' (Spotlight posts)",
|
||||
default="story",
|
||||
advanced=True,
|
||||
)
|
||||
video_thumbnail: str = SchemaField(
|
||||
description="Thumbnail URL for video content (optional, auto-generated if not provided)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="a9d7f854-2c83-4e96-b3a1-7f2e9c5d4b8e",
|
||||
description="Post to Snapchat using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToSnapchatBlock.Input,
|
||||
output_schema=PostToSnapchatBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToSnapchatBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Snapchat with Snapchat-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate Snapchat constraints
|
||||
if not input_data.media_urls:
|
||||
yield "error", "Snapchat requires at least one video URL"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 1:
|
||||
yield "error", "Snapchat supports only one video per post"
|
||||
return
|
||||
|
||||
# Validate story type
|
||||
valid_story_types = ["story", "saved_story", "spotlight"]
|
||||
if input_data.story_type not in valid_story_types:
|
||||
yield "error", f"Snapchat story type must be one of: {', '.join(valid_story_types)}"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build Snapchat-specific options
|
||||
snapchat_options = {}
|
||||
|
||||
# Story type
|
||||
if input_data.story_type != "story":
|
||||
snapchat_options["storyType"] = input_data.story_type
|
||||
|
||||
# Video thumbnail
|
||||
if input_data.video_thumbnail:
|
||||
snapchat_options["videoThumbnail"] = input_data.video_thumbnail
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.SNAPCHAT],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=True, # Snapchat only supports video
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
snapchat_options=snapchat_options if snapchat_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,116 +0,0 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToTelegramBlock(Block):
|
||||
"""Block for posting to Telegram with Telegram-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Telegram posts."""
|
||||
|
||||
# Override post field to include Telegram-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (empty string allowed). Use @handle to mention other Telegram users.",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include Telegram-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs. For animated GIFs, only one URL is allowed. Telegram will auto-preview links unless image/video is included.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override is_video to include GIF-specific information
|
||||
is_video: bool = SchemaField(
|
||||
description="Whether the media is a video. Set to true for animated GIFs that don't end in .gif/.GIF extension.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="47bc74eb-4af2-452c-b933-af377c7287df",
|
||||
description="Post to Telegram using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToTelegramBlock.Input,
|
||||
output_schema=PostToTelegramBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToTelegramBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Telegram with Telegram-specific validation."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate Telegram constraints
|
||||
# Check for animated GIFs - only one URL allowed
|
||||
gif_extensions = [".gif", ".GIF"]
|
||||
has_gif = any(
|
||||
any(url.endswith(ext) for ext in gif_extensions)
|
||||
for url in input_data.media_urls
|
||||
)
|
||||
|
||||
if has_gif and len(input_data.media_urls) > 1:
|
||||
yield "error", "Telegram animated GIFs support only one URL per post"
|
||||
return
|
||||
|
||||
# Auto-detect if we need to set is_video for GIFs without proper extension
|
||||
detected_is_video = input_data.is_video
|
||||
if input_data.media_urls and not has_gif and not input_data.is_video:
|
||||
# Check if this might be a GIF without proper extension
|
||||
# This is just informational - user needs to set is_video manually
|
||||
pass
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.TELEGRAM],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=detected_is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,111 +0,0 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToThreadsBlock(Block):
|
||||
"""Block for posting to Threads with Threads-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for Threads posts."""
|
||||
|
||||
# Override post field to include Threads-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (max 500 chars, empty string allowed). Only 1 hashtag allowed. Use @handle to mention users.",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include Threads-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs. Supports up to 20 images/videos in a carousel. Auto-preview links unless media is included.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="f8c3b2e1-9d4a-4e5f-8c7b-6a9e8d2f1c3b",
|
||||
description="Post to Threads using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToThreadsBlock.Input,
|
||||
output_schema=PostToThreadsBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToThreadsBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to Threads with Threads-specific validation."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate Threads constraints
|
||||
if len(input_data.post) > 500:
|
||||
yield "error", f"Threads post text exceeds 500 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 20:
|
||||
yield "error", "Threads supports a maximum of 20 images/videos in a carousel"
|
||||
return
|
||||
|
||||
# Count hashtags (only 1 allowed)
|
||||
hashtag_count = input_data.post.count("#")
|
||||
if hashtag_count > 1:
|
||||
yield "error", f"Threads allows only 1 hashtag per post ({hashtag_count} found)"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build Threads-specific options
|
||||
threads_options = {}
|
||||
# Note: Based on the documentation, Threads doesn't seem to have specific options
|
||||
# beyond the standard ones. The main constraints are validation-based.
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.THREADS],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
threads_options=threads_options if threads_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,243 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class TikTokVisibility(str, Enum):
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
||||
FOLLOWERS = "followers"
|
||||
|
||||
|
||||
class PostToTikTokBlock(Block):
|
||||
"""Block for posting to TikTok with TikTok-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for TikTok posts."""
|
||||
|
||||
# Override post field to include TikTok-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (max 2,200 chars, empty string allowed). Use @handle to mention users. Line breaks will be ignored.",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include TikTok-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Required media URLs. Either 1 video OR up to 35 images (JPG/JPEG/WEBP only). Cannot mix video and images.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# TikTok-specific options
|
||||
auto_add_music: bool = SchemaField(
|
||||
description="Whether to automatically add recommended music to the post. If you set this field to true, you can change the music later in the TikTok app.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
disable_comments: bool = SchemaField(
|
||||
description="Disable comments on the published post",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
disable_duet: bool = SchemaField(
|
||||
description="Disable duets on published video (video only)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
disable_stitch: bool = SchemaField(
|
||||
description="Disable stitch on published video (video only)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
is_ai_generated: bool = SchemaField(
|
||||
description="If you enable the toggle, your video will be labeled as “Creator labeled as AI-generated” once posted and can’t be changed. The “Creator labeled as AI-generated” label indicates that the content was completely AI-generated or significantly edited with AI.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
is_branded_content: bool = SchemaField(
|
||||
description="Whether to enable the Branded Content toggle. If this field is set to true, the video will be labeled as Branded Content, indicating you are in a paid partnership with a brand. A “Paid partnership” label will be attached to the video.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
is_brand_organic: bool = SchemaField(
|
||||
description="Whether to enable the Brand Organic Content toggle. If this field is set to true, the video will be labeled as Brand Organic Content, indicating you are promoting yourself or your own business. A “Promotional content” label will be attached to the video.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
image_cover_index: int = SchemaField(
|
||||
description="Index of image to use as cover (0-based, image posts only)",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
title: str = SchemaField(
|
||||
description="Title for image posts", default="", advanced=True
|
||||
)
|
||||
thumbnail_offset: int = SchemaField(
|
||||
description="Video thumbnail frame offset in milliseconds (video only)",
|
||||
default=0,
|
||||
advanced=True,
|
||||
)
|
||||
visibility: TikTokVisibility = SchemaField(
|
||||
description="Post visibility: 'public', 'private', 'followers', or 'friends'",
|
||||
default=TikTokVisibility.PUBLIC,
|
||||
advanced=True,
|
||||
)
|
||||
draft: bool = SchemaField(
|
||||
description="Create as draft post (video only)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7faf4b27-96b0-4f05-bf64-e0de54ae74e1",
|
||||
description="Post to TikTok using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToTikTokBlock.Input,
|
||||
output_schema=PostToTikTokBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: "PostToTikTokBlock.Input", *, user_id: str, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""Post to TikTok with TikTok-specific validation and options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate TikTok constraints
|
||||
if len(input_data.post) > 2200:
|
||||
yield "error", f"TikTok post text exceeds 2,200 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
if not input_data.media_urls:
|
||||
yield "error", "TikTok requires at least one media URL (either 1 video or up to 35 images)"
|
||||
return
|
||||
|
||||
# Check for video vs image constraints
|
||||
video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".webm"]
|
||||
image_extensions = [".jpg", ".jpeg", ".webp"]
|
||||
|
||||
has_video = input_data.is_video or any(
|
||||
any(url.lower().endswith(ext) for ext in video_extensions)
|
||||
for url in input_data.media_urls
|
||||
)
|
||||
|
||||
has_images = any(
|
||||
any(url.lower().endswith(ext) for ext in image_extensions)
|
||||
for url in input_data.media_urls
|
||||
)
|
||||
|
||||
if has_video and has_images:
|
||||
yield "error", "TikTok does not support mixing video and images in the same post"
|
||||
return
|
||||
|
||||
if has_video and len(input_data.media_urls) > 1:
|
||||
yield "error", "TikTok supports only 1 video per post"
|
||||
return
|
||||
|
||||
if has_images and len(input_data.media_urls) > 35:
|
||||
yield "error", "TikTok supports a maximum of 35 images per post"
|
||||
return
|
||||
|
||||
# Validate image cover index
|
||||
if has_images and input_data.image_cover_index >= len(input_data.media_urls):
|
||||
yield "error", f"Image cover index {input_data.image_cover_index} is out of range (max: {len(input_data.media_urls) - 1})"
|
||||
return
|
||||
|
||||
# Check for PNG files (not supported)
|
||||
has_png = any(url.lower().endswith(".png") for url in input_data.media_urls)
|
||||
if has_png:
|
||||
yield "error", "TikTok does not support PNG files. Please use JPG, JPEG, or WEBP for images."
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build TikTok-specific options
|
||||
tiktok_options = {}
|
||||
|
||||
# Common options
|
||||
if input_data.auto_add_music and has_images:
|
||||
tiktok_options["autoAddMusic"] = True
|
||||
|
||||
if input_data.disable_comments:
|
||||
tiktok_options["disableComments"] = True
|
||||
|
||||
if input_data.is_branded_content:
|
||||
tiktok_options["isBrandedContent"] = True
|
||||
|
||||
if input_data.is_brand_organic:
|
||||
tiktok_options["isBrandOrganic"] = True
|
||||
|
||||
# Video-specific options
|
||||
if has_video:
|
||||
if input_data.disable_duet:
|
||||
tiktok_options["disableDuet"] = True
|
||||
|
||||
if input_data.disable_stitch:
|
||||
tiktok_options["disableStitch"] = True
|
||||
|
||||
if input_data.is_ai_generated:
|
||||
tiktok_options["isAIGenerated"] = True
|
||||
|
||||
if input_data.thumbnail_offset > 0:
|
||||
tiktok_options["thumbNailOffset"] = input_data.thumbnail_offset
|
||||
|
||||
if input_data.draft:
|
||||
tiktok_options["draft"] = True
|
||||
|
||||
# Image-specific options
|
||||
if has_images:
|
||||
if input_data.image_cover_index > 0:
|
||||
tiktok_options["imageCoverIndex"] = input_data.image_cover_index
|
||||
|
||||
if input_data.title:
|
||||
tiktok_options["title"] = input_data.title
|
||||
|
||||
if input_data.visibility != TikTokVisibility.PUBLIC:
|
||||
tiktok_options["visibility"] = input_data.visibility.value
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.TIKTOK],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=has_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
tiktok_options=tiktok_options if tiktok_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,241 +0,0 @@
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class PostToXBlock(Block):
|
||||
"""Block for posting to X / Twitter with Twitter-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for X / Twitter posts."""
|
||||
|
||||
# Override post field to include X-specific information
|
||||
post: str = SchemaField(
|
||||
description="The post text (max 280 chars, up to 25,000 for Premium users). Use @handle to mention users. Use \\n\\n for thread breaks.",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include X-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Optional list of media URLs. X supports up to 4 images or videos per tweet. Auto-preview links unless media is included.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# X-specific options
|
||||
reply_to_id: str | None = SchemaField(
|
||||
description="ID of the tweet to reply to",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
quote_tweet_id: str | None = SchemaField(
|
||||
description="ID of the tweet to quote (low-level Tweet ID)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
poll_options: list[str] = SchemaField(
|
||||
description="Poll options (2-4 choices)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
poll_duration: int = SchemaField(
|
||||
description="Poll duration in minutes (1-10080)",
|
||||
default=1440,
|
||||
advanced=True,
|
||||
)
|
||||
alt_text: list[str] = SchemaField(
|
||||
description="Alt text for each image (max 1,000 chars each, not supported for videos)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
is_thread: bool = SchemaField(
|
||||
description="Whether to automatically break post into thread based on line breaks",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
thread_number: bool = SchemaField(
|
||||
description="Add thread numbers (1/n format) to each thread post",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
thread_media_urls: list[str] = SchemaField(
|
||||
description="Media URLs for thread posts (one per thread, use 'null' to skip)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
long_post: bool = SchemaField(
|
||||
description="Force long form post (requires Premium X account)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
long_video: bool = SchemaField(
|
||||
description="Enable long video upload (requires approval and Business/Enterprise plan)",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
subtitle_url: str = SchemaField(
|
||||
description="URL to SRT subtitle file for videos (must be HTTPS and end in .srt)",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
subtitle_language: str = SchemaField(
|
||||
description="Language code for subtitles (default: 'en')",
|
||||
default="en",
|
||||
advanced=True,
|
||||
)
|
||||
subtitle_name: str = SchemaField(
|
||||
description="Name of caption track (max 150 chars, default: 'English')",
|
||||
default="English",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9e8f844e-b4a5-4b25-80f2-9e1dd7d67625",
|
||||
description="Post to X / Twitter using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToXBlock.Input,
|
||||
output_schema=PostToXBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToXBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to X / Twitter with enhanced X-specific options."""
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate X constraints
|
||||
if not input_data.long_post and len(input_data.post) > 280:
|
||||
yield "error", f"X post text exceeds 280 character limit ({len(input_data.post)} characters). Enable 'long_post' for Premium accounts."
|
||||
return
|
||||
|
||||
if input_data.long_post and len(input_data.post) > 25000:
|
||||
yield "error", f"X long post text exceeds 25,000 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 4:
|
||||
yield "error", "X supports a maximum of 4 images or videos per tweet"
|
||||
return
|
||||
|
||||
# Validate poll options
|
||||
if input_data.poll_options:
|
||||
if len(input_data.poll_options) < 2 or len(input_data.poll_options) > 4:
|
||||
yield "error", "X polls require 2-4 options"
|
||||
return
|
||||
|
||||
if input_data.poll_duration < 1 or input_data.poll_duration > 10080:
|
||||
yield "error", "X poll duration must be between 1 and 10,080 minutes (7 days)"
|
||||
return
|
||||
|
||||
# Validate alt text
|
||||
if input_data.alt_text:
|
||||
for i, alt in enumerate(input_data.alt_text):
|
||||
if len(alt) > 1000:
|
||||
yield "error", f"X alt text {i+1} exceeds 1,000 character limit ({len(alt)} characters)"
|
||||
return
|
||||
|
||||
# Validate subtitle settings
|
||||
if input_data.subtitle_url:
|
||||
if not input_data.subtitle_url.startswith(
|
||||
"https://"
|
||||
) or not input_data.subtitle_url.endswith(".srt"):
|
||||
yield "error", "Subtitle URL must start with https:// and end with .srt"
|
||||
return
|
||||
|
||||
if len(input_data.subtitle_name) > 150:
|
||||
yield "error", f"Subtitle name exceeds 150 character limit ({len(input_data.subtitle_name)} characters)"
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided
|
||||
iso_date = (
|
||||
input_data.schedule_date.isoformat() if input_data.schedule_date else None
|
||||
)
|
||||
|
||||
# Build X-specific options
|
||||
twitter_options = {}
|
||||
|
||||
# Basic options
|
||||
if input_data.reply_to_id:
|
||||
twitter_options["replyToId"] = input_data.reply_to_id
|
||||
|
||||
if input_data.quote_tweet_id:
|
||||
twitter_options["quoteTweetId"] = input_data.quote_tweet_id
|
||||
|
||||
if input_data.long_post:
|
||||
twitter_options["longPost"] = True
|
||||
|
||||
if input_data.long_video:
|
||||
twitter_options["longVideo"] = True
|
||||
|
||||
# Poll options
|
||||
if input_data.poll_options:
|
||||
twitter_options["poll"] = {
|
||||
"duration": input_data.poll_duration,
|
||||
"options": input_data.poll_options,
|
||||
}
|
||||
|
||||
# Alt text for images
|
||||
if input_data.alt_text:
|
||||
twitter_options["altText"] = input_data.alt_text
|
||||
|
||||
# Thread options
|
||||
if input_data.is_thread:
|
||||
twitter_options["thread"] = True
|
||||
|
||||
if input_data.thread_number:
|
||||
twitter_options["threadNumber"] = True
|
||||
|
||||
if input_data.thread_media_urls:
|
||||
twitter_options["mediaUrls"] = input_data.thread_media_urls
|
||||
|
||||
# Video subtitle options
|
||||
if input_data.subtitle_url:
|
||||
twitter_options["subTitleUrl"] = input_data.subtitle_url
|
||||
twitter_options["subTitleLanguage"] = input_data.subtitle_language
|
||||
twitter_options["subTitleName"] = input_data.subtitle_name
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.TWITTER],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=input_data.is_video,
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
twitter_options=twitter_options if twitter_options else None,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,310 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from backend.integrations.ayrshare import PostIds, PostResponse, SocialPlatform
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._util import BaseAyrshareInput, create_ayrshare_client, get_profile_key
|
||||
|
||||
|
||||
class YouTubeVisibility(str, Enum):
|
||||
PRIVATE = "private"
|
||||
PUBLIC = "public"
|
||||
UNLISTED = "unlisted"
|
||||
|
||||
|
||||
class PostToYouTubeBlock(Block):
|
||||
"""Block for posting to YouTube with YouTube-specific options."""
|
||||
|
||||
class Input(BaseAyrshareInput):
|
||||
"""Input schema for YouTube posts."""
|
||||
|
||||
# Override post field to include YouTube-specific information
|
||||
post: str = SchemaField(
|
||||
description="Video description (max 5,000 chars, empty string allowed). Cannot contain < or > characters.",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# Override media_urls to include YouTube-specific constraints
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="Required video URL. YouTube only supports 1 video per post.",
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# YouTube-specific required options
|
||||
title: str = SchemaField(
|
||||
description="Video title (max 100 chars, required). Cannot contain < or > characters.",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
# YouTube-specific optional options
|
||||
visibility: YouTubeVisibility = SchemaField(
|
||||
description="Video visibility: 'private' (default), 'public' , or 'unlisted'",
|
||||
default=YouTubeVisibility.PRIVATE,
|
||||
advanced=False,
|
||||
)
|
||||
thumbnail: str | None = SchemaField(
|
||||
description="Thumbnail URL (JPEG/PNG under 2MB, must end in .png/.jpg/.jpeg). Requires phone verification.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
playlist_id: str | None = SchemaField(
|
||||
description="Playlist ID to add video (user must own playlist)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
tags: list[str] | None = SchemaField(
|
||||
description="Video tags (min 2 chars each, max 500 chars total)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
made_for_kids: bool | None = SchemaField(
|
||||
description="Self-declared kids content", default=None, advanced=True
|
||||
)
|
||||
is_shorts: bool | None = SchemaField(
|
||||
description="Post as YouTube Short (max 3 minutes, adds #shorts)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
notify_subscribers: bool | None = SchemaField(
|
||||
description="Send notification to subscribers", default=None, advanced=True
|
||||
)
|
||||
category_id: int | None = SchemaField(
|
||||
description="Video category ID (e.g., 24 = Entertainment)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
contains_synthetic_media: bool | None = SchemaField(
|
||||
description="Disclose realistic AI/synthetic content",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
publish_at: str | None = SchemaField(
|
||||
description="UTC publish time (YouTube controlled, format: 2022-10-08T21:18:36Z)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
# YouTube targeting options (flattened from YouTubeTargeting object)
|
||||
targeting_block_countries: list[str] | None = SchemaField(
|
||||
description="Country codes to block from viewing (e.g., ['US', 'CA'])",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
targeting_allow_countries: list[str] | None = SchemaField(
|
||||
description="Country codes to allow viewing (e.g., ['GB', 'AU'])",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
subtitle_url: str | None = SchemaField(
|
||||
description="URL to SRT or SBV subtitle file (must be HTTPS and end in .srt/.sbv, under 100MB)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
subtitle_language: str | None = SchemaField(
|
||||
description="Language code for subtitles (default: 'en')",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
subtitle_name: str | None = SchemaField(
|
||||
description="Name of caption track (max 150 chars, default: 'English')",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post_result: PostResponse = SchemaField(description="The result of the post")
|
||||
post: PostIds = SchemaField(description="The result of the post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0082d712-ff1b-4c3d-8a8d-6c7721883b83",
|
||||
description="Post to YouTube using Ayrshare",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
block_type=BlockType.AYRSHARE,
|
||||
input_schema=PostToYouTubeBlock.Input,
|
||||
output_schema=PostToYouTubeBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: "PostToYouTubeBlock.Input",
|
||||
*,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Post to YouTube with YouTube-specific validation and options."""
|
||||
|
||||
profile_key = await get_profile_key(user_id)
|
||||
if not profile_key:
|
||||
yield "error", "Please link a social account via Ayrshare"
|
||||
return
|
||||
|
||||
client = create_ayrshare_client()
|
||||
if not client:
|
||||
yield "error", "Ayrshare integration is not configured. Please set up the AYRSHARE_API_KEY."
|
||||
return
|
||||
|
||||
# Validate YouTube constraints
|
||||
if not input_data.title:
|
||||
yield "error", "YouTube requires a video title"
|
||||
return
|
||||
|
||||
if len(input_data.title) > 100:
|
||||
yield "error", f"YouTube title exceeds 100 character limit ({len(input_data.title)} characters)"
|
||||
return
|
||||
|
||||
if len(input_data.post) > 5000:
|
||||
yield "error", f"YouTube description exceeds 5,000 character limit ({len(input_data.post)} characters)"
|
||||
return
|
||||
|
||||
# Check for forbidden characters
|
||||
forbidden_chars = ["<", ">"]
|
||||
for char in forbidden_chars:
|
||||
if char in input_data.title:
|
||||
yield "error", f"YouTube title cannot contain '{char}' character"
|
||||
return
|
||||
if char in input_data.post:
|
||||
yield "error", f"YouTube description cannot contain '{char}' character"
|
||||
return
|
||||
|
||||
if not input_data.media_urls:
|
||||
yield "error", "YouTube requires exactly one video URL"
|
||||
return
|
||||
|
||||
if len(input_data.media_urls) > 1:
|
||||
yield "error", "YouTube supports only 1 video per post"
|
||||
return
|
||||
|
||||
# Validate visibility option
|
||||
valid_visibility = ["private", "public", "unlisted"]
|
||||
if input_data.visibility not in valid_visibility:
|
||||
yield "error", f"YouTube visibility must be one of: {', '.join(valid_visibility)}"
|
||||
return
|
||||
|
||||
# Validate thumbnail URL format
|
||||
if input_data.thumbnail:
|
||||
valid_extensions = [".png", ".jpg", ".jpeg"]
|
||||
if not any(
|
||||
input_data.thumbnail.lower().endswith(ext) for ext in valid_extensions
|
||||
):
|
||||
yield "error", "YouTube thumbnail must end in .png, .jpg, or .jpeg"
|
||||
return
|
||||
|
||||
# Validate tags
|
||||
if input_data.tags:
|
||||
total_tag_length = sum(len(tag) for tag in input_data.tags)
|
||||
if total_tag_length > 500:
|
||||
yield "error", f"YouTube tags total length exceeds 500 characters ({total_tag_length} characters)"
|
||||
return
|
||||
|
||||
for tag in input_data.tags:
|
||||
if len(tag) < 2:
|
||||
yield "error", f"YouTube tag '{tag}' is too short (minimum 2 characters)"
|
||||
return
|
||||
|
||||
# Validate subtitle URL
|
||||
if input_data.subtitle_url:
|
||||
if not input_data.subtitle_url.startswith("https://"):
|
||||
yield "error", "YouTube subtitle URL must start with https://"
|
||||
return
|
||||
|
||||
valid_subtitle_extensions = [".srt", ".sbv"]
|
||||
if not any(
|
||||
input_data.subtitle_url.lower().endswith(ext)
|
||||
for ext in valid_subtitle_extensions
|
||||
):
|
||||
yield "error", "YouTube subtitle URL must end in .srt or .sbv"
|
||||
return
|
||||
|
||||
if input_data.subtitle_name and len(input_data.subtitle_name) > 150:
|
||||
yield "error", f"YouTube subtitle name exceeds 150 character limit ({len(input_data.subtitle_name)} characters)"
|
||||
return
|
||||
|
||||
# Validate publish_at format if provided
|
||||
if input_data.publish_at and input_data.schedule_date:
|
||||
yield "error", "Cannot use both 'publish_at' and 'schedule_date'. Use 'publish_at' for YouTube-controlled publishing."
|
||||
return
|
||||
|
||||
# Convert datetime to ISO format if provided (only if not using publish_at)
|
||||
iso_date = None
|
||||
if not input_data.publish_at and input_data.schedule_date:
|
||||
iso_date = input_data.schedule_date.isoformat()
|
||||
|
||||
# Build YouTube-specific options
|
||||
youtube_options: dict[str, Any] = {"title": input_data.title}
|
||||
|
||||
# Basic options
|
||||
if input_data.visibility != "private":
|
||||
youtube_options["visibility"] = input_data.visibility
|
||||
|
||||
if input_data.thumbnail:
|
||||
youtube_options["thumbNail"] = input_data.thumbnail
|
||||
|
||||
if input_data.playlist_id:
|
||||
youtube_options["playListId"] = input_data.playlist_id
|
||||
|
||||
if input_data.tags:
|
||||
youtube_options["tags"] = input_data.tags
|
||||
|
||||
if input_data.made_for_kids:
|
||||
youtube_options["madeForKids"] = True
|
||||
|
||||
if input_data.is_shorts:
|
||||
youtube_options["shorts"] = True
|
||||
|
||||
if not input_data.notify_subscribers:
|
||||
youtube_options["notifySubscribers"] = False
|
||||
|
||||
if input_data.category_id and input_data.category_id > 0:
|
||||
youtube_options["categoryId"] = input_data.category_id
|
||||
|
||||
if input_data.contains_synthetic_media:
|
||||
youtube_options["containsSyntheticMedia"] = True
|
||||
|
||||
if input_data.publish_at:
|
||||
youtube_options["publishAt"] = input_data.publish_at
|
||||
|
||||
# Country targeting (from flattened fields)
|
||||
targeting_dict = {}
|
||||
if input_data.targeting_block_countries:
|
||||
targeting_dict["block"] = input_data.targeting_block_countries
|
||||
if input_data.targeting_allow_countries:
|
||||
targeting_dict["allow"] = input_data.targeting_allow_countries
|
||||
|
||||
if targeting_dict:
|
||||
youtube_options["targeting"] = targeting_dict
|
||||
|
||||
# Subtitle options
|
||||
if input_data.subtitle_url:
|
||||
youtube_options["subTitleUrl"] = input_data.subtitle_url
|
||||
youtube_options["subTitleLanguage"] = input_data.subtitle_language
|
||||
youtube_options["subTitleName"] = input_data.subtitle_name
|
||||
|
||||
response = await client.create_post(
|
||||
post=input_data.post,
|
||||
platforms=[SocialPlatform.YOUTUBE],
|
||||
media_urls=input_data.media_urls,
|
||||
is_video=True, # YouTube only supports videos
|
||||
schedule_date=iso_date,
|
||||
disable_comments=input_data.disable_comments,
|
||||
shorten_links=input_data.shorten_links,
|
||||
unsplash=input_data.unsplash,
|
||||
requires_approval=input_data.requires_approval,
|
||||
random_post=input_data.random_post,
|
||||
random_media_url=input_data.random_media_url,
|
||||
notes=input_data.notes,
|
||||
youtube_options=youtube_options,
|
||||
profile_key=profile_key.get_secret_value(),
|
||||
)
|
||||
yield "post_result", response
|
||||
if response.postIds:
|
||||
for p in response.postIds:
|
||||
yield "post", p
|
||||
@@ -1,26 +1,24 @@
|
||||
import enum
|
||||
from typing import Any
|
||||
from typing import Any, List
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType, convert
|
||||
from backend.util.file import MediaFile, store_media_file
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.text import TextFormatter
|
||||
from backend.util.type import convert
|
||||
|
||||
formatter = TextFormatter()
|
||||
|
||||
|
||||
class FileStoreBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
file_in: MediaFileType = SchemaField(
|
||||
file_in: MediaFile = SchemaField(
|
||||
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
|
||||
)
|
||||
base_64: bool = SchemaField(
|
||||
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
|
||||
default=False,
|
||||
advanced=True,
|
||||
title="Produce Base64 Output",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
file_out: MediaFileType = SchemaField(
|
||||
file_out: MediaFile = SchemaField(
|
||||
description="The relative path to the stored file in the temporary directory."
|
||||
)
|
||||
|
||||
@@ -34,20 +32,19 @@ class FileStoreBlock(Block):
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
async def run(
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
yield "file_out", await store_media_file(
|
||||
file_path = store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.file_in,
|
||||
user_id=user_id,
|
||||
return_content=input_data.base_64,
|
||||
return_content=False,
|
||||
)
|
||||
yield "file_out", file_path
|
||||
|
||||
|
||||
class StoreValueBlock(Block):
|
||||
@@ -89,16 +86,15 @@ class StoreValueBlock(Block):
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "output", input_data.data or input_data.input
|
||||
|
||||
|
||||
class PrintToConsoleBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: Any = SchemaField(description="The data to print to the console.")
|
||||
text: str = SchemaField(description="The text to print to the console.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="The data printed to the console.")
|
||||
status: str = SchemaField(description="The status of the print operation.")
|
||||
|
||||
def __init__(self):
|
||||
@@ -109,15 +105,451 @@ class PrintToConsoleBlock(Block):
|
||||
input_schema=PrintToConsoleBlock.Input,
|
||||
output_schema=PrintToConsoleBlock.Output,
|
||||
test_input={"text": "Hello, World!"},
|
||||
test_output=("status", "printed"),
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
print(">>>>> Print: ", input_data.text)
|
||||
yield "status", "printed"
|
||||
|
||||
|
||||
class FindInDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: Any = SchemaField(description="Dictionary to lookup from")
|
||||
key: str | int = SchemaField(description="Key to lookup in the dictionary")
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="Value found for the given key")
|
||||
missing: Any = SchemaField(
|
||||
description="Value of the input that missing the key"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0e50422c-6dee-4145-83d6-3a5a392f65de",
|
||||
description="Lookup the given key in the input dictionary/object/list and return the value.",
|
||||
input_schema=FindInDictionaryBlock.Input,
|
||||
output_schema=FindInDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{"input": {"apple": 1, "banana": 2, "cherry": 3}, "key": "banana"},
|
||||
{"input": {"x": 10, "y": 20, "z": 30}, "key": "w"},
|
||||
{"input": [1, 2, 3], "key": 1},
|
||||
{"input": [1, 2, 3], "key": 3},
|
||||
{"input": MockObject(value="!!", key="key"), "key": "key"},
|
||||
{"input": [{"k1": "v1"}, {"k2": "v2"}, {"k1": "v3"}], "key": "k1"},
|
||||
],
|
||||
test_output=[
|
||||
("output", "Hello, World!"),
|
||||
("status", "printed"),
|
||||
("output", 2),
|
||||
("missing", {"x": 10, "y": 20, "z": 30}),
|
||||
("output", 2),
|
||||
("missing", [1, 2, 3]),
|
||||
("output", "key"),
|
||||
("output", ["v1", "v3"]),
|
||||
],
|
||||
categories={BlockCategory.BASIC},
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
obj = input_data.input
|
||||
key = input_data.key
|
||||
|
||||
if isinstance(obj, dict) and key in obj:
|
||||
yield "output", obj[key]
|
||||
elif isinstance(obj, list) and isinstance(key, int) and 0 <= key < len(obj):
|
||||
yield "output", obj[key]
|
||||
elif isinstance(obj, list) and isinstance(key, str):
|
||||
if len(obj) == 0:
|
||||
yield "output", []
|
||||
elif isinstance(obj[0], dict) and key in obj[0]:
|
||||
yield "output", [item[key] for item in obj if key in item]
|
||||
else:
|
||||
yield "output", [getattr(val, key) for val in obj if hasattr(val, key)]
|
||||
elif isinstance(obj, object) and isinstance(key, str) and hasattr(obj, key):
|
||||
yield "output", getattr(obj, key)
|
||||
else:
|
||||
yield "missing", input_data.input
|
||||
|
||||
|
||||
class AgentInputBlock(Block):
|
||||
"""
|
||||
This block is used to provide input to the graph.
|
||||
|
||||
It takes in a value, name, description, default values list and bool to limit selection to default values.
|
||||
|
||||
It Outputs the value passed as input.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
name: str = SchemaField(description="The name of the input.")
|
||||
value: Any = SchemaField(
|
||||
description="The value to be passed as input.",
|
||||
default=None,
|
||||
)
|
||||
title: str | None = SchemaField(
|
||||
description="The title of the input.", default=None, advanced=True
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the input.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
placeholder_values: List[Any] = SchemaField(
|
||||
description="The placeholder values to be passed as input.",
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
limit_to_placeholder_values: bool = SchemaField(
|
||||
description="Whether to limit the selection to placeholder values.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to show the input in the advanced section, if the field is not required.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
secret: bool = SchemaField(
|
||||
description="Whether the input should be treated as a secret.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: Any = SchemaField(description="The value passed as input.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
|
||||
description="This block is used to provide input to the graph.",
|
||||
input_schema=AgentInputBlock.Input,
|
||||
output_schema=AgentInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "input_1",
|
||||
"description": "This is a test input.",
|
||||
"placeholder_values": [],
|
||||
"limit_to_placeholder_values": False,
|
||||
},
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "input_2",
|
||||
"description": "This is a test input.",
|
||||
"placeholder_values": ["Hello, World!"],
|
||||
"limit_to_placeholder_values": True,
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", "Hello, World!"),
|
||||
("result", "Hello, World!"),
|
||||
],
|
||||
categories={BlockCategory.INPUT, BlockCategory.BASIC},
|
||||
block_type=BlockType.INPUT,
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "result", input_data.value
|
||||
|
||||
|
||||
class AgentOutputBlock(Block):
|
||||
"""
|
||||
Records the output of the graph for users to see.
|
||||
|
||||
Behavior:
|
||||
If `format` is provided and the `value` is of a type that can be formatted,
|
||||
the block attempts to format the recorded_value using the `format`.
|
||||
If formatting fails or no `format` is provided, the raw `value` is output.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
value: Any = SchemaField(
|
||||
description="The value to be recorded as output.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
name: str = SchemaField(description="The name of the output.")
|
||||
title: str | None = SchemaField(
|
||||
description="The title of the output.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the output.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="The format string to be used to format the recorded_value. Use Jinja2 syntax.",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to treat the output as advanced.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
secret: bool = SchemaField(
|
||||
description="Whether the output should be treated as a secret.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="The value recorded as output.")
|
||||
name: Any = SchemaField(description="The name of the value recorded as output.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="363ae599-353e-4804-937e-b2ee3cef3da4",
|
||||
description="Stores the output of the graph for users to see.",
|
||||
input_schema=AgentOutputBlock.Input,
|
||||
output_schema=AgentOutputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "output_1",
|
||||
"description": "This is a test output.",
|
||||
"format": "{{ output_1 }}!!",
|
||||
},
|
||||
{
|
||||
"value": "42",
|
||||
"name": "output_2",
|
||||
"description": "This is another test output.",
|
||||
"format": "{{ output_2 }}",
|
||||
},
|
||||
{
|
||||
"value": MockObject(value="!!", key="key"),
|
||||
"name": "output_3",
|
||||
"description": "This is a test output with a mock object.",
|
||||
"format": "{{ output_3 }}",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("output", "Hello, World!!!"),
|
||||
("output", "42"),
|
||||
("output", MockObject(value="!!", key="key")),
|
||||
],
|
||||
categories={BlockCategory.OUTPUT, BlockCategory.BASIC},
|
||||
block_type=BlockType.OUTPUT,
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Attempts to format the recorded_value using the fmt_string if provided.
|
||||
If formatting fails or no fmt_string is given, returns the original recorded_value.
|
||||
"""
|
||||
if input_data.format:
|
||||
try:
|
||||
yield "output", formatter.format_string(
|
||||
input_data.format, {input_data.name: input_data.value}
|
||||
)
|
||||
except Exception as e:
|
||||
yield "output", f"Error: {e}, {input_data.value}"
|
||||
else:
|
||||
yield "output", input_data.value
|
||||
yield "name", input_data.name
|
||||
|
||||
|
||||
class AddToDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
default={},
|
||||
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
|
||||
)
|
||||
key: str = SchemaField(
|
||||
default="",
|
||||
description="The key for the new entry.",
|
||||
placeholder="new_key",
|
||||
advanced=False,
|
||||
)
|
||||
value: Any = SchemaField(
|
||||
default=None,
|
||||
description="The value for the new entry.",
|
||||
placeholder="new_value",
|
||||
advanced=False,
|
||||
)
|
||||
entries: dict[Any, Any] = SchemaField(
|
||||
default={},
|
||||
description="The entries to add to the dictionary. This is the batch version of the `key` and `value` fields.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_dictionary: dict = SchemaField(
|
||||
description="The dictionary with the new entry added."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="31d1064e-7446-4693-a7d4-65e5ca1180d1",
|
||||
description="Adds a new key-value pair to a dictionary. If no dictionary is provided, a new one is created.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=AddToDictionaryBlock.Input,
|
||||
output_schema=AddToDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"dictionary": {"existing_key": "existing_value"},
|
||||
"key": "new_key",
|
||||
"value": "new_value",
|
||||
},
|
||||
{"key": "first_key", "value": "first_value"},
|
||||
{
|
||||
"dictionary": {"existing_key": "existing_value"},
|
||||
"entries": {"new_key": "new_value", "first_key": "first_value"},
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"updated_dictionary",
|
||||
{"existing_key": "existing_value", "new_key": "new_value"},
|
||||
),
|
||||
("updated_dictionary", {"first_key": "first_value"}),
|
||||
(
|
||||
"updated_dictionary",
|
||||
{
|
||||
"existing_key": "existing_value",
|
||||
"new_key": "new_value",
|
||||
"first_key": "first_value",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "output", input_data.text
|
||||
yield "status", "printed"
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
updated_dict = input_data.dictionary.copy()
|
||||
|
||||
if input_data.value is not None and input_data.key:
|
||||
updated_dict[input_data.key] = input_data.value
|
||||
|
||||
for key, value in input_data.entries.items():
|
||||
updated_dict[key] = value
|
||||
|
||||
yield "updated_dictionary", updated_dict
|
||||
|
||||
|
||||
class AddToListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
default=[],
|
||||
advanced=False,
|
||||
description="The list to add the entry to. If not provided, a new list will be created.",
|
||||
)
|
||||
entry: Any = SchemaField(
|
||||
description="The entry to add to the list. Can be of any type (string, int, dict, etc.).",
|
||||
advanced=False,
|
||||
default=None,
|
||||
)
|
||||
entries: List[Any] = SchemaField(
|
||||
default=[],
|
||||
description="The entries to add to the list. This is the batch version of the `entry` field.",
|
||||
advanced=True,
|
||||
)
|
||||
position: int | None = SchemaField(
|
||||
default=None,
|
||||
description="The position to insert the new entry. If not provided, the entry will be appended to the end of the list.",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_list: List[Any] = SchemaField(
|
||||
description="The list with the new entry added."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="aeb08fc1-2fc1-4141-bc8e-f758f183a822",
|
||||
description="Adds a new entry to a list. The entry can be of any type. If no list is provided, a new one is created.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=AddToListBlock.Input,
|
||||
output_schema=AddToListBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"list": [1, "string", {"existing_key": "existing_value"}],
|
||||
"entry": {"new_key": "new_value"},
|
||||
"position": 1,
|
||||
},
|
||||
{"entry": "first_entry"},
|
||||
{"list": ["a", "b", "c"], "entry": "d"},
|
||||
{
|
||||
"entry": "e",
|
||||
"entries": ["f", "g"],
|
||||
"list": ["a", "b"],
|
||||
"position": 1,
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"updated_list",
|
||||
[
|
||||
1,
|
||||
{"new_key": "new_value"},
|
||||
"string",
|
||||
{"existing_key": "existing_value"},
|
||||
],
|
||||
),
|
||||
("updated_list", ["first_entry"]),
|
||||
("updated_list", ["a", "b", "c", "d"]),
|
||||
("updated_list", ["a", "f", "g", "e", "b"]),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
entries_added = input_data.entries.copy()
|
||||
if input_data.entry:
|
||||
entries_added.append(input_data.entry)
|
||||
|
||||
updated_list = input_data.list.copy()
|
||||
if (pos := input_data.position) is not None:
|
||||
updated_list = updated_list[:pos] + entries_added + updated_list[pos:]
|
||||
else:
|
||||
updated_list += entries_added
|
||||
|
||||
yield "updated_list", updated_list
|
||||
|
||||
|
||||
class FindInListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to search in.")
|
||||
value: Any = SchemaField(description="The value to search for.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
index: int = SchemaField(description="The index of the value in the list.")
|
||||
found: bool = SchemaField(
|
||||
description="Whether the value was found in the list."
|
||||
)
|
||||
not_found_value: Any = SchemaField(
|
||||
description="The value that was not found in the list."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5e2c6d0a-1e37-489f-b1d0-8e1812b23333",
|
||||
description="Finds the index of the value in the list.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=FindInListBlock.Input,
|
||||
output_schema=FindInListBlock.Output,
|
||||
test_input=[
|
||||
{"list": [1, 2, 3, 4, 5], "value": 3},
|
||||
{"list": [1, 2, 3, 4, 5], "value": 6},
|
||||
],
|
||||
test_output=[
|
||||
("index", 2),
|
||||
("found", True),
|
||||
("found", False),
|
||||
("not_found_value", 6),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
yield "index", input_data.list.index(input_data.value)
|
||||
yield "found", True
|
||||
except ValueError:
|
||||
yield "found", False
|
||||
yield "not_found_value", input_data.value
|
||||
|
||||
|
||||
class NoteBlock(Block):
|
||||
@@ -141,10 +573,108 @@ class NoteBlock(Block):
|
||||
block_type=BlockType.NOTE,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "output", input_data.text
|
||||
|
||||
|
||||
class CreateDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: dict[str, Any] = SchemaField(
|
||||
description="Key-value pairs to create the dictionary with",
|
||||
placeholder="e.g., {'name': 'Alice', 'age': 25}",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
dictionary: dict[str, Any] = SchemaField(
|
||||
description="The created dictionary containing the specified key-value pairs"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if dictionary creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b924ddf4-de4f-4b56-9a85-358930dcbc91",
|
||||
description="Creates a dictionary with the specified key-value pairs. Use this when you know all the values you want to add upfront.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=CreateDictionaryBlock.Input,
|
||||
output_schema=CreateDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"values": {"name": "Alice", "age": 25, "city": "New York"},
|
||||
},
|
||||
{
|
||||
"values": {"numbers": [1, 2, 3], "active": True, "score": 95.5},
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"dictionary",
|
||||
{"name": "Alice", "age": 25, "city": "New York"},
|
||||
),
|
||||
(
|
||||
"dictionary",
|
||||
{"numbers": [1, 2, 3], "active": True, "score": 95.5},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# The values are already validated by Pydantic schema
|
||||
yield "dictionary", input_data.values
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create dictionary: {str(e)}"
|
||||
|
||||
|
||||
class CreateListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: List[Any] = SchemaField(
|
||||
description="A list of values to be combined into a new list.",
|
||||
placeholder="e.g., ['Alice', 25, True]",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
description="The created list containing the specified values."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if list creation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a912d5c7-6e00-4542-b2a9-8034136930e4",
|
||||
description="Creates a list with the specified values. Use this when you know all the values you want to add upfront.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=CreateListBlock.Input,
|
||||
output_schema=CreateListBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"values": ["Alice", 25, True],
|
||||
},
|
||||
{
|
||||
"values": [1, 2, 3, "four", {"key": "value"}],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"list",
|
||||
["Alice", 25, True],
|
||||
),
|
||||
(
|
||||
"list",
|
||||
[1, 2, 3, "four", {"key": "value"}],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# The values are already validated by Pydantic schema
|
||||
yield "list", input_data.values
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create list: {str(e)}"
|
||||
|
||||
|
||||
class TypeOptions(enum.Enum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
@@ -162,7 +692,6 @@ class UniversalTypeConverterBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
value: Any = SchemaField(description="The converted value.")
|
||||
error: str = SchemaField(description="Error message if conversion failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -173,7 +702,7 @@ class UniversalTypeConverterBlock(Block):
|
||||
output_schema=UniversalTypeConverterBlock.Output,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
converted_value = convert(
|
||||
input_data.value,
|
||||
@@ -188,31 +717,3 @@ class UniversalTypeConverterBlock(Block):
|
||||
yield "value", converted_value
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to convert value: {str(e)}"
|
||||
|
||||
|
||||
class ReverseListOrderBlock(Block):
|
||||
"""
|
||||
A block which takes in a list and returns it in the opposite order.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
input_list: list[Any] = SchemaField(description="The list to reverse")
|
||||
|
||||
class Output(BlockSchema):
|
||||
reversed_list: list[Any] = SchemaField(description="The list in reversed order")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="422cb708-3109-4277-bfe3-bc2ae5812777",
|
||||
description="Reverses the order of elements in a list",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=ReverseListOrderBlock.Input,
|
||||
output_schema=ReverseListOrderBlock.Output,
|
||||
test_input={"input_list": [1, 2, 3, 4, 5]},
|
||||
test_output=[("reversed_list", [5, 4, 3, 2, 1])],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
reversed_list = list(input_data.input_list)
|
||||
reversed_list.reverse()
|
||||
yield "reversed_list", reversed_list
|
||||
|
||||
@@ -38,7 +38,7 @@ class BlockInstallationBlock(Block):
|
||||
disabled=True,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
code = input_data.code
|
||||
|
||||
if search := re.search(r"class (\w+)\(Block\):", code):
|
||||
@@ -64,7 +64,7 @@ class BlockInstallationBlock(Block):
|
||||
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
await execute_block_test(block)
|
||||
execute_block_test(block)
|
||||
yield "success", "Block installed successfully."
|
||||
except Exception as e:
|
||||
os.remove(file_path)
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Any
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.type import convert
|
||||
|
||||
|
||||
class ComparisonOperator(Enum):
|
||||
@@ -71,7 +70,7 @@ class ConditionBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
operator = input_data.operator
|
||||
|
||||
value1 = input_data.value1
|
||||
@@ -164,7 +163,7 @@ class IfInputMatchesBlock(Block):
|
||||
},
|
||||
{
|
||||
"input": 10,
|
||||
"value": "None",
|
||||
"value": None,
|
||||
"yes_value": "Yes",
|
||||
"no_value": "No",
|
||||
},
|
||||
@@ -181,24 +180,8 @@ class IfInputMatchesBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
|
||||
# If input_data.value is not matching input_data.input, convert value to type of input
|
||||
if (
|
||||
input_data.input != input_data.value
|
||||
and input_data.input is not input_data.value
|
||||
):
|
||||
try:
|
||||
# Only attempt conversion if input is not None and value is not None
|
||||
if input_data.input is not None and input_data.value is not None:
|
||||
input_type = type(input_data.input)
|
||||
# Avoid converting if input_type is Any or object
|
||||
if input_type not in (Any, object):
|
||||
input_data.value = convert(input_data.value, input_type)
|
||||
except Exception:
|
||||
pass # If conversion fails, just leave value as is
|
||||
|
||||
if input_data.input == input_data.value:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
if input_data.input == input_data.value or input_data.input is input_data.value:
|
||||
yield "result", True
|
||||
yield "yes_output", input_data.yes_value
|
||||
else:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from e2b_code_interpreter import AsyncSandbox
|
||||
from e2b_code_interpreter import Sandbox
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
@@ -55,7 +55,7 @@ class CodeExecutionBlock(Block):
|
||||
"These commands are executed with `sh`, in the foreground."
|
||||
),
|
||||
placeholder="pip install cowsay",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
@@ -123,7 +123,7 @@ class CodeExecutionBlock(Block):
|
||||
},
|
||||
)
|
||||
|
||||
async def execute_code(
|
||||
def execute_code(
|
||||
self,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
@@ -135,21 +135,21 @@ class CodeExecutionBlock(Block):
|
||||
try:
|
||||
sandbox = None
|
||||
if template_id:
|
||||
sandbox = await AsyncSandbox.create(
|
||||
sandbox = Sandbox(
|
||||
template=template_id, api_key=api_key, timeout=timeout
|
||||
)
|
||||
else:
|
||||
sandbox = await AsyncSandbox.create(api_key=api_key, timeout=timeout)
|
||||
sandbox = Sandbox(api_key=api_key, timeout=timeout)
|
||||
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not created")
|
||||
|
||||
# Running setup commands
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
sandbox.commands.run(cmd)
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(
|
||||
execution = sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox if there is an error
|
||||
@@ -167,11 +167,11 @@ class CodeExecutionBlock(Block):
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def run(
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
response, stdout_logs, stderr_logs = await self.execute_code(
|
||||
response, stdout_logs, stderr_logs = self.execute_code(
|
||||
input_data.code,
|
||||
input_data.language,
|
||||
input_data.setup_commands,
|
||||
@@ -207,7 +207,7 @@ class InstantiationBlock(Block):
|
||||
"These commands are executed with `sh`, in the foreground."
|
||||
),
|
||||
placeholder="pip install cowsay",
|
||||
default_factory=list,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
@@ -278,11 +278,11 @@ class InstantiationBlock(Block):
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
sandbox_id, response, stdout_logs, stderr_logs = await self.execute_code(
|
||||
sandbox_id, response, stdout_logs, stderr_logs = self.execute_code(
|
||||
input_data.setup_code,
|
||||
input_data.language,
|
||||
input_data.setup_commands,
|
||||
@@ -303,7 +303,7 @@ class InstantiationBlock(Block):
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
async def execute_code(
|
||||
def execute_code(
|
||||
self,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
@@ -315,21 +315,21 @@ class InstantiationBlock(Block):
|
||||
try:
|
||||
sandbox = None
|
||||
if template_id:
|
||||
sandbox = await AsyncSandbox.create(
|
||||
sandbox = Sandbox(
|
||||
template=template_id, api_key=api_key, timeout=timeout
|
||||
)
|
||||
else:
|
||||
sandbox = await AsyncSandbox.create(api_key=api_key, timeout=timeout)
|
||||
sandbox = Sandbox(api_key=api_key, timeout=timeout)
|
||||
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not created")
|
||||
|
||||
# Running setup commands
|
||||
for cmd in setup_commands:
|
||||
await sandbox.commands.run(cmd)
|
||||
sandbox.commands.run(cmd)
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(
|
||||
execution = sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox if there is an error
|
||||
@@ -409,7 +409,7 @@ class StepExecutionBlock(Block):
|
||||
},
|
||||
)
|
||||
|
||||
async def execute_step_code(
|
||||
def execute_step_code(
|
||||
self,
|
||||
sandbox_id: str,
|
||||
code: str,
|
||||
@@ -417,12 +417,12 @@ class StepExecutionBlock(Block):
|
||||
api_key: str,
|
||||
):
|
||||
try:
|
||||
sandbox = await AsyncSandbox.connect(sandbox_id=sandbox_id, api_key=api_key)
|
||||
sandbox = Sandbox.connect(sandbox_id=sandbox_id, api_key=api_key)
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not found")
|
||||
|
||||
# Executing the code
|
||||
execution = await sandbox.run_code(code, language=language.value)
|
||||
execution = sandbox.run_code(code, language=language.value)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
@@ -436,11 +436,11 @@ class StepExecutionBlock(Block):
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def run(
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
response, stdout_logs, stderr_logs = await self.execute_step_code(
|
||||
response, stdout_logs, stderr_logs = self.execute_step_code(
|
||||
input_data.sandbox_id,
|
||||
input_data.step_code,
|
||||
input_data.language,
|
||||
|
||||
@@ -49,7 +49,7 @@ class CodeExtractionBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# List of supported programming languages with mapped aliases
|
||||
language_aliases = {
|
||||
"html": ["html", "htm"],
|
||||
|
||||
@@ -8,7 +8,6 @@ from backend.data.block import (
|
||||
BlockSchema,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.compass import CompassWebhookType
|
||||
|
||||
|
||||
@@ -43,7 +42,7 @@ class CompassAITriggerBlock(Block):
|
||||
input_schema=CompassAITriggerBlock.Input,
|
||||
output_schema=CompassAITriggerBlock.Output,
|
||||
webhook_config=BlockManualWebhookConfig(
|
||||
provider=ProviderName.COMPASS,
|
||||
provider="compass",
|
||||
webhook_type=CompassWebhookType.TRANSCRIPTION,
|
||||
),
|
||||
test_input=[
|
||||
@@ -56,5 +55,5 @@ class CompassAITriggerBlock(Block):
|
||||
# ],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "transcription", input_data.payload.transcription
|
||||
|
||||
@@ -30,7 +30,7 @@ class WordCharacterCountBlock(Block):
|
||||
test_output=[("word_count", 4), ("character_count", 19)],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
text = input_data.text
|
||||
word_count = len(text.split())
|
||||
|
||||
109
autogpt_platform/backend/backend/blocks/csv.py
Normal file
109
autogpt_platform/backend/backend/blocks/csv.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import ContributorDetails, SchemaField
|
||||
|
||||
|
||||
class ReadCsvBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
contents: str = SchemaField(
|
||||
description="The contents of the CSV file to read",
|
||||
placeholder="a, b, c\n1,2,3\n4,5,6",
|
||||
)
|
||||
delimiter: str = SchemaField(
|
||||
description="The delimiter used in the CSV file",
|
||||
default=",",
|
||||
)
|
||||
quotechar: str = SchemaField(
|
||||
description="The character used to quote fields",
|
||||
default='"',
|
||||
)
|
||||
escapechar: str = SchemaField(
|
||||
description="The character used to escape the delimiter",
|
||||
default="\\",
|
||||
)
|
||||
has_header: bool = SchemaField(
|
||||
description="Whether the CSV file has a header row",
|
||||
default=True,
|
||||
)
|
||||
skip_rows: int = SchemaField(
|
||||
description="The number of rows to skip from the start of the file",
|
||||
default=0,
|
||||
)
|
||||
strip: bool = SchemaField(
|
||||
description="Whether to strip whitespace from the values",
|
||||
default=True,
|
||||
)
|
||||
skip_columns: list[str] = SchemaField(
|
||||
description="The columns to skip from the start of the row",
|
||||
default=[],
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
row: dict[str, str] = SchemaField(
|
||||
description="The data produced from each row in the CSV file"
|
||||
)
|
||||
all_data: list[dict[str, str]] = SchemaField(
|
||||
description="All the data in the CSV file as a list of rows"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="acf7625e-d2cb-4941-bfeb-2819fc6fc015",
|
||||
input_schema=ReadCsvBlock.Input,
|
||||
output_schema=ReadCsvBlock.Output,
|
||||
description="Reads a CSV file and outputs the data as a list of dictionaries and individual rows via rows.",
|
||||
contributors=[ContributorDetails(name="Nicholas Tindle")],
|
||||
categories={BlockCategory.TEXT, BlockCategory.DATA},
|
||||
test_input={
|
||||
"contents": "a, b, c\n1,2,3\n4,5,6",
|
||||
},
|
||||
test_output=[
|
||||
("row", {"a": "1", "b": "2", "c": "3"}),
|
||||
("row", {"a": "4", "b": "5", "c": "6"}),
|
||||
(
|
||||
"all_data",
|
||||
[
|
||||
{"a": "1", "b": "2", "c": "3"},
|
||||
{"a": "4", "b": "5", "c": "6"},
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
import csv
|
||||
from io import StringIO
|
||||
|
||||
csv_file = StringIO(input_data.contents)
|
||||
reader = csv.reader(
|
||||
csv_file,
|
||||
delimiter=input_data.delimiter,
|
||||
quotechar=input_data.quotechar,
|
||||
escapechar=input_data.escapechar,
|
||||
)
|
||||
|
||||
header = None
|
||||
if input_data.has_header:
|
||||
header = next(reader)
|
||||
if input_data.strip:
|
||||
header = [h.strip() for h in header]
|
||||
|
||||
for _ in range(input_data.skip_rows):
|
||||
next(reader)
|
||||
|
||||
def process_row(row):
|
||||
data = {}
|
||||
for i, value in enumerate(row):
|
||||
if i not in input_data.skip_columns:
|
||||
if input_data.has_header and header:
|
||||
data[header[i]] = value.strip() if input_data.strip else value
|
||||
else:
|
||||
data[str(i)] = value.strip() if input_data.strip else value
|
||||
return data
|
||||
|
||||
all_data = []
|
||||
for row in reader:
|
||||
processed_row = process_row(row)
|
||||
all_data.append(processed_row)
|
||||
yield "row", processed_row
|
||||
|
||||
yield "all_data", all_data
|
||||
@@ -1,683 +0,0 @@
|
||||
from typing import Any, List
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.json import loads
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.prompt import estimate_token_count_str
|
||||
|
||||
# =============================================================================
|
||||
# Dictionary Manipulation Blocks
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CreateDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: dict[str, Any] = SchemaField(
|
||||
description="Key-value pairs to create the dictionary with",
|
||||
placeholder="e.g., {'name': 'Alice', 'age': 25}",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
dictionary: dict[str, Any] = SchemaField(
|
||||
description="The created dictionary containing the specified key-value pairs"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if dictionary creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b924ddf4-de4f-4b56-9a85-358930dcbc91",
|
||||
description="Creates a dictionary with the specified key-value pairs. Use this when you know all the values you want to add upfront.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=CreateDictionaryBlock.Input,
|
||||
output_schema=CreateDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"values": {"name": "Alice", "age": 25, "city": "New York"},
|
||||
},
|
||||
{
|
||||
"values": {"numbers": [1, 2, 3], "active": True, "score": 95.5},
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"dictionary",
|
||||
{"name": "Alice", "age": 25, "city": "New York"},
|
||||
),
|
||||
(
|
||||
"dictionary",
|
||||
{"numbers": [1, 2, 3], "active": True, "score": 95.5},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# The values are already validated by Pydantic schema
|
||||
yield "dictionary", input_data.values
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create dictionary: {str(e)}"
|
||||
|
||||
|
||||
class AddToDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
|
||||
)
|
||||
key: str = SchemaField(
|
||||
default="",
|
||||
description="The key for the new entry.",
|
||||
placeholder="new_key",
|
||||
advanced=False,
|
||||
)
|
||||
value: Any = SchemaField(
|
||||
default=None,
|
||||
description="The value for the new entry.",
|
||||
placeholder="new_value",
|
||||
advanced=False,
|
||||
)
|
||||
entries: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
description="The entries to add to the dictionary. This is the batch version of the `key` and `value` fields.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_dictionary: dict = SchemaField(
|
||||
description="The dictionary with the new entry added."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="31d1064e-7446-4693-a7d4-65e5ca1180d1",
|
||||
description="Adds a new key-value pair to a dictionary. If no dictionary is provided, a new one is created.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=AddToDictionaryBlock.Input,
|
||||
output_schema=AddToDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"dictionary": {"existing_key": "existing_value"},
|
||||
"key": "new_key",
|
||||
"value": "new_value",
|
||||
},
|
||||
{"key": "first_key", "value": "first_value"},
|
||||
{
|
||||
"dictionary": {"existing_key": "existing_value"},
|
||||
"entries": {"new_key": "new_value", "first_key": "first_value"},
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"updated_dictionary",
|
||||
{"existing_key": "existing_value", "new_key": "new_value"},
|
||||
),
|
||||
("updated_dictionary", {"first_key": "first_value"}),
|
||||
(
|
||||
"updated_dictionary",
|
||||
{
|
||||
"existing_key": "existing_value",
|
||||
"new_key": "new_value",
|
||||
"first_key": "first_value",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
updated_dict = input_data.dictionary.copy()
|
||||
|
||||
if input_data.value is not None and input_data.key:
|
||||
updated_dict[input_data.key] = input_data.value
|
||||
|
||||
for key, value in input_data.entries.items():
|
||||
updated_dict[key] = value
|
||||
|
||||
yield "updated_dictionary", updated_dict
|
||||
|
||||
|
||||
class FindInDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: Any = SchemaField(description="Dictionary to lookup from")
|
||||
key: str | int = SchemaField(description="Key to lookup in the dictionary")
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="Value found for the given key")
|
||||
missing: Any = SchemaField(
|
||||
description="Value of the input that missing the key"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0e50422c-6dee-4145-83d6-3a5a392f65de",
|
||||
description="Lookup the given key in the input dictionary/object/list and return the value.",
|
||||
input_schema=FindInDictionaryBlock.Input,
|
||||
output_schema=FindInDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{"input": {"apple": 1, "banana": 2, "cherry": 3}, "key": "banana"},
|
||||
{"input": {"x": 10, "y": 20, "z": 30}, "key": "w"},
|
||||
{"input": [1, 2, 3], "key": 1},
|
||||
{"input": [1, 2, 3], "key": 3},
|
||||
{"input": MockObject(value="!!", key="key"), "key": "key"},
|
||||
{"input": [{"k1": "v1"}, {"k2": "v2"}, {"k1": "v3"}], "key": "k1"},
|
||||
],
|
||||
test_output=[
|
||||
("output", 2),
|
||||
("missing", {"x": 10, "y": 20, "z": 30}),
|
||||
("output", 2),
|
||||
("missing", [1, 2, 3]),
|
||||
("output", "key"),
|
||||
("output", ["v1", "v3"]),
|
||||
],
|
||||
categories={BlockCategory.BASIC},
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
obj = input_data.input
|
||||
key = input_data.key
|
||||
|
||||
if isinstance(obj, str):
|
||||
obj = loads(obj)
|
||||
|
||||
if isinstance(obj, dict) and key in obj:
|
||||
yield "output", obj[key]
|
||||
elif isinstance(obj, list) and isinstance(key, int) and 0 <= key < len(obj):
|
||||
yield "output", obj[key]
|
||||
elif isinstance(obj, list) and isinstance(key, str):
|
||||
if len(obj) == 0:
|
||||
yield "output", []
|
||||
elif isinstance(obj[0], dict) and key in obj[0]:
|
||||
yield "output", [item[key] for item in obj if key in item]
|
||||
else:
|
||||
yield "output", [getattr(val, key) for val in obj if hasattr(val, key)]
|
||||
elif isinstance(obj, object) and isinstance(key, str) and hasattr(obj, key):
|
||||
yield "output", getattr(obj, key)
|
||||
else:
|
||||
yield "missing", input_data.input
|
||||
|
||||
|
||||
class RemoveFromDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary to modify."
|
||||
)
|
||||
key: str | int = SchemaField(description="Key to remove from the dictionary.")
|
||||
return_value: bool = SchemaField(
|
||||
default=False, description="Whether to return the removed value."
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary after removal."
|
||||
)
|
||||
removed_value: Any = SchemaField(description="The removed value if requested.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="46afe2ea-c613-43f8-95ff-6692c3ef6876",
|
||||
description="Removes a key-value pair from a dictionary.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=RemoveFromDictionaryBlock.Input,
|
||||
output_schema=RemoveFromDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"dictionary": {"a": 1, "b": 2, "c": 3},
|
||||
"key": "b",
|
||||
"return_value": True,
|
||||
},
|
||||
{"dictionary": {"x": "hello", "y": "world"}, "key": "x"},
|
||||
],
|
||||
test_output=[
|
||||
("updated_dictionary", {"a": 1, "c": 3}),
|
||||
("removed_value", 2),
|
||||
("updated_dictionary", {"y": "world"}),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
updated_dict = input_data.dictionary.copy()
|
||||
try:
|
||||
removed_value = updated_dict.pop(input_data.key)
|
||||
yield "updated_dictionary", updated_dict
|
||||
if input_data.return_value:
|
||||
yield "removed_value", removed_value
|
||||
except KeyError:
|
||||
yield "error", f"Key '{input_data.key}' not found in dictionary"
|
||||
|
||||
|
||||
class ReplaceDictionaryValueBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary to modify."
|
||||
)
|
||||
key: str | int = SchemaField(description="Key to replace the value for.")
|
||||
value: Any = SchemaField(description="The new value for the given key.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary after replacement."
|
||||
)
|
||||
old_value: Any = SchemaField(description="The value that was replaced.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="27e31876-18b6-44f3-ab97-f6226d8b3889",
|
||||
description="Replaces the value for a specified key in a dictionary.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=ReplaceDictionaryValueBlock.Input,
|
||||
output_schema=ReplaceDictionaryValueBlock.Output,
|
||||
test_input=[
|
||||
{"dictionary": {"a": 1, "b": 2, "c": 3}, "key": "b", "value": 99},
|
||||
{
|
||||
"dictionary": {"x": "hello", "y": "world"},
|
||||
"key": "y",
|
||||
"value": "universe",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("updated_dictionary", {"a": 1, "b": 99, "c": 3}),
|
||||
("old_value", 2),
|
||||
("updated_dictionary", {"x": "hello", "y": "universe"}),
|
||||
("old_value", "world"),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
updated_dict = input_data.dictionary.copy()
|
||||
try:
|
||||
old_value = updated_dict[input_data.key]
|
||||
updated_dict[input_data.key] = input_data.value
|
||||
yield "updated_dictionary", updated_dict
|
||||
yield "old_value", old_value
|
||||
except KeyError:
|
||||
yield "error", f"Key '{input_data.key}' not found in dictionary"
|
||||
|
||||
|
||||
class DictionaryIsEmptyBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(description="The dictionary to check.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
is_empty: bool = SchemaField(description="True if the dictionary is empty.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a3cf3f64-6bb9-4cc6-9900-608a0b3359b0",
|
||||
description="Checks if a dictionary is empty.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=DictionaryIsEmptyBlock.Input,
|
||||
output_schema=DictionaryIsEmptyBlock.Output,
|
||||
test_input=[{"dictionary": {}}, {"dictionary": {"a": 1}}],
|
||||
test_output=[("is_empty", True), ("is_empty", False)],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "is_empty", len(input_data.dictionary) == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# List Manipulation Blocks
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CreateListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: List[Any] = SchemaField(
|
||||
description="A list of values to be combined into a new list.",
|
||||
placeholder="e.g., ['Alice', 25, True]",
|
||||
)
|
||||
max_size: int | None = SchemaField(
|
||||
default=None,
|
||||
description="Maximum size of the list. If provided, the list will be yielded in chunks of this size.",
|
||||
advanced=True,
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
default=None,
|
||||
description="Maximum tokens for the list. If provided, the list will be yielded in chunks that fit within this token limit.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
description="The created list containing the specified values."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if list creation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a912d5c7-6e00-4542-b2a9-8034136930e4",
|
||||
description="Creates a list with the specified values. Use this when you know all the values you want to add upfront. This block can also yield the list in batches based on a maximum size or token limit.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=CreateListBlock.Input,
|
||||
output_schema=CreateListBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"values": ["Alice", 25, True],
|
||||
},
|
||||
{
|
||||
"values": [1, 2, 3, "four", {"key": "value"}],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"list",
|
||||
["Alice", 25, True],
|
||||
),
|
||||
(
|
||||
"list",
|
||||
[1, 2, 3, "four", {"key": "value"}],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
chunk = []
|
||||
cur_tokens, max_tokens = 0, input_data.max_tokens
|
||||
cur_size, max_size = 0, input_data.max_size
|
||||
|
||||
for value in input_data.values:
|
||||
if max_tokens:
|
||||
tokens = estimate_token_count_str(value)
|
||||
else:
|
||||
tokens = 0
|
||||
|
||||
# Check if adding this value would exceed either limit
|
||||
if (max_tokens and (cur_tokens + tokens > max_tokens)) or (
|
||||
max_size and (cur_size + 1 > max_size)
|
||||
):
|
||||
yield "list", chunk
|
||||
chunk = [value]
|
||||
cur_size, cur_tokens = 1, tokens
|
||||
else:
|
||||
chunk.append(value)
|
||||
cur_size, cur_tokens = cur_size + 1, cur_tokens + tokens
|
||||
|
||||
# Yield final chunk if any
|
||||
if chunk or not input_data.values:
|
||||
yield "list", chunk
|
||||
|
||||
|
||||
class AddToListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
description="The list to add the entry to. If not provided, a new list will be created.",
|
||||
)
|
||||
entry: Any = SchemaField(
|
||||
description="The entry to add to the list. Can be of any type (string, int, dict, etc.).",
|
||||
advanced=False,
|
||||
default=None,
|
||||
)
|
||||
entries: List[Any] = SchemaField(
|
||||
default_factory=lambda: list(),
|
||||
description="The entries to add to the list. This is the batch version of the `entry` field.",
|
||||
advanced=True,
|
||||
)
|
||||
position: int | None = SchemaField(
|
||||
default=None,
|
||||
description="The position to insert the new entry. If not provided, the entry will be appended to the end of the list.",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_list: List[Any] = SchemaField(
|
||||
description="The list with the new entry added."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="aeb08fc1-2fc1-4141-bc8e-f758f183a822",
|
||||
description="Adds a new entry to a list. The entry can be of any type. If no list is provided, a new one is created.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=AddToListBlock.Input,
|
||||
output_schema=AddToListBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"list": [1, "string", {"existing_key": "existing_value"}],
|
||||
"entry": {"new_key": "new_value"},
|
||||
"position": 1,
|
||||
},
|
||||
{"entry": "first_entry"},
|
||||
{"list": ["a", "b", "c"], "entry": "d"},
|
||||
{
|
||||
"entry": "e",
|
||||
"entries": ["f", "g"],
|
||||
"list": ["a", "b"],
|
||||
"position": 1,
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"updated_list",
|
||||
[
|
||||
1,
|
||||
{"new_key": "new_value"},
|
||||
"string",
|
||||
{"existing_key": "existing_value"},
|
||||
],
|
||||
),
|
||||
("updated_list", ["first_entry"]),
|
||||
("updated_list", ["a", "b", "c", "d"]),
|
||||
("updated_list", ["a", "f", "g", "e", "b"]),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
entries_added = input_data.entries.copy()
|
||||
if input_data.entry:
|
||||
entries_added.append(input_data.entry)
|
||||
|
||||
updated_list = input_data.list.copy()
|
||||
if (pos := input_data.position) is not None:
|
||||
updated_list = updated_list[:pos] + entries_added + updated_list[pos:]
|
||||
else:
|
||||
updated_list += entries_added
|
||||
|
||||
yield "updated_list", updated_list
|
||||
|
||||
|
||||
class FindInListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to search in.")
|
||||
value: Any = SchemaField(description="The value to search for.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
index: int = SchemaField(description="The index of the value in the list.")
|
||||
found: bool = SchemaField(
|
||||
description="Whether the value was found in the list."
|
||||
)
|
||||
not_found_value: Any = SchemaField(
|
||||
description="The value that was not found in the list."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5e2c6d0a-1e37-489f-b1d0-8e1812b23333",
|
||||
description="Finds the index of the value in the list.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=FindInListBlock.Input,
|
||||
output_schema=FindInListBlock.Output,
|
||||
test_input=[
|
||||
{"list": [1, 2, 3, 4, 5], "value": 3},
|
||||
{"list": [1, 2, 3, 4, 5], "value": 6},
|
||||
],
|
||||
test_output=[
|
||||
("index", 2),
|
||||
("found", True),
|
||||
("found", False),
|
||||
("not_found_value", 6),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
yield "index", input_data.list.index(input_data.value)
|
||||
yield "found", True
|
||||
except ValueError:
|
||||
yield "found", False
|
||||
yield "not_found_value", input_data.value
|
||||
|
||||
|
||||
class GetListItemBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to get the item from.")
|
||||
index: int = SchemaField(
|
||||
description="The 0-based index of the item (supports negative indices)."
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
item: Any = SchemaField(description="The item at the specified index.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="262ca24c-1025-43cf-a578-534e23234e97",
|
||||
description="Returns the element at the given index.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=GetListItemBlock.Input,
|
||||
output_schema=GetListItemBlock.Output,
|
||||
test_input=[
|
||||
{"list": [1, 2, 3], "index": 1},
|
||||
{"list": [1, 2, 3], "index": -1},
|
||||
],
|
||||
test_output=[
|
||||
("item", 2),
|
||||
("item", 3),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
yield "item", input_data.list[input_data.index]
|
||||
except IndexError:
|
||||
yield "error", "Index out of range"
|
||||
|
||||
|
||||
class RemoveFromListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to modify.")
|
||||
value: Any = SchemaField(
|
||||
default=None, description="Value to remove from the list."
|
||||
)
|
||||
index: int | None = SchemaField(
|
||||
default=None,
|
||||
description="Index of the item to pop (supports negative indices).",
|
||||
)
|
||||
return_item: bool = SchemaField(
|
||||
default=False, description="Whether to return the removed item."
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_list: List[Any] = SchemaField(description="The list after removal.")
|
||||
removed_item: Any = SchemaField(description="The removed item if requested.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d93c5a93-ac7e-41c1-ae5c-ef67e6e9b826",
|
||||
description="Removes an item from a list by value or index.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=RemoveFromListBlock.Input,
|
||||
output_schema=RemoveFromListBlock.Output,
|
||||
test_input=[
|
||||
{"list": [1, 2, 3], "index": 1, "return_item": True},
|
||||
{"list": ["a", "b", "c"], "value": "b"},
|
||||
],
|
||||
test_output=[
|
||||
("updated_list", [1, 3]),
|
||||
("removed_item", 2),
|
||||
("updated_list", ["a", "c"]),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
lst = input_data.list.copy()
|
||||
removed = None
|
||||
try:
|
||||
if input_data.index is not None:
|
||||
removed = lst.pop(input_data.index)
|
||||
elif input_data.value is not None:
|
||||
lst.remove(input_data.value)
|
||||
removed = input_data.value
|
||||
else:
|
||||
raise ValueError("No index or value provided for removal")
|
||||
except (IndexError, ValueError):
|
||||
yield "error", "Index or value not found"
|
||||
return
|
||||
|
||||
yield "updated_list", lst
|
||||
if input_data.return_item:
|
||||
yield "removed_item", removed
|
||||
|
||||
|
||||
class ReplaceListItemBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to modify.")
|
||||
index: int = SchemaField(
|
||||
description="Index of the item to replace (supports negative indices)."
|
||||
)
|
||||
value: Any = SchemaField(description="The new value for the given index.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_list: List[Any] = SchemaField(description="The list after replacement.")
|
||||
old_item: Any = SchemaField(description="The item that was replaced.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fbf62922-bea1-4a3d-8bac-23587f810b38",
|
||||
description="Replaces an item at the specified index.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=ReplaceListItemBlock.Input,
|
||||
output_schema=ReplaceListItemBlock.Output,
|
||||
test_input=[
|
||||
{"list": [1, 2, 3], "index": 1, "value": 99},
|
||||
{"list": ["a", "b"], "index": -1, "value": "c"},
|
||||
],
|
||||
test_output=[
|
||||
("updated_list", [1, 99, 3]),
|
||||
("old_item", 2),
|
||||
("updated_list", ["a", "c"]),
|
||||
("old_item", "b"),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
lst = input_data.list.copy()
|
||||
try:
|
||||
old = lst[input_data.index]
|
||||
lst[input_data.index] = input_data.value
|
||||
except IndexError:
|
||||
yield "error", "Index out of range"
|
||||
return
|
||||
|
||||
yield "updated_list", lst
|
||||
yield "old_item", old
|
||||
|
||||
|
||||
class ListIsEmptyBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to check.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
is_empty: bool = SchemaField(description="True if the list is empty.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="896ed73b-27d0-41be-813c-c1c1dc856c03",
|
||||
description="Checks if a list is empty.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=ListIsEmptyBlock.Input,
|
||||
output_schema=ListIsEmptyBlock.Output,
|
||||
test_input=[{"list": []}, {"list": [1]}],
|
||||
test_output=[("is_empty", True), ("is_empty", False)],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "is_empty", len(input_data.list) == 0
|
||||
@@ -34,6 +34,6 @@ This is a "quoted" string.""",
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
decoded_text = codecs.decode(input_data.text, "unicode_escape")
|
||||
yield "decoded_text", decoded_text
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -121,7 +121,7 @@ class SendEmailBlock(Block):
|
||||
|
||||
return "Email sent successfully"
|
||||
|
||||
async def run(
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: SMTPCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
yield "status", self.send_email(
|
||||
|
||||
@@ -1,408 +0,0 @@
|
||||
"""
|
||||
API module for Enrichlayer integration.
|
||||
|
||||
This module provides a client for interacting with the Enrichlayer API,
|
||||
which allows fetching LinkedIn profile data and related information.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import enum
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.model import APIKeyCredentials
|
||||
from backend.util.request import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class EnrichlayerAPIException(Exception):
|
||||
"""Exception raised for Enrichlayer API errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class FallbackToCache(enum.Enum):
|
||||
ON_ERROR = "on-error"
|
||||
NEVER = "never"
|
||||
|
||||
|
||||
class UseCache(enum.Enum):
|
||||
IF_PRESENT = "if-present"
|
||||
NEVER = "never"
|
||||
|
||||
|
||||
class SocialMediaProfiles(BaseModel):
|
||||
"""Social media profiles model."""
|
||||
|
||||
twitter: Optional[str] = None
|
||||
facebook: Optional[str] = None
|
||||
github: Optional[str] = None
|
||||
|
||||
|
||||
class Experience(BaseModel):
|
||||
"""Experience model for LinkedIn profiles."""
|
||||
|
||||
company: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
location: Optional[str] = None
|
||||
starts_at: Optional[dict[str, int]] = None
|
||||
ends_at: Optional[dict[str, int]] = None
|
||||
company_linkedin_profile_url: Optional[str] = None
|
||||
|
||||
|
||||
class Education(BaseModel):
|
||||
"""Education model for LinkedIn profiles."""
|
||||
|
||||
school: Optional[str] = None
|
||||
degree_name: Optional[str] = None
|
||||
field_of_study: Optional[str] = None
|
||||
starts_at: Optional[dict[str, int]] = None
|
||||
ends_at: Optional[dict[str, int]] = None
|
||||
school_linkedin_profile_url: Optional[str] = None
|
||||
|
||||
|
||||
class PersonProfileResponse(BaseModel):
|
||||
"""Response model for LinkedIn person profile.
|
||||
|
||||
This model represents the response from Enrichlayer's LinkedIn profile API.
|
||||
The API returns comprehensive profile data including work experience,
|
||||
education, skills, and contact information (when available).
|
||||
|
||||
Example API Response:
|
||||
{
|
||||
"public_identifier": "johnsmith",
|
||||
"full_name": "John Smith",
|
||||
"occupation": "Software Engineer at Tech Corp",
|
||||
"experiences": [
|
||||
{
|
||||
"company": "Tech Corp",
|
||||
"title": "Software Engineer",
|
||||
"starts_at": {"year": 2020, "month": 1}
|
||||
}
|
||||
],
|
||||
"education": [...],
|
||||
"skills": ["Python", "JavaScript", ...]
|
||||
}
|
||||
"""
|
||||
|
||||
public_identifier: Optional[str] = None
|
||||
profile_pic_url: Optional[str] = None
|
||||
full_name: Optional[str] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
occupation: Optional[str] = None
|
||||
headline: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
country_full_name: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
state: Optional[str] = None
|
||||
experiences: Optional[list[Experience]] = None
|
||||
education: Optional[list[Education]] = None
|
||||
languages: Optional[list[str]] = None
|
||||
skills: Optional[list[str]] = None
|
||||
inferred_salary: Optional[dict[str, Any]] = None
|
||||
personal_email: Optional[str] = None
|
||||
personal_contact_number: Optional[str] = None
|
||||
social_media_profiles: Optional[SocialMediaProfiles] = None
|
||||
extra: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class SimilarProfile(BaseModel):
|
||||
"""Similar profile model for LinkedIn person lookup."""
|
||||
|
||||
similarity: float
|
||||
linkedin_profile_url: str
|
||||
|
||||
|
||||
class PersonLookupResponse(BaseModel):
|
||||
"""Response model for LinkedIn person lookup.
|
||||
|
||||
This model represents the response from Enrichlayer's person lookup API.
|
||||
The API returns a LinkedIn profile URL and similarity scores when
|
||||
searching for a person by name and company.
|
||||
|
||||
Example API Response:
|
||||
{
|
||||
"url": "https://www.linkedin.com/in/johnsmith/",
|
||||
"name_similarity_score": 0.95,
|
||||
"company_similarity_score": 0.88,
|
||||
"title_similarity_score": 0.75,
|
||||
"location_similarity_score": 0.60
|
||||
}
|
||||
"""
|
||||
|
||||
url: str | None = None
|
||||
name_similarity_score: float | None
|
||||
company_similarity_score: float | None
|
||||
title_similarity_score: float | None
|
||||
location_similarity_score: float | None
|
||||
last_updated: datetime.datetime | None = None
|
||||
profile: PersonProfileResponse | None = None
|
||||
|
||||
|
||||
class RoleLookupResponse(BaseModel):
|
||||
"""Response model for LinkedIn role lookup.
|
||||
|
||||
This model represents the response from Enrichlayer's role lookup API.
|
||||
The API returns LinkedIn profile data for a specific role at a company.
|
||||
|
||||
Example API Response:
|
||||
{
|
||||
"linkedin_profile_url": "https://www.linkedin.com/in/johnsmith/",
|
||||
"profile_data": {...} // Full PersonProfileResponse data when enrich_profile=True
|
||||
}
|
||||
"""
|
||||
|
||||
linkedin_profile_url: Optional[str] = None
|
||||
profile_data: Optional[PersonProfileResponse] = None
|
||||
|
||||
|
||||
class ProfilePictureResponse(BaseModel):
|
||||
"""Response model for LinkedIn profile picture.
|
||||
|
||||
This model represents the response from Enrichlayer's profile picture API.
|
||||
The API returns a URL to the person's LinkedIn profile picture.
|
||||
|
||||
Example API Response:
|
||||
{
|
||||
"tmp_profile_pic_url": "https://media.licdn.com/dms/image/..."
|
||||
}
|
||||
"""
|
||||
|
||||
tmp_profile_pic_url: str = Field(
|
||||
..., description="URL of the profile picture", alias="tmp_profile_pic_url"
|
||||
)
|
||||
|
||||
@property
|
||||
def profile_picture_url(self) -> str:
|
||||
"""Backward compatibility property for profile_picture_url."""
|
||||
return self.tmp_profile_pic_url
|
||||
|
||||
|
||||
class EnrichlayerClient:
|
||||
"""Client for interacting with the Enrichlayer API."""
|
||||
|
||||
API_BASE_URL = "https://enrichlayer.com/api/v2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
credentials: Optional[APIKeyCredentials] = None,
|
||||
custom_requests: Optional[Requests] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Enrichlayer client.
|
||||
|
||||
Args:
|
||||
credentials: The credentials to use for authentication.
|
||||
custom_requests: Custom Requests instance for testing.
|
||||
"""
|
||||
if custom_requests:
|
||||
self._requests = custom_requests
|
||||
else:
|
||||
headers: dict[str, str] = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if credentials:
|
||||
headers["Authorization"] = (
|
||||
f"Bearer {credentials.api_key.get_secret_value()}"
|
||||
)
|
||||
|
||||
self._requests = Requests(
|
||||
extra_headers=headers,
|
||||
raise_for_status=False,
|
||||
)
|
||||
|
||||
async def _handle_response(self, response) -> Any:
|
||||
"""
|
||||
Handle API response and check for errors.
|
||||
|
||||
Args:
|
||||
response: The response object from the request.
|
||||
|
||||
Returns:
|
||||
The response data.
|
||||
|
||||
Raises:
|
||||
EnrichlayerAPIException: If the API request fails.
|
||||
"""
|
||||
if not response.ok:
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_message = error_data.get("message", "")
|
||||
except JSONDecodeError:
|
||||
error_message = response.text
|
||||
|
||||
raise EnrichlayerAPIException(
|
||||
f"Enrichlayer API request failed ({response.status_code}): {error_message}",
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def fetch_profile(
|
||||
self,
|
||||
linkedin_url: str,
|
||||
fallback_to_cache: FallbackToCache = FallbackToCache.ON_ERROR,
|
||||
use_cache: UseCache = UseCache.IF_PRESENT,
|
||||
include_skills: bool = False,
|
||||
include_inferred_salary: bool = False,
|
||||
include_personal_email: bool = False,
|
||||
include_personal_contact_number: bool = False,
|
||||
include_social_media: bool = False,
|
||||
include_extra: bool = False,
|
||||
) -> PersonProfileResponse:
|
||||
"""
|
||||
Fetch a LinkedIn profile with optional parameters.
|
||||
|
||||
Args:
|
||||
linkedin_url: The LinkedIn profile URL to fetch.
|
||||
fallback_to_cache: Cache usage if live fetch fails ('on-error' or 'never').
|
||||
use_cache: Cache utilization ('if-present' or 'never').
|
||||
include_skills: Whether to include skills data.
|
||||
include_inferred_salary: Whether to include inferred salary data.
|
||||
include_personal_email: Whether to include personal email.
|
||||
include_personal_contact_number: Whether to include personal contact number.
|
||||
include_social_media: Whether to include social media profiles.
|
||||
include_extra: Whether to include additional data.
|
||||
|
||||
Returns:
|
||||
The LinkedIn profile data.
|
||||
|
||||
Raises:
|
||||
EnrichlayerAPIException: If the API request fails.
|
||||
"""
|
||||
params = {
|
||||
"url": linkedin_url,
|
||||
"fallback_to_cache": fallback_to_cache.value.lower(),
|
||||
"use_cache": use_cache.value.lower(),
|
||||
}
|
||||
|
||||
if include_skills:
|
||||
params["skills"] = "include"
|
||||
if include_inferred_salary:
|
||||
params["inferred_salary"] = "include"
|
||||
if include_personal_email:
|
||||
params["personal_email"] = "include"
|
||||
if include_personal_contact_number:
|
||||
params["personal_contact_number"] = "include"
|
||||
if include_social_media:
|
||||
params["twitter_profile_id"] = "include"
|
||||
params["facebook_profile_id"] = "include"
|
||||
params["github_profile_id"] = "include"
|
||||
if include_extra:
|
||||
params["extra"] = "include"
|
||||
|
||||
response = await self._requests.get(
|
||||
f"{self.API_BASE_URL}/profile", params=params
|
||||
)
|
||||
return PersonProfileResponse(**await self._handle_response(response))
|
||||
|
||||
async def lookup_person(
|
||||
self,
|
||||
first_name: str,
|
||||
company_domain: str,
|
||||
last_name: str | None = None,
|
||||
location: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
include_similarity_checks: bool = False,
|
||||
enrich_profile: bool = False,
|
||||
) -> PersonLookupResponse:
|
||||
"""
|
||||
Look up a LinkedIn profile by person's information.
|
||||
|
||||
Args:
|
||||
first_name: The person's first name.
|
||||
last_name: The person's last name.
|
||||
company_domain: The domain of the company they work for.
|
||||
location: The person's location.
|
||||
title: The person's job title.
|
||||
include_similarity_checks: Whether to include similarity checks.
|
||||
enrich_profile: Whether to enrich the profile.
|
||||
|
||||
Returns:
|
||||
The LinkedIn profile lookup result.
|
||||
|
||||
Raises:
|
||||
EnrichlayerAPIException: If the API request fails.
|
||||
"""
|
||||
params = {"first_name": first_name, "company_domain": company_domain}
|
||||
|
||||
if last_name:
|
||||
params["last_name"] = last_name
|
||||
if location:
|
||||
params["location"] = location
|
||||
if title:
|
||||
params["title"] = title
|
||||
if include_similarity_checks:
|
||||
params["similarity_checks"] = "include"
|
||||
if enrich_profile:
|
||||
params["enrich_profile"] = "enrich"
|
||||
|
||||
response = await self._requests.get(
|
||||
f"{self.API_BASE_URL}/profile/resolve", params=params
|
||||
)
|
||||
return PersonLookupResponse(**await self._handle_response(response))
|
||||
|
||||
async def lookup_role(
|
||||
self, role: str, company_name: str, enrich_profile: bool = False
|
||||
) -> RoleLookupResponse:
|
||||
"""
|
||||
Look up a LinkedIn profile by role in a company.
|
||||
|
||||
Args:
|
||||
role: The role title (e.g., CEO, CTO).
|
||||
company_name: The name of the company.
|
||||
enrich_profile: Whether to enrich the profile.
|
||||
|
||||
Returns:
|
||||
The LinkedIn profile lookup result.
|
||||
|
||||
Raises:
|
||||
EnrichlayerAPIException: If the API request fails.
|
||||
"""
|
||||
params = {
|
||||
"role": role,
|
||||
"company_name": company_name,
|
||||
}
|
||||
|
||||
if enrich_profile:
|
||||
params["enrich_profile"] = "enrich"
|
||||
|
||||
response = await self._requests.get(
|
||||
f"{self.API_BASE_URL}/find/company/role", params=params
|
||||
)
|
||||
return RoleLookupResponse(**await self._handle_response(response))
|
||||
|
||||
async def get_profile_picture(
|
||||
self, linkedin_profile_url: str
|
||||
) -> ProfilePictureResponse:
|
||||
"""
|
||||
Get a LinkedIn profile picture URL.
|
||||
|
||||
Args:
|
||||
linkedin_profile_url: The LinkedIn profile URL.
|
||||
|
||||
Returns:
|
||||
The profile picture URL.
|
||||
|
||||
Raises:
|
||||
EnrichlayerAPIException: If the API request fails.
|
||||
"""
|
||||
params = {
|
||||
"linkedin_person_profile_url": linkedin_profile_url,
|
||||
}
|
||||
|
||||
response = await self._requests.get(
|
||||
f"{self.API_BASE_URL}/person/profile-picture", params=params
|
||||
)
|
||||
return ProfilePictureResponse(**await self._handle_response(response))
|
||||
@@ -1,34 +0,0 @@
|
||||
"""
|
||||
Authentication module for Enrichlayer API integration.
|
||||
|
||||
This module provides credential types and test credentials for the Enrichlayer API.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
# Define the type of credentials input expected for Enrichlayer API
|
||||
EnrichlayerCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.ENRICHLAYER], Literal["api_key"]
|
||||
]
|
||||
|
||||
# Mock credentials for testing Enrichlayer API integration
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="1234a567-89bc-4def-ab12-3456cdef7890",
|
||||
provider="enrichlayer",
|
||||
api_key=SecretStr("mock-enrichlayer-api-key"),
|
||||
title="Mock Enrichlayer API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
# Dictionary representation of test credentials for input fields
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user