mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
48 Commits
aarushikan
...
aarushikan
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9334eee41d | ||
|
|
e4a9c8216f | ||
|
|
f19ed9f652 | ||
|
|
30376a8ec8 | ||
|
|
32680a549e | ||
|
|
68158de126 | ||
|
|
6f3828fc99 | ||
|
|
26b1bca033 | ||
|
|
7f6354caae | ||
|
|
5d4d2486da | ||
|
|
2c0286e411 | ||
|
|
fca8d61cc4 | ||
|
|
2b4af19799 | ||
|
|
615a9746dc | ||
|
|
ac33c1eb03 | ||
|
|
d6d2820b92 | ||
|
|
3982e20faa | ||
|
|
c029fde502 | ||
|
|
405dd1659e | ||
|
|
2d0e51fe28 | ||
|
|
6f07d24e93 | ||
|
|
9292597d56 | ||
|
|
f6eebcab6e | ||
|
|
9fe3fed1a2 | ||
|
|
769ab18cca | ||
|
|
d46219c80f | ||
|
|
97015a91ad | ||
|
|
a2ef456525 | ||
|
|
1c71351652 | ||
|
|
bd5d2b1e86 | ||
|
|
8502928a21 | ||
|
|
c1f97415fb | ||
|
|
74e677baec | ||
|
|
992989ee71 | ||
|
|
d8145c158c | ||
|
|
9ad5e1f808 | ||
|
|
7b92bae942 | ||
|
|
c03e2fb949 | ||
|
|
dbc603c6eb | ||
|
|
c582b5512a | ||
|
|
e654aa1e7a | ||
|
|
e37744b9f2 | ||
|
|
bc1df92c29 | ||
|
|
04473cad1e | ||
|
|
2a74381ae8 | ||
|
|
d42ed088dd | ||
|
|
2aed470d26 | ||
|
|
61f1d0cdb5 |
4
.github/workflows/classic-autogpt-ci.yml
vendored
4
.github/workflows/classic-autogpt-ci.yml
vendored
@@ -2,12 +2,12 @@ name: Classic - AutoGPT CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, development, ci-test* ]
|
||||
branches: [ master, dev, ci-test* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpt-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
pull_request:
|
||||
branches: [ master, development, release-* ]
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpt-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
|
||||
@@ -8,7 +8,7 @@ on:
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
pull_request:
|
||||
branches: [ master, development, release-* ]
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpt-docker-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
|
||||
4
.github/workflows/classic-autogpts-ci.yml
vendored
4
.github/workflows/classic-autogpts-ci.yml
vendored
@@ -5,7 +5,7 @@ on:
|
||||
schedule:
|
||||
- cron: '0 8 * * *'
|
||||
push:
|
||||
branches: [ master, development, ci-test* ]
|
||||
branches: [ master, dev, ci-test* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpts-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
@@ -16,7 +16,7 @@ on:
|
||||
- 'classic/setup.py'
|
||||
- '!**/*.md'
|
||||
pull_request:
|
||||
branches: [ master, development, release-* ]
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpts-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
|
||||
4
.github/workflows/classic-benchmark-ci.yml
vendored
4
.github/workflows/classic-benchmark-ci.yml
vendored
@@ -2,13 +2,13 @@ name: Classic - AGBenchmark CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, development, ci-test* ]
|
||||
branches: [ master, dev, ci-test* ]
|
||||
paths:
|
||||
- 'classic/benchmark/**'
|
||||
- '!classic/benchmark/reports/**'
|
||||
- .github/workflows/classic-benchmark-ci.yml
|
||||
pull_request:
|
||||
branches: [ master, development, release-* ]
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- 'classic/benchmark/**'
|
||||
- '!classic/benchmark/reports/**'
|
||||
|
||||
4
.github/workflows/classic-forge-ci.yml
vendored
4
.github/workflows/classic-forge-ci.yml
vendored
@@ -2,13 +2,13 @@ name: Classic - Forge CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, development, ci-test* ]
|
||||
branches: [ master, dev, ci-test* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-forge-ci.yml'
|
||||
- 'classic/forge/**'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
pull_request:
|
||||
branches: [ master, development, release-* ]
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-forge-ci.yml'
|
||||
- 'classic/forge/**'
|
||||
|
||||
4
.github/workflows/classic-python-checks.yml
vendored
4
.github/workflows/classic-python-checks.yml
vendored
@@ -2,7 +2,7 @@ name: Classic - Python checks
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, development, ci-test* ]
|
||||
branches: [ master, dev, ci-test* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-python-checks-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
@@ -11,7 +11,7 @@ on:
|
||||
- '**.py'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
pull_request:
|
||||
branches: [ master, development, release-* ]
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-python-checks-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
|
||||
@@ -2,7 +2,7 @@ name: AutoGPT Platform - Infra
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
branches: [ master, dev ]
|
||||
paths:
|
||||
- '.github/workflows/platform-autogpt-infra-ci.yml'
|
||||
- 'autogpt_platform/infra/**'
|
||||
|
||||
4
.github/workflows/platform-backend-ci.yml
vendored
4
.github/workflows/platform-backend-ci.yml
vendored
@@ -2,12 +2,12 @@ name: AutoGPT Platform - Backend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, development, ci-test*]
|
||||
branches: [master, dev, ci-test*]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- "autogpt_platform/backend/**"
|
||||
pull_request:
|
||||
branches: [master, development, release-*]
|
||||
branches: [master, dev, release-*]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- "autogpt_platform/backend/**"
|
||||
|
||||
16
.github/workflows/platform-frontend-ci.yml
vendored
16
.github/workflows/platform-frontend-ci.yml
vendored
@@ -2,7 +2,7 @@ name: AutoGPT Platform - Frontend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master]
|
||||
branches: [master, dev]
|
||||
paths:
|
||||
- ".github/workflows/platform-frontend-ci.yml"
|
||||
- "autogpt_platform/frontend/**"
|
||||
@@ -29,15 +29,11 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
npm install
|
||||
|
||||
- name: Check formatting with Prettier
|
||||
run: |
|
||||
npx prettier --check .
|
||||
yarn install --frozen-lockfile
|
||||
|
||||
- name: Run lint
|
||||
run: |
|
||||
npm run lint
|
||||
yarn lint
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -62,18 +58,18 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
npm install
|
||||
yarn install --frozen-lockfile
|
||||
|
||||
- name: Setup Builder .env
|
||||
run: |
|
||||
cp .env.example .env
|
||||
|
||||
- name: Install Playwright Browsers
|
||||
run: npx playwright install --with-deps
|
||||
run: yarn playwright install --with-deps
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
npm run test
|
||||
yarn test
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: ${{ !cancelled() }}
|
||||
|
||||
125
.github/workflows/platform-market-ci.yml
vendored
Normal file
125
.github/workflows/platform-market-ci.yml
vendored
Normal file
@@ -0,0 +1,125 @@
|
||||
name: AutoGPT Platform - Backend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev, ci-test*]
|
||||
paths:
|
||||
- ".github/workflows/platform-market-ci.yml"
|
||||
- "autogpt_platform/market/**"
|
||||
pull_request:
|
||||
branches: [master, dev, release-*]
|
||||
paths:
|
||||
- ".github/workflows/platform-market-ci.yml"
|
||||
- "autogpt_platform/market/**"
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('backend-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
cancel-in-progress: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: autogpt_platform/market
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Setup Supabase
|
||||
uses: supabase/setup-cli@v1
|
||||
with:
|
||||
version: latest
|
||||
|
||||
- id: get_date
|
||||
name: Get date
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/market/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
run: poetry run prisma generate
|
||||
|
||||
- id: supabase
|
||||
name: Start Supabase
|
||||
working-directory: .
|
||||
run: |
|
||||
supabase init
|
||||
supabase start --exclude postgres-meta,realtime,storage-api,imgproxy,inbucket,studio,edge-runtime,logflare,vector,supavisor
|
||||
supabase status -o env | sed 's/="/=/; s/"$//' >> $GITHUB_OUTPUT
|
||||
# outputs:
|
||||
# DB_URL, API_URL, GRAPHQL_URL, ANON_KEY, SERVICE_ROLE_KEY, JWT_SECRET
|
||||
|
||||
- name: Run Database Migrations
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
env:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
- id: lint
|
||||
name: Run Linter
|
||||
run: poetry run lint
|
||||
|
||||
# Tests comment out because they do not work with prisma mock, nor have they been updated since they were created
|
||||
# - name: Run pytest with coverage
|
||||
# run: |
|
||||
# if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
# poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG test
|
||||
# else
|
||||
# poetry run pytest -s -vv test
|
||||
# fi
|
||||
# if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
# env:
|
||||
# LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
# DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
# SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
|
||||
# SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
# SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
# REDIS_HOST: 'localhost'
|
||||
# REDIS_PORT: '6379'
|
||||
# REDIS_PASSWORD: 'testpassword'
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
RUN_ENV: local
|
||||
PORT: 8080
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# token: ${{ secrets.CODECOV_TOKEN }}
|
||||
# flags: backend,${{ runner.os }}
|
||||
21
.github/workflows/repo-pr-enforce-base-branch.yml
vendored
Normal file
21
.github/workflows/repo-pr-enforce-base-branch.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
name: Repo - Enforce dev as base branch
|
||||
on:
|
||||
pull_request_target:
|
||||
branches: [ master ]
|
||||
types: [ opened ]
|
||||
|
||||
jobs:
|
||||
check_pr_target:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Check if PR is from dev or hotfix
|
||||
if: ${{ !(startsWith(github.event.pull_request.head.ref, 'hotfix/') || github.event.pull_request.head.ref == 'dev') }}
|
||||
run: |
|
||||
gh pr comment ${{ github.event.number }} --repo "$REPO" \
|
||||
--body $'This PR targets the `master` branch but does not come from `dev` or a `hotfix/*` branch.\n\nAutomatically setting the base branch to `dev`.'
|
||||
gh pr edit ${{ github.event.number }} --base dev --repo "$REPO"
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
REPO: ${{ github.repository }}
|
||||
2
.github/workflows/repo-pr-label.yml
vendored
2
.github/workflows/repo-pr-label.yml
vendored
@@ -3,7 +3,7 @@ name: Repo - Pull Request auto-label
|
||||
on:
|
||||
# So that PRs touching the same files as the push are updated
|
||||
push:
|
||||
branches: [ master, development, release-* ]
|
||||
branches: [ master, dev, release-* ]
|
||||
paths-ignore:
|
||||
- 'classic/forge/tests/vcr_cassettes'
|
||||
- 'classic/benchmark/reports/**'
|
||||
|
||||
@@ -11,7 +11,7 @@ Also check out our [🚀 Roadmap][roadmap] for information about our priorities
|
||||
[kanban board]: https://github.com/orgs/Significant-Gravitas/projects/1
|
||||
|
||||
## Contributing to the AutoGPT Platform Folder
|
||||
All contributions to [the autogpt_platform folder](https://github.com/Significant-Gravitas/AutoGPT/blob/master/autogpt_platform) will be under our [Contribution License Agreement](https://github.com/Significant-Gravitas/AutoGPT/blob/master/autogpt_platform/Contributor%20License%20Agreement%20(CLA).md). By making a pull request contributing to this folder, you agree to the terms of our CLA for your contribution.
|
||||
All contributions to [the autogpt_platform folder](https://github.com/Significant-Gravitas/AutoGPT/blob/master/autogpt_platform) will be under our [Contribution License Agreement](https://github.com/Significant-Gravitas/AutoGPT/blob/master/autogpt_platform/Contributor%20License%20Agreement%20(CLA).md). By making a pull request contributing to this folder, you agree to the terms of our CLA for your contribution. All contributions to other folders will be under the MIT license.
|
||||
|
||||
## In short
|
||||
1. Avoid duplicate work, issues, PRs etc.
|
||||
|
||||
8
LICENSE
8
LICENSE
@@ -1,7 +1,13 @@
|
||||
All portions of this repository are under one of two licenses. The majority of the AutoGPT repository is under the MIT License below. The autogpt_platform folder is under the
|
||||
Polyform Shield License.
|
||||
|
||||
|
||||
MIT License
|
||||
|
||||
|
||||
Copyright (c) 2023 Toran Bruce Richards
|
||||
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
@@ -9,9 +15,11 @@ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
|
||||
17
README.md
17
README.md
@@ -65,6 +65,7 @@ Here are two examples of what you can do with AutoGPT:
|
||||
These examples show just a glimpse of what you can achieve with AutoGPT! You can create customized workflows to build agents for any use case.
|
||||
|
||||
---
|
||||
### Mission and Licencing
|
||||
Our mission is to provide the tools, so that you can focus on what matters:
|
||||
|
||||
- 🏗️ **Building** - Lay the foundation for something amazing.
|
||||
@@ -77,6 +78,13 @@ Be part of the revolution! **AutoGPT** is here to stay, at the forefront of AI i
|
||||
 | 
|
||||
**🚀 [Contributing](CONTRIBUTING.md)**
|
||||
|
||||
**Licensing:**
|
||||
|
||||
MIT License: The majority of the AutoGPT repository is under the MIT License.
|
||||
|
||||
Polyform Shield License: This license applies to the autogpt_platform folder.
|
||||
|
||||
For more information, see https://agpt.co/blog/introducing-the-autogpt-platform
|
||||
|
||||
---
|
||||
## 🤖 AutoGPT Classic
|
||||
@@ -150,6 +158,8 @@ To maintain a uniform standard and ensure seamless compatibility with many curre
|
||||
|
||||
---
|
||||
|
||||
## Stars stats
|
||||
|
||||
<p align="center">
|
||||
<a href="https://star-history.com/#Significant-Gravitas/AutoGPT">
|
||||
<picture>
|
||||
@@ -159,3 +169,10 @@ To maintain a uniform standard and ensure seamless compatibility with many curre
|
||||
</picture>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
|
||||
## ⚡ Contributors
|
||||
|
||||
<a href="https://github.com/Significant-Gravitas/AutoGPT/graphs/contributors" alt="View Contributors">
|
||||
<img src="https://contrib.rocks/image?repo=Significant-Gravitas/AutoGPT&max=1000&columns=10" alt="Contributors" />
|
||||
</a>
|
||||
|
||||
@@ -149,6 +149,3 @@ To persist data for PostgreSQL and Redis, you can modify the `docker-compose.yml
|
||||
3. Save the file and run `docker compose up -d` to apply the changes.
|
||||
|
||||
This configuration will create named volumes for PostgreSQL and Redis, ensuring that your data persists across container restarts.
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from .store import SupabaseIntegrationCredentialsStore
|
||||
from .types import APIKeyCredentials, OAuth2Credentials
|
||||
from .types import Credentials, APIKeyCredentials, OAuth2Credentials
|
||||
|
||||
__all__ = [
|
||||
"SupabaseIntegrationCredentialsStore",
|
||||
"Credentials",
|
||||
"APIKeyCredentials",
|
||||
"OAuth2Credentials",
|
||||
]
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from supabase import Client
|
||||
if TYPE_CHECKING:
|
||||
from redis import Redis
|
||||
from supabase import Client
|
||||
|
||||
from autogpt_libs.utils.synchronize import RedisKeyedMutex
|
||||
|
||||
from .types import (
|
||||
Credentials,
|
||||
@@ -14,26 +18,28 @@ from .types import (
|
||||
|
||||
|
||||
class SupabaseIntegrationCredentialsStore:
|
||||
def __init__(self, supabase: Client):
|
||||
def __init__(self, supabase: "Client", redis: "Redis"):
|
||||
self.supabase = supabase
|
||||
self.locks = RedisKeyedMutex(redis)
|
||||
|
||||
def add_creds(self, user_id: str, credentials: Credentials) -> None:
|
||||
if self.get_creds_by_id(user_id, credentials.id):
|
||||
raise ValueError(
|
||||
f"Can not re-create existing credentials with ID {credentials.id} "
|
||||
f"for user with ID {user_id}"
|
||||
with self.locked_user_metadata(user_id):
|
||||
if self.get_creds_by_id(user_id, credentials.id):
|
||||
raise ValueError(
|
||||
f"Can not re-create existing credentials #{credentials.id} "
|
||||
f"for user #{user_id}"
|
||||
)
|
||||
self._set_user_integration_creds(
|
||||
user_id, [*self.get_all_creds(user_id), credentials]
|
||||
)
|
||||
self._set_user_integration_creds(
|
||||
user_id, [*self.get_all_creds(user_id), credentials]
|
||||
)
|
||||
|
||||
def get_all_creds(self, user_id: str) -> list[Credentials]:
|
||||
user_metadata = self._get_user_metadata(user_id)
|
||||
return UserMetadata.model_validate(user_metadata).integration_credentials
|
||||
|
||||
def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None:
|
||||
credentials = self.get_all_creds(user_id)
|
||||
return next((c for c in credentials if c.id == credentials_id), None)
|
||||
all_credentials = self.get_all_creds(user_id)
|
||||
return next((c for c in all_credentials if c.id == credentials_id), None)
|
||||
|
||||
def get_creds_by_provider(self, user_id: str, provider: str) -> list[Credentials]:
|
||||
credentials = self.get_all_creds(user_id)
|
||||
@@ -44,42 +50,45 @@ class SupabaseIntegrationCredentialsStore:
|
||||
return list(set(c.provider for c in credentials))
|
||||
|
||||
def update_creds(self, user_id: str, updated: Credentials) -> None:
|
||||
current = self.get_creds_by_id(user_id, updated.id)
|
||||
if not current:
|
||||
raise ValueError(
|
||||
f"Credentials with ID {updated.id} "
|
||||
f"for user with ID {user_id} not found"
|
||||
)
|
||||
if type(current) is not type(updated):
|
||||
raise TypeError(
|
||||
f"Can not update credentials with ID {updated.id} "
|
||||
f"from type {type(current)} "
|
||||
f"to type {type(updated)}"
|
||||
)
|
||||
with self.locked_user_metadata(user_id):
|
||||
current = self.get_creds_by_id(user_id, updated.id)
|
||||
if not current:
|
||||
raise ValueError(
|
||||
f"Credentials with ID {updated.id} "
|
||||
f"for user with ID {user_id} not found"
|
||||
)
|
||||
if type(current) is not type(updated):
|
||||
raise TypeError(
|
||||
f"Can not update credentials with ID {updated.id} "
|
||||
f"from type {type(current)} "
|
||||
f"to type {type(updated)}"
|
||||
)
|
||||
|
||||
# Ensure no scopes are removed when updating credentials
|
||||
if (
|
||||
isinstance(updated, OAuth2Credentials)
|
||||
and isinstance(current, OAuth2Credentials)
|
||||
and not set(updated.scopes).issuperset(current.scopes)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Can not update credentials with ID {updated.id} "
|
||||
f"and scopes {current.scopes} "
|
||||
f"to more restrictive set of scopes {updated.scopes}"
|
||||
)
|
||||
# Ensure no scopes are removed when updating credentials
|
||||
if (
|
||||
isinstance(updated, OAuth2Credentials)
|
||||
and isinstance(current, OAuth2Credentials)
|
||||
and not set(updated.scopes).issuperset(current.scopes)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Can not update credentials with ID {updated.id} "
|
||||
f"and scopes {current.scopes} "
|
||||
f"to more restrictive set of scopes {updated.scopes}"
|
||||
)
|
||||
|
||||
# Update the credentials
|
||||
updated_credentials_list = [
|
||||
updated if c.id == updated.id else c for c in self.get_all_creds(user_id)
|
||||
]
|
||||
self._set_user_integration_creds(user_id, updated_credentials_list)
|
||||
# Update the credentials
|
||||
updated_credentials_list = [
|
||||
updated if c.id == updated.id else c
|
||||
for c in self.get_all_creds(user_id)
|
||||
]
|
||||
self._set_user_integration_creds(user_id, updated_credentials_list)
|
||||
|
||||
def delete_creds_by_id(self, user_id: str, credentials_id: str) -> None:
|
||||
filtered_credentials = [
|
||||
c for c in self.get_all_creds(user_id) if c.id != credentials_id
|
||||
]
|
||||
self._set_user_integration_creds(user_id, filtered_credentials)
|
||||
with self.locked_user_metadata(user_id):
|
||||
filtered_credentials = [
|
||||
c for c in self.get_all_creds(user_id) if c.id != credentials_id
|
||||
]
|
||||
self._set_user_integration_creds(user_id, filtered_credentials)
|
||||
|
||||
async def store_state_token(
|
||||
self, user_id: str, provider: str, scopes: list[str]
|
||||
@@ -94,14 +103,15 @@ class SupabaseIntegrationCredentialsStore:
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
user_metadata = self._get_user_metadata(user_id)
|
||||
oauth_states = user_metadata.get("integration_oauth_states", [])
|
||||
oauth_states.append(state.model_dump())
|
||||
user_metadata["integration_oauth_states"] = oauth_states
|
||||
with self.locked_user_metadata(user_id):
|
||||
user_metadata = self._get_user_metadata(user_id)
|
||||
oauth_states = user_metadata.get("integration_oauth_states", [])
|
||||
oauth_states.append(state.model_dump())
|
||||
user_metadata["integration_oauth_states"] = oauth_states
|
||||
|
||||
self.supabase.auth.admin.update_user_by_id(
|
||||
user_id, {"user_metadata": user_metadata}
|
||||
)
|
||||
self.supabase.auth.admin.update_user_by_id(
|
||||
user_id, {"user_metadata": user_metadata}
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
@@ -136,29 +146,30 @@ class SupabaseIntegrationCredentialsStore:
|
||||
return []
|
||||
|
||||
async def verify_state_token(self, user_id: str, token: str, provider: str) -> bool:
|
||||
user_metadata = self._get_user_metadata(user_id)
|
||||
oauth_states = user_metadata.get("integration_oauth_states", [])
|
||||
with self.locked_user_metadata(user_id):
|
||||
user_metadata = self._get_user_metadata(user_id)
|
||||
oauth_states = user_metadata.get("integration_oauth_states", [])
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
valid_state = next(
|
||||
(
|
||||
state
|
||||
for state in oauth_states
|
||||
if state["token"] == token
|
||||
and state["provider"] == provider
|
||||
and state["expires_at"] > now.timestamp()
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if valid_state:
|
||||
# Remove the used state
|
||||
oauth_states.remove(valid_state)
|
||||
user_metadata["integration_oauth_states"] = oauth_states
|
||||
self.supabase.auth.admin.update_user_by_id(
|
||||
user_id, {"user_metadata": user_metadata}
|
||||
now = datetime.now(timezone.utc)
|
||||
valid_state = next(
|
||||
(
|
||||
state
|
||||
for state in oauth_states
|
||||
if state["token"] == token
|
||||
and state["provider"] == provider
|
||||
and state["expires_at"] > now.timestamp()
|
||||
),
|
||||
None,
|
||||
)
|
||||
return True
|
||||
|
||||
if valid_state:
|
||||
# Remove the used state
|
||||
oauth_states.remove(valid_state)
|
||||
user_metadata["integration_oauth_states"] = oauth_states
|
||||
self.supabase.auth.admin.update_user_by_id(
|
||||
user_id, {"user_metadata": user_metadata}
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -178,3 +189,7 @@ class SupabaseIntegrationCredentialsStore:
|
||||
if not response.user:
|
||||
raise ValueError(f"User with ID {user_id} not found")
|
||||
return cast(UserMetadataRaw, response.user.user_metadata)
|
||||
|
||||
def locked_user_metadata(self, user_id: str):
|
||||
key = (self.supabase.supabase_url, f"user:{user_id}", "metadata")
|
||||
return self.locks.locked(key)
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
from contextlib import contextmanager
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from expiringdict import ExpiringDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
|
||||
class RedisKeyedMutex:
|
||||
"""
|
||||
This class provides a mutex that can be locked and unlocked by a specific key,
|
||||
using Redis as a distributed locking provider.
|
||||
It uses an ExpiringDict to automatically clear the mutex after a specified timeout,
|
||||
in case the key is not unlocked for a specified duration, to prevent memory leaks.
|
||||
"""
|
||||
|
||||
def __init__(self, redis: "Redis", timeout: int | None = 60):
|
||||
self.redis = redis
|
||||
self.timeout = timeout
|
||||
self.locks: dict[Any, "RedisLock"] = ExpiringDict(
|
||||
max_len=6000, max_age_seconds=self.timeout
|
||||
)
|
||||
self.locks_lock = Lock()
|
||||
|
||||
@contextmanager
|
||||
def locked(self, key: Any):
|
||||
lock = self.acquire(key)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
def acquire(self, key: Any) -> "RedisLock":
|
||||
"""Acquires and returns a lock with the given key"""
|
||||
with self.locks_lock:
|
||||
if key not in self.locks:
|
||||
self.locks[key] = self.redis.lock(
|
||||
str(key), self.timeout, thread_local=False
|
||||
)
|
||||
lock = self.locks[key]
|
||||
lock.acquire()
|
||||
return lock
|
||||
|
||||
def release(self, key: Any):
|
||||
if lock := self.locks.get(key):
|
||||
lock.release()
|
||||
|
||||
def release_all_locks(self):
|
||||
"""Call this on process termination to ensure all locks are released"""
|
||||
self.locks_lock.acquire(blocking=False)
|
||||
for lock in self.locks.values():
|
||||
if lock.locked() and lock.owned():
|
||||
lock.release()
|
||||
36
autogpt_platform/autogpt_libs/poetry.lock
generated
36
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -377,6 +377,20 @@ files = [
|
||||
[package.extras]
|
||||
test = ["pytest (>=6)"]
|
||||
|
||||
[[package]]
|
||||
name = "expiringdict"
|
||||
version = "1.2.2"
|
||||
description = "Dictionary with auto-expiring values for caching purposes"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "expiringdict-1.2.2-py3-none-any.whl", hash = "sha256:09a5d20bc361163e6432a874edd3179676e935eb81b925eccef48d409a8a45e8"},
|
||||
{file = "expiringdict-1.2.2.tar.gz", hash = "sha256:300fb92a7e98f15b05cf9a856c1415b3bc4f2e132be07daa326da6414c23ee09"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
tests = ["coverage", "coveralls", "dill", "mock", "nose"]
|
||||
|
||||
[[package]]
|
||||
name = "frozenlist"
|
||||
version = "1.4.1"
|
||||
@@ -1031,6 +1045,7 @@ description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"},
|
||||
{file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"},
|
||||
]
|
||||
|
||||
@@ -1041,6 +1056,7 @@ description = "A collection of ASN.1-based protocols modules"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"},
|
||||
{file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"},
|
||||
]
|
||||
|
||||
@@ -1253,6 +1269,24 @@ python-dateutil = ">=2.8.1,<3.0.0"
|
||||
typing-extensions = ">=4.12.2,<5.0.0"
|
||||
websockets = ">=11,<13"
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "5.1.1"
|
||||
description = "Python client for Redis database and key-value store"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "redis-5.1.1-py3-none-any.whl", hash = "sha256:f8ea06b7482a668c6475ae202ed8d9bcaa409f6e87fb77ed1043d912afd62e24"},
|
||||
{file = "redis-5.1.1.tar.gz", hash = "sha256:f6c997521fedbae53387307c5d0bf784d9acc28d9f1d058abeac566ec4dbed72"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""}
|
||||
|
||||
[package.extras]
|
||||
hiredis = ["hiredis (>=3.0.0)"]
|
||||
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "requests"
|
||||
version = "2.32.3"
|
||||
@@ -1690,4 +1724,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "e9b6e5d877eeb9c9f1ebc69dead1985d749facc160afbe61f3bf37e9a6e35aa5"
|
||||
content-hash = "ad9a4c8b399f6480a9f70319d13df810f92f63b532d4e10503d283f0948bed6c"
|
||||
|
||||
@@ -8,6 +8,7 @@ packages = [{ include = "autogpt_libs" }]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
colorama = "^0.4.6"
|
||||
expiringdict = "^1.2.2"
|
||||
google-cloud-logging = "^3.8.0"
|
||||
pydantic = "^2.8.2"
|
||||
pydantic-settings = "^2.5.2"
|
||||
@@ -16,6 +17,9 @@ python = ">=3.10,<4.0"
|
||||
python-dotenv = "^1.0.1"
|
||||
supabase = "^2.7.2"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
redis = "^5.0.8"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
@@ -12,7 +12,10 @@ REDIS_PORT=6379
|
||||
REDIS_PASSWORD=password
|
||||
|
||||
ENABLE_CREDIT=false
|
||||
APP_ENV="local"
|
||||
# What environment things should be logged under: local dev or prod
|
||||
APP_ENV=local
|
||||
# What environment to behave as: "local" or "cloud"
|
||||
BEHAVE_AS=local
|
||||
PYRO_HOST=localhost
|
||||
SENTRY_DSN=
|
||||
|
||||
|
||||
@@ -24,10 +24,12 @@ def main(**kwargs):
|
||||
Run all the processes required for the AutoGPT-server (REST and WebSocket APIs).
|
||||
"""
|
||||
|
||||
from backend.executor import ExecutionManager, ExecutionScheduler
|
||||
from backend.server import AgentServer, WebsocketServer
|
||||
from backend.executor import DatabaseManager, ExecutionManager, ExecutionScheduler
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.server.ws_api import WebsocketServer
|
||||
|
||||
run_processes(
|
||||
DatabaseManager(),
|
||||
ExecutionManager(),
|
||||
ExecutionScheduler(),
|
||||
WebsocketServer(),
|
||||
|
||||
@@ -53,15 +53,33 @@ for cls in all_subclasses(Block):
|
||||
if block.id in AVAILABLE_BLOCKS:
|
||||
raise ValueError(f"Block ID {block.name} error: {block.id} is already in use")
|
||||
|
||||
input_schema = block.input_schema.model_fields
|
||||
output_schema = block.output_schema.model_fields
|
||||
|
||||
# Prevent duplicate field name in input_schema and output_schema
|
||||
duplicate_field_names = set(block.input_schema.model_fields.keys()) & set(
|
||||
block.output_schema.model_fields.keys()
|
||||
)
|
||||
duplicate_field_names = set(input_schema.keys()) & set(output_schema.keys())
|
||||
if duplicate_field_names:
|
||||
raise ValueError(
|
||||
f"{block.name} has duplicate field names in input_schema and output_schema: {duplicate_field_names}"
|
||||
)
|
||||
|
||||
# Make sure `error` field is a string in the output schema
|
||||
if "error" in output_schema and output_schema["error"].annotation is not str:
|
||||
raise ValueError(
|
||||
f"{block.name} `error` field in output_schema must be a string"
|
||||
)
|
||||
|
||||
# Make sure all fields in input_schema and output_schema are annotated and has a value
|
||||
for field_name, field in [*input_schema.items(), *output_schema.items()]:
|
||||
if field.annotation is None:
|
||||
raise ValueError(
|
||||
f"{block.name} has a field {field_name} that is not annotated"
|
||||
)
|
||||
if field.json_schema_extra is None:
|
||||
raise ValueError(
|
||||
f"{block.name} has a field {field_name} not defined as SchemaField"
|
||||
)
|
||||
|
||||
for field in block.input_schema.model_fields.values():
|
||||
if field.annotation is bool and field.default not in (True, False):
|
||||
raise ValueError(f"{block.name} has a boolean field with no default value")
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from pydantic import Field
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||
@@ -130,9 +128,13 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
description="""1. Use short and punctuated sentences\n\n2. Use linebreaks to create a new clip\n\n3. Text outside of brackets is spoken by the AI, and [text between brackets] will be used to guide the visual generation. For example, [close-up of a cat] will show a close-up of a cat.""",
|
||||
placeholder="[close-up of a cat] Meow!",
|
||||
)
|
||||
ratio: str = Field(description="Aspect ratio of the video", default="9 / 16")
|
||||
resolution: str = Field(description="Resolution of the video", default="720p")
|
||||
frame_rate: int = Field(description="Frame rate of the video", default=60)
|
||||
ratio: str = SchemaField(
|
||||
description="Aspect ratio of the video", default="9 / 16"
|
||||
)
|
||||
resolution: str = SchemaField(
|
||||
description="Resolution of the video", default="720p"
|
||||
)
|
||||
frame_rate: int = SchemaField(description="Frame rate of the video", default=60)
|
||||
generation_preset: GenerationPreset = SchemaField(
|
||||
description="Generation preset for visual style - only effects AI generated visuals",
|
||||
default=GenerationPreset.LEONARDO,
|
||||
@@ -155,8 +157,8 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
video_url: str = Field(description="The URL of the created video")
|
||||
error: Optional[str] = Field(description="Error message if the request failed")
|
||||
video_url: str = SchemaField(description="The URL of the created video")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -239,69 +241,58 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# Create a new Webhook.site URL
|
||||
webhook_token, webhook_url = self.create_webhook()
|
||||
logger.debug(f"Webhook URL: {webhook_url}")
|
||||
# Create a new Webhook.site URL
|
||||
webhook_token, webhook_url = self.create_webhook()
|
||||
logger.debug(f"Webhook URL: {webhook_url}")
|
||||
|
||||
audio_url = input_data.background_music.audio_url
|
||||
audio_url = input_data.background_music.audio_url
|
||||
|
||||
payload = {
|
||||
"frameRate": input_data.frame_rate,
|
||||
"resolution": input_data.resolution,
|
||||
"frameDurationMultiplier": 18,
|
||||
"webhook": webhook_url,
|
||||
"creationParams": {
|
||||
"mediaType": input_data.video_style,
|
||||
"captionPresetName": "Wrap 1",
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"hasEnhancedGeneration": True,
|
||||
"generationPreset": input_data.generation_preset.name,
|
||||
"selectedAudio": input_data.background_music,
|
||||
"origin": "/create",
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
"slug": "create-tiktok-video",
|
||||
"hasToGenerateVoice": True,
|
||||
"hasToTranscript": False,
|
||||
"hasToSearchMedia": True,
|
||||
"hasAvatar": False,
|
||||
"hasWebsiteRecorder": False,
|
||||
"hasTextSmallAtBottom": False,
|
||||
"ratio": input_data.ratio,
|
||||
"sourceType": "contentScraping",
|
||||
"selectedStoryStyle": {"value": "custom", "label": "Custom"},
|
||||
"hasToGenerateVideos": input_data.video_style
|
||||
!= VisualMediaType.STOCK_VIDEOS,
|
||||
"audioUrl": audio_url,
|
||||
},
|
||||
}
|
||||
payload = {
|
||||
"frameRate": input_data.frame_rate,
|
||||
"resolution": input_data.resolution,
|
||||
"frameDurationMultiplier": 18,
|
||||
"webhook": webhook_url,
|
||||
"creationParams": {
|
||||
"mediaType": input_data.video_style,
|
||||
"captionPresetName": "Wrap 1",
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"hasEnhancedGeneration": True,
|
||||
"generationPreset": input_data.generation_preset.name,
|
||||
"selectedAudio": input_data.background_music,
|
||||
"origin": "/create",
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
"slug": "create-tiktok-video",
|
||||
"hasToGenerateVoice": True,
|
||||
"hasToTranscript": False,
|
||||
"hasToSearchMedia": True,
|
||||
"hasAvatar": False,
|
||||
"hasWebsiteRecorder": False,
|
||||
"hasTextSmallAtBottom": False,
|
||||
"ratio": input_data.ratio,
|
||||
"sourceType": "contentScraping",
|
||||
"selectedStoryStyle": {"value": "custom", "label": "Custom"},
|
||||
"hasToGenerateVideos": input_data.video_style
|
||||
!= VisualMediaType.STOCK_VIDEOS,
|
||||
"audioUrl": audio_url,
|
||||
},
|
||||
}
|
||||
|
||||
logger.debug("Creating video...")
|
||||
response = self.create_video(input_data.api_key.get_secret_value(), payload)
|
||||
pid = response.get("pid")
|
||||
logger.debug("Creating video...")
|
||||
response = self.create_video(input_data.api_key.get_secret_value(), payload)
|
||||
pid = response.get("pid")
|
||||
|
||||
if not pid:
|
||||
logger.error(
|
||||
f"Failed to create video: No project ID returned. API Response: {response}"
|
||||
)
|
||||
yield "error", "Failed to create video: No project ID returned"
|
||||
else:
|
||||
logger.debug(
|
||||
f"Video created with project ID: {pid}. Waiting for completion..."
|
||||
)
|
||||
video_url = self.wait_for_video(
|
||||
input_data.api_key.get_secret_value(), pid, webhook_token
|
||||
)
|
||||
logger.debug(f"Video ready: {video_url}")
|
||||
yield "video_url", video_url
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.exception("Error creating video")
|
||||
yield "error", f"Error creating video: {str(e)}"
|
||||
except ValueError as e:
|
||||
logger.exception("Error in video creation process")
|
||||
yield "error", str(e)
|
||||
except TimeoutError as e:
|
||||
logger.exception("Video creation timed out")
|
||||
yield "error", str(e)
|
||||
if not pid:
|
||||
logger.error(
|
||||
f"Failed to create video: No project ID returned. API Response: {response}"
|
||||
)
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
else:
|
||||
logger.debug(
|
||||
f"Video created with project ID: {pid}. Waiting for completion..."
|
||||
)
|
||||
video_url = self.wait_for_video(
|
||||
input_data.api_key.get_secret_value(), pid, webhook_token
|
||||
)
|
||||
logger.debug(f"Video ready: {video_url}")
|
||||
yield "video_url", video_url
|
||||
|
||||
@@ -2,7 +2,6 @@ import re
|
||||
from typing import Any, List
|
||||
|
||||
from jinja2 import BaseLoader, Environment
|
||||
from pydantic import Field
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||
from backend.data.model import SchemaField
|
||||
@@ -19,18 +18,18 @@ class StoreValueBlock(Block):
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
input: Any = Field(
|
||||
input: Any = SchemaField(
|
||||
description="Trigger the block to produce the output. "
|
||||
"The value is only used when `data` is None."
|
||||
)
|
||||
data: Any = Field(
|
||||
data: Any = SchemaField(
|
||||
description="The constant data to be retained in the block. "
|
||||
"This value is passed as `output`.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any
|
||||
output: Any = SchemaField(description="The stored data retained in the block.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -56,10 +55,10 @@ class StoreValueBlock(Block):
|
||||
|
||||
class PrintToConsoleBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str
|
||||
text: str = SchemaField(description="The text to print to the console.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str
|
||||
status: str = SchemaField(description="The status of the print operation.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -79,12 +78,14 @@ class PrintToConsoleBlock(Block):
|
||||
|
||||
class FindInDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: Any = Field(description="Dictionary to lookup from")
|
||||
key: str | int = Field(description="Key to lookup in the dictionary")
|
||||
input: Any = SchemaField(description="Dictionary to lookup from")
|
||||
key: str | int = SchemaField(description="Key to lookup in the dictionary")
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = Field(description="Value found for the given key")
|
||||
missing: Any = Field(description="Value of the input that missing the key")
|
||||
output: Any = SchemaField(description="Value found for the given key")
|
||||
missing: Any = SchemaField(
|
||||
description="Value of the input that missing the key"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -330,20 +331,17 @@ class AddToDictionaryBlock(Block):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# If no dictionary is provided, create a new one
|
||||
if input_data.dictionary is None:
|
||||
updated_dict = {}
|
||||
else:
|
||||
# Create a copy of the input dictionary to avoid modifying the original
|
||||
updated_dict = input_data.dictionary.copy()
|
||||
# If no dictionary is provided, create a new one
|
||||
if input_data.dictionary is None:
|
||||
updated_dict = {}
|
||||
else:
|
||||
# Create a copy of the input dictionary to avoid modifying the original
|
||||
updated_dict = input_data.dictionary.copy()
|
||||
|
||||
# Add the new key-value pair
|
||||
updated_dict[input_data.key] = input_data.value
|
||||
# Add the new key-value pair
|
||||
updated_dict[input_data.key] = input_data.value
|
||||
|
||||
yield "updated_dictionary", updated_dict
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to add entry to dictionary: {str(e)}"
|
||||
yield "updated_dictionary", updated_dict
|
||||
|
||||
|
||||
class AddToListBlock(Block):
|
||||
@@ -401,23 +399,20 @@ class AddToListBlock(Block):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# If no list is provided, create a new one
|
||||
if input_data.list is None:
|
||||
updated_list = []
|
||||
else:
|
||||
# Create a copy of the input list to avoid modifying the original
|
||||
updated_list = input_data.list.copy()
|
||||
# If no list is provided, create a new one
|
||||
if input_data.list is None:
|
||||
updated_list = []
|
||||
else:
|
||||
# Create a copy of the input list to avoid modifying the original
|
||||
updated_list = input_data.list.copy()
|
||||
|
||||
# Add the new entry
|
||||
if input_data.position is None:
|
||||
updated_list.append(input_data.entry)
|
||||
else:
|
||||
updated_list.insert(input_data.position, input_data.entry)
|
||||
# Add the new entry
|
||||
if input_data.position is None:
|
||||
updated_list.append(input_data.entry)
|
||||
else:
|
||||
updated_list.insert(input_data.position, input_data.entry)
|
||||
|
||||
yield "updated_list", updated_list
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to add entry to list: {str(e)}"
|
||||
yield "updated_list", updated_list
|
||||
|
||||
|
||||
class NoteBlock(Block):
|
||||
|
||||
@@ -3,6 +3,7 @@ import re
|
||||
from typing import Type
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class BlockInstallationBlock(Block):
|
||||
@@ -15,11 +16,17 @@ class BlockInstallationBlock(Block):
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
code: str
|
||||
code: str = SchemaField(
|
||||
description="Python code of the block to be installed",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: str
|
||||
error: str
|
||||
success: str = SchemaField(
|
||||
description="Success message if the block is installed successfully",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the block installation fails",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -37,14 +44,12 @@ class BlockInstallationBlock(Block):
|
||||
if search := re.search(r"class (\w+)\(Block\):", code):
|
||||
class_name = search.group(1)
|
||||
else:
|
||||
yield "error", "No class found in the code."
|
||||
return
|
||||
raise RuntimeError("No class found in the code.")
|
||||
|
||||
if search := re.search(r"id=\"(\w+-\w+-\w+-\w+-\w+)\"", code):
|
||||
file_name = search.group(1)
|
||||
else:
|
||||
yield "error", "No UUID found in the code."
|
||||
return
|
||||
raise RuntimeError("No UUID found in the code.")
|
||||
|
||||
block_dir = os.path.dirname(__file__)
|
||||
file_path = f"{block_dir}/{file_name}.py"
|
||||
@@ -63,4 +68,4 @@ class BlockInstallationBlock(Block):
|
||||
yield "success", "Block installed successfully."
|
||||
except Exception as e:
|
||||
os.remove(file_path)
|
||||
yield "error", f"[Code]\n{code}\n\n[Error]\n{str(e)}"
|
||||
raise RuntimeError(f"[Code]\n{code}\n\n[Error]\n{str(e)}")
|
||||
|
||||
@@ -1,21 +1,49 @@
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import ContributorDetails
|
||||
from backend.data.model import ContributorDetails, SchemaField
|
||||
|
||||
|
||||
class ReadCsvBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
contents: str
|
||||
delimiter: str = ","
|
||||
quotechar: str = '"'
|
||||
escapechar: str = "\\"
|
||||
has_header: bool = True
|
||||
skip_rows: int = 0
|
||||
strip: bool = True
|
||||
skip_columns: list[str] = []
|
||||
contents: str = SchemaField(
|
||||
description="The contents of the CSV file to read",
|
||||
placeholder="a, b, c\n1,2,3\n4,5,6",
|
||||
)
|
||||
delimiter: str = SchemaField(
|
||||
description="The delimiter used in the CSV file",
|
||||
default=",",
|
||||
)
|
||||
quotechar: str = SchemaField(
|
||||
description="The character used to quote fields",
|
||||
default='"',
|
||||
)
|
||||
escapechar: str = SchemaField(
|
||||
description="The character used to escape the delimiter",
|
||||
default="\\",
|
||||
)
|
||||
has_header: bool = SchemaField(
|
||||
description="Whether the CSV file has a header row",
|
||||
default=True,
|
||||
)
|
||||
skip_rows: int = SchemaField(
|
||||
description="The number of rows to skip from the start of the file",
|
||||
default=0,
|
||||
)
|
||||
strip: bool = SchemaField(
|
||||
description="Whether to strip whitespace from the values",
|
||||
default=True,
|
||||
)
|
||||
skip_columns: list[str] = SchemaField(
|
||||
description="The columns to skip from the start of the row",
|
||||
default=[],
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
row: dict[str, str]
|
||||
all_data: list[dict[str, str]]
|
||||
row: dict[str, str] = SchemaField(
|
||||
description="The data produced from each row in the CSV file"
|
||||
)
|
||||
all_data: list[dict[str, str]] = SchemaField(
|
||||
description="All the data in the CSV file as a list of rows"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -35,8 +35,5 @@ This is a "quoted" string.""",
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
decoded_text = codecs.decode(input_data.text, "unicode_escape")
|
||||
yield "decoded_text", decoded_text
|
||||
except Exception as e:
|
||||
yield "error", f"Error decoding text: {str(e)}"
|
||||
decoded_text = codecs.decode(input_data.text, "unicode_escape")
|
||||
yield "decoded_text", decoded_text
|
||||
|
||||
@@ -2,10 +2,9 @@ import asyncio
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
from pydantic import Field
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import BlockSecret, SecretField
|
||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||
|
||||
|
||||
class ReadDiscordMessagesBlock(Block):
|
||||
@@ -13,16 +12,18 @@ class ReadDiscordMessagesBlock(Block):
|
||||
discord_bot_token: BlockSecret = SecretField(
|
||||
key="discord_bot_token", description="Discord bot token"
|
||||
)
|
||||
continuous_read: bool = Field(
|
||||
continuous_read: bool = SchemaField(
|
||||
description="Whether to continuously read messages", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
message_content: str = Field(description="The content of the message received")
|
||||
channel_name: str = Field(
|
||||
message_content: str = SchemaField(
|
||||
description="The content of the message received"
|
||||
)
|
||||
channel_name: str = SchemaField(
|
||||
description="The name of the channel the message was received from"
|
||||
)
|
||||
username: str = Field(
|
||||
username: str = SchemaField(
|
||||
description="The username of the user who sent the message"
|
||||
)
|
||||
|
||||
@@ -134,13 +135,15 @@ class SendDiscordMessageBlock(Block):
|
||||
discord_bot_token: BlockSecret = SecretField(
|
||||
key="discord_bot_token", description="Discord bot token"
|
||||
)
|
||||
message_content: str = Field(description="The content of the message received")
|
||||
channel_name: str = Field(
|
||||
message_content: str = SchemaField(
|
||||
description="The content of the message received"
|
||||
)
|
||||
channel_name: str = SchemaField(
|
||||
description="The name of the channel the message was received from"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = Field(
|
||||
status: str = SchemaField(
|
||||
description="The status of the operation (e.g., 'Message sent', 'Error')"
|
||||
)
|
||||
|
||||
|
||||
@@ -2,17 +2,17 @@ import smtplib
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||
|
||||
|
||||
class EmailCredentials(BaseModel):
|
||||
smtp_server: str = Field(
|
||||
smtp_server: str = SchemaField(
|
||||
default="smtp.gmail.com", description="SMTP server address"
|
||||
)
|
||||
smtp_port: int = Field(default=25, description="SMTP port number")
|
||||
smtp_port: int = SchemaField(default=25, description="SMTP port number")
|
||||
smtp_username: BlockSecret = SecretField(key="smtp_username")
|
||||
smtp_password: BlockSecret = SecretField(key="smtp_password")
|
||||
|
||||
@@ -30,7 +30,7 @@ class SendEmailBlock(Block):
|
||||
body: str = SchemaField(
|
||||
description="Body of the email", placeholder="Enter the email body"
|
||||
)
|
||||
creds: EmailCredentials = Field(
|
||||
creds: EmailCredentials = SchemaField(
|
||||
description="SMTP credentials",
|
||||
default=EmailCredentials(),
|
||||
)
|
||||
@@ -67,35 +67,28 @@ class SendEmailBlock(Block):
|
||||
def send_email(
|
||||
creds: EmailCredentials, to_email: str, subject: str, body: str
|
||||
) -> str:
|
||||
try:
|
||||
smtp_server = creds.smtp_server
|
||||
smtp_port = creds.smtp_port
|
||||
smtp_username = creds.smtp_username.get_secret_value()
|
||||
smtp_password = creds.smtp_password.get_secret_value()
|
||||
smtp_server = creds.smtp_server
|
||||
smtp_port = creds.smtp_port
|
||||
smtp_username = creds.smtp_username.get_secret_value()
|
||||
smtp_password = creds.smtp_password.get_secret_value()
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg["From"] = smtp_username
|
||||
msg["To"] = to_email
|
||||
msg["Subject"] = subject
|
||||
msg.attach(MIMEText(body, "plain"))
|
||||
msg = MIMEMultipart()
|
||||
msg["From"] = smtp_username
|
||||
msg["To"] = to_email
|
||||
msg["Subject"] = subject
|
||||
msg.attach(MIMEText(body, "plain"))
|
||||
|
||||
with smtplib.SMTP(smtp_server, smtp_port) as server:
|
||||
server.starttls()
|
||||
server.login(smtp_username, smtp_password)
|
||||
server.sendmail(smtp_username, to_email, msg.as_string())
|
||||
with smtplib.SMTP(smtp_server, smtp_port) as server:
|
||||
server.starttls()
|
||||
server.login(smtp_username, smtp_password)
|
||||
server.sendmail(smtp_username, to_email, msg.as_string())
|
||||
|
||||
return "Email sent successfully"
|
||||
except Exception as e:
|
||||
return f"Failed to send email: {str(e)}"
|
||||
return "Email sent successfully"
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
status = self.send_email(
|
||||
yield "status", self.send_email(
|
||||
input_data.creds,
|
||||
input_data.to_email,
|
||||
input_data.subject,
|
||||
input_data.body,
|
||||
)
|
||||
if "successfully" in status:
|
||||
yield "status", status
|
||||
else:
|
||||
yield "error", status
|
||||
|
||||
@@ -13,6 +13,7 @@ from ._auth import (
|
||||
)
|
||||
|
||||
|
||||
# --8<-- [start:GithubCommentBlockExample]
|
||||
class GithubCommentBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
@@ -92,16 +93,16 @@ class GithubCommentBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
id, url = self.post_comment(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.comment,
|
||||
)
|
||||
yield "id", id
|
||||
yield "url", url
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to post comment: {str(e)}"
|
||||
id, url = self.post_comment(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.comment,
|
||||
)
|
||||
yield "id", id
|
||||
yield "url", url
|
||||
|
||||
|
||||
# --8<-- [end:GithubCommentBlockExample]
|
||||
|
||||
|
||||
class GithubMakeIssueBlock(Block):
|
||||
@@ -175,17 +176,14 @@ class GithubMakeIssueBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
number, url = self.create_issue(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.title,
|
||||
input_data.body,
|
||||
)
|
||||
yield "number", number
|
||||
yield "url", url
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create issue: {str(e)}"
|
||||
number, url = self.create_issue(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.title,
|
||||
input_data.body,
|
||||
)
|
||||
yield "number", number
|
||||
yield "url", url
|
||||
|
||||
|
||||
class GithubReadIssueBlock(Block):
|
||||
@@ -258,16 +256,13 @@ class GithubReadIssueBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
title, body, user = self.read_issue(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
)
|
||||
yield "title", title
|
||||
yield "body", body
|
||||
yield "user", user
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to read issue: {str(e)}"
|
||||
title, body, user = self.read_issue(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
)
|
||||
yield "title", title
|
||||
yield "body", body
|
||||
yield "user", user
|
||||
|
||||
|
||||
class GithubListIssuesBlock(Block):
|
||||
@@ -346,14 +341,11 @@ class GithubListIssuesBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
issues = self.list_issues(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("issue", issue) for issue in issues)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list issues: {str(e)}"
|
||||
issues = self.list_issues(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("issue", issue) for issue in issues)
|
||||
|
||||
|
||||
class GithubAddLabelBlock(Block):
|
||||
@@ -424,15 +416,12 @@ class GithubAddLabelBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.add_label(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.label,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to add label: {str(e)}"
|
||||
status = self.add_label(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.label,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubRemoveLabelBlock(Block):
|
||||
@@ -508,15 +497,12 @@ class GithubRemoveLabelBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.remove_label(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.label,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to remove label: {str(e)}"
|
||||
status = self.remove_label(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.label,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubAssignIssueBlock(Block):
|
||||
@@ -590,15 +576,12 @@ class GithubAssignIssueBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.assign_issue(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.assignee,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to assign issue: {str(e)}"
|
||||
status = self.assign_issue(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.assignee,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubUnassignIssueBlock(Block):
|
||||
@@ -672,12 +655,9 @@ class GithubUnassignIssueBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.unassign_issue(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.assignee,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to unassign issue: {str(e)}"
|
||||
status = self.unassign_issue(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.assignee,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
@@ -87,14 +87,11 @@ class GithubListPullRequestsBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
pull_requests = self.list_prs(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("pull_request", pr) for pr in pull_requests)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list pull requests: {str(e)}"
|
||||
pull_requests = self.list_prs(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("pull_request", pr) for pr in pull_requests)
|
||||
|
||||
|
||||
class GithubMakePullRequestBlock(Block):
|
||||
@@ -203,9 +200,7 @@ class GithubMakePullRequestBlock(Block):
|
||||
error_message = error_details.get("message", "Unknown error")
|
||||
else:
|
||||
error_message = str(http_err)
|
||||
yield "error", f"Failed to create pull request: {error_message}"
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create pull request: {str(e)}"
|
||||
raise RuntimeError(f"Failed to create pull request: {error_message}")
|
||||
|
||||
|
||||
class GithubReadPullRequestBlock(Block):
|
||||
@@ -313,23 +308,20 @@ class GithubReadPullRequestBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
title, body, author = self.read_pr(
|
||||
title, body, author = self.read_pr(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
)
|
||||
yield "title", title
|
||||
yield "body", body
|
||||
yield "author", author
|
||||
|
||||
if input_data.include_pr_changes:
|
||||
changes = self.read_pr_changes(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
)
|
||||
yield "title", title
|
||||
yield "body", body
|
||||
yield "author", author
|
||||
|
||||
if input_data.include_pr_changes:
|
||||
changes = self.read_pr_changes(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
)
|
||||
yield "changes", changes
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to read pull request: {str(e)}"
|
||||
yield "changes", changes
|
||||
|
||||
|
||||
class GithubAssignPRReviewerBlock(Block):
|
||||
@@ -418,9 +410,7 @@ class GithubAssignPRReviewerBlock(Block):
|
||||
)
|
||||
else:
|
||||
error_msg = f"HTTP error: {http_err} - {http_err.response.text}"
|
||||
yield "error", error_msg
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to assign reviewer: {str(e)}"
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
class GithubUnassignPRReviewerBlock(Block):
|
||||
@@ -490,15 +480,12 @@ class GithubUnassignPRReviewerBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.unassign_reviewer(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
input_data.reviewer,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to unassign reviewer: {str(e)}"
|
||||
status = self.unassign_reviewer(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
input_data.reviewer,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubListPRReviewersBlock(Block):
|
||||
@@ -586,11 +573,8 @@ class GithubListPRReviewersBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
reviewers = self.list_reviewers(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
)
|
||||
yield from (("reviewer", reviewer) for reviewer in reviewers)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list reviewers: {str(e)}"
|
||||
reviewers = self.list_reviewers(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
)
|
||||
yield from (("reviewer", reviewer) for reviewer in reviewers)
|
||||
|
||||
@@ -96,14 +96,11 @@ class GithubListTagsBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
tags = self.list_tags(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("tag", tag) for tag in tags)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list tags: {str(e)}"
|
||||
tags = self.list_tags(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("tag", tag) for tag in tags)
|
||||
|
||||
|
||||
class GithubListBranchesBlock(Block):
|
||||
@@ -183,14 +180,11 @@ class GithubListBranchesBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
branches = self.list_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("branch", branch) for branch in branches)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list branches: {str(e)}"
|
||||
branches = self.list_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("branch", branch) for branch in branches)
|
||||
|
||||
|
||||
class GithubListDiscussionsBlock(Block):
|
||||
@@ -294,13 +288,10 @@ class GithubListDiscussionsBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
discussions = self.list_discussions(
|
||||
credentials, input_data.repo_url, input_data.num_discussions
|
||||
)
|
||||
yield from (("discussion", discussion) for discussion in discussions)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list discussions: {str(e)}"
|
||||
discussions = self.list_discussions(
|
||||
credentials, input_data.repo_url, input_data.num_discussions
|
||||
)
|
||||
yield from (("discussion", discussion) for discussion in discussions)
|
||||
|
||||
|
||||
class GithubListReleasesBlock(Block):
|
||||
@@ -381,14 +372,11 @@ class GithubListReleasesBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
releases = self.list_releases(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("release", release) for release in releases)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list releases: {str(e)}"
|
||||
releases = self.list_releases(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("release", release) for release in releases)
|
||||
|
||||
|
||||
class GithubReadFileBlock(Block):
|
||||
@@ -474,18 +462,15 @@ class GithubReadFileBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
raw_content, size = self.read_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
yield "raw_content", raw_content
|
||||
yield "text_content", base64.b64decode(raw_content).decode("utf-8")
|
||||
yield "size", size
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to read file: {str(e)}"
|
||||
raw_content, size = self.read_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
yield "raw_content", raw_content
|
||||
yield "text_content", base64.b64decode(raw_content).decode("utf-8")
|
||||
yield "size", size
|
||||
|
||||
|
||||
class GithubReadFolderBlock(Block):
|
||||
@@ -612,17 +597,14 @@ class GithubReadFolderBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
files, dirs = self.read_folder(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.folder_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
yield from (("file", file) for file in files)
|
||||
yield from (("dir", dir) for dir in dirs)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to read folder: {str(e)}"
|
||||
files, dirs = self.read_folder(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.folder_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
yield from (("file", file) for file in files)
|
||||
yield from (("dir", dir) for dir in dirs)
|
||||
|
||||
|
||||
class GithubMakeBranchBlock(Block):
|
||||
@@ -703,16 +685,13 @@ class GithubMakeBranchBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.create_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.new_branch,
|
||||
input_data.source_branch,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create branch: {str(e)}"
|
||||
status = self.create_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.new_branch,
|
||||
input_data.source_branch,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubDeleteBranchBlock(Block):
|
||||
@@ -775,12 +754,9 @@ class GithubDeleteBranchBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.delete_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to delete branch: {str(e)}"
|
||||
status = self.delete_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
@@ -6,11 +6,12 @@ from pydantic import SecretStr
|
||||
from backend.data.model import CredentialsField, CredentialsMetaInput
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
# --8<-- [start:GoogleOAuthIsConfigured]
|
||||
secrets = Secrets()
|
||||
GOOGLE_OAUTH_IS_CONFIGURED = bool(
|
||||
secrets.google_client_id and secrets.google_client_secret
|
||||
)
|
||||
|
||||
# --8<-- [end:GoogleOAuthIsConfigured]
|
||||
GoogleCredentials = OAuth2Credentials
|
||||
GoogleCredentialsInput = CredentialsMetaInput[Literal["google"], Literal["oauth2"]]
|
||||
|
||||
|
||||
@@ -104,16 +104,11 @@ class GmailReadBlock(Block):
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
messages = self._read_emails(
|
||||
service, input_data.query, input_data.max_results
|
||||
)
|
||||
for email in messages:
|
||||
yield "email", email
|
||||
yield "emails", messages
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
messages = self._read_emails(service, input_data.query, input_data.max_results)
|
||||
for email in messages:
|
||||
yield "email", email
|
||||
yield "emails", messages
|
||||
|
||||
@staticmethod
|
||||
def _build_service(credentials: GoogleCredentials, **kwargs):
|
||||
@@ -267,14 +262,11 @@ class GmailSendBlock(Block):
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
service = GmailReadBlock._build_service(credentials, **kwargs)
|
||||
send_result = self._send_email(
|
||||
service, input_data.to, input_data.subject, input_data.body
|
||||
)
|
||||
yield "result", send_result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
service = GmailReadBlock._build_service(credentials, **kwargs)
|
||||
send_result = self._send_email(
|
||||
service, input_data.to, input_data.subject, input_data.body
|
||||
)
|
||||
yield "result", send_result
|
||||
|
||||
def _send_email(self, service, to: str, subject: str, body: str) -> dict:
|
||||
if not to or not subject or not body:
|
||||
@@ -342,12 +334,9 @@ class GmailListLabelsBlock(Block):
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
service = GmailReadBlock._build_service(credentials, **kwargs)
|
||||
labels = self._list_labels(service)
|
||||
yield "result", labels
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
service = GmailReadBlock._build_service(credentials, **kwargs)
|
||||
labels = self._list_labels(service)
|
||||
yield "result", labels
|
||||
|
||||
def _list_labels(self, service) -> list[dict]:
|
||||
results = service.users().labels().list(userId="me").execute()
|
||||
@@ -406,14 +395,9 @@ class GmailAddLabelBlock(Block):
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
service = GmailReadBlock._build_service(credentials, **kwargs)
|
||||
result = self._add_label(
|
||||
service, input_data.message_id, input_data.label_name
|
||||
)
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
service = GmailReadBlock._build_service(credentials, **kwargs)
|
||||
result = self._add_label(service, input_data.message_id, input_data.label_name)
|
||||
yield "result", result
|
||||
|
||||
def _add_label(self, service, message_id: str, label_name: str) -> dict:
|
||||
label_id = self._get_or_create_label(service, label_name)
|
||||
@@ -494,14 +478,11 @@ class GmailRemoveLabelBlock(Block):
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
service = GmailReadBlock._build_service(credentials, **kwargs)
|
||||
result = self._remove_label(
|
||||
service, input_data.message_id, input_data.label_name
|
||||
)
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
service = GmailReadBlock._build_service(credentials, **kwargs)
|
||||
result = self._remove_label(
|
||||
service, input_data.message_id, input_data.label_name
|
||||
)
|
||||
yield "result", result
|
||||
|
||||
def _remove_label(self, service, message_id: str, label_name: str) -> dict:
|
||||
label_id = self._get_label_id(service, label_name)
|
||||
|
||||
@@ -68,14 +68,9 @@ class GoogleSheetsReadBlock(Block):
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
data = self._read_sheet(
|
||||
service, input_data.spreadsheet_id, input_data.range
|
||||
)
|
||||
yield "result", data
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
data = self._read_sheet(service, input_data.spreadsheet_id, input_data.range)
|
||||
yield "result", data
|
||||
|
||||
@staticmethod
|
||||
def _build_service(credentials: GoogleCredentials, **kwargs):
|
||||
@@ -162,17 +157,14 @@ class GoogleSheetsWriteBlock(Block):
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
service = GoogleSheetsReadBlock._build_service(credentials, **kwargs)
|
||||
result = self._write_sheet(
|
||||
service,
|
||||
input_data.spreadsheet_id,
|
||||
input_data.range,
|
||||
input_data.values,
|
||||
)
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
service = GoogleSheetsReadBlock._build_service(credentials, **kwargs)
|
||||
result = self._write_sheet(
|
||||
service,
|
||||
input_data.spreadsheet_id,
|
||||
input_data.range,
|
||||
input_data.values,
|
||||
)
|
||||
yield "result", result
|
||||
|
||||
def _write_sheet(
|
||||
self, service, spreadsheet_id: str, range: str, values: list[list[str]]
|
||||
|
||||
@@ -82,17 +82,14 @@ class GoogleMapsSearchBlock(Block):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
places = self.search_places(
|
||||
input_data.api_key.get_secret_value(),
|
||||
input_data.query,
|
||||
input_data.radius,
|
||||
input_data.max_results,
|
||||
)
|
||||
for place in places:
|
||||
yield "place", place
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
places = self.search_places(
|
||||
input_data.api_key.get_secret_value(),
|
||||
input_data.query,
|
||||
input_data.radius,
|
||||
input_data.max_results,
|
||||
)
|
||||
for place in places:
|
||||
yield "place", place
|
||||
|
||||
def search_places(self, api_key, query, radius, max_results):
|
||||
client = googlemaps.Client(key=api_key)
|
||||
|
||||
@@ -4,6 +4,7 @@ from enum import Enum
|
||||
import requests
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class HttpMethod(Enum):
|
||||
@@ -18,15 +19,27 @@ class HttpMethod(Enum):
|
||||
|
||||
class SendWebRequestBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
url: str
|
||||
method: HttpMethod = HttpMethod.POST
|
||||
headers: dict[str, str] = {}
|
||||
body: object = {}
|
||||
url: str = SchemaField(
|
||||
description="The URL to send the request to",
|
||||
placeholder="https://api.example.com",
|
||||
)
|
||||
method: HttpMethod = SchemaField(
|
||||
description="The HTTP method to use for the request",
|
||||
default=HttpMethod.POST,
|
||||
)
|
||||
headers: dict[str, str] = SchemaField(
|
||||
description="The headers to include in the request",
|
||||
default={},
|
||||
)
|
||||
body: object = SchemaField(
|
||||
description="The body of the request",
|
||||
default={},
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: object
|
||||
client_error: object
|
||||
server_error: object
|
||||
response: object = SchemaField(description="The response from the server")
|
||||
client_error: object = SchemaField(description="The error on 4xx status codes")
|
||||
server_error: object = SchemaField(description="The error on 5xx status codes")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -75,28 +75,24 @@ class IdeogramModelBlock(Block):
|
||||
description="The name of the Image Generation Model, e.g., V_2",
|
||||
default=IdeogramModelName.V2,
|
||||
title="Image Generation Model",
|
||||
enum=IdeogramModelName,
|
||||
advanced=False,
|
||||
)
|
||||
aspect_ratio: AspectRatio = SchemaField(
|
||||
description="Aspect ratio for the generated image",
|
||||
default=AspectRatio.ASPECT_1_1,
|
||||
title="Aspect Ratio",
|
||||
enum=AspectRatio,
|
||||
advanced=False,
|
||||
)
|
||||
upscale: UpscaleOption = SchemaField(
|
||||
description="Upscale the generated image",
|
||||
default=UpscaleOption.NO_UPSCALE,
|
||||
title="Upscale Image",
|
||||
enum=UpscaleOption,
|
||||
advanced=False,
|
||||
)
|
||||
magic_prompt_option: MagicPromptOption = SchemaField(
|
||||
description="Whether to use MagicPrompt for enhancing the request",
|
||||
default=MagicPromptOption.AUTO,
|
||||
title="Magic Prompt Option",
|
||||
enum=MagicPromptOption,
|
||||
advanced=True,
|
||||
)
|
||||
seed: Optional[int] = SchemaField(
|
||||
@@ -109,7 +105,6 @@ class IdeogramModelBlock(Block):
|
||||
description="Style type to apply, applicable for V_2 and above",
|
||||
default=StyleType.AUTO,
|
||||
title="Style Type",
|
||||
enum=StyleType,
|
||||
advanced=True,
|
||||
)
|
||||
negative_prompt: Optional[str] = SchemaField(
|
||||
@@ -122,15 +117,12 @@ class IdeogramModelBlock(Block):
|
||||
description="Color palette preset name, choose 'None' to skip",
|
||||
default=ColorPalettePreset.NONE,
|
||||
title="Color Palette Preset",
|
||||
enum=ColorPalettePreset,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: str = SchemaField(description="Generated image URL")
|
||||
error: Optional[str] = SchemaField(
|
||||
description="Error message if the model run failed"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the model run failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -166,30 +158,27 @@ class IdeogramModelBlock(Block):
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
seed = input_data.seed
|
||||
|
||||
try:
|
||||
# Step 1: Generate the image
|
||||
result = self.run_model(
|
||||
# Step 1: Generate the image
|
||||
result = self.run_model(
|
||||
api_key=input_data.api_key.get_secret_value(),
|
||||
model_name=input_data.ideogram_model_name.value,
|
||||
prompt=input_data.prompt,
|
||||
seed=seed,
|
||||
aspect_ratio=input_data.aspect_ratio.value,
|
||||
magic_prompt_option=input_data.magic_prompt_option.value,
|
||||
style_type=input_data.style_type.value,
|
||||
negative_prompt=input_data.negative_prompt,
|
||||
color_palette_name=input_data.color_palette_name.value,
|
||||
)
|
||||
|
||||
# Step 2: Upscale the image if requested
|
||||
if input_data.upscale == UpscaleOption.AI_UPSCALE:
|
||||
result = self.upscale_image(
|
||||
api_key=input_data.api_key.get_secret_value(),
|
||||
model_name=input_data.ideogram_model_name.value,
|
||||
prompt=input_data.prompt,
|
||||
seed=seed,
|
||||
aspect_ratio=input_data.aspect_ratio.value,
|
||||
magic_prompt_option=input_data.magic_prompt_option.value,
|
||||
style_type=input_data.style_type.value,
|
||||
negative_prompt=input_data.negative_prompt,
|
||||
color_palette_name=input_data.color_palette_name.value,
|
||||
image_url=result,
|
||||
)
|
||||
|
||||
# Step 2: Upscale the image if requested
|
||||
if input_data.upscale == UpscaleOption.AI_UPSCALE:
|
||||
result = self.upscale_image(
|
||||
api_key=input_data.api_key.get_secret_value(),
|
||||
image_url=result,
|
||||
)
|
||||
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "result", result
|
||||
|
||||
def run_model(
|
||||
self,
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import ast
|
||||
import logging
|
||||
from enum import Enum
|
||||
from enum import Enum, EnumMeta
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, List, NamedTuple
|
||||
from types import MappingProxyType
|
||||
from typing import TYPE_CHECKING, Any, List, NamedTuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enum import _EnumMemberT
|
||||
|
||||
import anthropic
|
||||
import ollama
|
||||
@@ -12,6 +16,7 @@ from groq import Groq
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||
from backend.util import json
|
||||
from backend.util.settings import BehaveAs, Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,7 +34,26 @@ class ModelMetadata(NamedTuple):
|
||||
cost_factor: int
|
||||
|
||||
|
||||
class LlmModel(str, Enum):
|
||||
class LlmModelMeta(EnumMeta):
|
||||
@property
|
||||
def __members__(
|
||||
self: type["_EnumMemberT"],
|
||||
) -> MappingProxyType[str, "_EnumMemberT"]:
|
||||
if Settings().config.behave_as == BehaveAs.LOCAL:
|
||||
members = super().__members__
|
||||
return members
|
||||
else:
|
||||
removed_providers = ["ollama"]
|
||||
existing_members = super().__members__
|
||||
members = {
|
||||
name: member
|
||||
for name, member in existing_members.items()
|
||||
if LlmModel[name].provider not in removed_providers
|
||||
}
|
||||
return MappingProxyType(members)
|
||||
|
||||
|
||||
class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
# OpenAI models
|
||||
O1_PREVIEW = "o1-preview"
|
||||
O1_MINI = "o1-mini"
|
||||
@@ -58,27 +82,39 @@ class LlmModel(str, Enum):
|
||||
def metadata(self) -> ModelMetadata:
|
||||
return MODEL_METADATA[self]
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return self.metadata.provider
|
||||
|
||||
@property
|
||||
def context_window(self) -> int:
|
||||
return self.metadata.context_window
|
||||
|
||||
@property
|
||||
def cost_factor(self) -> int:
|
||||
return self.metadata.cost_factor
|
||||
|
||||
|
||||
MODEL_METADATA = {
|
||||
LlmModel.O1_PREVIEW: ModelMetadata("openai", 32000, cost_factor=60),
|
||||
LlmModel.O1_MINI: ModelMetadata("openai", 62000, cost_factor=30),
|
||||
LlmModel.GPT4O_MINI: ModelMetadata("openai", 128000, cost_factor=10),
|
||||
LlmModel.GPT4O: ModelMetadata("openai", 128000, cost_factor=12),
|
||||
LlmModel.GPT4_TURBO: ModelMetadata("openai", 128000, cost_factor=11),
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, cost_factor=8),
|
||||
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata("anthropic", 200000, cost_factor=14),
|
||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata("anthropic", 200000, cost_factor=13),
|
||||
LlmModel.LLAMA3_8B: ModelMetadata("groq", 8192, cost_factor=6),
|
||||
LlmModel.LLAMA3_70B: ModelMetadata("groq", 8192, cost_factor=9),
|
||||
LlmModel.MIXTRAL_8X7B: ModelMetadata("groq", 32768, cost_factor=7),
|
||||
LlmModel.GEMMA_7B: ModelMetadata("groq", 8192, cost_factor=6),
|
||||
LlmModel.GEMMA2_9B: ModelMetadata("groq", 8192, cost_factor=7),
|
||||
LlmModel.LLAMA3_1_405B: ModelMetadata("groq", 8192, cost_factor=10),
|
||||
LlmModel.O1_PREVIEW: ModelMetadata("openai", 32000, cost_factor=16),
|
||||
LlmModel.O1_MINI: ModelMetadata("openai", 62000, cost_factor=4),
|
||||
LlmModel.GPT4O_MINI: ModelMetadata("openai", 128000, cost_factor=1),
|
||||
LlmModel.GPT4O: ModelMetadata("openai", 128000, cost_factor=3),
|
||||
LlmModel.GPT4_TURBO: ModelMetadata("openai", 128000, cost_factor=10),
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, cost_factor=1),
|
||||
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata("anthropic", 200000, cost_factor=4),
|
||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata("anthropic", 200000, cost_factor=1),
|
||||
LlmModel.LLAMA3_8B: ModelMetadata("groq", 8192, cost_factor=1),
|
||||
LlmModel.LLAMA3_70B: ModelMetadata("groq", 8192, cost_factor=1),
|
||||
LlmModel.MIXTRAL_8X7B: ModelMetadata("groq", 32768, cost_factor=1),
|
||||
LlmModel.GEMMA_7B: ModelMetadata("groq", 8192, cost_factor=1),
|
||||
LlmModel.GEMMA2_9B: ModelMetadata("groq", 8192, cost_factor=1),
|
||||
LlmModel.LLAMA3_1_405B: ModelMetadata("groq", 8192, cost_factor=1),
|
||||
# Limited to 16k during preview
|
||||
LlmModel.LLAMA3_1_70B: ModelMetadata("groq", 131072, cost_factor=15),
|
||||
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072, cost_factor=13),
|
||||
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192, cost_factor=7),
|
||||
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192, cost_factor=11),
|
||||
LlmModel.LLAMA3_1_70B: ModelMetadata("groq", 131072, cost_factor=1),
|
||||
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072, cost_factor=1),
|
||||
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192, cost_factor=1),
|
||||
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192, cost_factor=1),
|
||||
}
|
||||
|
||||
for model in LlmModel:
|
||||
@@ -88,7 +124,10 @@ for model in LlmModel:
|
||||
|
||||
class AIStructuredResponseGeneratorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
prompt: str
|
||||
prompt: str = SchemaField(
|
||||
description="The prompt to send to the language model.",
|
||||
placeholder="Enter your prompt here...",
|
||||
)
|
||||
expected_format: dict[str, str] = SchemaField(
|
||||
description="Expected format of the response. If provided, the response will be validated against this format. "
|
||||
"The keys should be the expected fields in the response, and the values should be the description of the field.",
|
||||
@@ -100,15 +139,25 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
api_key: BlockSecret = SecretField(value="")
|
||||
sys_prompt: str = ""
|
||||
retry: int = 3
|
||||
sys_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default="",
|
||||
description="The system prompt to provide additional context to the model.",
|
||||
)
|
||||
retry: int = SchemaField(
|
||||
title="Retry Count",
|
||||
default=3,
|
||||
description="Number of times to retry the LLM call if the response does not match the expected format.",
|
||||
)
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False, default={}, description="Values used to fill in the prompt."
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: dict[str, Any]
|
||||
error: str
|
||||
response: dict[str, Any] = SchemaField(
|
||||
description="The response object generated by the language model."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -308,12 +357,15 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
logger.error(f"Error calling LLM: {e}")
|
||||
retry_prompt = f"Error calling LLM: {e}"
|
||||
|
||||
yield "error", retry_prompt
|
||||
raise RuntimeError(retry_prompt)
|
||||
|
||||
|
||||
class AITextGeneratorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
prompt: str
|
||||
prompt: str = SchemaField(
|
||||
description="The prompt to send to the language model.",
|
||||
placeholder="Enter your prompt here...",
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4_TURBO,
|
||||
@@ -321,15 +373,25 @@ class AITextGeneratorBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
api_key: BlockSecret = SecretField(value="")
|
||||
sys_prompt: str = ""
|
||||
retry: int = 3
|
||||
sys_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default="",
|
||||
description="The system prompt to provide additional context to the model.",
|
||||
)
|
||||
retry: int = SchemaField(
|
||||
title="Retry Count",
|
||||
default=3,
|
||||
description="Number of times to retry the LLM call if the response does not match the expected format.",
|
||||
)
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False, default={}, description="Values used to fill in the prompt."
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: str
|
||||
error: str
|
||||
response: str = SchemaField(
|
||||
description="The response generated by the language model."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -354,14 +416,11 @@ class AITextGeneratorBlock(Block):
|
||||
raise ValueError("Failed to get a response from the LLM.")
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
object_input_data = AIStructuredResponseGeneratorBlock.Input(
|
||||
**{attr: getattr(input_data, attr) for attr in input_data.model_fields},
|
||||
expected_format={},
|
||||
)
|
||||
yield "response", self.llm_call(object_input_data)
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
object_input_data = AIStructuredResponseGeneratorBlock.Input(
|
||||
**{attr: getattr(input_data, attr) for attr in input_data.model_fields},
|
||||
expected_format={},
|
||||
)
|
||||
yield "response", self.llm_call(object_input_data)
|
||||
|
||||
|
||||
class SummaryStyle(Enum):
|
||||
@@ -373,22 +432,43 @@ class SummaryStyle(Enum):
|
||||
|
||||
class AITextSummarizerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str
|
||||
text: str = SchemaField(
|
||||
description="The text to summarize.",
|
||||
placeholder="Enter the text to summarize here...",
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4_TURBO,
|
||||
description="The language model to use for summarizing the text.",
|
||||
)
|
||||
focus: str = "general information"
|
||||
style: SummaryStyle = SummaryStyle.CONCISE
|
||||
focus: str = SchemaField(
|
||||
title="Focus",
|
||||
default="general information",
|
||||
description="The topic to focus on in the summary",
|
||||
)
|
||||
style: SummaryStyle = SchemaField(
|
||||
title="Summary Style",
|
||||
default=SummaryStyle.CONCISE,
|
||||
description="The style of the summary to generate.",
|
||||
)
|
||||
api_key: BlockSecret = SecretField(value="")
|
||||
# TODO: Make this dynamic
|
||||
max_tokens: int = 4000 # Adjust based on the model's context window
|
||||
chunk_overlap: int = 100 # Overlap between chunks to maintain context
|
||||
max_tokens: int = SchemaField(
|
||||
title="Max Tokens",
|
||||
default=4096,
|
||||
description="The maximum number of tokens to generate in the chat completion.",
|
||||
ge=1,
|
||||
)
|
||||
chunk_overlap: int = SchemaField(
|
||||
title="Chunk Overlap",
|
||||
default=100,
|
||||
description="The number of overlapping tokens between chunks to maintain context.",
|
||||
ge=0,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
summary: str
|
||||
error: str
|
||||
summary: str = SchemaField(description="The final summary of the text.")
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -409,11 +489,8 @@ class AITextSummarizerBlock(Block):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
for output in self._run(input_data):
|
||||
yield output
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
for output in self._run(input_data):
|
||||
yield output
|
||||
|
||||
def _run(self, input_data: Input) -> BlockOutput:
|
||||
chunks = self._split_text(
|
||||
@@ -606,24 +683,21 @@ class AIConversationBlock(Block):
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
api_key = (
|
||||
input_data.api_key.get_secret_value()
|
||||
or LlmApiKeys[input_data.model.metadata.provider].get_secret_value()
|
||||
)
|
||||
api_key = (
|
||||
input_data.api_key.get_secret_value()
|
||||
or LlmApiKeys[input_data.model.metadata.provider].get_secret_value()
|
||||
)
|
||||
|
||||
messages = [message.model_dump() for message in input_data.messages]
|
||||
messages = [message.model_dump() for message in input_data.messages]
|
||||
|
||||
response = self.llm_call(
|
||||
api_key=api_key,
|
||||
model=input_data.model,
|
||||
messages=messages,
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
response = self.llm_call(
|
||||
api_key=api_key,
|
||||
model=input_data.model,
|
||||
messages=messages,
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
|
||||
yield "response", response
|
||||
except Exception as e:
|
||||
yield "error", f"Error calling LLM: {str(e)}"
|
||||
yield "response", response
|
||||
|
||||
|
||||
class AIListGeneratorBlock(Block):
|
||||
@@ -741,9 +815,7 @@ class AIListGeneratorBlock(Block):
|
||||
or LlmApiKeys[input_data.model.metadata.provider].get_secret_value()
|
||||
)
|
||||
if not api_key_check:
|
||||
logger.error("No LLM API key provided.")
|
||||
yield "error", "No LLM API key provided."
|
||||
return
|
||||
raise ValueError("No LLM API key provided.")
|
||||
|
||||
# Prepare the system prompt
|
||||
sys_prompt = """You are a Python list generator. Your task is to generate a Python list based on the user's prompt.
|
||||
@@ -837,7 +909,9 @@ class AIListGeneratorBlock(Block):
|
||||
logger.error(
|
||||
f"Failed to generate a valid Python list after {input_data.max_retries} attempts"
|
||||
)
|
||||
yield "error", f"Failed to generate a valid Python list after {input_data.max_retries} attempts. Last error: {str(e)}"
|
||||
raise RuntimeError(
|
||||
f"Failed to generate a valid Python list after {input_data.max_retries} attempts. Last error: {str(e)}"
|
||||
)
|
||||
else:
|
||||
# Add a retry prompt
|
||||
logger.debug("Preparing retry prompt")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
@@ -6,6 +7,12 @@ from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||
|
||||
|
||||
class PublishToMediumStatus(str, Enum):
|
||||
PUBLIC = "public"
|
||||
DRAFT = "draft"
|
||||
UNLISTED = "unlisted"
|
||||
|
||||
|
||||
class PublishToMediumBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
author_id: BlockSecret = SecretField(
|
||||
@@ -34,9 +41,9 @@ class PublishToMediumBlock(Block):
|
||||
description="The original home of this content, if it was originally published elsewhere",
|
||||
placeholder="https://yourblog.com/original-post",
|
||||
)
|
||||
publish_status: str = SchemaField(
|
||||
description="The publish status: 'public', 'draft', or 'unlisted'",
|
||||
placeholder="public",
|
||||
publish_status: PublishToMediumStatus = SchemaField(
|
||||
description="The publish status",
|
||||
placeholder=PublishToMediumStatus.DRAFT,
|
||||
)
|
||||
license: str = SchemaField(
|
||||
default="all-rights-reserved",
|
||||
@@ -79,7 +86,7 @@ class PublishToMediumBlock(Block):
|
||||
"tags": ["test", "automation"],
|
||||
"license": "all-rights-reserved",
|
||||
"notify_followers": False,
|
||||
"publish_status": "draft",
|
||||
"publish_status": PublishToMediumStatus.DRAFT.value,
|
||||
"api_key": "your_test_api_key",
|
||||
},
|
||||
test_output=[
|
||||
@@ -138,31 +145,25 @@ class PublishToMediumBlock(Block):
|
||||
return response.json()
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
response = self.create_post(
|
||||
input_data.api_key.get_secret_value(),
|
||||
input_data.author_id.get_secret_value(),
|
||||
input_data.title,
|
||||
input_data.content,
|
||||
input_data.content_format,
|
||||
input_data.tags,
|
||||
input_data.canonical_url,
|
||||
input_data.publish_status,
|
||||
input_data.license,
|
||||
input_data.notify_followers,
|
||||
response = self.create_post(
|
||||
input_data.api_key.get_secret_value(),
|
||||
input_data.author_id.get_secret_value(),
|
||||
input_data.title,
|
||||
input_data.content,
|
||||
input_data.content_format,
|
||||
input_data.tags,
|
||||
input_data.canonical_url,
|
||||
input_data.publish_status,
|
||||
input_data.license,
|
||||
input_data.notify_followers,
|
||||
)
|
||||
|
||||
if "data" in response:
|
||||
yield "post_id", response["data"]["id"]
|
||||
yield "post_url", response["data"]["url"]
|
||||
yield "published_at", response["data"]["publishedAt"]
|
||||
else:
|
||||
error_message = response.get("errors", [{}])[0].get(
|
||||
"message", "Unknown error occurred"
|
||||
)
|
||||
|
||||
if "data" in response:
|
||||
yield "post_id", response["data"]["id"]
|
||||
yield "post_url", response["data"]["url"]
|
||||
yield "published_at", response["data"]["publishedAt"]
|
||||
else:
|
||||
error_message = response.get("errors", [{}])[0].get(
|
||||
"message", "Unknown error occurred"
|
||||
)
|
||||
yield "error", f"Failed to create Medium post: {error_message}"
|
||||
|
||||
except requests.RequestException as e:
|
||||
yield "error", f"Network error occurred while creating Medium post: {str(e)}"
|
||||
except Exception as e:
|
||||
yield "error", f"Error occurred while creating Medium post: {str(e)}"
|
||||
raise RuntimeError(f"Failed to create Medium post: {error_message}")
|
||||
|
||||
@@ -2,10 +2,10 @@ from datetime import datetime, timezone
|
||||
from typing import Iterator
|
||||
|
||||
import praw
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import BlockSecret, SecretField
|
||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||
from backend.util.mock import MockObject
|
||||
|
||||
|
||||
@@ -48,25 +48,25 @@ def get_praw(creds: RedditCredentials) -> praw.Reddit:
|
||||
|
||||
class GetRedditPostsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
subreddit: str = Field(description="Subreddit name")
|
||||
creds: RedditCredentials = Field(
|
||||
subreddit: str = SchemaField(description="Subreddit name")
|
||||
creds: RedditCredentials = SchemaField(
|
||||
description="Reddit credentials",
|
||||
default=RedditCredentials(),
|
||||
)
|
||||
last_minutes: int | None = Field(
|
||||
last_minutes: int | None = SchemaField(
|
||||
description="Post time to stop minutes ago while fetching posts",
|
||||
default=None,
|
||||
)
|
||||
last_post: str | None = Field(
|
||||
last_post: str | None = SchemaField(
|
||||
description="Post ID to stop when reached while fetching posts",
|
||||
default=None,
|
||||
)
|
||||
post_limit: int | None = Field(
|
||||
post_limit: int | None = SchemaField(
|
||||
description="Number of posts to fetch", default=10
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
post: RedditPost = Field(description="Reddit post")
|
||||
post: RedditPost = SchemaField(description="Reddit post")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -140,13 +140,13 @@ class GetRedditPostsBlock(Block):
|
||||
|
||||
class PostRedditCommentBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
creds: RedditCredentials = Field(
|
||||
creds: RedditCredentials = SchemaField(
|
||||
description="Reddit credentials", default=RedditCredentials()
|
||||
)
|
||||
data: RedditComment = Field(description="Reddit comment")
|
||||
data: RedditComment = SchemaField(description="Reddit comment")
|
||||
|
||||
class Output(BlockSchema):
|
||||
comment_id: str
|
||||
comment_id: str = SchemaField(description="Posted comment ID")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -139,24 +139,21 @@ class ReplicateFluxAdvancedModelBlock(Block):
|
||||
if seed is None:
|
||||
seed = int.from_bytes(os.urandom(4), "big")
|
||||
|
||||
try:
|
||||
# Run the model using the provided inputs
|
||||
result = self.run_model(
|
||||
api_key=input_data.api_key.get_secret_value(),
|
||||
model_name=input_data.replicate_model_name.api_name,
|
||||
prompt=input_data.prompt,
|
||||
seed=seed,
|
||||
steps=input_data.steps,
|
||||
guidance=input_data.guidance,
|
||||
interval=input_data.interval,
|
||||
aspect_ratio=input_data.aspect_ratio,
|
||||
output_format=input_data.output_format,
|
||||
output_quality=input_data.output_quality,
|
||||
safety_tolerance=input_data.safety_tolerance,
|
||||
)
|
||||
yield "result", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
# Run the model using the provided inputs
|
||||
result = self.run_model(
|
||||
api_key=input_data.api_key.get_secret_value(),
|
||||
model_name=input_data.replicate_model_name.api_name,
|
||||
prompt=input_data.prompt,
|
||||
seed=seed,
|
||||
steps=input_data.steps,
|
||||
guidance=input_data.guidance,
|
||||
interval=input_data.interval,
|
||||
aspect_ratio=input_data.aspect_ratio,
|
||||
output_format=input_data.output_format,
|
||||
output_quality=input_data.output_quality,
|
||||
safety_tolerance=input_data.safety_tolerance,
|
||||
)
|
||||
yield "result", result
|
||||
|
||||
def run_model(
|
||||
self,
|
||||
|
||||
@@ -17,11 +17,13 @@ class GetRequest:
|
||||
|
||||
class GetWikipediaSummaryBlock(Block, GetRequest):
|
||||
class Input(BlockSchema):
|
||||
topic: str
|
||||
topic: str = SchemaField(description="The topic to fetch the summary for")
|
||||
|
||||
class Output(BlockSchema):
|
||||
summary: str
|
||||
error: str
|
||||
summary: str = SchemaField(description="The summary of the given topic")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the summary cannot be retrieved"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -36,29 +38,23 @@ class GetWikipediaSummaryBlock(Block, GetRequest):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
topic = input_data.topic
|
||||
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{topic}"
|
||||
response = self.get_request(url, json=True)
|
||||
yield "summary", response["extract"]
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
yield "error", f"HTTP error occurred: {http_err}"
|
||||
|
||||
except requests.RequestException as e:
|
||||
yield "error", f"Request to Wikipedia failed: {e}"
|
||||
|
||||
except KeyError as e:
|
||||
yield "error", f"Error parsing Wikipedia response: {e}"
|
||||
topic = input_data.topic
|
||||
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{topic}"
|
||||
response = self.get_request(url, json=True)
|
||||
if "extract" not in response:
|
||||
raise RuntimeError(f"Unable to parse Wikipedia response: {response}")
|
||||
yield "summary", response["extract"]
|
||||
|
||||
|
||||
class SearchTheWebBlock(Block, GetRequest):
|
||||
class Input(BlockSchema):
|
||||
query: str # The search query
|
||||
query: str = SchemaField(description="The search query to search the web for")
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: str # The search results including content from top 5 URLs
|
||||
error: str # Error message if the search fails
|
||||
results: str = SchemaField(
|
||||
description="The search results including content from top 5 URLs"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the search fails")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -73,29 +69,22 @@ class SearchTheWebBlock(Block, GetRequest):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# Encode the search query
|
||||
encoded_query = quote(input_data.query)
|
||||
# Encode the search query
|
||||
encoded_query = quote(input_data.query)
|
||||
|
||||
# Prepend the Jina Search URL to the encoded query
|
||||
jina_search_url = f"https://s.jina.ai/{encoded_query}"
|
||||
# Prepend the Jina Search URL to the encoded query
|
||||
jina_search_url = f"https://s.jina.ai/{encoded_query}"
|
||||
|
||||
# Make the request to Jina Search
|
||||
response = self.get_request(jina_search_url, json=False)
|
||||
# Make the request to Jina Search
|
||||
response = self.get_request(jina_search_url, json=False)
|
||||
|
||||
# Output the search results
|
||||
yield "results", response
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
yield "error", f"HTTP error occurred: {http_err}"
|
||||
|
||||
except requests.RequestException as e:
|
||||
yield "error", f"Request to Jina Search failed: {e}"
|
||||
# Output the search results
|
||||
yield "results", response
|
||||
|
||||
|
||||
class ExtractWebsiteContentBlock(Block, GetRequest):
|
||||
class Input(BlockSchema):
|
||||
url: str # The URL to scrape
|
||||
url: str = SchemaField(description="The URL to scrape the content from")
|
||||
raw_content: bool = SchemaField(
|
||||
default=False,
|
||||
title="Raw Content",
|
||||
@@ -104,8 +93,10 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
content: str # The scraped content from the URL
|
||||
error: str
|
||||
content: str = SchemaField(description="The scraped content from the given URL")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the content cannot be retrieved"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -125,26 +116,32 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
|
||||
else:
|
||||
url = f"https://r.jina.ai/{input_data.url}"
|
||||
|
||||
try:
|
||||
content = self.get_request(url, json=False)
|
||||
yield "content", content
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
yield "error", f"HTTP error occurred: {http_err}"
|
||||
except requests.RequestException as e:
|
||||
yield "error", f"Request to URL failed: {e}"
|
||||
content = self.get_request(url, json=False)
|
||||
yield "content", content
|
||||
|
||||
|
||||
class GetWeatherInformationBlock(Block, GetRequest):
|
||||
class Input(BlockSchema):
|
||||
location: str
|
||||
location: str = SchemaField(
|
||||
description="Location to get weather information for"
|
||||
)
|
||||
api_key: BlockSecret = SecretField(key="openweathermap_api_key")
|
||||
use_celsius: bool = True
|
||||
use_celsius: bool = SchemaField(
|
||||
default=True,
|
||||
description="Whether to use Celsius or Fahrenheit for temperature",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
temperature: str
|
||||
humidity: str
|
||||
condition: str
|
||||
error: str
|
||||
temperature: str = SchemaField(
|
||||
description="Temperature in the specified location"
|
||||
)
|
||||
humidity: str = SchemaField(description="Humidity in the specified location")
|
||||
condition: str = SchemaField(
|
||||
description="Weather condition in the specified location"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the weather information cannot be retrieved"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -171,26 +168,15 @@ class GetWeatherInformationBlock(Block, GetRequest):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
units = "metric" if input_data.use_celsius else "imperial"
|
||||
api_key = input_data.api_key.get_secret_value()
|
||||
location = input_data.location
|
||||
url = f"http://api.openweathermap.org/data/2.5/weather?q={quote(location)}&appid={api_key}&units={units}"
|
||||
weather_data = self.get_request(url, json=True)
|
||||
units = "metric" if input_data.use_celsius else "imperial"
|
||||
api_key = input_data.api_key.get_secret_value()
|
||||
location = input_data.location
|
||||
url = f"http://api.openweathermap.org/data/2.5/weather?q={quote(location)}&appid={api_key}&units={units}"
|
||||
weather_data = self.get_request(url, json=True)
|
||||
|
||||
if "main" in weather_data and "weather" in weather_data:
|
||||
yield "temperature", str(weather_data["main"]["temp"])
|
||||
yield "humidity", str(weather_data["main"]["humidity"])
|
||||
yield "condition", weather_data["weather"][0]["description"]
|
||||
else:
|
||||
yield "error", f"Expected keys not found in response: {weather_data}"
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
if http_err.response.status_code == 403:
|
||||
yield "error", "Request to weather API failed: 403 Forbidden. Check your API key and permissions."
|
||||
else:
|
||||
yield "error", f"HTTP error occurred: {http_err}"
|
||||
except requests.RequestException as e:
|
||||
yield "error", f"Request to weather API failed: {e}"
|
||||
except KeyError as e:
|
||||
yield "error", f"Error processing weather data: {e}"
|
||||
if "main" in weather_data and "weather" in weather_data:
|
||||
yield "temperature", str(weather_data["main"]["temp"])
|
||||
yield "humidity", str(weather_data["main"]["humidity"])
|
||||
yield "condition", weather_data["weather"][0]["description"]
|
||||
else:
|
||||
raise RuntimeError(f"Expected keys not found in response: {weather_data}")
|
||||
|
||||
@@ -13,7 +13,8 @@ class CreateTalkingAvatarVideoBlock(Block):
|
||||
key="did_api_key", description="D-ID API Key"
|
||||
)
|
||||
script_input: str = SchemaField(
|
||||
description="The text input for the script", default="Welcome to AutoGPT"
|
||||
description="The text input for the script",
|
||||
placeholder="Welcome to AutoGPT",
|
||||
)
|
||||
provider: Literal["microsoft", "elevenlabs", "amazon"] = SchemaField(
|
||||
description="The voice provider to use", default="microsoft"
|
||||
@@ -106,41 +107,40 @@ class CreateTalkingAvatarVideoBlock(Block):
|
||||
return response.json()
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# Create the clip
|
||||
payload = {
|
||||
"script": {
|
||||
"type": "text",
|
||||
"subtitles": str(input_data.subtitles).lower(),
|
||||
"provider": {
|
||||
"type": input_data.provider,
|
||||
"voice_id": input_data.voice_id,
|
||||
},
|
||||
"ssml": str(input_data.ssml).lower(),
|
||||
"input": input_data.script_input,
|
||||
# Create the clip
|
||||
payload = {
|
||||
"script": {
|
||||
"type": "text",
|
||||
"subtitles": str(input_data.subtitles).lower(),
|
||||
"provider": {
|
||||
"type": input_data.provider,
|
||||
"voice_id": input_data.voice_id,
|
||||
},
|
||||
"config": {"result_format": input_data.result_format},
|
||||
"presenter_config": {"crop": {"type": input_data.crop_type}},
|
||||
"presenter_id": input_data.presenter_id,
|
||||
"driver_id": input_data.driver_id,
|
||||
}
|
||||
"ssml": str(input_data.ssml).lower(),
|
||||
"input": input_data.script_input,
|
||||
},
|
||||
"config": {"result_format": input_data.result_format},
|
||||
"presenter_config": {"crop": {"type": input_data.crop_type}},
|
||||
"presenter_id": input_data.presenter_id,
|
||||
"driver_id": input_data.driver_id,
|
||||
}
|
||||
|
||||
response = self.create_clip(input_data.api_key.get_secret_value(), payload)
|
||||
clip_id = response["id"]
|
||||
response = self.create_clip(input_data.api_key.get_secret_value(), payload)
|
||||
clip_id = response["id"]
|
||||
|
||||
# Poll for clip status
|
||||
for _ in range(input_data.max_polling_attempts):
|
||||
status_response = self.get_clip_status(
|
||||
input_data.api_key.get_secret_value(), clip_id
|
||||
# Poll for clip status
|
||||
for _ in range(input_data.max_polling_attempts):
|
||||
status_response = self.get_clip_status(
|
||||
input_data.api_key.get_secret_value(), clip_id
|
||||
)
|
||||
if status_response["status"] == "done":
|
||||
yield "video_url", status_response["result_url"]
|
||||
return
|
||||
elif status_response["status"] == "error":
|
||||
raise RuntimeError(
|
||||
f"Clip creation failed: {status_response.get('error', 'Unknown error')}"
|
||||
)
|
||||
if status_response["status"] == "done":
|
||||
yield "video_url", status_response["result_url"]
|
||||
return
|
||||
elif status_response["status"] == "error":
|
||||
yield "error", f"Clip creation failed: {status_response.get('error', 'Unknown error')}"
|
||||
return
|
||||
time.sleep(input_data.polling_interval)
|
||||
|
||||
yield "error", "Clip creation timed out"
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
time.sleep(input_data.polling_interval)
|
||||
|
||||
raise TimeoutError("Clip creation timed out")
|
||||
|
||||
@@ -2,9 +2,9 @@ import re
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import BaseLoader, Environment
|
||||
from pydantic import Field
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
|
||||
jinja = Environment(loader=BaseLoader())
|
||||
@@ -12,15 +12,17 @@ jinja = Environment(loader=BaseLoader())
|
||||
|
||||
class MatchTextPatternBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: Any = Field(description="Text to match")
|
||||
match: str = Field(description="Pattern (Regex) to match")
|
||||
data: Any = Field(description="Data to be forwarded to output")
|
||||
case_sensitive: bool = Field(description="Case sensitive match", default=True)
|
||||
dot_all: bool = Field(description="Dot matches all", default=True)
|
||||
text: Any = SchemaField(description="Text to match")
|
||||
match: str = SchemaField(description="Pattern (Regex) to match")
|
||||
data: Any = SchemaField(description="Data to be forwarded to output")
|
||||
case_sensitive: bool = SchemaField(
|
||||
description="Case sensitive match", default=True
|
||||
)
|
||||
dot_all: bool = SchemaField(description="Dot matches all", default=True)
|
||||
|
||||
class Output(BlockSchema):
|
||||
positive: Any = Field(description="Output data if match is found")
|
||||
negative: Any = Field(description="Output data if match is not found")
|
||||
positive: Any = SchemaField(description="Output data if match is found")
|
||||
negative: Any = SchemaField(description="Output data if match is not found")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -64,15 +66,17 @@ class MatchTextPatternBlock(Block):
|
||||
|
||||
class ExtractTextInformationBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: Any = Field(description="Text to parse")
|
||||
pattern: str = Field(description="Pattern (Regex) to parse")
|
||||
group: int = Field(description="Group number to extract", default=0)
|
||||
case_sensitive: bool = Field(description="Case sensitive match", default=True)
|
||||
dot_all: bool = Field(description="Dot matches all", default=True)
|
||||
text: Any = SchemaField(description="Text to parse")
|
||||
pattern: str = SchemaField(description="Pattern (Regex) to parse")
|
||||
group: int = SchemaField(description="Group number to extract", default=0)
|
||||
case_sensitive: bool = SchemaField(
|
||||
description="Case sensitive match", default=True
|
||||
)
|
||||
dot_all: bool = SchemaField(description="Dot matches all", default=True)
|
||||
|
||||
class Output(BlockSchema):
|
||||
positive: str = Field(description="Extracted text")
|
||||
negative: str = Field(description="Original text")
|
||||
positive: str = SchemaField(description="Extracted text")
|
||||
negative: str = SchemaField(description="Original text")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -116,11 +120,15 @@ class ExtractTextInformationBlock(Block):
|
||||
|
||||
class FillTextTemplateBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: dict[str, Any] = Field(description="Values (dict) to be used in format")
|
||||
format: str = Field(description="Template to format the text using `values`")
|
||||
values: dict[str, Any] = SchemaField(
|
||||
description="Values (dict) to be used in format"
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="Template to format the text using `values`"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: str
|
||||
output: str = SchemaField(description="Formatted text")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -155,11 +163,13 @@ class FillTextTemplateBlock(Block):
|
||||
|
||||
class CombineTextsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: list[str] = Field(description="text input to combine")
|
||||
delimiter: str = Field(description="Delimiter to combine texts", default="")
|
||||
input: list[str] = SchemaField(description="text input to combine")
|
||||
delimiter: str = SchemaField(
|
||||
description="Delimiter to combine texts", default=""
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: str = Field(description="Combined text")
|
||||
output: str = SchemaField(description="Combined text")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -68,12 +68,9 @@ class UnrealTextToSpeechBlock(Block):
|
||||
return response.json()
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
api_response = self.call_unreal_speech_api(
|
||||
input_data.api_key.get_secret_value(),
|
||||
input_data.text,
|
||||
input_data.voice_id,
|
||||
)
|
||||
yield "mp3_url", api_response["OutputUri"]
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
api_response = self.call_unreal_speech_api(
|
||||
input_data.api_key.get_secret_value(),
|
||||
input_data.text,
|
||||
input_data.voice_id,
|
||||
)
|
||||
yield "mp3_url", api_response["OutputUri"]
|
||||
|
||||
@@ -3,14 +3,22 @@ from datetime import datetime, timedelta
|
||||
from typing import Any, Union
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class GetCurrentTimeBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
trigger: str
|
||||
trigger: str = SchemaField(
|
||||
description="Trigger any data to output the current time"
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="Format of the time to output", default="%H:%M:%S"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
time: str
|
||||
time: str = SchemaField(
|
||||
description="Current time in the specified format (default: %H:%M:%S)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -20,25 +28,38 @@ class GetCurrentTimeBlock(Block):
|
||||
input_schema=GetCurrentTimeBlock.Input,
|
||||
output_schema=GetCurrentTimeBlock.Output,
|
||||
test_input=[
|
||||
{"trigger": "Hello", "format": "{time}"},
|
||||
{"trigger": "Hello"},
|
||||
{"trigger": "Hello", "format": "%H:%M"},
|
||||
],
|
||||
test_output=[
|
||||
("time", lambda _: time.strftime("%H:%M:%S")),
|
||||
("time", lambda _: time.strftime("%H:%M")),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
current_time = time.strftime("%H:%M:%S")
|
||||
current_time = time.strftime(input_data.format)
|
||||
yield "time", current_time
|
||||
|
||||
|
||||
class GetCurrentDateBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
trigger: str
|
||||
offset: Union[int, str]
|
||||
trigger: str = SchemaField(
|
||||
description="Trigger any data to output the current date"
|
||||
)
|
||||
offset: Union[int, str] = SchemaField(
|
||||
title="Days Offset",
|
||||
description="Offset in days from the current date",
|
||||
default=0,
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="Format of the date to output", default="%Y-%m-%d"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
date: str
|
||||
date: str = SchemaField(
|
||||
description="Current date in the specified format (default: YYYY-MM-DD)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -48,7 +69,8 @@ class GetCurrentDateBlock(Block):
|
||||
input_schema=GetCurrentDateBlock.Input,
|
||||
output_schema=GetCurrentDateBlock.Output,
|
||||
test_input=[
|
||||
{"trigger": "Hello", "format": "{date}", "offset": "7"},
|
||||
{"trigger": "Hello", "offset": "7"},
|
||||
{"trigger": "Hello", "offset": "7", "format": "%m/%d/%Y"},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
@@ -56,6 +78,12 @@ class GetCurrentDateBlock(Block):
|
||||
lambda t: abs(datetime.now() - datetime.strptime(t, "%Y-%m-%d"))
|
||||
< timedelta(days=8), # 7 days difference + 1 day error margin.
|
||||
),
|
||||
(
|
||||
"date",
|
||||
lambda t: abs(datetime.now() - datetime.strptime(t, "%m/%d/%Y"))
|
||||
< timedelta(days=8),
|
||||
# 7 days difference + 1 day error margin.
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -65,15 +93,23 @@ class GetCurrentDateBlock(Block):
|
||||
except ValueError:
|
||||
offset = 0
|
||||
current_date = datetime.now() - timedelta(days=offset)
|
||||
yield "date", current_date.strftime("%Y-%m-%d")
|
||||
yield "date", current_date.strftime(input_data.format)
|
||||
|
||||
|
||||
class GetCurrentDateAndTimeBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
trigger: str
|
||||
trigger: str = SchemaField(
|
||||
description="Trigger any data to output the current date and time"
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="Format of the date and time to output",
|
||||
default="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
date_time: str
|
||||
date_time: str = SchemaField(
|
||||
description="Current date and time in the specified format (default: YYYY-MM-DD HH:MM:SS)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -83,7 +119,7 @@ class GetCurrentDateAndTimeBlock(Block):
|
||||
input_schema=GetCurrentDateAndTimeBlock.Input,
|
||||
output_schema=GetCurrentDateAndTimeBlock.Output,
|
||||
test_input=[
|
||||
{"trigger": "Hello", "format": "{date_time}"},
|
||||
{"trigger": "Hello"},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
@@ -97,20 +133,29 @@ class GetCurrentDateAndTimeBlock(Block):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
current_date_time = time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
current_date_time = time.strftime(input_data.format)
|
||||
yield "date_time", current_date_time
|
||||
|
||||
|
||||
class CountdownTimerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input_message: Any = "timer finished"
|
||||
seconds: Union[int, str] = 0
|
||||
minutes: Union[int, str] = 0
|
||||
hours: Union[int, str] = 0
|
||||
days: Union[int, str] = 0
|
||||
input_message: Any = SchemaField(
|
||||
description="Message to output after the timer finishes",
|
||||
default="timer finished",
|
||||
)
|
||||
seconds: Union[int, str] = SchemaField(
|
||||
description="Duration in seconds", default=0
|
||||
)
|
||||
minutes: Union[int, str] = SchemaField(
|
||||
description="Duration in minutes", default=0
|
||||
)
|
||||
hours: Union[int, str] = SchemaField(description="Duration in hours", default=0)
|
||||
days: Union[int, str] = SchemaField(description="Duration in days", default=0)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output_message: str
|
||||
output_message: str = SchemaField(
|
||||
description="Message after the timer finishes"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -7,9 +7,10 @@ from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TranscribeYouTubeVideoBlock(Block):
|
||||
class TranscribeYoutubeVideoBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
youtube_url: str = SchemaField(
|
||||
title="YouTube URL",
|
||||
description="The URL of the YouTube video to transcribe",
|
||||
placeholder="https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
||||
)
|
||||
@@ -24,8 +25,8 @@ class TranscribeYouTubeVideoBlock(Block):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f3a8f7e1-4b1d-4e5f-9f2a-7c3d5a2e6b4c",
|
||||
input_schema=TranscribeYouTubeVideoBlock.Input,
|
||||
output_schema=TranscribeYouTubeVideoBlock.Output,
|
||||
input_schema=TranscribeYoutubeVideoBlock.Input,
|
||||
output_schema=TranscribeYoutubeVideoBlock.Output,
|
||||
description="Transcribes a YouTube video.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
test_input={"youtube_url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ"},
|
||||
@@ -64,14 +65,11 @@ class TranscribeYouTubeVideoBlock(Block):
|
||||
return YouTubeTranscriptApi.get_transcript(video_id)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
video_id = self.extract_video_id(input_data.youtube_url)
|
||||
yield "video_id", video_id
|
||||
video_id = self.extract_video_id(input_data.youtube_url)
|
||||
yield "video_id", video_id
|
||||
|
||||
transcript = self.get_transcript(video_id)
|
||||
formatter = TextFormatter()
|
||||
transcript_text = formatter.format_transcript(transcript)
|
||||
transcript = self.get_transcript(video_id)
|
||||
formatter = TextFormatter()
|
||||
transcript_text = formatter.format_transcript(transcript)
|
||||
|
||||
yield "transcript", transcript_text
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "transcript", transcript_text
|
||||
|
||||
@@ -272,6 +272,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
for output_name, output_data in self.run(
|
||||
self.input_schema(**input_data), **kwargs
|
||||
):
|
||||
if output_name == "error":
|
||||
raise RuntimeError(output_data)
|
||||
if error := self.output_schema.validate_field(output_name, output_data):
|
||||
raise ValueError(f"Block produced an invalid output data: {error}")
|
||||
yield output_name, output_data
|
||||
|
||||
@@ -17,8 +17,9 @@ from backend.blocks.llm import (
|
||||
AITextSummarizerBlock,
|
||||
LlmModel,
|
||||
)
|
||||
from backend.blocks.search import ExtractWebsiteContentBlock, SearchTheWebBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.data.block import Block, BlockInput
|
||||
from backend.data.block import Block, BlockInput, get_block
|
||||
from backend.util.settings import Config
|
||||
|
||||
|
||||
@@ -74,6 +75,10 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
CreateTalkingAvatarVideoBlock: [
|
||||
BlockCost(cost_amount=15, cost_filter={"api_key": None})
|
||||
],
|
||||
SearchTheWebBlock: [BlockCost(cost_amount=1)],
|
||||
ExtractWebsiteContentBlock: [
|
||||
BlockCost(cost_amount=1, cost_filter={"raw_content": False})
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -96,7 +101,7 @@ class UserCreditBase(ABC):
|
||||
self,
|
||||
user_id: str,
|
||||
user_credit: int,
|
||||
block: Block,
|
||||
block_id: str,
|
||||
input_data: BlockInput,
|
||||
data_size: float,
|
||||
run_time: float,
|
||||
@@ -107,7 +112,7 @@ class UserCreditBase(ABC):
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
user_credit (int): The current credit for the user.
|
||||
block (Block): The block that is being used.
|
||||
block_id (str): The block ID.
|
||||
input_data (BlockInput): The input data for the block.
|
||||
data_size (float): The size of the data being processed.
|
||||
run_time (float): The time taken to run the block.
|
||||
@@ -208,12 +213,16 @@ class UserCredit(UserCreditBase):
|
||||
self,
|
||||
user_id: str,
|
||||
user_credit: int,
|
||||
block: Block,
|
||||
block_id: str,
|
||||
input_data: BlockInput,
|
||||
data_size: float,
|
||||
run_time: float,
|
||||
validate_balance: bool = True,
|
||||
) -> int:
|
||||
block = get_block(block_id)
|
||||
if not block:
|
||||
raise ValueError(f"Block not found: {block_id}")
|
||||
|
||||
cost, matching_filter = self._block_usage_cost(
|
||||
block=block, input_data=input_data, data_size=data_size, run_time=run_time
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import datetime, timezone
|
||||
from multiprocessing import Manager
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import Credentials
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
from prisma.models import (
|
||||
AgentGraphExecution,
|
||||
@@ -26,7 +25,6 @@ class GraphExecution(BaseModel):
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
start_node_execs: list["NodeExecution"]
|
||||
node_input_credentials: dict[str, Credentials] # dict[node_id, Credentials]
|
||||
|
||||
|
||||
class NodeExecution(BaseModel):
|
||||
@@ -40,28 +38,6 @@ class NodeExecution(BaseModel):
|
||||
|
||||
ExecutionStatus = AgentExecutionStatus
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ExecutionQueue(Generic[T]):
|
||||
"""
|
||||
Queue for managing the execution of agents.
|
||||
This will be shared between different processes
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.queue = Manager().Queue()
|
||||
|
||||
def add(self, execution: T) -> T:
|
||||
self.queue.put(execution)
|
||||
return execution
|
||||
|
||||
def get(self) -> T:
|
||||
return self.queue.get()
|
||||
|
||||
def empty(self) -> bool:
|
||||
return self.queue.empty()
|
||||
|
||||
|
||||
class ExecutionResult(BaseModel):
|
||||
graph_id: str
|
||||
|
||||
@@ -2,20 +2,18 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
import prisma.types
|
||||
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
|
||||
from prisma.types import AgentGraphInclude
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
from pydantic import BaseModel
|
||||
from pydantic_core import PydanticUndefinedType
|
||||
|
||||
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock
|
||||
from backend.data.block import BlockInput, get_block, get_blocks
|
||||
from backend.data.db import BaseDbModel, transaction
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.util import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -53,17 +51,8 @@ class Node(BaseDbModel):
|
||||
block_id: str
|
||||
input_default: BlockInput = {} # dict[input_name, default_value]
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
_input_links: list[Link] = PrivateAttr(default=[])
|
||||
_output_links: list[Link] = PrivateAttr(default=[])
|
||||
|
||||
@property
|
||||
def input_links(self) -> list[Link]:
|
||||
return self._input_links
|
||||
|
||||
@property
|
||||
def output_links(self) -> list[Link]:
|
||||
return self._output_links
|
||||
input_links: list[Link] = []
|
||||
output_links: list[Link] = []
|
||||
|
||||
@staticmethod
|
||||
def from_db(node: AgentNode):
|
||||
@@ -75,8 +64,8 @@ class Node(BaseDbModel):
|
||||
input_default=json.loads(node.constantInput),
|
||||
metadata=json.loads(node.metadata),
|
||||
)
|
||||
obj._input_links = [Link.from_db(link) for link in node.Input or []]
|
||||
obj._output_links = [Link.from_db(link) for link in node.Output or []]
|
||||
obj.input_links = [Link.from_db(link) for link in node.Input or []]
|
||||
obj.output_links = [Link.from_db(link) for link in node.Output or []]
|
||||
return obj
|
||||
|
||||
|
||||
@@ -330,7 +319,7 @@ class Graph(GraphMeta):
|
||||
return input_schema
|
||||
|
||||
@staticmethod
|
||||
def from_db(graph: AgentGraph):
|
||||
def from_db(graph: AgentGraph, hide_credentials: bool = False):
|
||||
nodes = [
|
||||
*(graph.AgentNodes or []),
|
||||
*(
|
||||
@@ -341,7 +330,7 @@ class Graph(GraphMeta):
|
||||
]
|
||||
return Graph(
|
||||
**GraphMeta.from_db(graph).model_dump(),
|
||||
nodes=[Node.from_db(node) for node in nodes],
|
||||
nodes=[Graph._process_node(node, hide_credentials) for node in nodes],
|
||||
links=list(
|
||||
{
|
||||
Link.from_db(link)
|
||||
@@ -355,6 +344,31 @@ class Graph(GraphMeta):
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_node(node: AgentNode, hide_credentials: bool) -> Node:
|
||||
node_dict = node.model_dump()
|
||||
if hide_credentials and "constantInput" in node_dict:
|
||||
constant_input = json.loads(node_dict["constantInput"])
|
||||
constant_input = Graph._hide_credentials_in_input(constant_input)
|
||||
node_dict["constantInput"] = json.dumps(constant_input)
|
||||
return Node.from_db(AgentNode(**node_dict))
|
||||
|
||||
@staticmethod
|
||||
def _hide_credentials_in_input(input_data: dict[str, Any]) -> dict[str, Any]:
|
||||
sensitive_keys = ["credentials", "api_key", "password", "token", "secret"]
|
||||
result = {}
|
||||
for key, value in input_data.items():
|
||||
if isinstance(value, dict):
|
||||
result[key] = Graph._hide_credentials_in_input(value)
|
||||
elif isinstance(value, str) and any(
|
||||
sensitive_key in key.lower() for sensitive_key in sensitive_keys
|
||||
):
|
||||
# Skip this key-value pair in the result
|
||||
continue
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
|
||||
"Input": True,
|
||||
@@ -382,9 +396,9 @@ async def get_node(node_id: str) -> Node:
|
||||
|
||||
|
||||
async def get_graphs_meta(
|
||||
user_id: str,
|
||||
include_executions: bool = False,
|
||||
filter_by: Literal["active", "template"] | None = "active",
|
||||
user_id: str | None = None,
|
||||
) -> list[GraphMeta]:
|
||||
"""
|
||||
Retrieves graph metadata objects.
|
||||
@@ -393,6 +407,7 @@ async def get_graphs_meta(
|
||||
Args:
|
||||
include_executions: Whether to include executions in the graph metadata.
|
||||
filter_by: An optional filter to either select templates or active graphs.
|
||||
user_id: The ID of the user that owns the graph.
|
||||
|
||||
Returns:
|
||||
list[GraphMeta]: A list of objects representing the retrieved graph metadata.
|
||||
@@ -404,8 +419,7 @@ async def get_graphs_meta(
|
||||
elif filter_by == "template":
|
||||
where_clause["isTemplate"] = True
|
||||
|
||||
if user_id and filter_by != "template":
|
||||
where_clause["userId"] = user_id
|
||||
where_clause["userId"] = user_id
|
||||
|
||||
graphs = await AgentGraph.prisma().find_many(
|
||||
where=where_clause,
|
||||
@@ -431,6 +445,7 @@ async def get_graph(
|
||||
version: int | None = None,
|
||||
template: bool = False,
|
||||
user_id: str | None = None,
|
||||
hide_credentials: bool = False,
|
||||
) -> Graph | None:
|
||||
"""
|
||||
Retrieves a graph from the DB.
|
||||
@@ -456,7 +471,7 @@ async def get_graph(
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
order={"version": "desc"},
|
||||
)
|
||||
return Graph.from_db(graph) if graph else None
|
||||
return Graph.from_db(graph, hide_credentials) if graph else None
|
||||
|
||||
|
||||
async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:
|
||||
@@ -500,6 +515,15 @@ async def get_graph_all_versions(graph_id: str, user_id: str) -> list[Graph]:
|
||||
return [Graph.from_db(graph) for graph in graph_versions]
|
||||
|
||||
|
||||
async def delete_graph(graph_id: str, user_id: str) -> int:
|
||||
entries_count = await AgentGraph.prisma().delete_many(
|
||||
where={"id": graph_id, "userId": user_id}
|
||||
)
|
||||
if entries_count:
|
||||
logger.info(f"Deleted {entries_count} graph entries for Graph #{graph_id}")
|
||||
return entries_count
|
||||
|
||||
|
||||
async def create_graph(graph: Graph, user_id: str) -> Graph:
|
||||
async with transaction() as tx:
|
||||
await __create_graph(tx, graph, user_id)
|
||||
@@ -576,30 +600,3 @@ async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
for link in graph.links
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# --------------------- Helper functions --------------------- #
|
||||
|
||||
|
||||
TEMPLATES_DIR = Path(__file__).parent.parent.parent / "graph_templates"
|
||||
|
||||
|
||||
async def import_packaged_templates() -> None:
|
||||
templates_in_db = await get_graphs_meta(filter_by="template")
|
||||
|
||||
logging.info("Loading templates...")
|
||||
for template_file in TEMPLATES_DIR.glob("*.json"):
|
||||
template_data = json.loads(template_file.read_bytes())
|
||||
|
||||
template = Graph.model_validate(template_data)
|
||||
if not template.is_template:
|
||||
logging.warning(
|
||||
f"pre-packaged graph file {template_file} is not a template"
|
||||
)
|
||||
continue
|
||||
if (
|
||||
exists := next((t for t in templates_in_db if t.id == template.id), None)
|
||||
) and exists.version >= template.version:
|
||||
continue
|
||||
await create_graph(template, DEFAULT_USER_ID)
|
||||
logging.info(f"Loaded template '{template.name}' ({template.id})")
|
||||
|
||||
@@ -2,12 +2,14 @@ import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from backend.data import redis
|
||||
from backend.data.execution import ExecutionResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
@@ -17,14 +19,6 @@ class DateTimeEncoder(json.JSONEncoder):
|
||||
|
||||
|
||||
class AbstractEventQueue(ABC):
|
||||
@abstractmethod
|
||||
def connect(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, execution_result: ExecutionResult):
|
||||
pass
|
||||
@@ -36,26 +30,41 @@ class AbstractEventQueue(ABC):
|
||||
|
||||
class RedisEventQueue(AbstractEventQueue):
|
||||
def __init__(self):
|
||||
self.connection = None
|
||||
self.queue_name = redis.QUEUE_NAME
|
||||
|
||||
def connect(self):
|
||||
self.connection = redis.connect()
|
||||
@property
|
||||
def connection(self):
|
||||
return redis.get_redis()
|
||||
|
||||
def put(self, execution_result: ExecutionResult):
|
||||
if self.connection:
|
||||
message = json.dumps(execution_result.model_dump(), cls=DateTimeEncoder)
|
||||
logger.info(f"Putting execution result to Redis {message}")
|
||||
self.connection.lpush(self.queue_name, message)
|
||||
message = json.dumps(execution_result.model_dump(), cls=DateTimeEncoder)
|
||||
logger.info(f"Putting execution result to Redis {message}")
|
||||
self.connection.lpush(self.queue_name, message)
|
||||
|
||||
def get(self) -> ExecutionResult | None:
|
||||
if self.connection:
|
||||
message = self.connection.rpop(self.queue_name)
|
||||
if message is not None and isinstance(message, (str, bytes, bytearray)):
|
||||
data = json.loads(message)
|
||||
logger.info(f"Getting execution result from Redis {data}")
|
||||
return ExecutionResult(**data)
|
||||
message = self.connection.rpop(self.queue_name)
|
||||
if message is not None and isinstance(message, (str, bytes, bytearray)):
|
||||
data = json.loads(message)
|
||||
logger.info(f"Getting execution result from Redis {data}")
|
||||
return ExecutionResult(**data)
|
||||
elif message is not None:
|
||||
logger.error(f"Failed to get execution result from Redis {message}")
|
||||
return None
|
||||
|
||||
def close(self):
|
||||
redis.disconnect()
|
||||
|
||||
class ExecutionQueue(Generic[T]):
|
||||
def __init__(self, queue_name: str):
|
||||
self.redis = redis.get_redis()
|
||||
self.queue_name = queue_name
|
||||
|
||||
def add(self, item: T):
|
||||
message = json.dumps(item.model_dump(), default=str)
|
||||
self.redis.lpush(self.queue_name, message)
|
||||
|
||||
def get(self) -> T:
|
||||
while True:
|
||||
_, message = self.redis.brpop(self.queue_name)
|
||||
return T.model_validate(json.loads(message))
|
||||
|
||||
def empty(self) -> bool:
|
||||
return self.redis.llen(self.queue_name) == 0
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from backend.app import run_processes
|
||||
from backend.executor import ExecutionManager
|
||||
from backend.executor import DatabaseManager, ExecutionManager
|
||||
|
||||
|
||||
def main():
|
||||
@@ -7,6 +7,7 @@ def main():
|
||||
Run all the processes required for the AutoGPT-server REST API.
|
||||
"""
|
||||
run_processes(
|
||||
DatabaseManager(),
|
||||
ExecutionManager(),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from .database import DatabaseManager
|
||||
from .manager import ExecutionManager
|
||||
from .scheduler import ExecutionScheduler
|
||||
|
||||
__all__ = [
|
||||
"DatabaseManager",
|
||||
"ExecutionManager",
|
||||
"ExecutionScheduler",
|
||||
]
|
||||
|
||||
75
autogpt_platform/backend/backend/executor/database.py
Normal file
75
autogpt_platform/backend/backend/executor/database.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Concatenate, Coroutine, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
ExecutionResult,
|
||||
create_graph_execution,
|
||||
get_execution_results,
|
||||
get_incomplete_executions,
|
||||
get_latest_execution,
|
||||
update_execution_status,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_stats,
|
||||
upsert_execution_input,
|
||||
upsert_execution_output,
|
||||
)
|
||||
from backend.data.graph import get_graph, get_node
|
||||
from backend.data.queue import RedisEventQueue
|
||||
from backend.util.service import AppService, expose
|
||||
from backend.util.settings import Config
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(port=Config().database_api_port)
|
||||
self.use_db = True
|
||||
self.use_redis = True
|
||||
self.event_queue = RedisEventQueue()
|
||||
|
||||
@expose
|
||||
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
|
||||
self.event_queue.put(ExecutionResult(**execution_result_dict))
|
||||
|
||||
@staticmethod
|
||||
def exposed_run_and_wait(
|
||||
f: Callable[P, Coroutine[None, None, R]]
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
@expose
|
||||
@wraps(f)
|
||||
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
coroutine = f(*args, **kwargs)
|
||||
res = self.run_and_wait(coroutine)
|
||||
return res
|
||||
|
||||
return wrapper
|
||||
|
||||
# Executions
|
||||
create_graph_execution = exposed_run_and_wait(create_graph_execution)
|
||||
get_execution_results = exposed_run_and_wait(get_execution_results)
|
||||
get_incomplete_executions = exposed_run_and_wait(get_incomplete_executions)
|
||||
get_latest_execution = exposed_run_and_wait(get_latest_execution)
|
||||
update_execution_status = exposed_run_and_wait(update_execution_status)
|
||||
update_graph_execution_stats = exposed_run_and_wait(update_graph_execution_stats)
|
||||
update_node_execution_stats = exposed_run_and_wait(update_node_execution_stats)
|
||||
upsert_execution_input = exposed_run_and_wait(upsert_execution_input)
|
||||
upsert_execution_output = exposed_run_and_wait(upsert_execution_output)
|
||||
|
||||
# Graphs
|
||||
get_node = exposed_run_and_wait(get_node)
|
||||
get_graph = exposed_run_and_wait(get_graph)
|
||||
|
||||
# Credits
|
||||
user_credit_model = get_user_credit_model()
|
||||
get_or_refill_credit = cast(
|
||||
Callable[[Any, str], int],
|
||||
exposed_run_and_wait(user_credit_model.get_or_refill_credit),
|
||||
)
|
||||
spend_credits = cast(
|
||||
Callable[[Any, str, int, str, dict[str, str], float, float], int],
|
||||
exposed_run_and_wait(user_credit_model.spend_credits),
|
||||
)
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import logging
|
||||
import multiprocessing
|
||||
@@ -9,45 +8,40 @@ import threading
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing.pool import AsyncResult, Pool
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import Credentials
|
||||
from pydantic import BaseModel
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from backend.data.queue import ExecutionQueue
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.executor import DatabaseManager
|
||||
|
||||
from backend.data import db, redis
|
||||
from backend.data import redis
|
||||
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionResult,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
NodeExecution,
|
||||
create_graph_execution,
|
||||
get_execution_results,
|
||||
get_incomplete_executions,
|
||||
get_latest_execution,
|
||||
merge_execution_input,
|
||||
parse_execution_output,
|
||||
update_execution_status,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_stats,
|
||||
upsert_execution_input,
|
||||
upsert_execution_output,
|
||||
)
|
||||
from backend.data.graph import Graph, Link, Node, get_graph, get_node
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
from backend.data.model import CREDENTIALS_FIELD_NAME, CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util import json
|
||||
from backend.util.cache import thread_cached_property
|
||||
from backend.util.decorator import error_logged, time_measured
|
||||
from backend.util.logging import configure_logging
|
||||
from backend.util.process import set_service_name
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.settings import Config
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.type import convert
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class LogMetadata:
|
||||
@@ -100,10 +94,9 @@ ExecutionStream = Generator[NodeExecution, None, None]
|
||||
|
||||
|
||||
def execute_node(
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
api_client: "AgentServer",
|
||||
db_client: "DatabaseManager",
|
||||
creds_manager: IntegrationCredentialsManager,
|
||||
data: NodeExecution,
|
||||
input_credentials: Credentials | None = None,
|
||||
execution_stats: dict[str, Any] | None = None,
|
||||
) -> ExecutionStream:
|
||||
"""
|
||||
@@ -111,8 +104,7 @@ def execute_node(
|
||||
persist the execution result, and return the subsequent node to be executed.
|
||||
|
||||
Args:
|
||||
loop: The event loop to run the async functions.
|
||||
api_client: The client to send execution updates to the server.
|
||||
db_client: The client to send execution updates to the server.
|
||||
data: The execution data for executing the current node.
|
||||
execution_stats: The execution statistics to be updated.
|
||||
|
||||
@@ -125,17 +117,12 @@ def execute_node(
|
||||
node_exec_id = data.node_exec_id
|
||||
node_id = data.node_id
|
||||
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
def wait(f: Coroutine[Any, Any, T]) -> T:
|
||||
return loop.run_until_complete(f)
|
||||
|
||||
def update_execution(status: ExecutionStatus) -> ExecutionResult:
|
||||
exec_update = wait(update_execution_status(node_exec_id, status))
|
||||
api_client.send_execution_update(exec_update.model_dump())
|
||||
exec_update = db_client.update_execution_status(node_exec_id, status)
|
||||
db_client.send_execution_update(exec_update.model_dump())
|
||||
return exec_update
|
||||
|
||||
node = wait(get_node(node_id))
|
||||
node = db_client.get_node(node_id)
|
||||
|
||||
node_block = get_block(node.block_id)
|
||||
if not node_block:
|
||||
@@ -161,15 +148,21 @@ def execute_node(
|
||||
input_size = len(input_data_str)
|
||||
log_metadata.info("Executed node with input", input=input_data_str)
|
||||
update_execution(ExecutionStatus.RUNNING)
|
||||
user_credit = get_user_credit_model()
|
||||
|
||||
extra_exec_kwargs = {}
|
||||
if input_credentials:
|
||||
extra_exec_kwargs["credentials"] = input_credentials
|
||||
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
||||
# changes during execution. ⚠️ This means a set of credentials can only be used by
|
||||
# one (running) block at a time; simultaneous execution of blocks using same
|
||||
# credentials is not supported.
|
||||
credentials = creds_lock = None
|
||||
if CREDENTIALS_FIELD_NAME in input_data:
|
||||
credentials_meta = CredentialsMetaInput(**input_data[CREDENTIALS_FIELD_NAME])
|
||||
credentials, creds_lock = creds_manager.acquire(user_id, credentials_meta.id)
|
||||
extra_exec_kwargs["credentials"] = credentials
|
||||
|
||||
output_size = 0
|
||||
try:
|
||||
credit = wait(user_credit.get_or_refill_credit(user_id))
|
||||
credit = db_client.get_or_refill_credit(user_id)
|
||||
if credit < 0:
|
||||
raise ValueError(f"Insufficient credit: {credit}")
|
||||
|
||||
@@ -178,11 +171,10 @@ def execute_node(
|
||||
):
|
||||
output_size += len(json.dumps(output_data))
|
||||
log_metadata.info("Node produced output", output_name=output_data)
|
||||
wait(upsert_execution_output(node_exec_id, output_name, output_data))
|
||||
db_client.upsert_execution_output(node_exec_id, output_name, output_data)
|
||||
|
||||
for execution in _enqueue_next_nodes(
|
||||
api_client=api_client,
|
||||
loop=loop,
|
||||
db_client=db_client,
|
||||
node=node,
|
||||
output=(output_name, output_data),
|
||||
user_id=user_id,
|
||||
@@ -192,6 +184,10 @@ def execute_node(
|
||||
):
|
||||
yield execution
|
||||
|
||||
# Release lock on credentials ASAP
|
||||
if creds_lock:
|
||||
creds_lock.release()
|
||||
|
||||
r = update_execution(ExecutionStatus.COMPLETED)
|
||||
s = input_size + output_size
|
||||
t = (
|
||||
@@ -199,35 +195,27 @@ def execute_node(
|
||||
if r.end_time and r.start_time
|
||||
else 0
|
||||
)
|
||||
wait(user_credit.spend_credits(user_id, credit, node_block, input_data, s, t))
|
||||
db_client.spend_credits(user_id, credit, node_block.id, input_data, s, t)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
log_metadata.exception(f"Node execution failed with error {error_msg}")
|
||||
wait(upsert_execution_output(node_exec_id, "error", error_msg))
|
||||
db_client.upsert_execution_output(node_exec_id, "error", error_msg)
|
||||
update_execution(ExecutionStatus.FAILED)
|
||||
|
||||
raise e
|
||||
|
||||
finally:
|
||||
# Ensure credentials are released even if execution fails
|
||||
if creds_lock:
|
||||
creds_lock.release()
|
||||
if execution_stats is not None:
|
||||
execution_stats["input_size"] = input_size
|
||||
execution_stats["output_size"] = output_size
|
||||
|
||||
|
||||
@contextmanager
|
||||
def synchronized(key: str, timeout: int = 60):
|
||||
lock = redis.get_redis().lock(f"lock:{key}", timeout=timeout)
|
||||
try:
|
||||
lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
def _enqueue_next_nodes(
|
||||
api_client: "AgentServer",
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
db_client: "DatabaseManager",
|
||||
node: Node,
|
||||
output: BlockData,
|
||||
user_id: str,
|
||||
@@ -235,16 +223,14 @@ def _enqueue_next_nodes(
|
||||
graph_id: str,
|
||||
log_metadata: LogMetadata,
|
||||
) -> list[NodeExecution]:
|
||||
def wait(f: Coroutine[Any, Any, T]) -> T:
|
||||
return loop.run_until_complete(f)
|
||||
|
||||
def add_enqueued_execution(
|
||||
node_exec_id: str, node_id: str, data: BlockInput
|
||||
) -> NodeExecution:
|
||||
exec_update = wait(
|
||||
update_execution_status(node_exec_id, ExecutionStatus.QUEUED, data)
|
||||
exec_update = db_client.update_execution_status(
|
||||
node_exec_id, ExecutionStatus.QUEUED, data
|
||||
)
|
||||
api_client.send_execution_update(exec_update.model_dump())
|
||||
db_client.send_execution_update(exec_update.model_dump())
|
||||
return NodeExecution(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
@@ -264,20 +250,18 @@ def _enqueue_next_nodes(
|
||||
if next_data is None:
|
||||
return enqueued_executions
|
||||
|
||||
next_node = wait(get_node(next_node_id))
|
||||
next_node = db_client.get_node(next_node_id)
|
||||
|
||||
# Multiple node can register the same next node, we need this to be atomic
|
||||
# To avoid same execution to be enqueued multiple times,
|
||||
# Or the same input to be consumed multiple times.
|
||||
with synchronized(f"upsert_input-{next_node_id}-{graph_exec_id}"):
|
||||
# Add output data to the earliest incomplete execution, or create a new one.
|
||||
next_node_exec_id, next_node_input = wait(
|
||||
upsert_execution_input(
|
||||
node_id=next_node_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
input_name=next_input_name,
|
||||
input_data=next_data,
|
||||
)
|
||||
next_node_exec_id, next_node_input = db_client.upsert_execution_input(
|
||||
node_id=next_node_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
input_name=next_input_name,
|
||||
input_data=next_data,
|
||||
)
|
||||
|
||||
# Complete missing static input pins data using the last execution input.
|
||||
@@ -287,8 +271,8 @@ def _enqueue_next_nodes(
|
||||
if link.is_static and link.sink_name not in next_node_input
|
||||
}
|
||||
if static_link_names and (
|
||||
latest_execution := wait(
|
||||
get_latest_execution(next_node_id, graph_exec_id)
|
||||
latest_execution := db_client.get_latest_execution(
|
||||
next_node_id, graph_exec_id
|
||||
)
|
||||
):
|
||||
for name in static_link_names:
|
||||
@@ -315,7 +299,9 @@ def _enqueue_next_nodes(
|
||||
|
||||
# If link is static, there could be some incomplete executions waiting for it.
|
||||
# Load and complete the input missing input data, and try to re-enqueue them.
|
||||
for iexec in wait(get_incomplete_executions(next_node_id, graph_exec_id)):
|
||||
for iexec in db_client.get_incomplete_executions(
|
||||
next_node_id, graph_exec_id
|
||||
):
|
||||
idata = iexec.input_data
|
||||
ineid = iexec.node_exec_id
|
||||
|
||||
@@ -400,12 +386,6 @@ def validate_exec(
|
||||
return data, node_block.name
|
||||
|
||||
|
||||
def get_agent_server_client() -> "AgentServer":
|
||||
from backend.server.rest_api import AgentServer
|
||||
|
||||
return get_service_client(AgentServer, Config().agent_server_port)
|
||||
|
||||
|
||||
class Executor:
|
||||
"""
|
||||
This class contains event handlers for the process pool executor events.
|
||||
@@ -434,13 +414,12 @@ class Executor:
|
||||
@classmethod
|
||||
def on_node_executor_start(cls):
|
||||
configure_logging()
|
||||
|
||||
cls.loop = asyncio.new_event_loop()
|
||||
cls.pid = os.getpid()
|
||||
|
||||
set_service_name("NodeExecutor")
|
||||
redis.connect()
|
||||
cls.loop.run_until_complete(db.connect())
|
||||
cls.agent_server_client = get_agent_server_client()
|
||||
cls.node_queue = ExecutionQueue[NodeExecution]("node_execution_queue")
|
||||
cls.pid = os.getpid()
|
||||
cls.db_client = get_db_client()
|
||||
cls.creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
# Set up shutdown handlers
|
||||
cls.shutdown_lock = threading.Lock()
|
||||
@@ -454,8 +433,8 @@ class Executor:
|
||||
if not cls.shutdown_lock.acquire(blocking=False):
|
||||
return # already shutting down
|
||||
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB...")
|
||||
cls.loop.run_until_complete(db.disconnect())
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
|
||||
cls.creds_manager.release_all_locks()
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
|
||||
@@ -464,20 +443,20 @@ class Executor:
|
||||
def on_node_executor_sigterm(cls):
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ⚠️ SIGTERM received")
|
||||
if not cls.shutdown_lock.acquire(blocking=False):
|
||||
return # already shutting down, no need to self-terminate
|
||||
return # already shutting down
|
||||
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ⏳ Disconnecting DB...")
|
||||
cls.loop.run_until_complete(db.disconnect())
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ✅ Finished cleanup")
|
||||
llprint(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
|
||||
cls.creds_manager.release_all_locks()
|
||||
llprint(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
llprint(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
|
||||
sys.exit(0)
|
||||
|
||||
@classmethod
|
||||
@error_logged
|
||||
def on_node_execution(
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
node_exec: NodeExecution,
|
||||
input_credentials: Credentials | None,
|
||||
):
|
||||
log_metadata = LogMetadata(
|
||||
user_id=node_exec.user_id,
|
||||
@@ -487,34 +466,32 @@ class Executor:
|
||||
node_id=node_exec.node_id,
|
||||
block_name="-",
|
||||
)
|
||||
|
||||
q = cls.node_queue
|
||||
execution_stats = {}
|
||||
timing_info, _ = cls._on_node_execution(
|
||||
q, node_exec, input_credentials, log_metadata, execution_stats
|
||||
q, node_exec, log_metadata, execution_stats
|
||||
)
|
||||
execution_stats["walltime"] = timing_info.wall_time
|
||||
execution_stats["cputime"] = timing_info.cpu_time
|
||||
|
||||
cls.loop.run_until_complete(
|
||||
update_node_execution_stats(node_exec.node_exec_id, execution_stats)
|
||||
cls.db_client.update_node_execution_stats(
|
||||
node_exec.node_exec_id, execution_stats
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@time_measured
|
||||
def _on_node_execution(
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
node_exec: NodeExecution,
|
||||
input_credentials: Credentials | None,
|
||||
log_metadata: LogMetadata,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
try:
|
||||
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
|
||||
for execution in execute_node(
|
||||
cls.loop, cls.agent_server_client, node_exec, input_credentials, stats
|
||||
cls.db_client, cls.creds_manager, node_exec, stats
|
||||
):
|
||||
q.add(execution)
|
||||
cls.node_queue.add(execution)
|
||||
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
|
||||
except Exception as e:
|
||||
log_metadata.exception(
|
||||
@@ -524,12 +501,11 @@ class Executor:
|
||||
@classmethod
|
||||
def on_graph_executor_start(cls):
|
||||
configure_logging()
|
||||
set_service_name("GraphExecutor")
|
||||
|
||||
cls.pool_size = Config().num_node_workers
|
||||
cls.loop = asyncio.new_event_loop()
|
||||
cls.db_client = get_db_client()
|
||||
cls.pool_size = settings.config.num_node_workers
|
||||
cls.pid = os.getpid()
|
||||
|
||||
cls.loop.run_until_complete(db.connect())
|
||||
cls._init_node_executor_pool()
|
||||
logger.info(
|
||||
f"Graph executor {cls.pid} started with {cls.pool_size} node workers"
|
||||
@@ -541,8 +517,6 @@ class Executor:
|
||||
@classmethod
|
||||
def on_graph_executor_stop(cls):
|
||||
prefix = f"[on_graph_executor_stop {cls.pid}]"
|
||||
logger.info(f"{prefix} ⏳ Disconnecting DB...")
|
||||
cls.loop.run_until_complete(db.disconnect())
|
||||
logger.info(f"{prefix} ⏳ Terminating node executor pool...")
|
||||
cls.executor.terminate()
|
||||
logger.info(f"{prefix} ✅ Finished cleanup")
|
||||
@@ -569,14 +543,12 @@ class Executor:
|
||||
graph_exec, cancel, log_metadata
|
||||
)
|
||||
|
||||
cls.loop.run_until_complete(
|
||||
update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
error=error,
|
||||
wall_time=timing_info.wall_time,
|
||||
cpu_time=timing_info.cpu_time,
|
||||
node_count=node_count,
|
||||
)
|
||||
cls.db_client.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
error=error,
|
||||
wall_time=timing_info.wall_time,
|
||||
cpu_time=timing_info.cpu_time,
|
||||
node_count=node_count,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -610,7 +582,7 @@ class Executor:
|
||||
cancel_thread.start()
|
||||
|
||||
try:
|
||||
queue = ExecutionQueue[NodeExecution]()
|
||||
queue = ExecutionQueue[NodeExecution]("node_execution_queue")
|
||||
for node_exec in graph_exec.start_node_execs:
|
||||
queue.add(node_exec)
|
||||
|
||||
@@ -648,11 +620,7 @@ class Executor:
|
||||
)
|
||||
running_executions[exec_data.node_id] = cls.executor.apply_async(
|
||||
cls.on_node_execution,
|
||||
(
|
||||
queue,
|
||||
exec_data,
|
||||
graph_exec.node_input_credentials.get(exec_data.node_id),
|
||||
),
|
||||
(exec_data,),
|
||||
callback=make_exec_callback(exec_data),
|
||||
)
|
||||
|
||||
@@ -687,12 +655,13 @@ class Executor:
|
||||
|
||||
|
||||
class ExecutionManager(AppService):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(port=Config().execution_manager_port)
|
||||
self.use_db = True
|
||||
super().__init__(port=settings.config.execution_manager_port)
|
||||
self.use_redis = True
|
||||
self.use_supabase = True
|
||||
self.pool_size = Config().num_graph_workers
|
||||
self.queue = ExecutionQueue[GraphExecution]()
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.queue = ExecutionQueue[GraphExecution]("graph_execution_queue")
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
|
||||
def run_service(self):
|
||||
@@ -700,7 +669,9 @@ class ExecutionManager(AppService):
|
||||
SupabaseIntegrationCredentialsStore,
|
||||
)
|
||||
|
||||
self.credentials_store = SupabaseIntegrationCredentialsStore(self.supabase)
|
||||
self.credentials_store = SupabaseIntegrationCredentialsStore(
|
||||
self.supabase, redis.get_redis()
|
||||
)
|
||||
self.executor = ProcessPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
initializer=Executor.on_graph_executor_start,
|
||||
@@ -730,20 +701,20 @@ class ExecutionManager(AppService):
|
||||
|
||||
super().cleanup()
|
||||
|
||||
@property
|
||||
def agent_server_client(self) -> "AgentServer":
|
||||
return get_agent_server_client()
|
||||
@thread_cached_property
|
||||
def db_client(self) -> "DatabaseManager":
|
||||
return get_db_client()
|
||||
|
||||
@expose
|
||||
def add_execution(
|
||||
self, graph_id: str, data: BlockInput, user_id: str
|
||||
) -> dict[str, Any]:
|
||||
graph: Graph | None = self.run_and_wait(get_graph(graph_id, user_id=user_id))
|
||||
graph: Graph | None = self.db_client.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise Exception(f"Graph #{graph_id} not found.")
|
||||
|
||||
graph.validate_graph(for_run=True)
|
||||
node_input_credentials = self._get_node_input_credentials(graph, user_id)
|
||||
self._validate_node_input_credentials(graph, user_id)
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
@@ -766,13 +737,11 @@ class ExecutionManager(AppService):
|
||||
else:
|
||||
nodes_input.append((node.id, input_data))
|
||||
|
||||
graph_exec_id, node_execs = self.run_and_wait(
|
||||
create_graph_execution(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
nodes_input=nodes_input,
|
||||
user_id=user_id,
|
||||
)
|
||||
graph_exec_id, node_execs = self.db_client.create_graph_execution(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
nodes_input=nodes_input,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
starting_node_execs = []
|
||||
@@ -787,19 +756,16 @@ class ExecutionManager(AppService):
|
||||
data=node_exec.input_data,
|
||||
)
|
||||
)
|
||||
exec_update = self.run_and_wait(
|
||||
update_execution_status(
|
||||
node_exec.node_exec_id, ExecutionStatus.QUEUED, node_exec.input_data
|
||||
)
|
||||
exec_update = self.db_client.update_execution_status(
|
||||
node_exec.node_exec_id, ExecutionStatus.QUEUED, node_exec.input_data
|
||||
)
|
||||
self.agent_server_client.send_execution_update(exec_update.model_dump())
|
||||
self.db_client.send_execution_update(exec_update.model_dump())
|
||||
|
||||
graph_exec = GraphExecution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
start_node_execs=starting_node_execs,
|
||||
node_input_credentials=node_input_credentials,
|
||||
)
|
||||
self.queue.add(graph_exec)
|
||||
|
||||
@@ -828,30 +794,22 @@ class ExecutionManager(AppService):
|
||||
future.result()
|
||||
|
||||
# Update the status of the unfinished node executions
|
||||
node_execs = self.run_and_wait(get_execution_results(graph_exec_id))
|
||||
node_execs = self.db_client.get_execution_results(graph_exec_id)
|
||||
for node_exec in node_execs:
|
||||
if node_exec.status not in (
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.FAILED,
|
||||
):
|
||||
self.run_and_wait(
|
||||
upsert_execution_output(
|
||||
node_exec.node_exec_id, "error", "TERMINATED"
|
||||
)
|
||||
self.db_client.upsert_execution_output(
|
||||
node_exec.node_exec_id, "error", "TERMINATED"
|
||||
)
|
||||
exec_update = self.run_and_wait(
|
||||
update_execution_status(
|
||||
node_exec.node_exec_id, ExecutionStatus.FAILED
|
||||
)
|
||||
exec_update = self.db_client.update_execution_status(
|
||||
node_exec.node_exec_id, ExecutionStatus.FAILED
|
||||
)
|
||||
self.agent_server_client.send_execution_update(exec_update.model_dump())
|
||||
self.db_client.send_execution_update(exec_update.model_dump())
|
||||
|
||||
def _get_node_input_credentials(
|
||||
self, graph: Graph, user_id: str
|
||||
) -> dict[str, Credentials]:
|
||||
"""Gets all credentials for all nodes of the graph"""
|
||||
|
||||
node_credentials: dict[str, Credentials] = {}
|
||||
def _validate_node_input_credentials(self, graph: Graph, user_id: str):
|
||||
"""Checks all credentials for all nodes of the graph"""
|
||||
|
||||
for node in graph.nodes:
|
||||
block = get_block(node.block_id)
|
||||
@@ -894,9 +852,25 @@ class ExecutionManager(AppService):
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch"
|
||||
)
|
||||
node_credentials[node.id] = credentials
|
||||
|
||||
return node_credentials
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
|
||||
def get_db_client() -> "DatabaseManager":
|
||||
from backend.executor import DatabaseManager
|
||||
|
||||
return get_service_client(DatabaseManager, settings.config.database_api_port)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def synchronized(key: str, timeout: int = 60):
|
||||
lock: RedisLock = redis.get_redis().lock(f"lock:{key}", timeout=timeout)
|
||||
try:
|
||||
lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
def llprint(message: str):
|
||||
|
||||
@@ -5,9 +5,16 @@ from datetime import datetime
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
|
||||
from backend.data import schedule as model
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.schedule import (
|
||||
ExecutionSchedule,
|
||||
add_schedule,
|
||||
get_active_schedules,
|
||||
get_schedules,
|
||||
update_schedule,
|
||||
)
|
||||
from backend.executor.manager import ExecutionManager
|
||||
from backend.util.cache import thread_cached_property
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.settings import Config
|
||||
|
||||
@@ -19,14 +26,15 @@ def log(msg, **kwargs):
|
||||
|
||||
|
||||
class ExecutionScheduler(AppService):
|
||||
|
||||
def __init__(self, refresh_interval=10):
|
||||
super().__init__(port=Config().execution_scheduler_port)
|
||||
self.use_db = True
|
||||
self.last_check = datetime.min
|
||||
self.refresh_interval = refresh_interval
|
||||
|
||||
@property
|
||||
def execution_manager_client(self) -> ExecutionManager:
|
||||
@thread_cached_property
|
||||
def execution_client(self) -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager, Config().execution_manager_port)
|
||||
|
||||
def run_service(self):
|
||||
@@ -37,7 +45,7 @@ class ExecutionScheduler(AppService):
|
||||
time.sleep(self.refresh_interval)
|
||||
|
||||
def __refresh_jobs_from_db(self, scheduler: BackgroundScheduler):
|
||||
schedules = self.run_and_wait(model.get_active_schedules(self.last_check))
|
||||
schedules = self.run_and_wait(get_active_schedules(self.last_check))
|
||||
for schedule in schedules:
|
||||
if schedule.last_updated:
|
||||
self.last_check = max(self.last_check, schedule.last_updated)
|
||||
@@ -59,14 +67,13 @@ class ExecutionScheduler(AppService):
|
||||
def __execute_graph(self, graph_id: str, input_data: dict, user_id: str):
|
||||
try:
|
||||
log(f"Executing recurring job for graph #{graph_id}")
|
||||
execution_manager = self.execution_manager_client
|
||||
execution_manager.add_execution(graph_id, input_data, user_id)
|
||||
self.execution_client.add_execution(graph_id, input_data, user_id)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error executing graph {graph_id}: {e}")
|
||||
|
||||
@expose
|
||||
def update_schedule(self, schedule_id: str, is_enabled: bool, user_id: str) -> str:
|
||||
self.run_and_wait(model.update_schedule(schedule_id, is_enabled, user_id))
|
||||
self.run_and_wait(update_schedule(schedule_id, is_enabled, user_id))
|
||||
return schedule_id
|
||||
|
||||
@expose
|
||||
@@ -78,17 +85,16 @@ class ExecutionScheduler(AppService):
|
||||
input_data: BlockInput,
|
||||
user_id: str,
|
||||
) -> str:
|
||||
schedule = model.ExecutionSchedule(
|
||||
schedule = ExecutionSchedule(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
graph_version=graph_version,
|
||||
schedule=cron,
|
||||
input_data=input_data,
|
||||
)
|
||||
return self.run_and_wait(model.add_schedule(schedule)).id
|
||||
return self.run_and_wait(add_schedule(schedule)).id
|
||||
|
||||
@expose
|
||||
def get_execution_schedules(self, graph_id: str, user_id: str) -> dict[str, str]:
|
||||
query = model.get_schedules(graph_id, user_id=user_id)
|
||||
schedules: list[model.ExecutionSchedule] = self.run_and_wait(query)
|
||||
schedules = self.run_and_wait(get_schedules(graph_id, user_id=user_id))
|
||||
return {v.id: v.schedule for v in schedules}
|
||||
|
||||
172
autogpt_platform/backend/backend/integrations/creds_manager.py
Normal file
172
autogpt_platform/backend/backend/integrations/creds_manager.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store import (
|
||||
Credentials,
|
||||
SupabaseIntegrationCredentialsStore,
|
||||
)
|
||||
from autogpt_libs.utils.synchronize import RedisKeyedMutex
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from backend.data import redis
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from ..server.integrations.utils import get_supabase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class IntegrationCredentialsManager:
|
||||
"""
|
||||
Handles the lifecycle of integration credentials.
|
||||
- Automatically refreshes requested credentials if needed.
|
||||
- Uses locking mechanisms to ensure system-wide consistency and
|
||||
prevent invalidation of in-use tokens.
|
||||
|
||||
### ⚠️ Gotcha
|
||||
With `acquire(..)`, credentials can only be in use in one place at a time (e.g. one
|
||||
block execution).
|
||||
|
||||
### Locking mechanism
|
||||
- Because *getting* credentials can result in a refresh (= *invalidation* +
|
||||
*replacement*) of the stored credentials, *getting* is an operation that
|
||||
potentially requires read/write access.
|
||||
- Checking whether a token has to be refreshed is subject to an additional `refresh`
|
||||
scoped lock to prevent unnecessary sequential refreshes when multiple executions
|
||||
try to access the same credentials simultaneously.
|
||||
- We MUST lock credentials while in use to prevent them from being invalidated while
|
||||
they are in use, e.g. because they are being refreshed by a different part
|
||||
of the system.
|
||||
- The `!time_sensitive` lock in `acquire(..)` is part of a two-tier locking
|
||||
mechanism in which *updating* gets priority over *getting* credentials.
|
||||
This is to prevent a long queue of waiting *get* requests from blocking essential
|
||||
credential refreshes or user-initiated updates.
|
||||
|
||||
It is possible to implement a reader/writer locking system where either multiple
|
||||
readers or a single writer can have simultaneous access, but this would add a lot of
|
||||
complexity to the mechanism. I don't expect the current ("simple") mechanism to
|
||||
cause so much latency that it's worth implementing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
redis_conn = redis.get_redis()
|
||||
self._locks = RedisKeyedMutex(redis_conn)
|
||||
self.store = SupabaseIntegrationCredentialsStore(get_supabase(), redis_conn)
|
||||
|
||||
def create(self, user_id: str, credentials: Credentials) -> None:
|
||||
return self.store.add_creds(user_id, credentials)
|
||||
|
||||
def exists(self, user_id: str, credentials_id: str) -> bool:
|
||||
return self.store.get_creds_by_id(user_id, credentials_id) is not None
|
||||
|
||||
def get(
|
||||
self, user_id: str, credentials_id: str, lock: bool = True
|
||||
) -> Credentials | None:
|
||||
credentials = self.store.get_creds_by_id(user_id, credentials_id)
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
# Refresh OAuth credentials if needed
|
||||
if credentials.type == "oauth2" and credentials.access_token_expires_at:
|
||||
logger.debug(
|
||||
f"Credentials #{credentials.id} expire at "
|
||||
f"{datetime.fromtimestamp(credentials.access_token_expires_at)}; "
|
||||
f"current time is {datetime.now()}"
|
||||
)
|
||||
|
||||
with self._locked(user_id, credentials_id, "refresh"):
|
||||
oauth_handler = _get_provider_oauth_handler(credentials.provider)
|
||||
if oauth_handler.needs_refresh(credentials):
|
||||
logger.debug(
|
||||
f"Refreshing '{credentials.provider}' "
|
||||
f"credentials #{credentials.id}"
|
||||
)
|
||||
_lock = None
|
||||
if lock:
|
||||
# Wait until the credentials are no longer in use anywhere
|
||||
_lock = self._acquire_lock(user_id, credentials_id)
|
||||
|
||||
fresh_credentials = oauth_handler.refresh_tokens(credentials)
|
||||
self.store.update_creds(user_id, fresh_credentials)
|
||||
if _lock:
|
||||
_lock.release()
|
||||
|
||||
credentials = fresh_credentials
|
||||
else:
|
||||
logger.debug(f"Credentials #{credentials.id} never expire")
|
||||
|
||||
return credentials
|
||||
|
||||
def acquire(
|
||||
self, user_id: str, credentials_id: str
|
||||
) -> tuple[Credentials, RedisLock]:
|
||||
"""
|
||||
⚠️ WARNING: this locks credentials system-wide and blocks both acquiring
|
||||
and updating them elsewhere until the lock is released.
|
||||
See the class docstring for more info.
|
||||
"""
|
||||
# Use a low-priority (!time_sensitive) locking queue on top of the general lock
|
||||
# to allow priority access for refreshing/updating the tokens.
|
||||
with self._locked(user_id, credentials_id, "!time_sensitive"):
|
||||
lock = self._acquire_lock(user_id, credentials_id)
|
||||
credentials = self.get(user_id, credentials_id, lock=False)
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
f"Credentials #{credentials_id} for user #{user_id} not found"
|
||||
)
|
||||
return credentials, lock
|
||||
|
||||
def update(self, user_id: str, updated: Credentials) -> None:
|
||||
with self._locked(user_id, updated.id):
|
||||
self.store.update_creds(user_id, updated)
|
||||
|
||||
def delete(self, user_id: str, credentials_id: str) -> None:
|
||||
with self._locked(user_id, credentials_id):
|
||||
self.store.delete_creds_by_id(user_id, credentials_id)
|
||||
|
||||
# -- Locking utilities -- #
|
||||
|
||||
def _acquire_lock(self, user_id: str, credentials_id: str, *args: str) -> RedisLock:
|
||||
key = (
|
||||
self.store.supabase.supabase_url,
|
||||
f"user:{user_id}",
|
||||
f"credentials:{credentials_id}",
|
||||
*args,
|
||||
)
|
||||
return self._locks.acquire(key)
|
||||
|
||||
@contextmanager
|
||||
def _locked(self, user_id: str, credentials_id: str, *args: str):
|
||||
lock = self._acquire_lock(user_id, credentials_id, *args)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
def release_all_locks(self):
|
||||
"""Call this on process termination to ensure all locks are released"""
|
||||
self._locks.release_all_locks()
|
||||
self.store.locks.release_all_locks()
|
||||
|
||||
|
||||
def _get_provider_oauth_handler(provider_name: str) -> BaseOAuthHandler:
|
||||
if provider_name not in HANDLERS_BY_NAME:
|
||||
raise KeyError(f"Unknown provider '{provider_name}'")
|
||||
|
||||
client_id = getattr(settings.secrets, f"{provider_name}_client_id")
|
||||
client_secret = getattr(settings.secrets, f"{provider_name}_client_secret")
|
||||
if not (client_id and client_secret):
|
||||
raise Exception( # TODO: ConfigError
|
||||
f"Integration with provider '{provider_name}' is not configured",
|
||||
)
|
||||
|
||||
handler_class = HANDLERS_BY_NAME[provider_name]
|
||||
frontend_base_url = settings.config.frontend_base_url
|
||||
return handler_class(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
|
||||
)
|
||||
@@ -3,6 +3,7 @@ from .github import GitHubOAuthHandler
|
||||
from .google import GoogleOAuthHandler
|
||||
from .notion import NotionOAuthHandler
|
||||
|
||||
# --8<-- [start:HANDLERS_BY_NAMEExample]
|
||||
HANDLERS_BY_NAME: dict[str, type[BaseOAuthHandler]] = {
|
||||
handler.PROVIDER_NAME: handler
|
||||
for handler in [
|
||||
@@ -11,5 +12,6 @@ HANDLERS_BY_NAME: dict[str, type[BaseOAuthHandler]] = {
|
||||
NotionOAuthHandler,
|
||||
]
|
||||
}
|
||||
# --8<-- [end:HANDLERS_BY_NAMEExample]
|
||||
|
||||
__all__ = ["HANDLERS_BY_NAME"]
|
||||
|
||||
@@ -9,29 +9,48 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseOAuthHandler(ABC):
|
||||
# --8<-- [start:BaseOAuthHandler1]
|
||||
PROVIDER_NAME: ClassVar[str]
|
||||
DEFAULT_SCOPES: ClassVar[list[str]] = []
|
||||
# --8<-- [end:BaseOAuthHandler1]
|
||||
|
||||
@abstractmethod
|
||||
# --8<-- [start:BaseOAuthHandler2]
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str): ...
|
||||
|
||||
# --8<-- [end:BaseOAuthHandler2]
|
||||
|
||||
@abstractmethod
|
||||
# --8<-- [start:BaseOAuthHandler3]
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
# --8<-- [end:BaseOAuthHandler3]
|
||||
"""Constructs a login URL that the user can be redirected to"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
# --8<-- [start:BaseOAuthHandler4]
|
||||
def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str]
|
||||
) -> OAuth2Credentials:
|
||||
# --8<-- [end:BaseOAuthHandler4]
|
||||
"""Exchanges the acquired authorization code from login for a set of tokens"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
# --8<-- [start:BaseOAuthHandler5]
|
||||
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
# --8<-- [end:BaseOAuthHandler5]
|
||||
"""Implements the token refresh mechanism"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
# --8<-- [start:BaseOAuthHandler6]
|
||||
def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
# --8<-- [end:BaseOAuthHandler6]
|
||||
"""Revokes the given token at provider,
|
||||
returns False provider does not support it"""
|
||||
...
|
||||
|
||||
def refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
if credentials.provider != self.PROVIDER_NAME:
|
||||
raise ValueError(
|
||||
|
||||
@@ -8,6 +8,7 @@ from autogpt_libs.supabase_integration_credentials_store import OAuth2Credential
|
||||
from .base import BaseOAuthHandler
|
||||
|
||||
|
||||
# --8<-- [start:GithubOAuthHandlerExample]
|
||||
class GitHubOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
Based on the documentation at:
|
||||
@@ -23,7 +24,6 @@ class GitHubOAuthHandler(BaseOAuthHandler):
|
||||
""" # noqa
|
||||
|
||||
PROVIDER_NAME = "github"
|
||||
EMAIL_ENDPOINT = "https://api.github.com/user/emails"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
@@ -31,6 +31,7 @@ class GitHubOAuthHandler(BaseOAuthHandler):
|
||||
self.redirect_uri = redirect_uri
|
||||
self.auth_base_url = "https://github.com/login/oauth/authorize"
|
||||
self.token_url = "https://github.com/login/oauth/access_token"
|
||||
self.revoke_url = "https://api.github.com/applications/{client_id}/token"
|
||||
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
params = {
|
||||
@@ -46,6 +47,24 @@ class GitHubOAuthHandler(BaseOAuthHandler):
|
||||
) -> OAuth2Credentials:
|
||||
return self._request_tokens({"code": code, "redirect_uri": self.redirect_uri})
|
||||
|
||||
def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
if not credentials.access_token:
|
||||
raise ValueError("No access token to revoke")
|
||||
|
||||
headers = {
|
||||
"Accept": "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
}
|
||||
|
||||
response = requests.delete(
|
||||
url=self.revoke_url.format(client_id=self.client_id),
|
||||
auth=(self.client_id, self.client_secret),
|
||||
headers=headers,
|
||||
json={"access_token": credentials.access_token.get_secret_value()},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
|
||||
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
if not credentials.refresh_token:
|
||||
return credentials
|
||||
@@ -119,3 +138,6 @@ class GitHubOAuthHandler(BaseOAuthHandler):
|
||||
|
||||
# Get the login (username)
|
||||
return response.json().get("login")
|
||||
|
||||
|
||||
# --8<-- [end:GithubOAuthHandlerExample]
|
||||
|
||||
@@ -14,6 +14,7 @@ from .base import BaseOAuthHandler
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# --8<-- [start:GoogleOAuthHandlerExample]
|
||||
class GoogleOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
Based on the documentation at https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
@@ -26,12 +27,14 @@ class GoogleOAuthHandler(BaseOAuthHandler):
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"openid",
|
||||
]
|
||||
# --8<-- [end:GoogleOAuthHandlerExample]
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.token_uri = "https://oauth2.googleapis.com/token"
|
||||
self.revoke_uri = "https://oauth2.googleapis.com/revoke"
|
||||
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
all_scopes = list(set(scopes + self.DEFAULT_SCOPES))
|
||||
@@ -98,6 +101,16 @@ class GoogleOAuthHandler(BaseOAuthHandler):
|
||||
|
||||
return credentials
|
||||
|
||||
def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
session = AuthorizedSession(credentials)
|
||||
response = session.post(
|
||||
self.revoke_uri,
|
||||
params={"token": credentials.access_token.get_secret_value()},
|
||||
headers={"content-type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
|
||||
def _request_email(
|
||||
self, creds: Credentials | ExternalAccountCredentials
|
||||
) -> str | None:
|
||||
|
||||
@@ -77,6 +77,10 @@ class NotionOAuthHandler(BaseOAuthHandler):
|
||||
},
|
||||
)
|
||||
|
||||
def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
# Notion doesn't support token revocation
|
||||
return False
|
||||
|
||||
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
# Notion doesn't support token refresh
|
||||
return credentials
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from backend.app import run_processes
|
||||
from backend.executor import ExecutionScheduler
|
||||
from backend.server import AgentServer
|
||||
from backend.server.rest_api import AgentServer
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
from .rest_api import AgentServer
|
||||
from .ws_api import WebsocketServer
|
||||
|
||||
__all__ = ["AgentServer", "WebsocketServer"]
|
||||
|
||||
@@ -1,40 +1,25 @@
|
||||
import logging
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store import (
|
||||
SupabaseIntegrationCredentialsStore,
|
||||
)
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
CredentialsType,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Path,
|
||||
Query,
|
||||
Request,
|
||||
Response,
|
||||
)
|
||||
from pydantic import BaseModel, SecretStr
|
||||
from supabase import Client
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from ..utils import get_supabase, get_user_id
|
||||
from ..utils import get_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def get_store(supabase: Client = Depends(get_supabase)):
|
||||
return SupabaseIntegrationCredentialsStore(supabase)
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
@@ -47,7 +32,6 @@ async def login(
|
||||
provider: Annotated[str, Path(title="The provider to initiate an OAuth flow for")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
request: Request,
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
scopes: Annotated[
|
||||
str, Query(title="Comma-separated list of authorization scopes")
|
||||
] = "",
|
||||
@@ -57,7 +41,9 @@ async def login(
|
||||
requested_scopes = scopes.split(",") if scopes else []
|
||||
|
||||
# Generate and store a secure random state token along with the scopes
|
||||
state_token = await store.store_state_token(user_id, provider, requested_scopes)
|
||||
state_token = await creds_manager.store.store_state_token(
|
||||
user_id, provider, requested_scopes
|
||||
)
|
||||
|
||||
login_url = handler.get_login_url(requested_scopes, state_token)
|
||||
|
||||
@@ -77,7 +63,6 @@ async def callback(
|
||||
provider: Annotated[str, Path(title="The target provider for this OAuth exchange")],
|
||||
code: Annotated[str, Body(title="Authorization code acquired by user login")],
|
||||
state_token: Annotated[str, Body(title="Anti-CSRF nonce")],
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
request: Request,
|
||||
) -> CredentialsMetaResponse:
|
||||
@@ -85,12 +70,12 @@ async def callback(
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
|
||||
# Verify the state token
|
||||
if not await store.verify_state_token(user_id, state_token, provider):
|
||||
if not await creds_manager.store.verify_state_token(user_id, state_token, provider):
|
||||
logger.warning(f"Invalid or expired state token for user {user_id}")
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired state token")
|
||||
|
||||
try:
|
||||
scopes = await store.get_any_valid_scopes_from_state_token(
|
||||
scopes = await creds_manager.store.get_any_valid_scopes_from_state_token(
|
||||
user_id, state_token, provider
|
||||
)
|
||||
logger.debug(f"Retrieved scopes from state token: {scopes}")
|
||||
@@ -114,7 +99,7 @@ async def callback(
|
||||
)
|
||||
|
||||
# TODO: Allow specifying `title` to set on `credentials`
|
||||
store.add_creds(user_id, credentials)
|
||||
creds_manager.create(user_id, credentials)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully processed OAuth callback for user {user_id} and provider {provider}"
|
||||
@@ -132,9 +117,8 @@ async def callback(
|
||||
async def list_credentials(
|
||||
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
credentials = store.get_creds_by_provider(user_id, provider)
|
||||
credentials = creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
return [
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
@@ -152,9 +136,8 @@ async def get_credential(
|
||||
provider: Annotated[str, Path(title="The provider to retrieve credentials for")],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
) -> Credentials:
|
||||
credential = store.get_creds_by_id(user_id, cred_id)
|
||||
credential = creds_manager.get(user_id, cred_id)
|
||||
if not credential:
|
||||
raise HTTPException(status_code=404, detail="Credentials not found")
|
||||
if credential.provider != provider:
|
||||
@@ -166,7 +149,6 @@ async def get_credential(
|
||||
|
||||
@router.post("/{provider}/credentials", status_code=201)
|
||||
async def create_api_key_credentials(
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
provider: Annotated[str, Path(title="The provider to create credentials for")],
|
||||
api_key: Annotated[str, Body(title="The API key to store")],
|
||||
@@ -183,7 +165,7 @@ async def create_api_key_credentials(
|
||||
)
|
||||
|
||||
try:
|
||||
store.add_creds(user_id, new_credentials)
|
||||
creds_manager.create(user_id, new_credentials)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to store credentials: {str(e)}"
|
||||
@@ -191,14 +173,23 @@ async def create_api_key_credentials(
|
||||
return new_credentials
|
||||
|
||||
|
||||
@router.delete("/{provider}/credentials/{cred_id}", status_code=204)
|
||||
async def delete_credential(
|
||||
class CredentialsDeletionResponse(BaseModel):
|
||||
deleted: Literal[True] = True
|
||||
revoked: bool | None = Field(
|
||||
description="Indicates whether the credentials were also revoked by their "
|
||||
"provider. `None`/`null` if not applicable, e.g. when deleting "
|
||||
"non-revocable credentials such as API keys."
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{provider}/credentials/{cred_id}")
|
||||
async def delete_credentials(
|
||||
request: Request,
|
||||
provider: Annotated[str, Path(title="The provider to delete credentials for")],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to delete")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
):
|
||||
creds = store.get_creds_by_id(user_id, cred_id)
|
||||
) -> CredentialsDeletionResponse:
|
||||
creds = creds_manager.store.get_creds_by_id(user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(status_code=404, detail="Credentials not found")
|
||||
if creds.provider != provider:
|
||||
@@ -206,8 +197,14 @@ async def delete_credential(
|
||||
status_code=404, detail="Credentials do not match the specified provider"
|
||||
)
|
||||
|
||||
store.delete_creds_by_id(user_id, cred_id)
|
||||
return Response(status_code=204)
|
||||
creds_manager.delete(user_id, cred_id)
|
||||
|
||||
tokens_revoked = None
|
||||
if isinstance(creds, OAuth2Credentials):
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
tokens_revoked = handler.revoke_tokens(creds)
|
||||
|
||||
return CredentialsDeletionResponse(revoked=tokens_revoked)
|
||||
|
||||
|
||||
# -------- UTILITIES --------- #
|
||||
@@ -0,0 +1,11 @@
|
||||
from supabase import Client, create_client
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def get_supabase() -> Client:
|
||||
return create_client(
|
||||
settings.secrets.supabase_url, settings.secrets.supabase_service_role_key
|
||||
)
|
||||
@@ -10,19 +10,20 @@ from autogpt_libs.auth.middleware import auth_middleware
|
||||
from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data import block, db
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data import user as user_db
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.credit import get_block_costs, get_user_credit_model
|
||||
from backend.data.queue import RedisEventQueue
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.executor import ExecutionManager, ExecutionScheduler
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.server.model import CreateGraph, SetGraphActiveVersion
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.settings import Config, Settings
|
||||
from backend.util.cache import thread_cached_property
|
||||
from backend.util.service import AppService, get_service_client
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
|
||||
from .utils import get_user_id
|
||||
|
||||
@@ -31,26 +32,22 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentServer(AppService):
|
||||
use_queue = True
|
||||
_test_dependency_overrides = {}
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(port=Config().agent_server_port)
|
||||
self.event_queue = RedisEventQueue()
|
||||
self.use_redis = True
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, _: FastAPI):
|
||||
await db.connect()
|
||||
self.event_queue.connect()
|
||||
await block.initialize_blocks()
|
||||
if await user_db.create_default_user(settings.config.enable_auth):
|
||||
await graph_db.import_packaged_templates()
|
||||
yield
|
||||
self.event_queue.close()
|
||||
await db.disconnect()
|
||||
|
||||
def run_service(self):
|
||||
docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
|
||||
app = FastAPI(
|
||||
title="AutoGPT Agent Server",
|
||||
description=(
|
||||
@@ -60,6 +57,7 @@ class AgentServer(AppService):
|
||||
summary="AutoGPT Agent Server",
|
||||
version="0.1",
|
||||
lifespan=self.lifespan,
|
||||
docs_url=docs_url,
|
||||
)
|
||||
|
||||
if self._test_dependency_overrides:
|
||||
@@ -77,20 +75,29 @@ class AgentServer(AppService):
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
|
||||
health_router = APIRouter()
|
||||
health_router.add_api_route(
|
||||
path="/health",
|
||||
endpoint=self.health,
|
||||
methods=["GET"],
|
||||
tags=["health"],
|
||||
)
|
||||
|
||||
# Define the API routes
|
||||
api_router = APIRouter(prefix="/api")
|
||||
api_router.dependencies.append(Depends(auth_middleware))
|
||||
|
||||
# Import & Attach sub-routers
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.analytics
|
||||
import backend.server.routers.integrations
|
||||
|
||||
api_router.include_router(
|
||||
backend.server.routers.integrations.router,
|
||||
backend.server.integrations.router.router,
|
||||
prefix="/integrations",
|
||||
tags=["integrations"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
self.integration_creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
api_router.include_router(
|
||||
backend.server.routers.analytics.router,
|
||||
@@ -166,6 +173,12 @@ class AgentServer(AppService):
|
||||
methods=["PUT"],
|
||||
tags=["templates", "graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}",
|
||||
endpoint=self.delete_graph,
|
||||
methods=["DELETE"],
|
||||
tags=["graphs"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/graphs/{graph_id}/versions",
|
||||
endpoint=self.get_graph_all_versions,
|
||||
@@ -254,6 +267,7 @@ class AgentServer(AppService):
|
||||
app.add_exception_handler(500, self.handle_internal_http_error)
|
||||
|
||||
app.include_router(api_router)
|
||||
app.include_router(health_router)
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
@@ -291,11 +305,11 @@ class AgentServer(AppService):
|
||||
|
||||
return wrapper
|
||||
|
||||
@property
|
||||
@thread_cached_property
|
||||
def execution_manager_client(self) -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager, Config().execution_manager_port)
|
||||
|
||||
@property
|
||||
@thread_cached_property
|
||||
def execution_scheduler_client(self) -> ExecutionScheduler:
|
||||
return get_service_client(ExecutionScheduler, Config().execution_scheduler_port)
|
||||
|
||||
@@ -355,8 +369,11 @@ class AgentServer(AppService):
|
||||
graph_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
version: int | None = None,
|
||||
hide_credentials: bool = False,
|
||||
) -> graph_db.Graph:
|
||||
graph = await graph_db.get_graph(graph_id, version, user_id=user_id)
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id, version, user_id=user_id, hide_credentials=hide_credentials
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
return graph
|
||||
@@ -393,6 +410,17 @@ class AgentServer(AppService):
|
||||
) -> graph_db.Graph:
|
||||
return await cls.create_graph(create_graph, is_template=True, user_id=user_id)
|
||||
|
||||
class DeleteGraphResponse(TypedDict):
|
||||
version_counts: int
|
||||
|
||||
@classmethod
|
||||
async def delete_graph(
|
||||
cls, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> DeleteGraphResponse:
|
||||
return {
|
||||
"version_counts": await graph_db.delete_graph(graph_id, user_id=user_id)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def create_graph(
|
||||
cls,
|
||||
@@ -613,10 +641,8 @@ class AgentServer(AppService):
|
||||
execution_scheduler = self.execution_scheduler_client
|
||||
return execution_scheduler.get_execution_schedules(graph_id, user_id)
|
||||
|
||||
@expose
|
||||
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
|
||||
execution_result = execution_db.ExecutionResult(**execution_result_dict)
|
||||
self.event_queue.put(execution_result)
|
||||
async def health(self):
|
||||
return {"status": "healthy"}
|
||||
|
||||
@classmethod
|
||||
def update_configuration(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from autogpt_libs.auth.middleware import auth_middleware
|
||||
from fastapi import Depends, HTTPException
|
||||
from supabase import Client, create_client
|
||||
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.util.settings import Settings
|
||||
@@ -17,9 +16,3 @@ def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found in token")
|
||||
return user_id
|
||||
|
||||
|
||||
def get_supabase() -> Client:
|
||||
return create_client(
|
||||
settings.secrets.supabase_url, settings.secrets.supabase_service_role_key
|
||||
)
|
||||
|
||||
@@ -7,12 +7,13 @@ from autogpt_libs.auth import parse_jwt_token
|
||||
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from backend.data import redis
|
||||
from backend.data.queue import RedisEventQueue
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.model import ExecutionSubscription, Methods, WsMessage
|
||||
from backend.util.service import AppProcess
|
||||
from backend.util.settings import Config, Settings
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
@@ -20,16 +21,14 @@ settings = Settings()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
event_queue.connect()
|
||||
manager = get_connection_manager()
|
||||
fut = asyncio.create_task(event_broadcaster(manager))
|
||||
fut.add_done_callback(lambda _: logger.info("Event broadcaster stopped"))
|
||||
yield
|
||||
event_queue.close()
|
||||
|
||||
|
||||
docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
event_queue = RedisEventQueue()
|
||||
_connection_manager = None
|
||||
|
||||
logger.info(f"CORS allow origins: {settings.config.backend_cors_allow_origins}")
|
||||
@@ -50,12 +49,20 @@ def get_connection_manager():
|
||||
|
||||
|
||||
async def event_broadcaster(manager: ConnectionManager):
|
||||
while True:
|
||||
event = event_queue.get()
|
||||
if event is not None:
|
||||
await manager.send_execution_result(event)
|
||||
else:
|
||||
await asyncio.sleep(0.1)
|
||||
try:
|
||||
redis.connect()
|
||||
event_queue = RedisEventQueue()
|
||||
while True:
|
||||
event = event_queue.get()
|
||||
if event:
|
||||
await manager.send_execution_result(event)
|
||||
else:
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception as e:
|
||||
logger.exception(f"Event broadcaster error: {e}")
|
||||
raise
|
||||
finally:
|
||||
redis.disconnect()
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
|
||||
21
autogpt_platform/backend/backend/util/cache.py
Normal file
21
autogpt_platform/backend/backend/util/cache.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import threading
|
||||
from functools import wraps
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def thread_cached_property(func: Callable[[T], R]) -> property:
|
||||
local_cache = threading.local()
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self: T) -> R:
|
||||
if not hasattr(local_cache, "cache"):
|
||||
local_cache.cache = {}
|
||||
key = id(self)
|
||||
if key not in local_cache.cache:
|
||||
local_cache.cache[key] = func(self)
|
||||
return local_cache.cache[key]
|
||||
|
||||
return property(wrapper)
|
||||
@@ -1,4 +1,6 @@
|
||||
import os
|
||||
from backend.util.settings import AppEnvironment, BehaveAs, Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def configure_logging():
|
||||
@@ -6,7 +8,10 @@ def configure_logging():
|
||||
|
||||
import autogpt_libs.logging.config
|
||||
|
||||
if os.getenv("APP_ENV") != "cloud":
|
||||
if (
|
||||
settings.config.behave_as == BehaveAs.LOCAL
|
||||
or settings.config.app_env == AppEnvironment.LOCAL
|
||||
):
|
||||
autogpt_libs.logging.config.configure_logging(force_cloud_logging=False)
|
||||
else:
|
||||
autogpt_libs.logging.config.configure_logging(force_cloud_logging=True)
|
||||
|
||||
@@ -17,6 +17,11 @@ def get_service_name():
|
||||
return _SERVICE_NAME
|
||||
|
||||
|
||||
def set_service_name(name: str):
|
||||
global _SERVICE_NAME
|
||||
_SERVICE_NAME = name
|
||||
|
||||
|
||||
class AppProcess(ABC):
|
||||
"""
|
||||
A class to represent an object that can be executed in a background process.
|
||||
@@ -63,9 +68,7 @@ class AppProcess(ABC):
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
|
||||
global _SERVICE_NAME
|
||||
_SERVICE_NAME = self.service_name
|
||||
|
||||
set_service_name(self.service_name)
|
||||
logger.info(f"[{self.service_name}] Starting...")
|
||||
self.run()
|
||||
except (KeyboardInterrupt, SystemExit) as e:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from functools import wraps
|
||||
from uuid import uuid4
|
||||
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
@@ -21,28 +22,33 @@ def _log_prefix(resource_name: str, conn_id: str):
|
||||
def conn_retry(resource_name: str, action_name: str, max_retry: int = 5):
|
||||
conn_id = str(uuid4())
|
||||
|
||||
def before_call(retry_state):
|
||||
prefix = _log_prefix(resource_name, conn_id)
|
||||
logger.info(f"{prefix} {action_name} started...")
|
||||
|
||||
def after_call(retry_state):
|
||||
prefix = _log_prefix(resource_name, conn_id)
|
||||
if retry_state.outcome.failed:
|
||||
# Optionally, you can log something here if needed
|
||||
pass
|
||||
else:
|
||||
logger.info(f"{prefix} {action_name} completed!")
|
||||
|
||||
def on_retry(retry_state):
|
||||
prefix = _log_prefix(resource_name, conn_id)
|
||||
exception = retry_state.outcome.exception()
|
||||
logger.info(f"{prefix} {action_name} failed: {exception}. Retrying now...")
|
||||
|
||||
return retry(
|
||||
stop=stop_after_attempt(max_retry + 1),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=30),
|
||||
before=before_call,
|
||||
after=after_call,
|
||||
before_sleep=on_retry,
|
||||
reraise=True,
|
||||
)
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
prefix = _log_prefix(resource_name, conn_id)
|
||||
logger.info(f"{prefix} {action_name} started...")
|
||||
|
||||
# Define the retrying strategy
|
||||
retrying_func = retry(
|
||||
stop=stop_after_attempt(max_retry + 1),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=30),
|
||||
before_sleep=on_retry,
|
||||
reraise=True,
|
||||
)(func)
|
||||
|
||||
try:
|
||||
result = retrying_func(*args, **kwargs)
|
||||
logger.info(f"{prefix} {action_name} completed successfully.")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} {action_name} failed after retries: {e}")
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -1,16 +1,36 @@
|
||||
import asyncio
|
||||
import builtins
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Callable, Coroutine, Type, TypeVar, cast
|
||||
import typing
|
||||
from enum import Enum
|
||||
from types import NoneType, UnionType
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
Iterator,
|
||||
List,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
import Pyro5.api
|
||||
from pydantic import BaseModel
|
||||
from Pyro5 import api as pyro
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.queue import AbstractEventQueue, RedisEventQueue
|
||||
from backend.data import db, redis
|
||||
from backend.util.process import AppProcess
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.settings import Config, Secrets
|
||||
@@ -27,9 +47,8 @@ def expose(func: C) -> C:
|
||||
Decorator to mark a method or class to be exposed for remote calls.
|
||||
|
||||
## ⚠️ Gotcha
|
||||
The types on the exposed function signature are respected **as long as they are
|
||||
fully picklable**. This is not the case for Pydantic models, so if you really need
|
||||
to pass a model, try dumping the model and passing the resulting dict instead.
|
||||
Aside from "simple" types, only Pydantic models are passed unscathed *if annotated*.
|
||||
Any other passed or returned class objects are converted to dictionaries by Pyro.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
@@ -38,24 +57,59 @@ def expose(func: C) -> C:
|
||||
except Exception as e:
|
||||
msg = f"Error in {func.__name__}: {e.__str__()}"
|
||||
logger.exception(msg)
|
||||
raise Exception(msg, e)
|
||||
raise
|
||||
|
||||
# Register custom serializers and deserializers for annotated Pydantic models
|
||||
for name, annotation in func.__annotations__.items():
|
||||
try:
|
||||
pydantic_types = _pydantic_models_from_type_annotation(annotation)
|
||||
except Exception as e:
|
||||
raise TypeError(f"Error while exposing {func.__name__}: {e.__str__()}")
|
||||
|
||||
for model in pydantic_types:
|
||||
logger.debug(
|
||||
f"Registering Pyro (de)serializers for {func.__name__} annotation "
|
||||
f"'{name}': {model.__qualname__}"
|
||||
)
|
||||
pyro.register_class_to_dict(model, _make_custom_serializer(model))
|
||||
pyro.register_dict_to_class(
|
||||
model.__qualname__, _make_custom_deserializer(model)
|
||||
)
|
||||
|
||||
return pyro.expose(wrapper) # type: ignore
|
||||
|
||||
|
||||
def _make_custom_serializer(model: Type[BaseModel]):
|
||||
def custom_class_to_dict(obj):
|
||||
data = {
|
||||
"__class__": obj.__class__.__qualname__,
|
||||
**obj.model_dump(),
|
||||
}
|
||||
logger.debug(f"Serializing {obj.__class__.__qualname__} with data: {data}")
|
||||
return data
|
||||
|
||||
return custom_class_to_dict
|
||||
|
||||
|
||||
def _make_custom_deserializer(model: Type[BaseModel]):
|
||||
def custom_dict_to_class(qualname, data: dict):
|
||||
logger.debug(f"Deserializing {model.__qualname__} from data: {data}")
|
||||
return model(**data)
|
||||
|
||||
return custom_dict_to_class
|
||||
|
||||
|
||||
class AppService(AppProcess):
|
||||
shared_event_loop: asyncio.AbstractEventLoop
|
||||
event_queue: AbstractEventQueue = RedisEventQueue()
|
||||
use_db: bool = False
|
||||
use_queue: bool = False
|
||||
use_redis: bool = False
|
||||
use_supabase: bool = False
|
||||
|
||||
def __init__(self, port):
|
||||
self.port = port
|
||||
self.uri = None
|
||||
|
||||
@abstractmethod
|
||||
def run_service(self):
|
||||
def run_service(self) -> None:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
|
||||
@@ -70,8 +124,8 @@ class AppService(AppProcess):
|
||||
self.shared_event_loop = asyncio.get_event_loop()
|
||||
if self.use_db:
|
||||
self.shared_event_loop.run_until_complete(db.connect())
|
||||
if self.use_queue:
|
||||
self.event_queue.connect()
|
||||
if self.use_redis:
|
||||
redis.connect()
|
||||
if self.use_supabase:
|
||||
from supabase import create_client
|
||||
|
||||
@@ -97,9 +151,9 @@ class AppService(AppProcess):
|
||||
if self.use_db:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting DB...")
|
||||
self.run_and_wait(db.disconnect())
|
||||
if self.use_queue:
|
||||
if self.use_redis:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting Redis...")
|
||||
self.event_queue.close()
|
||||
redis.disconnect()
|
||||
|
||||
@conn_retry("Pyro", "Starting Pyro Service")
|
||||
def __start_pyro(self):
|
||||
@@ -131,6 +185,53 @@ def get_service_client(service_type: Type[AS], port: int) -> AS:
|
||||
logger.debug(f"Successfully connected to service [{service_name}]")
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
return getattr(self.proxy, name)
|
||||
res = getattr(self.proxy, name)
|
||||
return res
|
||||
|
||||
return cast(AS, DynamicClient())
|
||||
|
||||
|
||||
# --------- UTILITIES --------- #
|
||||
|
||||
builtin_types = [*vars(builtins).values(), NoneType, Enum]
|
||||
|
||||
|
||||
def _pydantic_models_from_type_annotation(annotation) -> Iterator[type[BaseModel]]:
|
||||
# Peel Annotated parameters
|
||||
if (origin := get_origin(annotation)) and origin is Annotated:
|
||||
annotation = get_args(annotation)[0]
|
||||
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
|
||||
if origin in (
|
||||
Union,
|
||||
UnionType,
|
||||
list,
|
||||
List,
|
||||
tuple,
|
||||
Tuple,
|
||||
set,
|
||||
Set,
|
||||
frozenset,
|
||||
FrozenSet,
|
||||
):
|
||||
for arg in args:
|
||||
yield from _pydantic_models_from_type_annotation(arg)
|
||||
elif origin in (dict, Dict):
|
||||
key_type, value_type = args
|
||||
yield from _pydantic_models_from_type_annotation(key_type)
|
||||
yield from _pydantic_models_from_type_annotation(value_type)
|
||||
else:
|
||||
annotype = annotation if origin is None else origin
|
||||
|
||||
# Exclude generic types and aliases
|
||||
if (
|
||||
annotype is not None
|
||||
and not hasattr(typing, getattr(annotype, "__name__", ""))
|
||||
and isinstance(annotype, type)
|
||||
):
|
||||
if issubclass(annotype, BaseModel):
|
||||
yield annotype
|
||||
elif annotype not in builtin_types and not issubclass(annotype, Enum):
|
||||
raise TypeError(f"Unsupported type encountered: {annotype}")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Generic, List, Set, Tuple, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
@@ -15,6 +16,17 @@ from backend.util.data import get_config_path, get_data_path, get_secrets_path
|
||||
T = TypeVar("T", bound=BaseSettings)
|
||||
|
||||
|
||||
class AppEnvironment(str, Enum):
|
||||
LOCAL = "local"
|
||||
DEVELOPMENT = "dev"
|
||||
PRODUCTION = "prod"
|
||||
|
||||
|
||||
class BehaveAs(str, Enum):
|
||||
LOCAL = "local"
|
||||
CLOUD = "cloud"
|
||||
|
||||
|
||||
class UpdateTrackingModel(BaseModel, Generic[T]):
|
||||
_updated_fields: Set[str] = PrivateAttr(default_factory=set)
|
||||
|
||||
@@ -105,6 +117,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The port for agent server daemon to run on",
|
||||
)
|
||||
|
||||
database_api_port: int = Field(
|
||||
default=8005,
|
||||
description="The port for database server API to run on",
|
||||
)
|
||||
|
||||
agent_api_host: str = Field(
|
||||
default="0.0.0.0",
|
||||
description="The host for agent server API to run on",
|
||||
@@ -121,6 +138,16 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
"This value is then used to generate redirect URLs for OAuth flows.",
|
||||
)
|
||||
|
||||
app_env: AppEnvironment = Field(
|
||||
default=AppEnvironment.LOCAL,
|
||||
description="The name of the app environment: local or dev or prod",
|
||||
)
|
||||
|
||||
behave_as: BehaveAs = Field(
|
||||
default=BehaveAs.LOCAL,
|
||||
description="What environment to behave as: local or cloud",
|
||||
)
|
||||
|
||||
backend_cors_allow_origins: List[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("backend_cors_allow_origins")
|
||||
@@ -177,10 +204,12 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
)
|
||||
|
||||
# OAuth server credentials for integrations
|
||||
# --8<-- [start:OAuthServerCredentialsExample]
|
||||
github_client_id: str = Field(default="", description="GitHub OAuth client ID")
|
||||
github_client_secret: str = Field(
|
||||
default="", description="GitHub OAuth client secret"
|
||||
)
|
||||
# --8<-- [end:OAuthServerCredentialsExample]
|
||||
google_client_id: str = Field(default="", description="Google OAuth client ID")
|
||||
google_client_secret: str = Field(
|
||||
default="", description="Google OAuth client secret"
|
||||
|
||||
@@ -5,15 +5,15 @@ from backend.data.block import Block, initialize_blocks
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import CREDENTIALS_FIELD_NAME
|
||||
from backend.data.user import create_default_user
|
||||
from backend.executor import ExecutionManager, ExecutionScheduler
|
||||
from backend.server import AgentServer
|
||||
from backend.server.rest_api import get_user_id
|
||||
from backend.executor import DatabaseManager, ExecutionManager, ExecutionScheduler
|
||||
from backend.server.rest_api import AgentServer, get_user_id
|
||||
|
||||
log = print
|
||||
|
||||
|
||||
class SpinTestServer:
|
||||
def __init__(self):
|
||||
self.db_api = DatabaseManager()
|
||||
self.exec_manager = ExecutionManager()
|
||||
self.agent_server = AgentServer()
|
||||
self.scheduler = ExecutionScheduler()
|
||||
@@ -24,6 +24,7 @@ class SpinTestServer:
|
||||
|
||||
async def __aenter__(self):
|
||||
self.setup_dependency_overrides()
|
||||
self.db_api.__enter__()
|
||||
self.agent_server.__enter__()
|
||||
self.exec_manager.__enter__()
|
||||
self.scheduler.__enter__()
|
||||
@@ -40,6 +41,7 @@ class SpinTestServer:
|
||||
self.scheduler.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.exec_manager.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.agent_server.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.db_api.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def setup_dependency_overrides(self):
|
||||
# Override get_user_id for testing
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
-- DropForeignKey
|
||||
ALTER TABLE "AgentGraph" DROP CONSTRAINT "AgentGraph_agentGraphParentId_version_fkey";
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "AgentGraph" ADD CONSTRAINT "AgentGraph_agentGraphParentId_version_fkey" FOREIGN KEY ("agentGraphParentId", "version") REFERENCES "AgentGraph"("id", "version") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
3
autogpt_platform/backend/poetry.lock
generated
3
autogpt_platform/backend/poetry.lock
generated
@@ -293,6 +293,7 @@ develop = true
|
||||
|
||||
[package.dependencies]
|
||||
colorama = "^0.4.6"
|
||||
expiringdict = "^1.2.2"
|
||||
google-cloud-logging = "^3.8.0"
|
||||
pydantic = "^2.8.2"
|
||||
pydantic-settings = "^2.5.2"
|
||||
@@ -3667,4 +3668,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "3ab370b624b486517a2fbcdc17fb294fbd76b3ec6659c5b471c57bfd738e7277"
|
||||
content-hash = "0962d61ced1a8154c64c6bbdb3f72aca558831adfbfda68eb66f39b535466f77"
|
||||
|
||||
@@ -16,7 +16,6 @@ autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||
click = "^8.1.7"
|
||||
croniter = "^2.0.5"
|
||||
discord-py = "^2.4.0"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.109.0"
|
||||
feedparser = "^6.0.11"
|
||||
flake8 = "^7.0.0"
|
||||
|
||||
@@ -53,7 +53,7 @@ model AgentGraph {
|
||||
// All sub-graphs are defined within this 1-level depth list (even if it's a nested graph).
|
||||
AgentSubGraphs AgentGraph[] @relation("AgentSubGraph")
|
||||
agentGraphParentId String?
|
||||
AgentGraphParent AgentGraph? @relation("AgentSubGraph", fields: [agentGraphParentId, version], references: [id, version])
|
||||
AgentGraphParent AgentGraph? @relation("AgentSubGraph", fields: [agentGraphParentId, version], references: [id, version], onDelete: Cascade)
|
||||
|
||||
@@id(name: "graphVersionId", [id, version])
|
||||
}
|
||||
@@ -63,7 +63,7 @@ model AgentNode {
|
||||
id String @id @default(uuid())
|
||||
|
||||
agentBlockId String
|
||||
AgentBlock AgentBlock @relation(fields: [agentBlockId], references: [id])
|
||||
AgentBlock AgentBlock @relation(fields: [agentBlockId], references: [id], onUpdate: Cascade)
|
||||
|
||||
agentGraphId String
|
||||
agentGraphVersion Int @default(1)
|
||||
|
||||
@@ -7,3 +7,28 @@ from backend.util.test import SpinTestServer
|
||||
async def server():
|
||||
async with SpinTestServer() as server:
|
||||
yield server
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
async def graph_cleanup(server):
|
||||
created_graph_ids = []
|
||||
original_create_graph = server.agent_server.create_graph
|
||||
|
||||
async def create_graph_wrapper(*args, **kwargs):
|
||||
created_graph = await original_create_graph(*args, **kwargs)
|
||||
# Extract user_id correctly
|
||||
user_id = kwargs.get("user_id", args[2] if len(args) > 2 else None)
|
||||
created_graph_ids.append((created_graph.id, user_id))
|
||||
return created_graph
|
||||
|
||||
try:
|
||||
server.agent_server.create_graph = create_graph_wrapper
|
||||
yield # This runs the test function
|
||||
finally:
|
||||
server.agent_server.create_graph = original_create_graph
|
||||
|
||||
# Delete the created graphs and assert they were deleted
|
||||
for graph_id, user_id in created_graph_ids:
|
||||
resp = await server.agent_server.delete_graph(graph_id, user_id)
|
||||
num_deleted = resp["version_counts"]
|
||||
assert num_deleted > 0, f"Graph {graph_id} was not deleted."
|
||||
|
||||
@@ -19,7 +19,7 @@ async def test_block_credit_usage(server: SpinTestServer):
|
||||
spending_amount_1 = await user_credit.spend_credits(
|
||||
DEFAULT_USER_ID,
|
||||
current_credit,
|
||||
AITextGeneratorBlock(),
|
||||
AITextGeneratorBlock().id,
|
||||
{"model": "gpt-4-turbo"},
|
||||
0.0,
|
||||
0.0,
|
||||
@@ -30,7 +30,7 @@ async def test_block_credit_usage(server: SpinTestServer):
|
||||
spending_amount_2 = await user_credit.spend_credits(
|
||||
DEFAULT_USER_ID,
|
||||
current_credit,
|
||||
AITextGeneratorBlock(),
|
||||
AITextGeneratorBlock().id,
|
||||
{"model": "gpt-4-turbo", "api_key": "owned_api_key"},
|
||||
0.0,
|
||||
0.0,
|
||||
|
||||
@@ -4,11 +4,16 @@ from prisma.models import User
|
||||
from backend.blocks.basic import FindInDictionaryBlock, StoreValueBlock
|
||||
from backend.blocks.maths import CalculatorBlock, Operation
|
||||
from backend.data import execution, graph
|
||||
from backend.server import AgentServer
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
|
||||
async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Graph:
|
||||
return await s.agent_server.create_graph(CreateGraph(graph=g), False, u.id)
|
||||
|
||||
|
||||
async def execute_graph(
|
||||
agent_server: AgentServer,
|
||||
test_graph: graph.Graph,
|
||||
@@ -99,9 +104,8 @@ async def assert_sample_graph_executions(
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_agent_execution(server: SpinTestServer):
|
||||
test_graph = create_test_graph()
|
||||
test_user = await create_test_user()
|
||||
await graph.create_graph(test_graph, user_id=test_user.id)
|
||||
test_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
data = {"input_1": "Hello", "input_2": "World"}
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server,
|
||||
@@ -163,7 +167,7 @@ async def test_input_pin_always_waited(server: SpinTestServer):
|
||||
links=links,
|
||||
)
|
||||
test_user = await create_test_user()
|
||||
test_graph = await graph.create_graph(test_graph, user_id=test_user.id)
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server, test_graph, test_user, {}, 3
|
||||
)
|
||||
@@ -244,7 +248,7 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
|
||||
links=links,
|
||||
)
|
||||
test_user = await create_test_user()
|
||||
test_graph = await graph.create_graph(test_graph, user_id=test_user.id)
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server, test_graph, test_user, {}, 8
|
||||
)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import pytest
|
||||
|
||||
from backend.data import db, graph
|
||||
from backend.data import db
|
||||
from backend.executor import ExecutionScheduler
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Config
|
||||
@@ -12,7 +13,11 @@ from backend.util.test import SpinTestServer
|
||||
async def test_agent_schedule(server: SpinTestServer):
|
||||
await db.connect()
|
||||
test_user = await create_test_user()
|
||||
test_graph = await graph.create_graph(create_test_graph(), user_id=test_user.id)
|
||||
test_graph = await server.agent_server.create_graph(
|
||||
create_graph=CreateGraph(graph=create_test_graph()),
|
||||
is_template=False,
|
||||
user_id=test_user.id,
|
||||
)
|
||||
|
||||
scheduler = get_service_client(
|
||||
ExecutionScheduler, Config().execution_scheduler_port
|
||||
|
||||
@@ -2,13 +2,12 @@ import pytest
|
||||
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
|
||||
TEST_SERVICE_PORT = 8765
|
||||
|
||||
class TestService(AppService):
|
||||
|
||||
class ServiceTest(AppService):
|
||||
def __init__(self):
|
||||
super().__init__(port=8005)
|
||||
|
||||
def run_service(self):
|
||||
super().run_service()
|
||||
super().__init__(port=TEST_SERVICE_PORT)
|
||||
|
||||
@expose
|
||||
def add(self, a: int, b: int) -> int:
|
||||
@@ -28,8 +27,8 @@ class TestService(AppService):
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_service_creation(server):
|
||||
with TestService():
|
||||
client = get_service_client(TestService, 8005)
|
||||
with ServiceTest():
|
||||
client = get_service_client(ServiceTest, TEST_SERVICE_PORT)
|
||||
assert client.add(5, 3) == 8
|
||||
assert client.subtract(10, 4) == 6
|
||||
assert client.fun_with_async(5, 3) == 8
|
||||
|
||||
@@ -103,6 +103,7 @@ services:
|
||||
- ENABLE_AUTH=true
|
||||
- PYRO_HOST=0.0.0.0
|
||||
- AGENTSERVER_HOST=rest_server
|
||||
- DATABASEMANAGER_HOST=0.0.0.0
|
||||
ports:
|
||||
- "8002:8000"
|
||||
networks:
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
{
|
||||
"extends": "next/core-web-vitals"
|
||||
"extends": ["next/core-web-vitals", "plugin:storybook/recommended"]
|
||||
}
|
||||
|
||||
3
autogpt_platform/frontend/.gitignore
vendored
3
autogpt_platform/frontend/.gitignore
vendored
@@ -42,3 +42,6 @@ node_modules/
|
||||
/playwright-report/
|
||||
/blob-report/
|
||||
/playwright/.cache/
|
||||
|
||||
*storybook.log
|
||||
storybook-static
|
||||
|
||||
18
autogpt_platform/frontend/.storybook/main.ts
Normal file
18
autogpt_platform/frontend/.storybook/main.ts
Normal file
@@ -0,0 +1,18 @@
|
||||
import type { StorybookConfig } from "@storybook/nextjs";
|
||||
|
||||
const config: StorybookConfig = {
|
||||
stories: ["../src/**/*.mdx", "../src/**/*.stories.@(js|jsx|mjs|ts|tsx)"],
|
||||
addons: [
|
||||
"@storybook/addon-onboarding",
|
||||
"@storybook/addon-links",
|
||||
"@storybook/addon-essentials",
|
||||
"@chromatic-com/storybook",
|
||||
"@storybook/addon-interactions",
|
||||
],
|
||||
framework: {
|
||||
name: "@storybook/nextjs",
|
||||
options: {},
|
||||
},
|
||||
staticDirs: ["../public"],
|
||||
};
|
||||
export default config;
|
||||
15
autogpt_platform/frontend/.storybook/preview.ts
Normal file
15
autogpt_platform/frontend/.storybook/preview.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
import type { Preview } from "@storybook/react";
|
||||
import "../src/app/globals.css";
|
||||
|
||||
const preview: Preview = {
|
||||
parameters: {
|
||||
controls: {
|
||||
matchers: {
|
||||
color: /(background|color)$/i,
|
||||
date: /Date$/i,
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export default preview;
|
||||
@@ -14,7 +14,7 @@ CMD ["yarn", "run", "dev"]
|
||||
# Build stage for prod
|
||||
FROM base AS build
|
||||
COPY autogpt_platform/frontend/ .
|
||||
RUN npm run build
|
||||
RUN yarn build
|
||||
|
||||
# Prod stage
|
||||
FROM node:21-alpine AS prod
|
||||
@@ -29,4 +29,4 @@ COPY --from=build /app/public ./public
|
||||
COPY --from=build /app/next.config.mjs ./next.config.mjs
|
||||
|
||||
EXPOSE 3000
|
||||
CMD ["npm", "start"]
|
||||
CMD ["yarn", "start"]
|
||||
|
||||
@@ -39,3 +39,50 @@ This project uses [`next/font`](https://nextjs.org/docs/basic-features/font-opti
|
||||
## Deploy
|
||||
|
||||
TODO
|
||||
|
||||
## Storybook
|
||||
|
||||
Storybook is a powerful development environment for UI components. It allows you to build UI components in isolation, making it easier to develop, test, and document your components independently from your main application.
|
||||
|
||||
### Purpose in the Development Process
|
||||
|
||||
1. **Component Development**: Develop and test UI components in isolation.
|
||||
2. **Visual Testing**: Easily spot visual regressions.
|
||||
3. **Documentation**: Automatically document components and their props.
|
||||
4. **Collaboration**: Share components with your team or stakeholders for feedback.
|
||||
|
||||
### How to Use Storybook
|
||||
|
||||
1. **Start Storybook**:
|
||||
Run the following command to start the Storybook development server:
|
||||
|
||||
```bash
|
||||
npm run storybook
|
||||
```
|
||||
|
||||
This will start Storybook on port 6006. Open [http://localhost:6006](http://localhost:6006) in your browser to view your component library.
|
||||
|
||||
2. **Build Storybook**:
|
||||
To build a static version of Storybook for deployment, use:
|
||||
|
||||
```bash
|
||||
npm run build-storybook
|
||||
```
|
||||
|
||||
3. **Running Storybook Tests**:
|
||||
Storybook tests can be run using:
|
||||
|
||||
```bash
|
||||
npm run test-storybook
|
||||
```
|
||||
|
||||
For CI environments, use:
|
||||
|
||||
```bash
|
||||
npm run test-storybook:ci
|
||||
```
|
||||
|
||||
4. **Writing Stories**:
|
||||
Create `.stories.tsx` files alongside your components to define different states and variations of your components.
|
||||
|
||||
By integrating Storybook into our development workflow, we can streamline UI development, improve component reusability, and maintain a consistent design system across the project.
|
||||
|
||||
@@ -8,11 +8,15 @@
|
||||
"dev:test": "export NODE_ENV=test && next dev",
|
||||
"build": "next build",
|
||||
"start": "next start",
|
||||
"lint": "next lint",
|
||||
"lint": "next lint && prettier --check .",
|
||||
"format": "prettier --write .",
|
||||
"test": "playwright test",
|
||||
"test-ui": "playwright test --ui",
|
||||
"gentests": "playwright codegen http://localhost:3000"
|
||||
"gentests": "playwright codegen http://localhost:3000",
|
||||
"storybook": "storybook dev -p 6006",
|
||||
"build-storybook": "storybook build",
|
||||
"test-storybook": "test-storybook",
|
||||
"test-storybook:ci": "concurrently -k -s first -n \"SB,TEST\" -c \"magenta,blue\" \"npm run build-storybook -- --quiet && npx http-server storybook-static --port 6006 --silent\" \"wait-on tcp:6006 && npm run test-storybook\""
|
||||
},
|
||||
"browserslist": [
|
||||
"defaults"
|
||||
@@ -23,6 +27,7 @@
|
||||
"@radix-ui/react-avatar": "^1.1.0",
|
||||
"@radix-ui/react-checkbox": "^1.1.1",
|
||||
"@radix-ui/react-collapsible": "^1.1.0",
|
||||
"@radix-ui/react-context-menu": "^2.2.1",
|
||||
"@radix-ui/react-dialog": "^1.1.1",
|
||||
"@radix-ui/react-dropdown-menu": "^2.1.1",
|
||||
"@radix-ui/react-icons": "^1.3.0",
|
||||
@@ -39,7 +44,7 @@
|
||||
"@supabase/ssr": "^0.4.0",
|
||||
"@supabase/supabase-js": "^2.45.0",
|
||||
"@tanstack/react-table": "^8.20.5",
|
||||
"@xyflow/react": "^12.1.0",
|
||||
"@xyflow/react": "^12.3.1",
|
||||
"ajv": "^8.17.1",
|
||||
"class-variance-authority": "^0.7.0",
|
||||
"clsx": "^2.1.1",
|
||||
@@ -65,17 +70,31 @@
|
||||
"zod": "^3.23.8"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@chromatic-com/storybook": "^1.9.0",
|
||||
"@playwright/test": "^1.47.1",
|
||||
"@storybook/addon-essentials": "^8.3.5",
|
||||
"@storybook/addon-interactions": "^8.3.5",
|
||||
"@storybook/addon-links": "^8.3.5",
|
||||
"@storybook/addon-onboarding": "^8.3.5",
|
||||
"@storybook/blocks": "^8.3.5",
|
||||
"@storybook/nextjs": "^8.3.5",
|
||||
"@storybook/react": "^8.3.5",
|
||||
"@storybook/test": "^8.3.5",
|
||||
"@storybook/test-runner": "^0.19.1",
|
||||
"@types/node": "^22.7.3",
|
||||
"@types/react": "^18",
|
||||
"@types/react-dom": "^18",
|
||||
"@types/react-modal": "^3.16.3",
|
||||
"concurrently": "^9.0.1",
|
||||
"eslint": "^8",
|
||||
"eslint-config-next": "14.2.4",
|
||||
"eslint-plugin-storybook": "^0.9.0",
|
||||
"postcss": "^8",
|
||||
"prettier": "^3.3.3",
|
||||
"prettier-plugin-tailwindcss": "^0.6.6",
|
||||
"storybook": "^8.3.5",
|
||||
"tailwindcss": "^3.4.1",
|
||||
"typescript": "^5"
|
||||
}
|
||||
},
|
||||
"packageManager": "yarn@1.22.22+sha512.a6b2f7906b721bba3d67d4aff083df04dad64c399707841b7acf00f6b133b7ac24255f2652fa22ae3534329dc6180534e98d17432037ff6fd140556e2bb3137e"
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ async function AdminMarketplace() {
|
||||
|
||||
return (
|
||||
<>
|
||||
<AdminMarketplaceAgentList agents={reviewableAgents.agents} />
|
||||
<AdminMarketplaceAgentList agents={reviewableAgents.items} />
|
||||
<Separator className="my-4" />
|
||||
<AdminFeaturedAgentsControl className="mt-4" />
|
||||
</>
|
||||
|
||||
@@ -27,7 +27,7 @@
|
||||
--destructive: 0 84.2% 60.2%;
|
||||
--destructive-foreground: 0 0% 98%;
|
||||
--border: 240 5.9% 90%;
|
||||
--input: 240 5.9% 90%;
|
||||
--input: 240 5.9% 85%;
|
||||
--ring: 240 5.9% 10%;
|
||||
--radius: 0.5rem;
|
||||
--chart-1: 12 76% 61%;
|
||||
@@ -72,4 +72,12 @@
|
||||
body {
|
||||
@apply bg-background text-foreground;
|
||||
}
|
||||
|
||||
.agpt-border-input {
|
||||
@apply border-input focus-visible:border-gray-400 focus-visible:outline-none;
|
||||
}
|
||||
|
||||
.agpt-shadow-input {
|
||||
@apply shadow-sm focus-visible:shadow-md;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,7 +114,7 @@ export default function LoginPage() {
|
||||
return (
|
||||
<div className="flex h-[80vh] items-center justify-center">
|
||||
<div className="w-full max-w-md space-y-6 rounded-lg p-8 shadow-md">
|
||||
<div className="mb-6 space-y-2">
|
||||
{/* <div className="mb-6 space-y-2">
|
||||
<Button
|
||||
className="w-full"
|
||||
onClick={() => handleSignInWithProvider("google")}
|
||||
@@ -145,7 +145,7 @@ export default function LoginPage() {
|
||||
<FaDiscord className="mr-2 h-4 w-4" />
|
||||
Sign in with Discord
|
||||
</Button>
|
||||
</div>
|
||||
</div> */}
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onLogin)}>
|
||||
<FormField
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user