mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-20 20:48:11 -05:00
Merge branch 'dev' into swiftyos/secrt-887-financial-advisor-agent
This commit is contained in:
13
.github/PULL_REQUEST_TEMPLATE.md
vendored
13
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -21,3 +21,16 @@ Here is a list of our critical paths, if you need some inspiration on what and h
|
||||
- Upload agent to marketplace
|
||||
- Import an agent from marketplace and confirm it executes correctly
|
||||
- Edit an agent from monitor, and confirm it executes correctly
|
||||
|
||||
### Configuration Changes 📝
|
||||
> [!NOTE]
|
||||
Only for the new autogpt platform, currently in autogpt_platform/
|
||||
|
||||
If you're making configuration or infrastructure changes, please remember to check you've updated the related infrastructure code in the autogpt_platform/infra folder.
|
||||
|
||||
Examples of such changes might include:
|
||||
|
||||
- Changing ports
|
||||
- Adding new services that need to communicate with each other
|
||||
- Secrets or environment variable changes
|
||||
- New or infrastructure changes such as databases
|
||||
|
||||
179
.github/dependabot.yml
vendored
Normal file
179
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,179 @@
|
||||
version: 2
|
||||
updates:
|
||||
# autogpt_libs (Poetry project)
|
||||
- package-ecosystem: "pip"
|
||||
directory: "autogpt_platform/autogpt_libs"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 10
|
||||
target-branch: "dev"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
# backend (Poetry project)
|
||||
- package-ecosystem: "pip"
|
||||
directory: "autogpt_platform/backend"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 10
|
||||
target-branch: "dev"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# frontend (Next.js project)
|
||||
- package-ecosystem: "npm"
|
||||
directory: "autogpt_platform/frontend"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 10
|
||||
target-branch: "dev"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# infra (Terraform)
|
||||
- package-ecosystem: "terraform"
|
||||
directory: "autogpt_platform/infra"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 5
|
||||
target-branch: "dev"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# market (Poetry project)
|
||||
- package-ecosystem: "pip"
|
||||
directory: "autogpt_platform/market"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 10
|
||||
target-branch: "dev"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# GitHub Actions
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 5
|
||||
target-branch: "dev"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# Docker
|
||||
- package-ecosystem: "docker"
|
||||
directory: "autogpt_platform/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 5
|
||||
target-branch: "dev"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# Submodules
|
||||
- package-ecosystem: "gitsubmodule"
|
||||
directory: "autogpt_platform/supabase"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 1
|
||||
target-branch: "dev"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# Docs
|
||||
- package-ecosystem: 'pip'
|
||||
directory: "docs/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 1
|
||||
target-branch: "dev"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
5
.github/labeler.yml
vendored
5
.github/labeler.yml
vendored
@@ -25,3 +25,8 @@ platform/frontend:
|
||||
platform/backend:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: autogpt_platform/backend/**
|
||||
- all-globs-to-all-files: '!autogpt_platform/backend/backend/blocks/**'
|
||||
|
||||
platform/blocks:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: autogpt_platform/backend/backend/blocks/**
|
||||
|
||||
4
.github/workflows/classic-autogpt-ci.yml
vendored
4
.github/workflows/classic-autogpt-ci.yml
vendored
@@ -2,12 +2,12 @@ name: Classic - AutoGPT CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, development, ci-test* ]
|
||||
branches: [ master, dev, ci-test* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpt-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
pull_request:
|
||||
branches: [ master, development, release-* ]
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpt-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
|
||||
@@ -8,7 +8,7 @@ on:
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
pull_request:
|
||||
branches: [ master, development, release-* ]
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpt-docker-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
|
||||
4
.github/workflows/classic-autogpts-ci.yml
vendored
4
.github/workflows/classic-autogpts-ci.yml
vendored
@@ -5,7 +5,7 @@ on:
|
||||
schedule:
|
||||
- cron: '0 8 * * *'
|
||||
push:
|
||||
branches: [ master, development, ci-test* ]
|
||||
branches: [ master, dev, ci-test* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpts-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
@@ -16,7 +16,7 @@ on:
|
||||
- 'classic/setup.py'
|
||||
- '!**/*.md'
|
||||
pull_request:
|
||||
branches: [ master, development, release-* ]
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpts-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
|
||||
4
.github/workflows/classic-benchmark-ci.yml
vendored
4
.github/workflows/classic-benchmark-ci.yml
vendored
@@ -2,13 +2,13 @@ name: Classic - AGBenchmark CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, development, ci-test* ]
|
||||
branches: [ master, dev, ci-test* ]
|
||||
paths:
|
||||
- 'classic/benchmark/**'
|
||||
- '!classic/benchmark/reports/**'
|
||||
- .github/workflows/classic-benchmark-ci.yml
|
||||
pull_request:
|
||||
branches: [ master, development, release-* ]
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- 'classic/benchmark/**'
|
||||
- '!classic/benchmark/reports/**'
|
||||
|
||||
4
.github/workflows/classic-forge-ci.yml
vendored
4
.github/workflows/classic-forge-ci.yml
vendored
@@ -2,13 +2,13 @@ name: Classic - Forge CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, development, ci-test* ]
|
||||
branches: [ master, dev, ci-test* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-forge-ci.yml'
|
||||
- 'classic/forge/**'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
pull_request:
|
||||
branches: [ master, development, release-* ]
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-forge-ci.yml'
|
||||
- 'classic/forge/**'
|
||||
|
||||
2
.github/workflows/classic-frontend-ci.yml
vendored
2
.github/workflows/classic-frontend-ci.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
|
||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||
if: github.event_name == 'push'
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
add-paths: classic/frontend/build/web
|
||||
base: ${{ github.ref_name }}
|
||||
|
||||
4
.github/workflows/classic-python-checks.yml
vendored
4
.github/workflows/classic-python-checks.yml
vendored
@@ -2,7 +2,7 @@ name: Classic - Python checks
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, development, ci-test* ]
|
||||
branches: [ master, dev, ci-test* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-python-checks-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
@@ -11,7 +11,7 @@ on:
|
||||
- '**.py'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
pull_request:
|
||||
branches: [ master, development, release-* ]
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-python-checks-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
|
||||
182
.github/workflows/platform-autgpt-deploy-prod.yml
vendored
Normal file
182
.github/workflows/platform-autgpt-deploy-prod.yml
vendored
Normal file
@@ -0,0 +1,182 @@
|
||||
name: AutoGPT Platform - Build, Push, and Deploy Prod Environment
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
permissions:
|
||||
contents: 'read'
|
||||
id-token: 'write'
|
||||
|
||||
env:
|
||||
PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }}
|
||||
GKE_CLUSTER: prod-gke-cluster
|
||||
GKE_ZONE: us-central1-a
|
||||
NAMESPACE: prod-agpt
|
||||
|
||||
jobs:
|
||||
migrate:
|
||||
environment: production
|
||||
name: Run migrations for AutoGPT Platform
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install prisma
|
||||
|
||||
- name: Run Backend Migrations
|
||||
working-directory: ./autogpt_platform/backend
|
||||
run: |
|
||||
python -m prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
- name: Run Market Migrations
|
||||
working-directory: ./autogpt_platform/market
|
||||
run: |
|
||||
python -m prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ secrets.MARKET_DATABASE_URL }}
|
||||
|
||||
build-push-deploy:
|
||||
environment: production
|
||||
name: Build, Push, and Deploy
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- id: 'auth'
|
||||
uses: 'google-github-actions/auth@v2'
|
||||
with:
|
||||
workload_identity_provider: 'projects/1021527134101/locations/global/workloadIdentityPools/prod-pool/providers/github'
|
||||
service_account: 'prod-github-actions-sa@agpt-prod.iam.gserviceaccount.com'
|
||||
token_format: 'access_token'
|
||||
create_credentials_file: true
|
||||
|
||||
- name: 'Set up Cloud SDK'
|
||||
uses: 'google-github-actions/setup-gcloud@v2'
|
||||
|
||||
- name: 'Configure Docker'
|
||||
run: |
|
||||
gcloud auth configure-docker us-east1-docker.pkg.dev
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Cache Docker layers
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /tmp/.buildx-cache
|
||||
key: ${{ runner.os }}-buildx-${{ github.sha }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-buildx-
|
||||
|
||||
- name: Check for changes
|
||||
id: check_changes
|
||||
run: |
|
||||
git fetch origin master
|
||||
BACKEND_CHANGED=$(git diff --name-only origin/master HEAD | grep "^autogpt_platform/backend/" && echo "true" || echo "false")
|
||||
FRONTEND_CHANGED=$(git diff --name-only origin/master HEAD | grep "^autogpt_platform/frontend/" && echo "true" || echo "false")
|
||||
MARKET_CHANGED=$(git diff --name-only origin/master HEAD | grep "^autogpt_platform/market/" && echo "true" || echo "false")
|
||||
echo "backend_changed=$BACKEND_CHANGED" >> $GITHUB_OUTPUT
|
||||
echo "frontend_changed=$FRONTEND_CHANGED" >> $GITHUB_OUTPUT
|
||||
echo "market_changed=$MARKET_CHANGED" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Get GKE credentials
|
||||
uses: 'google-github-actions/get-gke-credentials@v2'
|
||||
with:
|
||||
cluster_name: ${{ env.GKE_CLUSTER }}
|
||||
location: ${{ env.GKE_ZONE }}
|
||||
|
||||
- name: Build and Push Backend
|
||||
if: steps.check_changes.outputs.backend_changed == 'true'
|
||||
uses: docker/build-push-action@v2
|
||||
with:
|
||||
context: .
|
||||
file: ./autogpt_platform/backend/Dockerfile
|
||||
push: true
|
||||
tags: us-east1-docker.pkg.dev/agpt-prod/agpt-backend-prod/agpt-backend-prod:${{ github.sha }}
|
||||
cache-from: type=local,src=/tmp/.buildx-cache
|
||||
cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
||||
|
||||
- name: Build and Push Frontend
|
||||
if: steps.check_changes.outputs.frontend_changed == 'true'
|
||||
uses: docker/build-push-action@v2
|
||||
with:
|
||||
context: .
|
||||
file: ./autogpt_platform/frontend/Dockerfile
|
||||
push: true
|
||||
tags: us-east1-docker.pkg.dev/agpt-prod/agpt-frontend-prod/agpt-frontend-prod:${{ github.sha }}
|
||||
cache-from: type=local,src=/tmp/.buildx-cache
|
||||
cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
||||
|
||||
- name: Build and Push Market
|
||||
if: steps.check_changes.outputs.market_changed == 'true'
|
||||
uses: docker/build-push-action@v2
|
||||
with:
|
||||
context: .
|
||||
file: ./autogpt_platform/market/Dockerfile
|
||||
push: true
|
||||
tags: us-east1-docker.pkg.dev/agpt-prod/agpt-market-prod/agpt-market-prod:${{ github.sha }}
|
||||
cache-from: type=local,src=/tmp/.buildx-cache
|
||||
cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
||||
|
||||
- name: Move cache
|
||||
run: |
|
||||
rm -rf /tmp/.buildx-cache
|
||||
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4
|
||||
with:
|
||||
version: v3.4.0
|
||||
|
||||
- name: Deploy Backend
|
||||
if: steps.check_changes.outputs.backend_changed == 'true'
|
||||
run: |
|
||||
helm upgrade autogpt-server ./autogpt-server \
|
||||
--namespace ${{ env.NAMESPACE }} \
|
||||
-f autogpt-server/values.yaml \
|
||||
-f autogpt-server/values.prod.yaml \
|
||||
--set image.tag=${{ github.sha }}
|
||||
|
||||
- name: Deploy Websocket
|
||||
if: steps.check_changes.outputs.backend_changed == 'true'
|
||||
run: |
|
||||
helm upgrade autogpt-websocket-server ./autogpt-websocket-server \
|
||||
--namespace ${{ env.NAMESPACE }} \
|
||||
-f autogpt-websocket-server/values.yaml \
|
||||
-f autogpt-websocket-server/values.prod.yaml \
|
||||
--set image.tag=${{ github.sha }}
|
||||
|
||||
- name: Deploy Market
|
||||
if: steps.check_changes.outputs.market_changed == 'true'
|
||||
run: |
|
||||
helm upgrade autogpt-market ./autogpt-market \
|
||||
--namespace ${{ env.NAMESPACE }} \
|
||||
-f autogpt-market/values.yaml \
|
||||
-f autogpt-market/values.prod.yaml \
|
||||
--set image.tag=${{ github.sha }}
|
||||
|
||||
- name: Deploy Frontend
|
||||
if: steps.check_changes.outputs.frontend_changed == 'true'
|
||||
run: |
|
||||
helm upgrade autogpt-builder ./autogpt-builder \
|
||||
--namespace ${{ env.NAMESPACE }} \
|
||||
-f autogpt-builder/values.yaml \
|
||||
-f autogpt-builder/values.prod.yaml \
|
||||
--set image.tag=${{ github.sha }}
|
||||
186
.github/workflows/platform-autogpt-deploy.yaml
vendored
Normal file
186
.github/workflows/platform-autogpt-deploy.yaml
vendored
Normal file
@@ -0,0 +1,186 @@
|
||||
name: AutoGPT Platform - Build, Push, and Deploy Dev Environment
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ dev ]
|
||||
paths:
|
||||
- 'autogpt_platform/backend/**'
|
||||
- 'autogpt_platform/frontend/**'
|
||||
- 'autogpt_platform/market/**'
|
||||
|
||||
permissions:
|
||||
contents: 'read'
|
||||
id-token: 'write'
|
||||
|
||||
env:
|
||||
PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }}
|
||||
GKE_CLUSTER: dev-gke-cluster
|
||||
GKE_ZONE: us-central1-a
|
||||
NAMESPACE: dev-agpt
|
||||
|
||||
jobs:
|
||||
migrate:
|
||||
environment: develop
|
||||
name: Run migrations for AutoGPT Platform
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install prisma
|
||||
|
||||
- name: Run Backend Migrations
|
||||
working-directory: ./autogpt_platform/backend
|
||||
run: |
|
||||
python -m prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
- name: Run Market Migrations
|
||||
working-directory: ./autogpt_platform/market
|
||||
run: |
|
||||
python -m prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ secrets.MARKET_DATABASE_URL }}
|
||||
|
||||
build-push-deploy:
|
||||
name: Build, Push, and Deploy
|
||||
needs: migrate
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- id: 'auth'
|
||||
uses: 'google-github-actions/auth@v2'
|
||||
with:
|
||||
workload_identity_provider: 'projects/638488734936/locations/global/workloadIdentityPools/dev-pool/providers/github'
|
||||
service_account: 'dev-github-actions-sa@agpt-dev.iam.gserviceaccount.com'
|
||||
token_format: 'access_token'
|
||||
create_credentials_file: true
|
||||
|
||||
- name: 'Set up Cloud SDK'
|
||||
uses: 'google-github-actions/setup-gcloud@v2'
|
||||
|
||||
- name: 'Configure Docker'
|
||||
run: |
|
||||
gcloud auth configure-docker us-east1-docker.pkg.dev
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Cache Docker layers
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /tmp/.buildx-cache
|
||||
key: ${{ runner.os }}-buildx-${{ github.sha }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-buildx-
|
||||
|
||||
- name: Check for changes
|
||||
id: check_changes
|
||||
run: |
|
||||
git fetch origin dev
|
||||
BACKEND_CHANGED=$(git diff --name-only origin/dev HEAD | grep "^autogpt_platform/backend/" && echo "true" || echo "false")
|
||||
FRONTEND_CHANGED=$(git diff --name-only origin/dev HEAD | grep "^autogpt_platform/frontend/" && echo "true" || echo "false")
|
||||
MARKET_CHANGED=$(git diff --name-only origin/dev HEAD | grep "^autogpt_platform/market/" && echo "true" || echo "false")
|
||||
echo "backend_changed=$BACKEND_CHANGED" >> $GITHUB_OUTPUT
|
||||
echo "frontend_changed=$FRONTEND_CHANGED" >> $GITHUB_OUTPUT
|
||||
echo "market_changed=$MARKET_CHANGED" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Get GKE credentials
|
||||
uses: 'google-github-actions/get-gke-credentials@v2'
|
||||
with:
|
||||
cluster_name: ${{ env.GKE_CLUSTER }}
|
||||
location: ${{ env.GKE_ZONE }}
|
||||
|
||||
- name: Build and Push Backend
|
||||
if: steps.check_changes.outputs.backend_changed == 'true'
|
||||
uses: docker/build-push-action@v2
|
||||
with:
|
||||
context: .
|
||||
file: ./autogpt_platform/backend/Dockerfile
|
||||
push: true
|
||||
tags: us-east1-docker.pkg.dev/agpt-dev/agpt-backend-dev/agpt-backend-dev:${{ github.sha }}
|
||||
cache-from: type=local,src=/tmp/.buildx-cache
|
||||
cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
||||
|
||||
- name: Build and Push Frontend
|
||||
if: steps.check_changes.outputs.frontend_changed == 'true'
|
||||
uses: docker/build-push-action@v2
|
||||
with:
|
||||
context: .
|
||||
file: ./autogpt_platform/frontend/Dockerfile
|
||||
push: true
|
||||
tags: us-east1-docker.pkg.dev/agpt-dev/agpt-frontend-dev/agpt-frontend-dev:${{ github.sha }}
|
||||
cache-from: type=local,src=/tmp/.buildx-cache
|
||||
cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
||||
|
||||
- name: Build and Push Market
|
||||
if: steps.check_changes.outputs.market_changed == 'true'
|
||||
uses: docker/build-push-action@v2
|
||||
with:
|
||||
context: .
|
||||
file: ./autogpt_platform/market/Dockerfile
|
||||
push: true
|
||||
tags: us-east1-docker.pkg.dev/agpt-dev/agpt-market-dev/agpt-market-dev:${{ github.sha }}
|
||||
cache-from: type=local,src=/tmp/.buildx-cache
|
||||
cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
||||
|
||||
- name: Move cache
|
||||
run: |
|
||||
rm -rf /tmp/.buildx-cache
|
||||
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4
|
||||
with:
|
||||
version: v3.4.0
|
||||
|
||||
- name: Deploy Backend
|
||||
if: steps.check_changes.outputs.backend_changed == 'true'
|
||||
run: |
|
||||
helm upgrade autogpt-server ./autogpt-server \
|
||||
--namespace ${{ env.NAMESPACE }} \
|
||||
-f autogpt-server/values.yaml \
|
||||
-f autogpt-server/values.dev.yaml \
|
||||
--set image.tag=${{ github.sha }}
|
||||
|
||||
- name: Deploy Websocket
|
||||
if: steps.check_changes.outputs.backend_changed == 'true'
|
||||
run: |
|
||||
helm upgrade autogpt-websocket-server ./autogpt-websocket-server \
|
||||
--namespace ${{ env.NAMESPACE }} \
|
||||
-f autogpt-websocket-server/values.yaml \
|
||||
-f autogpt-websocket-server/values.dev.yaml \
|
||||
--set image.tag=${{ github.sha }}
|
||||
|
||||
- name: Deploy Market
|
||||
if: steps.check_changes.outputs.market_changed == 'true'
|
||||
run: |
|
||||
helm upgrade autogpt-market ./autogpt-market \
|
||||
--namespace ${{ env.NAMESPACE }} \
|
||||
-f autogpt-market/values.yaml \
|
||||
-f autogpt-market/values.dev.yaml \
|
||||
--set image.tag=${{ github.sha }}
|
||||
|
||||
- name: Deploy Frontend
|
||||
if: steps.check_changes.outputs.frontend_changed == 'true'
|
||||
run: |
|
||||
helm upgrade autogpt-builder ./autogpt-builder \
|
||||
--namespace ${{ env.NAMESPACE }} \
|
||||
-f autogpt-builder/values.yaml \
|
||||
-f autogpt-builder/values.dev.yaml \
|
||||
--set image.tag=${{ github.sha }}
|
||||
@@ -2,7 +2,7 @@ name: AutoGPT Platform - Infra
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
branches: [ master, dev ]
|
||||
paths:
|
||||
- '.github/workflows/platform-autogpt-infra-ci.yml'
|
||||
- 'autogpt_platform/infra/**'
|
||||
@@ -36,12 +36,12 @@ jobs:
|
||||
tflint_changed_only: false
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
uses: azure/setup-helm@v4
|
||||
with:
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.0
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
|
||||
20
.github/workflows/platform-backend-ci.yml
vendored
20
.github/workflows/platform-backend-ci.yml
vendored
@@ -2,12 +2,12 @@ name: AutoGPT Platform - Backend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, development, ci-test*]
|
||||
branches: [master, dev, ci-test*]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- "autogpt_platform/backend/**"
|
||||
pull_request:
|
||||
branches: [master, development, release-*]
|
||||
branches: [master, dev, release-*]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- "autogpt_platform/backend/**"
|
||||
@@ -32,6 +32,14 @@ jobs:
|
||||
python-version: ["3.10"]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: bitnami/redis:6.2
|
||||
env:
|
||||
REDIS_PASSWORD: testpassword
|
||||
ports:
|
||||
- 6379:6379
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -96,9 +104,9 @@ jobs:
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -vv -o log_cli=true -o log_cli_level=DEBUG test
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG test
|
||||
else
|
||||
poetry run pytest -vv test
|
||||
poetry run pytest -s -vv test
|
||||
fi
|
||||
if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
env:
|
||||
@@ -107,6 +115,10 @@ jobs:
|
||||
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: 'localhost'
|
||||
REDIS_PORT: '6379'
|
||||
REDIS_PASSWORD: 'testpassword'
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
|
||||
33
.github/workflows/platform-frontend-ci.yml
vendored
33
.github/workflows/platform-frontend-ci.yml
vendored
@@ -2,7 +2,7 @@ name: AutoGPT Platform - Frontend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master]
|
||||
branches: [master, dev]
|
||||
paths:
|
||||
- ".github/workflows/platform-frontend-ci.yml"
|
||||
- "autogpt_platform/frontend/**"
|
||||
@@ -29,24 +29,37 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
npm install
|
||||
|
||||
- name: Check formatting with Prettier
|
||||
run: |
|
||||
npx prettier --check .
|
||||
yarn install --frozen-lockfile
|
||||
|
||||
- name: Run lint
|
||||
run: |
|
||||
npm run lint
|
||||
yarn lint
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
# this might remove tools that are actually needed,
|
||||
# if set to "true" but frees about 6 GB
|
||||
tool-cache: false
|
||||
|
||||
# all of these default to true, but feel free to set to
|
||||
# "false" if necessary for your workflow
|
||||
android: false
|
||||
dotnet: false
|
||||
haskell: false
|
||||
large-packages: true
|
||||
docker-images: true
|
||||
swap-storage: true
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
@@ -62,18 +75,18 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
npm install
|
||||
yarn install --frozen-lockfile
|
||||
|
||||
- name: Setup Builder .env
|
||||
run: |
|
||||
cp .env.example .env
|
||||
|
||||
- name: Install Playwright Browsers
|
||||
run: npx playwright install --with-deps
|
||||
run: yarn playwright install --with-deps
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
npm run test
|
||||
yarn test
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: ${{ !cancelled() }}
|
||||
|
||||
125
.github/workflows/platform-market-ci.yml
vendored
Normal file
125
.github/workflows/platform-market-ci.yml
vendored
Normal file
@@ -0,0 +1,125 @@
|
||||
name: AutoGPT Platform - Backend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev, ci-test*]
|
||||
paths:
|
||||
- ".github/workflows/platform-market-ci.yml"
|
||||
- "autogpt_platform/market/**"
|
||||
pull_request:
|
||||
branches: [master, dev, release-*]
|
||||
paths:
|
||||
- ".github/workflows/platform-market-ci.yml"
|
||||
- "autogpt_platform/market/**"
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('backend-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
cancel-in-progress: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: autogpt_platform/market
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Setup Supabase
|
||||
uses: supabase/setup-cli@v1
|
||||
with:
|
||||
version: latest
|
||||
|
||||
- id: get_date
|
||||
name: Get date
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/market/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
run: poetry run prisma generate
|
||||
|
||||
- id: supabase
|
||||
name: Start Supabase
|
||||
working-directory: .
|
||||
run: |
|
||||
supabase init
|
||||
supabase start --exclude postgres-meta,realtime,storage-api,imgproxy,inbucket,studio,edge-runtime,logflare,vector,supavisor
|
||||
supabase status -o env | sed 's/="/=/; s/"$//' >> $GITHUB_OUTPUT
|
||||
# outputs:
|
||||
# DB_URL, API_URL, GRAPHQL_URL, ANON_KEY, SERVICE_ROLE_KEY, JWT_SECRET
|
||||
|
||||
- name: Run Database Migrations
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
env:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
- id: lint
|
||||
name: Run Linter
|
||||
run: poetry run lint
|
||||
|
||||
# Tests comment out because they do not work with prisma mock, nor have they been updated since they were created
|
||||
# - name: Run pytest with coverage
|
||||
# run: |
|
||||
# if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
# poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG test
|
||||
# else
|
||||
# poetry run pytest -s -vv test
|
||||
# fi
|
||||
# if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
# env:
|
||||
# LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
# DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
# SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
|
||||
# SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
# SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
# REDIS_HOST: 'localhost'
|
||||
# REDIS_PORT: '6379'
|
||||
# REDIS_PASSWORD: 'testpassword'
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
RUN_ENV: local
|
||||
PORT: 8080
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# token: ${{ secrets.CODECOV_TOKEN }}
|
||||
# flags: backend,${{ runner.os }}
|
||||
21
.github/workflows/repo-pr-enforce-base-branch.yml
vendored
Normal file
21
.github/workflows/repo-pr-enforce-base-branch.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
name: Repo - Enforce dev as base branch
|
||||
on:
|
||||
pull_request_target:
|
||||
branches: [ master ]
|
||||
types: [ opened ]
|
||||
|
||||
jobs:
|
||||
check_pr_target:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Check if PR is from dev or hotfix
|
||||
if: ${{ !(startsWith(github.event.pull_request.head.ref, 'hotfix/') || github.event.pull_request.head.ref == 'dev') }}
|
||||
run: |
|
||||
gh pr comment ${{ github.event.number }} --repo "$REPO" \
|
||||
--body $'This PR targets the `master` branch but does not come from `dev` or a `hotfix/*` branch.\n\nAutomatically setting the base branch to `dev`.'
|
||||
gh pr edit ${{ github.event.number }} --base dev --repo "$REPO"
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
REPO: ${{ github.repository }}
|
||||
2
.github/workflows/repo-pr-label.yml
vendored
2
.github/workflows/repo-pr-label.yml
vendored
@@ -3,7 +3,7 @@ name: Repo - Pull Request auto-label
|
||||
on:
|
||||
# So that PRs touching the same files as the push are updated
|
||||
push:
|
||||
branches: [ master, development, release-* ]
|
||||
branches: [ master, dev, release-* ]
|
||||
paths-ignore:
|
||||
- 'classic/forge/tests/vcr_cassettes'
|
||||
- 'classic/benchmark/reports/**'
|
||||
|
||||
@@ -11,7 +11,7 @@ Also check out our [🚀 Roadmap][roadmap] for information about our priorities
|
||||
[kanban board]: https://github.com/orgs/Significant-Gravitas/projects/1
|
||||
|
||||
## Contributing to the AutoGPT Platform Folder
|
||||
All contributions to [the autogpt_platform folder](https://github.com/Significant-Gravitas/AutoGPT/blob/master/autogpt_platform) will be under our [Contribution License Agreement](https://github.com/Significant-Gravitas/AutoGPT/blob/master/autogpt_platform/Contributor%20License%20Agreement%20(CLA).md). By making a pull request contributing to this folder, you agree to the terms of our CLA for your contribution.
|
||||
All contributions to [the autogpt_platform folder](https://github.com/Significant-Gravitas/AutoGPT/blob/master/autogpt_platform) will be under our [Contribution License Agreement](https://github.com/Significant-Gravitas/AutoGPT/blob/master/autogpt_platform/Contributor%20License%20Agreement%20(CLA).md). By making a pull request contributing to this folder, you agree to the terms of our CLA for your contribution. All contributions to other folders will be under the MIT license.
|
||||
|
||||
## In short
|
||||
1. Avoid duplicate work, issues, PRs etc.
|
||||
|
||||
8
LICENSE
8
LICENSE
@@ -1,7 +1,13 @@
|
||||
All portions of this repository are under one of two licenses. The majority of the AutoGPT repository is under the MIT License below. The autogpt_platform folder is under the
|
||||
Polyform Shield License.
|
||||
|
||||
|
||||
MIT License
|
||||
|
||||
|
||||
Copyright (c) 2023 Toran Bruce Richards
|
||||
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
@@ -9,9 +15,11 @@ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
|
||||
19
README.md
19
README.md
@@ -65,6 +65,7 @@ Here are two examples of what you can do with AutoGPT:
|
||||
These examples show just a glimpse of what you can achieve with AutoGPT! You can create customized workflows to build agents for any use case.
|
||||
|
||||
---
|
||||
### Mission and Licencing
|
||||
Our mission is to provide the tools, so that you can focus on what matters:
|
||||
|
||||
- 🏗️ **Building** - Lay the foundation for something amazing.
|
||||
@@ -77,6 +78,13 @@ Be part of the revolution! **AutoGPT** is here to stay, at the forefront of AI i
|
||||
 | 
|
||||
**🚀 [Contributing](CONTRIBUTING.md)**
|
||||
|
||||
**Licensing:**
|
||||
|
||||
MIT License: The majority of the AutoGPT repository is under the MIT License.
|
||||
|
||||
Polyform Shield License: This license applies to the autogpt_platform folder.
|
||||
|
||||
For more information, see https://agpt.co/blog/introducing-the-autogpt-platform
|
||||
|
||||
---
|
||||
## 🤖 AutoGPT Classic
|
||||
@@ -101,7 +109,7 @@ This guide will walk you through the process of creating your own agent and usin
|
||||
|
||||
📦 [`agbenchmark`](https://pypi.org/project/agbenchmark/) on Pypi
|
||||
 | 
|
||||
📘 [Learn More](https://github.com/Significant-Gravitas/AutoGPT/blob/master/benchmark) about the Benchmark
|
||||
📘 [Learn More](https://github.com/Significant-Gravitas/AutoGPT/tree/master/classic/benchmark) about the Benchmark
|
||||
|
||||
### 💻 UI
|
||||
|
||||
@@ -150,6 +158,8 @@ To maintain a uniform standard and ensure seamless compatibility with many curre
|
||||
|
||||
---
|
||||
|
||||
## Stars stats
|
||||
|
||||
<p align="center">
|
||||
<a href="https://star-history.com/#Significant-Gravitas/AutoGPT">
|
||||
<picture>
|
||||
@@ -159,3 +169,10 @@ To maintain a uniform standard and ensure seamless compatibility with many curre
|
||||
</picture>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
|
||||
## ⚡ Contributors
|
||||
|
||||
<a href="https://github.com/Significant-Gravitas/AutoGPT/graphs/contributors" alt="View Contributors">
|
||||
<img src="https://contrib.rocks/image?repo=Significant-Gravitas/AutoGPT&max=1000&columns=10" alt="Contributors" />
|
||||
</a>
|
||||
|
||||
@@ -149,6 +149,3 @@ To persist data for PostgreSQL and Redis, you can modify the `docker-compose.yml
|
||||
3. Save the file and run `docker compose up -d` to apply the changes.
|
||||
|
||||
This configuration will create named volumes for PostgreSQL and Redis, ensuring that your data persists across container restarts.
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from .store import SupabaseIntegrationCredentialsStore
|
||||
from .types import APIKeyCredentials, OAuth2Credentials
|
||||
from .types import Credentials, APIKeyCredentials, OAuth2Credentials
|
||||
|
||||
__all__ = [
|
||||
"SupabaseIntegrationCredentialsStore",
|
||||
"Credentials",
|
||||
"APIKeyCredentials",
|
||||
"OAuth2Credentials",
|
||||
]
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from supabase import Client
|
||||
if TYPE_CHECKING:
|
||||
from redis import Redis
|
||||
from backend.executor.database import DatabaseManager
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from autogpt_libs.utils.synchronize import RedisKeyedMutex
|
||||
|
||||
from .types import (
|
||||
Credentials,
|
||||
@@ -14,26 +19,36 @@ from .types import (
|
||||
|
||||
|
||||
class SupabaseIntegrationCredentialsStore:
|
||||
def __init__(self, supabase: Client):
|
||||
self.supabase = supabase
|
||||
def __init__(self, redis: "Redis"):
|
||||
self.locks = RedisKeyedMutex(redis)
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def db_manager(self) -> "DatabaseManager":
|
||||
from backend.executor.database import DatabaseManager
|
||||
from backend.util.service import get_service_client
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
def add_creds(self, user_id: str, credentials: Credentials) -> None:
|
||||
if self.get_creds_by_id(user_id, credentials.id):
|
||||
raise ValueError(
|
||||
f"Can not re-create existing credentials with ID {credentials.id} "
|
||||
f"for user with ID {user_id}"
|
||||
with self.locked_user_metadata(user_id):
|
||||
if self.get_creds_by_id(user_id, credentials.id):
|
||||
raise ValueError(
|
||||
f"Can not re-create existing credentials #{credentials.id} "
|
||||
f"for user #{user_id}"
|
||||
)
|
||||
self._set_user_integration_creds(
|
||||
user_id, [*self.get_all_creds(user_id), credentials]
|
||||
)
|
||||
self._set_user_integration_creds(
|
||||
user_id, [*self.get_all_creds(user_id), credentials]
|
||||
)
|
||||
|
||||
def get_all_creds(self, user_id: str) -> list[Credentials]:
|
||||
user_metadata = self._get_user_metadata(user_id)
|
||||
return UserMetadata.model_validate(user_metadata).integration_credentials
|
||||
return UserMetadata.model_validate(
|
||||
user_metadata.model_dump()
|
||||
).integration_credentials
|
||||
|
||||
def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None:
|
||||
credentials = self.get_all_creds(user_id)
|
||||
return next((c for c in credentials if c.id == credentials_id), None)
|
||||
all_credentials = self.get_all_creds(user_id)
|
||||
return next((c for c in all_credentials if c.id == credentials_id), None)
|
||||
|
||||
def get_creds_by_provider(self, user_id: str, provider: str) -> list[Credentials]:
|
||||
credentials = self.get_all_creds(user_id)
|
||||
@@ -44,65 +59,81 @@ class SupabaseIntegrationCredentialsStore:
|
||||
return list(set(c.provider for c in credentials))
|
||||
|
||||
def update_creds(self, user_id: str, updated: Credentials) -> None:
|
||||
current = self.get_creds_by_id(user_id, updated.id)
|
||||
if not current:
|
||||
raise ValueError(
|
||||
f"Credentials with ID {updated.id} "
|
||||
f"for user with ID {user_id} not found"
|
||||
)
|
||||
if type(current) is not type(updated):
|
||||
raise TypeError(
|
||||
f"Can not update credentials with ID {updated.id} "
|
||||
f"from type {type(current)} "
|
||||
f"to type {type(updated)}"
|
||||
)
|
||||
with self.locked_user_metadata(user_id):
|
||||
current = self.get_creds_by_id(user_id, updated.id)
|
||||
if not current:
|
||||
raise ValueError(
|
||||
f"Credentials with ID {updated.id} "
|
||||
f"for user with ID {user_id} not found"
|
||||
)
|
||||
if type(current) is not type(updated):
|
||||
raise TypeError(
|
||||
f"Can not update credentials with ID {updated.id} "
|
||||
f"from type {type(current)} "
|
||||
f"to type {type(updated)}"
|
||||
)
|
||||
|
||||
# Ensure no scopes are removed when updating credentials
|
||||
if (
|
||||
isinstance(updated, OAuth2Credentials)
|
||||
and isinstance(current, OAuth2Credentials)
|
||||
and not set(updated.scopes).issuperset(current.scopes)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Can not update credentials with ID {updated.id} "
|
||||
f"and scopes {current.scopes} "
|
||||
f"to more restrictive set of scopes {updated.scopes}"
|
||||
)
|
||||
# Ensure no scopes are removed when updating credentials
|
||||
if (
|
||||
isinstance(updated, OAuth2Credentials)
|
||||
and isinstance(current, OAuth2Credentials)
|
||||
and not set(updated.scopes).issuperset(current.scopes)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Can not update credentials with ID {updated.id} "
|
||||
f"and scopes {current.scopes} "
|
||||
f"to more restrictive set of scopes {updated.scopes}"
|
||||
)
|
||||
|
||||
# Update the credentials
|
||||
updated_credentials_list = [
|
||||
updated if c.id == updated.id else c for c in self.get_all_creds(user_id)
|
||||
]
|
||||
self._set_user_integration_creds(user_id, updated_credentials_list)
|
||||
# Update the credentials
|
||||
updated_credentials_list = [
|
||||
updated if c.id == updated.id else c
|
||||
for c in self.get_all_creds(user_id)
|
||||
]
|
||||
self._set_user_integration_creds(user_id, updated_credentials_list)
|
||||
|
||||
def delete_creds_by_id(self, user_id: str, credentials_id: str) -> None:
|
||||
filtered_credentials = [
|
||||
c for c in self.get_all_creds(user_id) if c.id != credentials_id
|
||||
]
|
||||
self._set_user_integration_creds(user_id, filtered_credentials)
|
||||
with self.locked_user_metadata(user_id):
|
||||
filtered_credentials = [
|
||||
c for c in self.get_all_creds(user_id) if c.id != credentials_id
|
||||
]
|
||||
self._set_user_integration_creds(user_id, filtered_credentials)
|
||||
|
||||
async def store_state_token(self, user_id: str, provider: str) -> str:
|
||||
def store_state_token(self, user_id: str, provider: str, scopes: list[str]) -> str:
|
||||
token = secrets.token_urlsafe(32)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(minutes=10)
|
||||
|
||||
state = OAuthState(
|
||||
token=token, provider=provider, expires_at=int(expires_at.timestamp())
|
||||
token=token,
|
||||
provider=provider,
|
||||
expires_at=int(expires_at.timestamp()),
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
user_metadata = self._get_user_metadata(user_id)
|
||||
oauth_states = user_metadata.get("integration_oauth_states", [])
|
||||
oauth_states.append(state.model_dump())
|
||||
user_metadata["integration_oauth_states"] = oauth_states
|
||||
with self.locked_user_metadata(user_id):
|
||||
user_metadata = self._get_user_metadata(user_id)
|
||||
oauth_states = user_metadata.integration_oauth_states
|
||||
oauth_states.append(state.model_dump())
|
||||
user_metadata.integration_oauth_states = oauth_states
|
||||
|
||||
self.supabase.auth.admin.update_user_by_id(
|
||||
user_id, {"user_metadata": user_metadata}
|
||||
)
|
||||
self.db_manager.update_user_metadata(
|
||||
user_id=user_id, metadata=user_metadata
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
async def verify_state_token(self, user_id: str, token: str, provider: str) -> bool:
|
||||
def get_any_valid_scopes_from_state_token(
|
||||
self, user_id: str, token: str, provider: str
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get the valid scopes from the OAuth state token. This will return any valid scopes
|
||||
from any OAuth state token for the given provider. If no valid scopes are found,
|
||||
an empty list is returned. DO NOT RELY ON THIS TOKEN TO AUTHENTICATE A USER, AS IT
|
||||
IS TO CHECK IF THE USER HAS GIVEN PERMISSIONS TO THE APPLICATION BEFORE EXCHANGING
|
||||
THE CODE FOR TOKENS.
|
||||
"""
|
||||
user_metadata = self._get_user_metadata(user_id)
|
||||
oauth_states = user_metadata.get("integration_oauth_states", [])
|
||||
oauth_states = user_metadata.integration_oauth_states
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
valid_state = next(
|
||||
@@ -117,13 +148,33 @@ class SupabaseIntegrationCredentialsStore:
|
||||
)
|
||||
|
||||
if valid_state:
|
||||
# Remove the used state
|
||||
oauth_states.remove(valid_state)
|
||||
user_metadata["integration_oauth_states"] = oauth_states
|
||||
self.supabase.auth.admin.update_user_by_id(
|
||||
user_id, {"user_metadata": user_metadata}
|
||||
return valid_state.get("scopes", [])
|
||||
|
||||
return []
|
||||
|
||||
def verify_state_token(self, user_id: str, token: str, provider: str) -> bool:
|
||||
with self.locked_user_metadata(user_id):
|
||||
user_metadata = self._get_user_metadata(user_id)
|
||||
oauth_states = user_metadata.integration_oauth_states
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
valid_state = next(
|
||||
(
|
||||
state
|
||||
for state in oauth_states
|
||||
if state["token"] == token
|
||||
and state["provider"] == provider
|
||||
and state["expires_at"] > now.timestamp()
|
||||
),
|
||||
None,
|
||||
)
|
||||
return True
|
||||
|
||||
if valid_state:
|
||||
# Remove the used state
|
||||
oauth_states.remove(valid_state)
|
||||
user_metadata.integration_oauth_states = oauth_states
|
||||
self.db_manager.update_user_metadata(user_id, user_metadata)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -131,15 +182,13 @@ class SupabaseIntegrationCredentialsStore:
|
||||
self, user_id: str, credentials: list[Credentials]
|
||||
) -> None:
|
||||
raw_metadata = self._get_user_metadata(user_id)
|
||||
raw_metadata.update(
|
||||
{"integration_credentials": [c.model_dump() for c in credentials]}
|
||||
)
|
||||
self.supabase.auth.admin.update_user_by_id(
|
||||
user_id, {"user_metadata": raw_metadata}
|
||||
)
|
||||
raw_metadata.integration_credentials = [c.model_dump() for c in credentials]
|
||||
self.db_manager.update_user_metadata(user_id, raw_metadata)
|
||||
|
||||
def _get_user_metadata(self, user_id: str) -> UserMetadataRaw:
|
||||
response = self.supabase.auth.admin.get_user_by_id(user_id)
|
||||
if not response.user:
|
||||
raise ValueError(f"User with ID {user_id} not found")
|
||||
return cast(UserMetadataRaw, response.user.user_metadata)
|
||||
metadata: UserMetadataRaw = self.db_manager.get_user_metadata(user_id=user_id)
|
||||
return metadata
|
||||
|
||||
def locked_user_metadata(self, user_id: str):
|
||||
key = (self.db_manager, f"user:{user_id}", "metadata")
|
||||
return self.locks.locked(key)
|
||||
|
||||
@@ -56,6 +56,7 @@ class OAuthState(BaseModel):
|
||||
token: str
|
||||
provider: str
|
||||
expires_at: int
|
||||
scopes: list[str]
|
||||
"""Unix timestamp (seconds) indicating when this OAuth state expires"""
|
||||
|
||||
|
||||
@@ -64,6 +65,6 @@ class UserMetadata(BaseModel):
|
||||
integration_oauth_states: list[OAuthState] = Field(default_factory=list)
|
||||
|
||||
|
||||
class UserMetadataRaw(TypedDict, total=False):
|
||||
integration_credentials: list[dict]
|
||||
integration_oauth_states: list[dict]
|
||||
class UserMetadataRaw(BaseModel):
|
||||
integration_credentials: list[dict] = Field(default_factory=list)
|
||||
integration_oauth_states: list[dict] = Field(default_factory=list)
|
||||
|
||||
20
autogpt_platform/autogpt_libs/autogpt_libs/utils/cache.py
Normal file
20
autogpt_platform/autogpt_libs/autogpt_libs/utils/cache.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Callable, TypeVar, ParamSpec
|
||||
import threading
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
|
||||
thread_local = threading.local()
|
||||
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
cache = getattr(thread_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = thread_local.cache = {}
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
return wrapper
|
||||
@@ -0,0 +1,56 @@
|
||||
from contextlib import contextmanager
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from expiringdict import ExpiringDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
|
||||
class RedisKeyedMutex:
|
||||
"""
|
||||
This class provides a mutex that can be locked and unlocked by a specific key,
|
||||
using Redis as a distributed locking provider.
|
||||
It uses an ExpiringDict to automatically clear the mutex after a specified timeout,
|
||||
in case the key is not unlocked for a specified duration, to prevent memory leaks.
|
||||
"""
|
||||
|
||||
def __init__(self, redis: "Redis", timeout: int | None = 60):
|
||||
self.redis = redis
|
||||
self.timeout = timeout
|
||||
self.locks: dict[Any, "RedisLock"] = ExpiringDict(
|
||||
max_len=6000, max_age_seconds=self.timeout
|
||||
)
|
||||
self.locks_lock = Lock()
|
||||
|
||||
@contextmanager
|
||||
def locked(self, key: Any):
|
||||
lock = self.acquire(key)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
def acquire(self, key: Any) -> "RedisLock":
|
||||
"""Acquires and returns a lock with the given key"""
|
||||
with self.locks_lock:
|
||||
if key not in self.locks:
|
||||
self.locks[key] = self.redis.lock(
|
||||
str(key), self.timeout, thread_local=False
|
||||
)
|
||||
lock = self.locks[key]
|
||||
lock.acquire()
|
||||
return lock
|
||||
|
||||
def release(self, key: Any):
|
||||
if lock := self.locks.get(key):
|
||||
lock.release()
|
||||
|
||||
def release_all_locks(self):
|
||||
"""Call this on process termination to ensure all locks are released"""
|
||||
self.locks_lock.acquire(blocking=False)
|
||||
for lock in self.locks.values():
|
||||
if lock.locked() and lock.owned():
|
||||
lock.release()
|
||||
294
autogpt_platform/autogpt_libs/poetry.lock
generated
294
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -377,6 +377,20 @@ files = [
|
||||
[package.extras]
|
||||
test = ["pytest (>=6)"]
|
||||
|
||||
[[package]]
|
||||
name = "expiringdict"
|
||||
version = "1.2.2"
|
||||
description = "Dictionary with auto-expiring values for caching purposes"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "expiringdict-1.2.2-py3-none-any.whl", hash = "sha256:09a5d20bc361163e6432a874edd3179676e935eb81b925eccef48d409a8a45e8"},
|
||||
{file = "expiringdict-1.2.2.tar.gz", hash = "sha256:300fb92a7e98f15b05cf9a856c1415b3bc4f2e132be07daa326da6414c23ee09"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
tests = ["coverage", "coveralls", "dill", "mock", "nose"]
|
||||
|
||||
[[package]]
|
||||
name = "frozenlist"
|
||||
version = "1.4.1"
|
||||
@@ -569,13 +583,13 @@ grpc = ["grpcio (>=1.38.0,<2.0dev)", "grpcio-status (>=1.38.0,<2.0.dev0)"]
|
||||
|
||||
[[package]]
|
||||
name = "google-cloud-logging"
|
||||
version = "3.11.2"
|
||||
version = "3.11.3"
|
||||
description = "Stackdriver Logging API client library"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "google_cloud_logging-3.11.2-py2.py3-none-any.whl", hash = "sha256:0a755f04f184fbe77ad608258dc283a032485ebb4d0e2b2501964059ee9c898f"},
|
||||
{file = "google_cloud_logging-3.11.2.tar.gz", hash = "sha256:4897441c2b74f6eda9181c23a8817223b6145943314a821d64b729d30766cb2b"},
|
||||
{file = "google_cloud_logging-3.11.3-py2.py3-none-any.whl", hash = "sha256:b8ec23f2998f76a58f8492db26a0f4151dd500425c3f08448586b85972f3c494"},
|
||||
{file = "google_cloud_logging-3.11.3.tar.gz", hash = "sha256:0a73cd94118875387d4535371d9e9426861edef8e44fba1261e86782d5b8d54f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -612,17 +626,17 @@ grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"]
|
||||
|
||||
[[package]]
|
||||
name = "gotrue"
|
||||
version = "2.8.1"
|
||||
version = "2.9.3"
|
||||
description = "Python Client Library for Supabase Auth"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8"
|
||||
python-versions = "<4.0,>=3.9"
|
||||
files = [
|
||||
{file = "gotrue-2.8.1-py3-none-any.whl", hash = "sha256:97dff077d71cca629f046c35ba34fae132b69c55fe271651766ddcf6d8132468"},
|
||||
{file = "gotrue-2.8.1.tar.gz", hash = "sha256:644d0096c4c390f7e36d9cb05271a7091c01e7dc6d506eb117b8fe8fc48eb8d9"},
|
||||
{file = "gotrue-2.9.3-py3-none-any.whl", hash = "sha256:9d2e9c74405d879f4828e0a7b94daf167a6e109c10ae6e5c59a0e21446f6e423"},
|
||||
{file = "gotrue-2.9.3.tar.gz", hash = "sha256:051551d80e642bdd2ab42cac78207745d89a2a08f429a1512d82624e675d8255"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = {version = ">=0.24,<0.28", extras = ["http2"]}
|
||||
httpx = {version = ">=0.26,<0.28", extras = ["http2"]}
|
||||
pydantic = ">=1.10,<3"
|
||||
|
||||
[[package]]
|
||||
@@ -972,20 +986,20 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "postgrest"
|
||||
version = "0.16.11"
|
||||
version = "0.17.2"
|
||||
description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8"
|
||||
python-versions = "<4.0,>=3.9"
|
||||
files = [
|
||||
{file = "postgrest-0.16.11-py3-none-any.whl", hash = "sha256:22fb6b817ace1f68aa648fd4ce0f56d2786c9260fa4ed2cb9046191231a682b8"},
|
||||
{file = "postgrest-0.16.11.tar.gz", hash = "sha256:10af51b4c39e288ad7df2db92d6a61fb3c4683131b40561f473e3de116e83fa5"},
|
||||
{file = "postgrest-0.17.2-py3-none-any.whl", hash = "sha256:f7c4f448e5a5e2d4c1dcf192edae9d1007c4261e9a6fb5116783a0046846ece2"},
|
||||
{file = "postgrest-0.17.2.tar.gz", hash = "sha256:445cd4e4a191e279492549df0c4e827d32f9d01d0852599bb8a6efb0f07fcf78"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
deprecation = ">=2.1.0,<3.0.0"
|
||||
httpx = {version = ">=0.24,<0.28", extras = ["http2"]}
|
||||
httpx = {version = ">=0.26,<0.28", extras = ["http2"]}
|
||||
pydantic = ">=1.9,<3.0"
|
||||
strenum = ">=0.4.9,<0.5.0"
|
||||
strenum = {version = ">=0.4.9,<0.5.0", markers = "python_version < \"3.11\""}
|
||||
|
||||
[[package]]
|
||||
name = "proto-plus"
|
||||
@@ -1031,6 +1045,7 @@ description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"},
|
||||
{file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"},
|
||||
]
|
||||
|
||||
@@ -1041,6 +1056,7 @@ description = "A collection of ASN.1-based protocols modules"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"},
|
||||
{file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"},
|
||||
]
|
||||
|
||||
@@ -1049,18 +1065,18 @@ pyasn1 = ">=0.4.6,<0.7.0"
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.9.1"
|
||||
version = "2.9.2"
|
||||
description = "Data validation using Python type hints"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pydantic-2.9.1-py3-none-any.whl", hash = "sha256:7aff4db5fdf3cf573d4b3c30926a510a10e19a0774d38fc4967f78beb6deb612"},
|
||||
{file = "pydantic-2.9.1.tar.gz", hash = "sha256:1363c7d975c7036df0db2b4a61f2e062fbc0aa5ab5f2772e0ffc7191a4f4bce2"},
|
||||
{file = "pydantic-2.9.2-py3-none-any.whl", hash = "sha256:f048cec7b26778210e28a0459867920654d48e5e62db0958433636cde4254f12"},
|
||||
{file = "pydantic-2.9.2.tar.gz", hash = "sha256:d155cef71265d1e9807ed1c32b4c8deec042a44a50a4188b25ac67ecd81a9c0f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
annotated-types = ">=0.6.0"
|
||||
pydantic-core = "2.23.3"
|
||||
pydantic-core = "2.23.4"
|
||||
typing-extensions = [
|
||||
{version = ">=4.12.2", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=4.6.1", markers = "python_version < \"3.13\""},
|
||||
@@ -1072,100 +1088,100 @@ timezone = ["tzdata"]
|
||||
|
||||
[[package]]
|
||||
name = "pydantic-core"
|
||||
version = "2.23.3"
|
||||
version = "2.23.4"
|
||||
description = "Core functionality for Pydantic validation and serialization"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pydantic_core-2.23.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:7f10a5d1b9281392f1bf507d16ac720e78285dfd635b05737c3911637601bae6"},
|
||||
{file = "pydantic_core-2.23.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3c09a7885dd33ee8c65266e5aa7fb7e2f23d49d8043f089989726391dd7350c5"},
|
||||
{file = "pydantic_core-2.23.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6470b5a1ec4d1c2e9afe928c6cb37eb33381cab99292a708b8cb9aa89e62429b"},
|
||||
{file = "pydantic_core-2.23.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9172d2088e27d9a185ea0a6c8cebe227a9139fd90295221d7d495944d2367700"},
|
||||
{file = "pydantic_core-2.23.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86fc6c762ca7ac8fbbdff80d61b2c59fb6b7d144aa46e2d54d9e1b7b0e780e01"},
|
||||
{file = "pydantic_core-2.23.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f0cb80fd5c2df4898693aa841425ea1727b1b6d2167448253077d2a49003e0ed"},
|
||||
{file = "pydantic_core-2.23.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:03667cec5daf43ac4995cefa8aaf58f99de036204a37b889c24a80927b629cec"},
|
||||
{file = "pydantic_core-2.23.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:047531242f8e9c2db733599f1c612925de095e93c9cc0e599e96cf536aaf56ba"},
|
||||
{file = "pydantic_core-2.23.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5499798317fff7f25dbef9347f4451b91ac2a4330c6669821c8202fd354c7bee"},
|
||||
{file = "pydantic_core-2.23.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bbb5e45eab7624440516ee3722a3044b83fff4c0372efe183fd6ba678ff681fe"},
|
||||
{file = "pydantic_core-2.23.3-cp310-none-win32.whl", hash = "sha256:8b5b3ed73abb147704a6e9f556d8c5cb078f8c095be4588e669d315e0d11893b"},
|
||||
{file = "pydantic_core-2.23.3-cp310-none-win_amd64.whl", hash = "sha256:2b603cde285322758a0279995b5796d64b63060bfbe214b50a3ca23b5cee3e83"},
|
||||
{file = "pydantic_core-2.23.3-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:c889fd87e1f1bbeb877c2ee56b63bb297de4636661cc9bbfcf4b34e5e925bc27"},
|
||||
{file = "pydantic_core-2.23.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ea85bda3189fb27503af4c45273735bcde3dd31c1ab17d11f37b04877859ef45"},
|
||||
{file = "pydantic_core-2.23.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a7f7f72f721223f33d3dc98a791666ebc6a91fa023ce63733709f4894a7dc611"},
|
||||
{file = "pydantic_core-2.23.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2b2b55b0448e9da68f56b696f313949cda1039e8ec7b5d294285335b53104b61"},
|
||||
{file = "pydantic_core-2.23.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c24574c7e92e2c56379706b9a3f07c1e0c7f2f87a41b6ee86653100c4ce343e5"},
|
||||
{file = "pydantic_core-2.23.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f2b05e6ccbee333a8f4b8f4d7c244fdb7a979e90977ad9c51ea31261e2085ce0"},
|
||||
{file = "pydantic_core-2.23.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2c409ce1c219c091e47cb03feb3c4ed8c2b8e004efc940da0166aaee8f9d6c8"},
|
||||
{file = "pydantic_core-2.23.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d965e8b325f443ed3196db890d85dfebbb09f7384486a77461347f4adb1fa7f8"},
|
||||
{file = "pydantic_core-2.23.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f56af3a420fb1ffaf43ece3ea09c2d27c444e7c40dcb7c6e7cf57aae764f2b48"},
|
||||
{file = "pydantic_core-2.23.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5b01a078dd4f9a52494370af21aa52964e0a96d4862ac64ff7cea06e0f12d2c5"},
|
||||
{file = "pydantic_core-2.23.3-cp311-none-win32.whl", hash = "sha256:560e32f0df04ac69b3dd818f71339983f6d1f70eb99d4d1f8e9705fb6c34a5c1"},
|
||||
{file = "pydantic_core-2.23.3-cp311-none-win_amd64.whl", hash = "sha256:c744fa100fdea0d000d8bcddee95213d2de2e95b9c12be083370b2072333a0fa"},
|
||||
{file = "pydantic_core-2.23.3-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:e0ec50663feedf64d21bad0809f5857bac1ce91deded203efc4a84b31b2e4305"},
|
||||
{file = "pydantic_core-2.23.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:db6e6afcb95edbe6b357786684b71008499836e91f2a4a1e55b840955b341dbb"},
|
||||
{file = "pydantic_core-2.23.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98ccd69edcf49f0875d86942f4418a4e83eb3047f20eb897bffa62a5d419c8fa"},
|
||||
{file = "pydantic_core-2.23.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a678c1ac5c5ec5685af0133262103defb427114e62eafeda12f1357a12140162"},
|
||||
{file = "pydantic_core-2.23.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:01491d8b4d8db9f3391d93b0df60701e644ff0894352947f31fff3e52bd5c801"},
|
||||
{file = "pydantic_core-2.23.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fcf31facf2796a2d3b7fe338fe8640aa0166e4e55b4cb108dbfd1058049bf4cb"},
|
||||
{file = "pydantic_core-2.23.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7200fd561fb3be06827340da066df4311d0b6b8eb0c2116a110be5245dceb326"},
|
||||
{file = "pydantic_core-2.23.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:dc1636770a809dee2bd44dd74b89cc80eb41172bcad8af75dd0bc182c2666d4c"},
|
||||
{file = "pydantic_core-2.23.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:67a5def279309f2e23014b608c4150b0c2d323bd7bccd27ff07b001c12c2415c"},
|
||||
{file = "pydantic_core-2.23.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:748bdf985014c6dd3e1e4cc3db90f1c3ecc7246ff5a3cd4ddab20c768b2f1dab"},
|
||||
{file = "pydantic_core-2.23.3-cp312-none-win32.whl", hash = "sha256:255ec6dcb899c115f1e2a64bc9ebc24cc0e3ab097775755244f77360d1f3c06c"},
|
||||
{file = "pydantic_core-2.23.3-cp312-none-win_amd64.whl", hash = "sha256:40b8441be16c1e940abebed83cd006ddb9e3737a279e339dbd6d31578b802f7b"},
|
||||
{file = "pydantic_core-2.23.3-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:6daaf5b1ba1369a22c8b050b643250e3e5efc6a78366d323294aee54953a4d5f"},
|
||||
{file = "pydantic_core-2.23.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d015e63b985a78a3d4ccffd3bdf22b7c20b3bbd4b8227809b3e8e75bc37f9cb2"},
|
||||
{file = "pydantic_core-2.23.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3fc572d9b5b5cfe13f8e8a6e26271d5d13f80173724b738557a8c7f3a8a3791"},
|
||||
{file = "pydantic_core-2.23.3-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f6bd91345b5163ee7448bee201ed7dd601ca24f43f439109b0212e296eb5b423"},
|
||||
{file = "pydantic_core-2.23.3-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc379c73fd66606628b866f661e8785088afe2adaba78e6bbe80796baf708a63"},
|
||||
{file = "pydantic_core-2.23.3-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fbdce4b47592f9e296e19ac31667daed8753c8367ebb34b9a9bd89dacaa299c9"},
|
||||
{file = "pydantic_core-2.23.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc3cf31edf405a161a0adad83246568647c54404739b614b1ff43dad2b02e6d5"},
|
||||
{file = "pydantic_core-2.23.3-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8e22b477bf90db71c156f89a55bfe4d25177b81fce4aa09294d9e805eec13855"},
|
||||
{file = "pydantic_core-2.23.3-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:0a0137ddf462575d9bce863c4c95bac3493ba8e22f8c28ca94634b4a1d3e2bb4"},
|
||||
{file = "pydantic_core-2.23.3-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:203171e48946c3164fe7691fc349c79241ff8f28306abd4cad5f4f75ed80bc8d"},
|
||||
{file = "pydantic_core-2.23.3-cp313-none-win32.whl", hash = "sha256:76bdab0de4acb3f119c2a4bff740e0c7dc2e6de7692774620f7452ce11ca76c8"},
|
||||
{file = "pydantic_core-2.23.3-cp313-none-win_amd64.whl", hash = "sha256:37ba321ac2a46100c578a92e9a6aa33afe9ec99ffa084424291d84e456f490c1"},
|
||||
{file = "pydantic_core-2.23.3-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:d063c6b9fed7d992bcbebfc9133f4c24b7a7f215d6b102f3e082b1117cddb72c"},
|
||||
{file = "pydantic_core-2.23.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6cb968da9a0746a0cf521b2b5ef25fc5a0bee9b9a1a8214e0a1cfaea5be7e8a4"},
|
||||
{file = "pydantic_core-2.23.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edbefe079a520c5984e30e1f1f29325054b59534729c25b874a16a5048028d16"},
|
||||
{file = "pydantic_core-2.23.3-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cbaaf2ef20d282659093913da9d402108203f7cb5955020bd8d1ae5a2325d1c4"},
|
||||
{file = "pydantic_core-2.23.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fb539d7e5dc4aac345846f290cf504d2fd3c1be26ac4e8b5e4c2b688069ff4cf"},
|
||||
{file = "pydantic_core-2.23.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7e6f33503c5495059148cc486867e1d24ca35df5fc064686e631e314d959ad5b"},
|
||||
{file = "pydantic_core-2.23.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:04b07490bc2f6f2717b10c3969e1b830f5720b632f8ae2f3b8b1542394c47a8e"},
|
||||
{file = "pydantic_core-2.23.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:03795b9e8a5d7fda05f3873efc3f59105e2dcff14231680296b87b80bb327295"},
|
||||
{file = "pydantic_core-2.23.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:c483dab0f14b8d3f0df0c6c18d70b21b086f74c87ab03c59250dbf6d3c89baba"},
|
||||
{file = "pydantic_core-2.23.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8b2682038e255e94baf2c473dca914a7460069171ff5cdd4080be18ab8a7fd6e"},
|
||||
{file = "pydantic_core-2.23.3-cp38-none-win32.whl", hash = "sha256:f4a57db8966b3a1d1a350012839c6a0099f0898c56512dfade8a1fe5fb278710"},
|
||||
{file = "pydantic_core-2.23.3-cp38-none-win_amd64.whl", hash = "sha256:13dd45ba2561603681a2676ca56006d6dee94493f03d5cadc055d2055615c3ea"},
|
||||
{file = "pydantic_core-2.23.3-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:82da2f4703894134a9f000e24965df73cc103e31e8c31906cc1ee89fde72cbd8"},
|
||||
{file = "pydantic_core-2.23.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:dd9be0a42de08f4b58a3cc73a123f124f65c24698b95a54c1543065baca8cf0e"},
|
||||
{file = "pydantic_core-2.23.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89b731f25c80830c76fdb13705c68fef6a2b6dc494402987c7ea9584fe189f5d"},
|
||||
{file = "pydantic_core-2.23.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c6de1ec30c4bb94f3a69c9f5f2182baeda5b809f806676675e9ef6b8dc936f28"},
|
||||
{file = "pydantic_core-2.23.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bb68b41c3fa64587412b104294b9cbb027509dc2f6958446c502638d481525ef"},
|
||||
{file = "pydantic_core-2.23.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c3980f2843de5184656aab58698011b42763ccba11c4a8c35936c8dd6c7068c"},
|
||||
{file = "pydantic_core-2.23.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94f85614f2cba13f62c3c6481716e4adeae48e1eaa7e8bac379b9d177d93947a"},
|
||||
{file = "pydantic_core-2.23.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:510b7fb0a86dc8f10a8bb43bd2f97beb63cffad1203071dc434dac26453955cd"},
|
||||
{file = "pydantic_core-2.23.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:1eba2f7ce3e30ee2170410e2171867ea73dbd692433b81a93758ab2de6c64835"},
|
||||
{file = "pydantic_core-2.23.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4b259fd8409ab84b4041b7b3f24dcc41e4696f180b775961ca8142b5b21d0e70"},
|
||||
{file = "pydantic_core-2.23.3-cp39-none-win32.whl", hash = "sha256:40d9bd259538dba2f40963286009bf7caf18b5112b19d2b55b09c14dde6db6a7"},
|
||||
{file = "pydantic_core-2.23.3-cp39-none-win_amd64.whl", hash = "sha256:5a8cd3074a98ee70173a8633ad3c10e00dcb991ecec57263aacb4095c5efb958"},
|
||||
{file = "pydantic_core-2.23.3-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f399e8657c67313476a121a6944311fab377085ca7f490648c9af97fc732732d"},
|
||||
{file = "pydantic_core-2.23.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:6b5547d098c76e1694ba85f05b595720d7c60d342f24d5aad32c3049131fa5c4"},
|
||||
{file = "pydantic_core-2.23.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0dda0290a6f608504882d9f7650975b4651ff91c85673341789a476b1159f211"},
|
||||
{file = "pydantic_core-2.23.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65b6e5da855e9c55a0c67f4db8a492bf13d8d3316a59999cfbaf98cc6e401961"},
|
||||
{file = "pydantic_core-2.23.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:09e926397f392059ce0afdcac920df29d9c833256354d0c55f1584b0b70cf07e"},
|
||||
{file = "pydantic_core-2.23.3-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:87cfa0ed6b8c5bd6ae8b66de941cece179281239d482f363814d2b986b79cedc"},
|
||||
{file = "pydantic_core-2.23.3-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e61328920154b6a44d98cabcb709f10e8b74276bc709c9a513a8c37a18786cc4"},
|
||||
{file = "pydantic_core-2.23.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:ce3317d155628301d649fe5e16a99528d5680af4ec7aa70b90b8dacd2d725c9b"},
|
||||
{file = "pydantic_core-2.23.3-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:e89513f014c6be0d17b00a9a7c81b1c426f4eb9224b15433f3d98c1a071f8433"},
|
||||
{file = "pydantic_core-2.23.3-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:4f62c1c953d7ee375df5eb2e44ad50ce2f5aff931723b398b8bc6f0ac159791a"},
|
||||
{file = "pydantic_core-2.23.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2718443bc671c7ac331de4eef9b673063b10af32a0bb385019ad61dcf2cc8f6c"},
|
||||
{file = "pydantic_core-2.23.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0d90e08b2727c5d01af1b5ef4121d2f0c99fbee692c762f4d9d0409c9da6541"},
|
||||
{file = "pydantic_core-2.23.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2b676583fc459c64146debea14ba3af54e540b61762dfc0613dc4e98c3f66eeb"},
|
||||
{file = "pydantic_core-2.23.3-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:50e4661f3337977740fdbfbae084ae5693e505ca2b3130a6d4eb0f2281dc43b8"},
|
||||
{file = "pydantic_core-2.23.3-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:68f4cf373f0de6abfe599a38307f4417c1c867ca381c03df27c873a9069cda25"},
|
||||
{file = "pydantic_core-2.23.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:59d52cf01854cb26c46958552a21acb10dd78a52aa34c86f284e66b209db8cab"},
|
||||
{file = "pydantic_core-2.23.3.tar.gz", hash = "sha256:3cb0f65d8b4121c1b015c60104a685feb929a29d7cf204387c7f2688c7974690"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b10bd51f823d891193d4717448fab065733958bdb6a6b351967bd349d48d5c9b"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4fc714bdbfb534f94034efaa6eadd74e5b93c8fa6315565a222f7b6f42ca1166"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63e46b3169866bd62849936de036f901a9356e36376079b05efa83caeaa02ceb"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed1a53de42fbe34853ba90513cea21673481cd81ed1be739f7f2efb931b24916"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cfdd16ab5e59fc31b5e906d1a3f666571abc367598e3e02c83403acabc092e07"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:255a8ef062cbf6674450e668482456abac99a5583bbafb73f9ad469540a3a232"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a7cd62e831afe623fbb7aabbb4fe583212115b3ef38a9f6b71869ba644624a2"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f09e2ff1f17c2b51f2bc76d1cc33da96298f0a036a137f5440ab3ec5360b624f"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e38e63e6f3d1cec5a27e0afe90a085af8b6806ee208b33030e65b6516353f1a3"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0dbd8dbed2085ed23b5c04afa29d8fd2771674223135dc9bc937f3c09284d071"},
|
||||
{file = "pydantic_core-2.23.4-cp310-none-win32.whl", hash = "sha256:6531b7ca5f951d663c339002e91aaebda765ec7d61b7d1e3991051906ddde119"},
|
||||
{file = "pydantic_core-2.23.4-cp310-none-win_amd64.whl", hash = "sha256:7c9129eb40958b3d4500fa2467e6a83356b3b61bfff1b414c7361d9220f9ae8f"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:77733e3892bb0a7fa797826361ce8a9184d25c8dffaec60b7ffe928153680ba8"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1b84d168f6c48fabd1f2027a3d1bdfe62f92cade1fb273a5d68e621da0e44e6d"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df49e7a0861a8c36d089c1ed57d308623d60416dab2647a4a17fe050ba85de0e"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ff02b6d461a6de369f07ec15e465a88895f3223eb75073ffea56b84d9331f607"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:996a38a83508c54c78a5f41456b0103c30508fed9abcad0a59b876d7398f25fd"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d97683ddee4723ae8c95d1eddac7c192e8c552da0c73a925a89fa8649bf13eea"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:216f9b2d7713eb98cb83c80b9c794de1f6b7e3145eef40400c62e86cee5f4e1e"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6f783e0ec4803c787bcea93e13e9932edab72068f68ecffdf86a99fd5918878b"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d0776dea117cf5272382634bd2a5c1b6eb16767c223c6a5317cd3e2a757c61a0"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d5f7a395a8cf1621939692dba2a6b6a830efa6b3cee787d82c7de1ad2930de64"},
|
||||
{file = "pydantic_core-2.23.4-cp311-none-win32.whl", hash = "sha256:74b9127ffea03643e998e0c5ad9bd3811d3dac8c676e47db17b0ee7c3c3bf35f"},
|
||||
{file = "pydantic_core-2.23.4-cp311-none-win_amd64.whl", hash = "sha256:98d134c954828488b153d88ba1f34e14259284f256180ce659e8d83e9c05eaa3"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f3e0da4ebaef65158d4dfd7d3678aad692f7666877df0002b8a522cdf088f231"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f69a8e0b033b747bb3e36a44e7732f0c99f7edd5cea723d45bc0d6e95377ffee"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:723314c1d51722ab28bfcd5240d858512ffd3116449c557a1336cbe3919beb87"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bb2802e667b7051a1bebbfe93684841cc9351004e2badbd6411bf357ab8d5ac8"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d18ca8148bebe1b0a382a27a8ee60350091a6ddaf475fa05ef50dc35b5df6327"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:33e3d65a85a2a4a0dc3b092b938a4062b1a05f3a9abde65ea93b233bca0e03f2"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:128585782e5bfa515c590ccee4b727fb76925dd04a98864182b22e89a4e6ed36"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:68665f4c17edcceecc112dfed5dbe6f92261fb9d6054b47d01bf6371a6196126"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:20152074317d9bed6b7a95ade3b7d6054845d70584216160860425f4fbd5ee9e"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9261d3ce84fa1d38ed649c3638feefeae23d32ba9182963e465d58d62203bd24"},
|
||||
{file = "pydantic_core-2.23.4-cp312-none-win32.whl", hash = "sha256:4ba762ed58e8d68657fc1281e9bb72e1c3e79cc5d464be146e260c541ec12d84"},
|
||||
{file = "pydantic_core-2.23.4-cp312-none-win_amd64.whl", hash = "sha256:97df63000f4fea395b2824da80e169731088656d1818a11b95f3b173747b6cd9"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:7530e201d10d7d14abce4fb54cfe5b94a0aefc87da539d0346a484ead376c3cc"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:df933278128ea1cd77772673c73954e53a1c95a4fdf41eef97c2b779271bd0bd"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cb3da3fd1b6a5d0279a01877713dbda118a2a4fc6f0d821a57da2e464793f05"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:42c6dcb030aefb668a2b7009c85b27f90e51e6a3b4d5c9bc4c57631292015b0d"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:696dd8d674d6ce621ab9d45b205df149399e4bb9aa34102c970b721554828510"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2971bb5ffe72cc0f555c13e19b23c85b654dd2a8f7ab493c262071377bfce9f6"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8394d940e5d400d04cad4f75c0598665cbb81aecefaca82ca85bd28264af7f9b"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0dff76e0602ca7d4cdaacc1ac4c005e0ce0dcfe095d5b5259163a80d3a10d327"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7d32706badfe136888bdea71c0def994644e09fff0bfe47441deaed8e96fdbc6"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ed541d70698978a20eb63d8c5d72f2cc6d7079d9d90f6b50bad07826f1320f5f"},
|
||||
{file = "pydantic_core-2.23.4-cp313-none-win32.whl", hash = "sha256:3d5639516376dce1940ea36edf408c554475369f5da2abd45d44621cb616f769"},
|
||||
{file = "pydantic_core-2.23.4-cp313-none-win_amd64.whl", hash = "sha256:5a1504ad17ba4210df3a045132a7baeeba5a200e930f57512ee02909fc5c4cb5"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:d4488a93b071c04dc20f5cecc3631fc78b9789dd72483ba15d423b5b3689b555"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:81965a16b675b35e1d09dd14df53f190f9129c0202356ed44ab2728b1c905658"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ffa2ebd4c8530079140dd2d7f794a9d9a73cbb8e9d59ffe24c63436efa8f271"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:61817945f2fe7d166e75fbfb28004034b48e44878177fc54d81688e7b85a3665"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:29d2c342c4bc01b88402d60189f3df065fb0dda3654744d5a165a5288a657368"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5e11661ce0fd30a6790e8bcdf263b9ec5988e95e63cf901972107efc49218b13"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d18368b137c6295db49ce7218b1a9ba15c5bc254c96d7c9f9e924a9bc7825ad"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ec4e55f79b1c4ffb2eecd8a0cfba9955a2588497d96851f4c8f99aa4a1d39b12"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:374a5e5049eda9e0a44c696c7ade3ff355f06b1fe0bb945ea3cac2bc336478a2"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5c364564d17da23db1106787675fc7af45f2f7b58b4173bfdd105564e132e6fb"},
|
||||
{file = "pydantic_core-2.23.4-cp38-none-win32.whl", hash = "sha256:d7a80d21d613eec45e3d41eb22f8f94ddc758a6c4720842dc74c0581f54993d6"},
|
||||
{file = "pydantic_core-2.23.4-cp38-none-win_amd64.whl", hash = "sha256:5f5ff8d839f4566a474a969508fe1c5e59c31c80d9e140566f9a37bba7b8d556"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a4fa4fc04dff799089689f4fd502ce7d59de529fc2f40a2c8836886c03e0175a"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0a7df63886be5e270da67e0966cf4afbae86069501d35c8c1b3b6c168f42cb36"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dcedcd19a557e182628afa1d553c3895a9f825b936415d0dbd3cd0bbcfd29b4b"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f54b118ce5de9ac21c363d9b3caa6c800341e8c47a508787e5868c6b79c9323"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86d2f57d3e1379a9525c5ab067b27dbb8a0642fb5d454e17a9ac434f9ce523e3"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:de6d1d1b9e5101508cb37ab0d972357cac5235f5c6533d1071964c47139257df"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1278e0d324f6908e872730c9102b0112477a7f7cf88b308e4fc36ce1bdb6d58c"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9a6b5099eeec78827553827f4c6b8615978bb4b6a88e5d9b93eddf8bb6790f55"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:e55541f756f9b3ee346b840103f32779c695a19826a4c442b7954550a0972040"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a5c7ba8ffb6d6f8f2ab08743be203654bb1aaa8c9dcb09f82ddd34eadb695605"},
|
||||
{file = "pydantic_core-2.23.4-cp39-none-win32.whl", hash = "sha256:37b0fe330e4a58d3c58b24d91d1eb102aeec675a3db4c292ec3928ecd892a9a6"},
|
||||
{file = "pydantic_core-2.23.4-cp39-none-win_amd64.whl", hash = "sha256:1498bec4c05c9c787bde9125cfdcc63a41004ff167f495063191b863399b1a29"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f455ee30a9d61d3e1a15abd5068827773d6e4dc513e795f380cdd59932c782d5"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1e90d2e3bd2c3863d48525d297cd143fe541be8bbf6f579504b9712cb6b643ec"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e203fdf807ac7e12ab59ca2bfcabb38c7cf0b33c41efeb00f8e5da1d86af480"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e08277a400de01bc72436a0ccd02bdf596631411f592ad985dcee21445bd0068"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f220b0eea5965dec25480b6333c788fb72ce5f9129e8759ef876a1d805d00801"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d06b0c8da4f16d1d1e352134427cb194a0a6e19ad5db9161bf32b2113409e728"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:ba1a0996f6c2773bd83e63f18914c1de3c9dd26d55f4ac302a7efe93fb8e7433"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:9a5bce9d23aac8f0cf0836ecfc033896aa8443b501c58d0602dbfd5bd5b37753"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:78ddaaa81421a29574a682b3179d4cf9e6d405a09b99d93ddcf7e5239c742e21"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:883a91b5dd7d26492ff2f04f40fbb652de40fcc0afe07e8129e8ae779c2110eb"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88ad334a15b32a791ea935af224b9de1bf99bcd62fabf745d5f3442199d86d59"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:233710f069d251feb12a56da21e14cca67994eab08362207785cf8c598e74577"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:19442362866a753485ba5e4be408964644dd6a09123d9416c54cd49171f50744"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:624e278a7d29b6445e4e813af92af37820fafb6dcc55c012c834f9e26f9aaaef"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f5ef8f42bec47f21d07668a043f077d507e5bf4e668d5c6dfe6aaba89de1a5b8"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:aea443fffa9fbe3af1a9ba721a87f926fe548d32cab71d188a6ede77d0ff244e"},
|
||||
{file = "pydantic_core-2.23.4.tar.gz", hash = "sha256:2584f7cf844ac4d970fba483a717dbe10c1c1c96a969bf65d61ffe94df1b2863"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1173,13 +1189,13 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
|
||||
|
||||
[[package]]
|
||||
name = "pydantic-settings"
|
||||
version = "2.5.2"
|
||||
version = "2.6.0"
|
||||
description = "Settings management using Pydantic"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pydantic_settings-2.5.2-py3-none-any.whl", hash = "sha256:2c912e55fd5794a59bf8c832b9de832dcfdf4778d79ff79b708744eed499a907"},
|
||||
{file = "pydantic_settings-2.5.2.tar.gz", hash = "sha256:f90b139682bee4d2065273d5185d71d37ea46cfe57e1b5ae184fc6a0b2484ca0"},
|
||||
{file = "pydantic_settings-2.6.0-py3-none-any.whl", hash = "sha256:4a819166f119b74d7f8c765196b165f95cc7487ce58ea27dec8a5a26be0970e0"},
|
||||
{file = "pydantic_settings-2.6.0.tar.gz", hash = "sha256:44a1804abffac9e6a30372bb45f6cafab945ef5af25e66b1c634c01dd39e0188"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1253,6 +1269,24 @@ python-dateutil = ">=2.8.1,<3.0.0"
|
||||
typing-extensions = ">=4.12.2,<5.0.0"
|
||||
websockets = ">=11,<13"
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "5.1.1"
|
||||
description = "Python client for Redis database and key-value store"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "redis-5.1.1-py3-none-any.whl", hash = "sha256:f8ea06b7482a668c6475ae202ed8d9bcaa409f6e87fb77ed1043d912afd62e24"},
|
||||
{file = "redis-5.1.1.tar.gz", hash = "sha256:f6c997521fedbae53387307c5d0bf784d9acc28d9f1d058abeac566ec4dbed72"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""}
|
||||
|
||||
[package.extras]
|
||||
hiredis = ["hiredis (>=3.0.0)"]
|
||||
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "requests"
|
||||
version = "2.32.3"
|
||||
@@ -1312,17 +1346,17 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "storage3"
|
||||
version = "0.7.7"
|
||||
version = "0.8.2"
|
||||
description = "Supabase Storage client for Python."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8"
|
||||
python-versions = "<4.0,>=3.9"
|
||||
files = [
|
||||
{file = "storage3-0.7.7-py3-none-any.whl", hash = "sha256:ed80a2546cd0b5c22e2c30ea71096db6c99268daf2958c603488e7d72efb8426"},
|
||||
{file = "storage3-0.7.7.tar.gz", hash = "sha256:9fba680cf761d139ad764f43f0e91c245d1ce1af2cc3afe716652f835f48f83e"},
|
||||
{file = "storage3-0.8.2-py3-none-any.whl", hash = "sha256:f2e995b18c77a2a9265d1a33047d43e4d6abb11eb3ca5067959f68281c305de3"},
|
||||
{file = "storage3-0.8.2.tar.gz", hash = "sha256:db05d3fe8fb73bd30c814c4c4749664f37a5dfc78b629e8c058ef558c2b89f5a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = {version = ">=0.24,<0.28", extras = ["http2"]}
|
||||
httpx = {version = ">=0.26,<0.28", extras = ["http2"]}
|
||||
python-dateutil = ">=2.8.2,<3.0.0"
|
||||
typing-extensions = ">=4.2.0,<5.0.0"
|
||||
|
||||
@@ -1344,36 +1378,36 @@ test = ["pylint", "pytest", "pytest-black", "pytest-cov", "pytest-pylint"]
|
||||
|
||||
[[package]]
|
||||
name = "supabase"
|
||||
version = "2.7.4"
|
||||
version = "2.9.1"
|
||||
description = "Supabase client for Python."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
files = [
|
||||
{file = "supabase-2.7.4-py3-none-any.whl", hash = "sha256:01815fbc30cac753933d4a44a2529fd13cb7634b56c705c65b12a02c8e75982b"},
|
||||
{file = "supabase-2.7.4.tar.gz", hash = "sha256:5a979c7711b3c5ce688514fa0afc015780522569494e1a9a9d25d03b7c3d654b"},
|
||||
{file = "supabase-2.9.1-py3-none-any.whl", hash = "sha256:a96f857a465712cb551679c1df66ba772c834f861756ce4aa2aa4cb703f6aeb7"},
|
||||
{file = "supabase-2.9.1.tar.gz", hash = "sha256:51fce39c9eb50573126dabb342541ec5e1f13e7476938768f4b0ccfdb8c522cd"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
gotrue = ">=1.3,<3.0"
|
||||
httpx = ">=0.24,<0.28"
|
||||
postgrest = ">=0.14,<0.17.0"
|
||||
gotrue = ">=2.9.0,<3.0.0"
|
||||
httpx = ">=0.26,<0.28"
|
||||
postgrest = ">=0.17.0,<0.18.0"
|
||||
realtime = ">=2.0.0,<3.0.0"
|
||||
storage3 = ">=0.5.3,<0.8.0"
|
||||
supafunc = ">=0.3.1,<0.6.0"
|
||||
storage3 = ">=0.8.0,<0.9.0"
|
||||
supafunc = ">=0.6.0,<0.7.0"
|
||||
|
||||
[[package]]
|
||||
name = "supafunc"
|
||||
version = "0.5.1"
|
||||
version = "0.6.2"
|
||||
description = "Library for Supabase Functions"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8"
|
||||
python-versions = "<4.0,>=3.9"
|
||||
files = [
|
||||
{file = "supafunc-0.5.1-py3-none-any.whl", hash = "sha256:b05e99a2b41270211a3f90ec843c04c5f27a5618f2d2d2eb8e07f41eb962a910"},
|
||||
{file = "supafunc-0.5.1.tar.gz", hash = "sha256:1ae9dce6bd935939c561650e86abb676af9665ecf5d4ffc1c7ec3c4932c84334"},
|
||||
{file = "supafunc-0.6.2-py3-none-any.whl", hash = "sha256:101b30616b0a1ce8cf938eca1df362fa4cf1deacb0271f53ebbd674190fb0da5"},
|
||||
{file = "supafunc-0.6.2.tar.gz", hash = "sha256:c7dfa20db7182f7fe4ae436e94e05c06cd7ed98d697fed75d68c7b9792822adc"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = {version = ">=0.24,<0.28", extras = ["http2"]}
|
||||
httpx = {version = ">=0.26,<0.28", extras = ["http2"]}
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
@@ -1690,4 +1724,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "e9b6e5d877eeb9c9f1ebc69dead1985d749facc160afbe61f3bf37e9a6e35aa5"
|
||||
content-hash = "44af7722ca3d2788fc817129ac43477b71eea9921d51502a63f755cb04e3f254"
|
||||
|
||||
@@ -8,13 +8,17 @@ packages = [{ include = "autogpt_libs" }]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
colorama = "^0.4.6"
|
||||
google-cloud-logging = "^3.8.0"
|
||||
pydantic = "^2.8.2"
|
||||
pydantic-settings = "^2.5.2"
|
||||
expiringdict = "^1.2.2"
|
||||
google-cloud-logging = "^3.11.3"
|
||||
pydantic = "^2.9.2"
|
||||
pydantic-settings = "^2.6.0"
|
||||
pyjwt = "^2.8.0"
|
||||
python = ">=3.10,<4.0"
|
||||
python-dotenv = "^1.0.1"
|
||||
supabase = "^2.7.2"
|
||||
supabase = "^2.9.1"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
redis = "^5.0.8"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -12,18 +12,21 @@ REDIS_PORT=6379
|
||||
REDIS_PASSWORD=password
|
||||
|
||||
ENABLE_CREDIT=false
|
||||
APP_ENV="local"
|
||||
# What environment things should be logged under: local dev or prod
|
||||
APP_ENV=local
|
||||
# What environment to behave as: "local" or "cloud"
|
||||
BEHAVE_AS=local
|
||||
PYRO_HOST=localhost
|
||||
SENTRY_DSN=
|
||||
|
||||
## User auth with Supabase is required for any of the 3rd party integrations with auth to work.
|
||||
ENABLE_AUTH=false
|
||||
ENABLE_AUTH=true
|
||||
SUPABASE_URL=http://localhost:8000
|
||||
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
|
||||
# For local development, you may need to set FRONTEND_BASE_URL for the OAuth flow for integrations to work.
|
||||
# FRONTEND_BASE_URL=http://localhost:3000
|
||||
FRONTEND_BASE_URL=http://localhost:3000
|
||||
|
||||
## == INTEGRATION CREDENTIALS == ##
|
||||
# Each set of server side credentials is required for the corresponding 3rd party
|
||||
@@ -36,6 +39,15 @@ SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
GITHUB_CLIENT_ID=
|
||||
GITHUB_CLIENT_SECRET=
|
||||
|
||||
# Google OAuth App server credentials - https://console.cloud.google.com/apis/credentials, and enable gmail api and set scopes
|
||||
# https://console.cloud.google.com/apis/credentials/consent ?project=<your_project_id>
|
||||
|
||||
# You'll need to add/enable the following scopes (minimum):
|
||||
# https://console.developers.google.com/apis/api/gmail.googleapis.com/overview ?project=<your_project_id>
|
||||
# https://console.cloud.google.com/apis/library/sheets.googleapis.com/ ?project=<your_project_id>
|
||||
GOOGLE_CLIENT_ID=
|
||||
GOOGLE_CLIENT_SECRET=
|
||||
|
||||
## ===== OPTIONAL API KEYS ===== ##
|
||||
|
||||
# LLM
|
||||
@@ -74,6 +86,14 @@ SMTP_PASSWORD=
|
||||
MEDIUM_API_KEY=
|
||||
MEDIUM_AUTHOR_ID=
|
||||
|
||||
# Google Maps
|
||||
GOOGLE_MAPS_API_KEY=
|
||||
|
||||
# Replicate
|
||||
REPLICATE_API_KEY=
|
||||
|
||||
# Ideogram
|
||||
IDEOGRAM_API_KEY=
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=INFO
|
||||
|
||||
@@ -8,7 +8,7 @@ WORKDIR /app
|
||||
|
||||
# Install build dependencies
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y build-essential curl ffmpeg wget libcurl4-gnutls-dev libexpat1-dev gettext libz-dev libssl-dev postgresql-client git \
|
||||
&& apt-get install -y build-essential curl ffmpeg wget libcurl4-gnutls-dev libexpat1-dev libpq5 gettext libz-dev libssl-dev postgresql-client git \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ We use the Poetry to manage the dependencies. To set up the project, follow thes
|
||||
5. Generate the Prisma client
|
||||
|
||||
```sh
|
||||
poetry run prisma generate --schema postgres/schema.prisma
|
||||
poetry run prisma generate
|
||||
```
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ We use the Poetry to manage the dependencies. To set up the project, follow thes
|
||||
|
||||
```sh
|
||||
cd ../backend
|
||||
prisma migrate dev --schema postgres/schema.prisma
|
||||
prisma migrate deploy
|
||||
```
|
||||
|
||||
## Running The Server
|
||||
|
||||
@@ -58,17 +58,18 @@ We use the Poetry to manage the dependencies. To set up the project, follow thes
|
||||
6. Migrate the database. Be careful because this deletes current data in the database.
|
||||
|
||||
```sh
|
||||
docker compose up db redis -d
|
||||
poetry run prisma migrate dev
|
||||
docker compose up db -d
|
||||
poetry run prisma migrate deploy
|
||||
```
|
||||
|
||||
## Running The Server
|
||||
|
||||
### Starting the server without Docker
|
||||
|
||||
Run the following command to build the dockerfiles:
|
||||
Run the following command to run database in docker but the application locally:
|
||||
|
||||
```sh
|
||||
docker compose --profile local up deps --build --detach
|
||||
poetry run app
|
||||
```
|
||||
|
||||
|
||||
@@ -24,10 +24,12 @@ def main(**kwargs):
|
||||
Run all the processes required for the AutoGPT-server (REST and WebSocket APIs).
|
||||
"""
|
||||
|
||||
from backend.executor import ExecutionManager, ExecutionScheduler
|
||||
from backend.server import AgentServer, WebsocketServer
|
||||
from backend.executor import DatabaseManager, ExecutionManager, ExecutionScheduler
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.server.ws_api import WebsocketServer
|
||||
|
||||
run_processes(
|
||||
DatabaseManager(),
|
||||
ExecutionManager(),
|
||||
ExecutionScheduler(),
|
||||
WebsocketServer(),
|
||||
|
||||
@@ -2,6 +2,7 @@ import importlib
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Type, TypeVar
|
||||
|
||||
from backend.data.block import Block
|
||||
|
||||
@@ -24,28 +25,31 @@ for module in modules:
|
||||
AVAILABLE_MODULES.append(module)
|
||||
|
||||
# Load all Block instances from the available modules
|
||||
AVAILABLE_BLOCKS = {}
|
||||
AVAILABLE_BLOCKS: dict[str, Type[Block]] = {}
|
||||
|
||||
|
||||
def all_subclasses(clz):
|
||||
subclasses = clz.__subclasses__()
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def all_subclasses(cls: Type[T]) -> list[Type[T]]:
|
||||
subclasses = cls.__subclasses__()
|
||||
for subclass in subclasses:
|
||||
subclasses += all_subclasses(subclass)
|
||||
return subclasses
|
||||
|
||||
|
||||
for cls in all_subclasses(Block):
|
||||
name = cls.__name__
|
||||
for block_cls in all_subclasses(Block):
|
||||
name = block_cls.__name__
|
||||
|
||||
if cls.__name__.endswith("Base"):
|
||||
if block_cls.__name__.endswith("Base"):
|
||||
continue
|
||||
|
||||
if not cls.__name__.endswith("Block"):
|
||||
if not block_cls.__name__.endswith("Block"):
|
||||
raise ValueError(
|
||||
f"Block class {cls.__name__} does not end with 'Block', If you are creating an abstract class, please name the class with 'Base' at the end"
|
||||
f"Block class {block_cls.__name__} does not end with 'Block', If you are creating an abstract class, please name the class with 'Base' at the end"
|
||||
)
|
||||
|
||||
block = cls()
|
||||
block = block_cls.create()
|
||||
|
||||
if not isinstance(block.id, str) or len(block.id) != 36:
|
||||
raise ValueError(f"Block ID {block.name} error: {block.id} is not a valid UUID")
|
||||
@@ -53,15 +57,33 @@ for cls in all_subclasses(Block):
|
||||
if block.id in AVAILABLE_BLOCKS:
|
||||
raise ValueError(f"Block ID {block.name} error: {block.id} is already in use")
|
||||
|
||||
input_schema = block.input_schema.model_fields
|
||||
output_schema = block.output_schema.model_fields
|
||||
|
||||
# Prevent duplicate field name in input_schema and output_schema
|
||||
duplicate_field_names = set(block.input_schema.model_fields.keys()) & set(
|
||||
block.output_schema.model_fields.keys()
|
||||
)
|
||||
duplicate_field_names = set(input_schema.keys()) & set(output_schema.keys())
|
||||
if duplicate_field_names:
|
||||
raise ValueError(
|
||||
f"{block.name} has duplicate field names in input_schema and output_schema: {duplicate_field_names}"
|
||||
)
|
||||
|
||||
# Make sure `error` field is a string in the output schema
|
||||
if "error" in output_schema and output_schema["error"].annotation is not str:
|
||||
raise ValueError(
|
||||
f"{block.name} `error` field in output_schema must be a string"
|
||||
)
|
||||
|
||||
# Make sure all fields in input_schema and output_schema are annotated and has a value
|
||||
for field_name, field in [*input_schema.items(), *output_schema.items()]:
|
||||
if field.annotation is None:
|
||||
raise ValueError(
|
||||
f"{block.name} has a field {field_name} that is not annotated"
|
||||
)
|
||||
if field.json_schema_extra is None:
|
||||
raise ValueError(
|
||||
f"{block.name} has a field {field_name} not defined as SchemaField"
|
||||
)
|
||||
|
||||
for field in block.input_schema.model_fields.values():
|
||||
if field.annotation is bool and field.default not in (True, False):
|
||||
raise ValueError(f"{block.name} has a boolean field with no default value")
|
||||
@@ -69,6 +91,6 @@ for cls in all_subclasses(Block):
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
AVAILABLE_BLOCKS[block.id] = block
|
||||
AVAILABLE_BLOCKS[block.id] = block_cls
|
||||
|
||||
__all__ = ["AVAILABLE_MODULES", "AVAILABLE_BLOCKS"]
|
||||
|
||||
@@ -0,0 +1,298 @@
|
||||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
|
||||
import requests
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||
|
||||
|
||||
class AudioTrack(str, Enum):
|
||||
OBSERVER = ("Observer",)
|
||||
FUTURISTIC_BEAT = ("Futuristic Beat",)
|
||||
SCIENCE_DOCUMENTARY = ("Science Documentary",)
|
||||
HOTLINE = ("Hotline",)
|
||||
BLADERUNNER_2049 = ("Bladerunner 2049",)
|
||||
A_FUTURE = ("A Future",)
|
||||
ELYSIAN_EMBERS = ("Elysian Embers",)
|
||||
INSPIRING_CINEMATIC = ("Inspiring Cinematic",)
|
||||
BLADERUNNER_REMIX = ("Bladerunner Remix",)
|
||||
IZZAMUZZIC = ("Izzamuzzic",)
|
||||
NAS = ("Nas",)
|
||||
PARIS_ELSE = ("Paris - Else",)
|
||||
SNOWFALL = ("Snowfall",)
|
||||
BURLESQUE = ("Burlesque",)
|
||||
CORNY_CANDY = ("Corny Candy",)
|
||||
HIGHWAY_NOCTURNE = ("Highway Nocturne",)
|
||||
I_DONT_THINK_SO = ("I Don't Think So",)
|
||||
LOSING_YOUR_MARBLES = ("Losing Your Marbles",)
|
||||
REFRESHER = ("Refresher",)
|
||||
TOURIST = ("Tourist",)
|
||||
TWIN_TYCHES = ("Twin Tyches",)
|
||||
|
||||
@property
|
||||
def audio_url(self):
|
||||
audio_urls = {
|
||||
AudioTrack.OBSERVER: "https://cdn.tfrv.xyz/audio/observer.mp3",
|
||||
AudioTrack.FUTURISTIC_BEAT: "https://cdn.tfrv.xyz/audio/_futuristic-beat.mp3",
|
||||
AudioTrack.SCIENCE_DOCUMENTARY: "https://cdn.tfrv.xyz/audio/_science-documentary.mp3",
|
||||
AudioTrack.HOTLINE: "https://cdn.tfrv.xyz/audio/_hotline.mp3",
|
||||
AudioTrack.BLADERUNNER_2049: "https://cdn.tfrv.xyz/audio/_bladerunner-2049.mp3",
|
||||
AudioTrack.A_FUTURE: "https://cdn.tfrv.xyz/audio/a-future.mp3",
|
||||
AudioTrack.ELYSIAN_EMBERS: "https://cdn.tfrv.xyz/audio/elysian-embers.mp3",
|
||||
AudioTrack.INSPIRING_CINEMATIC: "https://cdn.tfrv.xyz/audio/inspiring-cinematic-ambient.mp3",
|
||||
AudioTrack.BLADERUNNER_REMIX: "https://cdn.tfrv.xyz/audio/bladerunner-remix.mp3",
|
||||
AudioTrack.IZZAMUZZIC: "https://cdn.tfrv.xyz/audio/_izzamuzzic.mp3",
|
||||
AudioTrack.NAS: "https://cdn.tfrv.xyz/audio/_nas.mp3",
|
||||
AudioTrack.PARIS_ELSE: "https://cdn.tfrv.xyz/audio/_paris-else.mp3",
|
||||
AudioTrack.SNOWFALL: "https://cdn.tfrv.xyz/audio/_snowfall.mp3",
|
||||
AudioTrack.BURLESQUE: "https://cdn.tfrv.xyz/audio/burlesque.mp3",
|
||||
AudioTrack.CORNY_CANDY: "https://cdn.tfrv.xyz/audio/corny-candy.mp3",
|
||||
AudioTrack.HIGHWAY_NOCTURNE: "https://cdn.tfrv.xyz/audio/highway-nocturne.mp3",
|
||||
AudioTrack.I_DONT_THINK_SO: "https://cdn.tfrv.xyz/audio/i-dont-think-so.mp3",
|
||||
AudioTrack.LOSING_YOUR_MARBLES: "https://cdn.tfrv.xyz/audio/losing-your-marbles.mp3",
|
||||
AudioTrack.REFRESHER: "https://cdn.tfrv.xyz/audio/refresher.mp3",
|
||||
AudioTrack.TOURIST: "https://cdn.tfrv.xyz/audio/tourist.mp3",
|
||||
AudioTrack.TWIN_TYCHES: "https://cdn.tfrv.xyz/audio/twin-tynches.mp3",
|
||||
}
|
||||
return audio_urls[self]
|
||||
|
||||
|
||||
class GenerationPreset(str, Enum):
|
||||
LEONARDO = ("Default",)
|
||||
ANIME = ("Anime",)
|
||||
REALISM = ("Realist",)
|
||||
ILLUSTRATION = ("Illustration",)
|
||||
SKETCH_COLOR = ("Sketch Color",)
|
||||
SKETCH_BW = ("Sketch B&W",)
|
||||
PIXAR = ("Pixar",)
|
||||
INK = ("Japanese Ink",)
|
||||
RENDER_3D = ("3D Render",)
|
||||
LEGO = ("Lego",)
|
||||
SCIFI = ("Sci-Fi",)
|
||||
RECRO_CARTOON = ("Retro Cartoon",)
|
||||
PIXEL_ART = ("Pixel Art",)
|
||||
CREATIVE = ("Creative",)
|
||||
PHOTOGRAPHY = ("Photography",)
|
||||
RAYTRACED = ("Raytraced",)
|
||||
ENVIRONMENT = ("Environment",)
|
||||
FANTASY = ("Fantasy",)
|
||||
ANIME_SR = ("Anime Realism",)
|
||||
MOVIE = ("Movie",)
|
||||
STYLIZED_ILLUSTRATION = ("Stylized Illustration",)
|
||||
MANGA = ("Manga",)
|
||||
|
||||
|
||||
class Voice(str, Enum):
|
||||
LILY = "Lily"
|
||||
DANIEL = "Daniel"
|
||||
BRIAN = "Brian"
|
||||
JESSICA = "Jessica"
|
||||
CHARLOTTE = "Charlotte"
|
||||
CALLUM = "Callum"
|
||||
|
||||
@property
|
||||
def voice_id(self):
|
||||
voice_id_map = {
|
||||
Voice.LILY: "pFZP5JQG7iQjIQuC4Bku",
|
||||
Voice.DANIEL: "onwK4e9ZLuTAKqWW03F9",
|
||||
Voice.BRIAN: "nPczCjzI2devNBz1zQrb",
|
||||
Voice.JESSICA: "cgSgspJ2msm6clMCkdW9",
|
||||
Voice.CHARLOTTE: "XB0fDUnXU5powFXDhCwa",
|
||||
Voice.CALLUM: "N2lVS1w4EtoT3dr4eOWO",
|
||||
}
|
||||
return voice_id_map[self]
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class VisualMediaType(str, Enum):
|
||||
STOCK_VIDEOS = ("stockVideo",)
|
||||
MOVING_AI_IMAGES = ("movingImage",)
|
||||
AI_VIDEO = ("aiVideo",)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIShortformVideoCreatorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
api_key: BlockSecret = SecretField(
|
||||
key="revid_api_key",
|
||||
description="Your revid.ai API key",
|
||||
placeholder="Enter your revid.ai API key",
|
||||
)
|
||||
script: str = SchemaField(
|
||||
description="""1. Use short and punctuated sentences\n\n2. Use linebreaks to create a new clip\n\n3. Text outside of brackets is spoken by the AI, and [text between brackets] will be used to guide the visual generation. For example, [close-up of a cat] will show a close-up of a cat.""",
|
||||
placeholder="[close-up of a cat] Meow!",
|
||||
)
|
||||
ratio: str = SchemaField(
|
||||
description="Aspect ratio of the video", default="9 / 16"
|
||||
)
|
||||
resolution: str = SchemaField(
|
||||
description="Resolution of the video", default="720p"
|
||||
)
|
||||
frame_rate: int = SchemaField(description="Frame rate of the video", default=60)
|
||||
generation_preset: GenerationPreset = SchemaField(
|
||||
description="Generation preset for visual style - only effects AI generated visuals",
|
||||
default=GenerationPreset.LEONARDO,
|
||||
placeholder=GenerationPreset.LEONARDO,
|
||||
)
|
||||
background_music: AudioTrack = SchemaField(
|
||||
description="Background music track",
|
||||
default=AudioTrack.HIGHWAY_NOCTURNE,
|
||||
placeholder=AudioTrack.HIGHWAY_NOCTURNE,
|
||||
)
|
||||
voice: Voice = SchemaField(
|
||||
description="AI voice to use for narration",
|
||||
default=Voice.LILY,
|
||||
placeholder=Voice.LILY,
|
||||
)
|
||||
video_style: VisualMediaType = SchemaField(
|
||||
description="Type of visual media to use for the video",
|
||||
default=VisualMediaType.STOCK_VIDEOS,
|
||||
placeholder=VisualMediaType.STOCK_VIDEOS,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="The URL of the created video")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="361697fb-0c4f-4feb-aed3-8320c88c771b",
|
||||
description="Creates a shortform video using revid.ai",
|
||||
categories={BlockCategory.SOCIAL, BlockCategory.AI},
|
||||
input_schema=AIShortformVideoCreatorBlock.Input,
|
||||
output_schema=AIShortformVideoCreatorBlock.Output,
|
||||
test_input={
|
||||
"api_key": "test_api_key",
|
||||
"script": "[close-up of a cat] Meow!",
|
||||
"ratio": "9 / 16",
|
||||
"resolution": "720p",
|
||||
"frame_rate": 60,
|
||||
"generation_preset": GenerationPreset.LEONARDO,
|
||||
"background_music": AudioTrack.HIGHWAY_NOCTURNE,
|
||||
"voice": Voice.LILY,
|
||||
"video_style": VisualMediaType.STOCK_VIDEOS,
|
||||
},
|
||||
test_output=(
|
||||
"video_url",
|
||||
"https://example.com/video.mp4",
|
||||
),
|
||||
test_mock={
|
||||
"create_webhook": lambda: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda api_key, payload: {"pid": "test_pid"},
|
||||
"wait_for_video": lambda api_key, pid, webhook_token, max_wait_time=1000: "https://example.com/video.mp4",
|
||||
},
|
||||
)
|
||||
|
||||
def create_webhook(self):
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = requests.post(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
webhook_data = response.json()
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
def create_video(self, api_key: str, payload: dict) -> dict:
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key}
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
logger.debug(
|
||||
f"API Response Status Code: {response.status_code}, Content: {response.text}"
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def check_video_status(self, api_key: str, pid: str) -> dict:
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key}
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def wait_for_video(
|
||||
self, api_key: str, pid: str, webhook_token: str, max_wait_time: int = 1000
|
||||
) -> str:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = self.check_video_status(api_key, pid)
|
||||
logger.debug(f"Video status: {status}")
|
||||
|
||||
if status.get("status") == "ready" and "videoUrl" in status:
|
||||
return status["videoUrl"]
|
||||
elif status.get("status") == "error":
|
||||
error_message = status.get("error", "Unknown error occurred")
|
||||
logger.error(f"Video creation failed: {error_message}")
|
||||
raise ValueError(f"Video creation failed: {error_message}")
|
||||
elif status.get("status") in ["FAILED", "CANCELED"]:
|
||||
logger.error(f"Video creation failed: {status.get('message')}")
|
||||
raise ValueError(f"Video creation failed: {status.get('message')}")
|
||||
|
||||
time.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# Create a new Webhook.site URL
|
||||
webhook_token, webhook_url = self.create_webhook()
|
||||
logger.debug(f"Webhook URL: {webhook_url}")
|
||||
|
||||
audio_url = input_data.background_music.audio_url
|
||||
|
||||
payload = {
|
||||
"frameRate": input_data.frame_rate,
|
||||
"resolution": input_data.resolution,
|
||||
"frameDurationMultiplier": 18,
|
||||
"webhook": webhook_url,
|
||||
"creationParams": {
|
||||
"mediaType": input_data.video_style,
|
||||
"captionPresetName": "Wrap 1",
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"hasEnhancedGeneration": True,
|
||||
"generationPreset": input_data.generation_preset.name,
|
||||
"selectedAudio": input_data.background_music,
|
||||
"origin": "/create",
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
"slug": "create-tiktok-video",
|
||||
"hasToGenerateVoice": True,
|
||||
"hasToTranscript": False,
|
||||
"hasToSearchMedia": True,
|
||||
"hasAvatar": False,
|
||||
"hasWebsiteRecorder": False,
|
||||
"hasTextSmallAtBottom": False,
|
||||
"ratio": input_data.ratio,
|
||||
"sourceType": "contentScraping",
|
||||
"selectedStoryStyle": {"value": "custom", "label": "Custom"},
|
||||
"hasToGenerateVideos": input_data.video_style
|
||||
!= VisualMediaType.STOCK_VIDEOS,
|
||||
"audioUrl": audio_url,
|
||||
},
|
||||
}
|
||||
|
||||
logger.debug("Creating video...")
|
||||
response = self.create_video(input_data.api_key.get_secret_value(), payload)
|
||||
pid = response.get("pid")
|
||||
|
||||
if not pid:
|
||||
logger.error(
|
||||
f"Failed to create video: No project ID returned. API Response: {response}"
|
||||
)
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
else:
|
||||
logger.debug(
|
||||
f"Video created with project ID: {pid}. Waiting for completion..."
|
||||
)
|
||||
video_url = self.wait_for_video(
|
||||
input_data.api_key.get_secret_value(), pid, webhook_token
|
||||
)
|
||||
logger.debug(f"Video ready: {video_url}")
|
||||
yield "video_url", video_url
|
||||
@@ -2,7 +2,6 @@ import re
|
||||
from typing import Any, List
|
||||
|
||||
from jinja2 import BaseLoader, Environment
|
||||
from pydantic import Field
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||
from backend.data.model import SchemaField
|
||||
@@ -19,18 +18,18 @@ class StoreValueBlock(Block):
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
input: Any = Field(
|
||||
input: Any = SchemaField(
|
||||
description="Trigger the block to produce the output. "
|
||||
"The value is only used when `data` is None."
|
||||
)
|
||||
data: Any = Field(
|
||||
data: Any = SchemaField(
|
||||
description="The constant data to be retained in the block. "
|
||||
"This value is passed as `output`.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any
|
||||
output: Any = SchemaField(description="The stored data retained in the block.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -56,10 +55,10 @@ class StoreValueBlock(Block):
|
||||
|
||||
class PrintToConsoleBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str
|
||||
text: str = SchemaField(description="The text to print to the console.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str
|
||||
status: str = SchemaField(description="The status of the print operation.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -79,16 +78,18 @@ class PrintToConsoleBlock(Block):
|
||||
|
||||
class FindInDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: Any = Field(description="Dictionary to lookup from")
|
||||
key: str | int = Field(description="Key to lookup in the dictionary")
|
||||
input: Any = SchemaField(description="Dictionary to lookup from")
|
||||
key: str | int = SchemaField(description="Key to lookup in the dictionary")
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = Field(description="Value found for the given key")
|
||||
missing: Any = Field(description="Value of the input that missing the key")
|
||||
output: Any = SchemaField(description="Value found for the given key")
|
||||
missing: Any = SchemaField(
|
||||
description="Value of the input that missing the key"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b2g2c3d4-5e6f-7g8h-9i0j-k1l2m3n4o5p6",
|
||||
id="0e50422c-6dee-4145-83d6-3a5a392f65de",
|
||||
description="Lookup the given key in the input dictionary/object/list and return the value.",
|
||||
input_schema=FindInDictionaryBlock.Input,
|
||||
output_schema=FindInDictionaryBlock.Output,
|
||||
@@ -330,20 +331,17 @@ class AddToDictionaryBlock(Block):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# If no dictionary is provided, create a new one
|
||||
if input_data.dictionary is None:
|
||||
updated_dict = {}
|
||||
else:
|
||||
# Create a copy of the input dictionary to avoid modifying the original
|
||||
updated_dict = input_data.dictionary.copy()
|
||||
# If no dictionary is provided, create a new one
|
||||
if input_data.dictionary is None:
|
||||
updated_dict = {}
|
||||
else:
|
||||
# Create a copy of the input dictionary to avoid modifying the original
|
||||
updated_dict = input_data.dictionary.copy()
|
||||
|
||||
# Add the new key-value pair
|
||||
updated_dict[input_data.key] = input_data.value
|
||||
# Add the new key-value pair
|
||||
updated_dict[input_data.key] = input_data.value
|
||||
|
||||
yield "updated_dictionary", updated_dict
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to add entry to dictionary: {str(e)}"
|
||||
yield "updated_dictionary", updated_dict
|
||||
|
||||
|
||||
class AddToListBlock(Block):
|
||||
@@ -401,23 +399,20 @@ class AddToListBlock(Block):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# If no list is provided, create a new one
|
||||
if input_data.list is None:
|
||||
updated_list = []
|
||||
else:
|
||||
# Create a copy of the input list to avoid modifying the original
|
||||
updated_list = input_data.list.copy()
|
||||
# If no list is provided, create a new one
|
||||
if input_data.list is None:
|
||||
updated_list = []
|
||||
else:
|
||||
# Create a copy of the input list to avoid modifying the original
|
||||
updated_list = input_data.list.copy()
|
||||
|
||||
# Add the new entry
|
||||
if input_data.position is None:
|
||||
updated_list.append(input_data.entry)
|
||||
else:
|
||||
updated_list.insert(input_data.position, input_data.entry)
|
||||
# Add the new entry
|
||||
if input_data.position is None:
|
||||
updated_list.append(input_data.entry)
|
||||
else:
|
||||
updated_list.insert(input_data.position, input_data.entry)
|
||||
|
||||
yield "updated_list", updated_list
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to add entry to list: {str(e)}"
|
||||
yield "updated_list", updated_list
|
||||
|
||||
|
||||
class NoteBlock(Block):
|
||||
@@ -429,7 +424,7 @@ class NoteBlock(Block):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="31d1064e-7446-4693-o7d4-65e5ca9110d1",
|
||||
id="cc10ff7b-7753-4ff2-9af6-9399b1a7eddc",
|
||||
description="This block is used to display a sticky note with the given text.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=NoteBlock.Input,
|
||||
|
||||
@@ -3,6 +3,7 @@ import re
|
||||
from typing import Type
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class BlockInstallationBlock(Block):
|
||||
@@ -15,11 +16,17 @@ class BlockInstallationBlock(Block):
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
code: str
|
||||
code: str = SchemaField(
|
||||
description="Python code of the block to be installed",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: str
|
||||
error: str
|
||||
success: str = SchemaField(
|
||||
description="Success message if the block is installed successfully",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the block installation fails",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -37,14 +44,12 @@ class BlockInstallationBlock(Block):
|
||||
if search := re.search(r"class (\w+)\(Block\):", code):
|
||||
class_name = search.group(1)
|
||||
else:
|
||||
yield "error", "No class found in the code."
|
||||
return
|
||||
raise RuntimeError("No class found in the code.")
|
||||
|
||||
if search := re.search(r"id=\"(\w+-\w+-\w+-\w+-\w+)\"", code):
|
||||
file_name = search.group(1)
|
||||
else:
|
||||
yield "error", "No UUID found in the code."
|
||||
return
|
||||
raise RuntimeError("No UUID found in the code.")
|
||||
|
||||
block_dir = os.path.dirname(__file__)
|
||||
file_path = f"{block_dir}/{file_name}.py"
|
||||
@@ -63,4 +68,4 @@ class BlockInstallationBlock(Block):
|
||||
yield "success", "Block installed successfully."
|
||||
except Exception as e:
|
||||
os.remove(file_path)
|
||||
yield "error", f"[Code]\n{code}\n\n[Error]\n{str(e)}"
|
||||
raise RuntimeError(f"[Code]\n{code}\n\n[Error]\n{str(e)}")
|
||||
|
||||
@@ -1,21 +1,49 @@
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import ContributorDetails
|
||||
from backend.data.model import ContributorDetails, SchemaField
|
||||
|
||||
|
||||
class ReadCsvBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
contents: str
|
||||
delimiter: str = ","
|
||||
quotechar: str = '"'
|
||||
escapechar: str = "\\"
|
||||
has_header: bool = True
|
||||
skip_rows: int = 0
|
||||
strip: bool = True
|
||||
skip_columns: list[str] = []
|
||||
contents: str = SchemaField(
|
||||
description="The contents of the CSV file to read",
|
||||
placeholder="a, b, c\n1,2,3\n4,5,6",
|
||||
)
|
||||
delimiter: str = SchemaField(
|
||||
description="The delimiter used in the CSV file",
|
||||
default=",",
|
||||
)
|
||||
quotechar: str = SchemaField(
|
||||
description="The character used to quote fields",
|
||||
default='"',
|
||||
)
|
||||
escapechar: str = SchemaField(
|
||||
description="The character used to escape the delimiter",
|
||||
default="\\",
|
||||
)
|
||||
has_header: bool = SchemaField(
|
||||
description="Whether the CSV file has a header row",
|
||||
default=True,
|
||||
)
|
||||
skip_rows: int = SchemaField(
|
||||
description="The number of rows to skip from the start of the file",
|
||||
default=0,
|
||||
)
|
||||
strip: bool = SchemaField(
|
||||
description="Whether to strip whitespace from the values",
|
||||
default=True,
|
||||
)
|
||||
skip_columns: list[str] = SchemaField(
|
||||
description="The columns to skip from the start of the row",
|
||||
default=[],
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
row: dict[str, str]
|
||||
all_data: list[dict[str, str]]
|
||||
row: dict[str, str] = SchemaField(
|
||||
description="The data produced from each row in the CSV file"
|
||||
)
|
||||
all_data: list[dict[str, str]] = SchemaField(
|
||||
description="All the data in the CSV file as a list of rows"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -24,7 +52,7 @@ class ReadCsvBlock(Block):
|
||||
output_schema=ReadCsvBlock.Output,
|
||||
description="Reads a CSV file and outputs the data as a list of dictionaries and individual rows via rows.",
|
||||
contributors=[ContributorDetails(name="Nicholas Tindle")],
|
||||
categories={BlockCategory.TEXT},
|
||||
categories={BlockCategory.TEXT, BlockCategory.DATA},
|
||||
test_input={
|
||||
"contents": "a, b, c\n1,2,3\n4,5,6",
|
||||
},
|
||||
|
||||
39
autogpt_platform/backend/backend/blocks/decoder_block.py
Normal file
39
autogpt_platform/backend/backend/blocks/decoder_block.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import codecs
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TextDecoderBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(
|
||||
description="A string containing escaped characters to be decoded",
|
||||
placeholder='Your entire text block with \\n and \\" escaped characters',
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
decoded_text: str = SchemaField(
|
||||
description="The decoded text with escape sequences processed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2570e8fe-8447-43ed-84c7-70d657923231",
|
||||
description="Decodes a string containing escape sequences into actual text",
|
||||
categories={BlockCategory.TEXT},
|
||||
input_schema=TextDecoderBlock.Input,
|
||||
output_schema=TextDecoderBlock.Output,
|
||||
test_input={"text": """Hello\nWorld!\nThis is a \"quoted\" string."""},
|
||||
test_output=[
|
||||
(
|
||||
"decoded_text",
|
||||
"""Hello
|
||||
World!
|
||||
This is a "quoted" string.""",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
decoded_text = codecs.decode(input_data.text, "unicode_escape")
|
||||
yield "decoded_text", decoded_text
|
||||
@@ -2,10 +2,9 @@ import asyncio
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
from pydantic import Field
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import BlockSecret, SecretField
|
||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||
|
||||
|
||||
class ReadDiscordMessagesBlock(Block):
|
||||
@@ -13,22 +12,24 @@ class ReadDiscordMessagesBlock(Block):
|
||||
discord_bot_token: BlockSecret = SecretField(
|
||||
key="discord_bot_token", description="Discord bot token"
|
||||
)
|
||||
continuous_read: bool = Field(
|
||||
continuous_read: bool = SchemaField(
|
||||
description="Whether to continuously read messages", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
message_content: str = Field(description="The content of the message received")
|
||||
channel_name: str = Field(
|
||||
message_content: str = SchemaField(
|
||||
description="The content of the message received"
|
||||
)
|
||||
channel_name: str = SchemaField(
|
||||
description="The name of the channel the message was received from"
|
||||
)
|
||||
username: str = Field(
|
||||
username: str = SchemaField(
|
||||
description="The username of the user who sent the message"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d3f4g5h6-1i2j-3k4l-5m6n-7o8p9q0r1s2t", # Unique ID for the node
|
||||
id="df06086a-d5ac-4abb-9996-2ad0acb2eff7",
|
||||
input_schema=ReadDiscordMessagesBlock.Input, # Assign input schema
|
||||
output_schema=ReadDiscordMessagesBlock.Output, # Assign output schema
|
||||
description="Reads messages from a Discord channel using a bot token.",
|
||||
@@ -134,19 +135,21 @@ class SendDiscordMessageBlock(Block):
|
||||
discord_bot_token: BlockSecret = SecretField(
|
||||
key="discord_bot_token", description="Discord bot token"
|
||||
)
|
||||
message_content: str = Field(description="The content of the message received")
|
||||
channel_name: str = Field(
|
||||
message_content: str = SchemaField(
|
||||
description="The content of the message received"
|
||||
)
|
||||
channel_name: str = SchemaField(
|
||||
description="The name of the channel the message was received from"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = Field(
|
||||
status: str = SchemaField(
|
||||
description="The status of the operation (e.g., 'Message sent', 'Error')"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="h1i2j3k4-5l6m-7n8o-9p0q-r1s2t3u4v5w6", # Unique ID for the node
|
||||
id="d0822ab5-9f8a-44a3-8971-531dd0178b6b",
|
||||
input_schema=SendDiscordMessageBlock.Input, # Assign input schema
|
||||
output_schema=SendDiscordMessageBlock.Output, # Assign output schema
|
||||
description="Sends a message to a Discord channel using a bot token.",
|
||||
|
||||
@@ -2,17 +2,17 @@ import smtplib
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||
|
||||
|
||||
class EmailCredentials(BaseModel):
|
||||
smtp_server: str = Field(
|
||||
smtp_server: str = SchemaField(
|
||||
default="smtp.gmail.com", description="SMTP server address"
|
||||
)
|
||||
smtp_port: int = Field(default=25, description="SMTP port number")
|
||||
smtp_port: int = SchemaField(default=25, description="SMTP port number")
|
||||
smtp_username: BlockSecret = SecretField(key="smtp_username")
|
||||
smtp_password: BlockSecret = SecretField(key="smtp_password")
|
||||
|
||||
@@ -30,7 +30,7 @@ class SendEmailBlock(Block):
|
||||
body: str = SchemaField(
|
||||
description="Body of the email", placeholder="Enter the email body"
|
||||
)
|
||||
creds: EmailCredentials = Field(
|
||||
creds: EmailCredentials = SchemaField(
|
||||
description="SMTP credentials",
|
||||
default=EmailCredentials(),
|
||||
)
|
||||
@@ -43,7 +43,7 @@ class SendEmailBlock(Block):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a1234567-89ab-cdef-0123-456789abcdef",
|
||||
id="4335878a-394e-4e67-adf2-919877ff49ae",
|
||||
description="This block sends an email using the provided SMTP credentials.",
|
||||
categories={BlockCategory.OUTPUT},
|
||||
input_schema=SendEmailBlock.Input,
|
||||
@@ -67,35 +67,28 @@ class SendEmailBlock(Block):
|
||||
def send_email(
|
||||
creds: EmailCredentials, to_email: str, subject: str, body: str
|
||||
) -> str:
|
||||
try:
|
||||
smtp_server = creds.smtp_server
|
||||
smtp_port = creds.smtp_port
|
||||
smtp_username = creds.smtp_username.get_secret_value()
|
||||
smtp_password = creds.smtp_password.get_secret_value()
|
||||
smtp_server = creds.smtp_server
|
||||
smtp_port = creds.smtp_port
|
||||
smtp_username = creds.smtp_username.get_secret_value()
|
||||
smtp_password = creds.smtp_password.get_secret_value()
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg["From"] = smtp_username
|
||||
msg["To"] = to_email
|
||||
msg["Subject"] = subject
|
||||
msg.attach(MIMEText(body, "plain"))
|
||||
msg = MIMEMultipart()
|
||||
msg["From"] = smtp_username
|
||||
msg["To"] = to_email
|
||||
msg["Subject"] = subject
|
||||
msg.attach(MIMEText(body, "plain"))
|
||||
|
||||
with smtplib.SMTP(smtp_server, smtp_port) as server:
|
||||
server.starttls()
|
||||
server.login(smtp_username, smtp_password)
|
||||
server.sendmail(smtp_username, to_email, msg.as_string())
|
||||
with smtplib.SMTP(smtp_server, smtp_port) as server:
|
||||
server.starttls()
|
||||
server.login(smtp_username, smtp_password)
|
||||
server.sendmail(smtp_username, to_email, msg.as_string())
|
||||
|
||||
return "Email sent successfully"
|
||||
except Exception as e:
|
||||
return f"Failed to send email: {str(e)}"
|
||||
return "Email sent successfully"
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
status = self.send_email(
|
||||
yield "status", self.send_email(
|
||||
input_data.creds,
|
||||
input_data.to_email,
|
||||
input_data.subject,
|
||||
input_data.body,
|
||||
)
|
||||
if "successfully" in status:
|
||||
yield "status", status
|
||||
else:
|
||||
yield "error", status
|
||||
|
||||
@@ -13,6 +13,7 @@ from ._auth import (
|
||||
)
|
||||
|
||||
|
||||
# --8<-- [start:GithubCommentBlockExample]
|
||||
class GithubCommentBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
@@ -92,16 +93,16 @@ class GithubCommentBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
id, url = self.post_comment(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.comment,
|
||||
)
|
||||
yield "id", id
|
||||
yield "url", url
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to post comment: {str(e)}"
|
||||
id, url = self.post_comment(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.comment,
|
||||
)
|
||||
yield "id", id
|
||||
yield "url", url
|
||||
|
||||
|
||||
# --8<-- [end:GithubCommentBlockExample]
|
||||
|
||||
|
||||
class GithubMakeIssueBlock(Block):
|
||||
@@ -175,17 +176,14 @@ class GithubMakeIssueBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
number, url = self.create_issue(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.title,
|
||||
input_data.body,
|
||||
)
|
||||
yield "number", number
|
||||
yield "url", url
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create issue: {str(e)}"
|
||||
number, url = self.create_issue(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.title,
|
||||
input_data.body,
|
||||
)
|
||||
yield "number", number
|
||||
yield "url", url
|
||||
|
||||
|
||||
class GithubReadIssueBlock(Block):
|
||||
@@ -258,16 +256,13 @@ class GithubReadIssueBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
title, body, user = self.read_issue(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
)
|
||||
yield "title", title
|
||||
yield "body", body
|
||||
yield "user", user
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to read issue: {str(e)}"
|
||||
title, body, user = self.read_issue(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
)
|
||||
yield "title", title
|
||||
yield "body", body
|
||||
yield "user", user
|
||||
|
||||
|
||||
class GithubListIssuesBlock(Block):
|
||||
@@ -346,14 +341,11 @@ class GithubListIssuesBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
issues = self.list_issues(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("issue", issue) for issue in issues)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list issues: {str(e)}"
|
||||
issues = self.list_issues(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("issue", issue) for issue in issues)
|
||||
|
||||
|
||||
class GithubAddLabelBlock(Block):
|
||||
@@ -424,15 +416,12 @@ class GithubAddLabelBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.add_label(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.label,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to add label: {str(e)}"
|
||||
status = self.add_label(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.label,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubRemoveLabelBlock(Block):
|
||||
@@ -508,15 +497,12 @@ class GithubRemoveLabelBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.remove_label(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.label,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to remove label: {str(e)}"
|
||||
status = self.remove_label(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.label,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubAssignIssueBlock(Block):
|
||||
@@ -590,15 +576,12 @@ class GithubAssignIssueBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.assign_issue(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.assignee,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to assign issue: {str(e)}"
|
||||
status = self.assign_issue(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.assignee,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubUnassignIssueBlock(Block):
|
||||
@@ -672,12 +655,9 @@ class GithubUnassignIssueBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.unassign_issue(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.assignee,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to unassign issue: {str(e)}"
|
||||
status = self.unassign_issue(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.assignee,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
@@ -87,14 +87,11 @@ class GithubListPullRequestsBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
pull_requests = self.list_prs(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("pull_request", pr) for pr in pull_requests)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list pull requests: {str(e)}"
|
||||
pull_requests = self.list_prs(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("pull_request", pr) for pr in pull_requests)
|
||||
|
||||
|
||||
class GithubMakePullRequestBlock(Block):
|
||||
@@ -203,9 +200,7 @@ class GithubMakePullRequestBlock(Block):
|
||||
error_message = error_details.get("message", "Unknown error")
|
||||
else:
|
||||
error_message = str(http_err)
|
||||
yield "error", f"Failed to create pull request: {error_message}"
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create pull request: {str(e)}"
|
||||
raise RuntimeError(f"Failed to create pull request: {error_message}")
|
||||
|
||||
|
||||
class GithubReadPullRequestBlock(Block):
|
||||
@@ -313,23 +308,20 @@ class GithubReadPullRequestBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
title, body, author = self.read_pr(
|
||||
title, body, author = self.read_pr(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
)
|
||||
yield "title", title
|
||||
yield "body", body
|
||||
yield "author", author
|
||||
|
||||
if input_data.include_pr_changes:
|
||||
changes = self.read_pr_changes(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
)
|
||||
yield "title", title
|
||||
yield "body", body
|
||||
yield "author", author
|
||||
|
||||
if input_data.include_pr_changes:
|
||||
changes = self.read_pr_changes(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
)
|
||||
yield "changes", changes
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to read pull request: {str(e)}"
|
||||
yield "changes", changes
|
||||
|
||||
|
||||
class GithubAssignPRReviewerBlock(Block):
|
||||
@@ -418,9 +410,7 @@ class GithubAssignPRReviewerBlock(Block):
|
||||
)
|
||||
else:
|
||||
error_msg = f"HTTP error: {http_err} - {http_err.response.text}"
|
||||
yield "error", error_msg
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to assign reviewer: {str(e)}"
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
class GithubUnassignPRReviewerBlock(Block):
|
||||
@@ -490,15 +480,12 @@ class GithubUnassignPRReviewerBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.unassign_reviewer(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
input_data.reviewer,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to unassign reviewer: {str(e)}"
|
||||
status = self.unassign_reviewer(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
input_data.reviewer,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubListPRReviewersBlock(Block):
|
||||
@@ -586,11 +573,8 @@ class GithubListPRReviewersBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
reviewers = self.list_reviewers(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
)
|
||||
yield from (("reviewer", reviewer) for reviewer in reviewers)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list reviewers: {str(e)}"
|
||||
reviewers = self.list_reviewers(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
)
|
||||
yield from (("reviewer", reviewer) for reviewer in reviewers)
|
||||
|
||||
@@ -96,14 +96,11 @@ class GithubListTagsBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
tags = self.list_tags(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("tag", tag) for tag in tags)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list tags: {str(e)}"
|
||||
tags = self.list_tags(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("tag", tag) for tag in tags)
|
||||
|
||||
|
||||
class GithubListBranchesBlock(Block):
|
||||
@@ -183,14 +180,11 @@ class GithubListBranchesBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
branches = self.list_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("branch", branch) for branch in branches)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list branches: {str(e)}"
|
||||
branches = self.list_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("branch", branch) for branch in branches)
|
||||
|
||||
|
||||
class GithubListDiscussionsBlock(Block):
|
||||
@@ -294,13 +288,10 @@ class GithubListDiscussionsBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
discussions = self.list_discussions(
|
||||
credentials, input_data.repo_url, input_data.num_discussions
|
||||
)
|
||||
yield from (("discussion", discussion) for discussion in discussions)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list discussions: {str(e)}"
|
||||
discussions = self.list_discussions(
|
||||
credentials, input_data.repo_url, input_data.num_discussions
|
||||
)
|
||||
yield from (("discussion", discussion) for discussion in discussions)
|
||||
|
||||
|
||||
class GithubListReleasesBlock(Block):
|
||||
@@ -381,14 +372,11 @@ class GithubListReleasesBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
releases = self.list_releases(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("release", release) for release in releases)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list releases: {str(e)}"
|
||||
releases = self.list_releases(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("release", release) for release in releases)
|
||||
|
||||
|
||||
class GithubReadFileBlock(Block):
|
||||
@@ -474,18 +462,15 @@ class GithubReadFileBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
raw_content, size = self.read_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
yield "raw_content", raw_content
|
||||
yield "text_content", base64.b64decode(raw_content).decode("utf-8")
|
||||
yield "size", size
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to read file: {str(e)}"
|
||||
raw_content, size = self.read_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
yield "raw_content", raw_content
|
||||
yield "text_content", base64.b64decode(raw_content).decode("utf-8")
|
||||
yield "size", size
|
||||
|
||||
|
||||
class GithubReadFolderBlock(Block):
|
||||
@@ -612,17 +597,14 @@ class GithubReadFolderBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
files, dirs = self.read_folder(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.folder_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
yield from (("file", file) for file in files)
|
||||
yield from (("dir", dir) for dir in dirs)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to read folder: {str(e)}"
|
||||
files, dirs = self.read_folder(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.folder_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
yield from (("file", file) for file in files)
|
||||
yield from (("dir", dir) for dir in dirs)
|
||||
|
||||
|
||||
class GithubMakeBranchBlock(Block):
|
||||
@@ -703,16 +685,13 @@ class GithubMakeBranchBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.create_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.new_branch,
|
||||
input_data.source_branch,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create branch: {str(e)}"
|
||||
status = self.create_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.new_branch,
|
||||
input_data.source_branch,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubDeleteBranchBlock(Block):
|
||||
@@ -775,12 +754,9 @@ class GithubDeleteBranchBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.delete_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to delete branch: {str(e)}"
|
||||
status = self.delete_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
54
autogpt_platform/backend/backend/blocks/google/_auth.py
Normal file
54
autogpt_platform/backend/backend/blocks/google/_auth.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import Literal
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import OAuth2Credentials
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import CredentialsField, CredentialsMetaInput
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
# --8<-- [start:GoogleOAuthIsConfigured]
|
||||
secrets = Secrets()
|
||||
GOOGLE_OAUTH_IS_CONFIGURED = bool(
|
||||
secrets.google_client_id and secrets.google_client_secret
|
||||
)
|
||||
# --8<-- [end:GoogleOAuthIsConfigured]
|
||||
GoogleCredentials = OAuth2Credentials
|
||||
GoogleCredentialsInput = CredentialsMetaInput[Literal["google"], Literal["oauth2"]]
|
||||
|
||||
|
||||
def GoogleCredentialsField(scopes: list[str]) -> GoogleCredentialsInput:
|
||||
"""
|
||||
Creates a Google credentials input on a block.
|
||||
|
||||
Params:
|
||||
scopes: The authorization scopes needed for the block to work.
|
||||
"""
|
||||
return CredentialsField(
|
||||
provider="google",
|
||||
supported_credential_types={"oauth2"},
|
||||
required_scopes=set(scopes),
|
||||
description="The Google integration requires OAuth2 authentication.",
|
||||
)
|
||||
|
||||
|
||||
TEST_CREDENTIALS = OAuth2Credentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="google",
|
||||
access_token=SecretStr("mock-google-access-token"),
|
||||
refresh_token=SecretStr("mock-google-refresh-token"),
|
||||
access_token_expires_at=1234567890,
|
||||
scopes=[
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/gmail.send",
|
||||
],
|
||||
title="Mock Google OAuth2 Credentials",
|
||||
username="mock-google-username",
|
||||
refresh_token_expires_at=1234567890,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
503
autogpt_platform/backend/backend/blocks/google/gmail.py
Normal file
503
autogpt_platform/backend/backend/blocks/google/gmail.py
Normal file
@@ -0,0 +1,503 @@
|
||||
import base64
|
||||
from email.utils import parseaddr
|
||||
from typing import List
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._auth import (
|
||||
GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GoogleCredentials,
|
||||
GoogleCredentialsField,
|
||||
GoogleCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class Attachment(BaseModel):
|
||||
filename: str
|
||||
content_type: str
|
||||
size: int
|
||||
attachment_id: str
|
||||
|
||||
|
||||
class Email(BaseModel):
|
||||
id: str
|
||||
subject: str
|
||||
snippet: str
|
||||
from_: str
|
||||
to: str
|
||||
date: str
|
||||
body: str = "" # Default to an empty string
|
||||
sizeEstimate: int
|
||||
attachments: List[Attachment]
|
||||
|
||||
|
||||
class GmailReadBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
)
|
||||
query: str = SchemaField(
|
||||
description="Search query for reading emails",
|
||||
default="is:unread",
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="Maximum number of emails to retrieve",
|
||||
default=10,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
email: Email = SchemaField(
|
||||
description="Email data",
|
||||
)
|
||||
emails: list[Email] = SchemaField(
|
||||
description="List of email data",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if any",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="25310c70-b89b-43ba-b25c-4dfa7e2a481c",
|
||||
description="This block reads emails from Gmail.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
input_schema=GmailReadBlock.Input,
|
||||
output_schema=GmailReadBlock.Output,
|
||||
test_input={
|
||||
"query": "is:unread",
|
||||
"max_results": 5,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"result",
|
||||
[
|
||||
{
|
||||
"id": "1",
|
||||
"subject": "Test Email",
|
||||
"snippet": "This is a test email",
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"_read_emails": lambda *args, **kwargs: [
|
||||
{
|
||||
"id": "1",
|
||||
"subject": "Test Email",
|
||||
"snippet": "This is a test email",
|
||||
}
|
||||
],
|
||||
"_send_email": lambda *args, **kwargs: {"id": "1", "status": "sent"},
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
messages = self._read_emails(service, input_data.query, input_data.max_results)
|
||||
for email in messages:
|
||||
yield "email", email
|
||||
yield "emails", messages
|
||||
|
||||
@staticmethod
|
||||
def _build_service(credentials: GoogleCredentials, **kwargs):
|
||||
creds = Credentials(
|
||||
token=(
|
||||
credentials.access_token.get_secret_value()
|
||||
if credentials.access_token
|
||||
else None
|
||||
),
|
||||
refresh_token=(
|
||||
credentials.refresh_token.get_secret_value()
|
||||
if credentials.refresh_token
|
||||
else None
|
||||
),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=kwargs.get("client_id"),
|
||||
client_secret=kwargs.get("client_secret"),
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
return build("gmail", "v1", credentials=creds)
|
||||
|
||||
def _read_emails(
|
||||
self, service, query: str | None, max_results: int | None
|
||||
) -> list[Email]:
|
||||
results = (
|
||||
service.users()
|
||||
.messages()
|
||||
.list(userId="me", q=query or "", maxResults=max_results or 10)
|
||||
.execute()
|
||||
)
|
||||
messages = results.get("messages", [])
|
||||
|
||||
email_data = []
|
||||
for message in messages:
|
||||
msg = (
|
||||
service.users()
|
||||
.messages()
|
||||
.get(userId="me", id=message["id"], format="full")
|
||||
.execute()
|
||||
)
|
||||
|
||||
headers = {
|
||||
header["name"].lower(): header["value"]
|
||||
for header in msg["payload"]["headers"]
|
||||
}
|
||||
|
||||
attachments = self._get_attachments(service, msg)
|
||||
|
||||
email = Email(
|
||||
id=msg["id"],
|
||||
subject=headers.get("subject", "No Subject"),
|
||||
snippet=msg["snippet"],
|
||||
from_=parseaddr(headers.get("from", ""))[1],
|
||||
to=parseaddr(headers.get("to", ""))[1],
|
||||
date=headers.get("date", ""),
|
||||
body=self._get_email_body(msg),
|
||||
sizeEstimate=msg["sizeEstimate"],
|
||||
attachments=attachments,
|
||||
)
|
||||
email_data.append(email)
|
||||
|
||||
return email_data
|
||||
|
||||
def _get_email_body(self, msg):
|
||||
if "parts" in msg["payload"]:
|
||||
for part in msg["payload"]["parts"]:
|
||||
if part["mimeType"] == "text/plain":
|
||||
return base64.urlsafe_b64decode(part["body"]["data"]).decode(
|
||||
"utf-8"
|
||||
)
|
||||
elif msg["payload"]["mimeType"] == "text/plain":
|
||||
return base64.urlsafe_b64decode(msg["payload"]["body"]["data"]).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
return "This email does not contain a text body."
|
||||
|
||||
def _get_attachments(self, service, message):
|
||||
attachments = []
|
||||
if "parts" in message["payload"]:
|
||||
for part in message["payload"]["parts"]:
|
||||
if part["filename"]:
|
||||
attachment = Attachment(
|
||||
filename=part["filename"],
|
||||
content_type=part["mimeType"],
|
||||
size=int(part["body"].get("size", 0)),
|
||||
attachment_id=part["body"]["attachmentId"],
|
||||
)
|
||||
attachments.append(attachment)
|
||||
return attachments
|
||||
|
||||
# Add a new method to download attachment content
|
||||
def download_attachment(self, service, message_id: str, attachment_id: str):
|
||||
attachment = (
|
||||
service.users()
|
||||
.messages()
|
||||
.attachments()
|
||||
.get(userId="me", messageId=message_id, id=attachment_id)
|
||||
.execute()
|
||||
)
|
||||
file_data = base64.urlsafe_b64decode(attachment["data"].encode("UTF-8"))
|
||||
return file_data
|
||||
|
||||
|
||||
class GmailSendBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/gmail.send"]
|
||||
)
|
||||
to: str = SchemaField(
|
||||
description="Recipient email address",
|
||||
)
|
||||
subject: str = SchemaField(
|
||||
description="Email subject",
|
||||
)
|
||||
body: str = SchemaField(
|
||||
description="Email body",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: dict = SchemaField(
|
||||
description="Send confirmation",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if any",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="6c27abc2-e51d-499e-a85f-5a0041ba94f0",
|
||||
description="This block sends an email using Gmail.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=GmailSendBlock.Input,
|
||||
output_schema=GmailSendBlock.Output,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"to": "recipient@example.com",
|
||||
"subject": "Test Email",
|
||||
"body": "This is a test email sent from GmailSendBlock.",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("result", {"id": "1", "status": "sent"}),
|
||||
],
|
||||
test_mock={
|
||||
"_send_email": lambda *args, **kwargs: {"id": "1", "status": "sent"},
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
service = GmailReadBlock._build_service(credentials, **kwargs)
|
||||
send_result = self._send_email(
|
||||
service, input_data.to, input_data.subject, input_data.body
|
||||
)
|
||||
yield "result", send_result
|
||||
|
||||
def _send_email(self, service, to: str, subject: str, body: str) -> dict:
|
||||
if not to or not subject or not body:
|
||||
raise ValueError("To, subject, and body are required for sending an email")
|
||||
message = self._create_message(to, subject, body)
|
||||
sent_message = (
|
||||
service.users().messages().send(userId="me", body=message).execute()
|
||||
)
|
||||
return {"id": sent_message["id"], "status": "sent"}
|
||||
|
||||
def _create_message(self, to: str, subject: str, body: str) -> dict:
|
||||
import base64
|
||||
from email.mime.text import MIMEText
|
||||
|
||||
message = MIMEText(body)
|
||||
message["to"] = to
|
||||
message["subject"] = subject
|
||||
raw_message = base64.urlsafe_b64encode(message.as_bytes()).decode("utf-8")
|
||||
return {"raw": raw_message}
|
||||
|
||||
|
||||
class GmailListLabelsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/gmail.labels"]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: list[dict] = SchemaField(
|
||||
description="List of labels",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if any",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3e1c2c1c-c689-4520-b956-1f3bf4e02bb7",
|
||||
description="This block lists all labels in Gmail.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=GmailListLabelsBlock.Input,
|
||||
output_schema=GmailListLabelsBlock.Output,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"result",
|
||||
[
|
||||
{"id": "Label_1", "name": "Important"},
|
||||
{"id": "Label_2", "name": "Work"},
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"_list_labels": lambda *args, **kwargs: [
|
||||
{"id": "Label_1", "name": "Important"},
|
||||
{"id": "Label_2", "name": "Work"},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
service = GmailReadBlock._build_service(credentials, **kwargs)
|
||||
labels = self._list_labels(service)
|
||||
yield "result", labels
|
||||
|
||||
def _list_labels(self, service) -> list[dict]:
|
||||
results = service.users().labels().list(userId="me").execute()
|
||||
labels = results.get("labels", [])
|
||||
return [{"id": label["id"], "name": label["name"]} for label in labels]
|
||||
|
||||
|
||||
class GmailAddLabelBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/gmail.modify"]
|
||||
)
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID to add label to",
|
||||
)
|
||||
label_name: str = SchemaField(
|
||||
description="Label name to add",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: dict = SchemaField(
|
||||
description="Label addition result",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if any",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f884b2fb-04f4-4265-9658-14f433926ac9",
|
||||
description="This block adds a label to a Gmail message.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=GmailAddLabelBlock.Input,
|
||||
output_schema=GmailAddLabelBlock.Output,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"message_id": "12345",
|
||||
"label_name": "Important",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"result",
|
||||
{"status": "Label added successfully", "label_id": "Label_1"},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"_add_label": lambda *args, **kwargs: {
|
||||
"status": "Label added successfully",
|
||||
"label_id": "Label_1",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
service = GmailReadBlock._build_service(credentials, **kwargs)
|
||||
result = self._add_label(service, input_data.message_id, input_data.label_name)
|
||||
yield "result", result
|
||||
|
||||
def _add_label(self, service, message_id: str, label_name: str) -> dict:
|
||||
label_id = self._get_or_create_label(service, label_name)
|
||||
service.users().messages().modify(
|
||||
userId="me", id=message_id, body={"addLabelIds": [label_id]}
|
||||
).execute()
|
||||
return {"status": "Label added successfully", "label_id": label_id}
|
||||
|
||||
def _get_or_create_label(self, service, label_name: str) -> str:
|
||||
label_id = self._get_label_id(service, label_name)
|
||||
if not label_id:
|
||||
label = (
|
||||
service.users()
|
||||
.labels()
|
||||
.create(userId="me", body={"name": label_name})
|
||||
.execute()
|
||||
)
|
||||
label_id = label["id"]
|
||||
return label_id
|
||||
|
||||
def _get_label_id(self, service, label_name: str) -> str | None:
|
||||
results = service.users().labels().list(userId="me").execute()
|
||||
labels = results.get("labels", [])
|
||||
for label in labels:
|
||||
if label["name"] == label_name:
|
||||
return label["id"]
|
||||
return None
|
||||
|
||||
|
||||
class GmailRemoveLabelBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/gmail.modify"]
|
||||
)
|
||||
message_id: str = SchemaField(
|
||||
description="Message ID to remove label from",
|
||||
)
|
||||
label_name: str = SchemaField(
|
||||
description="Label name to remove",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: dict = SchemaField(
|
||||
description="Label removal result",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if any",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0afc0526-aba1-4b2b-888e-a22b7c3f359d",
|
||||
description="This block removes a label from a Gmail message.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=GmailRemoveLabelBlock.Input,
|
||||
output_schema=GmailRemoveLabelBlock.Output,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"message_id": "12345",
|
||||
"label_name": "Important",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"result",
|
||||
{"status": "Label removed successfully", "label_id": "Label_1"},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"_remove_label": lambda *args, **kwargs: {
|
||||
"status": "Label removed successfully",
|
||||
"label_id": "Label_1",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
service = GmailReadBlock._build_service(credentials, **kwargs)
|
||||
result = self._remove_label(
|
||||
service, input_data.message_id, input_data.label_name
|
||||
)
|
||||
yield "result", result
|
||||
|
||||
def _remove_label(self, service, message_id: str, label_name: str) -> dict:
|
||||
label_id = self._get_label_id(service, label_name)
|
||||
if label_id:
|
||||
service.users().messages().modify(
|
||||
userId="me", id=message_id, body={"removeLabelIds": [label_id]}
|
||||
).execute()
|
||||
return {"status": "Label removed successfully", "label_id": label_id}
|
||||
else:
|
||||
return {"status": "Label not found", "label_name": label_name}
|
||||
|
||||
def _get_label_id(self, service, label_name: str) -> str | None:
|
||||
results = service.users().labels().list(userId="me").execute()
|
||||
labels = results.get("labels", [])
|
||||
for label in labels:
|
||||
if label["name"] == label_name:
|
||||
return label["id"]
|
||||
return None
|
||||
184
autogpt_platform/backend/backend/blocks/google/sheets.py
Normal file
184
autogpt_platform/backend/backend/blocks/google/sheets.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._auth import (
|
||||
GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GoogleCredentials,
|
||||
GoogleCredentialsField,
|
||||
GoogleCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class GoogleSheetsReadBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/spreadsheets.readonly"]
|
||||
)
|
||||
spreadsheet_id: str = SchemaField(
|
||||
description="The ID of the spreadsheet to read from",
|
||||
)
|
||||
range: str = SchemaField(
|
||||
description="The A1 notation of the range to read",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: list[list[str]] = SchemaField(
|
||||
description="The data read from the spreadsheet",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if any",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5724e902-3635-47e9-a108-aaa0263a4988",
|
||||
description="This block reads data from a Google Sheets spreadsheet.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=GoogleSheetsReadBlock.Input,
|
||||
output_schema=GoogleSheetsReadBlock.Output,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"spreadsheet_id": "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms",
|
||||
"range": "Sheet1!A1:B2",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"result",
|
||||
[
|
||||
["Name", "Score"],
|
||||
["Alice", "85"],
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"_read_sheet": lambda *args, **kwargs: [
|
||||
["Name", "Score"],
|
||||
["Alice", "85"],
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
data = self._read_sheet(service, input_data.spreadsheet_id, input_data.range)
|
||||
yield "result", data
|
||||
|
||||
@staticmethod
|
||||
def _build_service(credentials: GoogleCredentials, **kwargs):
|
||||
creds = Credentials(
|
||||
token=(
|
||||
credentials.access_token.get_secret_value()
|
||||
if credentials.access_token
|
||||
else None
|
||||
),
|
||||
refresh_token=(
|
||||
credentials.refresh_token.get_secret_value()
|
||||
if credentials.refresh_token
|
||||
else None
|
||||
),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=kwargs.get("client_id"),
|
||||
client_secret=kwargs.get("client_secret"),
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
return build("sheets", "v4", credentials=creds)
|
||||
|
||||
def _read_sheet(self, service, spreadsheet_id: str, range: str) -> list[list[str]]:
|
||||
sheet = service.spreadsheets()
|
||||
result = sheet.values().get(spreadsheetId=spreadsheet_id, range=range).execute()
|
||||
return result.get("values", [])
|
||||
|
||||
|
||||
class GoogleSheetsWriteBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
["https://www.googleapis.com/auth/spreadsheets"]
|
||||
)
|
||||
spreadsheet_id: str = SchemaField(
|
||||
description="The ID of the spreadsheet to write to",
|
||||
)
|
||||
range: str = SchemaField(
|
||||
description="The A1 notation of the range to write",
|
||||
)
|
||||
values: list[list[str]] = SchemaField(
|
||||
description="The data to write to the spreadsheet",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: dict = SchemaField(
|
||||
description="The result of the write operation",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if any",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d9291e87-301d-47a8-91fe-907fb55460e5",
|
||||
description="This block writes data to a Google Sheets spreadsheet.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=GoogleSheetsWriteBlock.Input,
|
||||
output_schema=GoogleSheetsWriteBlock.Output,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"spreadsheet_id": "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms",
|
||||
"range": "Sheet1!A1:B2",
|
||||
"values": [
|
||||
["Name", "Score"],
|
||||
["Bob", "90"],
|
||||
],
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"result",
|
||||
{"updatedCells": 4, "updatedColumns": 2, "updatedRows": 2},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"_write_sheet": lambda *args, **kwargs: {
|
||||
"updatedCells": 4,
|
||||
"updatedColumns": 2,
|
||||
"updatedRows": 2,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
service = GoogleSheetsReadBlock._build_service(credentials, **kwargs)
|
||||
result = self._write_sheet(
|
||||
service,
|
||||
input_data.spreadsheet_id,
|
||||
input_data.range,
|
||||
input_data.values,
|
||||
)
|
||||
yield "result", result
|
||||
|
||||
def _write_sheet(
|
||||
self, service, spreadsheet_id: str, range: str, values: list[list[str]]
|
||||
) -> dict:
|
||||
body = {"values": values}
|
||||
result = (
|
||||
service.spreadsheets()
|
||||
.values()
|
||||
.update(
|
||||
spreadsheetId=spreadsheet_id,
|
||||
range=range,
|
||||
valueInputOption="USER_ENTERED",
|
||||
body=body,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
return result
|
||||
124
autogpt_platform/backend/backend/blocks/google_maps.py
Normal file
124
autogpt_platform/backend/backend/blocks/google_maps.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import googlemaps
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||
|
||||
|
||||
class Place(BaseModel):
|
||||
name: str
|
||||
address: str
|
||||
phone: str
|
||||
rating: float
|
||||
reviews: int
|
||||
website: str
|
||||
|
||||
|
||||
class GoogleMapsSearchBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
api_key: BlockSecret = SecretField(
|
||||
key="google_maps_api_key",
|
||||
description="Google Maps API Key",
|
||||
)
|
||||
query: str = SchemaField(
|
||||
description="Search query for local businesses",
|
||||
placeholder="e.g., 'restaurants in New York'",
|
||||
)
|
||||
radius: int = SchemaField(
|
||||
description="Search radius in meters (max 50000)",
|
||||
default=5000,
|
||||
ge=1,
|
||||
le=50000,
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="Maximum number of results to return (max 60)",
|
||||
default=20,
|
||||
ge=1,
|
||||
le=60,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
place: Place = SchemaField(description="Place found")
|
||||
error: str = SchemaField(description="Error message if the search failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f47ac10b-58cc-4372-a567-0e02b2c3d479",
|
||||
description="This block searches for local businesses using Google Maps API.",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=GoogleMapsSearchBlock.Input,
|
||||
output_schema=GoogleMapsSearchBlock.Output,
|
||||
test_input={
|
||||
"api_key": "your_test_api_key",
|
||||
"query": "restaurants in new york",
|
||||
"radius": 5000,
|
||||
"max_results": 5,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"place",
|
||||
{
|
||||
"name": "Test Restaurant",
|
||||
"address": "123 Test St, New York, NY 10001",
|
||||
"phone": "+1 (555) 123-4567",
|
||||
"rating": 4.5,
|
||||
"reviews": 100,
|
||||
"website": "https://testrestaurant.com",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"search_places": lambda *args, **kwargs: [
|
||||
{
|
||||
"name": "Test Restaurant",
|
||||
"address": "123 Test St, New York, NY 10001",
|
||||
"phone": "+1 (555) 123-4567",
|
||||
"rating": 4.5,
|
||||
"reviews": 100,
|
||||
"website": "https://testrestaurant.com",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
places = self.search_places(
|
||||
input_data.api_key.get_secret_value(),
|
||||
input_data.query,
|
||||
input_data.radius,
|
||||
input_data.max_results,
|
||||
)
|
||||
for place in places:
|
||||
yield "place", place
|
||||
|
||||
def search_places(self, api_key, query, radius, max_results):
|
||||
client = googlemaps.Client(key=api_key)
|
||||
return self._search_places(client, query, radius, max_results)
|
||||
|
||||
def _search_places(self, client, query, radius, max_results):
|
||||
results = []
|
||||
next_page_token = None
|
||||
while len(results) < max_results:
|
||||
response = client.places(
|
||||
query=query,
|
||||
radius=radius,
|
||||
page_token=next_page_token,
|
||||
)
|
||||
for place in response["results"]:
|
||||
if len(results) >= max_results:
|
||||
break
|
||||
place_details = client.place(place["place_id"])["result"]
|
||||
results.append(
|
||||
Place(
|
||||
name=place_details.get("name", ""),
|
||||
address=place_details.get("formatted_address", ""),
|
||||
phone=place_details.get("formatted_phone_number", ""),
|
||||
rating=place_details.get("rating", 0),
|
||||
reviews=place_details.get("user_ratings_total", 0),
|
||||
website=place_details.get("website", ""),
|
||||
)
|
||||
)
|
||||
next_page_token = response.get("next_page_token")
|
||||
if not next_page_token:
|
||||
break
|
||||
return results
|
||||
@@ -4,6 +4,7 @@ from enum import Enum
|
||||
import requests
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class HttpMethod(Enum):
|
||||
@@ -18,15 +19,27 @@ class HttpMethod(Enum):
|
||||
|
||||
class SendWebRequestBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
url: str
|
||||
method: HttpMethod = HttpMethod.POST
|
||||
headers: dict[str, str] = {}
|
||||
body: object = {}
|
||||
url: str = SchemaField(
|
||||
description="The URL to send the request to",
|
||||
placeholder="https://api.example.com",
|
||||
)
|
||||
method: HttpMethod = SchemaField(
|
||||
description="The HTTP method to use for the request",
|
||||
default=HttpMethod.POST,
|
||||
)
|
||||
headers: dict[str, str] = SchemaField(
|
||||
description="The headers to include in the request",
|
||||
default={},
|
||||
)
|
||||
body: object = SchemaField(
|
||||
description="The body of the request",
|
||||
default={},
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: object
|
||||
client_error: object
|
||||
server_error: object
|
||||
response: object = SchemaField(description="The response from the server")
|
||||
client_error: object = SchemaField(description="The error on 4xx status codes")
|
||||
server_error: object = SchemaField(description="The error on 5xx status codes")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
253
autogpt_platform/backend/backend/blocks/ideogram.py
Normal file
253
autogpt_platform/backend/backend/blocks/ideogram.py
Normal file
@@ -0,0 +1,253 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||
|
||||
|
||||
class IdeogramModelName(str, Enum):
|
||||
V2 = "V_2"
|
||||
V1 = "V_1"
|
||||
V1_TURBO = "V_1_TURBO"
|
||||
V2_TURBO = "V_2_TURBO"
|
||||
|
||||
|
||||
class MagicPromptOption(str, Enum):
|
||||
AUTO = "AUTO"
|
||||
ON = "ON"
|
||||
OFF = "OFF"
|
||||
|
||||
|
||||
class StyleType(str, Enum):
|
||||
AUTO = "AUTO"
|
||||
GENERAL = "GENERAL"
|
||||
REALISTIC = "REALISTIC"
|
||||
DESIGN = "DESIGN"
|
||||
RENDER_3D = "RENDER_3D"
|
||||
ANIME = "ANIME"
|
||||
|
||||
|
||||
class ColorPalettePreset(str, Enum):
|
||||
NONE = "NONE"
|
||||
EMBER = "EMBER"
|
||||
FRESH = "FRESH"
|
||||
JUNGLE = "JUNGLE"
|
||||
MAGIC = "MAGIC"
|
||||
MELON = "MELON"
|
||||
MOSAIC = "MOSAIC"
|
||||
PASTEL = "PASTEL"
|
||||
ULTRAMARINE = "ULTRAMARINE"
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
ASPECT_10_16 = "ASPECT_10_16"
|
||||
ASPECT_16_10 = "ASPECT_16_10"
|
||||
ASPECT_9_16 = "ASPECT_9_16"
|
||||
ASPECT_16_9 = "ASPECT_16_9"
|
||||
ASPECT_3_2 = "ASPECT_3_2"
|
||||
ASPECT_2_3 = "ASPECT_2_3"
|
||||
ASPECT_4_3 = "ASPECT_4_3"
|
||||
ASPECT_3_4 = "ASPECT_3_4"
|
||||
ASPECT_1_1 = "ASPECT_1_1"
|
||||
ASPECT_1_3 = "ASPECT_1_3"
|
||||
ASPECT_3_1 = "ASPECT_3_1"
|
||||
|
||||
|
||||
class UpscaleOption(str, Enum):
|
||||
AI_UPSCALE = "AI Upscale"
|
||||
NO_UPSCALE = "No Upscale"
|
||||
|
||||
|
||||
class IdeogramModelBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
api_key: BlockSecret = SecretField(
|
||||
key="ideogram_api_key",
|
||||
description="Ideogram API Key",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="Text prompt for image generation",
|
||||
placeholder="e.g., 'A futuristic cityscape at sunset'",
|
||||
title="Prompt",
|
||||
)
|
||||
ideogram_model_name: IdeogramModelName = SchemaField(
|
||||
description="The name of the Image Generation Model, e.g., V_2",
|
||||
default=IdeogramModelName.V2,
|
||||
title="Image Generation Model",
|
||||
advanced=False,
|
||||
)
|
||||
aspect_ratio: AspectRatio = SchemaField(
|
||||
description="Aspect ratio for the generated image",
|
||||
default=AspectRatio.ASPECT_1_1,
|
||||
title="Aspect Ratio",
|
||||
advanced=False,
|
||||
)
|
||||
upscale: UpscaleOption = SchemaField(
|
||||
description="Upscale the generated image",
|
||||
default=UpscaleOption.NO_UPSCALE,
|
||||
title="Upscale Image",
|
||||
advanced=False,
|
||||
)
|
||||
magic_prompt_option: MagicPromptOption = SchemaField(
|
||||
description="Whether to use MagicPrompt for enhancing the request",
|
||||
default=MagicPromptOption.AUTO,
|
||||
title="Magic Prompt Option",
|
||||
advanced=True,
|
||||
)
|
||||
seed: Optional[int] = SchemaField(
|
||||
description="Random seed. Set for reproducible generation",
|
||||
default=None,
|
||||
title="Seed",
|
||||
advanced=True,
|
||||
)
|
||||
style_type: StyleType = SchemaField(
|
||||
description="Style type to apply, applicable for V_2 and above",
|
||||
default=StyleType.AUTO,
|
||||
title="Style Type",
|
||||
advanced=True,
|
||||
)
|
||||
negative_prompt: Optional[str] = SchemaField(
|
||||
description="Description of what to exclude from the image",
|
||||
default=None,
|
||||
title="Negative Prompt",
|
||||
advanced=True,
|
||||
)
|
||||
color_palette_name: ColorPalettePreset = SchemaField(
|
||||
description="Color palette preset name, choose 'None' to skip",
|
||||
default=ColorPalettePreset.NONE,
|
||||
title="Color Palette Preset",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: str = SchemaField(description="Generated image URL")
|
||||
error: str = SchemaField(description="Error message if the model run failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="6ab085e2-20b3-4055-bc3e-08036e01eca6",
|
||||
description="This block runs Ideogram models with both simple and advanced settings.",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=IdeogramModelBlock.Input,
|
||||
output_schema=IdeogramModelBlock.Output,
|
||||
test_input={
|
||||
"api_key": "test_api_key",
|
||||
"ideogram_model_name": IdeogramModelName.V2,
|
||||
"prompt": "A futuristic cityscape at sunset",
|
||||
"aspect_ratio": AspectRatio.ASPECT_1_1,
|
||||
"upscale": UpscaleOption.NO_UPSCALE,
|
||||
"magic_prompt_option": MagicPromptOption.AUTO,
|
||||
"seed": None,
|
||||
"style_type": StyleType.AUTO,
|
||||
"negative_prompt": None,
|
||||
"color_palette_name": ColorPalettePreset.NONE,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"result",
|
||||
"https://ideogram.ai/api/images/test-generated-image-url.png",
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"run_model": lambda api_key, model_name, prompt, seed, aspect_ratio, magic_prompt_option, style_type, negative_prompt, color_palette_name: "https://ideogram.ai/api/images/test-generated-image-url.png",
|
||||
"upscale_image": lambda api_key, image_url: "https://ideogram.ai/api/images/test-upscaled-image-url.png",
|
||||
},
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
seed = input_data.seed
|
||||
|
||||
# Step 1: Generate the image
|
||||
result = self.run_model(
|
||||
api_key=input_data.api_key.get_secret_value(),
|
||||
model_name=input_data.ideogram_model_name.value,
|
||||
prompt=input_data.prompt,
|
||||
seed=seed,
|
||||
aspect_ratio=input_data.aspect_ratio.value,
|
||||
magic_prompt_option=input_data.magic_prompt_option.value,
|
||||
style_type=input_data.style_type.value,
|
||||
negative_prompt=input_data.negative_prompt,
|
||||
color_palette_name=input_data.color_palette_name.value,
|
||||
)
|
||||
|
||||
# Step 2: Upscale the image if requested
|
||||
if input_data.upscale == UpscaleOption.AI_UPSCALE:
|
||||
result = self.upscale_image(
|
||||
api_key=input_data.api_key.get_secret_value(),
|
||||
image_url=result,
|
||||
)
|
||||
|
||||
yield "result", result
|
||||
|
||||
def run_model(
|
||||
self,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
seed: Optional[int],
|
||||
aspect_ratio: str,
|
||||
magic_prompt_option: str,
|
||||
style_type: str,
|
||||
negative_prompt: Optional[str],
|
||||
color_palette_name: str,
|
||||
):
|
||||
url = "https://api.ideogram.ai/generate"
|
||||
headers = {"Api-Key": api_key, "Content-Type": "application/json"}
|
||||
|
||||
data: Dict[str, Any] = {
|
||||
"image_request": {
|
||||
"prompt": prompt,
|
||||
"model": model_name,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"magic_prompt_option": magic_prompt_option,
|
||||
"style_type": style_type,
|
||||
}
|
||||
}
|
||||
|
||||
if seed is not None:
|
||||
data["image_request"]["seed"] = seed
|
||||
|
||||
if negative_prompt:
|
||||
data["image_request"]["negative_prompt"] = negative_prompt
|
||||
|
||||
if color_palette_name != "NONE":
|
||||
data["image_request"]["color_palette"] = {"name": color_palette_name}
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=data, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()["data"][0]["url"]
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise Exception(f"Failed to fetch image: {str(e)}")
|
||||
|
||||
def upscale_image(self, api_key: str, image_url: str):
|
||||
url = "https://api.ideogram.ai/upscale"
|
||||
headers = {
|
||||
"Api-Key": api_key,
|
||||
}
|
||||
|
||||
try:
|
||||
# Step 1: Download the image from the provided URL
|
||||
image_response = requests.get(image_url)
|
||||
image_response.raise_for_status()
|
||||
|
||||
# Step 2: Send the downloaded image to the upscale API
|
||||
files = {
|
||||
"image_file": ("image.png", image_response.content, "image/png"),
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=headers,
|
||||
data={
|
||||
"image_request": "{}", # Empty JSON object
|
||||
},
|
||||
files=files,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()["data"][0]["url"]
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise Exception(f"Failed to upscale image: {str(e)}")
|
||||
@@ -1,37 +1,52 @@
|
||||
from typing import Any, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class ListIteratorBlock(Block):
|
||||
class StepThroughItemsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
items: List[Any] = SchemaField(
|
||||
description="The list of items to iterate over",
|
||||
placeholder="[1, 2, 3, 4, 5]",
|
||||
items: list | dict = SchemaField(
|
||||
description="The list or dictionary of items to iterate over",
|
||||
placeholder="[1, 2, 3, 4, 5] or {'key1': 'value1', 'key2': 'value2'}",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
item: Tuple[int, Any] = SchemaField(
|
||||
description="A tuple with the index and current item in the iteration"
|
||||
item: Any = SchemaField(description="The current item in the iteration")
|
||||
key: Any = SchemaField(
|
||||
description="The key or index of the current item in the iteration",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f8e7d6c5-b4a3-2c1d-0e9f-8g7h6i5j4k3l",
|
||||
input_schema=ListIteratorBlock.Input,
|
||||
output_schema=ListIteratorBlock.Output,
|
||||
description="Iterates over a list of items and outputs each item with its index.",
|
||||
id="f66a3543-28d3-4ab5-8945-9b336371e2ce",
|
||||
input_schema=StepThroughItemsBlock.Input,
|
||||
output_schema=StepThroughItemsBlock.Output,
|
||||
categories={BlockCategory.LOGIC},
|
||||
test_input={"items": [1, "two", {"three": 3}, [4, 5]]},
|
||||
description="Iterates over a list or dictionary and outputs each item.",
|
||||
test_input={"items": [1, 2, 3, {"key1": "value1", "key2": "value2"}]},
|
||||
test_output=[
|
||||
("item", (0, 1)),
|
||||
("item", (1, "two")),
|
||||
("item", (2, {"three": 3})),
|
||||
("item", (3, [4, 5])),
|
||||
("item", 1),
|
||||
("key", 0),
|
||||
("item", 2),
|
||||
("key", 1),
|
||||
("item", 3),
|
||||
("key", 2),
|
||||
("item", {"key1": "value1", "key2": "value2"}),
|
||||
("key", 3),
|
||||
],
|
||||
test_mock={},
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
for index, item in enumerate(input_data.items):
|
||||
yield "item", (index, item)
|
||||
items = input_data.items
|
||||
if isinstance(items, dict):
|
||||
# If items is a dictionary, iterate over its values
|
||||
for item in items.values():
|
||||
yield "item", item
|
||||
yield "key", item
|
||||
else:
|
||||
# If items is a list, iterate over the list
|
||||
for index, item in enumerate(items):
|
||||
yield "item", item
|
||||
yield "key", index
|
||||
|
||||
39
autogpt_platform/backend/backend/blocks/jina/_auth.py
Normal file
39
autogpt_platform/backend/backend/blocks/jina/_auth.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Literal
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import APIKeyCredentials
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import CredentialsField, CredentialsMetaInput
|
||||
|
||||
JinaCredentials = APIKeyCredentials
|
||||
JinaCredentialsInput = CredentialsMetaInput[
|
||||
Literal["jina"],
|
||||
Literal["api_key"],
|
||||
]
|
||||
|
||||
|
||||
def JinaCredentialsField() -> JinaCredentialsInput:
|
||||
"""
|
||||
Creates a Jina credentials input on a block.
|
||||
|
||||
"""
|
||||
return CredentialsField(
|
||||
provider="jina",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Jina integration can be used with an API Key.",
|
||||
)
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="jina",
|
||||
api_key=SecretStr("mock-jina-api-key"),
|
||||
title="Mock Jina API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.type,
|
||||
}
|
||||
69
autogpt_platform/backend/backend/blocks/jina/chunking.py
Normal file
69
autogpt_platform/backend/backend/blocks/jina/chunking.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import requests
|
||||
|
||||
from backend.blocks.jina._auth import (
|
||||
JinaCredentials,
|
||||
JinaCredentialsField,
|
||||
JinaCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class JinaChunkingBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
texts: list = SchemaField(description="List of texts to chunk")
|
||||
|
||||
credentials: JinaCredentialsInput = JinaCredentialsField()
|
||||
max_chunk_length: int = SchemaField(
|
||||
description="Maximum length of each chunk", default=1000
|
||||
)
|
||||
return_tokens: bool = SchemaField(
|
||||
description="Whether to return token information", default=False
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
chunks: list = SchemaField(description="List of chunked texts")
|
||||
tokens: list = SchemaField(
|
||||
description="List of token information for each chunk", optional=True
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="806fb15e-830f-4796-8692-557d300ff43c",
|
||||
description="Chunks texts using Jina AI's segmentation service",
|
||||
categories={BlockCategory.AI, BlockCategory.TEXT},
|
||||
input_schema=JinaChunkingBlock.Input,
|
||||
output_schema=JinaChunkingBlock.Output,
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: JinaCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://segment.jina.ai/"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {credentials.api_key.get_secret_value()}",
|
||||
}
|
||||
|
||||
all_chunks = []
|
||||
all_tokens = []
|
||||
|
||||
for text in input_data.texts:
|
||||
data = {
|
||||
"content": text,
|
||||
"return_tokens": str(input_data.return_tokens).lower(),
|
||||
"return_chunks": "true",
|
||||
"max_chunk_length": str(input_data.max_chunk_length),
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
all_chunks.extend(result.get("chunks", []))
|
||||
if input_data.return_tokens:
|
||||
all_tokens.extend(result.get("tokens", []))
|
||||
|
||||
yield "chunks", all_chunks
|
||||
if input_data.return_tokens:
|
||||
yield "tokens", all_tokens
|
||||
44
autogpt_platform/backend/backend/blocks/jina/embeddings.py
Normal file
44
autogpt_platform/backend/backend/blocks/jina/embeddings.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import requests
|
||||
|
||||
from backend.blocks.jina._auth import (
|
||||
JinaCredentials,
|
||||
JinaCredentialsField,
|
||||
JinaCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class JinaEmbeddingBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
texts: list = SchemaField(description="List of texts to embed")
|
||||
credentials: JinaCredentialsInput = JinaCredentialsField()
|
||||
model: str = SchemaField(
|
||||
description="Jina embedding model to use",
|
||||
default="jina-embeddings-v2-base-en",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
embeddings: list = SchemaField(description="List of embeddings")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7c56b3ab-62e7-43a2-a2dc-4ec4245660b6",
|
||||
description="Generates embeddings using Jina AI",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=JinaEmbeddingBlock.Input,
|
||||
output_schema=JinaEmbeddingBlock.Output,
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: JinaCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.jina.ai/v1/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {credentials.api_key.get_secret_value()}",
|
||||
}
|
||||
data = {"input": input_data.texts, "model": input_data.model}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
embeddings = [e["embedding"] for e in response.json()["data"]]
|
||||
yield "embeddings", embeddings
|
||||
@@ -1,7 +1,12 @@
|
||||
import ast
|
||||
import logging
|
||||
from enum import Enum
|
||||
from enum import Enum, EnumMeta
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, List, NamedTuple
|
||||
from types import MappingProxyType
|
||||
from typing import TYPE_CHECKING, Any, List, NamedTuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enum import _EnumMemberT
|
||||
|
||||
import anthropic
|
||||
import ollama
|
||||
@@ -11,6 +16,7 @@ from groq import Groq
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||
from backend.util import json
|
||||
from backend.util.settings import BehaveAs, Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,7 +34,26 @@ class ModelMetadata(NamedTuple):
|
||||
cost_factor: int
|
||||
|
||||
|
||||
class LlmModel(str, Enum):
|
||||
class LlmModelMeta(EnumMeta):
|
||||
@property
|
||||
def __members__(
|
||||
self: type["_EnumMemberT"],
|
||||
) -> MappingProxyType[str, "_EnumMemberT"]:
|
||||
if Settings().config.behave_as == BehaveAs.LOCAL:
|
||||
members = super().__members__
|
||||
return members
|
||||
else:
|
||||
removed_providers = ["ollama"]
|
||||
existing_members = super().__members__
|
||||
members = {
|
||||
name: member
|
||||
for name, member in existing_members.items()
|
||||
if LlmModel[name].provider not in removed_providers
|
||||
}
|
||||
return MappingProxyType(members)
|
||||
|
||||
|
||||
class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
# OpenAI models
|
||||
O1_PREVIEW = "o1-preview"
|
||||
O1_MINI = "o1-mini"
|
||||
@@ -37,7 +62,7 @@ class LlmModel(str, Enum):
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT3_5_TURBO = "gpt-3.5-turbo"
|
||||
# Anthropic models
|
||||
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-20240620"
|
||||
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
# Groq models
|
||||
LLAMA3_8B = "llama3-8b-8192"
|
||||
@@ -57,27 +82,39 @@ class LlmModel(str, Enum):
|
||||
def metadata(self) -> ModelMetadata:
|
||||
return MODEL_METADATA[self]
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return self.metadata.provider
|
||||
|
||||
@property
|
||||
def context_window(self) -> int:
|
||||
return self.metadata.context_window
|
||||
|
||||
@property
|
||||
def cost_factor(self) -> int:
|
||||
return self.metadata.cost_factor
|
||||
|
||||
|
||||
MODEL_METADATA = {
|
||||
LlmModel.O1_PREVIEW: ModelMetadata("openai", 32000, cost_factor=60),
|
||||
LlmModel.O1_MINI: ModelMetadata("openai", 62000, cost_factor=30),
|
||||
LlmModel.GPT4O_MINI: ModelMetadata("openai", 128000, cost_factor=10),
|
||||
LlmModel.GPT4O: ModelMetadata("openai", 128000, cost_factor=12),
|
||||
LlmModel.GPT4_TURBO: ModelMetadata("openai", 128000, cost_factor=11),
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, cost_factor=8),
|
||||
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata("anthropic", 200000, cost_factor=14),
|
||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata("anthropic", 200000, cost_factor=13),
|
||||
LlmModel.LLAMA3_8B: ModelMetadata("groq", 8192, cost_factor=6),
|
||||
LlmModel.LLAMA3_70B: ModelMetadata("groq", 8192, cost_factor=9),
|
||||
LlmModel.MIXTRAL_8X7B: ModelMetadata("groq", 32768, cost_factor=7),
|
||||
LlmModel.GEMMA_7B: ModelMetadata("groq", 8192, cost_factor=6),
|
||||
LlmModel.GEMMA2_9B: ModelMetadata("groq", 8192, cost_factor=7),
|
||||
LlmModel.LLAMA3_1_405B: ModelMetadata("groq", 8192, cost_factor=10),
|
||||
LlmModel.O1_PREVIEW: ModelMetadata("openai", 32000, cost_factor=16),
|
||||
LlmModel.O1_MINI: ModelMetadata("openai", 62000, cost_factor=4),
|
||||
LlmModel.GPT4O_MINI: ModelMetadata("openai", 128000, cost_factor=1),
|
||||
LlmModel.GPT4O: ModelMetadata("openai", 128000, cost_factor=3),
|
||||
LlmModel.GPT4_TURBO: ModelMetadata("openai", 128000, cost_factor=10),
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, cost_factor=1),
|
||||
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata("anthropic", 200000, cost_factor=4),
|
||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata("anthropic", 200000, cost_factor=1),
|
||||
LlmModel.LLAMA3_8B: ModelMetadata("groq", 8192, cost_factor=1),
|
||||
LlmModel.LLAMA3_70B: ModelMetadata("groq", 8192, cost_factor=1),
|
||||
LlmModel.MIXTRAL_8X7B: ModelMetadata("groq", 32768, cost_factor=1),
|
||||
LlmModel.GEMMA_7B: ModelMetadata("groq", 8192, cost_factor=1),
|
||||
LlmModel.GEMMA2_9B: ModelMetadata("groq", 8192, cost_factor=1),
|
||||
LlmModel.LLAMA3_1_405B: ModelMetadata("groq", 8192, cost_factor=1),
|
||||
# Limited to 16k during preview
|
||||
LlmModel.LLAMA3_1_70B: ModelMetadata("groq", 131072, cost_factor=15),
|
||||
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072, cost_factor=13),
|
||||
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192, cost_factor=7),
|
||||
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192, cost_factor=11),
|
||||
LlmModel.LLAMA3_1_70B: ModelMetadata("groq", 131072, cost_factor=1),
|
||||
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072, cost_factor=1),
|
||||
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192, cost_factor=1),
|
||||
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192, cost_factor=1),
|
||||
}
|
||||
|
||||
for model in LlmModel:
|
||||
@@ -85,9 +122,23 @@ for model in LlmModel:
|
||||
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
|
||||
class Message(BlockSchema):
|
||||
role: MessageRole
|
||||
content: str
|
||||
|
||||
|
||||
class AIStructuredResponseGeneratorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
prompt: str
|
||||
prompt: str = SchemaField(
|
||||
description="The prompt to send to the language model.",
|
||||
placeholder="Enter your prompt here...",
|
||||
)
|
||||
expected_format: dict[str, str] = SchemaField(
|
||||
description="Expected format of the response. If provided, the response will be validated against this format. "
|
||||
"The keys should be the expected fields in the response, and the values should be the description of the field.",
|
||||
@@ -99,15 +150,34 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
api_key: BlockSecret = SecretField(value="")
|
||||
sys_prompt: str = ""
|
||||
retry: int = 3
|
||||
sys_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default="",
|
||||
description="The system prompt to provide additional context to the model.",
|
||||
)
|
||||
conversation_history: list[Message] = SchemaField(
|
||||
default=[],
|
||||
description="The conversation history to provide context for the prompt.",
|
||||
)
|
||||
retry: int = SchemaField(
|
||||
title="Retry Count",
|
||||
default=3,
|
||||
description="Number of times to retry the LLM call if the response does not match the expected format.",
|
||||
)
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False, default={}, description="Values used to fill in the prompt."
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
advanced=True,
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate in the chat completion.",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: dict[str, Any]
|
||||
error: str
|
||||
response: dict[str, Any] = SchemaField(
|
||||
description="The response object generated by the language model."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -127,26 +197,47 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
},
|
||||
test_output=("response", {"key1": "key1Value", "key2": "key2Value"}),
|
||||
test_mock={
|
||||
"llm_call": lambda *args, **kwargs: json.dumps(
|
||||
{
|
||||
"key1": "key1Value",
|
||||
"key2": "key2Value",
|
||||
}
|
||||
"llm_call": lambda *args, **kwargs: (
|
||||
json.dumps(
|
||||
{
|
||||
"key1": "key1Value",
|
||||
"key2": "key2Value",
|
||||
}
|
||||
),
|
||||
0,
|
||||
0,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def llm_call(
|
||||
api_key: str, model: LlmModel, prompt: list[dict], json_format: bool
|
||||
) -> str:
|
||||
provider = model.metadata.provider
|
||||
api_key: str,
|
||||
llm_model: LlmModel,
|
||||
prompt: list[dict],
|
||||
json_format: bool,
|
||||
max_tokens: int | None = None,
|
||||
) -> tuple[str, int, int]:
|
||||
"""
|
||||
Args:
|
||||
api_key: API key for the LLM provider.
|
||||
llm_model: The LLM model to use.
|
||||
prompt: The prompt to send to the LLM.
|
||||
json_format: Whether the response should be in JSON format.
|
||||
max_tokens: The maximum number of tokens to generate in the chat completion.
|
||||
|
||||
Returns:
|
||||
The response from the LLM.
|
||||
The number of tokens used in the prompt.
|
||||
The number of tokens used in the completion.
|
||||
"""
|
||||
provider = llm_model.metadata.provider
|
||||
|
||||
if provider == "openai":
|
||||
openai.api_key = api_key
|
||||
response_format = None
|
||||
|
||||
if model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
|
||||
if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
prompt = [
|
||||
@@ -157,11 +248,17 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
response = openai.chat.completions.create(
|
||||
model=model.value,
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_completion_tokens=max_tokens,
|
||||
)
|
||||
|
||||
return (
|
||||
response.choices[0].message.content or "",
|
||||
response.usage.prompt_tokens if response.usage else 0,
|
||||
response.usage.completion_tokens if response.usage else 0,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
elif provider == "anthropic":
|
||||
system_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
sysprompt = " ".join(system_messages)
|
||||
@@ -179,13 +276,18 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
|
||||
client = anthropic.Anthropic(api_key=api_key)
|
||||
try:
|
||||
response = client.messages.create(
|
||||
model=model.value,
|
||||
max_tokens=4096,
|
||||
resp = client.messages.create(
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens or 8192,
|
||||
)
|
||||
|
||||
return (
|
||||
resp.content[0].text if resp.content else "",
|
||||
resp.usage.input_tokens,
|
||||
resp.usage.output_tokens,
|
||||
)
|
||||
return response.content[0].text if response.content else ""
|
||||
except anthropic.APIError as e:
|
||||
error_message = f"Anthropic API error: {str(e)}"
|
||||
logger.error(error_message)
|
||||
@@ -194,22 +296,35 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
client = Groq(api_key=api_key)
|
||||
response_format = {"type": "json_object"} if json_format else None
|
||||
response = client.chat.completions.create(
|
||||
model=model.value,
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return (
|
||||
response.choices[0].message.content or "",
|
||||
response.usage.prompt_tokens if response.usage else 0,
|
||||
response.usage.completion_tokens if response.usage else 0,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
elif provider == "ollama":
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
response = ollama.generate(
|
||||
model=model.value,
|
||||
prompt=prompt[0]["content"],
|
||||
model=llm_model.value,
|
||||
prompt=f"{sys_messages}\n\n{usr_messages}",
|
||||
stream=False,
|
||||
)
|
||||
return (
|
||||
response.get("response") or "",
|
||||
response.get("prompt_eval_count") or 0,
|
||||
response.get("eval_count") or 0,
|
||||
)
|
||||
return response["response"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
prompt = []
|
||||
logger.debug(f"Calling LLM with input data: {input_data}")
|
||||
prompt = [p.model_dump() for p in input_data.conversation_history]
|
||||
|
||||
def trim_prompt(s: str) -> str:
|
||||
lines = s.strip().split("\n")
|
||||
@@ -238,7 +353,8 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
)
|
||||
prompt.append({"role": "system", "content": sys_prompt})
|
||||
|
||||
prompt.append({"role": "user", "content": input_data.prompt})
|
||||
if input_data.prompt:
|
||||
prompt.append({"role": "user", "content": input_data.prompt})
|
||||
|
||||
def parse_response(resp: str) -> tuple[dict[str, Any], str | None]:
|
||||
try:
|
||||
@@ -254,19 +370,26 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
|
||||
logger.info(f"LLM request: {prompt}")
|
||||
retry_prompt = ""
|
||||
model = input_data.model
|
||||
llm_model = input_data.model
|
||||
api_key = (
|
||||
input_data.api_key.get_secret_value()
|
||||
or LlmApiKeys[model.metadata.provider].get_secret_value()
|
||||
or LlmApiKeys[llm_model.metadata.provider].get_secret_value()
|
||||
)
|
||||
|
||||
for retry_count in range(input_data.retry):
|
||||
try:
|
||||
response_text = self.llm_call(
|
||||
response_text, input_token, output_token = self.llm_call(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
json_format=bool(input_data.expected_format),
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
self.merge_stats(
|
||||
{
|
||||
"input_token_count": input_token,
|
||||
"output_token_count": output_token,
|
||||
}
|
||||
)
|
||||
logger.info(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
@@ -303,15 +426,25 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
)
|
||||
prompt.append({"role": "user", "content": retry_prompt})
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling LLM: {e}")
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
retry_prompt = f"Error calling LLM: {e}"
|
||||
finally:
|
||||
self.merge_stats(
|
||||
{
|
||||
"llm_call_count": retry_count + 1,
|
||||
"llm_retry_count": retry_count,
|
||||
}
|
||||
)
|
||||
|
||||
yield "error", retry_prompt
|
||||
raise RuntimeError(retry_prompt)
|
||||
|
||||
|
||||
class AITextGeneratorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
prompt: str
|
||||
prompt: str = SchemaField(
|
||||
description="The prompt to send to the language model.",
|
||||
placeholder="Enter your prompt here...",
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4_TURBO,
|
||||
@@ -319,15 +452,30 @@ class AITextGeneratorBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
api_key: BlockSecret = SecretField(value="")
|
||||
sys_prompt: str = ""
|
||||
retry: int = 3
|
||||
sys_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default="",
|
||||
description="The system prompt to provide additional context to the model.",
|
||||
)
|
||||
retry: int = SchemaField(
|
||||
title="Retry Count",
|
||||
default=3,
|
||||
description="Number of times to retry the LLM call if the response does not match the expected format.",
|
||||
)
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False, default={}, description="Values used to fill in the prompt."
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
advanced=True,
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate in the chat completion.",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: str
|
||||
error: str
|
||||
response: str = SchemaField(
|
||||
description="The response generated by the language model."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -341,47 +489,70 @@ class AITextGeneratorBlock(Block):
|
||||
test_mock={"llm_call": lambda *args, **kwargs: "Response text"},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def llm_call(input_data: AIStructuredResponseGeneratorBlock.Input) -> str:
|
||||
object_block = AIStructuredResponseGeneratorBlock()
|
||||
for output_name, output_data in object_block.run(input_data):
|
||||
if output_name == "response":
|
||||
return output_data["response"]
|
||||
else:
|
||||
raise RuntimeError(output_data)
|
||||
raise ValueError("Failed to get a response from the LLM.")
|
||||
def llm_call(self, input_data: AIStructuredResponseGeneratorBlock.Input) -> str:
|
||||
block = AIStructuredResponseGeneratorBlock()
|
||||
response = block.run_once(input_data, "response")
|
||||
self.merge_stats(block.execution_stats)
|
||||
return response["response"]
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
object_input_data = AIStructuredResponseGeneratorBlock.Input(
|
||||
**{attr: getattr(input_data, attr) for attr in input_data.model_fields},
|
||||
expected_format={},
|
||||
)
|
||||
yield "response", self.llm_call(object_input_data)
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
object_input_data = AIStructuredResponseGeneratorBlock.Input(
|
||||
**{attr: getattr(input_data, attr) for attr in input_data.model_fields},
|
||||
expected_format={},
|
||||
)
|
||||
yield "response", self.llm_call(object_input_data)
|
||||
|
||||
|
||||
class SummaryStyle(Enum):
|
||||
CONCISE = "concise"
|
||||
DETAILED = "detailed"
|
||||
BULLET_POINTS = "bullet points"
|
||||
NUMBERED_LIST = "numbered list"
|
||||
|
||||
|
||||
class AITextSummarizerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str
|
||||
text: str = SchemaField(
|
||||
description="The text to summarize.",
|
||||
placeholder="Enter the text to summarize here...",
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4_TURBO,
|
||||
description="The language model to use for summarizing the text.",
|
||||
)
|
||||
focus: str = SchemaField(
|
||||
title="Focus",
|
||||
default="general information",
|
||||
description="The topic to focus on in the summary",
|
||||
)
|
||||
style: SummaryStyle = SchemaField(
|
||||
title="Summary Style",
|
||||
default=SummaryStyle.CONCISE,
|
||||
description="The style of the summary to generate.",
|
||||
)
|
||||
api_key: BlockSecret = SecretField(value="")
|
||||
# TODO: Make this dynamic
|
||||
max_tokens: int = 4000 # Adjust based on the model's context window
|
||||
chunk_overlap: int = 100 # Overlap between chunks to maintain context
|
||||
max_tokens: int = SchemaField(
|
||||
title="Max Tokens",
|
||||
default=4096,
|
||||
description="The maximum number of tokens to generate in the chat completion.",
|
||||
ge=1,
|
||||
)
|
||||
chunk_overlap: int = SchemaField(
|
||||
title="Chunk Overlap",
|
||||
default=100,
|
||||
description="The number of overlapping tokens between chunks to maintain context.",
|
||||
ge=0,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
summary: str
|
||||
error: str
|
||||
summary: str = SchemaField(description="The final summary of the text.")
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c3d4e5f6-7g8h-9i0j-1k2l-m3n4o5p6q7r8",
|
||||
id="a0a69be1-4528-491c-a85a-a4ab6873e3f0",
|
||||
description="Utilize a Large Language Model (LLM) to summarize a long text.",
|
||||
categories={BlockCategory.AI, BlockCategory.TEXT},
|
||||
input_schema=AITextSummarizerBlock.Input,
|
||||
@@ -398,11 +569,8 @@ class AITextSummarizerBlock(Block):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
for output in self._run(input_data):
|
||||
yield output
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
for output in self._run(input_data):
|
||||
yield output
|
||||
|
||||
def _run(self, input_data: Input) -> BlockOutput:
|
||||
chunks = self._split_text(
|
||||
@@ -429,18 +597,14 @@ class AITextSummarizerBlock(Block):
|
||||
|
||||
return chunks
|
||||
|
||||
@staticmethod
|
||||
def llm_call(
|
||||
input_data: AIStructuredResponseGeneratorBlock.Input,
|
||||
) -> dict[str, str]:
|
||||
llm_block = AIStructuredResponseGeneratorBlock()
|
||||
for output_name, output_data in llm_block.run(input_data):
|
||||
if output_name == "response":
|
||||
return output_data
|
||||
raise ValueError("Failed to get a response from the LLM.")
|
||||
def llm_call(self, input_data: AIStructuredResponseGeneratorBlock.Input) -> dict:
|
||||
block = AIStructuredResponseGeneratorBlock()
|
||||
response = block.run_once(input_data, "response")
|
||||
self.merge_stats(block.execution_stats)
|
||||
return response
|
||||
|
||||
def _summarize_chunk(self, chunk: str, input_data: Input) -> str:
|
||||
prompt = f"Summarize the following text concisely:\n\n{chunk}"
|
||||
prompt = f"Summarize the following text in a {input_data.style} form. Focus your summary on the topic of `{input_data.focus}` if present, otherwise just provide a general summary:\n\n```{chunk}```"
|
||||
|
||||
llm_response = self.llm_call(
|
||||
AIStructuredResponseGeneratorBlock.Input(
|
||||
@@ -454,13 +618,10 @@ class AITextSummarizerBlock(Block):
|
||||
return llm_response["summary"]
|
||||
|
||||
def _combine_summaries(self, summaries: list[str], input_data: Input) -> str:
|
||||
combined_text = " ".join(summaries)
|
||||
combined_text = "\n\n".join(summaries)
|
||||
|
||||
if len(combined_text.split()) <= input_data.max_tokens:
|
||||
prompt = (
|
||||
"Provide a final, concise summary of the following summaries:\n\n"
|
||||
+ combined_text
|
||||
)
|
||||
prompt = f"Provide a final summary of the following section summaries in a {input_data.style} form, focus your summary on the topic of `{input_data.focus}` if present:\n\n ```{combined_text}```\n\n Just respond with the final_summary in the format specified."
|
||||
|
||||
llm_response = self.llm_call(
|
||||
AIStructuredResponseGeneratorBlock.Input(
|
||||
@@ -489,17 +650,6 @@ class AITextSummarizerBlock(Block):
|
||||
] # Get the first yielded value
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
|
||||
class Message(BlockSchema):
|
||||
role: MessageRole
|
||||
content: str
|
||||
|
||||
|
||||
class AIConversationBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
messages: List[Message] = SchemaField(
|
||||
@@ -514,9 +664,9 @@ class AIConversationBlock(Block):
|
||||
value="", description="API key for the chosen language model provider."
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
advanced=True,
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate in the chat completion.",
|
||||
ge=1,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -527,7 +677,7 @@ class AIConversationBlock(Block):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c3d4e5f6-g7h8-i9j0-k1l2-m3n4o5p6q7r8",
|
||||
id="32a87eab-381e-4dd4-bdb8-4c47151be35a",
|
||||
description="Advanced LLM call that takes a list of messages and sends them to the language model.",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=AIConversationBlock.Input,
|
||||
@@ -554,65 +704,253 @@ class AIConversationBlock(Block):
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def llm_call(
|
||||
api_key: str,
|
||||
model: LlmModel,
|
||||
messages: List[dict[str, str]],
|
||||
max_tokens: int | None = None,
|
||||
) -> str:
|
||||
provider = model.metadata.provider
|
||||
|
||||
if provider == "openai":
|
||||
openai.api_key = api_key
|
||||
response = openai.chat.completions.create(
|
||||
model=model.value,
|
||||
messages=messages, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
elif provider == "anthropic":
|
||||
client = anthropic.Anthropic(api_key=api_key)
|
||||
response = client.messages.create(
|
||||
model=model.value,
|
||||
max_tokens=max_tokens or 4096,
|
||||
messages=messages, # type: ignore
|
||||
)
|
||||
return response.content[0].text if response.content else ""
|
||||
elif provider == "groq":
|
||||
client = Groq(api_key=api_key)
|
||||
response = client.chat.completions.create(
|
||||
model=model.value,
|
||||
messages=messages, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
elif provider == "ollama":
|
||||
response = ollama.chat(
|
||||
model=model.value,
|
||||
messages=messages, # type: ignore
|
||||
stream=False, # type: ignore
|
||||
)
|
||||
return response["message"]["content"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
def llm_call(self, input_data: AIStructuredResponseGeneratorBlock.Input) -> str:
|
||||
block = AIStructuredResponseGeneratorBlock()
|
||||
response = block.run_once(input_data, "response")
|
||||
self.merge_stats(block.execution_stats)
|
||||
return response["response"]
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
api_key = (
|
||||
input_data.api_key.get_secret_value()
|
||||
or LlmApiKeys[input_data.model.metadata.provider].get_secret_value()
|
||||
)
|
||||
|
||||
messages = [message.model_dump() for message in input_data.messages]
|
||||
|
||||
response = self.llm_call(
|
||||
api_key=api_key,
|
||||
response = self.llm_call(
|
||||
AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="",
|
||||
api_key=input_data.api_key,
|
||||
model=input_data.model,
|
||||
messages=messages,
|
||||
conversation_history=input_data.messages,
|
||||
max_tokens=input_data.max_tokens,
|
||||
expected_format={},
|
||||
)
|
||||
)
|
||||
|
||||
yield "response", response
|
||||
except Exception as e:
|
||||
yield "error", f"Error calling LLM: {str(e)}"
|
||||
yield "response", response
|
||||
|
||||
|
||||
class AIListGeneratorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
focus: str | None = SchemaField(
|
||||
description="The focus of the list to generate.",
|
||||
placeholder="The top 5 most interesting news stories in the data.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
source_data: str | None = SchemaField(
|
||||
description="The data to generate the list from.",
|
||||
placeholder="News Today: Humans land on Mars: Today humans landed on mars. -- AI wins Nobel Prize: AI wins Nobel Prize for solving world hunger. -- New AI Model: A new AI model has been released.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4_TURBO,
|
||||
description="The language model to use for generating the list.",
|
||||
advanced=True,
|
||||
)
|
||||
api_key: BlockSecret = SecretField(value="")
|
||||
max_retries: int = SchemaField(
|
||||
default=3,
|
||||
description="Maximum number of retries for generating a valid list.",
|
||||
ge=1,
|
||||
le=5,
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
advanced=True,
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate in the chat completion.",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
generated_list: List[str] = SchemaField(description="The generated list.")
|
||||
list_item: str = SchemaField(
|
||||
description="Each individual item in the list.",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the list generation failed."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9c0b0450-d199-458b-a731-072189dd6593",
|
||||
description="Generate a Python list based on the given prompt using a Large Language Model (LLM).",
|
||||
categories={BlockCategory.AI, BlockCategory.TEXT},
|
||||
input_schema=AIListGeneratorBlock.Input,
|
||||
output_schema=AIListGeneratorBlock.Output,
|
||||
test_input={
|
||||
"focus": "planets",
|
||||
"source_data": (
|
||||
"Zylora Prime is a glowing jungle world with bioluminescent plants, "
|
||||
"while Kharon-9 is a harsh desert planet with underground cities. "
|
||||
"Vortexia's constant storms power floating cities, and Oceara is a water-covered world home to "
|
||||
"intelligent marine life. On icy Draknos, ancient ruins lie buried beneath its frozen landscape, "
|
||||
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
|
||||
"fictional worlds."
|
||||
),
|
||||
"model": LlmModel.GPT4_TURBO,
|
||||
"api_key": "test_api_key",
|
||||
"max_retries": 3,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"generated_list",
|
||||
["Zylora Prime", "Kharon-9", "Vortexia", "Oceara", "Draknos"],
|
||||
),
|
||||
("list_item", "Zylora Prime"),
|
||||
("list_item", "Kharon-9"),
|
||||
("list_item", "Vortexia"),
|
||||
("list_item", "Oceara"),
|
||||
("list_item", "Draknos"),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda input_data: {
|
||||
"response": "['Zylora Prime', 'Kharon-9', 'Vortexia', 'Oceara', 'Draknos']"
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def llm_call(
|
||||
input_data: AIStructuredResponseGeneratorBlock.Input,
|
||||
) -> dict[str, str]:
|
||||
llm_block = AIStructuredResponseGeneratorBlock()
|
||||
response = llm_block.run_once(input_data, "response")
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def string_to_list(string):
|
||||
"""
|
||||
Converts a string representation of a list into an actual Python list object.
|
||||
"""
|
||||
logger.debug(f"Converting string to list. Input string: {string}")
|
||||
try:
|
||||
# Use ast.literal_eval to safely evaluate the string
|
||||
python_list = ast.literal_eval(string)
|
||||
if isinstance(python_list, list):
|
||||
logger.debug(f"Successfully converted string to list: {python_list}")
|
||||
return python_list
|
||||
else:
|
||||
logger.error(f"The provided string '{string}' is not a valid list")
|
||||
raise ValueError(f"The provided string '{string}' is not a valid list.")
|
||||
except (SyntaxError, ValueError) as e:
|
||||
logger.error(f"Failed to convert string to list: {e}")
|
||||
raise ValueError("Invalid list format. Could not convert to list.")
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
logger.debug(f"Starting AIListGeneratorBlock.run with input data: {input_data}")
|
||||
|
||||
# Check for API key
|
||||
api_key_check = (
|
||||
input_data.api_key.get_secret_value()
|
||||
or LlmApiKeys[input_data.model.metadata.provider].get_secret_value()
|
||||
)
|
||||
if not api_key_check:
|
||||
raise ValueError("No LLM API key provided.")
|
||||
|
||||
# Prepare the system prompt
|
||||
sys_prompt = """You are a Python list generator. Your task is to generate a Python list based on the user's prompt.
|
||||
|Respond ONLY with a valid python list.
|
||||
|The list can contain strings, numbers, or nested lists as appropriate.
|
||||
|Do not include any explanations or additional text.
|
||||
|
||||
|Valid Example string formats:
|
||||
|
||||
|Example 1:
|
||||
|```
|
||||
|['1', '2', '3', '4']
|
||||
|```
|
||||
|
||||
|Example 2:
|
||||
|```
|
||||
|[['1', '2'], ['3', '4'], ['5', '6']]
|
||||
|```
|
||||
|
||||
|Example 3:
|
||||
|```
|
||||
|['1', ['2', '3'], ['4', ['5', '6']]]
|
||||
|```
|
||||
|
||||
|Example 4:
|
||||
|```
|
||||
|['a', 'b', 'c']
|
||||
|```
|
||||
|
||||
|Example 5:
|
||||
|```
|
||||
|['1', '2.5', 'string', 'True', ['False', 'None']]
|
||||
|```
|
||||
|
||||
|Do not include any explanations or additional text, just respond with the list in the format specified above.
|
||||
"""
|
||||
# If a focus is provided, add it to the prompt
|
||||
if input_data.focus:
|
||||
prompt = f"Generate a list with the following focus:\n<focus>\n\n{input_data.focus}</focus>"
|
||||
else:
|
||||
# If there's source data
|
||||
if input_data.source_data:
|
||||
prompt = "Extract the main focus of the source data to a list.\ni.e if the source data is a news website, the focus would be the news stories rather than the social links in the footer."
|
||||
else:
|
||||
# No focus or source data provided, generat a random list
|
||||
prompt = "Generate a random list."
|
||||
|
||||
# If the source data is provided, add it to the prompt
|
||||
if input_data.source_data:
|
||||
prompt += f"\n\nUse the following source data to generate the list from:\n\n<source_data>\n\n{input_data.source_data}</source_data>\n\nDo not invent fictional data that is not present in the source data."
|
||||
# Else, tell the LLM to synthesize the data
|
||||
else:
|
||||
prompt += "\n\nInvent the data to generate the list from."
|
||||
|
||||
for attempt in range(input_data.max_retries):
|
||||
try:
|
||||
logger.debug("Calling LLM")
|
||||
llm_response = self.llm_call(
|
||||
AIStructuredResponseGeneratorBlock.Input(
|
||||
sys_prompt=sys_prompt,
|
||||
prompt=prompt,
|
||||
api_key=input_data.api_key,
|
||||
model=input_data.model,
|
||||
expected_format={}, # Do not use structured response
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"LLM response: {llm_response}")
|
||||
|
||||
# Extract Response string
|
||||
response_string = llm_response["response"]
|
||||
logger.debug(f"Response string: {response_string}")
|
||||
|
||||
# Convert the string to a Python list
|
||||
logger.debug("Converting string to Python list")
|
||||
parsed_list = self.string_to_list(response_string)
|
||||
logger.debug(f"Parsed list: {parsed_list}")
|
||||
|
||||
# If we reach here, we have a valid Python list
|
||||
logger.debug("Successfully generated a valid Python list")
|
||||
yield "generated_list", parsed_list
|
||||
|
||||
# Yield each item in the list
|
||||
for item in parsed_list:
|
||||
yield "list_item", item
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in attempt {attempt + 1}: {str(e)}")
|
||||
if attempt == input_data.max_retries - 1:
|
||||
logger.error(
|
||||
f"Failed to generate a valid Python list after {input_data.max_retries} attempts"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to generate a valid Python list after {input_data.max_retries} attempts. Last error: {str(e)}"
|
||||
)
|
||||
else:
|
||||
# Add a retry prompt
|
||||
logger.debug("Preparing retry prompt")
|
||||
prompt = f"""
|
||||
The previous attempt failed due to `{e}`
|
||||
Generate a valid Python list based on the original prompt.
|
||||
Remember to respond ONLY with a valid Python list as per the format specified earlier.
|
||||
Original prompt:
|
||||
```{prompt}```
|
||||
|
||||
Respond only with the list in the format specified with no commentary or apologies.
|
||||
"""
|
||||
logger.debug(f"Retry prompt: {prompt}")
|
||||
|
||||
logger.debug("AIListGeneratorBlock.run completed")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
@@ -6,6 +7,12 @@ from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||
|
||||
|
||||
class PublishToMediumStatus(str, Enum):
|
||||
PUBLIC = "public"
|
||||
DRAFT = "draft"
|
||||
UNLISTED = "unlisted"
|
||||
|
||||
|
||||
class PublishToMediumBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
author_id: BlockSecret = SecretField(
|
||||
@@ -34,9 +41,9 @@ class PublishToMediumBlock(Block):
|
||||
description="The original home of this content, if it was originally published elsewhere",
|
||||
placeholder="https://yourblog.com/original-post",
|
||||
)
|
||||
publish_status: str = SchemaField(
|
||||
description="The publish status: 'public', 'draft', or 'unlisted'",
|
||||
placeholder="public",
|
||||
publish_status: PublishToMediumStatus = SchemaField(
|
||||
description="The publish status",
|
||||
placeholder=PublishToMediumStatus.DRAFT,
|
||||
)
|
||||
license: str = SchemaField(
|
||||
default="all-rights-reserved",
|
||||
@@ -79,7 +86,7 @@ class PublishToMediumBlock(Block):
|
||||
"tags": ["test", "automation"],
|
||||
"license": "all-rights-reserved",
|
||||
"notify_followers": False,
|
||||
"publish_status": "draft",
|
||||
"publish_status": PublishToMediumStatus.DRAFT.value,
|
||||
"api_key": "your_test_api_key",
|
||||
},
|
||||
test_output=[
|
||||
@@ -138,31 +145,25 @@ class PublishToMediumBlock(Block):
|
||||
return response.json()
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
response = self.create_post(
|
||||
input_data.api_key.get_secret_value(),
|
||||
input_data.author_id.get_secret_value(),
|
||||
input_data.title,
|
||||
input_data.content,
|
||||
input_data.content_format,
|
||||
input_data.tags,
|
||||
input_data.canonical_url,
|
||||
input_data.publish_status,
|
||||
input_data.license,
|
||||
input_data.notify_followers,
|
||||
response = self.create_post(
|
||||
input_data.api_key.get_secret_value(),
|
||||
input_data.author_id.get_secret_value(),
|
||||
input_data.title,
|
||||
input_data.content,
|
||||
input_data.content_format,
|
||||
input_data.tags,
|
||||
input_data.canonical_url,
|
||||
input_data.publish_status,
|
||||
input_data.license,
|
||||
input_data.notify_followers,
|
||||
)
|
||||
|
||||
if "data" in response:
|
||||
yield "post_id", response["data"]["id"]
|
||||
yield "post_url", response["data"]["url"]
|
||||
yield "published_at", response["data"]["publishedAt"]
|
||||
else:
|
||||
error_message = response.get("errors", [{}])[0].get(
|
||||
"message", "Unknown error occurred"
|
||||
)
|
||||
|
||||
if "data" in response:
|
||||
yield "post_id", response["data"]["id"]
|
||||
yield "post_url", response["data"]["url"]
|
||||
yield "published_at", response["data"]["publishedAt"]
|
||||
else:
|
||||
error_message = response.get("errors", [{}])[0].get(
|
||||
"message", "Unknown error occurred"
|
||||
)
|
||||
yield "error", f"Failed to create Medium post: {error_message}"
|
||||
|
||||
except requests.RequestException as e:
|
||||
yield "error", f"Network error occurred while creating Medium post: {str(e)}"
|
||||
except Exception as e:
|
||||
yield "error", f"Error occurred while creating Medium post: {str(e)}"
|
||||
raise RuntimeError(f"Failed to create Medium post: {error_message}")
|
||||
|
||||
131
autogpt_platform/backend/backend/blocks/pinecone.py
Normal file
131
autogpt_platform/backend/backend/blocks/pinecone.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from typing import Literal
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store import APIKeyCredentials
|
||||
from pinecone import Pinecone, ServerlessSpec
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, CredentialsMetaInput, SchemaField
|
||||
|
||||
PineconeCredentials = APIKeyCredentials
|
||||
PineconeCredentialsInput = CredentialsMetaInput[
|
||||
Literal["pinecone"],
|
||||
Literal["api_key"],
|
||||
]
|
||||
|
||||
|
||||
def PineconeCredentialsField() -> PineconeCredentialsInput:
|
||||
"""
|
||||
Creates a Pinecone credentials input on a block.
|
||||
|
||||
"""
|
||||
return CredentialsField(
|
||||
provider="pinecone",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Pinecone integration can be used with an API Key.",
|
||||
)
|
||||
|
||||
|
||||
class PineconeInitBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: PineconeCredentialsInput = PineconeCredentialsField()
|
||||
index_name: str = SchemaField(description="Name of the Pinecone index")
|
||||
dimension: int = SchemaField(
|
||||
description="Dimension of the vectors", default=768
|
||||
)
|
||||
metric: str = SchemaField(
|
||||
description="Distance metric for the index", default="cosine"
|
||||
)
|
||||
cloud: str = SchemaField(
|
||||
description="Cloud provider for serverless", default="aws"
|
||||
)
|
||||
region: str = SchemaField(
|
||||
description="Region for serverless", default="us-east-1"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
index: str = SchemaField(description="Name of the initialized Pinecone index")
|
||||
message: str = SchemaField(description="Status message")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="48d8fdab-8f03-41f3-8407-8107ba11ec9b",
|
||||
description="Initializes a Pinecone index",
|
||||
categories={BlockCategory.LOGIC},
|
||||
input_schema=PineconeInitBlock.Input,
|
||||
output_schema=PineconeInitBlock.Output,
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
pc = Pinecone(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
try:
|
||||
existing_indexes = pc.list_indexes()
|
||||
if input_data.index_name not in [index.name for index in existing_indexes]:
|
||||
pc.create_index(
|
||||
name=input_data.index_name,
|
||||
dimension=input_data.dimension,
|
||||
metric=input_data.metric,
|
||||
spec=ServerlessSpec(
|
||||
cloud=input_data.cloud, region=input_data.region
|
||||
),
|
||||
)
|
||||
message = f"Created new index: {input_data.index_name}"
|
||||
else:
|
||||
message = f"Using existing index: {input_data.index_name}"
|
||||
|
||||
yield "index", input_data.index_name
|
||||
yield "message", message
|
||||
except Exception as e:
|
||||
yield "message", f"Error initializing Pinecone index: {str(e)}"
|
||||
|
||||
|
||||
class PineconeQueryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: PineconeCredentialsInput = PineconeCredentialsField()
|
||||
query_vector: list = SchemaField(description="Query vector")
|
||||
namespace: str = SchemaField(
|
||||
description="Namespace to query in Pinecone", default=""
|
||||
)
|
||||
top_k: int = SchemaField(
|
||||
description="Number of top results to return", default=3
|
||||
)
|
||||
include_values: bool = SchemaField(
|
||||
description="Whether to include vector values in the response",
|
||||
default=False,
|
||||
)
|
||||
include_metadata: bool = SchemaField(
|
||||
description="Whether to include metadata in the response", default=True
|
||||
)
|
||||
host: str = SchemaField(description="Host for pinecone")
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: dict = SchemaField(description="Query results from Pinecone")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9ad93d0f-91b4-4c9c-8eb1-82e26b4a01c5",
|
||||
description="Queries a Pinecone index",
|
||||
categories={BlockCategory.LOGIC},
|
||||
input_schema=PineconeQueryBlock.Input,
|
||||
output_schema=PineconeQueryBlock.Output,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
pc = Pinecone(api_key=credentials.api_key.get_secret_value())
|
||||
idx = pc.Index(host=input_data.host)
|
||||
results = idx.query(
|
||||
namespace=input_data.namespace,
|
||||
vector=input_data.query_vector,
|
||||
top_k=input_data.top_k,
|
||||
include_values=input_data.include_values,
|
||||
include_metadata=input_data.include_metadata,
|
||||
)
|
||||
yield "results", results
|
||||
@@ -2,10 +2,10 @@ from datetime import datetime, timezone
|
||||
from typing import Iterator
|
||||
|
||||
import praw
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import BlockSecret, SecretField
|
||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||
from backend.util.mock import MockObject
|
||||
|
||||
|
||||
@@ -48,25 +48,25 @@ def get_praw(creds: RedditCredentials) -> praw.Reddit:
|
||||
|
||||
class GetRedditPostsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
subreddit: str = Field(description="Subreddit name")
|
||||
creds: RedditCredentials = Field(
|
||||
subreddit: str = SchemaField(description="Subreddit name")
|
||||
creds: RedditCredentials = SchemaField(
|
||||
description="Reddit credentials",
|
||||
default=RedditCredentials(),
|
||||
)
|
||||
last_minutes: int | None = Field(
|
||||
last_minutes: int | None = SchemaField(
|
||||
description="Post time to stop minutes ago while fetching posts",
|
||||
default=None,
|
||||
)
|
||||
last_post: str | None = Field(
|
||||
last_post: str | None = SchemaField(
|
||||
description="Post ID to stop when reached while fetching posts",
|
||||
default=None,
|
||||
)
|
||||
post_limit: int | None = Field(
|
||||
post_limit: int | None = SchemaField(
|
||||
description="Number of posts to fetch", default=10
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post: RedditPost = Field(description="Reddit post")
|
||||
post: RedditPost = SchemaField(description="Reddit post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -140,13 +140,13 @@ class GetRedditPostsBlock(Block):
|
||||
|
||||
class PostRedditCommentBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
creds: RedditCredentials = Field(
|
||||
creds: RedditCredentials = SchemaField(
|
||||
description="Reddit credentials", default=RedditCredentials()
|
||||
)
|
||||
data: RedditComment = Field(description="Reddit comment")
|
||||
data: RedditComment = SchemaField(description="Reddit comment")
|
||||
|
||||
class Output(BlockSchema):
|
||||
comment_id: str
|
||||
comment_id: str = SchemaField(description="Posted comment ID")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -0,0 +1,201 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
|
||||
import replicate
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||
|
||||
|
||||
# Model name enum
|
||||
class ReplicateFluxModelName(str, Enum):
|
||||
FLUX_SCHNELL = ("Flux Schnell",)
|
||||
FLUX_PRO = ("Flux Pro",)
|
||||
FLUX_PRO1_1 = ("Flux Pro 1.1",)
|
||||
|
||||
@property
|
||||
def api_name(self):
|
||||
api_names = {
|
||||
ReplicateFluxModelName.FLUX_SCHNELL: "black-forest-labs/flux-schnell",
|
||||
ReplicateFluxModelName.FLUX_PRO: "black-forest-labs/flux-pro",
|
||||
ReplicateFluxModelName.FLUX_PRO1_1: "black-forest-labs/flux-1.1-pro",
|
||||
}
|
||||
return api_names[self]
|
||||
|
||||
|
||||
# Image type Enum
|
||||
class ImageType(str, Enum):
|
||||
WEBP = "webp"
|
||||
JPG = "jpg"
|
||||
PNG = "png"
|
||||
|
||||
|
||||
class ReplicateFluxAdvancedModelBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
api_key: BlockSecret = SecretField(
|
||||
key="replicate_api_key",
|
||||
description="Replicate API Key",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="Text prompt for image generation",
|
||||
placeholder="e.g., 'A futuristic cityscape at sunset'",
|
||||
title="Prompt",
|
||||
)
|
||||
replicate_model_name: ReplicateFluxModelName = SchemaField(
|
||||
description="The name of the Image Generation Model, i.e Flux Schnell",
|
||||
default=ReplicateFluxModelName.FLUX_SCHNELL,
|
||||
title="Image Generation Model",
|
||||
advanced=False,
|
||||
)
|
||||
seed: int | None = SchemaField(
|
||||
description="Random seed. Set for reproducible generation",
|
||||
default=None,
|
||||
title="Seed",
|
||||
)
|
||||
steps: int = SchemaField(
|
||||
description="Number of diffusion steps",
|
||||
default=25,
|
||||
title="Steps",
|
||||
)
|
||||
guidance: float = SchemaField(
|
||||
description=(
|
||||
"Controls the balance between adherence to the text prompt and image quality/diversity. "
|
||||
"Higher values make the output more closely match the prompt but may reduce overall image quality."
|
||||
),
|
||||
default=3,
|
||||
title="Guidance",
|
||||
)
|
||||
interval: float = SchemaField(
|
||||
description=(
|
||||
"Interval is a setting that increases the variance in possible outputs. "
|
||||
"Setting this value low will ensure strong prompt following with more consistent outputs."
|
||||
),
|
||||
default=2,
|
||||
title="Interval",
|
||||
)
|
||||
aspect_ratio: str = SchemaField(
|
||||
description="Aspect ratio for the generated image",
|
||||
default="1:1",
|
||||
title="Aspect Ratio",
|
||||
placeholder="Choose from: 1:1, 16:9, 2:3, 3:2, 4:5, 5:4, 9:16",
|
||||
)
|
||||
output_format: ImageType = SchemaField(
|
||||
description="File format of the output image",
|
||||
default=ImageType.WEBP,
|
||||
title="Output Format",
|
||||
)
|
||||
output_quality: int = SchemaField(
|
||||
description=(
|
||||
"Quality when saving the output images, from 0 to 100. "
|
||||
"Not relevant for .png outputs"
|
||||
),
|
||||
default=80,
|
||||
title="Output Quality",
|
||||
)
|
||||
safety_tolerance: int = SchemaField(
|
||||
description="Safety tolerance, 1 is most strict and 5 is most permissive",
|
||||
default=2,
|
||||
title="Safety Tolerance",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: str = SchemaField(description="Generated output")
|
||||
error: str = SchemaField(description="Error message if the model run failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="90f8c45e-e983-4644-aa0b-b4ebe2f531bc",
|
||||
description="This block runs Flux models on Replicate with advanced settings.",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=ReplicateFluxAdvancedModelBlock.Input,
|
||||
output_schema=ReplicateFluxAdvancedModelBlock.Output,
|
||||
test_input={
|
||||
"api_key": "test_api_key",
|
||||
"replicate_model_name": ReplicateFluxModelName.FLUX_SCHNELL,
|
||||
"prompt": "A beautiful landscape painting of a serene lake at sunrise",
|
||||
"seed": None,
|
||||
"steps": 25,
|
||||
"guidance": 3.0,
|
||||
"interval": 2.0,
|
||||
"aspect_ratio": "1:1",
|
||||
"output_format": ImageType.PNG,
|
||||
"output_quality": 80,
|
||||
"safety_tolerance": 2,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"result",
|
||||
"https://replicate.com/output/generated-image-url.jpg",
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"run_model": lambda api_key, model_name, prompt, seed, steps, guidance, interval, aspect_ratio, output_format, output_quality, safety_tolerance: "https://replicate.com/output/generated-image-url.jpg",
|
||||
},
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# If the seed is not provided, generate a random seed
|
||||
seed = input_data.seed
|
||||
if seed is None:
|
||||
seed = int.from_bytes(os.urandom(4), "big")
|
||||
|
||||
# Run the model using the provided inputs
|
||||
result = self.run_model(
|
||||
api_key=input_data.api_key.get_secret_value(),
|
||||
model_name=input_data.replicate_model_name.api_name,
|
||||
prompt=input_data.prompt,
|
||||
seed=seed,
|
||||
steps=input_data.steps,
|
||||
guidance=input_data.guidance,
|
||||
interval=input_data.interval,
|
||||
aspect_ratio=input_data.aspect_ratio,
|
||||
output_format=input_data.output_format,
|
||||
output_quality=input_data.output_quality,
|
||||
safety_tolerance=input_data.safety_tolerance,
|
||||
)
|
||||
yield "result", result
|
||||
|
||||
def run_model(
|
||||
self,
|
||||
api_key,
|
||||
model_name,
|
||||
prompt,
|
||||
seed,
|
||||
steps,
|
||||
guidance,
|
||||
interval,
|
||||
aspect_ratio,
|
||||
output_format,
|
||||
output_quality,
|
||||
safety_tolerance,
|
||||
):
|
||||
# Initialize Replicate client with the API key
|
||||
client = replicate.Client(api_token=api_key)
|
||||
|
||||
# Run the model with additional parameters
|
||||
output = client.run(
|
||||
f"{model_name}",
|
||||
input={
|
||||
"prompt": prompt,
|
||||
"seed": seed,
|
||||
"steps": steps,
|
||||
"guidance": guidance,
|
||||
"interval": interval,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"output_format": output_format,
|
||||
"output_quality": output_quality,
|
||||
"safety_tolerance": safety_tolerance,
|
||||
},
|
||||
)
|
||||
|
||||
# Check if output is a list or a string and extract accordingly; otherwise, assign a default message
|
||||
if isinstance(output, list) and len(output) > 0:
|
||||
result_url = output[0] # If output is a list, get the first element
|
||||
elif isinstance(output, str):
|
||||
result_url = output # If output is a string, use it directly
|
||||
else:
|
||||
result_url = (
|
||||
"No output received" # Fallback message if output is not as expected
|
||||
)
|
||||
|
||||
return result_url
|
||||
@@ -43,7 +43,7 @@ class ReadRSSFeedBlock(Block):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c6731acb-4105-4zp1-bc9b-03d0036h370g",
|
||||
id="5ebe6768-8e5d-41e3-9134-1c7bd89a8d52",
|
||||
input_schema=ReadRSSFeedBlock.Input,
|
||||
output_schema=ReadRSSFeedBlock.Output,
|
||||
description="Reads RSS feed entries from a given URL.",
|
||||
|
||||
@@ -4,7 +4,7 @@ from urllib.parse import quote
|
||||
import requests
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import BlockSecret, SecretField
|
||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||
|
||||
|
||||
class GetRequest:
|
||||
@@ -17,15 +17,17 @@ class GetRequest:
|
||||
|
||||
class GetWikipediaSummaryBlock(Block, GetRequest):
|
||||
class Input(BlockSchema):
|
||||
topic: str
|
||||
topic: str = SchemaField(description="The topic to fetch the summary for")
|
||||
|
||||
class Output(BlockSchema):
|
||||
summary: str
|
||||
error: str
|
||||
summary: str = SchemaField(description="The summary of the given topic")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the summary cannot be retrieved"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="h5e7f8g9-1b2c-3d4e-5f6g-7h8i9j0k1l2m",
|
||||
id="f5b0f5d0-1862-4d61-94be-3ad0fa772760",
|
||||
description="This block fetches the summary of a given topic from Wikipedia.",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=GetWikipediaSummaryBlock.Input,
|
||||
@@ -36,33 +38,27 @@ class GetWikipediaSummaryBlock(Block, GetRequest):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
topic = input_data.topic
|
||||
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{topic}"
|
||||
response = self.get_request(url, json=True)
|
||||
yield "summary", response["extract"]
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
yield "error", f"HTTP error occurred: {http_err}"
|
||||
|
||||
except requests.RequestException as e:
|
||||
yield "error", f"Request to Wikipedia failed: {e}"
|
||||
|
||||
except KeyError as e:
|
||||
yield "error", f"Error parsing Wikipedia response: {e}"
|
||||
topic = input_data.topic
|
||||
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{topic}"
|
||||
response = self.get_request(url, json=True)
|
||||
if "extract" not in response:
|
||||
raise RuntimeError(f"Unable to parse Wikipedia response: {response}")
|
||||
yield "summary", response["extract"]
|
||||
|
||||
|
||||
class SearchTheWebBlock(Block, GetRequest):
|
||||
class Input(BlockSchema):
|
||||
query: str # The search query
|
||||
query: str = SchemaField(description="The search query to search the web for")
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: str # The search results including content from top 5 URLs
|
||||
error: str # Error message if the search fails
|
||||
results: str = SchemaField(
|
||||
description="The search results including content from top 5 URLs"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the search fails")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b2c3d4e5-6f7g-8h9i-0j1k-l2m3n4o5p6q7",
|
||||
id="87840993-2053-44b7-8da4-187ad4ee518c",
|
||||
description="This block searches the internet for the given search query.",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=SearchTheWebBlock.Input,
|
||||
@@ -73,37 +69,38 @@ class SearchTheWebBlock(Block, GetRequest):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# Encode the search query
|
||||
encoded_query = quote(input_data.query)
|
||||
# Encode the search query
|
||||
encoded_query = quote(input_data.query)
|
||||
|
||||
# Prepend the Jina Search URL to the encoded query
|
||||
jina_search_url = f"https://s.jina.ai/{encoded_query}"
|
||||
# Prepend the Jina Search URL to the encoded query
|
||||
jina_search_url = f"https://s.jina.ai/{encoded_query}"
|
||||
|
||||
# Make the request to Jina Search
|
||||
response = self.get_request(jina_search_url, json=False)
|
||||
# Make the request to Jina Search
|
||||
response = self.get_request(jina_search_url, json=False)
|
||||
|
||||
# Output the search results
|
||||
yield "results", response
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
yield "error", f"HTTP error occurred: {http_err}"
|
||||
|
||||
except requests.RequestException as e:
|
||||
yield "error", f"Request to Jina Search failed: {e}"
|
||||
# Output the search results
|
||||
yield "results", response
|
||||
|
||||
|
||||
class ExtractWebsiteContentBlock(Block, GetRequest):
|
||||
class Input(BlockSchema):
|
||||
url: str # The URL to scrape
|
||||
url: str = SchemaField(description="The URL to scrape the content from")
|
||||
raw_content: bool = SchemaField(
|
||||
default=False,
|
||||
title="Raw Content",
|
||||
description="Whether to do a raw scrape of the content or use Jina-ai Reader to scrape the content",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
content: str # The scraped content from the URL
|
||||
error: str
|
||||
content: str = SchemaField(description="The scraped content from the given URL")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the content cannot be retrieved"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a1b2c3d4-5e6f-7g8h-9i0j-k1l2m3n4o5p6", # Unique ID for the block
|
||||
id="436c3984-57fd-4b85-8e9a-459b356883bd",
|
||||
description="This block scrapes the content from the given web URL.",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExtractWebsiteContentBlock.Input,
|
||||
@@ -114,34 +111,37 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# Prepend the Jina-ai Reader URL to the input URL
|
||||
jina_url = f"https://r.jina.ai/{input_data.url}"
|
||||
if input_data.raw_content:
|
||||
url = input_data.url
|
||||
else:
|
||||
url = f"https://r.jina.ai/{input_data.url}"
|
||||
|
||||
# Make the request to Jina-ai Reader
|
||||
response = self.get_request(jina_url, json=False)
|
||||
|
||||
# Output the scraped content
|
||||
yield "content", response
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
yield "error", f"HTTP error occurred: {http_err}"
|
||||
|
||||
except requests.RequestException as e:
|
||||
yield "error", f"Request to Jina-ai Reader failed: {e}"
|
||||
content = self.get_request(url, json=False)
|
||||
yield "content", content
|
||||
|
||||
|
||||
class GetWeatherInformationBlock(Block, GetRequest):
|
||||
class Input(BlockSchema):
|
||||
location: str
|
||||
location: str = SchemaField(
|
||||
description="Location to get weather information for"
|
||||
)
|
||||
api_key: BlockSecret = SecretField(key="openweathermap_api_key")
|
||||
use_celsius: bool = True
|
||||
use_celsius: bool = SchemaField(
|
||||
default=True,
|
||||
description="Whether to use Celsius or Fahrenheit for temperature",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
temperature: str
|
||||
humidity: str
|
||||
condition: str
|
||||
error: str
|
||||
temperature: str = SchemaField(
|
||||
description="Temperature in the specified location"
|
||||
)
|
||||
humidity: str = SchemaField(description="Humidity in the specified location")
|
||||
condition: str = SchemaField(
|
||||
description="Weather condition in the specified location"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the weather information cannot be retrieved"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -168,26 +168,15 @@ class GetWeatherInformationBlock(Block, GetRequest):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
units = "metric" if input_data.use_celsius else "imperial"
|
||||
api_key = input_data.api_key.get_secret_value()
|
||||
location = input_data.location
|
||||
url = f"http://api.openweathermap.org/data/2.5/weather?q={quote(location)}&appid={api_key}&units={units}"
|
||||
weather_data = self.get_request(url, json=True)
|
||||
units = "metric" if input_data.use_celsius else "imperial"
|
||||
api_key = input_data.api_key.get_secret_value()
|
||||
location = input_data.location
|
||||
url = f"http://api.openweathermap.org/data/2.5/weather?q={quote(location)}&appid={api_key}&units={units}"
|
||||
weather_data = self.get_request(url, json=True)
|
||||
|
||||
if "main" in weather_data and "weather" in weather_data:
|
||||
yield "temperature", str(weather_data["main"]["temp"])
|
||||
yield "humidity", str(weather_data["main"]["humidity"])
|
||||
yield "condition", weather_data["weather"][0]["description"]
|
||||
else:
|
||||
yield "error", f"Expected keys not found in response: {weather_data}"
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
if http_err.response.status_code == 403:
|
||||
yield "error", "Request to weather API failed: 403 Forbidden. Check your API key and permissions."
|
||||
else:
|
||||
yield "error", f"HTTP error occurred: {http_err}"
|
||||
except requests.RequestException as e:
|
||||
yield "error", f"Request to weather API failed: {e}"
|
||||
except KeyError as e:
|
||||
yield "error", f"Error processing weather data: {e}"
|
||||
if "main" in weather_data and "weather" in weather_data:
|
||||
yield "temperature", str(weather_data["main"]["temp"])
|
||||
yield "humidity", str(weather_data["main"]["humidity"])
|
||||
yield "condition", weather_data["weather"][0]["description"]
|
||||
else:
|
||||
raise RuntimeError(f"Expected keys not found in response: {weather_data}")
|
||||
|
||||
@@ -13,7 +13,8 @@ class CreateTalkingAvatarVideoBlock(Block):
|
||||
key="did_api_key", description="D-ID API Key"
|
||||
)
|
||||
script_input: str = SchemaField(
|
||||
description="The text input for the script", default="Welcome to AutoGPT"
|
||||
description="The text input for the script",
|
||||
placeholder="Welcome to AutoGPT",
|
||||
)
|
||||
provider: Literal["microsoft", "elevenlabs", "amazon"] = SchemaField(
|
||||
description="The voice provider to use", default="microsoft"
|
||||
@@ -106,41 +107,40 @@ class CreateTalkingAvatarVideoBlock(Block):
|
||||
return response.json()
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# Create the clip
|
||||
payload = {
|
||||
"script": {
|
||||
"type": "text",
|
||||
"subtitles": str(input_data.subtitles).lower(),
|
||||
"provider": {
|
||||
"type": input_data.provider,
|
||||
"voice_id": input_data.voice_id,
|
||||
},
|
||||
"ssml": str(input_data.ssml).lower(),
|
||||
"input": input_data.script_input,
|
||||
# Create the clip
|
||||
payload = {
|
||||
"script": {
|
||||
"type": "text",
|
||||
"subtitles": str(input_data.subtitles).lower(),
|
||||
"provider": {
|
||||
"type": input_data.provider,
|
||||
"voice_id": input_data.voice_id,
|
||||
},
|
||||
"config": {"result_format": input_data.result_format},
|
||||
"presenter_config": {"crop": {"type": input_data.crop_type}},
|
||||
"presenter_id": input_data.presenter_id,
|
||||
"driver_id": input_data.driver_id,
|
||||
}
|
||||
"ssml": str(input_data.ssml).lower(),
|
||||
"input": input_data.script_input,
|
||||
},
|
||||
"config": {"result_format": input_data.result_format},
|
||||
"presenter_config": {"crop": {"type": input_data.crop_type}},
|
||||
"presenter_id": input_data.presenter_id,
|
||||
"driver_id": input_data.driver_id,
|
||||
}
|
||||
|
||||
response = self.create_clip(input_data.api_key.get_secret_value(), payload)
|
||||
clip_id = response["id"]
|
||||
response = self.create_clip(input_data.api_key.get_secret_value(), payload)
|
||||
clip_id = response["id"]
|
||||
|
||||
# Poll for clip status
|
||||
for _ in range(input_data.max_polling_attempts):
|
||||
status_response = self.get_clip_status(
|
||||
input_data.api_key.get_secret_value(), clip_id
|
||||
# Poll for clip status
|
||||
for _ in range(input_data.max_polling_attempts):
|
||||
status_response = self.get_clip_status(
|
||||
input_data.api_key.get_secret_value(), clip_id
|
||||
)
|
||||
if status_response["status"] == "done":
|
||||
yield "video_url", status_response["result_url"]
|
||||
return
|
||||
elif status_response["status"] == "error":
|
||||
raise RuntimeError(
|
||||
f"Clip creation failed: {status_response.get('error', 'Unknown error')}"
|
||||
)
|
||||
if status_response["status"] == "done":
|
||||
yield "video_url", status_response["result_url"]
|
||||
return
|
||||
elif status_response["status"] == "error":
|
||||
yield "error", f"Clip creation failed: {status_response.get('error', 'Unknown error')}"
|
||||
return
|
||||
time.sleep(input_data.polling_interval)
|
||||
|
||||
yield "error", "Clip creation timed out"
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
time.sleep(input_data.polling_interval)
|
||||
|
||||
raise TimeoutError("Clip creation timed out")
|
||||
|
||||
@@ -2,9 +2,9 @@ import re
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import BaseLoader, Environment
|
||||
from pydantic import Field
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
|
||||
jinja = Environment(loader=BaseLoader())
|
||||
@@ -12,15 +12,17 @@ jinja = Environment(loader=BaseLoader())
|
||||
|
||||
class MatchTextPatternBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: Any = Field(description="Text to match")
|
||||
match: str = Field(description="Pattern (Regex) to match")
|
||||
data: Any = Field(description="Data to be forwarded to output")
|
||||
case_sensitive: bool = Field(description="Case sensitive match", default=True)
|
||||
dot_all: bool = Field(description="Dot matches all", default=True)
|
||||
text: Any = SchemaField(description="Text to match")
|
||||
match: str = SchemaField(description="Pattern (Regex) to match")
|
||||
data: Any = SchemaField(description="Data to be forwarded to output")
|
||||
case_sensitive: bool = SchemaField(
|
||||
description="Case sensitive match", default=True
|
||||
)
|
||||
dot_all: bool = SchemaField(description="Dot matches all", default=True)
|
||||
|
||||
class Output(BlockSchema):
|
||||
positive: Any = Field(description="Output data if match is found")
|
||||
negative: Any = Field(description="Output data if match is not found")
|
||||
positive: Any = SchemaField(description="Output data if match is found")
|
||||
negative: Any = SchemaField(description="Output data if match is not found")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -64,15 +66,17 @@ class MatchTextPatternBlock(Block):
|
||||
|
||||
class ExtractTextInformationBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: Any = Field(description="Text to parse")
|
||||
pattern: str = Field(description="Pattern (Regex) to parse")
|
||||
group: int = Field(description="Group number to extract", default=0)
|
||||
case_sensitive: bool = Field(description="Case sensitive match", default=True)
|
||||
dot_all: bool = Field(description="Dot matches all", default=True)
|
||||
text: Any = SchemaField(description="Text to parse")
|
||||
pattern: str = SchemaField(description="Pattern (Regex) to parse")
|
||||
group: int = SchemaField(description="Group number to extract", default=0)
|
||||
case_sensitive: bool = SchemaField(
|
||||
description="Case sensitive match", default=True
|
||||
)
|
||||
dot_all: bool = SchemaField(description="Dot matches all", default=True)
|
||||
|
||||
class Output(BlockSchema):
|
||||
positive: str = Field(description="Extracted text")
|
||||
negative: str = Field(description="Original text")
|
||||
positive: str = SchemaField(description="Extracted text")
|
||||
negative: str = SchemaField(description="Original text")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -116,11 +120,15 @@ class ExtractTextInformationBlock(Block):
|
||||
|
||||
class FillTextTemplateBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: dict[str, Any] = Field(description="Values (dict) to be used in format")
|
||||
format: str = Field(description="Template to format the text using `values`")
|
||||
values: dict[str, Any] = SchemaField(
|
||||
description="Values (dict) to be used in format"
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="Template to format the text using `values`"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: str
|
||||
output: str = SchemaField(description="Formatted text")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -155,11 +163,13 @@ class FillTextTemplateBlock(Block):
|
||||
|
||||
class CombineTextsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: list[str] = Field(description="text input to combine")
|
||||
delimiter: str = Field(description="Delimiter to combine texts", default="")
|
||||
input: list[str] = SchemaField(description="text input to combine")
|
||||
delimiter: str = SchemaField(
|
||||
description="Delimiter to combine texts", default=""
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: str = Field(description="Combined text")
|
||||
output: str = SchemaField(description="Combined text")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||
|
||||
|
||||
class UnrealTextToSpeechBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(
|
||||
description="The text to be converted to speech",
|
||||
placeholder="Enter the text you want to convert to speech",
|
||||
)
|
||||
voice_id: str = SchemaField(
|
||||
description="The voice ID to use for text-to-speech conversion",
|
||||
placeholder="Scarlett",
|
||||
default="Scarlett",
|
||||
)
|
||||
api_key: BlockSecret = SecretField(
|
||||
key="unreal_speech_api_key", description="Your Unreal Speech API key"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
mp3_url: str = SchemaField(description="The URL of the generated MP3 file")
|
||||
error: str = SchemaField(description="Error message if the API call failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4ff1ff6d-cc40-4caa-ae69-011daa20c378",
|
||||
description="Converts text to speech using the Unreal Speech API",
|
||||
categories={BlockCategory.AI, BlockCategory.TEXT},
|
||||
input_schema=UnrealTextToSpeechBlock.Input,
|
||||
output_schema=UnrealTextToSpeechBlock.Output,
|
||||
test_input={
|
||||
"text": "This is a test of the text to speech API.",
|
||||
"voice_id": "Scarlett",
|
||||
"api_key": "test_api_key",
|
||||
},
|
||||
test_output=[("mp3_url", "https://example.com/test.mp3")],
|
||||
test_mock={
|
||||
"call_unreal_speech_api": lambda *args, **kwargs: {
|
||||
"OutputUri": "https://example.com/test.mp3"
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def call_unreal_speech_api(
|
||||
api_key: str, text: str, voice_id: str
|
||||
) -> dict[str, Any]:
|
||||
url = "https://api.v7.unrealspeech.com/speech"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
data = {
|
||||
"Text": text,
|
||||
"VoiceId": voice_id,
|
||||
"Bitrate": "192k",
|
||||
"Speed": "0",
|
||||
"Pitch": "1",
|
||||
"TimestampType": "sentence",
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
api_response = self.call_unreal_speech_api(
|
||||
input_data.api_key.get_secret_value(),
|
||||
input_data.text,
|
||||
input_data.voice_id,
|
||||
)
|
||||
yield "mp3_url", api_response["OutputUri"]
|
||||
@@ -3,14 +3,22 @@ from datetime import datetime, timedelta
|
||||
from typing import Any, Union
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class GetCurrentTimeBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
trigger: str
|
||||
trigger: str = SchemaField(
|
||||
description="Trigger any data to output the current time"
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="Format of the time to output", default="%H:%M:%S"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
time: str
|
||||
time: str = SchemaField(
|
||||
description="Current time in the specified format (default: %H:%M:%S)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -20,25 +28,38 @@ class GetCurrentTimeBlock(Block):
|
||||
input_schema=GetCurrentTimeBlock.Input,
|
||||
output_schema=GetCurrentTimeBlock.Output,
|
||||
test_input=[
|
||||
{"trigger": "Hello", "format": "{time}"},
|
||||
{"trigger": "Hello"},
|
||||
{"trigger": "Hello", "format": "%H:%M"},
|
||||
],
|
||||
test_output=[
|
||||
("time", lambda _: time.strftime("%H:%M:%S")),
|
||||
("time", lambda _: time.strftime("%H:%M")),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
current_time = time.strftime("%H:%M:%S")
|
||||
current_time = time.strftime(input_data.format)
|
||||
yield "time", current_time
|
||||
|
||||
|
||||
class GetCurrentDateBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
trigger: str
|
||||
offset: Union[int, str]
|
||||
trigger: str = SchemaField(
|
||||
description="Trigger any data to output the current date"
|
||||
)
|
||||
offset: Union[int, str] = SchemaField(
|
||||
title="Days Offset",
|
||||
description="Offset in days from the current date",
|
||||
default=0,
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="Format of the date to output", default="%Y-%m-%d"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
date: str
|
||||
date: str = SchemaField(
|
||||
description="Current date in the specified format (default: YYYY-MM-DD)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -48,7 +69,8 @@ class GetCurrentDateBlock(Block):
|
||||
input_schema=GetCurrentDateBlock.Input,
|
||||
output_schema=GetCurrentDateBlock.Output,
|
||||
test_input=[
|
||||
{"trigger": "Hello", "format": "{date}", "offset": "7"},
|
||||
{"trigger": "Hello", "offset": "7"},
|
||||
{"trigger": "Hello", "offset": "7", "format": "%m/%d/%Y"},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
@@ -56,6 +78,12 @@ class GetCurrentDateBlock(Block):
|
||||
lambda t: abs(datetime.now() - datetime.strptime(t, "%Y-%m-%d"))
|
||||
< timedelta(days=8), # 7 days difference + 1 day error margin.
|
||||
),
|
||||
(
|
||||
"date",
|
||||
lambda t: abs(datetime.now() - datetime.strptime(t, "%m/%d/%Y"))
|
||||
< timedelta(days=8),
|
||||
# 7 days difference + 1 day error margin.
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -65,25 +93,33 @@ class GetCurrentDateBlock(Block):
|
||||
except ValueError:
|
||||
offset = 0
|
||||
current_date = datetime.now() - timedelta(days=offset)
|
||||
yield "date", current_date.strftime("%Y-%m-%d")
|
||||
yield "date", current_date.strftime(input_data.format)
|
||||
|
||||
|
||||
class GetCurrentDateAndTimeBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
trigger: str
|
||||
trigger: str = SchemaField(
|
||||
description="Trigger any data to output the current date and time"
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="Format of the date and time to output",
|
||||
default="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
date_time: str
|
||||
date_time: str = SchemaField(
|
||||
description="Current date and time in the specified format (default: YYYY-MM-DD HH:MM:SS)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0h2",
|
||||
id="716a67b3-6760-42e7-86dc-18645c6e00fc",
|
||||
description="This block outputs the current date and time.",
|
||||
categories={BlockCategory.TEXT},
|
||||
input_schema=GetCurrentDateAndTimeBlock.Input,
|
||||
output_schema=GetCurrentDateAndTimeBlock.Output,
|
||||
test_input=[
|
||||
{"trigger": "Hello", "format": "{date_time}"},
|
||||
{"trigger": "Hello"},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
@@ -97,20 +133,29 @@ class GetCurrentDateAndTimeBlock(Block):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
current_date_time = time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
current_date_time = time.strftime(input_data.format)
|
||||
yield "date_time", current_date_time
|
||||
|
||||
|
||||
class CountdownTimerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input_message: Any = "timer finished"
|
||||
seconds: Union[int, str] = 0
|
||||
minutes: Union[int, str] = 0
|
||||
hours: Union[int, str] = 0
|
||||
days: Union[int, str] = 0
|
||||
input_message: Any = SchemaField(
|
||||
description="Message to output after the timer finishes",
|
||||
default="timer finished",
|
||||
)
|
||||
seconds: Union[int, str] = SchemaField(
|
||||
description="Duration in seconds", default=0
|
||||
)
|
||||
minutes: Union[int, str] = SchemaField(
|
||||
description="Duration in minutes", default=0
|
||||
)
|
||||
hours: Union[int, str] = SchemaField(description="Duration in hours", default=0)
|
||||
days: Union[int, str] = SchemaField(description="Duration in days", default=0)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output_message: str
|
||||
output_message: str = SchemaField(
|
||||
description="Message after the timer finishes"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -7,9 +7,10 @@ from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TranscribeYouTubeVideoBlock(Block):
|
||||
class TranscribeYoutubeVideoBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
youtube_url: str = SchemaField(
|
||||
title="YouTube URL",
|
||||
description="The URL of the YouTube video to transcribe",
|
||||
placeholder="https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
||||
)
|
||||
@@ -24,8 +25,8 @@ class TranscribeYouTubeVideoBlock(Block):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f3a8f7e1-4b1d-4e5f-9f2a-7c3d5a2e6b4c",
|
||||
input_schema=TranscribeYouTubeVideoBlock.Input,
|
||||
output_schema=TranscribeYouTubeVideoBlock.Output,
|
||||
input_schema=TranscribeYoutubeVideoBlock.Input,
|
||||
output_schema=TranscribeYoutubeVideoBlock.Output,
|
||||
description="Transcribes a YouTube video.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
test_input={"youtube_url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ"},
|
||||
@@ -64,14 +65,11 @@ class TranscribeYouTubeVideoBlock(Block):
|
||||
return YouTubeTranscriptApi.get_transcript(video_id)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
video_id = self.extract_video_id(input_data.youtube_url)
|
||||
yield "video_id", video_id
|
||||
video_id = self.extract_video_id(input_data.youtube_url)
|
||||
yield "video_id", video_id
|
||||
|
||||
transcript = self.get_transcript(video_id)
|
||||
formatter = TextFormatter()
|
||||
transcript_text = formatter.format_transcript(transcript)
|
||||
transcript = self.get_transcript(video_id)
|
||||
formatter = TextFormatter()
|
||||
transcript_text = formatter.format_transcript(transcript)
|
||||
|
||||
yield "transcript", transcript_text
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "transcript", transcript_text
|
||||
|
||||
@@ -217,13 +217,13 @@ def websocket(server_address: str, graph_id: str):
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
import websockets
|
||||
import websockets.asyncio.client
|
||||
|
||||
from backend.server.ws_api import ExecutionSubscription, Methods, WsMessage
|
||||
|
||||
async def send_message(server_address: str):
|
||||
uri = f"ws://{server_address}"
|
||||
async with websockets.connect(uri) as websocket:
|
||||
async with websockets.asyncio.client.connect(uri) as websocket:
|
||||
try:
|
||||
msg = WsMessage(
|
||||
method=Methods.SUBSCRIBE,
|
||||
|
||||
@@ -45,7 +45,9 @@ class BlockCategory(Enum):
|
||||
INPUT = "Block that interacts with input of the graph."
|
||||
OUTPUT = "Block that interacts with output of the graph."
|
||||
LOGIC = "Programming logic to control the flow of your agent"
|
||||
COMMUNICATION = "Block that interacts with communication platforms."
|
||||
DEVELOPER_TOOLS = "Developer tools such as GitHub blocks."
|
||||
DATA = "Block that interacts with structured data."
|
||||
|
||||
def dict(self) -> dict[str, str]:
|
||||
return {"category": self.name, "description": self.value}
|
||||
@@ -228,6 +230,11 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.disabled = disabled
|
||||
self.static_output = static_output
|
||||
self.block_type = block_type
|
||||
self.execution_stats = {}
|
||||
|
||||
@classmethod
|
||||
def create(cls: Type["Block"]) -> "Block":
|
||||
return cls()
|
||||
|
||||
@abstractmethod
|
||||
def run(self, input_data: BlockSchemaInputType, **kwargs) -> BlockOutput:
|
||||
@@ -242,6 +249,26 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
"""
|
||||
pass
|
||||
|
||||
def run_once(self, input_data: BlockSchemaInputType, output: str, **kwargs) -> Any:
|
||||
for name, data in self.run(input_data, **kwargs):
|
||||
if name == output:
|
||||
return data
|
||||
raise ValueError(f"{self.name} did not produce any output for {output}")
|
||||
|
||||
def merge_stats(self, stats: dict[str, Any]) -> dict[str, Any]:
|
||||
for key, value in stats.items():
|
||||
if isinstance(value, dict):
|
||||
self.execution_stats.setdefault(key, {}).update(value)
|
||||
elif isinstance(value, (int, float)):
|
||||
self.execution_stats.setdefault(key, 0)
|
||||
self.execution_stats[key] += value
|
||||
elif isinstance(value, list):
|
||||
self.execution_stats.setdefault(key, [])
|
||||
self.execution_stats[key].extend(value)
|
||||
else:
|
||||
self.execution_stats[key] = value
|
||||
return self.execution_stats
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
@@ -270,6 +297,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
for output_name, output_data in self.run(
|
||||
self.input_schema(**input_data), **kwargs
|
||||
):
|
||||
if output_name == "error":
|
||||
raise RuntimeError(output_data)
|
||||
if error := self.output_schema.validate_field(output_name, output_data):
|
||||
raise ValueError(f"Block produced an invalid output data: {error}")
|
||||
yield output_name, output_data
|
||||
@@ -278,15 +307,18 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
# ======================= Block Helper Functions ======================= #
|
||||
|
||||
|
||||
def get_blocks() -> dict[str, Block]:
|
||||
def get_blocks() -> dict[str, Type[Block]]:
|
||||
from backend.blocks import AVAILABLE_BLOCKS # noqa: E402
|
||||
|
||||
return AVAILABLE_BLOCKS
|
||||
|
||||
|
||||
async def initialize_blocks() -> None:
|
||||
for block in get_blocks().values():
|
||||
existing_block = await AgentBlock.prisma().find_unique(where={"id": block.id})
|
||||
for cls in get_blocks().values():
|
||||
block = cls()
|
||||
existing_block = await AgentBlock.prisma().find_first(
|
||||
where={"OR": [{"id": block.id}, {"name": block.name}]}
|
||||
)
|
||||
if not existing_block:
|
||||
await AgentBlock.prisma().create(
|
||||
data={
|
||||
@@ -301,13 +333,15 @@ async def initialize_blocks() -> None:
|
||||
input_schema = json.dumps(block.input_schema.jsonschema())
|
||||
output_schema = json.dumps(block.output_schema.jsonschema())
|
||||
if (
|
||||
block.name != existing_block.name
|
||||
block.id != existing_block.id
|
||||
or block.name != existing_block.name
|
||||
or input_schema != existing_block.inputSchema
|
||||
or output_schema != existing_block.outputSchema
|
||||
):
|
||||
await AgentBlock.prisma().update(
|
||||
where={"id": block.id},
|
||||
where={"id": existing_block.id},
|
||||
data={
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"inputSchema": input_schema,
|
||||
"outputSchema": output_schema,
|
||||
@@ -316,4 +350,5 @@ async def initialize_blocks() -> None:
|
||||
|
||||
|
||||
def get_block(block_id: str) -> Block | None:
|
||||
return get_blocks().get(block_id)
|
||||
cls = get_blocks().get(block_id)
|
||||
return cls() if cls else None
|
||||
|
||||
@@ -17,8 +17,9 @@ from backend.blocks.llm import (
|
||||
AITextSummarizerBlock,
|
||||
LlmModel,
|
||||
)
|
||||
from backend.blocks.search import ExtractWebsiteContentBlock, SearchTheWebBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.data.block import Block, BlockInput
|
||||
from backend.data.block import Block, BlockInput, get_block
|
||||
from backend.util.settings import Config
|
||||
|
||||
|
||||
@@ -74,6 +75,10 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
CreateTalkingAvatarVideoBlock: [
|
||||
BlockCost(cost_amount=15, cost_filter={"api_key": None})
|
||||
],
|
||||
SearchTheWebBlock: [BlockCost(cost_amount=1)],
|
||||
ExtractWebsiteContentBlock: [
|
||||
BlockCost(cost_amount=1, cost_filter={"raw_content": False})
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -96,7 +101,7 @@ class UserCreditBase(ABC):
|
||||
self,
|
||||
user_id: str,
|
||||
user_credit: int,
|
||||
block: Block,
|
||||
block_id: str,
|
||||
input_data: BlockInput,
|
||||
data_size: float,
|
||||
run_time: float,
|
||||
@@ -107,7 +112,7 @@ class UserCreditBase(ABC):
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
user_credit (int): The current credit for the user.
|
||||
block (Block): The block that is being used.
|
||||
block_id (str): The block ID.
|
||||
input_data (BlockInput): The input data for the block.
|
||||
data_size (float): The size of the data being processed.
|
||||
run_time (float): The time taken to run the block.
|
||||
@@ -208,12 +213,16 @@ class UserCredit(UserCreditBase):
|
||||
self,
|
||||
user_id: str,
|
||||
user_credit: int,
|
||||
block: Block,
|
||||
block_id: str,
|
||||
input_data: BlockInput,
|
||||
data_size: float,
|
||||
run_time: float,
|
||||
validate_balance: bool = True,
|
||||
) -> int:
|
||||
block = get_block(block_id)
|
||||
if not block:
|
||||
raise ValueError(f"Block not found: {block_id}")
|
||||
|
||||
cost, matching_filter = self._block_usage_cost(
|
||||
block=block, input_data=input_data, data_size=data_size, run_time=run_time
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -8,40 +7,30 @@ from dotenv import load_dotenv
|
||||
from prisma import Prisma
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.util.retry import conn_retry
|
||||
|
||||
load_dotenv()
|
||||
|
||||
PRISMA_SCHEMA = os.getenv("PRISMA_SCHEMA", "schema.prisma")
|
||||
os.environ["PRISMA_SCHEMA_PATH"] = PRISMA_SCHEMA
|
||||
|
||||
prisma, conn_id = Prisma(auto_register=True), ""
|
||||
prisma = Prisma(auto_register=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def connect(call_count=0):
|
||||
global conn_id
|
||||
if not conn_id:
|
||||
conn_id = str(uuid4())
|
||||
|
||||
try:
|
||||
logger.info(f"[Prisma-{conn_id}] Acquiring connection..")
|
||||
if not prisma.is_connected():
|
||||
await prisma.connect()
|
||||
logger.info(f"[Prisma-{conn_id}] Connection acquired!")
|
||||
except Exception as e:
|
||||
if call_count <= 5:
|
||||
logger.info(f"[Prisma-{conn_id}] Connection failed: {e}. Retrying now..")
|
||||
await asyncio.sleep(2**call_count)
|
||||
await connect(call_count + 1)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
async def disconnect():
|
||||
@conn_retry("Prisma", "Acquiring connection")
|
||||
async def connect():
|
||||
if prisma.is_connected():
|
||||
logger.info(f"[Prisma-{conn_id}] Releasing connection.")
|
||||
await prisma.disconnect()
|
||||
logger.info(f"[Prisma-{conn_id}] Connection released.")
|
||||
return
|
||||
await prisma.connect()
|
||||
|
||||
|
||||
@conn_retry("Prisma", "Releasing connection")
|
||||
async def disconnect():
|
||||
if not prisma.is_connected():
|
||||
return
|
||||
await prisma.disconnect()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import datetime, timezone
|
||||
from multiprocessing import Manager
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import Credentials
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
from prisma.models import (
|
||||
AgentGraphExecution,
|
||||
@@ -26,7 +25,6 @@ class GraphExecution(BaseModel):
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
start_node_execs: list["NodeExecution"]
|
||||
node_input_credentials: dict[str, Credentials] # dict[node_id, Credentials]
|
||||
|
||||
|
||||
class NodeExecution(BaseModel):
|
||||
@@ -268,10 +266,29 @@ async def update_graph_execution_start_time(graph_exec_id: str):
|
||||
)
|
||||
|
||||
|
||||
async def update_graph_execution_stats(graph_exec_id: str, stats: dict[str, Any]):
|
||||
async def update_graph_execution_stats(
|
||||
graph_exec_id: str,
|
||||
error: Exception | None,
|
||||
wall_time: float,
|
||||
cpu_time: float,
|
||||
node_count: int,
|
||||
):
|
||||
status = ExecutionStatus.FAILED if error else ExecutionStatus.COMPLETED
|
||||
stats = (
|
||||
{
|
||||
"walltime": wall_time,
|
||||
"cputime": cpu_time,
|
||||
"nodecount": node_count,
|
||||
"error": str(error) if error else None,
|
||||
},
|
||||
)
|
||||
|
||||
await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
data={"executionStatus": ExecutionStatus.COMPLETED, "stats": json.dumps(stats)},
|
||||
data={
|
||||
"executionStatus": status,
|
||||
"stats": json.dumps(stats),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,20 +2,18 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
import prisma.types
|
||||
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
|
||||
from prisma.types import AgentGraphInclude
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
from pydantic import BaseModel
|
||||
from pydantic_core import PydanticUndefinedType
|
||||
|
||||
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock
|
||||
from backend.data.block import BlockInput, get_block, get_blocks
|
||||
from backend.data.db import BaseDbModel, transaction
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.util import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -53,17 +51,8 @@ class Node(BaseDbModel):
|
||||
block_id: str
|
||||
input_default: BlockInput = {} # dict[input_name, default_value]
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
_input_links: list[Link] = PrivateAttr(default=[])
|
||||
_output_links: list[Link] = PrivateAttr(default=[])
|
||||
|
||||
@property
|
||||
def input_links(self) -> list[Link]:
|
||||
return self._input_links
|
||||
|
||||
@property
|
||||
def output_links(self) -> list[Link]:
|
||||
return self._output_links
|
||||
input_links: list[Link] = []
|
||||
output_links: list[Link] = []
|
||||
|
||||
@staticmethod
|
||||
def from_db(node: AgentNode):
|
||||
@@ -75,8 +64,8 @@ class Node(BaseDbModel):
|
||||
input_default=json.loads(node.constantInput),
|
||||
metadata=json.loads(node.metadata),
|
||||
)
|
||||
obj._input_links = [Link.from_db(link) for link in node.Input or []]
|
||||
obj._output_links = [Link.from_db(link) for link in node.Output or []]
|
||||
obj.input_links = [Link.from_db(link) for link in node.Input or []]
|
||||
obj.output_links = [Link.from_db(link) for link in node.Output or []]
|
||||
return obj
|
||||
|
||||
|
||||
@@ -268,7 +257,7 @@ class Graph(GraphMeta):
|
||||
|
||||
block = get_block(node.block_id)
|
||||
if not block:
|
||||
blocks = {v.id: v.name for v in get_blocks().values()}
|
||||
blocks = {v().id: v().name for v in get_blocks().values()}
|
||||
raise ValueError(
|
||||
f"{suffix}, {node.block_id} is invalid block id, available blocks: {blocks}"
|
||||
)
|
||||
@@ -330,7 +319,7 @@ class Graph(GraphMeta):
|
||||
return input_schema
|
||||
|
||||
@staticmethod
|
||||
def from_db(graph: AgentGraph):
|
||||
def from_db(graph: AgentGraph, hide_credentials: bool = False):
|
||||
nodes = [
|
||||
*(graph.AgentNodes or []),
|
||||
*(
|
||||
@@ -341,7 +330,7 @@ class Graph(GraphMeta):
|
||||
]
|
||||
return Graph(
|
||||
**GraphMeta.from_db(graph).model_dump(),
|
||||
nodes=[Node.from_db(node) for node in nodes],
|
||||
nodes=[Graph._process_node(node, hide_credentials) for node in nodes],
|
||||
links=list(
|
||||
{
|
||||
Link.from_db(link)
|
||||
@@ -355,6 +344,31 @@ class Graph(GraphMeta):
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_node(node: AgentNode, hide_credentials: bool) -> Node:
|
||||
node_dict = node.model_dump()
|
||||
if hide_credentials and "constantInput" in node_dict:
|
||||
constant_input = json.loads(node_dict["constantInput"])
|
||||
constant_input = Graph._hide_credentials_in_input(constant_input)
|
||||
node_dict["constantInput"] = json.dumps(constant_input)
|
||||
return Node.from_db(AgentNode(**node_dict))
|
||||
|
||||
@staticmethod
|
||||
def _hide_credentials_in_input(input_data: dict[str, Any]) -> dict[str, Any]:
|
||||
sensitive_keys = ["credentials", "api_key", "password", "token", "secret"]
|
||||
result = {}
|
||||
for key, value in input_data.items():
|
||||
if isinstance(value, dict):
|
||||
result[key] = Graph._hide_credentials_in_input(value)
|
||||
elif isinstance(value, str) and any(
|
||||
sensitive_key in key.lower() for sensitive_key in sensitive_keys
|
||||
):
|
||||
# Skip this key-value pair in the result
|
||||
continue
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
|
||||
"Input": True,
|
||||
@@ -382,9 +396,9 @@ async def get_node(node_id: str) -> Node:
|
||||
|
||||
|
||||
async def get_graphs_meta(
|
||||
user_id: str,
|
||||
include_executions: bool = False,
|
||||
filter_by: Literal["active", "template"] | None = "active",
|
||||
user_id: str | None = None,
|
||||
) -> list[GraphMeta]:
|
||||
"""
|
||||
Retrieves graph metadata objects.
|
||||
@@ -393,6 +407,7 @@ async def get_graphs_meta(
|
||||
Args:
|
||||
include_executions: Whether to include executions in the graph metadata.
|
||||
filter_by: An optional filter to either select templates or active graphs.
|
||||
user_id: The ID of the user that owns the graph.
|
||||
|
||||
Returns:
|
||||
list[GraphMeta]: A list of objects representing the retrieved graph metadata.
|
||||
@@ -404,8 +419,7 @@ async def get_graphs_meta(
|
||||
elif filter_by == "template":
|
||||
where_clause["isTemplate"] = True
|
||||
|
||||
if user_id and filter_by != "template":
|
||||
where_clause["userId"] = user_id
|
||||
where_clause["userId"] = user_id
|
||||
|
||||
graphs = await AgentGraph.prisma().find_many(
|
||||
where=where_clause,
|
||||
@@ -431,6 +445,7 @@ async def get_graph(
|
||||
version: int | None = None,
|
||||
template: bool = False,
|
||||
user_id: str | None = None,
|
||||
hide_credentials: bool = False,
|
||||
) -> Graph | None:
|
||||
"""
|
||||
Retrieves a graph from the DB.
|
||||
@@ -456,7 +471,7 @@ async def get_graph(
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
order={"version": "desc"},
|
||||
)
|
||||
return Graph.from_db(graph) if graph else None
|
||||
return Graph.from_db(graph, hide_credentials) if graph else None
|
||||
|
||||
|
||||
async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:
|
||||
@@ -500,6 +515,15 @@ async def get_graph_all_versions(graph_id: str, user_id: str) -> list[Graph]:
|
||||
return [Graph.from_db(graph) for graph in graph_versions]
|
||||
|
||||
|
||||
async def delete_graph(graph_id: str, user_id: str) -> int:
|
||||
entries_count = await AgentGraph.prisma().delete_many(
|
||||
where={"id": graph_id, "userId": user_id}
|
||||
)
|
||||
if entries_count:
|
||||
logger.info(f"Deleted {entries_count} graph entries for Graph #{graph_id}")
|
||||
return entries_count
|
||||
|
||||
|
||||
async def create_graph(graph: Graph, user_id: str) -> Graph:
|
||||
async with transaction() as tx:
|
||||
await __create_graph(tx, graph, user_id)
|
||||
@@ -576,30 +600,3 @@ async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
for link in graph.links
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# --------------------- Helper functions --------------------- #
|
||||
|
||||
|
||||
TEMPLATES_DIR = Path(__file__).parent.parent.parent / "graph_templates"
|
||||
|
||||
|
||||
async def import_packaged_templates() -> None:
|
||||
templates_in_db = await get_graphs_meta(filter_by="template")
|
||||
|
||||
logging.info("Loading templates...")
|
||||
for template_file in TEMPLATES_DIR.glob("*.json"):
|
||||
template_data = json.loads(template_file.read_bytes())
|
||||
|
||||
template = Graph.model_validate(template_data)
|
||||
if not template.is_template:
|
||||
logging.warning(
|
||||
f"pre-packaged graph file {template_file} is not a template"
|
||||
)
|
||||
continue
|
||||
if (
|
||||
exists := next((t for t in templates_in_db if t.id == template.id), None)
|
||||
) and exists.version >= template.version:
|
||||
continue
|
||||
await create_graph(template, DEFAULT_USER_ID)
|
||||
logging.info(f"Loaded template '{template.name}' ({template.id})")
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Generator, Generic, TypeVar
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio.client import PubSub as AsyncPubSub
|
||||
from redis.client import PubSub
|
||||
|
||||
from backend.data import redis
|
||||
from backend.data.execution import ExecutionResult
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = Config()
|
||||
|
||||
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
@@ -18,60 +23,122 @@ class DateTimeEncoder(json.JSONEncoder):
|
||||
return super().default(o)
|
||||
|
||||
|
||||
class AsyncEventQueue(ABC):
|
||||
M = TypeVar("M", bound=BaseModel)
|
||||
|
||||
|
||||
class BaseRedisEventBus(Generic[M], ABC):
|
||||
Model: type[M]
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
async def connect(self):
|
||||
def event_bus_name(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
pass
|
||||
def _serialize_message(self, item: M, channel_key: str) -> tuple[str, str]:
|
||||
message = json.dumps(item.model_dump(), cls=DateTimeEncoder)
|
||||
channel_name = f"{self.event_bus_name}-{channel_key}"
|
||||
logger.info(f"[{channel_name}] Publishing an event to Redis {message}")
|
||||
return message, channel_name
|
||||
|
||||
@abstractmethod
|
||||
async def put(self, execution_result: ExecutionResult):
|
||||
pass
|
||||
def _deserialize_message(self, msg: Any, channel_key: str) -> M | None:
|
||||
message_type = "pmessage" if "*" in channel_key else "message"
|
||||
if msg["type"] != message_type:
|
||||
return None
|
||||
try:
|
||||
data = json.loads(msg["data"])
|
||||
logger.info(f"Consuming an event from Redis {data}")
|
||||
return self.Model(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse event result from Redis {msg} {e}")
|
||||
|
||||
@abstractmethod
|
||||
async def get(self) -> ExecutionResult | None:
|
||||
pass
|
||||
def _subscribe(
|
||||
self, connection: redis.Redis | redis.AsyncRedis, channel_key: str
|
||||
) -> tuple[PubSub | AsyncPubSub, str]:
|
||||
channel_name = f"{self.event_bus_name}-{channel_key}"
|
||||
pubsub = connection.pubsub()
|
||||
return pubsub, channel_name
|
||||
|
||||
|
||||
class AsyncRedisEventQueue(AsyncEventQueue):
|
||||
def __init__(self):
|
||||
self.host = os.getenv("REDIS_HOST", "localhost")
|
||||
self.port = int(os.getenv("REDIS_PORT", "6379"))
|
||||
self.password = os.getenv("REDIS_PASSWORD", "password")
|
||||
self.queue_name = os.getenv("REDIS_QUEUE", "execution_events")
|
||||
self.connection = None
|
||||
class RedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
Model: type[M]
|
||||
|
||||
async def connect(self):
|
||||
if not self.connection:
|
||||
self.connection = Redis(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
password=self.password,
|
||||
decode_responses=True,
|
||||
)
|
||||
await self.connection.ping()
|
||||
logger.info(f"Connected to Redis on {self.host}:{self.port}")
|
||||
@property
|
||||
def connection(self) -> redis.Redis:
|
||||
return redis.get_redis()
|
||||
|
||||
async def put(self, execution_result: ExecutionResult):
|
||||
if self.connection:
|
||||
message = json.dumps(execution_result.model_dump(), cls=DateTimeEncoder)
|
||||
logger.info(f"Putting execution result to Redis {message}")
|
||||
await self.connection.lpush(self.queue_name, message) # type: ignore
|
||||
def publish_event(self, event: M, channel_key: str):
|
||||
message, channel_name = self._serialize_message(event, channel_key)
|
||||
self.connection.publish(channel_name, message)
|
||||
|
||||
async def get(self) -> ExecutionResult | None:
|
||||
if self.connection:
|
||||
message = await self.connection.rpop(self.queue_name) # type: ignore
|
||||
if message is not None and isinstance(message, (str, bytes, bytearray)):
|
||||
data = json.loads(message)
|
||||
logger.info(f"Getting execution result from Redis {data}")
|
||||
return ExecutionResult(**data)
|
||||
return None
|
||||
def listen_events(self, channel_key: str) -> Generator[M, None, None]:
|
||||
pubsub, channel_name = self._subscribe(self.connection, channel_key)
|
||||
assert isinstance(pubsub, PubSub)
|
||||
|
||||
async def close(self):
|
||||
if self.connection:
|
||||
await self.connection.close()
|
||||
self.connection = None
|
||||
logger.info("Closed connection to Redis")
|
||||
if "*" in channel_key:
|
||||
pubsub.psubscribe(channel_name)
|
||||
else:
|
||||
pubsub.subscribe(channel_name)
|
||||
|
||||
for message in pubsub.listen():
|
||||
if event := self._deserialize_message(message, channel_key):
|
||||
yield event
|
||||
|
||||
|
||||
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
Model: type[M]
|
||||
|
||||
@property
|
||||
async def connection(self) -> redis.AsyncRedis:
|
||||
return await redis.get_redis_async()
|
||||
|
||||
async def publish_event(self, event: M, channel_key: str):
|
||||
message, channel_name = self._serialize_message(event, channel_key)
|
||||
connection = await self.connection
|
||||
await connection.publish(channel_name, message)
|
||||
|
||||
async def listen_events(self, channel_key: str) -> AsyncGenerator[M, None]:
|
||||
pubsub, channel_name = self._subscribe(await self.connection, channel_key)
|
||||
assert isinstance(pubsub, AsyncPubSub)
|
||||
|
||||
if "*" in channel_key:
|
||||
await pubsub.psubscribe(channel_name)
|
||||
else:
|
||||
await pubsub.subscribe(channel_name)
|
||||
|
||||
async for message in pubsub.listen():
|
||||
if event := self._deserialize_message(message, channel_key):
|
||||
yield event
|
||||
|
||||
|
||||
class RedisExecutionEventBus(RedisEventBus[ExecutionResult]):
|
||||
Model = ExecutionResult
|
||||
|
||||
@property
|
||||
def event_bus_name(self) -> str:
|
||||
return config.execution_event_bus_name
|
||||
|
||||
def publish(self, res: ExecutionResult):
|
||||
self.publish_event(res, f"{res.graph_id}-{res.graph_exec_id}")
|
||||
|
||||
def listen(
|
||||
self, graph_id: str = "*", graph_exec_id: str = "*"
|
||||
) -> Generator[ExecutionResult, None, None]:
|
||||
for execution_result in self.listen_events(f"{graph_id}-{graph_exec_id}"):
|
||||
yield execution_result
|
||||
|
||||
|
||||
class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionResult]):
|
||||
Model = ExecutionResult
|
||||
|
||||
@property
|
||||
def event_bus_name(self) -> str:
|
||||
return config.execution_event_bus_name
|
||||
|
||||
async def publish(self, res: ExecutionResult):
|
||||
await self.publish_event(res, f"{res.graph_id}-{res.graph_exec_id}")
|
||||
|
||||
async def listen(
|
||||
self, graph_id: str = "*", graph_exec_id: str = "*"
|
||||
) -> AsyncGenerator[ExecutionResult, None]:
|
||||
async for execution_result in self.listen_events(f"{graph_id}-{graph_exec_id}"):
|
||||
yield execution_result
|
||||
|
||||
84
autogpt_platform/backend/backend/data/redis.py
Normal file
84
autogpt_platform/backend/backend/data/redis.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
from backend.util.retry import conn_retry
|
||||
|
||||
load_dotenv()
|
||||
|
||||
HOST = os.getenv("REDIS_HOST", "localhost")
|
||||
PORT = int(os.getenv("REDIS_PORT", "6379"))
|
||||
PASSWORD = os.getenv("REDIS_PASSWORD", "password")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
connection: Redis | None = None
|
||||
connection_async: AsyncRedis | None = None
|
||||
|
||||
|
||||
@conn_retry("Redis", "Acquiring connection")
|
||||
def connect() -> Redis:
|
||||
global connection
|
||||
if connection:
|
||||
return connection
|
||||
|
||||
c = Redis(
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
password=PASSWORD,
|
||||
decode_responses=True,
|
||||
)
|
||||
c.ping()
|
||||
connection = c
|
||||
return connection
|
||||
|
||||
|
||||
@conn_retry("Redis", "Releasing connection")
|
||||
def disconnect():
|
||||
global connection
|
||||
if connection:
|
||||
connection.close()
|
||||
connection = None
|
||||
|
||||
|
||||
def get_redis(auto_connect: bool = True) -> Redis:
|
||||
if connection:
|
||||
return connection
|
||||
if auto_connect:
|
||||
return connect()
|
||||
raise RuntimeError("Redis connection is not established")
|
||||
|
||||
|
||||
@conn_retry("AsyncRedis", "Acquiring connection")
|
||||
async def connect_async() -> AsyncRedis:
|
||||
global connection_async
|
||||
if connection_async:
|
||||
return connection_async
|
||||
|
||||
c = AsyncRedis(
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
password=PASSWORD,
|
||||
decode_responses=True,
|
||||
)
|
||||
await c.ping()
|
||||
connection_async = c
|
||||
return connection_async
|
||||
|
||||
|
||||
@conn_retry("AsyncRedis", "Releasing connection")
|
||||
async def disconnect_async():
|
||||
global connection_async
|
||||
if connection_async:
|
||||
await connection_async.close()
|
||||
connection_async = None
|
||||
|
||||
|
||||
async def get_redis_async(auto_connect: bool = True) -> AsyncRedis:
|
||||
if connection_async:
|
||||
return connection_async
|
||||
if auto_connect:
|
||||
return await connect_async()
|
||||
raise RuntimeError("AsyncRedis connection is not established")
|
||||
@@ -1,6 +1,8 @@
|
||||
from typing import Optional
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import UserMetadataRaw
|
||||
from fastapi import HTTPException
|
||||
from prisma import Json
|
||||
from prisma.models import User
|
||||
|
||||
from backend.data.db import prisma
|
||||
@@ -35,16 +37,32 @@ async def get_user_by_id(user_id: str) -> Optional[User]:
|
||||
return User.model_validate(user) if user else None
|
||||
|
||||
|
||||
async def create_default_user(enable_auth: str) -> Optional[User]:
|
||||
if not enable_auth.lower() == "true":
|
||||
user = await prisma.user.find_unique(where={"id": DEFAULT_USER_ID})
|
||||
if not user:
|
||||
user = await prisma.user.create(
|
||||
data={
|
||||
"id": DEFAULT_USER_ID,
|
||||
"email": "default@example.com",
|
||||
"name": "Default User",
|
||||
}
|
||||
)
|
||||
return User.model_validate(user)
|
||||
return None
|
||||
async def create_default_user() -> Optional[User]:
|
||||
user = await prisma.user.find_unique(where={"id": DEFAULT_USER_ID})
|
||||
if not user:
|
||||
user = await prisma.user.create(
|
||||
data={
|
||||
"id": DEFAULT_USER_ID,
|
||||
"email": "default@example.com",
|
||||
"name": "Default User",
|
||||
}
|
||||
)
|
||||
return User.model_validate(user)
|
||||
|
||||
|
||||
async def get_user_metadata(user_id: str) -> UserMetadataRaw:
|
||||
user = await User.prisma().find_unique_or_raise(
|
||||
where={"id": user_id},
|
||||
)
|
||||
return (
|
||||
UserMetadataRaw.model_validate(user.metadata)
|
||||
if user.metadata
|
||||
else UserMetadataRaw()
|
||||
)
|
||||
|
||||
|
||||
async def update_user_metadata(user_id: str, metadata: UserMetadataRaw):
|
||||
await User.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={"metadata": Json(metadata.model_dump())},
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from backend.app import run_processes
|
||||
from backend.executor import ExecutionManager
|
||||
from backend.executor import DatabaseManager, ExecutionManager
|
||||
|
||||
|
||||
def main():
|
||||
@@ -7,6 +7,7 @@ def main():
|
||||
Run all the processes required for the AutoGPT-server REST API.
|
||||
"""
|
||||
run_processes(
|
||||
DatabaseManager(),
|
||||
ExecutionManager(),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from .database import DatabaseManager
|
||||
from .manager import ExecutionManager
|
||||
from .scheduler import ExecutionScheduler
|
||||
|
||||
__all__ = [
|
||||
"DatabaseManager",
|
||||
"ExecutionManager",
|
||||
"ExecutionScheduler",
|
||||
]
|
||||
|
||||
84
autogpt_platform/backend/backend/executor/database.py
Normal file
84
autogpt_platform/backend/backend/executor/database.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Concatenate, Coroutine, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
ExecutionResult,
|
||||
create_graph_execution,
|
||||
get_execution_results,
|
||||
get_incomplete_executions,
|
||||
get_latest_execution,
|
||||
update_execution_status,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_stats,
|
||||
upsert_execution_input,
|
||||
upsert_execution_output,
|
||||
)
|
||||
from backend.data.graph import get_graph, get_node
|
||||
from backend.data.queue import RedisExecutionEventBus
|
||||
from backend.data.user import get_user_metadata, update_user_metadata
|
||||
from backend.util.service import AppService, expose
|
||||
from backend.util.settings import Config
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.use_db = True
|
||||
self.use_redis = True
|
||||
self.event_queue = RedisExecutionEventBus()
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return Config().database_api_port
|
||||
|
||||
@expose
|
||||
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
|
||||
self.event_queue.publish(ExecutionResult(**execution_result_dict))
|
||||
|
||||
@staticmethod
|
||||
def exposed_run_and_wait(
|
||||
f: Callable[P, Coroutine[None, None, R]]
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
@expose
|
||||
@wraps(f)
|
||||
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
coroutine = f(*args, **kwargs)
|
||||
res = self.run_and_wait(coroutine)
|
||||
return res
|
||||
|
||||
return wrapper
|
||||
|
||||
# Executions
|
||||
create_graph_execution = exposed_run_and_wait(create_graph_execution)
|
||||
get_execution_results = exposed_run_and_wait(get_execution_results)
|
||||
get_incomplete_executions = exposed_run_and_wait(get_incomplete_executions)
|
||||
get_latest_execution = exposed_run_and_wait(get_latest_execution)
|
||||
update_execution_status = exposed_run_and_wait(update_execution_status)
|
||||
update_graph_execution_stats = exposed_run_and_wait(update_graph_execution_stats)
|
||||
update_node_execution_stats = exposed_run_and_wait(update_node_execution_stats)
|
||||
upsert_execution_input = exposed_run_and_wait(upsert_execution_input)
|
||||
upsert_execution_output = exposed_run_and_wait(upsert_execution_output)
|
||||
|
||||
# Graphs
|
||||
get_node = exposed_run_and_wait(get_node)
|
||||
get_graph = exposed_run_and_wait(get_graph)
|
||||
|
||||
# Credits
|
||||
user_credit_model = get_user_credit_model()
|
||||
get_or_refill_credit = cast(
|
||||
Callable[[Any, str], int],
|
||||
exposed_run_and_wait(user_credit_model.get_or_refill_credit),
|
||||
)
|
||||
spend_credits = cast(
|
||||
Callable[[Any, str, int, str, dict[str, str], float, float], int],
|
||||
exposed_run_and_wait(user_credit_model.spend_credits),
|
||||
)
|
||||
|
||||
# User + User Metadata
|
||||
get_user_metadata = exposed_run_and_wait(get_user_metadata)
|
||||
update_user_metadata = exposed_run_and_wait(update_user_metadata)
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import logging
|
||||
import multiprocessing
|
||||
@@ -9,45 +8,40 @@ import threading
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing.pool import AsyncResult, Pool
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import Credentials
|
||||
from pydantic import BaseModel
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.executor import DatabaseManager
|
||||
|
||||
from backend.data import db
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.data import redis
|
||||
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionResult,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
NodeExecution,
|
||||
create_graph_execution,
|
||||
get_execution_results,
|
||||
get_incomplete_executions,
|
||||
get_latest_execution,
|
||||
merge_execution_input,
|
||||
parse_execution_output,
|
||||
update_execution_status,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_stats,
|
||||
upsert_execution_input,
|
||||
upsert_execution_output,
|
||||
)
|
||||
from backend.data.graph import Graph, Link, Node, get_graph, get_node
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
from backend.data.model import CREDENTIALS_FIELD_NAME, CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util import json
|
||||
from backend.util.decorator import error_logged, time_measured
|
||||
from backend.util.logging import configure_logging
|
||||
from backend.util.process import set_service_name
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.settings import Config
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.type import convert
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class LogMetadata:
|
||||
@@ -100,10 +94,9 @@ ExecutionStream = Generator[NodeExecution, None, None]
|
||||
|
||||
|
||||
def execute_node(
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
api_client: "AgentServer",
|
||||
db_client: "DatabaseManager",
|
||||
creds_manager: IntegrationCredentialsManager,
|
||||
data: NodeExecution,
|
||||
input_credentials: Credentials | None = None,
|
||||
execution_stats: dict[str, Any] | None = None,
|
||||
) -> ExecutionStream:
|
||||
"""
|
||||
@@ -111,8 +104,8 @@ def execute_node(
|
||||
persist the execution result, and return the subsequent node to be executed.
|
||||
|
||||
Args:
|
||||
loop: The event loop to run the async functions.
|
||||
api_client: The client to send execution updates to the server.
|
||||
db_client: The client to send execution updates to the server.
|
||||
creds_manager: The manager to acquire and release credentials.
|
||||
data: The execution data for executing the current node.
|
||||
execution_stats: The execution statistics to be updated.
|
||||
|
||||
@@ -125,17 +118,12 @@ def execute_node(
|
||||
node_exec_id = data.node_exec_id
|
||||
node_id = data.node_id
|
||||
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
def wait(f: Coroutine[Any, Any, T]) -> T:
|
||||
return loop.run_until_complete(f)
|
||||
|
||||
def update_execution(status: ExecutionStatus) -> ExecutionResult:
|
||||
exec_update = wait(update_execution_status(node_exec_id, status))
|
||||
api_client.send_execution_update(exec_update.model_dump())
|
||||
exec_update = db_client.update_execution_status(node_exec_id, status)
|
||||
db_client.send_execution_update(exec_update.model_dump())
|
||||
return exec_update
|
||||
|
||||
node = wait(get_node(node_id))
|
||||
node = db_client.get_node(node_id)
|
||||
|
||||
node_block = get_block(node.block_id)
|
||||
if not node_block:
|
||||
@@ -161,28 +149,34 @@ def execute_node(
|
||||
input_size = len(input_data_str)
|
||||
log_metadata.info("Executed node with input", input=input_data_str)
|
||||
update_execution(ExecutionStatus.RUNNING)
|
||||
user_credit = get_user_credit_model()
|
||||
|
||||
extra_exec_kwargs = {}
|
||||
if input_credentials:
|
||||
extra_exec_kwargs["credentials"] = input_credentials
|
||||
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
||||
# changes during execution. ⚠️ This means a set of credentials can only be used by
|
||||
# one (running) block at a time; simultaneous execution of blocks using same
|
||||
# credentials is not supported.
|
||||
creds_lock = None
|
||||
if CREDENTIALS_FIELD_NAME in input_data:
|
||||
credentials_meta = CredentialsMetaInput(**input_data[CREDENTIALS_FIELD_NAME])
|
||||
credentials, creds_lock = creds_manager.acquire(user_id, credentials_meta.id)
|
||||
extra_exec_kwargs["credentials"] = credentials
|
||||
|
||||
output_size = 0
|
||||
try:
|
||||
credit = wait(user_credit.get_or_refill_credit(user_id))
|
||||
if credit < 0:
|
||||
raise ValueError(f"Insufficient credit: {credit}")
|
||||
end_status = ExecutionStatus.COMPLETED
|
||||
credit = db_client.get_or_refill_credit(user_id)
|
||||
if credit < 0:
|
||||
raise ValueError(f"Insufficient credit: {credit}")
|
||||
|
||||
try:
|
||||
for output_name, output_data in node_block.execute(
|
||||
input_data, **extra_exec_kwargs
|
||||
):
|
||||
output_size += len(json.dumps(output_data))
|
||||
log_metadata.info("Node produced output", output_name=output_data)
|
||||
wait(upsert_execution_output(node_exec_id, output_name, output_data))
|
||||
db_client.upsert_execution_output(node_exec_id, output_name, output_data)
|
||||
|
||||
for execution in _enqueue_next_nodes(
|
||||
api_client=api_client,
|
||||
loop=loop,
|
||||
db_client=db_client,
|
||||
node=node,
|
||||
output=(output_name, output_data),
|
||||
user_id=user_id,
|
||||
@@ -192,41 +186,52 @@ def execute_node(
|
||||
):
|
||||
yield execution
|
||||
|
||||
r = update_execution(ExecutionStatus.COMPLETED)
|
||||
s = input_size + output_size
|
||||
t = (
|
||||
(r.end_time - r.start_time).total_seconds()
|
||||
if r.end_time and r.start_time
|
||||
else 0
|
||||
)
|
||||
wait(user_credit.spend_credits(user_id, credit, node_block, input_data, s, t))
|
||||
|
||||
except Exception as e:
|
||||
end_status = ExecutionStatus.FAILED
|
||||
error_msg = str(e)
|
||||
log_metadata.exception(f"Node execution failed with error {error_msg}")
|
||||
wait(upsert_execution_output(node_exec_id, "error", error_msg))
|
||||
update_execution(ExecutionStatus.FAILED)
|
||||
db_client.upsert_execution_output(node_exec_id, "error", error_msg)
|
||||
|
||||
for execution in _enqueue_next_nodes(
|
||||
db_client=db_client,
|
||||
node=node,
|
||||
output=("error", error_msg),
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
log_metadata=log_metadata,
|
||||
):
|
||||
yield execution
|
||||
|
||||
raise e
|
||||
|
||||
finally:
|
||||
# Ensure credentials are released even if execution fails
|
||||
if creds_lock:
|
||||
try:
|
||||
creds_lock.release()
|
||||
except Exception as e:
|
||||
log_metadata.error(f"Failed to release credentials lock: {e}")
|
||||
|
||||
# Update execution status and spend credits
|
||||
res = update_execution(end_status)
|
||||
if end_status == ExecutionStatus.COMPLETED:
|
||||
s = input_size + output_size
|
||||
t = (
|
||||
(res.end_time - res.start_time).total_seconds()
|
||||
if res.end_time and res.start_time
|
||||
else 0
|
||||
)
|
||||
db_client.spend_credits(user_id, credit, node_block.id, input_data, s, t)
|
||||
|
||||
# Update execution stats
|
||||
if execution_stats is not None:
|
||||
execution_stats.update(node_block.execution_stats)
|
||||
execution_stats["input_size"] = input_size
|
||||
execution_stats["output_size"] = output_size
|
||||
|
||||
|
||||
@contextmanager
|
||||
def synchronized(api_client: "AgentServer", key: Any):
|
||||
api_client.acquire_lock(key)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
api_client.release_lock(key)
|
||||
|
||||
|
||||
def _enqueue_next_nodes(
|
||||
api_client: "AgentServer",
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
db_client: "DatabaseManager",
|
||||
node: Node,
|
||||
output: BlockData,
|
||||
user_id: str,
|
||||
@@ -234,16 +239,14 @@ def _enqueue_next_nodes(
|
||||
graph_id: str,
|
||||
log_metadata: LogMetadata,
|
||||
) -> list[NodeExecution]:
|
||||
def wait(f: Coroutine[Any, Any, T]) -> T:
|
||||
return loop.run_until_complete(f)
|
||||
|
||||
def add_enqueued_execution(
|
||||
node_exec_id: str, node_id: str, data: BlockInput
|
||||
) -> NodeExecution:
|
||||
exec_update = wait(
|
||||
update_execution_status(node_exec_id, ExecutionStatus.QUEUED, data)
|
||||
exec_update = db_client.update_execution_status(
|
||||
node_exec_id, ExecutionStatus.QUEUED, data
|
||||
)
|
||||
api_client.send_execution_update(exec_update.model_dump())
|
||||
db_client.send_execution_update(exec_update.model_dump())
|
||||
return NodeExecution(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
@@ -263,20 +266,18 @@ def _enqueue_next_nodes(
|
||||
if next_data is None:
|
||||
return enqueued_executions
|
||||
|
||||
next_node = wait(get_node(next_node_id))
|
||||
next_node = db_client.get_node(next_node_id)
|
||||
|
||||
# Multiple node can register the same next node, we need this to be atomic
|
||||
# To avoid same execution to be enqueued multiple times,
|
||||
# Or the same input to be consumed multiple times.
|
||||
with synchronized(api_client, ("upsert_input", next_node_id, graph_exec_id)):
|
||||
with synchronized(f"upsert_input-{next_node_id}-{graph_exec_id}"):
|
||||
# Add output data to the earliest incomplete execution, or create a new one.
|
||||
next_node_exec_id, next_node_input = wait(
|
||||
upsert_execution_input(
|
||||
node_id=next_node_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
input_name=next_input_name,
|
||||
input_data=next_data,
|
||||
)
|
||||
next_node_exec_id, next_node_input = db_client.upsert_execution_input(
|
||||
node_id=next_node_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
input_name=next_input_name,
|
||||
input_data=next_data,
|
||||
)
|
||||
|
||||
# Complete missing static input pins data using the last execution input.
|
||||
@@ -286,8 +287,8 @@ def _enqueue_next_nodes(
|
||||
if link.is_static and link.sink_name not in next_node_input
|
||||
}
|
||||
if static_link_names and (
|
||||
latest_execution := wait(
|
||||
get_latest_execution(next_node_id, graph_exec_id)
|
||||
latest_execution := db_client.get_latest_execution(
|
||||
next_node_id, graph_exec_id
|
||||
)
|
||||
):
|
||||
for name in static_link_names:
|
||||
@@ -314,7 +315,9 @@ def _enqueue_next_nodes(
|
||||
|
||||
# If link is static, there could be some incomplete executions waiting for it.
|
||||
# Load and complete the input missing input data, and try to re-enqueue them.
|
||||
for iexec in wait(get_incomplete_executions(next_node_id, graph_exec_id)):
|
||||
for iexec in db_client.get_incomplete_executions(
|
||||
next_node_id, graph_exec_id
|
||||
):
|
||||
idata = iexec.input_data
|
||||
ineid = iexec.node_exec_id
|
||||
|
||||
@@ -399,12 +402,6 @@ def validate_exec(
|
||||
return data, node_block.name
|
||||
|
||||
|
||||
def get_agent_server_client() -> "AgentServer":
|
||||
from backend.server.rest_api import AgentServer
|
||||
|
||||
return get_service_client(AgentServer, Config().agent_server_port)
|
||||
|
||||
|
||||
class Executor:
|
||||
"""
|
||||
This class contains event handlers for the process pool executor events.
|
||||
@@ -433,12 +430,11 @@ class Executor:
|
||||
@classmethod
|
||||
def on_node_executor_start(cls):
|
||||
configure_logging()
|
||||
|
||||
cls.loop = asyncio.new_event_loop()
|
||||
set_service_name("NodeExecutor")
|
||||
redis.connect()
|
||||
cls.pid = os.getpid()
|
||||
|
||||
cls.loop.run_until_complete(db.connect())
|
||||
cls.agent_server_client = get_agent_server_client()
|
||||
cls.db_client = get_db_client()
|
||||
cls.creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
# Set up shutdown handlers
|
||||
cls.shutdown_lock = threading.Lock()
|
||||
@@ -452,19 +448,23 @@ class Executor:
|
||||
if not cls.shutdown_lock.acquire(blocking=False):
|
||||
return # already shutting down
|
||||
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB...")
|
||||
cls.loop.run_until_complete(db.disconnect())
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
|
||||
cls.creds_manager.release_all_locks()
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
|
||||
|
||||
@classmethod
|
||||
def on_node_executor_sigterm(cls):
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ⚠️ SIGTERM received")
|
||||
if not cls.shutdown_lock.acquire(blocking=False):
|
||||
return # already shutting down, no need to self-terminate
|
||||
return # already shutting down
|
||||
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ⏳ Disconnecting DB...")
|
||||
cls.loop.run_until_complete(db.disconnect())
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ✅ Finished cleanup")
|
||||
llprint(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
|
||||
cls.creds_manager.release_all_locks()
|
||||
llprint(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
llprint(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
|
||||
sys.exit(0)
|
||||
|
||||
@classmethod
|
||||
@@ -473,7 +473,6 @@ class Executor:
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
node_exec: NodeExecution,
|
||||
input_credentials: Credentials | None,
|
||||
):
|
||||
log_metadata = LogMetadata(
|
||||
user_id=node_exec.user_id,
|
||||
@@ -486,13 +485,13 @@ class Executor:
|
||||
|
||||
execution_stats = {}
|
||||
timing_info, _ = cls._on_node_execution(
|
||||
q, node_exec, input_credentials, log_metadata, execution_stats
|
||||
q, node_exec, log_metadata, execution_stats
|
||||
)
|
||||
execution_stats["walltime"] = timing_info.wall_time
|
||||
execution_stats["cputime"] = timing_info.cpu_time
|
||||
|
||||
cls.loop.run_until_complete(
|
||||
update_node_execution_stats(node_exec.node_exec_id, execution_stats)
|
||||
cls.db_client.update_node_execution_stats(
|
||||
node_exec.node_exec_id, execution_stats
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -501,14 +500,13 @@ class Executor:
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
node_exec: NodeExecution,
|
||||
input_credentials: Credentials | None,
|
||||
log_metadata: LogMetadata,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
try:
|
||||
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
|
||||
for execution in execute_node(
|
||||
cls.loop, cls.agent_server_client, node_exec, input_credentials, stats
|
||||
cls.db_client, cls.creds_manager, node_exec, stats
|
||||
):
|
||||
q.add(execution)
|
||||
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
|
||||
@@ -520,12 +518,11 @@ class Executor:
|
||||
@classmethod
|
||||
def on_graph_executor_start(cls):
|
||||
configure_logging()
|
||||
set_service_name("GraphExecutor")
|
||||
|
||||
cls.pool_size = Config().num_node_workers
|
||||
cls.loop = asyncio.new_event_loop()
|
||||
cls.db_client = get_db_client()
|
||||
cls.pool_size = settings.config.num_node_workers
|
||||
cls.pid = os.getpid()
|
||||
|
||||
cls.loop.run_until_complete(db.connect())
|
||||
cls._init_node_executor_pool()
|
||||
logger.info(
|
||||
f"Graph executor {cls.pid} started with {cls.pool_size} node workers"
|
||||
@@ -537,8 +534,6 @@ class Executor:
|
||||
@classmethod
|
||||
def on_graph_executor_stop(cls):
|
||||
prefix = f"[on_graph_executor_stop {cls.pid}]"
|
||||
logger.info(f"{prefix} ⏳ Disconnecting DB...")
|
||||
cls.loop.run_until_complete(db.disconnect())
|
||||
logger.info(f"{prefix} ⏳ Terminating node executor pool...")
|
||||
cls.executor.terminate()
|
||||
logger.info(f"{prefix} ✅ Finished cleanup")
|
||||
@@ -561,19 +556,16 @@ class Executor:
|
||||
node_eid="*",
|
||||
block_name="-",
|
||||
)
|
||||
timing_info, node_count = cls._on_graph_execution(
|
||||
timing_info, (node_count, error) = cls._on_graph_execution(
|
||||
graph_exec, cancel, log_metadata
|
||||
)
|
||||
|
||||
cls.loop.run_until_complete(
|
||||
update_graph_execution_stats(
|
||||
graph_exec.graph_exec_id,
|
||||
{
|
||||
"walltime": timing_info.wall_time,
|
||||
"cputime": timing_info.cpu_time,
|
||||
"nodecount": node_count,
|
||||
},
|
||||
)
|
||||
cls.db_client.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
error=error,
|
||||
wall_time=timing_info.wall_time,
|
||||
cpu_time=timing_info.cpu_time,
|
||||
node_count=node_count,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -583,9 +575,15 @@ class Executor:
|
||||
graph_exec: GraphExecution,
|
||||
cancel: threading.Event,
|
||||
log_metadata: LogMetadata,
|
||||
) -> int:
|
||||
) -> tuple[int, Exception | None]:
|
||||
"""
|
||||
Returns:
|
||||
The number of node executions completed.
|
||||
The error that occurred during the execution.
|
||||
"""
|
||||
log_metadata.info(f"Start graph execution {graph_exec.graph_exec_id}")
|
||||
n_node_executions = 0
|
||||
error = None
|
||||
finished = False
|
||||
|
||||
def cancel_handler():
|
||||
@@ -619,7 +617,8 @@ class Executor:
|
||||
|
||||
while not queue.empty():
|
||||
if cancel.is_set():
|
||||
return n_node_executions
|
||||
error = RuntimeError("Execution is cancelled")
|
||||
return n_node_executions, error
|
||||
|
||||
exec_data = queue.get()
|
||||
|
||||
@@ -638,11 +637,7 @@ class Executor:
|
||||
)
|
||||
running_executions[exec_data.node_id] = cls.executor.apply_async(
|
||||
cls.on_node_execution,
|
||||
(
|
||||
queue,
|
||||
exec_data,
|
||||
graph_exec.node_input_credentials.get(exec_data.node_id),
|
||||
),
|
||||
(queue, exec_data),
|
||||
callback=make_exec_callback(exec_data),
|
||||
)
|
||||
|
||||
@@ -653,7 +648,8 @@ class Executor:
|
||||
)
|
||||
for node_id, execution in list(running_executions.items()):
|
||||
if cancel.is_set():
|
||||
return n_node_executions
|
||||
error = RuntimeError("Execution is cancelled")
|
||||
return n_node_executions, error
|
||||
|
||||
if not queue.empty():
|
||||
break # yield to parent loop to execute new queue items
|
||||
@@ -666,29 +662,37 @@ class Executor:
|
||||
log_metadata.exception(
|
||||
f"Failed graph execution {graph_exec.graph_exec_id}: {e}"
|
||||
)
|
||||
error = e
|
||||
finally:
|
||||
if not cancel.is_set():
|
||||
finished = True
|
||||
cancel.set()
|
||||
cancel_thread.join()
|
||||
return n_node_executions
|
||||
return n_node_executions, error
|
||||
|
||||
|
||||
class ExecutionManager(AppService):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(port=Config().execution_manager_port)
|
||||
self.use_db = True
|
||||
super().__init__()
|
||||
self.use_redis = True
|
||||
self.use_supabase = True
|
||||
self.pool_size = Config().num_graph_workers
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.queue = ExecutionQueue[GraphExecution]()
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return settings.config.execution_manager_port
|
||||
|
||||
def run_service(self):
|
||||
from autogpt_libs.supabase_integration_credentials_store import (
|
||||
SupabaseIntegrationCredentialsStore,
|
||||
)
|
||||
|
||||
self.credentials_store = SupabaseIntegrationCredentialsStore(self.supabase)
|
||||
self.credentials_store = SupabaseIntegrationCredentialsStore(
|
||||
redis=redis.get_redis()
|
||||
)
|
||||
self.executor = ProcessPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
initializer=Executor.on_graph_executor_start,
|
||||
@@ -719,19 +723,19 @@ class ExecutionManager(AppService):
|
||||
super().cleanup()
|
||||
|
||||
@property
|
||||
def agent_server_client(self) -> "AgentServer":
|
||||
return get_agent_server_client()
|
||||
def db_client(self) -> "DatabaseManager":
|
||||
return get_db_client()
|
||||
|
||||
@expose
|
||||
def add_execution(
|
||||
self, graph_id: str, data: BlockInput, user_id: str
|
||||
) -> dict[str, Any]:
|
||||
graph: Graph | None = self.run_and_wait(get_graph(graph_id, user_id=user_id))
|
||||
graph: Graph | None = self.db_client.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise Exception(f"Graph #{graph_id} not found.")
|
||||
|
||||
graph.validate_graph(for_run=True)
|
||||
node_input_credentials = self._get_node_input_credentials(graph, user_id)
|
||||
self._validate_node_input_credentials(graph, user_id)
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
@@ -754,13 +758,11 @@ class ExecutionManager(AppService):
|
||||
else:
|
||||
nodes_input.append((node.id, input_data))
|
||||
|
||||
graph_exec_id, node_execs = self.run_and_wait(
|
||||
create_graph_execution(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
nodes_input=nodes_input,
|
||||
user_id=user_id,
|
||||
)
|
||||
graph_exec_id, node_execs = self.db_client.create_graph_execution(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
nodes_input=nodes_input,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
starting_node_execs = []
|
||||
@@ -775,19 +777,16 @@ class ExecutionManager(AppService):
|
||||
data=node_exec.input_data,
|
||||
)
|
||||
)
|
||||
exec_update = self.run_and_wait(
|
||||
update_execution_status(
|
||||
node_exec.node_exec_id, ExecutionStatus.QUEUED, node_exec.input_data
|
||||
)
|
||||
exec_update = self.db_client.update_execution_status(
|
||||
node_exec.node_exec_id, ExecutionStatus.QUEUED, node_exec.input_data
|
||||
)
|
||||
self.agent_server_client.send_execution_update(exec_update.model_dump())
|
||||
self.db_client.send_execution_update(exec_update.model_dump())
|
||||
|
||||
graph_exec = GraphExecution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
start_node_execs=starting_node_execs,
|
||||
node_input_credentials=node_input_credentials,
|
||||
)
|
||||
self.queue.add(graph_exec)
|
||||
|
||||
@@ -816,30 +815,22 @@ class ExecutionManager(AppService):
|
||||
future.result()
|
||||
|
||||
# Update the status of the unfinished node executions
|
||||
node_execs = self.run_and_wait(get_execution_results(graph_exec_id))
|
||||
node_execs = self.db_client.get_execution_results(graph_exec_id)
|
||||
for node_exec in node_execs:
|
||||
if node_exec.status not in (
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.FAILED,
|
||||
):
|
||||
self.run_and_wait(
|
||||
upsert_execution_output(
|
||||
node_exec.node_exec_id, "error", "TERMINATED"
|
||||
)
|
||||
self.db_client.upsert_execution_output(
|
||||
node_exec.node_exec_id, "error", "TERMINATED"
|
||||
)
|
||||
exec_update = self.run_and_wait(
|
||||
update_execution_status(
|
||||
node_exec.node_exec_id, ExecutionStatus.FAILED
|
||||
)
|
||||
exec_update = self.db_client.update_execution_status(
|
||||
node_exec.node_exec_id, ExecutionStatus.FAILED
|
||||
)
|
||||
self.agent_server_client.send_execution_update(exec_update.model_dump())
|
||||
self.db_client.send_execution_update(exec_update.model_dump())
|
||||
|
||||
def _get_node_input_credentials(
|
||||
self, graph: Graph, user_id: str
|
||||
) -> dict[str, Credentials]:
|
||||
"""Gets all credentials for all nodes of the graph"""
|
||||
|
||||
node_credentials: dict[str, Credentials] = {}
|
||||
def _validate_node_input_credentials(self, graph: Graph, user_id: str):
|
||||
"""Checks all credentials for all nodes of the graph"""
|
||||
|
||||
for node in graph.nodes:
|
||||
block = get_block(node.block_id)
|
||||
@@ -882,9 +873,26 @@ class ExecutionManager(AppService):
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch"
|
||||
)
|
||||
node_credentials[node.id] = credentials
|
||||
|
||||
return node_credentials
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_client() -> "DatabaseManager":
|
||||
from backend.executor import DatabaseManager
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def synchronized(key: str, timeout: int = 60):
|
||||
lock: RedisLock = redis.get_redis().lock(f"lock:{key}", timeout=timeout)
|
||||
try:
|
||||
lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
def llprint(message: str):
|
||||
|
||||
@@ -4,9 +4,16 @@ from datetime import datetime
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.data import schedule as model
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.schedule import (
|
||||
ExecutionSchedule,
|
||||
add_schedule,
|
||||
get_active_schedules,
|
||||
get_schedules,
|
||||
update_schedule,
|
||||
)
|
||||
from backend.executor.manager import ExecutionManager
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.settings import Config
|
||||
@@ -19,16 +26,21 @@ def log(msg, **kwargs):
|
||||
|
||||
|
||||
class ExecutionScheduler(AppService):
|
||||
|
||||
def __init__(self, refresh_interval=10):
|
||||
super().__init__(port=Config().execution_scheduler_port)
|
||||
super().__init__()
|
||||
self.use_db = True
|
||||
self.last_check = datetime.min
|
||||
self.refresh_interval = refresh_interval
|
||||
self.use_redis = False
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return Config().execution_scheduler_port
|
||||
|
||||
@property
|
||||
def execution_manager_client(self) -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager, Config().execution_manager_port)
|
||||
@thread_cached
|
||||
def execution_client(self) -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
def run_service(self):
|
||||
scheduler = BackgroundScheduler()
|
||||
@@ -38,7 +50,7 @@ class ExecutionScheduler(AppService):
|
||||
time.sleep(self.refresh_interval)
|
||||
|
||||
def __refresh_jobs_from_db(self, scheduler: BackgroundScheduler):
|
||||
schedules = self.run_and_wait(model.get_active_schedules(self.last_check))
|
||||
schedules = self.run_and_wait(get_active_schedules(self.last_check))
|
||||
for schedule in schedules:
|
||||
if schedule.last_updated:
|
||||
self.last_check = max(self.last_check, schedule.last_updated)
|
||||
@@ -60,14 +72,13 @@ class ExecutionScheduler(AppService):
|
||||
def __execute_graph(self, graph_id: str, input_data: dict, user_id: str):
|
||||
try:
|
||||
log(f"Executing recurring job for graph #{graph_id}")
|
||||
execution_manager = self.execution_manager_client
|
||||
execution_manager.add_execution(graph_id, input_data, user_id)
|
||||
self.execution_client.add_execution(graph_id, input_data, user_id)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error executing graph {graph_id}: {e}")
|
||||
|
||||
@expose
|
||||
def update_schedule(self, schedule_id: str, is_enabled: bool, user_id: str) -> str:
|
||||
self.run_and_wait(model.update_schedule(schedule_id, is_enabled, user_id))
|
||||
self.run_and_wait(update_schedule(schedule_id, is_enabled, user_id))
|
||||
return schedule_id
|
||||
|
||||
@expose
|
||||
@@ -79,17 +90,16 @@ class ExecutionScheduler(AppService):
|
||||
input_data: BlockInput,
|
||||
user_id: str,
|
||||
) -> str:
|
||||
schedule = model.ExecutionSchedule(
|
||||
schedule = ExecutionSchedule(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
graph_version=graph_version,
|
||||
schedule=cron,
|
||||
input_data=input_data,
|
||||
)
|
||||
return self.run_and_wait(model.add_schedule(schedule)).id
|
||||
return self.run_and_wait(add_schedule(schedule)).id
|
||||
|
||||
@expose
|
||||
def get_execution_schedules(self, graph_id: str, user_id: str) -> dict[str, str]:
|
||||
query = model.get_schedules(graph_id, user_id=user_id)
|
||||
schedules: list[model.ExecutionSchedule] = self.run_and_wait(query)
|
||||
schedules = self.run_and_wait(get_schedules(graph_id, user_id=user_id))
|
||||
return {v.id: v.schedule for v in schedules}
|
||||
|
||||
170
autogpt_platform/backend/backend/integrations/creds_manager.py
Normal file
170
autogpt_platform/backend/backend/integrations/creds_manager.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store import (
|
||||
Credentials,
|
||||
SupabaseIntegrationCredentialsStore,
|
||||
)
|
||||
from autogpt_libs.utils.synchronize import RedisKeyedMutex
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from backend.data import redis
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class IntegrationCredentialsManager:
|
||||
"""
|
||||
Handles the lifecycle of integration credentials.
|
||||
- Automatically refreshes requested credentials if needed.
|
||||
- Uses locking mechanisms to ensure system-wide consistency and
|
||||
prevent invalidation of in-use tokens.
|
||||
|
||||
### ⚠️ Gotcha
|
||||
With `acquire(..)`, credentials can only be in use in one place at a time (e.g. one
|
||||
block execution).
|
||||
|
||||
### Locking mechanism
|
||||
- Because *getting* credentials can result in a refresh (= *invalidation* +
|
||||
*replacement*) of the stored credentials, *getting* is an operation that
|
||||
potentially requires read/write access.
|
||||
- Checking whether a token has to be refreshed is subject to an additional `refresh`
|
||||
scoped lock to prevent unnecessary sequential refreshes when multiple executions
|
||||
try to access the same credentials simultaneously.
|
||||
- We MUST lock credentials while in use to prevent them from being invalidated while
|
||||
they are in use, e.g. because they are being refreshed by a different part
|
||||
of the system.
|
||||
- The `!time_sensitive` lock in `acquire(..)` is part of a two-tier locking
|
||||
mechanism in which *updating* gets priority over *getting* credentials.
|
||||
This is to prevent a long queue of waiting *get* requests from blocking essential
|
||||
credential refreshes or user-initiated updates.
|
||||
|
||||
It is possible to implement a reader/writer locking system where either multiple
|
||||
readers or a single writer can have simultaneous access, but this would add a lot of
|
||||
complexity to the mechanism. I don't expect the current ("simple") mechanism to
|
||||
cause so much latency that it's worth implementing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
redis_conn = redis.get_redis()
|
||||
self._locks = RedisKeyedMutex(redis_conn)
|
||||
self.store = SupabaseIntegrationCredentialsStore(redis=redis_conn)
|
||||
|
||||
def create(self, user_id: str, credentials: Credentials) -> None:
|
||||
return self.store.add_creds(user_id, credentials)
|
||||
|
||||
def exists(self, user_id: str, credentials_id: str) -> bool:
|
||||
return self.store.get_creds_by_id(user_id, credentials_id) is not None
|
||||
|
||||
def get(
|
||||
self, user_id: str, credentials_id: str, lock: bool = True
|
||||
) -> Credentials | None:
|
||||
credentials = self.store.get_creds_by_id(user_id, credentials_id)
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
# Refresh OAuth credentials if needed
|
||||
if credentials.type == "oauth2" and credentials.access_token_expires_at:
|
||||
logger.debug(
|
||||
f"Credentials #{credentials.id} expire at "
|
||||
f"{datetime.fromtimestamp(credentials.access_token_expires_at)}; "
|
||||
f"current time is {datetime.now()}"
|
||||
)
|
||||
|
||||
with self._locked(user_id, credentials_id, "refresh"):
|
||||
oauth_handler = _get_provider_oauth_handler(credentials.provider)
|
||||
if oauth_handler.needs_refresh(credentials):
|
||||
logger.debug(
|
||||
f"Refreshing '{credentials.provider}' "
|
||||
f"credentials #{credentials.id}"
|
||||
)
|
||||
_lock = None
|
||||
if lock:
|
||||
# Wait until the credentials are no longer in use anywhere
|
||||
_lock = self._acquire_lock(user_id, credentials_id)
|
||||
|
||||
fresh_credentials = oauth_handler.refresh_tokens(credentials)
|
||||
self.store.update_creds(user_id, fresh_credentials)
|
||||
if _lock:
|
||||
_lock.release()
|
||||
|
||||
credentials = fresh_credentials
|
||||
else:
|
||||
logger.debug(f"Credentials #{credentials.id} never expire")
|
||||
|
||||
return credentials
|
||||
|
||||
def acquire(
|
||||
self, user_id: str, credentials_id: str
|
||||
) -> tuple[Credentials, RedisLock]:
|
||||
"""
|
||||
⚠️ WARNING: this locks credentials system-wide and blocks both acquiring
|
||||
and updating them elsewhere until the lock is released.
|
||||
See the class docstring for more info.
|
||||
"""
|
||||
# Use a low-priority (!time_sensitive) locking queue on top of the general lock
|
||||
# to allow priority access for refreshing/updating the tokens.
|
||||
with self._locked(user_id, credentials_id, "!time_sensitive"):
|
||||
lock = self._acquire_lock(user_id, credentials_id)
|
||||
credentials = self.get(user_id, credentials_id, lock=False)
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
f"Credentials #{credentials_id} for user #{user_id} not found"
|
||||
)
|
||||
return credentials, lock
|
||||
|
||||
def update(self, user_id: str, updated: Credentials) -> None:
|
||||
with self._locked(user_id, updated.id):
|
||||
self.store.update_creds(user_id, updated)
|
||||
|
||||
def delete(self, user_id: str, credentials_id: str) -> None:
|
||||
with self._locked(user_id, credentials_id):
|
||||
self.store.delete_creds_by_id(user_id, credentials_id)
|
||||
|
||||
# -- Locking utilities -- #
|
||||
|
||||
def _acquire_lock(self, user_id: str, credentials_id: str, *args: str) -> RedisLock:
|
||||
key = (
|
||||
self.store.db_manager,
|
||||
f"user:{user_id}",
|
||||
f"credentials:{credentials_id}",
|
||||
*args,
|
||||
)
|
||||
return self._locks.acquire(key)
|
||||
|
||||
@contextmanager
|
||||
def _locked(self, user_id: str, credentials_id: str, *args: str):
|
||||
lock = self._acquire_lock(user_id, credentials_id, *args)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
def release_all_locks(self):
|
||||
"""Call this on process termination to ensure all locks are released"""
|
||||
self._locks.release_all_locks()
|
||||
self.store.locks.release_all_locks()
|
||||
|
||||
|
||||
def _get_provider_oauth_handler(provider_name: str) -> BaseOAuthHandler:
|
||||
if provider_name not in HANDLERS_BY_NAME:
|
||||
raise KeyError(f"Unknown provider '{provider_name}'")
|
||||
|
||||
client_id = getattr(settings.secrets, f"{provider_name}_client_id")
|
||||
client_secret = getattr(settings.secrets, f"{provider_name}_client_secret")
|
||||
if not (client_id and client_secret):
|
||||
raise Exception( # TODO: ConfigError
|
||||
f"Integration with provider '{provider_name}' is not configured",
|
||||
)
|
||||
|
||||
handler_class = HANDLERS_BY_NAME[provider_name]
|
||||
frontend_base_url = settings.config.frontend_base_url
|
||||
return handler_class(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
|
||||
)
|
||||
@@ -3,6 +3,7 @@ from .github import GitHubOAuthHandler
|
||||
from .google import GoogleOAuthHandler
|
||||
from .notion import NotionOAuthHandler
|
||||
|
||||
# --8<-- [start:HANDLERS_BY_NAMEExample]
|
||||
HANDLERS_BY_NAME: dict[str, type[BaseOAuthHandler]] = {
|
||||
handler.PROVIDER_NAME: handler
|
||||
for handler in [
|
||||
@@ -11,5 +12,6 @@ HANDLERS_BY_NAME: dict[str, type[BaseOAuthHandler]] = {
|
||||
NotionOAuthHandler,
|
||||
]
|
||||
}
|
||||
# --8<-- [end:HANDLERS_BY_NAMEExample]
|
||||
|
||||
__all__ = ["HANDLERS_BY_NAME"]
|
||||
|
||||
@@ -1,31 +1,56 @@
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import ClassVar
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store import OAuth2Credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseOAuthHandler(ABC):
|
||||
# --8<-- [start:BaseOAuthHandler1]
|
||||
PROVIDER_NAME: ClassVar[str]
|
||||
DEFAULT_SCOPES: ClassVar[list[str]] = []
|
||||
# --8<-- [end:BaseOAuthHandler1]
|
||||
|
||||
@abstractmethod
|
||||
# --8<-- [start:BaseOAuthHandler2]
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str): ...
|
||||
|
||||
# --8<-- [end:BaseOAuthHandler2]
|
||||
|
||||
@abstractmethod
|
||||
# --8<-- [start:BaseOAuthHandler3]
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
# --8<-- [end:BaseOAuthHandler3]
|
||||
"""Constructs a login URL that the user can be redirected to"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def exchange_code_for_tokens(self, code: str) -> OAuth2Credentials:
|
||||
# --8<-- [start:BaseOAuthHandler4]
|
||||
def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str]
|
||||
) -> OAuth2Credentials:
|
||||
# --8<-- [end:BaseOAuthHandler4]
|
||||
"""Exchanges the acquired authorization code from login for a set of tokens"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
# --8<-- [start:BaseOAuthHandler5]
|
||||
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
# --8<-- [end:BaseOAuthHandler5]
|
||||
"""Implements the token refresh mechanism"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
# --8<-- [start:BaseOAuthHandler6]
|
||||
def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
# --8<-- [end:BaseOAuthHandler6]
|
||||
"""Revokes the given token at provider,
|
||||
returns False provider does not support it"""
|
||||
...
|
||||
|
||||
def refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
if credentials.provider != self.PROVIDER_NAME:
|
||||
raise ValueError(
|
||||
@@ -46,3 +71,11 @@ class BaseOAuthHandler(ABC):
|
||||
credentials.access_token_expires_at is not None
|
||||
and credentials.access_token_expires_at < int(time.time()) + 300
|
||||
)
|
||||
|
||||
def handle_default_scopes(self, scopes: list[str]) -> list[str]:
|
||||
"""Handles the default scopes for the provider"""
|
||||
# If scopes are empty, use the default scopes for the provider
|
||||
if not scopes:
|
||||
logger.debug(f"Using default scopes for provider {self.PROVIDER_NAME}")
|
||||
scopes = self.DEFAULT_SCOPES
|
||||
return scopes
|
||||
|
||||
@@ -8,6 +8,7 @@ from autogpt_libs.supabase_integration_credentials_store import OAuth2Credential
|
||||
from .base import BaseOAuthHandler
|
||||
|
||||
|
||||
# --8<-- [start:GithubOAuthHandlerExample]
|
||||
class GitHubOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
Based on the documentation at:
|
||||
@@ -23,7 +24,6 @@ class GitHubOAuthHandler(BaseOAuthHandler):
|
||||
""" # noqa
|
||||
|
||||
PROVIDER_NAME = "github"
|
||||
EMAIL_ENDPOINT = "https://api.github.com/user/emails"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
@@ -31,6 +31,7 @@ class GitHubOAuthHandler(BaseOAuthHandler):
|
||||
self.redirect_uri = redirect_uri
|
||||
self.auth_base_url = "https://github.com/login/oauth/authorize"
|
||||
self.token_url = "https://github.com/login/oauth/access_token"
|
||||
self.revoke_url = "https://api.github.com/applications/{client_id}/token"
|
||||
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
params = {
|
||||
@@ -41,9 +42,29 @@ class GitHubOAuthHandler(BaseOAuthHandler):
|
||||
}
|
||||
return f"{self.auth_base_url}?{urlencode(params)}"
|
||||
|
||||
def exchange_code_for_tokens(self, code: str) -> OAuth2Credentials:
|
||||
def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str]
|
||||
) -> OAuth2Credentials:
|
||||
return self._request_tokens({"code": code, "redirect_uri": self.redirect_uri})
|
||||
|
||||
def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
if not credentials.access_token:
|
||||
raise ValueError("No access token to revoke")
|
||||
|
||||
headers = {
|
||||
"Accept": "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
}
|
||||
|
||||
response = requests.delete(
|
||||
url=self.revoke_url.format(client_id=self.client_id),
|
||||
auth=(self.client_id, self.client_secret),
|
||||
headers=headers,
|
||||
json={"access_token": credentials.access_token.get_secret_value()},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
|
||||
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
if not credentials.refresh_token:
|
||||
return credentials
|
||||
@@ -117,3 +138,6 @@ class GitHubOAuthHandler(BaseOAuthHandler):
|
||||
|
||||
# Get the login (username)
|
||||
return response.json().get("login")
|
||||
|
||||
|
||||
# --8<-- [end:GithubOAuthHandlerExample]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import logging
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store import OAuth2Credentials
|
||||
from google.auth.external_account_authorized_user import (
|
||||
Credentials as ExternalAccountCredentials,
|
||||
@@ -9,7 +11,10 @@ from pydantic import SecretStr
|
||||
|
||||
from .base import BaseOAuthHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# --8<-- [start:GoogleOAuthHandlerExample]
|
||||
class GoogleOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
Based on the documentation at https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
@@ -17,15 +22,24 @@ class GoogleOAuthHandler(BaseOAuthHandler):
|
||||
|
||||
PROVIDER_NAME = "google"
|
||||
EMAIL_ENDPOINT = "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||
DEFAULT_SCOPES = [
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"openid",
|
||||
]
|
||||
# --8<-- [end:GoogleOAuthHandlerExample]
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.token_uri = "https://oauth2.googleapis.com/token"
|
||||
self.revoke_uri = "https://oauth2.googleapis.com/revoke"
|
||||
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
flow = self._setup_oauth_flow(scopes)
|
||||
all_scopes = list(set(scopes + self.DEFAULT_SCOPES))
|
||||
logger.debug(f"Setting up OAuth flow with scopes: {all_scopes}")
|
||||
flow = self._setup_oauth_flow(all_scopes)
|
||||
flow.redirect_uri = self.redirect_uri
|
||||
authorization_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
@@ -35,29 +49,67 @@ class GoogleOAuthHandler(BaseOAuthHandler):
|
||||
)
|
||||
return authorization_url
|
||||
|
||||
def exchange_code_for_tokens(self, code: str) -> OAuth2Credentials:
|
||||
flow = self._setup_oauth_flow(None)
|
||||
def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str]
|
||||
) -> OAuth2Credentials:
|
||||
logger.debug(f"Exchanging code for tokens with scopes: {scopes}")
|
||||
|
||||
# Use the scopes from the initial request
|
||||
flow = self._setup_oauth_flow(scopes)
|
||||
flow.redirect_uri = self.redirect_uri
|
||||
flow.fetch_token(code=code)
|
||||
|
||||
logger.debug("Fetching token from Google")
|
||||
|
||||
# Disable scope check in fetch_token
|
||||
flow.oauth2session.scope = None
|
||||
token = flow.fetch_token(code=code)
|
||||
logger.debug("Token fetched successfully")
|
||||
|
||||
# Get the actual scopes granted by Google
|
||||
granted_scopes: list[str] = token.get("scope", [])
|
||||
|
||||
logger.debug(f"Scopes granted by Google: {granted_scopes}")
|
||||
|
||||
google_creds = flow.credentials
|
||||
username = self._request_email(google_creds)
|
||||
logger.debug(f"Received credentials: {google_creds}")
|
||||
|
||||
logger.debug("Requesting user email")
|
||||
username = self._request_email(google_creds)
|
||||
logger.debug(f"User email retrieved: {username}")
|
||||
|
||||
# Google's OAuth library is poorly typed so we need some of these:
|
||||
assert google_creds.token
|
||||
assert google_creds.refresh_token
|
||||
assert google_creds.expiry
|
||||
assert google_creds.scopes
|
||||
return OAuth2Credentials(
|
||||
assert granted_scopes
|
||||
|
||||
# Create OAuth2Credentials with the granted scopes
|
||||
credentials = OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=None,
|
||||
username=username,
|
||||
access_token=SecretStr(google_creds.token),
|
||||
refresh_token=SecretStr(google_creds.refresh_token),
|
||||
access_token_expires_at=int(google_creds.expiry.timestamp()),
|
||||
refresh_token=(SecretStr(google_creds.refresh_token)),
|
||||
access_token_expires_at=(
|
||||
int(google_creds.expiry.timestamp()) if google_creds.expiry else None
|
||||
),
|
||||
refresh_token_expires_at=None,
|
||||
scopes=google_creds.scopes,
|
||||
scopes=granted_scopes,
|
||||
)
|
||||
logger.debug(
|
||||
f"OAuth2Credentials object created successfully with scopes: {credentials.scopes}"
|
||||
)
|
||||
|
||||
return credentials
|
||||
|
||||
def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
session = AuthorizedSession(credentials)
|
||||
response = session.post(
|
||||
self.revoke_uri,
|
||||
params={"token": credentials.access_token.get_secret_value()},
|
||||
headers={"content-type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
|
||||
def _request_email(
|
||||
self, creds: Credentials | ExternalAccountCredentials
|
||||
@@ -65,6 +117,9 @@ class GoogleOAuthHandler(BaseOAuthHandler):
|
||||
session = AuthorizedSession(creds)
|
||||
response = session.get(self.EMAIL_ENDPOINT)
|
||||
if not response.ok:
|
||||
logger.error(
|
||||
f"Failed to get user email. Status code: {response.status_code}"
|
||||
)
|
||||
return None
|
||||
return response.json()["email"]
|
||||
|
||||
@@ -99,7 +154,7 @@ class GoogleOAuthHandler(BaseOAuthHandler):
|
||||
scopes=google_creds.scopes,
|
||||
)
|
||||
|
||||
def _setup_oauth_flow(self, scopes: list[str] | None) -> Flow:
|
||||
def _setup_oauth_flow(self, scopes: list[str]) -> Flow:
|
||||
return Flow.from_client_config(
|
||||
{
|
||||
"web": {
|
||||
|
||||
@@ -35,7 +35,9 @@ class NotionOAuthHandler(BaseOAuthHandler):
|
||||
}
|
||||
return f"{self.auth_base_url}?{urlencode(params)}"
|
||||
|
||||
def exchange_code_for_tokens(self, code: str) -> OAuth2Credentials:
|
||||
def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str]
|
||||
) -> OAuth2Credentials:
|
||||
request_body = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
@@ -75,6 +77,10 @@ class NotionOAuthHandler(BaseOAuthHandler):
|
||||
},
|
||||
)
|
||||
|
||||
def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
# Notion doesn't support token revocation
|
||||
return False
|
||||
|
||||
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
# Notion doesn't support token refresh
|
||||
return credentials
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from backend.app import run_processes
|
||||
from backend.executor import ExecutionScheduler
|
||||
from backend.server import AgentServer
|
||||
from backend.server.rest_api import AgentServer
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
from .rest_api import AgentServer
|
||||
from .ws_api import WebsocketServer
|
||||
|
||||
__all__ = ["AgentServer", "WebsocketServer"]
|
||||
|
||||
@@ -1,40 +1,26 @@
|
||||
import logging
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store import (
|
||||
SupabaseIntegrationCredentialsStore,
|
||||
)
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
CredentialsType,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Path,
|
||||
Query,
|
||||
Request,
|
||||
Response,
|
||||
)
|
||||
from pydantic import BaseModel, SecretStr
|
||||
from supabase import Client
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from ..utils import get_supabase, get_user_id
|
||||
from ..utils import get_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def get_store(supabase: Client = Depends(get_supabase)):
|
||||
return SupabaseIntegrationCredentialsStore(supabase)
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
@@ -43,21 +29,23 @@ class LoginResponse(BaseModel):
|
||||
|
||||
|
||||
@router.get("/{provider}/login")
|
||||
async def login(
|
||||
def login(
|
||||
provider: Annotated[str, Path(title="The provider to initiate an OAuth flow for")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
request: Request,
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
scopes: Annotated[
|
||||
str, Query(title="Comma-separated list of authorization scopes")
|
||||
] = "",
|
||||
) -> LoginResponse:
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
|
||||
# Generate and store a secure random state token
|
||||
state_token = await store.store_state_token(user_id, provider)
|
||||
|
||||
requested_scopes = scopes.split(",") if scopes else []
|
||||
|
||||
# Generate and store a secure random state token along with the scopes
|
||||
state_token = creds_manager.store.store_state_token(
|
||||
user_id, provider, requested_scopes
|
||||
)
|
||||
|
||||
login_url = handler.get_login_url(requested_scopes, state_token)
|
||||
|
||||
return LoginResponse(login_url=login_url, state_token=state_token)
|
||||
@@ -72,28 +60,51 @@ class CredentialsMetaResponse(BaseModel):
|
||||
|
||||
|
||||
@router.post("/{provider}/callback")
|
||||
async def callback(
|
||||
def callback(
|
||||
provider: Annotated[str, Path(title="The target provider for this OAuth exchange")],
|
||||
code: Annotated[str, Body(title="Authorization code acquired by user login")],
|
||||
state_token: Annotated[str, Body(title="Anti-CSRF nonce")],
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
request: Request,
|
||||
) -> CredentialsMetaResponse:
|
||||
logger.debug(f"Received OAuth callback for provider: {provider}")
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
|
||||
# Verify the state token
|
||||
if not await store.verify_state_token(user_id, state_token, provider):
|
||||
if not creds_manager.store.verify_state_token(user_id, state_token, provider):
|
||||
logger.warning(f"Invalid or expired state token for user {user_id}")
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired state token")
|
||||
|
||||
try:
|
||||
credentials = handler.exchange_code_for_tokens(code)
|
||||
scopes = creds_manager.store.get_any_valid_scopes_from_state_token(
|
||||
user_id, state_token, provider
|
||||
)
|
||||
logger.debug(f"Retrieved scopes from state token: {scopes}")
|
||||
|
||||
scopes = handler.handle_default_scopes(scopes)
|
||||
|
||||
credentials = handler.exchange_code_for_tokens(code, scopes)
|
||||
logger.debug(f"Received credentials with final scopes: {credentials.scopes}")
|
||||
|
||||
# Check if the granted scopes are sufficient for the requested scopes
|
||||
if not set(scopes).issubset(set(credentials.scopes)):
|
||||
# For now, we'll just log the warning and continue
|
||||
logger.warning(
|
||||
f"Granted scopes {credentials.scopes} for {provider}do not include all requested scopes {scopes}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Code->Token exchange failed for provider {provider}: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
logger.error(f"Code->Token exchange failed for provider {provider}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Failed to exchange code for tokens: {str(e)}"
|
||||
)
|
||||
|
||||
# TODO: Allow specifying `title` to set on `credentials`
|
||||
store.add_creds(user_id, credentials)
|
||||
creds_manager.create(user_id, credentials)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully processed OAuth callback for user {user_id} and provider {provider}"
|
||||
)
|
||||
return CredentialsMetaResponse(
|
||||
id=credentials.id,
|
||||
type=credentials.type,
|
||||
@@ -104,12 +115,11 @@ async def callback(
|
||||
|
||||
|
||||
@router.get("/{provider}/credentials")
|
||||
async def list_credentials(
|
||||
def list_credentials(
|
||||
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
credentials = store.get_creds_by_provider(user_id, provider)
|
||||
credentials = creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
return [
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
@@ -123,13 +133,12 @@ async def list_credentials(
|
||||
|
||||
|
||||
@router.get("/{provider}/credentials/{cred_id}")
|
||||
async def get_credential(
|
||||
def get_credential(
|
||||
provider: Annotated[str, Path(title="The provider to retrieve credentials for")],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
) -> Credentials:
|
||||
credential = store.get_creds_by_id(user_id, cred_id)
|
||||
credential = creds_manager.get(user_id, cred_id)
|
||||
if not credential:
|
||||
raise HTTPException(status_code=404, detail="Credentials not found")
|
||||
if credential.provider != provider:
|
||||
@@ -140,8 +149,7 @@ async def get_credential(
|
||||
|
||||
|
||||
@router.post("/{provider}/credentials", status_code=201)
|
||||
async def create_api_key_credentials(
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
def create_api_key_credentials(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
provider: Annotated[str, Path(title="The provider to create credentials for")],
|
||||
api_key: Annotated[str, Body(title="The API key to store")],
|
||||
@@ -158,7 +166,7 @@ async def create_api_key_credentials(
|
||||
)
|
||||
|
||||
try:
|
||||
store.add_creds(user_id, new_credentials)
|
||||
creds_manager.create(user_id, new_credentials)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to store credentials: {str(e)}"
|
||||
@@ -166,14 +174,23 @@ async def create_api_key_credentials(
|
||||
return new_credentials
|
||||
|
||||
|
||||
@router.delete("/{provider}/credentials/{cred_id}", status_code=204)
|
||||
async def delete_credential(
|
||||
class CredentialsDeletionResponse(BaseModel):
|
||||
deleted: Literal[True] = True
|
||||
revoked: bool | None = Field(
|
||||
description="Indicates whether the credentials were also revoked by their "
|
||||
"provider. `None`/`null` if not applicable, e.g. when deleting "
|
||||
"non-revocable credentials such as API keys."
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{provider}/credentials/{cred_id}")
|
||||
def delete_credentials(
|
||||
request: Request,
|
||||
provider: Annotated[str, Path(title="The provider to delete credentials for")],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to delete")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
):
|
||||
creds = store.get_creds_by_id(user_id, cred_id)
|
||||
) -> CredentialsDeletionResponse:
|
||||
creds = creds_manager.store.get_creds_by_id(user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(status_code=404, detail="Credentials not found")
|
||||
if creds.provider != provider:
|
||||
@@ -181,8 +198,14 @@ async def delete_credential(
|
||||
status_code=404, detail="Credentials do not match the specified provider"
|
||||
)
|
||||
|
||||
store.delete_creds_by_id(user_id, cred_id)
|
||||
return Response(status_code=204)
|
||||
creds_manager.delete(user_id, cred_id)
|
||||
|
||||
tokens_revoked = None
|
||||
if isinstance(creds, OAuth2Credentials):
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
tokens_revoked = handler.revoke_tokens(creds)
|
||||
|
||||
return CredentialsDeletionResponse(revoked=tokens_revoked)
|
||||
|
||||
|
||||
# -------- UTILITIES --------- #
|
||||
@@ -0,0 +1,11 @@
|
||||
from supabase import Client, create_client
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def get_supabase() -> Client:
|
||||
return create_client(
|
||||
settings.secrets.supabase_url, settings.secrets.supabase_service_role_key
|
||||
)
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
@@ -7,23 +8,22 @@ from typing import Annotated, Any, Dict
|
||||
|
||||
import uvicorn
|
||||
from autogpt_libs.auth.middleware import auth_middleware
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data import block, db
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data import user as user_db
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.credit import get_block_costs, get_user_credit_model
|
||||
from backend.data.queue import AsyncEventQueue, AsyncRedisEventQueue
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.executor import ExecutionManager, ExecutionScheduler
|
||||
from backend.server.model import CreateGraph, SetGraphActiveVersion
|
||||
from backend.util.lock import KeyedMutex
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.settings import Config, Settings
|
||||
from backend.util.service import AppService, get_service_client
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
|
||||
from .utils import get_user_id
|
||||
|
||||
@@ -32,27 +32,26 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentServer(AppService):
|
||||
mutex = KeyedMutex()
|
||||
use_redis = True
|
||||
_test_dependency_overrides = {}
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
def __init__(self, event_queue: AsyncEventQueue | None = None):
|
||||
super().__init__(port=Config().agent_server_port)
|
||||
self.event_queue = event_queue or AsyncRedisEventQueue()
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.use_redis = True
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return Config().agent_server_port
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, _: FastAPI):
|
||||
await db.connect()
|
||||
self.run_and_wait(self.event_queue.connect())
|
||||
await block.initialize_blocks()
|
||||
if await user_db.create_default_user(settings.config.enable_auth):
|
||||
await graph_db.import_packaged_templates()
|
||||
yield
|
||||
await self.event_queue.close()
|
||||
await db.disconnect()
|
||||
|
||||
def run_service(self):
|
||||
docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
|
||||
app = FastAPI(
|
||||
title="AutoGPT Agent Server",
|
||||
description=(
|
||||
@@ -62,6 +61,7 @@ class AgentServer(AppService):
|
||||
summary="AutoGPT Agent Server",
|
||||
version="0.1",
|
||||
lifespan=self.lifespan,
|
||||
docs_url=docs_url,
|
||||
)
|
||||
|
||||
if self._test_dependency_overrides:
|
||||
@@ -79,16 +79,24 @@ class AgentServer(AppService):
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
|
||||
health_router = APIRouter()
|
||||
health_router.add_api_route(
|
||||
path="/health",
|
||||
endpoint=self.health,
|
||||
methods=["GET"],
|
||||
tags=["health"],
|
||||
)
|
||||
|
||||
# Define the API routes
|
||||
api_router = APIRouter(prefix="/api")
|
||||
api_router.dependencies.append(Depends(auth_middleware))
|
||||
|
||||
# Import & Attach sub-routers
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.analytics
|
||||
import backend.server.routers.integrations
|
||||
|
||||
api_router.include_router(
|
||||
backend.server.routers.integrations.router,
|
||||
backend.server.integrations.router.router,
|
||||
prefix="/integrations",
|
||||
tags=["integrations"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
@@ -168,6 +176,12 @@ class AgentServer(AppService):
|
||||
methods=["PUT"],
|
||||
tags=["templates", "graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}",
|
||||
endpoint=self.delete_graph,
|
||||
methods=["DELETE"],
|
||||
tags=["graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/versions",
|
||||
endpoint=self.get_graph_all_versions,
|
||||
@@ -256,6 +270,7 @@ class AgentServer(AppService):
|
||||
app.add_exception_handler(500, self.handle_internal_http_error)
|
||||
|
||||
app.include_router(api_router)
|
||||
app.include_router(health_router)
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
@@ -294,12 +309,14 @@ class AgentServer(AppService):
|
||||
return wrapper
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def execution_manager_client(self) -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager, Config().execution_manager_port)
|
||||
return get_service_client(ExecutionManager)
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def execution_scheduler_client(self) -> ExecutionScheduler:
|
||||
return get_service_client(ExecutionScheduler, Config().execution_scheduler_port)
|
||||
return get_service_client(ExecutionScheduler)
|
||||
|
||||
@classmethod
|
||||
def handle_internal_http_error(cls, request: Request, exc: Exception):
|
||||
@@ -318,9 +335,9 @@ class AgentServer(AppService):
|
||||
|
||||
@classmethod
|
||||
def get_graph_blocks(cls) -> list[dict[Any, Any]]:
|
||||
blocks = block.get_blocks()
|
||||
blocks = [cls() for cls in block.get_blocks().values()]
|
||||
costs = get_block_costs()
|
||||
return [{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks.values()]
|
||||
return [{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks]
|
||||
|
||||
@classmethod
|
||||
def execute_graph_block(
|
||||
@@ -346,8 +363,10 @@ class AgentServer(AppService):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_templates(cls) -> list[graph_db.GraphMeta]:
|
||||
return await graph_db.get_graphs_meta(filter_by="template")
|
||||
async def get_templates(
|
||||
cls, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[graph_db.GraphMeta]:
|
||||
return await graph_db.get_graphs_meta(filter_by="template", user_id=user_id)
|
||||
|
||||
@classmethod
|
||||
async def get_graph(
|
||||
@@ -355,8 +374,11 @@ class AgentServer(AppService):
|
||||
graph_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
version: int | None = None,
|
||||
hide_credentials: bool = False,
|
||||
) -> graph_db.Graph:
|
||||
graph = await graph_db.get_graph(graph_id, version, user_id=user_id)
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id, version, user_id=user_id, hide_credentials=hide_credentials
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
return graph
|
||||
@@ -393,6 +415,17 @@ class AgentServer(AppService):
|
||||
) -> graph_db.Graph:
|
||||
return await cls.create_graph(create_graph, is_template=True, user_id=user_id)
|
||||
|
||||
class DeleteGraphResponse(TypedDict):
|
||||
version_counts: int
|
||||
|
||||
@classmethod
|
||||
async def delete_graph(
|
||||
cls, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> DeleteGraphResponse:
|
||||
return {
|
||||
"version_counts": await graph_db.delete_graph(graph_id, user_id=user_id)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def create_graph(
|
||||
cls,
|
||||
@@ -486,7 +519,7 @@ class AgentServer(AppService):
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
async def execute_graph(
|
||||
def execute_graph(
|
||||
self,
|
||||
graph_id: str,
|
||||
node_input: dict[Any, Any],
|
||||
@@ -509,7 +542,9 @@ class AgentServer(AppService):
|
||||
404, detail=f"Agent execution #{graph_exec_id} not found"
|
||||
)
|
||||
|
||||
self.execution_manager_client.cancel_execution(graph_exec_id)
|
||||
await asyncio.to_thread(
|
||||
lambda: self.execution_manager_client.cancel_execution(graph_exec_id)
|
||||
)
|
||||
|
||||
# Retrieve & return canceled graph execution in its final state
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
@@ -584,10 +619,16 @@ class AgentServer(AppService):
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
execution_scheduler = self.execution_scheduler_client
|
||||
|
||||
return {
|
||||
"id": execution_scheduler.add_execution_schedule(
|
||||
graph_id, graph.version, cron, input_data, user_id=user_id
|
||||
"id": await asyncio.to_thread(
|
||||
lambda: self.execution_scheduler_client.add_execution_schedule(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
cron=cron,
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
@@ -613,18 +654,8 @@ class AgentServer(AppService):
|
||||
execution_scheduler = self.execution_scheduler_client
|
||||
return execution_scheduler.get_execution_schedules(graph_id, user_id)
|
||||
|
||||
@expose
|
||||
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
|
||||
execution_result = execution_db.ExecutionResult(**execution_result_dict)
|
||||
self.run_and_wait(self.event_queue.put(execution_result))
|
||||
|
||||
@expose
|
||||
def acquire_lock(self, key: Any):
|
||||
self.mutex.lock(key)
|
||||
|
||||
@expose
|
||||
def release_lock(self, key: Any):
|
||||
self.mutex.unlock(key)
|
||||
async def health(self):
|
||||
return {"status": "healthy"}
|
||||
|
||||
@classmethod
|
||||
def update_configuration(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from autogpt_libs.auth.middleware import auth_middleware
|
||||
from fastapi import Depends, HTTPException
|
||||
from supabase import Client, create_client
|
||||
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.util.settings import Settings
|
||||
@@ -17,9 +16,3 @@ def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found in token")
|
||||
return user_id
|
||||
|
||||
|
||||
def get_supabase() -> Client:
|
||||
return create_client(
|
||||
settings.secrets.supabase_url, settings.secrets.supabase_service_role_key
|
||||
)
|
||||
|
||||
@@ -1,23 +1,34 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import uvicorn
|
||||
from autogpt_libs.auth import parse_jwt_token
|
||||
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from backend.data.queue import AsyncRedisEventQueue
|
||||
from backend.data import redis
|
||||
from backend.data.queue import AsyncRedisExecutionEventBus
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.model import ExecutionSubscription, Methods, WsMessage
|
||||
from backend.util.service import AppProcess
|
||||
from backend.util.settings import Config, Settings
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
app = FastAPI()
|
||||
event_queue = AsyncRedisEventQueue()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
manager = get_connection_manager()
|
||||
fut = asyncio.create_task(event_broadcaster(manager))
|
||||
fut.add_done_callback(lambda _: logger.info("Event broadcaster stopped"))
|
||||
yield
|
||||
|
||||
|
||||
docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
|
||||
app = FastAPI(lifespan=lifespan, docs_url=docs_url)
|
||||
_connection_manager = None
|
||||
|
||||
logger.info(f"CORS allow origins: {settings.config.backend_cors_allow_origins}")
|
||||
@@ -37,27 +48,21 @@ def get_connection_manager():
|
||||
return _connection_manager
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
await event_queue.connect()
|
||||
manager = get_connection_manager()
|
||||
asyncio.create_task(event_broadcaster(manager))
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
await event_queue.close()
|
||||
|
||||
|
||||
async def event_broadcaster(manager: ConnectionManager):
|
||||
while True:
|
||||
event = await event_queue.get()
|
||||
if event is not None:
|
||||
try:
|
||||
redis.connect()
|
||||
event_queue = AsyncRedisExecutionEventBus()
|
||||
async for event in event_queue.listen():
|
||||
await manager.send_execution_result(event)
|
||||
except Exception as e:
|
||||
logger.exception(f"Event broadcaster error: {e}")
|
||||
raise
|
||||
finally:
|
||||
redis.disconnect()
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
if settings.config.enable_auth.lower() == "true":
|
||||
if settings.config.enable_auth:
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
await websocket.close(code=4001, reason="Missing authentication token")
|
||||
|
||||
@@ -252,7 +252,7 @@ async def block_autogen_agent():
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||
input_data = {"input": "Write me a block that writes a string into a file."}
|
||||
response = await server.agent_server.execute_graph(
|
||||
response = server.agent_server.execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
|
||||
@@ -156,7 +156,7 @@ async def reddit_marketing_agent():
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||
input_data = {"subreddit": "AutoGPT"}
|
||||
response = await server.agent_server.execute_graph(
|
||||
response = server.agent_server.execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
|
||||
@@ -78,7 +78,7 @@ async def sample_agent():
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), test_user.id)
|
||||
input_data = {"input_1": "Hello", "input_2": "World"}
|
||||
response = await server.agent_server.execute_graph(
|
||||
response = server.agent_server.execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
from expiringdict import ExpiringDict
|
||||
|
||||
|
||||
class KeyedMutex:
|
||||
"""
|
||||
This class provides a mutex that can be locked and unlocked by a specific key.
|
||||
It uses an ExpiringDict to automatically clear the mutex after a specified timeout,
|
||||
in case the key is not unlocked for a specified duration, to prevent memory leaks.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.locks: dict[Any, tuple[Lock, int]] = ExpiringDict(
|
||||
max_len=6000, max_age_seconds=60
|
||||
)
|
||||
self.locks_lock = Lock()
|
||||
|
||||
def lock(self, key: Any):
|
||||
with self.locks_lock:
|
||||
lock, request_count = self.locks.get(key, (Lock(), 0))
|
||||
self.locks[key] = (lock, request_count + 1)
|
||||
lock.acquire()
|
||||
|
||||
def unlock(self, key: Any):
|
||||
with self.locks_lock:
|
||||
lock, request_count = self.locks.pop(key)
|
||||
if request_count > 1:
|
||||
self.locks[key] = (lock, request_count - 1)
|
||||
lock.release()
|
||||
@@ -1,4 +1,6 @@
|
||||
import os
|
||||
from backend.util.settings import AppEnvironment, BehaveAs, Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def configure_logging():
|
||||
@@ -6,7 +8,10 @@ def configure_logging():
|
||||
|
||||
import autogpt_libs.logging.config
|
||||
|
||||
if os.getenv("APP_ENV") != "cloud":
|
||||
if (
|
||||
settings.config.behave_as == BehaveAs.LOCAL
|
||||
or settings.config.app_env == AppEnvironment.LOCAL
|
||||
):
|
||||
autogpt_libs.logging.config.configure_logging(force_cloud_logging=False)
|
||||
else:
|
||||
autogpt_libs.logging.config.configure_logging(force_cloud_logging=True)
|
||||
|
||||
@@ -10,6 +10,16 @@ from backend.util.logging import configure_logging
|
||||
from backend.util.metrics import sentry_init
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_SERVICE_NAME = "MainProcess"
|
||||
|
||||
|
||||
def get_service_name():
|
||||
return _SERVICE_NAME
|
||||
|
||||
|
||||
def set_service_name(name: str):
|
||||
global _SERVICE_NAME
|
||||
_SERVICE_NAME = name
|
||||
|
||||
|
||||
class AppProcess(ABC):
|
||||
@@ -32,6 +42,11 @@ class AppProcess(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def service_name(cls) -> str:
|
||||
return cls.__name__
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Implement this method on a subclass to do post-execution cleanup,
|
||||
@@ -52,10 +67,12 @@ class AppProcess(ABC):
|
||||
if silent:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
logger.info(f"[{self.__class__.__name__}] Starting...")
|
||||
|
||||
set_service_name(self.service_name)
|
||||
logger.info(f"[{self.service_name}] Starting...")
|
||||
self.run()
|
||||
except (KeyboardInterrupt, SystemExit) as e:
|
||||
logger.warning(f"[{self.__class__.__name__}] Terminated: {e}; quitting...")
|
||||
logger.warning(f"[{self.service_name}] Terminated: {e}; quitting...")
|
||||
|
||||
def _self_terminate(self, signum: int, frame):
|
||||
self.cleanup()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user