Compare commits

...

48 Commits

Author SHA1 Message Date
Aarushi
9334eee41d update execution manager with redis 2024-10-21 14:02:43 +01:00
Aarushi
e4a9c8216f fix(platform/infra): Fix liveness probe (#8382)
fix liveness probe
2024-10-21 10:55:16 +01:00
Aarushi
f19ed9f652 fix spelling mistake (#8380) 2024-10-21 10:39:59 +01:00
Emmanuel Ferdman
30376a8ec8 fix: update types reference (#8366)
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2024-10-19 10:52:41 -05:00
Aarushi
32680a549e Merge branch 'master' into dev 2024-10-18 13:59:06 +01:00
Aarushi
68158de126 fix(dockercompose): Fix db manager connection (#8377)
* add db host

* remove unused variable
2024-10-18 13:57:14 +01:00
Aarushi
6f3828fc99 fix(dockercompose): Fix db manager connection (#8377)
* add db host

* remove unused variable
2024-10-18 12:49:56 +00:00
Zamil Majdy
26b1bca033 refactor(backend): Make block fields consistently use SchemaField (#8360) 2024-10-18 10:22:05 +07:00
Kushal Agrawal
7f6354caae Update README.md (#8319)
* Update README.md

* Update README.md

---------

Co-authored-by: Aarushi <50577581+aarushik93@users.noreply.github.com>
2024-10-17 16:26:35 +00:00
Reinier van der Leer
5d4d2486da ci: Enforce dev as base branch for development (#8369)
* Create repo-pr-enforce-base-branch.yml

* fix quotes

* test

* fix github token

* fix trigger and CLI config

* change back trigger because otherwise I can't test it

* fix the fix

* fix repo selection

* fix perms?

* fix quotes and newlines escaping in message

* Update repo-pr-enforce-base-branch.yml

* grrr escape sequences in bash

* test

* clean up

---------

Co-authored-by: Aarushi <50577581+aarushik93@users.noreply.github.com>
2024-10-17 17:15:15 +01:00
Zamil Majdy
2c0286e411 feat(backend): Add credit for Jina/Search & LLM blocks (#8361) 2024-10-17 20:51:10 +07:00
Aarushi
fca8d61cc4 Merge branch 'master' into dev 2024-10-17 12:25:54 +01:00
Aarushi
2b4af19799 fix(platform/frontend): Add behave as var (#8365)
add behave as variable
2024-10-17 12:24:24 +01:00
Aarushi
615a9746dc fix(market): Reseal Market URL (#8363)
reseal market url
2024-10-17 11:46:02 +01:00
Aarushi
ac33c1eb03 fix(platform): Include health router (#8362)
include health router
2024-10-17 11:09:33 +01:00
Nicholas Tindle
d6d2820b92 fix(market): agent pagination and search errors (#8336)
* fix(market): agent pagination and search errors

* fix(frontend): search was not paginated

* fix: linting

* feat(market): linting ci

* fix(ci): branch limit name
2024-10-16 20:29:53 +00:00
Bently
3982e20faa feat(frontend): Allow copy and pasting of blocks between flows (#8346) 2024-10-16 21:21:01 +01:00
Bently
c029fde502 docs(blocks): Add documentation for each block we currently have (#8289)
* Initial upload of block docks

* add github + google block docs

* small tweak
2024-10-16 15:09:34 +00:00
Aarushi
405dd1659e fix(infra): Update backend server health check endpoint (#8351)
* feat(platform): List and revoke credentials in user profile (#8207)

Display existing credentials (OAuth and API keys) for all current providers: Google, Github, Notion and allow user to remove them. For providers that support it, we also revoke the tokens through the API: of the providers we currently have, Google and GitHub support it; Notion doesn't.

- Add credentials list and `Delete` button in `/profile`
- Add `revoke_tokens` abstract method to `BaseOAuthHandler` and implement it in each provider
- Revoke OAuth tokens for providers on `DELETE` `/{provider}/credentials/{cred_id}`, and return whether tokens could be revoked
   - Update `autogpt-server-api/baseClient.ts:deleteCredentials` with `CredentialsDeleteResponse` return type

Bonus:
- Update `autogpt-server-api/baseClient.ts:_request` to properly handle empty server responses

* fix(backend): Lower the number of node workers to save DB connections (#8331)

Change [graph]×[node] worker limit from 10×5 to 10×3

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>

* fix(ci,platform): Add dev branch trigger to all ci (#8339)

* update ci for dev

* update classic

* remove duplicate dev

* fix(frontend): Fix styling inconsistencies in input elements (#8337)

- Apply consistent border styling to `Input`, `Select`, and `Textarea`
   - Remove `rounded-xl` from node input elements

- Add `whitespace-nowrap` to `CustomNode` header category tags

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>

* feat(builder): Use configmap for builder (#8343)

use configmap in builder

* fix(platform,infra): Checkin non secret values  (#8344)

checkin non secrets

* security(platform): Add sealed secrets (#8342)

* add sealed secrets

* add encrypted secrets

* remove extra space

* Tf public media buckets (#8324)

* fix(infra): Fix sealed secret names  (#8350)

* fix sealed secret names

* fix names and add annotation

* feat(backend): Introduce executors shared DB connection (#8340)

* update health checkendpoint

---------

Co-authored-by: Krzysztof Czerwinski <34861343+kcze@users.noreply.github.com>
Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
Co-authored-by: Swifty <craigswift13@gmail.com>
2024-10-16 15:02:21 +00:00
Aarushi
2d0e51fe28 security(platform/backend): Add health endpoint (#8341)
* add health endpoint

* fix linting

---------

Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2024-10-16 14:57:23 +00:00
Zamil Majdy
6f07d24e93 feat(backend): Introduce executors shared DB connection (#8340) 2024-10-16 21:15:23 +07:00
Aarushi
9292597d56 fix(infra): Fix sealed secret names (#8350)
* fix sealed secret names

* fix names and add annotation
2024-10-16 14:54:35 +01:00
Swifty
f6eebcab6e Tf public media buckets (#8324) 2024-10-16 09:28:41 +00:00
Aarushi
9fe3fed1a2 security(platform): Add sealed secrets (#8342)
* add sealed secrets

* add encrypted secrets

* remove extra space
2024-10-16 09:59:54 +01:00
Aarushi
769ab18cca fix(platform,infra): Checkin non secret values (#8344)
checkin non secrets
2024-10-15 15:28:35 +01:00
Aarushi
d46219c80f feat(builder): Use configmap for builder (#8343)
use configmap in builder
2024-10-15 14:23:10 +01:00
Reinier van der Leer
97015a91ad fix(frontend): Fix styling inconsistencies in input elements (#8337)
- Apply consistent border styling to `Input`, `Select`, and `Textarea`
   - Remove `rounded-xl` from node input elements

- Add `whitespace-nowrap` to `CustomNode` header category tags

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2024-10-15 15:18:19 +02:00
Aarushi
a2ef456525 fix(ci,platform): Add dev branch trigger to all ci (#8339)
* update ci for dev

* update classic

* remove duplicate dev
2024-10-15 10:57:24 +01:00
Aarushi
1c71351652 fix(backend): Lower the number of node workers to save DB connections (#8331)
Change [graph]×[node] worker limit from 10×5 to 10×3

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2024-10-14 15:59:59 +00:00
Krzysztof Czerwinski
bd5d2b1e86 feat(platform): List and revoke credentials in user profile (#8207)
Display existing credentials (OAuth and API keys) for all current providers: Google, Github, Notion and allow user to remove them. For providers that support it, we also revoke the tokens through the API: of the providers we currently have, Google and GitHub support it; Notion doesn't.

- Add credentials list and `Delete` button in `/profile`
- Add `revoke_tokens` abstract method to `BaseOAuthHandler` and implement it in each provider
- Revoke OAuth tokens for providers on `DELETE` `/{provider}/credentials/{cred_id}`, and return whether tokens could be revoked
   - Update `autogpt-server-api/baseClient.ts:deleteCredentials` with `CredentialsDeleteResponse` return type

Bonus:
- Update `autogpt-server-api/baseClient.ts:_request` to properly handle empty server responses
2024-10-14 17:50:55 +02:00
Zamil Majdy
8502928a21 feat(frontend): Update block UI (#8260)
* feat(platform): Update block UI

* add border on card

* added delay and badge data-id

* Fix border & width for block control list

* More cleanup on border & shadow

* Nav border consistency

* Simplify category badges

* restored backward compatablility

* fix alignement of sub handles

* Fix dynamic pin experience

* Added a timeout to prevent losing focus whilst typing

* Added flex-col back in removed timeout

* Clear nodes before tutorial

* Fix highlight on tutorial

* Sort blocks

* lint

* Fix tutorial and lint error

* w-fit

* Fix tutorial modals silly jumps!

* updates to tutorial

* prettier

* add data-id to save control bar

* prettier again

---------

Co-authored-by: Swifty <craigswift13@gmail.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
Co-authored-by: Bently <tomnoon9@gmail.com>
2024-10-12 10:33:29 +01:00
Bently
c1f97415fb refactor(frontend/login): Hide sign in with Google/GitHub/Discord for now (#8318)
hide sign in with Google/GitHub/Discord for now
2024-10-11 16:52:08 +00:00
Reinier van der Leer
74e677baec ci(frontend): Enforce consistent yarn.lock + minor DX improvements (#8316)
- ci(frontend): Ensure CI fails if `yarn.lock` is inconsistent with `package.json`
- dx(frontend): Add Prettier check to `lint` script in `package.json`
- dx(frontend): Add `packageManager` to `package.json` for Corepack support
- build(frontend): Use `yarn` consistently in the Dockerfile
2024-10-11 16:51:15 +02:00
Reinier van der Leer
992989ee71 feat(backend): Ensure validity of OAuth credentials during graph execution (#8191)
- feat(backend/executor): Change credential injection mechanism to acquire credentials from `AgentServer` just before execution
  - Also locks the credentials for the duration of the execution

- feat(backend/server): Add thread-safe `IntegrationCredentialsManager` to handle and synchronize credentials-related operations

- feat(libs): Add mutexes to `SupabaseIntegrationCredentialsStore` to ensure thread-safety

Also:
- feat(backend): Added Pydantic model (de)serialization support to `@expose` decorator

Refactorings:
- refactor(backend, libs): Move `KeyedMutex` to `autogpt_libs.utils.synchronize`
- refactor(backend/server): Make `backend.server.integrations` module with `router`, `creds_manager`, and `utils` in it
2024-10-10 16:45:43 +00:00
Swifty
d8145c158c tool(platform): Add storybooks to aid UI development (#8274)
* storybook init

* alert stories

* Avatar stories

* badge stories

* button stories

* calander stories

* stories default

* card stories

* checkbox stories

* formatting

* added tailwind config

* add collapsible story

* added command story

* rename use-toast.ts to tsx

* added more stories

* fix linting issues

* added stories for input

* added stories for label

* Added tests to button story

* added multiselect stories

* added popover stories

* added render stories

* scroll area stories

* more stories

* Added rest of the stories for the default components

* fmt

* add test runner

* added ci

* fix tests

* fixing ci

* remove storybook from ci

* removed styling

* added new line
2024-10-10 16:35:05 +00:00
Zamil Majdy
9ad5e1f808 fix(platform): Remove blind try-except for yielding error on block (#8287) 2024-10-10 23:25:29 +07:00
vishesh10
7b92bae942 Fix block execution status in case of error (#8267) 2024-10-10 02:59:26 +00:00
Aarushi
c03e2fb949 tweak(platform): Remove importing templates from local dir (#8276)
* always filter on user id

* add user id to doc string

* fix linting

* fix imports function

* remove import templates from local directory
2024-10-09 23:13:46 +00:00
Zamil Majdy
dbc603c6eb fix(platform): Fix unexpected connection clash on two dynamic pins link with the same keys (#8252) 2024-10-09 15:29:13 -05:00
Aarushi
c582b5512a tweak(platform): Add Anthropic (#8286)
add anthropic in helm
2024-10-09 11:54:49 +01:00
Nicholas Tindle
e654aa1e7a feat(backend): add the capibility to disable llm models in the cloud env (#8285)
* feat(backend): logic to disable enums based on python logic

* feat(backend): add behave as setting and clarify its purpose and APP_ENV

APP_ENV is used for not cloud vs local but the application environment such as local/dev/prod so we need BehaveAs as well

* fix(backend): various uses of AppEnvironment without the Enum or incorrectly

AppEnv in the logging library will never be cloud due to the restrictions applied when loading settings in by pydantic settings. This commit fixes this error, however the code path for logging may now be incorrect

* feat(backend): use a metaclass to disable ollama in the cloud environment

* fix: formatting

* fix(backend): typing improvements

* fix(backend): more linting 😭
2024-10-09 10:12:48 +01:00
Aarushi
e37744b9f2 fix(platform): Update deletion of secret values to not do it in place (#8284)
update deletion of secret values to not do it in place
2024-10-08 22:43:36 +01:00
Aarushi
bc1df92c29 fix(platform): Fix marketplace leaking secrets (#8281)
* add hide secrets param

* Update autogpt_platform/backend/backend/data/graph.py

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>

* Update autogpt_platform/frontend/src/lib/autogpt-server-api/baseClient.ts

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>

* rename hide secrets

* use builtin dict

* delete key

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2024-10-08 20:20:18 +01:00
Nicholas Tindle
04473cad1e feat(docs): OAuth docs updates based on google block changes (#8243)
* feat(frontend,backend): testing

* feat: testing

* feat(backend): it works for reading email

* feat(backend): more docs on google

* fix(frontend,backend): formatting

* feat(backend): more logigin (i know this should be debug)

* feat(backend): make real the default scopes

* feat(backend): tests and linting

* fix: code review prep

* feat: sheets block

* feat: liniting

* Update route.ts

* Update autogpt_platform/backend/backend/integrations/oauth/google.py

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>

* Update autogpt_platform/backend/backend/server/routers/integrations.py

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>

* fix: revert opener change

* feat(frontend): add back opener

required to work on mac edge

* feat(frontend): drop typing list import from gmail

* fix: code review comments

* feat: code review changes

* feat: code review changes

* fix(backend): move from asserts to checks so they don't get optimized away in the future

* fix(backend): code review changes

* fix(backend): remove google specific check

* fix: add typing

* fix: only enable google blocks when oauth is configured for google

* fix: errors are real and valid outputs always when output

* fix(backend): add provider detail for debuging scope declines

* Update autogpt_platform/frontend/src/components/integrations/credentials-input.tsx

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>

* fix(frontend): enhance with comment, typeof error isn't known so this is best way to ensure the stringifyication will work

* feat: code review change requests

* fix: linting

* fix: reduce error catching

* fix: doc messages in code

* fix: check the correct scopes object 😄

* fix: remove double (and not needed) try catch

* fix: lint

* fix: scopes

* feat: handle the default scopes better

* feat: better email objectification

* feat: process attachements

turns out an email doesn't need a body

* fix: lint

* Update google.py

* Update autogpt_platform/backend/backend/data/block.py

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>

* fix: quit trying and except failure

* Update autogpt_platform/backend/backend/server/routers/integrations.py

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>

* feat: don't allow expired states

* fix: clarify function name and purpose

* feat: code links updates

* feat: additional docs on adding a block

* fix: type hint missing which means the block won't work

* fix: linting

* fix: docs formatting

* Update issues.py

* fix: improve the naming

* fix: formatting

* Update new_blocks.md

* Update new_blocks.md

* feat: better docs on what the args mean

* feat: more details on yield

* Update new_blocks.md

* fix: remove ignore from docs build

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2024-10-08 11:11:14 -05:00
Zamil Majdy
2a74381ae8 feat(platform): Add delete agent functionality (#8273) 2024-10-08 16:03:26 +00:00
Toran Bruce Richards
d42ed088dd tweak(docs): Further clarify licencing (#8282)
* Update files via upload

* Update README.md

* Update CONTRIBUTING.md

* Update CONTRIBUTING.md
2024-10-08 10:54:22 -05:00
Aarushi
2aed470d26 tweak(platform): Disable docs endpoint when not local (#8265)
* disable docs endpoint

* add to .env.example

* use enum for app env

* lint
2024-10-08 10:31:08 +01:00
Aarushi
61f1d0cdb5 fix(platform): Always filter on user id (#8275)
* always filter on user id

* add user id to doc string

* fix linting

* fix imports function
2024-10-07 15:47:49 +02:00
264 changed files with 16830 additions and 2486 deletions

View File

@@ -2,12 +2,12 @@ name: Classic - AutoGPT CI
on:
push:
branches: [ master, development, ci-test* ]
branches: [ master, dev, ci-test* ]
paths:
- '.github/workflows/classic-autogpt-ci.yml'
- 'classic/original_autogpt/**'
pull_request:
branches: [ master, development, release-* ]
branches: [ master, dev, release-* ]
paths:
- '.github/workflows/classic-autogpt-ci.yml'
- 'classic/original_autogpt/**'

View File

@@ -8,7 +8,7 @@ on:
- 'classic/original_autogpt/**'
- 'classic/forge/**'
pull_request:
branches: [ master, development, release-* ]
branches: [ master, dev, release-* ]
paths:
- '.github/workflows/classic-autogpt-docker-ci.yml'
- 'classic/original_autogpt/**'

View File

@@ -5,7 +5,7 @@ on:
schedule:
- cron: '0 8 * * *'
push:
branches: [ master, development, ci-test* ]
branches: [ master, dev, ci-test* ]
paths:
- '.github/workflows/classic-autogpts-ci.yml'
- 'classic/original_autogpt/**'
@@ -16,7 +16,7 @@ on:
- 'classic/setup.py'
- '!**/*.md'
pull_request:
branches: [ master, development, release-* ]
branches: [ master, dev, release-* ]
paths:
- '.github/workflows/classic-autogpts-ci.yml'
- 'classic/original_autogpt/**'

View File

@@ -2,13 +2,13 @@ name: Classic - AGBenchmark CI
on:
push:
branches: [ master, development, ci-test* ]
branches: [ master, dev, ci-test* ]
paths:
- 'classic/benchmark/**'
- '!classic/benchmark/reports/**'
- .github/workflows/classic-benchmark-ci.yml
pull_request:
branches: [ master, development, release-* ]
branches: [ master, dev, release-* ]
paths:
- 'classic/benchmark/**'
- '!classic/benchmark/reports/**'

View File

@@ -2,13 +2,13 @@ name: Classic - Forge CI
on:
push:
branches: [ master, development, ci-test* ]
branches: [ master, dev, ci-test* ]
paths:
- '.github/workflows/classic-forge-ci.yml'
- 'classic/forge/**'
- '!classic/forge/tests/vcr_cassettes'
pull_request:
branches: [ master, development, release-* ]
branches: [ master, dev, release-* ]
paths:
- '.github/workflows/classic-forge-ci.yml'
- 'classic/forge/**'

View File

@@ -2,7 +2,7 @@ name: Classic - Python checks
on:
push:
branches: [ master, development, ci-test* ]
branches: [ master, dev, ci-test* ]
paths:
- '.github/workflows/classic-python-checks-ci.yml'
- 'classic/original_autogpt/**'
@@ -11,7 +11,7 @@ on:
- '**.py'
- '!classic/forge/tests/vcr_cassettes'
pull_request:
branches: [ master, development, release-* ]
branches: [ master, dev, release-* ]
paths:
- '.github/workflows/classic-python-checks-ci.yml'
- 'classic/original_autogpt/**'

View File

@@ -2,7 +2,7 @@ name: AutoGPT Platform - Infra
on:
push:
branches: [ master ]
branches: [ master, dev ]
paths:
- '.github/workflows/platform-autogpt-infra-ci.yml'
- 'autogpt_platform/infra/**'

View File

@@ -2,12 +2,12 @@ name: AutoGPT Platform - Backend CI
on:
push:
branches: [master, development, ci-test*]
branches: [master, dev, ci-test*]
paths:
- ".github/workflows/platform-backend-ci.yml"
- "autogpt_platform/backend/**"
pull_request:
branches: [master, development, release-*]
branches: [master, dev, release-*]
paths:
- ".github/workflows/platform-backend-ci.yml"
- "autogpt_platform/backend/**"

View File

@@ -2,7 +2,7 @@ name: AutoGPT Platform - Frontend CI
on:
push:
branches: [master]
branches: [master, dev]
paths:
- ".github/workflows/platform-frontend-ci.yml"
- "autogpt_platform/frontend/**"
@@ -29,15 +29,11 @@ jobs:
- name: Install dependencies
run: |
npm install
- name: Check formatting with Prettier
run: |
npx prettier --check .
yarn install --frozen-lockfile
- name: Run lint
run: |
npm run lint
yarn lint
test:
runs-on: ubuntu-latest
@@ -62,18 +58,18 @@ jobs:
- name: Install dependencies
run: |
npm install
yarn install --frozen-lockfile
- name: Setup Builder .env
run: |
cp .env.example .env
- name: Install Playwright Browsers
run: npx playwright install --with-deps
run: yarn playwright install --with-deps
- name: Run tests
run: |
npm run test
yarn test
- uses: actions/upload-artifact@v4
if: ${{ !cancelled() }}

125
.github/workflows/platform-market-ci.yml vendored Normal file
View File

@@ -0,0 +1,125 @@
name: AutoGPT Platform - Backend CI
on:
push:
branches: [master, dev, ci-test*]
paths:
- ".github/workflows/platform-market-ci.yml"
- "autogpt_platform/market/**"
pull_request:
branches: [master, dev, release-*]
paths:
- ".github/workflows/platform-market-ci.yml"
- "autogpt_platform/market/**"
concurrency:
group: ${{ format('backend-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
cancel-in-progress: ${{ startsWith(github.event_name, 'pull_request') }}
defaults:
run:
shell: bash
working-directory: autogpt_platform/market
jobs:
test:
permissions:
contents: read
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Setup Supabase
uses: supabase/setup-cli@v1
with:
version: latest
- id: get_date
name: Get date
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
- name: Set up Python dependency cache
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/market/poetry.lock') }}
- name: Install Poetry (Unix)
run: |
curl -sSL https://install.python-poetry.org | python3 -
if [ "${{ runner.os }}" = "macOS" ]; then
PATH="$HOME/.local/bin:$PATH"
echo "$HOME/.local/bin" >> $GITHUB_PATH
fi
- name: Install Python dependencies
run: poetry install
- name: Generate Prisma Client
run: poetry run prisma generate
- id: supabase
name: Start Supabase
working-directory: .
run: |
supabase init
supabase start --exclude postgres-meta,realtime,storage-api,imgproxy,inbucket,studio,edge-runtime,logflare,vector,supavisor
supabase status -o env | sed 's/="/=/; s/"$//' >> $GITHUB_OUTPUT
# outputs:
# DB_URL, API_URL, GRAPHQL_URL, ANON_KEY, SERVICE_ROLE_KEY, JWT_SECRET
- name: Run Database Migrations
run: poetry run prisma migrate dev --name updates
env:
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
- id: lint
name: Run Linter
run: poetry run lint
# Tests comment out because they do not work with prisma mock, nor have they been updated since they were created
# - name: Run pytest with coverage
# run: |
# if [[ "${{ runner.debug }}" == "1" ]]; then
# poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG test
# else
# poetry run pytest -s -vv test
# fi
# if: success() || (failure() && steps.lint.outcome == 'failure')
# env:
# LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
# DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
# SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
# SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
# SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
# REDIS_HOST: 'localhost'
# REDIS_PORT: '6379'
# REDIS_PASSWORD: 'testpassword'
env:
CI: true
PLAIN_OUTPUT: True
RUN_ENV: local
PORT: 8080
# - name: Upload coverage reports to Codecov
# uses: codecov/codecov-action@v4
# with:
# token: ${{ secrets.CODECOV_TOKEN }}
# flags: backend,${{ runner.os }}

View File

@@ -0,0 +1,21 @@
name: Repo - Enforce dev as base branch
on:
pull_request_target:
branches: [ master ]
types: [ opened ]
jobs:
check_pr_target:
runs-on: ubuntu-latest
permissions:
pull-requests: write
steps:
- name: Check if PR is from dev or hotfix
if: ${{ !(startsWith(github.event.pull_request.head.ref, 'hotfix/') || github.event.pull_request.head.ref == 'dev') }}
run: |
gh pr comment ${{ github.event.number }} --repo "$REPO" \
--body $'This PR targets the `master` branch but does not come from `dev` or a `hotfix/*` branch.\n\nAutomatically setting the base branch to `dev`.'
gh pr edit ${{ github.event.number }} --base dev --repo "$REPO"
env:
GITHUB_TOKEN: ${{ github.token }}
REPO: ${{ github.repository }}

View File

@@ -3,7 +3,7 @@ name: Repo - Pull Request auto-label
on:
# So that PRs touching the same files as the push are updated
push:
branches: [ master, development, release-* ]
branches: [ master, dev, release-* ]
paths-ignore:
- 'classic/forge/tests/vcr_cassettes'
- 'classic/benchmark/reports/**'

View File

@@ -11,7 +11,7 @@ Also check out our [🚀 Roadmap][roadmap] for information about our priorities
[kanban board]: https://github.com/orgs/Significant-Gravitas/projects/1
## Contributing to the AutoGPT Platform Folder
All contributions to [the autogpt_platform folder](https://github.com/Significant-Gravitas/AutoGPT/blob/master/autogpt_platform) will be under our [Contribution License Agreement](https://github.com/Significant-Gravitas/AutoGPT/blob/master/autogpt_platform/Contributor%20License%20Agreement%20(CLA).md). By making a pull request contributing to this folder, you agree to the terms of our CLA for your contribution.
All contributions to [the autogpt_platform folder](https://github.com/Significant-Gravitas/AutoGPT/blob/master/autogpt_platform) will be under our [Contribution License Agreement](https://github.com/Significant-Gravitas/AutoGPT/blob/master/autogpt_platform/Contributor%20License%20Agreement%20(CLA).md). By making a pull request contributing to this folder, you agree to the terms of our CLA for your contribution. All contributions to other folders will be under the MIT license.
## In short
1. Avoid duplicate work, issues, PRs etc.

View File

@@ -1,7 +1,13 @@
All portions of this repository are under one of two licenses. The majority of the AutoGPT repository is under the MIT License below. The autogpt_platform folder is under the
Polyform Shield License.
MIT License
Copyright (c) 2023 Toran Bruce Richards
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
@@ -9,9 +15,11 @@ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE

View File

@@ -65,6 +65,7 @@ Here are two examples of what you can do with AutoGPT:
These examples show just a glimpse of what you can achieve with AutoGPT! You can create customized workflows to build agents for any use case.
---
### Mission and Licencing
Our mission is to provide the tools, so that you can focus on what matters:
- 🏗️ **Building** - Lay the foundation for something amazing.
@@ -77,6 +78,13 @@ Be part of the revolution! **AutoGPT** is here to stay, at the forefront of AI i
&ensp;|&ensp;
**🚀 [Contributing](CONTRIBUTING.md)**
**Licensing:**
MIT License: The majority of the AutoGPT repository is under the MIT License.
Polyform Shield License: This license applies to the autogpt_platform folder.
For more information, see https://agpt.co/blog/introducing-the-autogpt-platform
---
## 🤖 AutoGPT Classic
@@ -150,6 +158,8 @@ To maintain a uniform standard and ensure seamless compatibility with many curre
---
## Stars stats
<p align="center">
<a href="https://star-history.com/#Significant-Gravitas/AutoGPT">
<picture>
@@ -159,3 +169,10 @@ To maintain a uniform standard and ensure seamless compatibility with many curre
</picture>
</a>
</p>
## ⚡ Contributors
<a href="https://github.com/Significant-Gravitas/AutoGPT/graphs/contributors" alt="View Contributors">
<img src="https://contrib.rocks/image?repo=Significant-Gravitas/AutoGPT&max=1000&columns=10" alt="Contributors" />
</a>

View File

@@ -149,6 +149,3 @@ To persist data for PostgreSQL and Redis, you can modify the `docker-compose.yml
3. Save the file and run `docker compose up -d` to apply the changes.
This configuration will create named volumes for PostgreSQL and Redis, ensuring that your data persists across container restarts.

View File

@@ -1,8 +1,9 @@
from .store import SupabaseIntegrationCredentialsStore
from .types import APIKeyCredentials, OAuth2Credentials
from .types import Credentials, APIKeyCredentials, OAuth2Credentials
__all__ = [
"SupabaseIntegrationCredentialsStore",
"Credentials",
"APIKeyCredentials",
"OAuth2Credentials",
]

View File

@@ -1,8 +1,12 @@
import secrets
from datetime import datetime, timedelta, timezone
from typing import cast
from typing import TYPE_CHECKING, cast
from supabase import Client
if TYPE_CHECKING:
from redis import Redis
from supabase import Client
from autogpt_libs.utils.synchronize import RedisKeyedMutex
from .types import (
Credentials,
@@ -14,26 +18,28 @@ from .types import (
class SupabaseIntegrationCredentialsStore:
def __init__(self, supabase: Client):
def __init__(self, supabase: "Client", redis: "Redis"):
self.supabase = supabase
self.locks = RedisKeyedMutex(redis)
def add_creds(self, user_id: str, credentials: Credentials) -> None:
if self.get_creds_by_id(user_id, credentials.id):
raise ValueError(
f"Can not re-create existing credentials with ID {credentials.id} "
f"for user with ID {user_id}"
with self.locked_user_metadata(user_id):
if self.get_creds_by_id(user_id, credentials.id):
raise ValueError(
f"Can not re-create existing credentials #{credentials.id} "
f"for user #{user_id}"
)
self._set_user_integration_creds(
user_id, [*self.get_all_creds(user_id), credentials]
)
self._set_user_integration_creds(
user_id, [*self.get_all_creds(user_id), credentials]
)
def get_all_creds(self, user_id: str) -> list[Credentials]:
user_metadata = self._get_user_metadata(user_id)
return UserMetadata.model_validate(user_metadata).integration_credentials
def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None:
credentials = self.get_all_creds(user_id)
return next((c for c in credentials if c.id == credentials_id), None)
all_credentials = self.get_all_creds(user_id)
return next((c for c in all_credentials if c.id == credentials_id), None)
def get_creds_by_provider(self, user_id: str, provider: str) -> list[Credentials]:
credentials = self.get_all_creds(user_id)
@@ -44,42 +50,45 @@ class SupabaseIntegrationCredentialsStore:
return list(set(c.provider for c in credentials))
def update_creds(self, user_id: str, updated: Credentials) -> None:
current = self.get_creds_by_id(user_id, updated.id)
if not current:
raise ValueError(
f"Credentials with ID {updated.id} "
f"for user with ID {user_id} not found"
)
if type(current) is not type(updated):
raise TypeError(
f"Can not update credentials with ID {updated.id} "
f"from type {type(current)} "
f"to type {type(updated)}"
)
with self.locked_user_metadata(user_id):
current = self.get_creds_by_id(user_id, updated.id)
if not current:
raise ValueError(
f"Credentials with ID {updated.id} "
f"for user with ID {user_id} not found"
)
if type(current) is not type(updated):
raise TypeError(
f"Can not update credentials with ID {updated.id} "
f"from type {type(current)} "
f"to type {type(updated)}"
)
# Ensure no scopes are removed when updating credentials
if (
isinstance(updated, OAuth2Credentials)
and isinstance(current, OAuth2Credentials)
and not set(updated.scopes).issuperset(current.scopes)
):
raise ValueError(
f"Can not update credentials with ID {updated.id} "
f"and scopes {current.scopes} "
f"to more restrictive set of scopes {updated.scopes}"
)
# Ensure no scopes are removed when updating credentials
if (
isinstance(updated, OAuth2Credentials)
and isinstance(current, OAuth2Credentials)
and not set(updated.scopes).issuperset(current.scopes)
):
raise ValueError(
f"Can not update credentials with ID {updated.id} "
f"and scopes {current.scopes} "
f"to more restrictive set of scopes {updated.scopes}"
)
# Update the credentials
updated_credentials_list = [
updated if c.id == updated.id else c for c in self.get_all_creds(user_id)
]
self._set_user_integration_creds(user_id, updated_credentials_list)
# Update the credentials
updated_credentials_list = [
updated if c.id == updated.id else c
for c in self.get_all_creds(user_id)
]
self._set_user_integration_creds(user_id, updated_credentials_list)
def delete_creds_by_id(self, user_id: str, credentials_id: str) -> None:
filtered_credentials = [
c for c in self.get_all_creds(user_id) if c.id != credentials_id
]
self._set_user_integration_creds(user_id, filtered_credentials)
with self.locked_user_metadata(user_id):
filtered_credentials = [
c for c in self.get_all_creds(user_id) if c.id != credentials_id
]
self._set_user_integration_creds(user_id, filtered_credentials)
async def store_state_token(
self, user_id: str, provider: str, scopes: list[str]
@@ -94,14 +103,15 @@ class SupabaseIntegrationCredentialsStore:
scopes=scopes,
)
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])
oauth_states.append(state.model_dump())
user_metadata["integration_oauth_states"] = oauth_states
with self.locked_user_metadata(user_id):
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])
oauth_states.append(state.model_dump())
user_metadata["integration_oauth_states"] = oauth_states
self.supabase.auth.admin.update_user_by_id(
user_id, {"user_metadata": user_metadata}
)
self.supabase.auth.admin.update_user_by_id(
user_id, {"user_metadata": user_metadata}
)
return token
@@ -136,29 +146,30 @@ class SupabaseIntegrationCredentialsStore:
return []
async def verify_state_token(self, user_id: str, token: str, provider: str) -> bool:
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])
with self.locked_user_metadata(user_id):
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])
now = datetime.now(timezone.utc)
valid_state = next(
(
state
for state in oauth_states
if state["token"] == token
and state["provider"] == provider
and state["expires_at"] > now.timestamp()
),
None,
)
if valid_state:
# Remove the used state
oauth_states.remove(valid_state)
user_metadata["integration_oauth_states"] = oauth_states
self.supabase.auth.admin.update_user_by_id(
user_id, {"user_metadata": user_metadata}
now = datetime.now(timezone.utc)
valid_state = next(
(
state
for state in oauth_states
if state["token"] == token
and state["provider"] == provider
and state["expires_at"] > now.timestamp()
),
None,
)
return True
if valid_state:
# Remove the used state
oauth_states.remove(valid_state)
user_metadata["integration_oauth_states"] = oauth_states
self.supabase.auth.admin.update_user_by_id(
user_id, {"user_metadata": user_metadata}
)
return True
return False
@@ -178,3 +189,7 @@ class SupabaseIntegrationCredentialsStore:
if not response.user:
raise ValueError(f"User with ID {user_id} not found")
return cast(UserMetadataRaw, response.user.user_metadata)
def locked_user_metadata(self, user_id: str):
key = (self.supabase.supabase_url, f"user:{user_id}", "metadata")
return self.locks.locked(key)

View File

@@ -0,0 +1,56 @@
from contextlib import contextmanager
from threading import Lock
from typing import TYPE_CHECKING, Any
from expiringdict import ExpiringDict
if TYPE_CHECKING:
from redis import Redis
from redis.lock import Lock as RedisLock
class RedisKeyedMutex:
"""
This class provides a mutex that can be locked and unlocked by a specific key,
using Redis as a distributed locking provider.
It uses an ExpiringDict to automatically clear the mutex after a specified timeout,
in case the key is not unlocked for a specified duration, to prevent memory leaks.
"""
def __init__(self, redis: "Redis", timeout: int | None = 60):
self.redis = redis
self.timeout = timeout
self.locks: dict[Any, "RedisLock"] = ExpiringDict(
max_len=6000, max_age_seconds=self.timeout
)
self.locks_lock = Lock()
@contextmanager
def locked(self, key: Any):
lock = self.acquire(key)
try:
yield
finally:
lock.release()
def acquire(self, key: Any) -> "RedisLock":
"""Acquires and returns a lock with the given key"""
with self.locks_lock:
if key not in self.locks:
self.locks[key] = self.redis.lock(
str(key), self.timeout, thread_local=False
)
lock = self.locks[key]
lock.acquire()
return lock
def release(self, key: Any):
if lock := self.locks.get(key):
lock.release()
def release_all_locks(self):
"""Call this on process termination to ensure all locks are released"""
self.locks_lock.acquire(blocking=False)
for lock in self.locks.values():
if lock.locked() and lock.owned():
lock.release()

View File

@@ -377,6 +377,20 @@ files = [
[package.extras]
test = ["pytest (>=6)"]
[[package]]
name = "expiringdict"
version = "1.2.2"
description = "Dictionary with auto-expiring values for caching purposes"
optional = false
python-versions = "*"
files = [
{file = "expiringdict-1.2.2-py3-none-any.whl", hash = "sha256:09a5d20bc361163e6432a874edd3179676e935eb81b925eccef48d409a8a45e8"},
{file = "expiringdict-1.2.2.tar.gz", hash = "sha256:300fb92a7e98f15b05cf9a856c1415b3bc4f2e132be07daa326da6414c23ee09"},
]
[package.extras]
tests = ["coverage", "coveralls", "dill", "mock", "nose"]
[[package]]
name = "frozenlist"
version = "1.4.1"
@@ -1031,6 +1045,7 @@ description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs
optional = false
python-versions = ">=3.8"
files = [
{file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"},
{file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"},
]
@@ -1041,6 +1056,7 @@ description = "A collection of ASN.1-based protocols modules"
optional = false
python-versions = ">=3.8"
files = [
{file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"},
{file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"},
]
@@ -1253,6 +1269,24 @@ python-dateutil = ">=2.8.1,<3.0.0"
typing-extensions = ">=4.12.2,<5.0.0"
websockets = ">=11,<13"
[[package]]
name = "redis"
version = "5.1.1"
description = "Python client for Redis database and key-value store"
optional = false
python-versions = ">=3.8"
files = [
{file = "redis-5.1.1-py3-none-any.whl", hash = "sha256:f8ea06b7482a668c6475ae202ed8d9bcaa409f6e87fb77ed1043d912afd62e24"},
{file = "redis-5.1.1.tar.gz", hash = "sha256:f6c997521fedbae53387307c5d0bf784d9acc28d9f1d058abeac566ec4dbed72"},
]
[package.dependencies]
async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""}
[package.extras]
hiredis = ["hiredis (>=3.0.0)"]
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"]
[[package]]
name = "requests"
version = "2.32.3"
@@ -1690,4 +1724,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<4.0"
content-hash = "e9b6e5d877eeb9c9f1ebc69dead1985d749facc160afbe61f3bf37e9a6e35aa5"
content-hash = "ad9a4c8b399f6480a9f70319d13df810f92f63b532d4e10503d283f0948bed6c"

View File

@@ -8,6 +8,7 @@ packages = [{ include = "autogpt_libs" }]
[tool.poetry.dependencies]
colorama = "^0.4.6"
expiringdict = "^1.2.2"
google-cloud-logging = "^3.8.0"
pydantic = "^2.8.2"
pydantic-settings = "^2.5.2"
@@ -16,6 +17,9 @@ python = ">=3.10,<4.0"
python-dotenv = "^1.0.1"
supabase = "^2.7.2"
[tool.poetry.group.dev.dependencies]
redis = "^5.0.8"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@@ -12,7 +12,10 @@ REDIS_PORT=6379
REDIS_PASSWORD=password
ENABLE_CREDIT=false
APP_ENV="local"
# What environment things should be logged under: local dev or prod
APP_ENV=local
# What environment to behave as: "local" or "cloud"
BEHAVE_AS=local
PYRO_HOST=localhost
SENTRY_DSN=

View File

@@ -24,10 +24,12 @@ def main(**kwargs):
Run all the processes required for the AutoGPT-server (REST and WebSocket APIs).
"""
from backend.executor import ExecutionManager, ExecutionScheduler
from backend.server import AgentServer, WebsocketServer
from backend.executor import DatabaseManager, ExecutionManager, ExecutionScheduler
from backend.server.rest_api import AgentServer
from backend.server.ws_api import WebsocketServer
run_processes(
DatabaseManager(),
ExecutionManager(),
ExecutionScheduler(),
WebsocketServer(),

View File

@@ -53,15 +53,33 @@ for cls in all_subclasses(Block):
if block.id in AVAILABLE_BLOCKS:
raise ValueError(f"Block ID {block.name} error: {block.id} is already in use")
input_schema = block.input_schema.model_fields
output_schema = block.output_schema.model_fields
# Prevent duplicate field name in input_schema and output_schema
duplicate_field_names = set(block.input_schema.model_fields.keys()) & set(
block.output_schema.model_fields.keys()
)
duplicate_field_names = set(input_schema.keys()) & set(output_schema.keys())
if duplicate_field_names:
raise ValueError(
f"{block.name} has duplicate field names in input_schema and output_schema: {duplicate_field_names}"
)
# Make sure `error` field is a string in the output schema
if "error" in output_schema and output_schema["error"].annotation is not str:
raise ValueError(
f"{block.name} `error` field in output_schema must be a string"
)
# Make sure all fields in input_schema and output_schema are annotated and has a value
for field_name, field in [*input_schema.items(), *output_schema.items()]:
if field.annotation is None:
raise ValueError(
f"{block.name} has a field {field_name} that is not annotated"
)
if field.json_schema_extra is None:
raise ValueError(
f"{block.name} has a field {field_name} not defined as SchemaField"
)
for field in block.input_schema.model_fields.values():
if field.annotation is bool and field.default not in (True, False):
raise ValueError(f"{block.name} has a boolean field with no default value")

View File

@@ -1,10 +1,8 @@
import logging
import time
from enum import Enum
from typing import Optional
import requests
from pydantic import Field
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import BlockSecret, SchemaField, SecretField
@@ -130,9 +128,13 @@ class AIShortformVideoCreatorBlock(Block):
description="""1. Use short and punctuated sentences\n\n2. Use linebreaks to create a new clip\n\n3. Text outside of brackets is spoken by the AI, and [text between brackets] will be used to guide the visual generation. For example, [close-up of a cat] will show a close-up of a cat.""",
placeholder="[close-up of a cat] Meow!",
)
ratio: str = Field(description="Aspect ratio of the video", default="9 / 16")
resolution: str = Field(description="Resolution of the video", default="720p")
frame_rate: int = Field(description="Frame rate of the video", default=60)
ratio: str = SchemaField(
description="Aspect ratio of the video", default="9 / 16"
)
resolution: str = SchemaField(
description="Resolution of the video", default="720p"
)
frame_rate: int = SchemaField(description="Frame rate of the video", default=60)
generation_preset: GenerationPreset = SchemaField(
description="Generation preset for visual style - only effects AI generated visuals",
default=GenerationPreset.LEONARDO,
@@ -155,8 +157,8 @@ class AIShortformVideoCreatorBlock(Block):
)
class Output(BlockSchema):
video_url: str = Field(description="The URL of the created video")
error: Optional[str] = Field(description="Error message if the request failed")
video_url: str = SchemaField(description="The URL of the created video")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
@@ -239,69 +241,58 @@ class AIShortformVideoCreatorBlock(Block):
raise TimeoutError("Video creation timed out")
def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
# Create a new Webhook.site URL
webhook_token, webhook_url = self.create_webhook()
logger.debug(f"Webhook URL: {webhook_url}")
# Create a new Webhook.site URL
webhook_token, webhook_url = self.create_webhook()
logger.debug(f"Webhook URL: {webhook_url}")
audio_url = input_data.background_music.audio_url
audio_url = input_data.background_music.audio_url
payload = {
"frameRate": input_data.frame_rate,
"resolution": input_data.resolution,
"frameDurationMultiplier": 18,
"webhook": webhook_url,
"creationParams": {
"mediaType": input_data.video_style,
"captionPresetName": "Wrap 1",
"selectedVoice": input_data.voice.voice_id,
"hasEnhancedGeneration": True,
"generationPreset": input_data.generation_preset.name,
"selectedAudio": input_data.background_music,
"origin": "/create",
"inputText": input_data.script,
"flowType": "text-to-video",
"slug": "create-tiktok-video",
"hasToGenerateVoice": True,
"hasToTranscript": False,
"hasToSearchMedia": True,
"hasAvatar": False,
"hasWebsiteRecorder": False,
"hasTextSmallAtBottom": False,
"ratio": input_data.ratio,
"sourceType": "contentScraping",
"selectedStoryStyle": {"value": "custom", "label": "Custom"},
"hasToGenerateVideos": input_data.video_style
!= VisualMediaType.STOCK_VIDEOS,
"audioUrl": audio_url,
},
}
payload = {
"frameRate": input_data.frame_rate,
"resolution": input_data.resolution,
"frameDurationMultiplier": 18,
"webhook": webhook_url,
"creationParams": {
"mediaType": input_data.video_style,
"captionPresetName": "Wrap 1",
"selectedVoice": input_data.voice.voice_id,
"hasEnhancedGeneration": True,
"generationPreset": input_data.generation_preset.name,
"selectedAudio": input_data.background_music,
"origin": "/create",
"inputText": input_data.script,
"flowType": "text-to-video",
"slug": "create-tiktok-video",
"hasToGenerateVoice": True,
"hasToTranscript": False,
"hasToSearchMedia": True,
"hasAvatar": False,
"hasWebsiteRecorder": False,
"hasTextSmallAtBottom": False,
"ratio": input_data.ratio,
"sourceType": "contentScraping",
"selectedStoryStyle": {"value": "custom", "label": "Custom"},
"hasToGenerateVideos": input_data.video_style
!= VisualMediaType.STOCK_VIDEOS,
"audioUrl": audio_url,
},
}
logger.debug("Creating video...")
response = self.create_video(input_data.api_key.get_secret_value(), payload)
pid = response.get("pid")
logger.debug("Creating video...")
response = self.create_video(input_data.api_key.get_secret_value(), payload)
pid = response.get("pid")
if not pid:
logger.error(
f"Failed to create video: No project ID returned. API Response: {response}"
)
yield "error", "Failed to create video: No project ID returned"
else:
logger.debug(
f"Video created with project ID: {pid}. Waiting for completion..."
)
video_url = self.wait_for_video(
input_data.api_key.get_secret_value(), pid, webhook_token
)
logger.debug(f"Video ready: {video_url}")
yield "video_url", video_url
except requests.RequestException as e:
logger.exception("Error creating video")
yield "error", f"Error creating video: {str(e)}"
except ValueError as e:
logger.exception("Error in video creation process")
yield "error", str(e)
except TimeoutError as e:
logger.exception("Video creation timed out")
yield "error", str(e)
if not pid:
logger.error(
f"Failed to create video: No project ID returned. API Response: {response}"
)
raise RuntimeError("Failed to create video: No project ID returned")
else:
logger.debug(
f"Video created with project ID: {pid}. Waiting for completion..."
)
video_url = self.wait_for_video(
input_data.api_key.get_secret_value(), pid, webhook_token
)
logger.debug(f"Video ready: {video_url}")
yield "video_url", video_url

View File

@@ -2,7 +2,6 @@ import re
from typing import Any, List
from jinja2 import BaseLoader, Environment
from pydantic import Field
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
from backend.data.model import SchemaField
@@ -19,18 +18,18 @@ class StoreValueBlock(Block):
"""
class Input(BlockSchema):
input: Any = Field(
input: Any = SchemaField(
description="Trigger the block to produce the output. "
"The value is only used when `data` is None."
)
data: Any = Field(
data: Any = SchemaField(
description="The constant data to be retained in the block. "
"This value is passed as `output`.",
default=None,
)
class Output(BlockSchema):
output: Any
output: Any = SchemaField(description="The stored data retained in the block.")
def __init__(self):
super().__init__(
@@ -56,10 +55,10 @@ class StoreValueBlock(Block):
class PrintToConsoleBlock(Block):
class Input(BlockSchema):
text: str
text: str = SchemaField(description="The text to print to the console.")
class Output(BlockSchema):
status: str
status: str = SchemaField(description="The status of the print operation.")
def __init__(self):
super().__init__(
@@ -79,12 +78,14 @@ class PrintToConsoleBlock(Block):
class FindInDictionaryBlock(Block):
class Input(BlockSchema):
input: Any = Field(description="Dictionary to lookup from")
key: str | int = Field(description="Key to lookup in the dictionary")
input: Any = SchemaField(description="Dictionary to lookup from")
key: str | int = SchemaField(description="Key to lookup in the dictionary")
class Output(BlockSchema):
output: Any = Field(description="Value found for the given key")
missing: Any = Field(description="Value of the input that missing the key")
output: Any = SchemaField(description="Value found for the given key")
missing: Any = SchemaField(
description="Value of the input that missing the key"
)
def __init__(self):
super().__init__(
@@ -330,20 +331,17 @@ class AddToDictionaryBlock(Block):
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
# If no dictionary is provided, create a new one
if input_data.dictionary is None:
updated_dict = {}
else:
# Create a copy of the input dictionary to avoid modifying the original
updated_dict = input_data.dictionary.copy()
# If no dictionary is provided, create a new one
if input_data.dictionary is None:
updated_dict = {}
else:
# Create a copy of the input dictionary to avoid modifying the original
updated_dict = input_data.dictionary.copy()
# Add the new key-value pair
updated_dict[input_data.key] = input_data.value
# Add the new key-value pair
updated_dict[input_data.key] = input_data.value
yield "updated_dictionary", updated_dict
except Exception as e:
yield "error", f"Failed to add entry to dictionary: {str(e)}"
yield "updated_dictionary", updated_dict
class AddToListBlock(Block):
@@ -401,23 +399,20 @@ class AddToListBlock(Block):
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
# If no list is provided, create a new one
if input_data.list is None:
updated_list = []
else:
# Create a copy of the input list to avoid modifying the original
updated_list = input_data.list.copy()
# If no list is provided, create a new one
if input_data.list is None:
updated_list = []
else:
# Create a copy of the input list to avoid modifying the original
updated_list = input_data.list.copy()
# Add the new entry
if input_data.position is None:
updated_list.append(input_data.entry)
else:
updated_list.insert(input_data.position, input_data.entry)
# Add the new entry
if input_data.position is None:
updated_list.append(input_data.entry)
else:
updated_list.insert(input_data.position, input_data.entry)
yield "updated_list", updated_list
except Exception as e:
yield "error", f"Failed to add entry to list: {str(e)}"
yield "updated_list", updated_list
class NoteBlock(Block):

View File

@@ -3,6 +3,7 @@ import re
from typing import Type
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class BlockInstallationBlock(Block):
@@ -15,11 +16,17 @@ class BlockInstallationBlock(Block):
"""
class Input(BlockSchema):
code: str
code: str = SchemaField(
description="Python code of the block to be installed",
)
class Output(BlockSchema):
success: str
error: str
success: str = SchemaField(
description="Success message if the block is installed successfully",
)
error: str = SchemaField(
description="Error message if the block installation fails",
)
def __init__(self):
super().__init__(
@@ -37,14 +44,12 @@ class BlockInstallationBlock(Block):
if search := re.search(r"class (\w+)\(Block\):", code):
class_name = search.group(1)
else:
yield "error", "No class found in the code."
return
raise RuntimeError("No class found in the code.")
if search := re.search(r"id=\"(\w+-\w+-\w+-\w+-\w+)\"", code):
file_name = search.group(1)
else:
yield "error", "No UUID found in the code."
return
raise RuntimeError("No UUID found in the code.")
block_dir = os.path.dirname(__file__)
file_path = f"{block_dir}/{file_name}.py"
@@ -63,4 +68,4 @@ class BlockInstallationBlock(Block):
yield "success", "Block installed successfully."
except Exception as e:
os.remove(file_path)
yield "error", f"[Code]\n{code}\n\n[Error]\n{str(e)}"
raise RuntimeError(f"[Code]\n{code}\n\n[Error]\n{str(e)}")

View File

@@ -1,21 +1,49 @@
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import ContributorDetails
from backend.data.model import ContributorDetails, SchemaField
class ReadCsvBlock(Block):
class Input(BlockSchema):
contents: str
delimiter: str = ","
quotechar: str = '"'
escapechar: str = "\\"
has_header: bool = True
skip_rows: int = 0
strip: bool = True
skip_columns: list[str] = []
contents: str = SchemaField(
description="The contents of the CSV file to read",
placeholder="a, b, c\n1,2,3\n4,5,6",
)
delimiter: str = SchemaField(
description="The delimiter used in the CSV file",
default=",",
)
quotechar: str = SchemaField(
description="The character used to quote fields",
default='"',
)
escapechar: str = SchemaField(
description="The character used to escape the delimiter",
default="\\",
)
has_header: bool = SchemaField(
description="Whether the CSV file has a header row",
default=True,
)
skip_rows: int = SchemaField(
description="The number of rows to skip from the start of the file",
default=0,
)
strip: bool = SchemaField(
description="Whether to strip whitespace from the values",
default=True,
)
skip_columns: list[str] = SchemaField(
description="The columns to skip from the start of the row",
default=[],
)
class Output(BlockSchema):
row: dict[str, str]
all_data: list[dict[str, str]]
row: dict[str, str] = SchemaField(
description="The data produced from each row in the CSV file"
)
all_data: list[dict[str, str]] = SchemaField(
description="All the data in the CSV file as a list of rows"
)
def __init__(self):
super().__init__(

View File

@@ -35,8 +35,5 @@ This is a "quoted" string.""",
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
decoded_text = codecs.decode(input_data.text, "unicode_escape")
yield "decoded_text", decoded_text
except Exception as e:
yield "error", f"Error decoding text: {str(e)}"
decoded_text = codecs.decode(input_data.text, "unicode_escape")
yield "decoded_text", decoded_text

View File

@@ -2,10 +2,9 @@ import asyncio
import aiohttp
import discord
from pydantic import Field
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import BlockSecret, SecretField
from backend.data.model import BlockSecret, SchemaField, SecretField
class ReadDiscordMessagesBlock(Block):
@@ -13,16 +12,18 @@ class ReadDiscordMessagesBlock(Block):
discord_bot_token: BlockSecret = SecretField(
key="discord_bot_token", description="Discord bot token"
)
continuous_read: bool = Field(
continuous_read: bool = SchemaField(
description="Whether to continuously read messages", default=True
)
class Output(BlockSchema):
message_content: str = Field(description="The content of the message received")
channel_name: str = Field(
message_content: str = SchemaField(
description="The content of the message received"
)
channel_name: str = SchemaField(
description="The name of the channel the message was received from"
)
username: str = Field(
username: str = SchemaField(
description="The username of the user who sent the message"
)
@@ -134,13 +135,15 @@ class SendDiscordMessageBlock(Block):
discord_bot_token: BlockSecret = SecretField(
key="discord_bot_token", description="Discord bot token"
)
message_content: str = Field(description="The content of the message received")
channel_name: str = Field(
message_content: str = SchemaField(
description="The content of the message received"
)
channel_name: str = SchemaField(
description="The name of the channel the message was received from"
)
class Output(BlockSchema):
status: str = Field(
status: str = SchemaField(
description="The status of the operation (e.g., 'Message sent', 'Error')"
)

View File

@@ -2,17 +2,17 @@ import smtplib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import BlockSecret, SchemaField, SecretField
class EmailCredentials(BaseModel):
smtp_server: str = Field(
smtp_server: str = SchemaField(
default="smtp.gmail.com", description="SMTP server address"
)
smtp_port: int = Field(default=25, description="SMTP port number")
smtp_port: int = SchemaField(default=25, description="SMTP port number")
smtp_username: BlockSecret = SecretField(key="smtp_username")
smtp_password: BlockSecret = SecretField(key="smtp_password")
@@ -30,7 +30,7 @@ class SendEmailBlock(Block):
body: str = SchemaField(
description="Body of the email", placeholder="Enter the email body"
)
creds: EmailCredentials = Field(
creds: EmailCredentials = SchemaField(
description="SMTP credentials",
default=EmailCredentials(),
)
@@ -67,35 +67,28 @@ class SendEmailBlock(Block):
def send_email(
creds: EmailCredentials, to_email: str, subject: str, body: str
) -> str:
try:
smtp_server = creds.smtp_server
smtp_port = creds.smtp_port
smtp_username = creds.smtp_username.get_secret_value()
smtp_password = creds.smtp_password.get_secret_value()
smtp_server = creds.smtp_server
smtp_port = creds.smtp_port
smtp_username = creds.smtp_username.get_secret_value()
smtp_password = creds.smtp_password.get_secret_value()
msg = MIMEMultipart()
msg["From"] = smtp_username
msg["To"] = to_email
msg["Subject"] = subject
msg.attach(MIMEText(body, "plain"))
msg = MIMEMultipart()
msg["From"] = smtp_username
msg["To"] = to_email
msg["Subject"] = subject
msg.attach(MIMEText(body, "plain"))
with smtplib.SMTP(smtp_server, smtp_port) as server:
server.starttls()
server.login(smtp_username, smtp_password)
server.sendmail(smtp_username, to_email, msg.as_string())
with smtplib.SMTP(smtp_server, smtp_port) as server:
server.starttls()
server.login(smtp_username, smtp_password)
server.sendmail(smtp_username, to_email, msg.as_string())
return "Email sent successfully"
except Exception as e:
return f"Failed to send email: {str(e)}"
return "Email sent successfully"
def run(self, input_data: Input, **kwargs) -> BlockOutput:
status = self.send_email(
yield "status", self.send_email(
input_data.creds,
input_data.to_email,
input_data.subject,
input_data.body,
)
if "successfully" in status:
yield "status", status
else:
yield "error", status

View File

@@ -13,6 +13,7 @@ from ._auth import (
)
# --8<-- [start:GithubCommentBlockExample]
class GithubCommentBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
@@ -92,16 +93,16 @@ class GithubCommentBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
id, url = self.post_comment(
credentials,
input_data.issue_url,
input_data.comment,
)
yield "id", id
yield "url", url
except Exception as e:
yield "error", f"Failed to post comment: {str(e)}"
id, url = self.post_comment(
credentials,
input_data.issue_url,
input_data.comment,
)
yield "id", id
yield "url", url
# --8<-- [end:GithubCommentBlockExample]
class GithubMakeIssueBlock(Block):
@@ -175,17 +176,14 @@ class GithubMakeIssueBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
number, url = self.create_issue(
credentials,
input_data.repo_url,
input_data.title,
input_data.body,
)
yield "number", number
yield "url", url
except Exception as e:
yield "error", f"Failed to create issue: {str(e)}"
number, url = self.create_issue(
credentials,
input_data.repo_url,
input_data.title,
input_data.body,
)
yield "number", number
yield "url", url
class GithubReadIssueBlock(Block):
@@ -258,16 +256,13 @@ class GithubReadIssueBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
title, body, user = self.read_issue(
credentials,
input_data.issue_url,
)
yield "title", title
yield "body", body
yield "user", user
except Exception as e:
yield "error", f"Failed to read issue: {str(e)}"
title, body, user = self.read_issue(
credentials,
input_data.issue_url,
)
yield "title", title
yield "body", body
yield "user", user
class GithubListIssuesBlock(Block):
@@ -346,14 +341,11 @@ class GithubListIssuesBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
issues = self.list_issues(
credentials,
input_data.repo_url,
)
yield from (("issue", issue) for issue in issues)
except Exception as e:
yield "error", f"Failed to list issues: {str(e)}"
issues = self.list_issues(
credentials,
input_data.repo_url,
)
yield from (("issue", issue) for issue in issues)
class GithubAddLabelBlock(Block):
@@ -424,15 +416,12 @@ class GithubAddLabelBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = self.add_label(
credentials,
input_data.issue_url,
input_data.label,
)
yield "status", status
except Exception as e:
yield "error", f"Failed to add label: {str(e)}"
status = self.add_label(
credentials,
input_data.issue_url,
input_data.label,
)
yield "status", status
class GithubRemoveLabelBlock(Block):
@@ -508,15 +497,12 @@ class GithubRemoveLabelBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = self.remove_label(
credentials,
input_data.issue_url,
input_data.label,
)
yield "status", status
except Exception as e:
yield "error", f"Failed to remove label: {str(e)}"
status = self.remove_label(
credentials,
input_data.issue_url,
input_data.label,
)
yield "status", status
class GithubAssignIssueBlock(Block):
@@ -590,15 +576,12 @@ class GithubAssignIssueBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = self.assign_issue(
credentials,
input_data.issue_url,
input_data.assignee,
)
yield "status", status
except Exception as e:
yield "error", f"Failed to assign issue: {str(e)}"
status = self.assign_issue(
credentials,
input_data.issue_url,
input_data.assignee,
)
yield "status", status
class GithubUnassignIssueBlock(Block):
@@ -672,12 +655,9 @@ class GithubUnassignIssueBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = self.unassign_issue(
credentials,
input_data.issue_url,
input_data.assignee,
)
yield "status", status
except Exception as e:
yield "error", f"Failed to unassign issue: {str(e)}"
status = self.unassign_issue(
credentials,
input_data.issue_url,
input_data.assignee,
)
yield "status", status

View File

@@ -87,14 +87,11 @@ class GithubListPullRequestsBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
pull_requests = self.list_prs(
credentials,
input_data.repo_url,
)
yield from (("pull_request", pr) for pr in pull_requests)
except Exception as e:
yield "error", f"Failed to list pull requests: {str(e)}"
pull_requests = self.list_prs(
credentials,
input_data.repo_url,
)
yield from (("pull_request", pr) for pr in pull_requests)
class GithubMakePullRequestBlock(Block):
@@ -203,9 +200,7 @@ class GithubMakePullRequestBlock(Block):
error_message = error_details.get("message", "Unknown error")
else:
error_message = str(http_err)
yield "error", f"Failed to create pull request: {error_message}"
except Exception as e:
yield "error", f"Failed to create pull request: {str(e)}"
raise RuntimeError(f"Failed to create pull request: {error_message}")
class GithubReadPullRequestBlock(Block):
@@ -313,23 +308,20 @@ class GithubReadPullRequestBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
title, body, author = self.read_pr(
title, body, author = self.read_pr(
credentials,
input_data.pr_url,
)
yield "title", title
yield "body", body
yield "author", author
if input_data.include_pr_changes:
changes = self.read_pr_changes(
credentials,
input_data.pr_url,
)
yield "title", title
yield "body", body
yield "author", author
if input_data.include_pr_changes:
changes = self.read_pr_changes(
credentials,
input_data.pr_url,
)
yield "changes", changes
except Exception as e:
yield "error", f"Failed to read pull request: {str(e)}"
yield "changes", changes
class GithubAssignPRReviewerBlock(Block):
@@ -418,9 +410,7 @@ class GithubAssignPRReviewerBlock(Block):
)
else:
error_msg = f"HTTP error: {http_err} - {http_err.response.text}"
yield "error", error_msg
except Exception as e:
yield "error", f"Failed to assign reviewer: {str(e)}"
raise RuntimeError(error_msg)
class GithubUnassignPRReviewerBlock(Block):
@@ -490,15 +480,12 @@ class GithubUnassignPRReviewerBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = self.unassign_reviewer(
credentials,
input_data.pr_url,
input_data.reviewer,
)
yield "status", status
except Exception as e:
yield "error", f"Failed to unassign reviewer: {str(e)}"
status = self.unassign_reviewer(
credentials,
input_data.pr_url,
input_data.reviewer,
)
yield "status", status
class GithubListPRReviewersBlock(Block):
@@ -586,11 +573,8 @@ class GithubListPRReviewersBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
reviewers = self.list_reviewers(
credentials,
input_data.pr_url,
)
yield from (("reviewer", reviewer) for reviewer in reviewers)
except Exception as e:
yield "error", f"Failed to list reviewers: {str(e)}"
reviewers = self.list_reviewers(
credentials,
input_data.pr_url,
)
yield from (("reviewer", reviewer) for reviewer in reviewers)

View File

@@ -96,14 +96,11 @@ class GithubListTagsBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
tags = self.list_tags(
credentials,
input_data.repo_url,
)
yield from (("tag", tag) for tag in tags)
except Exception as e:
yield "error", f"Failed to list tags: {str(e)}"
tags = self.list_tags(
credentials,
input_data.repo_url,
)
yield from (("tag", tag) for tag in tags)
class GithubListBranchesBlock(Block):
@@ -183,14 +180,11 @@ class GithubListBranchesBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
branches = self.list_branches(
credentials,
input_data.repo_url,
)
yield from (("branch", branch) for branch in branches)
except Exception as e:
yield "error", f"Failed to list branches: {str(e)}"
branches = self.list_branches(
credentials,
input_data.repo_url,
)
yield from (("branch", branch) for branch in branches)
class GithubListDiscussionsBlock(Block):
@@ -294,13 +288,10 @@ class GithubListDiscussionsBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
discussions = self.list_discussions(
credentials, input_data.repo_url, input_data.num_discussions
)
yield from (("discussion", discussion) for discussion in discussions)
except Exception as e:
yield "error", f"Failed to list discussions: {str(e)}"
discussions = self.list_discussions(
credentials, input_data.repo_url, input_data.num_discussions
)
yield from (("discussion", discussion) for discussion in discussions)
class GithubListReleasesBlock(Block):
@@ -381,14 +372,11 @@ class GithubListReleasesBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
releases = self.list_releases(
credentials,
input_data.repo_url,
)
yield from (("release", release) for release in releases)
except Exception as e:
yield "error", f"Failed to list releases: {str(e)}"
releases = self.list_releases(
credentials,
input_data.repo_url,
)
yield from (("release", release) for release in releases)
class GithubReadFileBlock(Block):
@@ -474,18 +462,15 @@ class GithubReadFileBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
raw_content, size = self.read_file(
credentials,
input_data.repo_url,
input_data.file_path.lstrip("/"),
input_data.branch,
)
yield "raw_content", raw_content
yield "text_content", base64.b64decode(raw_content).decode("utf-8")
yield "size", size
except Exception as e:
yield "error", f"Failed to read file: {str(e)}"
raw_content, size = self.read_file(
credentials,
input_data.repo_url,
input_data.file_path.lstrip("/"),
input_data.branch,
)
yield "raw_content", raw_content
yield "text_content", base64.b64decode(raw_content).decode("utf-8")
yield "size", size
class GithubReadFolderBlock(Block):
@@ -612,17 +597,14 @@ class GithubReadFolderBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
files, dirs = self.read_folder(
credentials,
input_data.repo_url,
input_data.folder_path.lstrip("/"),
input_data.branch,
)
yield from (("file", file) for file in files)
yield from (("dir", dir) for dir in dirs)
except Exception as e:
yield "error", f"Failed to read folder: {str(e)}"
files, dirs = self.read_folder(
credentials,
input_data.repo_url,
input_data.folder_path.lstrip("/"),
input_data.branch,
)
yield from (("file", file) for file in files)
yield from (("dir", dir) for dir in dirs)
class GithubMakeBranchBlock(Block):
@@ -703,16 +685,13 @@ class GithubMakeBranchBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = self.create_branch(
credentials,
input_data.repo_url,
input_data.new_branch,
input_data.source_branch,
)
yield "status", status
except Exception as e:
yield "error", f"Failed to create branch: {str(e)}"
status = self.create_branch(
credentials,
input_data.repo_url,
input_data.new_branch,
input_data.source_branch,
)
yield "status", status
class GithubDeleteBranchBlock(Block):
@@ -775,12 +754,9 @@ class GithubDeleteBranchBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
try:
status = self.delete_branch(
credentials,
input_data.repo_url,
input_data.branch,
)
yield "status", status
except Exception as e:
yield "error", f"Failed to delete branch: {str(e)}"
status = self.delete_branch(
credentials,
input_data.repo_url,
input_data.branch,
)
yield "status", status

View File

@@ -6,11 +6,12 @@ from pydantic import SecretStr
from backend.data.model import CredentialsField, CredentialsMetaInput
from backend.util.settings import Secrets
# --8<-- [start:GoogleOAuthIsConfigured]
secrets = Secrets()
GOOGLE_OAUTH_IS_CONFIGURED = bool(
secrets.google_client_id and secrets.google_client_secret
)
# --8<-- [end:GoogleOAuthIsConfigured]
GoogleCredentials = OAuth2Credentials
GoogleCredentialsInput = CredentialsMetaInput[Literal["google"], Literal["oauth2"]]

View File

@@ -104,16 +104,11 @@ class GmailReadBlock(Block):
def run(
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
) -> BlockOutput:
try:
service = self._build_service(credentials, **kwargs)
messages = self._read_emails(
service, input_data.query, input_data.max_results
)
for email in messages:
yield "email", email
yield "emails", messages
except Exception as e:
yield "error", str(e)
service = self._build_service(credentials, **kwargs)
messages = self._read_emails(service, input_data.query, input_data.max_results)
for email in messages:
yield "email", email
yield "emails", messages
@staticmethod
def _build_service(credentials: GoogleCredentials, **kwargs):
@@ -267,14 +262,11 @@ class GmailSendBlock(Block):
def run(
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
) -> BlockOutput:
try:
service = GmailReadBlock._build_service(credentials, **kwargs)
send_result = self._send_email(
service, input_data.to, input_data.subject, input_data.body
)
yield "result", send_result
except Exception as e:
yield "error", str(e)
service = GmailReadBlock._build_service(credentials, **kwargs)
send_result = self._send_email(
service, input_data.to, input_data.subject, input_data.body
)
yield "result", send_result
def _send_email(self, service, to: str, subject: str, body: str) -> dict:
if not to or not subject or not body:
@@ -342,12 +334,9 @@ class GmailListLabelsBlock(Block):
def run(
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
) -> BlockOutput:
try:
service = GmailReadBlock._build_service(credentials, **kwargs)
labels = self._list_labels(service)
yield "result", labels
except Exception as e:
yield "error", str(e)
service = GmailReadBlock._build_service(credentials, **kwargs)
labels = self._list_labels(service)
yield "result", labels
def _list_labels(self, service) -> list[dict]:
results = service.users().labels().list(userId="me").execute()
@@ -406,14 +395,9 @@ class GmailAddLabelBlock(Block):
def run(
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
) -> BlockOutput:
try:
service = GmailReadBlock._build_service(credentials, **kwargs)
result = self._add_label(
service, input_data.message_id, input_data.label_name
)
yield "result", result
except Exception as e:
yield "error", str(e)
service = GmailReadBlock._build_service(credentials, **kwargs)
result = self._add_label(service, input_data.message_id, input_data.label_name)
yield "result", result
def _add_label(self, service, message_id: str, label_name: str) -> dict:
label_id = self._get_or_create_label(service, label_name)
@@ -494,14 +478,11 @@ class GmailRemoveLabelBlock(Block):
def run(
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
) -> BlockOutput:
try:
service = GmailReadBlock._build_service(credentials, **kwargs)
result = self._remove_label(
service, input_data.message_id, input_data.label_name
)
yield "result", result
except Exception as e:
yield "error", str(e)
service = GmailReadBlock._build_service(credentials, **kwargs)
result = self._remove_label(
service, input_data.message_id, input_data.label_name
)
yield "result", result
def _remove_label(self, service, message_id: str, label_name: str) -> dict:
label_id = self._get_label_id(service, label_name)

View File

@@ -68,14 +68,9 @@ class GoogleSheetsReadBlock(Block):
def run(
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
) -> BlockOutput:
try:
service = self._build_service(credentials, **kwargs)
data = self._read_sheet(
service, input_data.spreadsheet_id, input_data.range
)
yield "result", data
except Exception as e:
yield "error", str(e)
service = self._build_service(credentials, **kwargs)
data = self._read_sheet(service, input_data.spreadsheet_id, input_data.range)
yield "result", data
@staticmethod
def _build_service(credentials: GoogleCredentials, **kwargs):
@@ -162,17 +157,14 @@ class GoogleSheetsWriteBlock(Block):
def run(
self, input_data: Input, *, credentials: GoogleCredentials, **kwargs
) -> BlockOutput:
try:
service = GoogleSheetsReadBlock._build_service(credentials, **kwargs)
result = self._write_sheet(
service,
input_data.spreadsheet_id,
input_data.range,
input_data.values,
)
yield "result", result
except Exception as e:
yield "error", str(e)
service = GoogleSheetsReadBlock._build_service(credentials, **kwargs)
result = self._write_sheet(
service,
input_data.spreadsheet_id,
input_data.range,
input_data.values,
)
yield "result", result
def _write_sheet(
self, service, spreadsheet_id: str, range: str, values: list[list[str]]

View File

@@ -82,17 +82,14 @@ class GoogleMapsSearchBlock(Block):
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
places = self.search_places(
input_data.api_key.get_secret_value(),
input_data.query,
input_data.radius,
input_data.max_results,
)
for place in places:
yield "place", place
except Exception as e:
yield "error", str(e)
places = self.search_places(
input_data.api_key.get_secret_value(),
input_data.query,
input_data.radius,
input_data.max_results,
)
for place in places:
yield "place", place
def search_places(self, api_key, query, radius, max_results):
client = googlemaps.Client(key=api_key)

View File

@@ -4,6 +4,7 @@ from enum import Enum
import requests
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class HttpMethod(Enum):
@@ -18,15 +19,27 @@ class HttpMethod(Enum):
class SendWebRequestBlock(Block):
class Input(BlockSchema):
url: str
method: HttpMethod = HttpMethod.POST
headers: dict[str, str] = {}
body: object = {}
url: str = SchemaField(
description="The URL to send the request to",
placeholder="https://api.example.com",
)
method: HttpMethod = SchemaField(
description="The HTTP method to use for the request",
default=HttpMethod.POST,
)
headers: dict[str, str] = SchemaField(
description="The headers to include in the request",
default={},
)
body: object = SchemaField(
description="The body of the request",
default={},
)
class Output(BlockSchema):
response: object
client_error: object
server_error: object
response: object = SchemaField(description="The response from the server")
client_error: object = SchemaField(description="The error on 4xx status codes")
server_error: object = SchemaField(description="The error on 5xx status codes")
def __init__(self):
super().__init__(

View File

@@ -75,28 +75,24 @@ class IdeogramModelBlock(Block):
description="The name of the Image Generation Model, e.g., V_2",
default=IdeogramModelName.V2,
title="Image Generation Model",
enum=IdeogramModelName,
advanced=False,
)
aspect_ratio: AspectRatio = SchemaField(
description="Aspect ratio for the generated image",
default=AspectRatio.ASPECT_1_1,
title="Aspect Ratio",
enum=AspectRatio,
advanced=False,
)
upscale: UpscaleOption = SchemaField(
description="Upscale the generated image",
default=UpscaleOption.NO_UPSCALE,
title="Upscale Image",
enum=UpscaleOption,
advanced=False,
)
magic_prompt_option: MagicPromptOption = SchemaField(
description="Whether to use MagicPrompt for enhancing the request",
default=MagicPromptOption.AUTO,
title="Magic Prompt Option",
enum=MagicPromptOption,
advanced=True,
)
seed: Optional[int] = SchemaField(
@@ -109,7 +105,6 @@ class IdeogramModelBlock(Block):
description="Style type to apply, applicable for V_2 and above",
default=StyleType.AUTO,
title="Style Type",
enum=StyleType,
advanced=True,
)
negative_prompt: Optional[str] = SchemaField(
@@ -122,15 +117,12 @@ class IdeogramModelBlock(Block):
description="Color palette preset name, choose 'None' to skip",
default=ColorPalettePreset.NONE,
title="Color Palette Preset",
enum=ColorPalettePreset,
advanced=True,
)
class Output(BlockSchema):
result: str = SchemaField(description="Generated image URL")
error: Optional[str] = SchemaField(
description="Error message if the model run failed"
)
error: str = SchemaField(description="Error message if the model run failed")
def __init__(self):
super().__init__(
@@ -166,30 +158,27 @@ class IdeogramModelBlock(Block):
def run(self, input_data: Input, **kwargs) -> BlockOutput:
seed = input_data.seed
try:
# Step 1: Generate the image
result = self.run_model(
# Step 1: Generate the image
result = self.run_model(
api_key=input_data.api_key.get_secret_value(),
model_name=input_data.ideogram_model_name.value,
prompt=input_data.prompt,
seed=seed,
aspect_ratio=input_data.aspect_ratio.value,
magic_prompt_option=input_data.magic_prompt_option.value,
style_type=input_data.style_type.value,
negative_prompt=input_data.negative_prompt,
color_palette_name=input_data.color_palette_name.value,
)
# Step 2: Upscale the image if requested
if input_data.upscale == UpscaleOption.AI_UPSCALE:
result = self.upscale_image(
api_key=input_data.api_key.get_secret_value(),
model_name=input_data.ideogram_model_name.value,
prompt=input_data.prompt,
seed=seed,
aspect_ratio=input_data.aspect_ratio.value,
magic_prompt_option=input_data.magic_prompt_option.value,
style_type=input_data.style_type.value,
negative_prompt=input_data.negative_prompt,
color_palette_name=input_data.color_palette_name.value,
image_url=result,
)
# Step 2: Upscale the image if requested
if input_data.upscale == UpscaleOption.AI_UPSCALE:
result = self.upscale_image(
api_key=input_data.api_key.get_secret_value(),
image_url=result,
)
yield "result", result
except Exception as e:
yield "error", str(e)
yield "result", result
def run_model(
self,

View File

@@ -1,8 +1,12 @@
import ast
import logging
from enum import Enum
from enum import Enum, EnumMeta
from json import JSONDecodeError
from typing import Any, List, NamedTuple
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, List, NamedTuple
if TYPE_CHECKING:
from enum import _EnumMemberT
import anthropic
import ollama
@@ -12,6 +16,7 @@ from groq import Groq
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import BlockSecret, SchemaField, SecretField
from backend.util import json
from backend.util.settings import BehaveAs, Settings
logger = logging.getLogger(__name__)
@@ -29,7 +34,26 @@ class ModelMetadata(NamedTuple):
cost_factor: int
class LlmModel(str, Enum):
class LlmModelMeta(EnumMeta):
@property
def __members__(
self: type["_EnumMemberT"],
) -> MappingProxyType[str, "_EnumMemberT"]:
if Settings().config.behave_as == BehaveAs.LOCAL:
members = super().__members__
return members
else:
removed_providers = ["ollama"]
existing_members = super().__members__
members = {
name: member
for name, member in existing_members.items()
if LlmModel[name].provider not in removed_providers
}
return MappingProxyType(members)
class LlmModel(str, Enum, metaclass=LlmModelMeta):
# OpenAI models
O1_PREVIEW = "o1-preview"
O1_MINI = "o1-mini"
@@ -58,27 +82,39 @@ class LlmModel(str, Enum):
def metadata(self) -> ModelMetadata:
return MODEL_METADATA[self]
@property
def provider(self) -> str:
return self.metadata.provider
@property
def context_window(self) -> int:
return self.metadata.context_window
@property
def cost_factor(self) -> int:
return self.metadata.cost_factor
MODEL_METADATA = {
LlmModel.O1_PREVIEW: ModelMetadata("openai", 32000, cost_factor=60),
LlmModel.O1_MINI: ModelMetadata("openai", 62000, cost_factor=30),
LlmModel.GPT4O_MINI: ModelMetadata("openai", 128000, cost_factor=10),
LlmModel.GPT4O: ModelMetadata("openai", 128000, cost_factor=12),
LlmModel.GPT4_TURBO: ModelMetadata("openai", 128000, cost_factor=11),
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, cost_factor=8),
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata("anthropic", 200000, cost_factor=14),
LlmModel.CLAUDE_3_HAIKU: ModelMetadata("anthropic", 200000, cost_factor=13),
LlmModel.LLAMA3_8B: ModelMetadata("groq", 8192, cost_factor=6),
LlmModel.LLAMA3_70B: ModelMetadata("groq", 8192, cost_factor=9),
LlmModel.MIXTRAL_8X7B: ModelMetadata("groq", 32768, cost_factor=7),
LlmModel.GEMMA_7B: ModelMetadata("groq", 8192, cost_factor=6),
LlmModel.GEMMA2_9B: ModelMetadata("groq", 8192, cost_factor=7),
LlmModel.LLAMA3_1_405B: ModelMetadata("groq", 8192, cost_factor=10),
LlmModel.O1_PREVIEW: ModelMetadata("openai", 32000, cost_factor=16),
LlmModel.O1_MINI: ModelMetadata("openai", 62000, cost_factor=4),
LlmModel.GPT4O_MINI: ModelMetadata("openai", 128000, cost_factor=1),
LlmModel.GPT4O: ModelMetadata("openai", 128000, cost_factor=3),
LlmModel.GPT4_TURBO: ModelMetadata("openai", 128000, cost_factor=10),
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, cost_factor=1),
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata("anthropic", 200000, cost_factor=4),
LlmModel.CLAUDE_3_HAIKU: ModelMetadata("anthropic", 200000, cost_factor=1),
LlmModel.LLAMA3_8B: ModelMetadata("groq", 8192, cost_factor=1),
LlmModel.LLAMA3_70B: ModelMetadata("groq", 8192, cost_factor=1),
LlmModel.MIXTRAL_8X7B: ModelMetadata("groq", 32768, cost_factor=1),
LlmModel.GEMMA_7B: ModelMetadata("groq", 8192, cost_factor=1),
LlmModel.GEMMA2_9B: ModelMetadata("groq", 8192, cost_factor=1),
LlmModel.LLAMA3_1_405B: ModelMetadata("groq", 8192, cost_factor=1),
# Limited to 16k during preview
LlmModel.LLAMA3_1_70B: ModelMetadata("groq", 131072, cost_factor=15),
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072, cost_factor=13),
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192, cost_factor=7),
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192, cost_factor=11),
LlmModel.LLAMA3_1_70B: ModelMetadata("groq", 131072, cost_factor=1),
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072, cost_factor=1),
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192, cost_factor=1),
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192, cost_factor=1),
}
for model in LlmModel:
@@ -88,7 +124,10 @@ for model in LlmModel:
class AIStructuredResponseGeneratorBlock(Block):
class Input(BlockSchema):
prompt: str
prompt: str = SchemaField(
description="The prompt to send to the language model.",
placeholder="Enter your prompt here...",
)
expected_format: dict[str, str] = SchemaField(
description="Expected format of the response. If provided, the response will be validated against this format. "
"The keys should be the expected fields in the response, and the values should be the description of the field.",
@@ -100,15 +139,25 @@ class AIStructuredResponseGeneratorBlock(Block):
advanced=False,
)
api_key: BlockSecret = SecretField(value="")
sys_prompt: str = ""
retry: int = 3
sys_prompt: str = SchemaField(
title="System Prompt",
default="",
description="The system prompt to provide additional context to the model.",
)
retry: int = SchemaField(
title="Retry Count",
default=3,
description="Number of times to retry the LLM call if the response does not match the expected format.",
)
prompt_values: dict[str, str] = SchemaField(
advanced=False, default={}, description="Values used to fill in the prompt."
)
class Output(BlockSchema):
response: dict[str, Any]
error: str
response: dict[str, Any] = SchemaField(
description="The response object generated by the language model."
)
error: str = SchemaField(description="Error message if the API call failed.")
def __init__(self):
super().__init__(
@@ -308,12 +357,15 @@ class AIStructuredResponseGeneratorBlock(Block):
logger.error(f"Error calling LLM: {e}")
retry_prompt = f"Error calling LLM: {e}"
yield "error", retry_prompt
raise RuntimeError(retry_prompt)
class AITextGeneratorBlock(Block):
class Input(BlockSchema):
prompt: str
prompt: str = SchemaField(
description="The prompt to send to the language model.",
placeholder="Enter your prompt here...",
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4_TURBO,
@@ -321,15 +373,25 @@ class AITextGeneratorBlock(Block):
advanced=False,
)
api_key: BlockSecret = SecretField(value="")
sys_prompt: str = ""
retry: int = 3
sys_prompt: str = SchemaField(
title="System Prompt",
default="",
description="The system prompt to provide additional context to the model.",
)
retry: int = SchemaField(
title="Retry Count",
default=3,
description="Number of times to retry the LLM call if the response does not match the expected format.",
)
prompt_values: dict[str, str] = SchemaField(
advanced=False, default={}, description="Values used to fill in the prompt."
)
class Output(BlockSchema):
response: str
error: str
response: str = SchemaField(
description="The response generated by the language model."
)
error: str = SchemaField(description="Error message if the API call failed.")
def __init__(self):
super().__init__(
@@ -354,14 +416,11 @@ class AITextGeneratorBlock(Block):
raise ValueError("Failed to get a response from the LLM.")
def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
object_input_data = AIStructuredResponseGeneratorBlock.Input(
**{attr: getattr(input_data, attr) for attr in input_data.model_fields},
expected_format={},
)
yield "response", self.llm_call(object_input_data)
except Exception as e:
yield "error", str(e)
object_input_data = AIStructuredResponseGeneratorBlock.Input(
**{attr: getattr(input_data, attr) for attr in input_data.model_fields},
expected_format={},
)
yield "response", self.llm_call(object_input_data)
class SummaryStyle(Enum):
@@ -373,22 +432,43 @@ class SummaryStyle(Enum):
class AITextSummarizerBlock(Block):
class Input(BlockSchema):
text: str
text: str = SchemaField(
description="The text to summarize.",
placeholder="Enter the text to summarize here...",
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4_TURBO,
description="The language model to use for summarizing the text.",
)
focus: str = "general information"
style: SummaryStyle = SummaryStyle.CONCISE
focus: str = SchemaField(
title="Focus",
default="general information",
description="The topic to focus on in the summary",
)
style: SummaryStyle = SchemaField(
title="Summary Style",
default=SummaryStyle.CONCISE,
description="The style of the summary to generate.",
)
api_key: BlockSecret = SecretField(value="")
# TODO: Make this dynamic
max_tokens: int = 4000 # Adjust based on the model's context window
chunk_overlap: int = 100 # Overlap between chunks to maintain context
max_tokens: int = SchemaField(
title="Max Tokens",
default=4096,
description="The maximum number of tokens to generate in the chat completion.",
ge=1,
)
chunk_overlap: int = SchemaField(
title="Chunk Overlap",
default=100,
description="The number of overlapping tokens between chunks to maintain context.",
ge=0,
)
class Output(BlockSchema):
summary: str
error: str
summary: str = SchemaField(description="The final summary of the text.")
error: str = SchemaField(description="Error message if the API call failed.")
def __init__(self):
super().__init__(
@@ -409,11 +489,8 @@ class AITextSummarizerBlock(Block):
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
for output in self._run(input_data):
yield output
except Exception as e:
yield "error", str(e)
for output in self._run(input_data):
yield output
def _run(self, input_data: Input) -> BlockOutput:
chunks = self._split_text(
@@ -606,24 +683,21 @@ class AIConversationBlock(Block):
raise ValueError(f"Unsupported LLM provider: {provider}")
def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
api_key = (
input_data.api_key.get_secret_value()
or LlmApiKeys[input_data.model.metadata.provider].get_secret_value()
)
api_key = (
input_data.api_key.get_secret_value()
or LlmApiKeys[input_data.model.metadata.provider].get_secret_value()
)
messages = [message.model_dump() for message in input_data.messages]
messages = [message.model_dump() for message in input_data.messages]
response = self.llm_call(
api_key=api_key,
model=input_data.model,
messages=messages,
max_tokens=input_data.max_tokens,
)
response = self.llm_call(
api_key=api_key,
model=input_data.model,
messages=messages,
max_tokens=input_data.max_tokens,
)
yield "response", response
except Exception as e:
yield "error", f"Error calling LLM: {str(e)}"
yield "response", response
class AIListGeneratorBlock(Block):
@@ -741,9 +815,7 @@ class AIListGeneratorBlock(Block):
or LlmApiKeys[input_data.model.metadata.provider].get_secret_value()
)
if not api_key_check:
logger.error("No LLM API key provided.")
yield "error", "No LLM API key provided."
return
raise ValueError("No LLM API key provided.")
# Prepare the system prompt
sys_prompt = """You are a Python list generator. Your task is to generate a Python list based on the user's prompt.
@@ -837,7 +909,9 @@ class AIListGeneratorBlock(Block):
logger.error(
f"Failed to generate a valid Python list after {input_data.max_retries} attempts"
)
yield "error", f"Failed to generate a valid Python list after {input_data.max_retries} attempts. Last error: {str(e)}"
raise RuntimeError(
f"Failed to generate a valid Python list after {input_data.max_retries} attempts. Last error: {str(e)}"
)
else:
# Add a retry prompt
logger.debug("Preparing retry prompt")

View File

@@ -1,3 +1,4 @@
from enum import Enum
from typing import List
import requests
@@ -6,6 +7,12 @@ from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import BlockSecret, SchemaField, SecretField
class PublishToMediumStatus(str, Enum):
PUBLIC = "public"
DRAFT = "draft"
UNLISTED = "unlisted"
class PublishToMediumBlock(Block):
class Input(BlockSchema):
author_id: BlockSecret = SecretField(
@@ -34,9 +41,9 @@ class PublishToMediumBlock(Block):
description="The original home of this content, if it was originally published elsewhere",
placeholder="https://yourblog.com/original-post",
)
publish_status: str = SchemaField(
description="The publish status: 'public', 'draft', or 'unlisted'",
placeholder="public",
publish_status: PublishToMediumStatus = SchemaField(
description="The publish status",
placeholder=PublishToMediumStatus.DRAFT,
)
license: str = SchemaField(
default="all-rights-reserved",
@@ -79,7 +86,7 @@ class PublishToMediumBlock(Block):
"tags": ["test", "automation"],
"license": "all-rights-reserved",
"notify_followers": False,
"publish_status": "draft",
"publish_status": PublishToMediumStatus.DRAFT.value,
"api_key": "your_test_api_key",
},
test_output=[
@@ -138,31 +145,25 @@ class PublishToMediumBlock(Block):
return response.json()
def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
response = self.create_post(
input_data.api_key.get_secret_value(),
input_data.author_id.get_secret_value(),
input_data.title,
input_data.content,
input_data.content_format,
input_data.tags,
input_data.canonical_url,
input_data.publish_status,
input_data.license,
input_data.notify_followers,
response = self.create_post(
input_data.api_key.get_secret_value(),
input_data.author_id.get_secret_value(),
input_data.title,
input_data.content,
input_data.content_format,
input_data.tags,
input_data.canonical_url,
input_data.publish_status,
input_data.license,
input_data.notify_followers,
)
if "data" in response:
yield "post_id", response["data"]["id"]
yield "post_url", response["data"]["url"]
yield "published_at", response["data"]["publishedAt"]
else:
error_message = response.get("errors", [{}])[0].get(
"message", "Unknown error occurred"
)
if "data" in response:
yield "post_id", response["data"]["id"]
yield "post_url", response["data"]["url"]
yield "published_at", response["data"]["publishedAt"]
else:
error_message = response.get("errors", [{}])[0].get(
"message", "Unknown error occurred"
)
yield "error", f"Failed to create Medium post: {error_message}"
except requests.RequestException as e:
yield "error", f"Network error occurred while creating Medium post: {str(e)}"
except Exception as e:
yield "error", f"Error occurred while creating Medium post: {str(e)}"
raise RuntimeError(f"Failed to create Medium post: {error_message}")

View File

@@ -2,10 +2,10 @@ from datetime import datetime, timezone
from typing import Iterator
import praw
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import BlockSecret, SecretField
from backend.data.model import BlockSecret, SchemaField, SecretField
from backend.util.mock import MockObject
@@ -48,25 +48,25 @@ def get_praw(creds: RedditCredentials) -> praw.Reddit:
class GetRedditPostsBlock(Block):
class Input(BlockSchema):
subreddit: str = Field(description="Subreddit name")
creds: RedditCredentials = Field(
subreddit: str = SchemaField(description="Subreddit name")
creds: RedditCredentials = SchemaField(
description="Reddit credentials",
default=RedditCredentials(),
)
last_minutes: int | None = Field(
last_minutes: int | None = SchemaField(
description="Post time to stop minutes ago while fetching posts",
default=None,
)
last_post: str | None = Field(
last_post: str | None = SchemaField(
description="Post ID to stop when reached while fetching posts",
default=None,
)
post_limit: int | None = Field(
post_limit: int | None = SchemaField(
description="Number of posts to fetch", default=10
)
class Output(BlockSchema):
post: RedditPost = Field(description="Reddit post")
post: RedditPost = SchemaField(description="Reddit post")
def __init__(self):
super().__init__(
@@ -140,13 +140,13 @@ class GetRedditPostsBlock(Block):
class PostRedditCommentBlock(Block):
class Input(BlockSchema):
creds: RedditCredentials = Field(
creds: RedditCredentials = SchemaField(
description="Reddit credentials", default=RedditCredentials()
)
data: RedditComment = Field(description="Reddit comment")
data: RedditComment = SchemaField(description="Reddit comment")
class Output(BlockSchema):
comment_id: str
comment_id: str = SchemaField(description="Posted comment ID")
def __init__(self):
super().__init__(

View File

@@ -139,24 +139,21 @@ class ReplicateFluxAdvancedModelBlock(Block):
if seed is None:
seed = int.from_bytes(os.urandom(4), "big")
try:
# Run the model using the provided inputs
result = self.run_model(
api_key=input_data.api_key.get_secret_value(),
model_name=input_data.replicate_model_name.api_name,
prompt=input_data.prompt,
seed=seed,
steps=input_data.steps,
guidance=input_data.guidance,
interval=input_data.interval,
aspect_ratio=input_data.aspect_ratio,
output_format=input_data.output_format,
output_quality=input_data.output_quality,
safety_tolerance=input_data.safety_tolerance,
)
yield "result", result
except Exception as e:
yield "error", str(e)
# Run the model using the provided inputs
result = self.run_model(
api_key=input_data.api_key.get_secret_value(),
model_name=input_data.replicate_model_name.api_name,
prompt=input_data.prompt,
seed=seed,
steps=input_data.steps,
guidance=input_data.guidance,
interval=input_data.interval,
aspect_ratio=input_data.aspect_ratio,
output_format=input_data.output_format,
output_quality=input_data.output_quality,
safety_tolerance=input_data.safety_tolerance,
)
yield "result", result
def run_model(
self,

View File

@@ -17,11 +17,13 @@ class GetRequest:
class GetWikipediaSummaryBlock(Block, GetRequest):
class Input(BlockSchema):
topic: str
topic: str = SchemaField(description="The topic to fetch the summary for")
class Output(BlockSchema):
summary: str
error: str
summary: str = SchemaField(description="The summary of the given topic")
error: str = SchemaField(
description="Error message if the summary cannot be retrieved"
)
def __init__(self):
super().__init__(
@@ -36,29 +38,23 @@ class GetWikipediaSummaryBlock(Block, GetRequest):
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
topic = input_data.topic
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{topic}"
response = self.get_request(url, json=True)
yield "summary", response["extract"]
except requests.exceptions.HTTPError as http_err:
yield "error", f"HTTP error occurred: {http_err}"
except requests.RequestException as e:
yield "error", f"Request to Wikipedia failed: {e}"
except KeyError as e:
yield "error", f"Error parsing Wikipedia response: {e}"
topic = input_data.topic
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{topic}"
response = self.get_request(url, json=True)
if "extract" not in response:
raise RuntimeError(f"Unable to parse Wikipedia response: {response}")
yield "summary", response["extract"]
class SearchTheWebBlock(Block, GetRequest):
class Input(BlockSchema):
query: str # The search query
query: str = SchemaField(description="The search query to search the web for")
class Output(BlockSchema):
results: str # The search results including content from top 5 URLs
error: str # Error message if the search fails
results: str = SchemaField(
description="The search results including content from top 5 URLs"
)
error: str = SchemaField(description="Error message if the search fails")
def __init__(self):
super().__init__(
@@ -73,29 +69,22 @@ class SearchTheWebBlock(Block, GetRequest):
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
# Encode the search query
encoded_query = quote(input_data.query)
# Encode the search query
encoded_query = quote(input_data.query)
# Prepend the Jina Search URL to the encoded query
jina_search_url = f"https://s.jina.ai/{encoded_query}"
# Prepend the Jina Search URL to the encoded query
jina_search_url = f"https://s.jina.ai/{encoded_query}"
# Make the request to Jina Search
response = self.get_request(jina_search_url, json=False)
# Make the request to Jina Search
response = self.get_request(jina_search_url, json=False)
# Output the search results
yield "results", response
except requests.exceptions.HTTPError as http_err:
yield "error", f"HTTP error occurred: {http_err}"
except requests.RequestException as e:
yield "error", f"Request to Jina Search failed: {e}"
# Output the search results
yield "results", response
class ExtractWebsiteContentBlock(Block, GetRequest):
class Input(BlockSchema):
url: str # The URL to scrape
url: str = SchemaField(description="The URL to scrape the content from")
raw_content: bool = SchemaField(
default=False,
title="Raw Content",
@@ -104,8 +93,10 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
)
class Output(BlockSchema):
content: str # The scraped content from the URL
error: str
content: str = SchemaField(description="The scraped content from the given URL")
error: str = SchemaField(
description="Error message if the content cannot be retrieved"
)
def __init__(self):
super().__init__(
@@ -125,26 +116,32 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
else:
url = f"https://r.jina.ai/{input_data.url}"
try:
content = self.get_request(url, json=False)
yield "content", content
except requests.exceptions.HTTPError as http_err:
yield "error", f"HTTP error occurred: {http_err}"
except requests.RequestException as e:
yield "error", f"Request to URL failed: {e}"
content = self.get_request(url, json=False)
yield "content", content
class GetWeatherInformationBlock(Block, GetRequest):
class Input(BlockSchema):
location: str
location: str = SchemaField(
description="Location to get weather information for"
)
api_key: BlockSecret = SecretField(key="openweathermap_api_key")
use_celsius: bool = True
use_celsius: bool = SchemaField(
default=True,
description="Whether to use Celsius or Fahrenheit for temperature",
)
class Output(BlockSchema):
temperature: str
humidity: str
condition: str
error: str
temperature: str = SchemaField(
description="Temperature in the specified location"
)
humidity: str = SchemaField(description="Humidity in the specified location")
condition: str = SchemaField(
description="Weather condition in the specified location"
)
error: str = SchemaField(
description="Error message if the weather information cannot be retrieved"
)
def __init__(self):
super().__init__(
@@ -171,26 +168,15 @@ class GetWeatherInformationBlock(Block, GetRequest):
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
units = "metric" if input_data.use_celsius else "imperial"
api_key = input_data.api_key.get_secret_value()
location = input_data.location
url = f"http://api.openweathermap.org/data/2.5/weather?q={quote(location)}&appid={api_key}&units={units}"
weather_data = self.get_request(url, json=True)
units = "metric" if input_data.use_celsius else "imperial"
api_key = input_data.api_key.get_secret_value()
location = input_data.location
url = f"http://api.openweathermap.org/data/2.5/weather?q={quote(location)}&appid={api_key}&units={units}"
weather_data = self.get_request(url, json=True)
if "main" in weather_data and "weather" in weather_data:
yield "temperature", str(weather_data["main"]["temp"])
yield "humidity", str(weather_data["main"]["humidity"])
yield "condition", weather_data["weather"][0]["description"]
else:
yield "error", f"Expected keys not found in response: {weather_data}"
except requests.exceptions.HTTPError as http_err:
if http_err.response.status_code == 403:
yield "error", "Request to weather API failed: 403 Forbidden. Check your API key and permissions."
else:
yield "error", f"HTTP error occurred: {http_err}"
except requests.RequestException as e:
yield "error", f"Request to weather API failed: {e}"
except KeyError as e:
yield "error", f"Error processing weather data: {e}"
if "main" in weather_data and "weather" in weather_data:
yield "temperature", str(weather_data["main"]["temp"])
yield "humidity", str(weather_data["main"]["humidity"])
yield "condition", weather_data["weather"][0]["description"]
else:
raise RuntimeError(f"Expected keys not found in response: {weather_data}")

View File

@@ -13,7 +13,8 @@ class CreateTalkingAvatarVideoBlock(Block):
key="did_api_key", description="D-ID API Key"
)
script_input: str = SchemaField(
description="The text input for the script", default="Welcome to AutoGPT"
description="The text input for the script",
placeholder="Welcome to AutoGPT",
)
provider: Literal["microsoft", "elevenlabs", "amazon"] = SchemaField(
description="The voice provider to use", default="microsoft"
@@ -106,41 +107,40 @@ class CreateTalkingAvatarVideoBlock(Block):
return response.json()
def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
# Create the clip
payload = {
"script": {
"type": "text",
"subtitles": str(input_data.subtitles).lower(),
"provider": {
"type": input_data.provider,
"voice_id": input_data.voice_id,
},
"ssml": str(input_data.ssml).lower(),
"input": input_data.script_input,
# Create the clip
payload = {
"script": {
"type": "text",
"subtitles": str(input_data.subtitles).lower(),
"provider": {
"type": input_data.provider,
"voice_id": input_data.voice_id,
},
"config": {"result_format": input_data.result_format},
"presenter_config": {"crop": {"type": input_data.crop_type}},
"presenter_id": input_data.presenter_id,
"driver_id": input_data.driver_id,
}
"ssml": str(input_data.ssml).lower(),
"input": input_data.script_input,
},
"config": {"result_format": input_data.result_format},
"presenter_config": {"crop": {"type": input_data.crop_type}},
"presenter_id": input_data.presenter_id,
"driver_id": input_data.driver_id,
}
response = self.create_clip(input_data.api_key.get_secret_value(), payload)
clip_id = response["id"]
response = self.create_clip(input_data.api_key.get_secret_value(), payload)
clip_id = response["id"]
# Poll for clip status
for _ in range(input_data.max_polling_attempts):
status_response = self.get_clip_status(
input_data.api_key.get_secret_value(), clip_id
# Poll for clip status
for _ in range(input_data.max_polling_attempts):
status_response = self.get_clip_status(
input_data.api_key.get_secret_value(), clip_id
)
if status_response["status"] == "done":
yield "video_url", status_response["result_url"]
return
elif status_response["status"] == "error":
raise RuntimeError(
f"Clip creation failed: {status_response.get('error', 'Unknown error')}"
)
if status_response["status"] == "done":
yield "video_url", status_response["result_url"]
return
elif status_response["status"] == "error":
yield "error", f"Clip creation failed: {status_response.get('error', 'Unknown error')}"
return
time.sleep(input_data.polling_interval)
yield "error", "Clip creation timed out"
except Exception as e:
yield "error", str(e)
time.sleep(input_data.polling_interval)
raise TimeoutError("Clip creation timed out")

View File

@@ -2,9 +2,9 @@ import re
from typing import Any
from jinja2 import BaseLoader, Environment
from pydantic import Field
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util import json
jinja = Environment(loader=BaseLoader())
@@ -12,15 +12,17 @@ jinja = Environment(loader=BaseLoader())
class MatchTextPatternBlock(Block):
class Input(BlockSchema):
text: Any = Field(description="Text to match")
match: str = Field(description="Pattern (Regex) to match")
data: Any = Field(description="Data to be forwarded to output")
case_sensitive: bool = Field(description="Case sensitive match", default=True)
dot_all: bool = Field(description="Dot matches all", default=True)
text: Any = SchemaField(description="Text to match")
match: str = SchemaField(description="Pattern (Regex) to match")
data: Any = SchemaField(description="Data to be forwarded to output")
case_sensitive: bool = SchemaField(
description="Case sensitive match", default=True
)
dot_all: bool = SchemaField(description="Dot matches all", default=True)
class Output(BlockSchema):
positive: Any = Field(description="Output data if match is found")
negative: Any = Field(description="Output data if match is not found")
positive: Any = SchemaField(description="Output data if match is found")
negative: Any = SchemaField(description="Output data if match is not found")
def __init__(self):
super().__init__(
@@ -64,15 +66,17 @@ class MatchTextPatternBlock(Block):
class ExtractTextInformationBlock(Block):
class Input(BlockSchema):
text: Any = Field(description="Text to parse")
pattern: str = Field(description="Pattern (Regex) to parse")
group: int = Field(description="Group number to extract", default=0)
case_sensitive: bool = Field(description="Case sensitive match", default=True)
dot_all: bool = Field(description="Dot matches all", default=True)
text: Any = SchemaField(description="Text to parse")
pattern: str = SchemaField(description="Pattern (Regex) to parse")
group: int = SchemaField(description="Group number to extract", default=0)
case_sensitive: bool = SchemaField(
description="Case sensitive match", default=True
)
dot_all: bool = SchemaField(description="Dot matches all", default=True)
class Output(BlockSchema):
positive: str = Field(description="Extracted text")
negative: str = Field(description="Original text")
positive: str = SchemaField(description="Extracted text")
negative: str = SchemaField(description="Original text")
def __init__(self):
super().__init__(
@@ -116,11 +120,15 @@ class ExtractTextInformationBlock(Block):
class FillTextTemplateBlock(Block):
class Input(BlockSchema):
values: dict[str, Any] = Field(description="Values (dict) to be used in format")
format: str = Field(description="Template to format the text using `values`")
values: dict[str, Any] = SchemaField(
description="Values (dict) to be used in format"
)
format: str = SchemaField(
description="Template to format the text using `values`"
)
class Output(BlockSchema):
output: str
output: str = SchemaField(description="Formatted text")
def __init__(self):
super().__init__(
@@ -155,11 +163,13 @@ class FillTextTemplateBlock(Block):
class CombineTextsBlock(Block):
class Input(BlockSchema):
input: list[str] = Field(description="text input to combine")
delimiter: str = Field(description="Delimiter to combine texts", default="")
input: list[str] = SchemaField(description="text input to combine")
delimiter: str = SchemaField(
description="Delimiter to combine texts", default=""
)
class Output(BlockSchema):
output: str = Field(description="Combined text")
output: str = SchemaField(description="Combined text")
def __init__(self):
super().__init__(

View File

@@ -68,12 +68,9 @@ class UnrealTextToSpeechBlock(Block):
return response.json()
def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
api_response = self.call_unreal_speech_api(
input_data.api_key.get_secret_value(),
input_data.text,
input_data.voice_id,
)
yield "mp3_url", api_response["OutputUri"]
except Exception as e:
yield "error", str(e)
api_response = self.call_unreal_speech_api(
input_data.api_key.get_secret_value(),
input_data.text,
input_data.voice_id,
)
yield "mp3_url", api_response["OutputUri"]

View File

@@ -3,14 +3,22 @@ from datetime import datetime, timedelta
from typing import Any, Union
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class GetCurrentTimeBlock(Block):
class Input(BlockSchema):
trigger: str
trigger: str = SchemaField(
description="Trigger any data to output the current time"
)
format: str = SchemaField(
description="Format of the time to output", default="%H:%M:%S"
)
class Output(BlockSchema):
time: str
time: str = SchemaField(
description="Current time in the specified format (default: %H:%M:%S)"
)
def __init__(self):
super().__init__(
@@ -20,25 +28,38 @@ class GetCurrentTimeBlock(Block):
input_schema=GetCurrentTimeBlock.Input,
output_schema=GetCurrentTimeBlock.Output,
test_input=[
{"trigger": "Hello", "format": "{time}"},
{"trigger": "Hello"},
{"trigger": "Hello", "format": "%H:%M"},
],
test_output=[
("time", lambda _: time.strftime("%H:%M:%S")),
("time", lambda _: time.strftime("%H:%M")),
],
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
current_time = time.strftime("%H:%M:%S")
current_time = time.strftime(input_data.format)
yield "time", current_time
class GetCurrentDateBlock(Block):
class Input(BlockSchema):
trigger: str
offset: Union[int, str]
trigger: str = SchemaField(
description="Trigger any data to output the current date"
)
offset: Union[int, str] = SchemaField(
title="Days Offset",
description="Offset in days from the current date",
default=0,
)
format: str = SchemaField(
description="Format of the date to output", default="%Y-%m-%d"
)
class Output(BlockSchema):
date: str
date: str = SchemaField(
description="Current date in the specified format (default: YYYY-MM-DD)"
)
def __init__(self):
super().__init__(
@@ -48,7 +69,8 @@ class GetCurrentDateBlock(Block):
input_schema=GetCurrentDateBlock.Input,
output_schema=GetCurrentDateBlock.Output,
test_input=[
{"trigger": "Hello", "format": "{date}", "offset": "7"},
{"trigger": "Hello", "offset": "7"},
{"trigger": "Hello", "offset": "7", "format": "%m/%d/%Y"},
],
test_output=[
(
@@ -56,6 +78,12 @@ class GetCurrentDateBlock(Block):
lambda t: abs(datetime.now() - datetime.strptime(t, "%Y-%m-%d"))
< timedelta(days=8), # 7 days difference + 1 day error margin.
),
(
"date",
lambda t: abs(datetime.now() - datetime.strptime(t, "%m/%d/%Y"))
< timedelta(days=8),
# 7 days difference + 1 day error margin.
),
],
)
@@ -65,15 +93,23 @@ class GetCurrentDateBlock(Block):
except ValueError:
offset = 0
current_date = datetime.now() - timedelta(days=offset)
yield "date", current_date.strftime("%Y-%m-%d")
yield "date", current_date.strftime(input_data.format)
class GetCurrentDateAndTimeBlock(Block):
class Input(BlockSchema):
trigger: str
trigger: str = SchemaField(
description="Trigger any data to output the current date and time"
)
format: str = SchemaField(
description="Format of the date and time to output",
default="%Y-%m-%d %H:%M:%S",
)
class Output(BlockSchema):
date_time: str
date_time: str = SchemaField(
description="Current date and time in the specified format (default: YYYY-MM-DD HH:MM:SS)"
)
def __init__(self):
super().__init__(
@@ -83,7 +119,7 @@ class GetCurrentDateAndTimeBlock(Block):
input_schema=GetCurrentDateAndTimeBlock.Input,
output_schema=GetCurrentDateAndTimeBlock.Output,
test_input=[
{"trigger": "Hello", "format": "{date_time}"},
{"trigger": "Hello"},
],
test_output=[
(
@@ -97,20 +133,29 @@ class GetCurrentDateAndTimeBlock(Block):
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
current_date_time = time.strftime("%Y-%m-%d %H:%M:%S")
current_date_time = time.strftime(input_data.format)
yield "date_time", current_date_time
class CountdownTimerBlock(Block):
class Input(BlockSchema):
input_message: Any = "timer finished"
seconds: Union[int, str] = 0
minutes: Union[int, str] = 0
hours: Union[int, str] = 0
days: Union[int, str] = 0
input_message: Any = SchemaField(
description="Message to output after the timer finishes",
default="timer finished",
)
seconds: Union[int, str] = SchemaField(
description="Duration in seconds", default=0
)
minutes: Union[int, str] = SchemaField(
description="Duration in minutes", default=0
)
hours: Union[int, str] = SchemaField(description="Duration in hours", default=0)
days: Union[int, str] = SchemaField(description="Duration in days", default=0)
class Output(BlockSchema):
output_message: str
output_message: str = SchemaField(
description="Message after the timer finishes"
)
def __init__(self):
super().__init__(

View File

@@ -7,9 +7,10 @@ from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class TranscribeYouTubeVideoBlock(Block):
class TranscribeYoutubeVideoBlock(Block):
class Input(BlockSchema):
youtube_url: str = SchemaField(
title="YouTube URL",
description="The URL of the YouTube video to transcribe",
placeholder="https://www.youtube.com/watch?v=dQw4w9WgXcQ",
)
@@ -24,8 +25,8 @@ class TranscribeYouTubeVideoBlock(Block):
def __init__(self):
super().__init__(
id="f3a8f7e1-4b1d-4e5f-9f2a-7c3d5a2e6b4c",
input_schema=TranscribeYouTubeVideoBlock.Input,
output_schema=TranscribeYouTubeVideoBlock.Output,
input_schema=TranscribeYoutubeVideoBlock.Input,
output_schema=TranscribeYoutubeVideoBlock.Output,
description="Transcribes a YouTube video.",
categories={BlockCategory.SOCIAL},
test_input={"youtube_url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ"},
@@ -64,14 +65,11 @@ class TranscribeYouTubeVideoBlock(Block):
return YouTubeTranscriptApi.get_transcript(video_id)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
video_id = self.extract_video_id(input_data.youtube_url)
yield "video_id", video_id
video_id = self.extract_video_id(input_data.youtube_url)
yield "video_id", video_id
transcript = self.get_transcript(video_id)
formatter = TextFormatter()
transcript_text = formatter.format_transcript(transcript)
transcript = self.get_transcript(video_id)
formatter = TextFormatter()
transcript_text = formatter.format_transcript(transcript)
yield "transcript", transcript_text
except Exception as e:
yield "error", str(e)
yield "transcript", transcript_text

View File

@@ -272,6 +272,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
for output_name, output_data in self.run(
self.input_schema(**input_data), **kwargs
):
if output_name == "error":
raise RuntimeError(output_data)
if error := self.output_schema.validate_field(output_name, output_data):
raise ValueError(f"Block produced an invalid output data: {error}")
yield output_name, output_data

View File

@@ -17,8 +17,9 @@ from backend.blocks.llm import (
AITextSummarizerBlock,
LlmModel,
)
from backend.blocks.search import ExtractWebsiteContentBlock, SearchTheWebBlock
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
from backend.data.block import Block, BlockInput
from backend.data.block import Block, BlockInput, get_block
from backend.util.settings import Config
@@ -74,6 +75,10 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
CreateTalkingAvatarVideoBlock: [
BlockCost(cost_amount=15, cost_filter={"api_key": None})
],
SearchTheWebBlock: [BlockCost(cost_amount=1)],
ExtractWebsiteContentBlock: [
BlockCost(cost_amount=1, cost_filter={"raw_content": False})
],
}
@@ -96,7 +101,7 @@ class UserCreditBase(ABC):
self,
user_id: str,
user_credit: int,
block: Block,
block_id: str,
input_data: BlockInput,
data_size: float,
run_time: float,
@@ -107,7 +112,7 @@ class UserCreditBase(ABC):
Args:
user_id (str): The user ID.
user_credit (int): The current credit for the user.
block (Block): The block that is being used.
block_id (str): The block ID.
input_data (BlockInput): The input data for the block.
data_size (float): The size of the data being processed.
run_time (float): The time taken to run the block.
@@ -208,12 +213,16 @@ class UserCredit(UserCreditBase):
self,
user_id: str,
user_credit: int,
block: Block,
block_id: str,
input_data: BlockInput,
data_size: float,
run_time: float,
validate_balance: bool = True,
) -> int:
block = get_block(block_id)
if not block:
raise ValueError(f"Block not found: {block_id}")
cost, matching_filter = self._block_usage_cost(
block=block, input_data=input_data, data_size=data_size, run_time=run_time
)

View File

@@ -3,7 +3,6 @@ from datetime import datetime, timezone
from multiprocessing import Manager
from typing import Any, Generic, TypeVar
from autogpt_libs.supabase_integration_credentials_store.types import Credentials
from prisma.enums import AgentExecutionStatus
from prisma.models import (
AgentGraphExecution,
@@ -26,7 +25,6 @@ class GraphExecution(BaseModel):
graph_exec_id: str
graph_id: str
start_node_execs: list["NodeExecution"]
node_input_credentials: dict[str, Credentials] # dict[node_id, Credentials]
class NodeExecution(BaseModel):
@@ -40,28 +38,6 @@ class NodeExecution(BaseModel):
ExecutionStatus = AgentExecutionStatus
T = TypeVar("T")
class ExecutionQueue(Generic[T]):
"""
Queue for managing the execution of agents.
This will be shared between different processes
"""
def __init__(self):
self.queue = Manager().Queue()
def add(self, execution: T) -> T:
self.queue.put(execution)
return execution
def get(self) -> T:
return self.queue.get()
def empty(self) -> bool:
return self.queue.empty()
class ExecutionResult(BaseModel):
graph_id: str

View File

@@ -2,20 +2,18 @@ import asyncio
import logging
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Literal
import prisma.types
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
from prisma.types import AgentGraphInclude
from pydantic import BaseModel, PrivateAttr
from pydantic import BaseModel
from pydantic_core import PydanticUndefinedType
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock
from backend.data.block import BlockInput, get_block, get_blocks
from backend.data.db import BaseDbModel, transaction
from backend.data.execution import ExecutionStatus
from backend.data.user import DEFAULT_USER_ID
from backend.util import json
logger = logging.getLogger(__name__)
@@ -53,17 +51,8 @@ class Node(BaseDbModel):
block_id: str
input_default: BlockInput = {} # dict[input_name, default_value]
metadata: dict[str, Any] = {}
_input_links: list[Link] = PrivateAttr(default=[])
_output_links: list[Link] = PrivateAttr(default=[])
@property
def input_links(self) -> list[Link]:
return self._input_links
@property
def output_links(self) -> list[Link]:
return self._output_links
input_links: list[Link] = []
output_links: list[Link] = []
@staticmethod
def from_db(node: AgentNode):
@@ -75,8 +64,8 @@ class Node(BaseDbModel):
input_default=json.loads(node.constantInput),
metadata=json.loads(node.metadata),
)
obj._input_links = [Link.from_db(link) for link in node.Input or []]
obj._output_links = [Link.from_db(link) for link in node.Output or []]
obj.input_links = [Link.from_db(link) for link in node.Input or []]
obj.output_links = [Link.from_db(link) for link in node.Output or []]
return obj
@@ -330,7 +319,7 @@ class Graph(GraphMeta):
return input_schema
@staticmethod
def from_db(graph: AgentGraph):
def from_db(graph: AgentGraph, hide_credentials: bool = False):
nodes = [
*(graph.AgentNodes or []),
*(
@@ -341,7 +330,7 @@ class Graph(GraphMeta):
]
return Graph(
**GraphMeta.from_db(graph).model_dump(),
nodes=[Node.from_db(node) for node in nodes],
nodes=[Graph._process_node(node, hide_credentials) for node in nodes],
links=list(
{
Link.from_db(link)
@@ -355,6 +344,31 @@ class Graph(GraphMeta):
},
)
@staticmethod
def _process_node(node: AgentNode, hide_credentials: bool) -> Node:
node_dict = node.model_dump()
if hide_credentials and "constantInput" in node_dict:
constant_input = json.loads(node_dict["constantInput"])
constant_input = Graph._hide_credentials_in_input(constant_input)
node_dict["constantInput"] = json.dumps(constant_input)
return Node.from_db(AgentNode(**node_dict))
@staticmethod
def _hide_credentials_in_input(input_data: dict[str, Any]) -> dict[str, Any]:
sensitive_keys = ["credentials", "api_key", "password", "token", "secret"]
result = {}
for key, value in input_data.items():
if isinstance(value, dict):
result[key] = Graph._hide_credentials_in_input(value)
elif isinstance(value, str) and any(
sensitive_key in key.lower() for sensitive_key in sensitive_keys
):
# Skip this key-value pair in the result
continue
else:
result[key] = value
return result
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
"Input": True,
@@ -382,9 +396,9 @@ async def get_node(node_id: str) -> Node:
async def get_graphs_meta(
user_id: str,
include_executions: bool = False,
filter_by: Literal["active", "template"] | None = "active",
user_id: str | None = None,
) -> list[GraphMeta]:
"""
Retrieves graph metadata objects.
@@ -393,6 +407,7 @@ async def get_graphs_meta(
Args:
include_executions: Whether to include executions in the graph metadata.
filter_by: An optional filter to either select templates or active graphs.
user_id: The ID of the user that owns the graph.
Returns:
list[GraphMeta]: A list of objects representing the retrieved graph metadata.
@@ -404,8 +419,7 @@ async def get_graphs_meta(
elif filter_by == "template":
where_clause["isTemplate"] = True
if user_id and filter_by != "template":
where_clause["userId"] = user_id
where_clause["userId"] = user_id
graphs = await AgentGraph.prisma().find_many(
where=where_clause,
@@ -431,6 +445,7 @@ async def get_graph(
version: int | None = None,
template: bool = False,
user_id: str | None = None,
hide_credentials: bool = False,
) -> Graph | None:
"""
Retrieves a graph from the DB.
@@ -456,7 +471,7 @@ async def get_graph(
include=AGENT_GRAPH_INCLUDE,
order={"version": "desc"},
)
return Graph.from_db(graph) if graph else None
return Graph.from_db(graph, hide_credentials) if graph else None
async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:
@@ -500,6 +515,15 @@ async def get_graph_all_versions(graph_id: str, user_id: str) -> list[Graph]:
return [Graph.from_db(graph) for graph in graph_versions]
async def delete_graph(graph_id: str, user_id: str) -> int:
entries_count = await AgentGraph.prisma().delete_many(
where={"id": graph_id, "userId": user_id}
)
if entries_count:
logger.info(f"Deleted {entries_count} graph entries for Graph #{graph_id}")
return entries_count
async def create_graph(graph: Graph, user_id: str) -> Graph:
async with transaction() as tx:
await __create_graph(tx, graph, user_id)
@@ -576,30 +600,3 @@ async def __create_graph(tx, graph: Graph, user_id: str):
for link in graph.links
]
)
# --------------------- Helper functions --------------------- #
TEMPLATES_DIR = Path(__file__).parent.parent.parent / "graph_templates"
async def import_packaged_templates() -> None:
templates_in_db = await get_graphs_meta(filter_by="template")
logging.info("Loading templates...")
for template_file in TEMPLATES_DIR.glob("*.json"):
template_data = json.loads(template_file.read_bytes())
template = Graph.model_validate(template_data)
if not template.is_template:
logging.warning(
f"pre-packaged graph file {template_file} is not a template"
)
continue
if (
exists := next((t for t in templates_in_db if t.id == template.id), None)
) and exists.version >= template.version:
continue
await create_graph(template, DEFAULT_USER_ID)
logging.info(f"Loaded template '{template.name}' ({template.id})")

View File

@@ -2,12 +2,14 @@ import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Generic, TypeVar
from backend.data import redis
from backend.data.execution import ExecutionResult
logger = logging.getLogger(__name__)
T = TypeVar("T")
class DateTimeEncoder(json.JSONEncoder):
def default(self, o):
@@ -17,14 +19,6 @@ class DateTimeEncoder(json.JSONEncoder):
class AbstractEventQueue(ABC):
@abstractmethod
def connect(self):
pass
@abstractmethod
def close(self):
pass
@abstractmethod
def put(self, execution_result: ExecutionResult):
pass
@@ -36,26 +30,41 @@ class AbstractEventQueue(ABC):
class RedisEventQueue(AbstractEventQueue):
def __init__(self):
self.connection = None
self.queue_name = redis.QUEUE_NAME
def connect(self):
self.connection = redis.connect()
@property
def connection(self):
return redis.get_redis()
def put(self, execution_result: ExecutionResult):
if self.connection:
message = json.dumps(execution_result.model_dump(), cls=DateTimeEncoder)
logger.info(f"Putting execution result to Redis {message}")
self.connection.lpush(self.queue_name, message)
message = json.dumps(execution_result.model_dump(), cls=DateTimeEncoder)
logger.info(f"Putting execution result to Redis {message}")
self.connection.lpush(self.queue_name, message)
def get(self) -> ExecutionResult | None:
if self.connection:
message = self.connection.rpop(self.queue_name)
if message is not None and isinstance(message, (str, bytes, bytearray)):
data = json.loads(message)
logger.info(f"Getting execution result from Redis {data}")
return ExecutionResult(**data)
message = self.connection.rpop(self.queue_name)
if message is not None and isinstance(message, (str, bytes, bytearray)):
data = json.loads(message)
logger.info(f"Getting execution result from Redis {data}")
return ExecutionResult(**data)
elif message is not None:
logger.error(f"Failed to get execution result from Redis {message}")
return None
def close(self):
redis.disconnect()
class ExecutionQueue(Generic[T]):
def __init__(self, queue_name: str):
self.redis = redis.get_redis()
self.queue_name = queue_name
def add(self, item: T):
message = json.dumps(item.model_dump(), default=str)
self.redis.lpush(self.queue_name, message)
def get(self) -> T:
while True:
_, message = self.redis.brpop(self.queue_name)
return T.model_validate(json.loads(message))
def empty(self) -> bool:
return self.redis.llen(self.queue_name) == 0

View File

@@ -1,5 +1,5 @@
from backend.app import run_processes
from backend.executor import ExecutionManager
from backend.executor import DatabaseManager, ExecutionManager
def main():
@@ -7,6 +7,7 @@ def main():
Run all the processes required for the AutoGPT-server REST API.
"""
run_processes(
DatabaseManager(),
ExecutionManager(),
)

View File

@@ -1,7 +1,9 @@
from .database import DatabaseManager
from .manager import ExecutionManager
from .scheduler import ExecutionScheduler
__all__ = [
"DatabaseManager",
"ExecutionManager",
"ExecutionScheduler",
]

View File

@@ -0,0 +1,75 @@
from functools import wraps
from typing import Any, Callable, Concatenate, Coroutine, ParamSpec, TypeVar, cast
from backend.data.credit import get_user_credit_model
from backend.data.execution import (
ExecutionResult,
create_graph_execution,
get_execution_results,
get_incomplete_executions,
get_latest_execution,
update_execution_status,
update_graph_execution_stats,
update_node_execution_stats,
upsert_execution_input,
upsert_execution_output,
)
from backend.data.graph import get_graph, get_node
from backend.data.queue import RedisEventQueue
from backend.util.service import AppService, expose
from backend.util.settings import Config
P = ParamSpec("P")
R = TypeVar("R")
class DatabaseManager(AppService):
def __init__(self):
super().__init__(port=Config().database_api_port)
self.use_db = True
self.use_redis = True
self.event_queue = RedisEventQueue()
@expose
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
self.event_queue.put(ExecutionResult(**execution_result_dict))
@staticmethod
def exposed_run_and_wait(
f: Callable[P, Coroutine[None, None, R]]
) -> Callable[Concatenate[object, P], R]:
@expose
@wraps(f)
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
coroutine = f(*args, **kwargs)
res = self.run_and_wait(coroutine)
return res
return wrapper
# Executions
create_graph_execution = exposed_run_and_wait(create_graph_execution)
get_execution_results = exposed_run_and_wait(get_execution_results)
get_incomplete_executions = exposed_run_and_wait(get_incomplete_executions)
get_latest_execution = exposed_run_and_wait(get_latest_execution)
update_execution_status = exposed_run_and_wait(update_execution_status)
update_graph_execution_stats = exposed_run_and_wait(update_graph_execution_stats)
update_node_execution_stats = exposed_run_and_wait(update_node_execution_stats)
upsert_execution_input = exposed_run_and_wait(upsert_execution_input)
upsert_execution_output = exposed_run_and_wait(upsert_execution_output)
# Graphs
get_node = exposed_run_and_wait(get_node)
get_graph = exposed_run_and_wait(get_graph)
# Credits
user_credit_model = get_user_credit_model()
get_or_refill_credit = cast(
Callable[[Any, str], int],
exposed_run_and_wait(user_credit_model.get_or_refill_credit),
)
spend_credits = cast(
Callable[[Any, str, int, str, dict[str, str], float, float], int],
exposed_run_and_wait(user_credit_model.spend_credits),
)

View File

@@ -1,4 +1,3 @@
import asyncio
import atexit
import logging
import multiprocessing
@@ -9,45 +8,40 @@ import threading
from concurrent.futures import Future, ProcessPoolExecutor
from contextlib import contextmanager
from multiprocessing.pool import AsyncResult, Pool
from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast
from autogpt_libs.supabase_integration_credentials_store.types import Credentials
from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from backend.data.queue import ExecutionQueue
if TYPE_CHECKING:
from backend.server.rest_api import AgentServer
from backend.executor import DatabaseManager
from backend.data import db, redis
from backend.data import redis
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
from backend.data.credit import get_user_credit_model
from backend.data.execution import (
ExecutionQueue,
ExecutionResult,
ExecutionStatus,
GraphExecution,
NodeExecution,
create_graph_execution,
get_execution_results,
get_incomplete_executions,
get_latest_execution,
merge_execution_input,
parse_execution_output,
update_execution_status,
update_graph_execution_stats,
update_node_execution_stats,
upsert_execution_input,
upsert_execution_output,
)
from backend.data.graph import Graph, Link, Node, get_graph, get_node
from backend.data.graph import Graph, Link, Node
from backend.data.model import CREDENTIALS_FIELD_NAME, CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util import json
from backend.util.cache import thread_cached_property
from backend.util.decorator import error_logged, time_measured
from backend.util.logging import configure_logging
from backend.util.process import set_service_name
from backend.util.service import AppService, expose, get_service_client
from backend.util.settings import Config
from backend.util.settings import Settings
from backend.util.type import convert
logger = logging.getLogger(__name__)
settings = Settings()
class LogMetadata:
@@ -100,10 +94,9 @@ ExecutionStream = Generator[NodeExecution, None, None]
def execute_node(
loop: asyncio.AbstractEventLoop,
api_client: "AgentServer",
db_client: "DatabaseManager",
creds_manager: IntegrationCredentialsManager,
data: NodeExecution,
input_credentials: Credentials | None = None,
execution_stats: dict[str, Any] | None = None,
) -> ExecutionStream:
"""
@@ -111,8 +104,7 @@ def execute_node(
persist the execution result, and return the subsequent node to be executed.
Args:
loop: The event loop to run the async functions.
api_client: The client to send execution updates to the server.
db_client: The client to send execution updates to the server.
data: The execution data for executing the current node.
execution_stats: The execution statistics to be updated.
@@ -125,17 +117,12 @@ def execute_node(
node_exec_id = data.node_exec_id
node_id = data.node_id
asyncio.set_event_loop(loop)
def wait(f: Coroutine[Any, Any, T]) -> T:
return loop.run_until_complete(f)
def update_execution(status: ExecutionStatus) -> ExecutionResult:
exec_update = wait(update_execution_status(node_exec_id, status))
api_client.send_execution_update(exec_update.model_dump())
exec_update = db_client.update_execution_status(node_exec_id, status)
db_client.send_execution_update(exec_update.model_dump())
return exec_update
node = wait(get_node(node_id))
node = db_client.get_node(node_id)
node_block = get_block(node.block_id)
if not node_block:
@@ -161,15 +148,21 @@ def execute_node(
input_size = len(input_data_str)
log_metadata.info("Executed node with input", input=input_data_str)
update_execution(ExecutionStatus.RUNNING)
user_credit = get_user_credit_model()
extra_exec_kwargs = {}
if input_credentials:
extra_exec_kwargs["credentials"] = input_credentials
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
# changes during execution. ⚠️ This means a set of credentials can only be used by
# one (running) block at a time; simultaneous execution of blocks using same
# credentials is not supported.
credentials = creds_lock = None
if CREDENTIALS_FIELD_NAME in input_data:
credentials_meta = CredentialsMetaInput(**input_data[CREDENTIALS_FIELD_NAME])
credentials, creds_lock = creds_manager.acquire(user_id, credentials_meta.id)
extra_exec_kwargs["credentials"] = credentials
output_size = 0
try:
credit = wait(user_credit.get_or_refill_credit(user_id))
credit = db_client.get_or_refill_credit(user_id)
if credit < 0:
raise ValueError(f"Insufficient credit: {credit}")
@@ -178,11 +171,10 @@ def execute_node(
):
output_size += len(json.dumps(output_data))
log_metadata.info("Node produced output", output_name=output_data)
wait(upsert_execution_output(node_exec_id, output_name, output_data))
db_client.upsert_execution_output(node_exec_id, output_name, output_data)
for execution in _enqueue_next_nodes(
api_client=api_client,
loop=loop,
db_client=db_client,
node=node,
output=(output_name, output_data),
user_id=user_id,
@@ -192,6 +184,10 @@ def execute_node(
):
yield execution
# Release lock on credentials ASAP
if creds_lock:
creds_lock.release()
r = update_execution(ExecutionStatus.COMPLETED)
s = input_size + output_size
t = (
@@ -199,35 +195,27 @@ def execute_node(
if r.end_time and r.start_time
else 0
)
wait(user_credit.spend_credits(user_id, credit, node_block, input_data, s, t))
db_client.spend_credits(user_id, credit, node_block.id, input_data, s, t)
except Exception as e:
error_msg = str(e)
log_metadata.exception(f"Node execution failed with error {error_msg}")
wait(upsert_execution_output(node_exec_id, "error", error_msg))
db_client.upsert_execution_output(node_exec_id, "error", error_msg)
update_execution(ExecutionStatus.FAILED)
raise e
finally:
# Ensure credentials are released even if execution fails
if creds_lock:
creds_lock.release()
if execution_stats is not None:
execution_stats["input_size"] = input_size
execution_stats["output_size"] = output_size
@contextmanager
def synchronized(key: str, timeout: int = 60):
lock = redis.get_redis().lock(f"lock:{key}", timeout=timeout)
try:
lock.acquire()
yield
finally:
lock.release()
def _enqueue_next_nodes(
api_client: "AgentServer",
loop: asyncio.AbstractEventLoop,
db_client: "DatabaseManager",
node: Node,
output: BlockData,
user_id: str,
@@ -235,16 +223,14 @@ def _enqueue_next_nodes(
graph_id: str,
log_metadata: LogMetadata,
) -> list[NodeExecution]:
def wait(f: Coroutine[Any, Any, T]) -> T:
return loop.run_until_complete(f)
def add_enqueued_execution(
node_exec_id: str, node_id: str, data: BlockInput
) -> NodeExecution:
exec_update = wait(
update_execution_status(node_exec_id, ExecutionStatus.QUEUED, data)
exec_update = db_client.update_execution_status(
node_exec_id, ExecutionStatus.QUEUED, data
)
api_client.send_execution_update(exec_update.model_dump())
db_client.send_execution_update(exec_update.model_dump())
return NodeExecution(
user_id=user_id,
graph_exec_id=graph_exec_id,
@@ -264,20 +250,18 @@ def _enqueue_next_nodes(
if next_data is None:
return enqueued_executions
next_node = wait(get_node(next_node_id))
next_node = db_client.get_node(next_node_id)
# Multiple node can register the same next node, we need this to be atomic
# To avoid same execution to be enqueued multiple times,
# Or the same input to be consumed multiple times.
with synchronized(f"upsert_input-{next_node_id}-{graph_exec_id}"):
# Add output data to the earliest incomplete execution, or create a new one.
next_node_exec_id, next_node_input = wait(
upsert_execution_input(
node_id=next_node_id,
graph_exec_id=graph_exec_id,
input_name=next_input_name,
input_data=next_data,
)
next_node_exec_id, next_node_input = db_client.upsert_execution_input(
node_id=next_node_id,
graph_exec_id=graph_exec_id,
input_name=next_input_name,
input_data=next_data,
)
# Complete missing static input pins data using the last execution input.
@@ -287,8 +271,8 @@ def _enqueue_next_nodes(
if link.is_static and link.sink_name not in next_node_input
}
if static_link_names and (
latest_execution := wait(
get_latest_execution(next_node_id, graph_exec_id)
latest_execution := db_client.get_latest_execution(
next_node_id, graph_exec_id
)
):
for name in static_link_names:
@@ -315,7 +299,9 @@ def _enqueue_next_nodes(
# If link is static, there could be some incomplete executions waiting for it.
# Load and complete the input missing input data, and try to re-enqueue them.
for iexec in wait(get_incomplete_executions(next_node_id, graph_exec_id)):
for iexec in db_client.get_incomplete_executions(
next_node_id, graph_exec_id
):
idata = iexec.input_data
ineid = iexec.node_exec_id
@@ -400,12 +386,6 @@ def validate_exec(
return data, node_block.name
def get_agent_server_client() -> "AgentServer":
from backend.server.rest_api import AgentServer
return get_service_client(AgentServer, Config().agent_server_port)
class Executor:
"""
This class contains event handlers for the process pool executor events.
@@ -434,13 +414,12 @@ class Executor:
@classmethod
def on_node_executor_start(cls):
configure_logging()
cls.loop = asyncio.new_event_loop()
cls.pid = os.getpid()
set_service_name("NodeExecutor")
redis.connect()
cls.loop.run_until_complete(db.connect())
cls.agent_server_client = get_agent_server_client()
cls.node_queue = ExecutionQueue[NodeExecution]("node_execution_queue")
cls.pid = os.getpid()
cls.db_client = get_db_client()
cls.creds_manager = IntegrationCredentialsManager()
# Set up shutdown handlers
cls.shutdown_lock = threading.Lock()
@@ -454,8 +433,8 @@ class Executor:
if not cls.shutdown_lock.acquire(blocking=False):
return # already shutting down
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB...")
cls.loop.run_until_complete(db.disconnect())
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
cls.creds_manager.release_all_locks()
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
redis.disconnect()
logger.info(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
@@ -464,20 +443,20 @@ class Executor:
def on_node_executor_sigterm(cls):
llprint(f"[on_node_executor_sigterm {cls.pid}] ⚠️ SIGTERM received")
if not cls.shutdown_lock.acquire(blocking=False):
return # already shutting down, no need to self-terminate
return # already shutting down
llprint(f"[on_node_executor_sigterm {cls.pid}] ⏳ Disconnecting DB...")
cls.loop.run_until_complete(db.disconnect())
llprint(f"[on_node_executor_sigterm {cls.pid}] ✅ Finished cleanup")
llprint(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
cls.creds_manager.release_all_locks()
llprint(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
redis.disconnect()
llprint(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
sys.exit(0)
@classmethod
@error_logged
def on_node_execution(
cls,
q: ExecutionQueue[NodeExecution],
node_exec: NodeExecution,
input_credentials: Credentials | None,
):
log_metadata = LogMetadata(
user_id=node_exec.user_id,
@@ -487,34 +466,32 @@ class Executor:
node_id=node_exec.node_id,
block_name="-",
)
q = cls.node_queue
execution_stats = {}
timing_info, _ = cls._on_node_execution(
q, node_exec, input_credentials, log_metadata, execution_stats
q, node_exec, log_metadata, execution_stats
)
execution_stats["walltime"] = timing_info.wall_time
execution_stats["cputime"] = timing_info.cpu_time
cls.loop.run_until_complete(
update_node_execution_stats(node_exec.node_exec_id, execution_stats)
cls.db_client.update_node_execution_stats(
node_exec.node_exec_id, execution_stats
)
@classmethod
@time_measured
def _on_node_execution(
cls,
q: ExecutionQueue[NodeExecution],
node_exec: NodeExecution,
input_credentials: Credentials | None,
log_metadata: LogMetadata,
stats: dict[str, Any] | None = None,
):
try:
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
for execution in execute_node(
cls.loop, cls.agent_server_client, node_exec, input_credentials, stats
cls.db_client, cls.creds_manager, node_exec, stats
):
q.add(execution)
cls.node_queue.add(execution)
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
except Exception as e:
log_metadata.exception(
@@ -524,12 +501,11 @@ class Executor:
@classmethod
def on_graph_executor_start(cls):
configure_logging()
set_service_name("GraphExecutor")
cls.pool_size = Config().num_node_workers
cls.loop = asyncio.new_event_loop()
cls.db_client = get_db_client()
cls.pool_size = settings.config.num_node_workers
cls.pid = os.getpid()
cls.loop.run_until_complete(db.connect())
cls._init_node_executor_pool()
logger.info(
f"Graph executor {cls.pid} started with {cls.pool_size} node workers"
@@ -541,8 +517,6 @@ class Executor:
@classmethod
def on_graph_executor_stop(cls):
prefix = f"[on_graph_executor_stop {cls.pid}]"
logger.info(f"{prefix} ⏳ Disconnecting DB...")
cls.loop.run_until_complete(db.disconnect())
logger.info(f"{prefix} ⏳ Terminating node executor pool...")
cls.executor.terminate()
logger.info(f"{prefix} ✅ Finished cleanup")
@@ -569,14 +543,12 @@ class Executor:
graph_exec, cancel, log_metadata
)
cls.loop.run_until_complete(
update_graph_execution_stats(
graph_exec_id=graph_exec.graph_exec_id,
error=error,
wall_time=timing_info.wall_time,
cpu_time=timing_info.cpu_time,
node_count=node_count,
)
cls.db_client.update_graph_execution_stats(
graph_exec_id=graph_exec.graph_exec_id,
error=error,
wall_time=timing_info.wall_time,
cpu_time=timing_info.cpu_time,
node_count=node_count,
)
@classmethod
@@ -610,7 +582,7 @@ class Executor:
cancel_thread.start()
try:
queue = ExecutionQueue[NodeExecution]()
queue = ExecutionQueue[NodeExecution]("node_execution_queue")
for node_exec in graph_exec.start_node_execs:
queue.add(node_exec)
@@ -648,11 +620,7 @@ class Executor:
)
running_executions[exec_data.node_id] = cls.executor.apply_async(
cls.on_node_execution,
(
queue,
exec_data,
graph_exec.node_input_credentials.get(exec_data.node_id),
),
(exec_data,),
callback=make_exec_callback(exec_data),
)
@@ -687,12 +655,13 @@ class Executor:
class ExecutionManager(AppService):
def __init__(self):
super().__init__(port=Config().execution_manager_port)
self.use_db = True
super().__init__(port=settings.config.execution_manager_port)
self.use_redis = True
self.use_supabase = True
self.pool_size = Config().num_graph_workers
self.queue = ExecutionQueue[GraphExecution]()
self.pool_size = settings.config.num_graph_workers
self.queue = ExecutionQueue[GraphExecution]("graph_execution_queue")
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
def run_service(self):
@@ -700,7 +669,9 @@ class ExecutionManager(AppService):
SupabaseIntegrationCredentialsStore,
)
self.credentials_store = SupabaseIntegrationCredentialsStore(self.supabase)
self.credentials_store = SupabaseIntegrationCredentialsStore(
self.supabase, redis.get_redis()
)
self.executor = ProcessPoolExecutor(
max_workers=self.pool_size,
initializer=Executor.on_graph_executor_start,
@@ -730,20 +701,20 @@ class ExecutionManager(AppService):
super().cleanup()
@property
def agent_server_client(self) -> "AgentServer":
return get_agent_server_client()
@thread_cached_property
def db_client(self) -> "DatabaseManager":
return get_db_client()
@expose
def add_execution(
self, graph_id: str, data: BlockInput, user_id: str
) -> dict[str, Any]:
graph: Graph | None = self.run_and_wait(get_graph(graph_id, user_id=user_id))
graph: Graph | None = self.db_client.get_graph(graph_id, user_id=user_id)
if not graph:
raise Exception(f"Graph #{graph_id} not found.")
graph.validate_graph(for_run=True)
node_input_credentials = self._get_node_input_credentials(graph, user_id)
self._validate_node_input_credentials(graph, user_id)
nodes_input = []
for node in graph.starting_nodes:
@@ -766,13 +737,11 @@ class ExecutionManager(AppService):
else:
nodes_input.append((node.id, input_data))
graph_exec_id, node_execs = self.run_and_wait(
create_graph_execution(
graph_id=graph_id,
graph_version=graph.version,
nodes_input=nodes_input,
user_id=user_id,
)
graph_exec_id, node_execs = self.db_client.create_graph_execution(
graph_id=graph_id,
graph_version=graph.version,
nodes_input=nodes_input,
user_id=user_id,
)
starting_node_execs = []
@@ -787,19 +756,16 @@ class ExecutionManager(AppService):
data=node_exec.input_data,
)
)
exec_update = self.run_and_wait(
update_execution_status(
node_exec.node_exec_id, ExecutionStatus.QUEUED, node_exec.input_data
)
exec_update = self.db_client.update_execution_status(
node_exec.node_exec_id, ExecutionStatus.QUEUED, node_exec.input_data
)
self.agent_server_client.send_execution_update(exec_update.model_dump())
self.db_client.send_execution_update(exec_update.model_dump())
graph_exec = GraphExecution(
user_id=user_id,
graph_id=graph_id,
graph_exec_id=graph_exec_id,
start_node_execs=starting_node_execs,
node_input_credentials=node_input_credentials,
)
self.queue.add(graph_exec)
@@ -828,30 +794,22 @@ class ExecutionManager(AppService):
future.result()
# Update the status of the unfinished node executions
node_execs = self.run_and_wait(get_execution_results(graph_exec_id))
node_execs = self.db_client.get_execution_results(graph_exec_id)
for node_exec in node_execs:
if node_exec.status not in (
ExecutionStatus.COMPLETED,
ExecutionStatus.FAILED,
):
self.run_and_wait(
upsert_execution_output(
node_exec.node_exec_id, "error", "TERMINATED"
)
self.db_client.upsert_execution_output(
node_exec.node_exec_id, "error", "TERMINATED"
)
exec_update = self.run_and_wait(
update_execution_status(
node_exec.node_exec_id, ExecutionStatus.FAILED
)
exec_update = self.db_client.update_execution_status(
node_exec.node_exec_id, ExecutionStatus.FAILED
)
self.agent_server_client.send_execution_update(exec_update.model_dump())
self.db_client.send_execution_update(exec_update.model_dump())
def _get_node_input_credentials(
self, graph: Graph, user_id: str
) -> dict[str, Credentials]:
"""Gets all credentials for all nodes of the graph"""
node_credentials: dict[str, Credentials] = {}
def _validate_node_input_credentials(self, graph: Graph, user_id: str):
"""Checks all credentials for all nodes of the graph"""
for node in graph.nodes:
block = get_block(node.block_id)
@@ -894,9 +852,25 @@ class ExecutionManager(AppService):
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch"
)
node_credentials[node.id] = credentials
return node_credentials
# ------- UTILITIES ------- #
def get_db_client() -> "DatabaseManager":
from backend.executor import DatabaseManager
return get_service_client(DatabaseManager, settings.config.database_api_port)
@contextmanager
def synchronized(key: str, timeout: int = 60):
lock: RedisLock = redis.get_redis().lock(f"lock:{key}", timeout=timeout)
try:
lock.acquire()
yield
finally:
lock.release()
def llprint(message: str):

View File

@@ -5,9 +5,16 @@ from datetime import datetime
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from backend.data import schedule as model
from backend.data.block import BlockInput
from backend.data.schedule import (
ExecutionSchedule,
add_schedule,
get_active_schedules,
get_schedules,
update_schedule,
)
from backend.executor.manager import ExecutionManager
from backend.util.cache import thread_cached_property
from backend.util.service import AppService, expose, get_service_client
from backend.util.settings import Config
@@ -19,14 +26,15 @@ def log(msg, **kwargs):
class ExecutionScheduler(AppService):
def __init__(self, refresh_interval=10):
super().__init__(port=Config().execution_scheduler_port)
self.use_db = True
self.last_check = datetime.min
self.refresh_interval = refresh_interval
@property
def execution_manager_client(self) -> ExecutionManager:
@thread_cached_property
def execution_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager, Config().execution_manager_port)
def run_service(self):
@@ -37,7 +45,7 @@ class ExecutionScheduler(AppService):
time.sleep(self.refresh_interval)
def __refresh_jobs_from_db(self, scheduler: BackgroundScheduler):
schedules = self.run_and_wait(model.get_active_schedules(self.last_check))
schedules = self.run_and_wait(get_active_schedules(self.last_check))
for schedule in schedules:
if schedule.last_updated:
self.last_check = max(self.last_check, schedule.last_updated)
@@ -59,14 +67,13 @@ class ExecutionScheduler(AppService):
def __execute_graph(self, graph_id: str, input_data: dict, user_id: str):
try:
log(f"Executing recurring job for graph #{graph_id}")
execution_manager = self.execution_manager_client
execution_manager.add_execution(graph_id, input_data, user_id)
self.execution_client.add_execution(graph_id, input_data, user_id)
except Exception as e:
logger.exception(f"Error executing graph {graph_id}: {e}")
@expose
def update_schedule(self, schedule_id: str, is_enabled: bool, user_id: str) -> str:
self.run_and_wait(model.update_schedule(schedule_id, is_enabled, user_id))
self.run_and_wait(update_schedule(schedule_id, is_enabled, user_id))
return schedule_id
@expose
@@ -78,17 +85,16 @@ class ExecutionScheduler(AppService):
input_data: BlockInput,
user_id: str,
) -> str:
schedule = model.ExecutionSchedule(
schedule = ExecutionSchedule(
graph_id=graph_id,
user_id=user_id,
graph_version=graph_version,
schedule=cron,
input_data=input_data,
)
return self.run_and_wait(model.add_schedule(schedule)).id
return self.run_and_wait(add_schedule(schedule)).id
@expose
def get_execution_schedules(self, graph_id: str, user_id: str) -> dict[str, str]:
query = model.get_schedules(graph_id, user_id=user_id)
schedules: list[model.ExecutionSchedule] = self.run_and_wait(query)
schedules = self.run_and_wait(get_schedules(graph_id, user_id=user_id))
return {v.id: v.schedule for v in schedules}

View File

@@ -0,0 +1,172 @@
import logging
from contextlib import contextmanager
from datetime import datetime
from autogpt_libs.supabase_integration_credentials_store import (
Credentials,
SupabaseIntegrationCredentialsStore,
)
from autogpt_libs.utils.synchronize import RedisKeyedMutex
from redis.lock import Lock as RedisLock
from backend.data import redis
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
from backend.util.settings import Settings
from ..server.integrations.utils import get_supabase
logger = logging.getLogger(__name__)
settings = Settings()
class IntegrationCredentialsManager:
"""
Handles the lifecycle of integration credentials.
- Automatically refreshes requested credentials if needed.
- Uses locking mechanisms to ensure system-wide consistency and
prevent invalidation of in-use tokens.
### ⚠️ Gotcha
With `acquire(..)`, credentials can only be in use in one place at a time (e.g. one
block execution).
### Locking mechanism
- Because *getting* credentials can result in a refresh (= *invalidation* +
*replacement*) of the stored credentials, *getting* is an operation that
potentially requires read/write access.
- Checking whether a token has to be refreshed is subject to an additional `refresh`
scoped lock to prevent unnecessary sequential refreshes when multiple executions
try to access the same credentials simultaneously.
- We MUST lock credentials while in use to prevent them from being invalidated while
they are in use, e.g. because they are being refreshed by a different part
of the system.
- The `!time_sensitive` lock in `acquire(..)` is part of a two-tier locking
mechanism in which *updating* gets priority over *getting* credentials.
This is to prevent a long queue of waiting *get* requests from blocking essential
credential refreshes or user-initiated updates.
It is possible to implement a reader/writer locking system where either multiple
readers or a single writer can have simultaneous access, but this would add a lot of
complexity to the mechanism. I don't expect the current ("simple") mechanism to
cause so much latency that it's worth implementing.
"""
def __init__(self):
redis_conn = redis.get_redis()
self._locks = RedisKeyedMutex(redis_conn)
self.store = SupabaseIntegrationCredentialsStore(get_supabase(), redis_conn)
def create(self, user_id: str, credentials: Credentials) -> None:
return self.store.add_creds(user_id, credentials)
def exists(self, user_id: str, credentials_id: str) -> bool:
return self.store.get_creds_by_id(user_id, credentials_id) is not None
def get(
self, user_id: str, credentials_id: str, lock: bool = True
) -> Credentials | None:
credentials = self.store.get_creds_by_id(user_id, credentials_id)
if not credentials:
return None
# Refresh OAuth credentials if needed
if credentials.type == "oauth2" and credentials.access_token_expires_at:
logger.debug(
f"Credentials #{credentials.id} expire at "
f"{datetime.fromtimestamp(credentials.access_token_expires_at)}; "
f"current time is {datetime.now()}"
)
with self._locked(user_id, credentials_id, "refresh"):
oauth_handler = _get_provider_oauth_handler(credentials.provider)
if oauth_handler.needs_refresh(credentials):
logger.debug(
f"Refreshing '{credentials.provider}' "
f"credentials #{credentials.id}"
)
_lock = None
if lock:
# Wait until the credentials are no longer in use anywhere
_lock = self._acquire_lock(user_id, credentials_id)
fresh_credentials = oauth_handler.refresh_tokens(credentials)
self.store.update_creds(user_id, fresh_credentials)
if _lock:
_lock.release()
credentials = fresh_credentials
else:
logger.debug(f"Credentials #{credentials.id} never expire")
return credentials
def acquire(
self, user_id: str, credentials_id: str
) -> tuple[Credentials, RedisLock]:
"""
⚠️ WARNING: this locks credentials system-wide and blocks both acquiring
and updating them elsewhere until the lock is released.
See the class docstring for more info.
"""
# Use a low-priority (!time_sensitive) locking queue on top of the general lock
# to allow priority access for refreshing/updating the tokens.
with self._locked(user_id, credentials_id, "!time_sensitive"):
lock = self._acquire_lock(user_id, credentials_id)
credentials = self.get(user_id, credentials_id, lock=False)
if not credentials:
raise ValueError(
f"Credentials #{credentials_id} for user #{user_id} not found"
)
return credentials, lock
def update(self, user_id: str, updated: Credentials) -> None:
with self._locked(user_id, updated.id):
self.store.update_creds(user_id, updated)
def delete(self, user_id: str, credentials_id: str) -> None:
with self._locked(user_id, credentials_id):
self.store.delete_creds_by_id(user_id, credentials_id)
# -- Locking utilities -- #
def _acquire_lock(self, user_id: str, credentials_id: str, *args: str) -> RedisLock:
key = (
self.store.supabase.supabase_url,
f"user:{user_id}",
f"credentials:{credentials_id}",
*args,
)
return self._locks.acquire(key)
@contextmanager
def _locked(self, user_id: str, credentials_id: str, *args: str):
lock = self._acquire_lock(user_id, credentials_id, *args)
try:
yield
finally:
lock.release()
def release_all_locks(self):
"""Call this on process termination to ensure all locks are released"""
self._locks.release_all_locks()
self.store.locks.release_all_locks()
def _get_provider_oauth_handler(provider_name: str) -> BaseOAuthHandler:
if provider_name not in HANDLERS_BY_NAME:
raise KeyError(f"Unknown provider '{provider_name}'")
client_id = getattr(settings.secrets, f"{provider_name}_client_id")
client_secret = getattr(settings.secrets, f"{provider_name}_client_secret")
if not (client_id and client_secret):
raise Exception( # TODO: ConfigError
f"Integration with provider '{provider_name}' is not configured",
)
handler_class = HANDLERS_BY_NAME[provider_name]
frontend_base_url = settings.config.frontend_base_url
return handler_class(
client_id=client_id,
client_secret=client_secret,
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
)

View File

@@ -3,6 +3,7 @@ from .github import GitHubOAuthHandler
from .google import GoogleOAuthHandler
from .notion import NotionOAuthHandler
# --8<-- [start:HANDLERS_BY_NAMEExample]
HANDLERS_BY_NAME: dict[str, type[BaseOAuthHandler]] = {
handler.PROVIDER_NAME: handler
for handler in [
@@ -11,5 +12,6 @@ HANDLERS_BY_NAME: dict[str, type[BaseOAuthHandler]] = {
NotionOAuthHandler,
]
}
# --8<-- [end:HANDLERS_BY_NAMEExample]
__all__ = ["HANDLERS_BY_NAME"]

View File

@@ -9,29 +9,48 @@ logger = logging.getLogger(__name__)
class BaseOAuthHandler(ABC):
# --8<-- [start:BaseOAuthHandler1]
PROVIDER_NAME: ClassVar[str]
DEFAULT_SCOPES: ClassVar[list[str]] = []
# --8<-- [end:BaseOAuthHandler1]
@abstractmethod
# --8<-- [start:BaseOAuthHandler2]
def __init__(self, client_id: str, client_secret: str, redirect_uri: str): ...
# --8<-- [end:BaseOAuthHandler2]
@abstractmethod
# --8<-- [start:BaseOAuthHandler3]
def get_login_url(self, scopes: list[str], state: str) -> str:
# --8<-- [end:BaseOAuthHandler3]
"""Constructs a login URL that the user can be redirected to"""
...
@abstractmethod
# --8<-- [start:BaseOAuthHandler4]
def exchange_code_for_tokens(
self, code: str, scopes: list[str]
) -> OAuth2Credentials:
# --8<-- [end:BaseOAuthHandler4]
"""Exchanges the acquired authorization code from login for a set of tokens"""
...
@abstractmethod
# --8<-- [start:BaseOAuthHandler5]
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
# --8<-- [end:BaseOAuthHandler5]
"""Implements the token refresh mechanism"""
...
@abstractmethod
# --8<-- [start:BaseOAuthHandler6]
def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
# --8<-- [end:BaseOAuthHandler6]
"""Revokes the given token at provider,
returns False provider does not support it"""
...
def refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
if credentials.provider != self.PROVIDER_NAME:
raise ValueError(

View File

@@ -8,6 +8,7 @@ from autogpt_libs.supabase_integration_credentials_store import OAuth2Credential
from .base import BaseOAuthHandler
# --8<-- [start:GithubOAuthHandlerExample]
class GitHubOAuthHandler(BaseOAuthHandler):
"""
Based on the documentation at:
@@ -23,7 +24,6 @@ class GitHubOAuthHandler(BaseOAuthHandler):
""" # noqa
PROVIDER_NAME = "github"
EMAIL_ENDPOINT = "https://api.github.com/user/emails"
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
@@ -31,6 +31,7 @@ class GitHubOAuthHandler(BaseOAuthHandler):
self.redirect_uri = redirect_uri
self.auth_base_url = "https://github.com/login/oauth/authorize"
self.token_url = "https://github.com/login/oauth/access_token"
self.revoke_url = "https://api.github.com/applications/{client_id}/token"
def get_login_url(self, scopes: list[str], state: str) -> str:
params = {
@@ -46,6 +47,24 @@ class GitHubOAuthHandler(BaseOAuthHandler):
) -> OAuth2Credentials:
return self._request_tokens({"code": code, "redirect_uri": self.redirect_uri})
def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
if not credentials.access_token:
raise ValueError("No access token to revoke")
headers = {
"Accept": "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.delete(
url=self.revoke_url.format(client_id=self.client_id),
auth=(self.client_id, self.client_secret),
headers=headers,
json={"access_token": credentials.access_token.get_secret_value()},
)
response.raise_for_status()
return True
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
if not credentials.refresh_token:
return credentials
@@ -119,3 +138,6 @@ class GitHubOAuthHandler(BaseOAuthHandler):
# Get the login (username)
return response.json().get("login")
# --8<-- [end:GithubOAuthHandlerExample]

View File

@@ -14,6 +14,7 @@ from .base import BaseOAuthHandler
logger = logging.getLogger(__name__)
# --8<-- [start:GoogleOAuthHandlerExample]
class GoogleOAuthHandler(BaseOAuthHandler):
"""
Based on the documentation at https://developers.google.com/identity/protocols/oauth2/web-server
@@ -26,12 +27,14 @@ class GoogleOAuthHandler(BaseOAuthHandler):
"https://www.googleapis.com/auth/userinfo.profile",
"openid",
]
# --8<-- [end:GoogleOAuthHandlerExample]
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
self.token_uri = "https://oauth2.googleapis.com/token"
self.revoke_uri = "https://oauth2.googleapis.com/revoke"
def get_login_url(self, scopes: list[str], state: str) -> str:
all_scopes = list(set(scopes + self.DEFAULT_SCOPES))
@@ -98,6 +101,16 @@ class GoogleOAuthHandler(BaseOAuthHandler):
return credentials
def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
session = AuthorizedSession(credentials)
response = session.post(
self.revoke_uri,
params={"token": credentials.access_token.get_secret_value()},
headers={"content-type": "application/x-www-form-urlencoded"},
)
response.raise_for_status()
return True
def _request_email(
self, creds: Credentials | ExternalAccountCredentials
) -> str | None:

View File

@@ -77,6 +77,10 @@ class NotionOAuthHandler(BaseOAuthHandler):
},
)
def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
# Notion doesn't support token revocation
return False
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
# Notion doesn't support token refresh
return credentials

View File

@@ -1,6 +1,6 @@
from backend.app import run_processes
from backend.executor import ExecutionScheduler
from backend.server import AgentServer
from backend.server.rest_api import AgentServer
def main():

View File

@@ -1,4 +0,0 @@
from .rest_api import AgentServer
from .ws_api import WebsocketServer
__all__ = ["AgentServer", "WebsocketServer"]

View File

@@ -1,40 +1,25 @@
import logging
from typing import Annotated
from typing import Annotated, Literal
from autogpt_libs.supabase_integration_credentials_store import (
SupabaseIntegrationCredentialsStore,
)
from autogpt_libs.supabase_integration_credentials_store.types import (
APIKeyCredentials,
Credentials,
CredentialsType,
OAuth2Credentials,
)
from fastapi import (
APIRouter,
Body,
Depends,
HTTPException,
Path,
Query,
Request,
Response,
)
from pydantic import BaseModel, SecretStr
from supabase import Client
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
from pydantic import BaseModel, Field, SecretStr
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
from backend.util.settings import Settings
from ..utils import get_supabase, get_user_id
from ..utils import get_user_id
logger = logging.getLogger(__name__)
settings = Settings()
router = APIRouter()
def get_store(supabase: Client = Depends(get_supabase)):
return SupabaseIntegrationCredentialsStore(supabase)
creds_manager = IntegrationCredentialsManager()
class LoginResponse(BaseModel):
@@ -47,7 +32,6 @@ async def login(
provider: Annotated[str, Path(title="The provider to initiate an OAuth flow for")],
user_id: Annotated[str, Depends(get_user_id)],
request: Request,
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
scopes: Annotated[
str, Query(title="Comma-separated list of authorization scopes")
] = "",
@@ -57,7 +41,9 @@ async def login(
requested_scopes = scopes.split(",") if scopes else []
# Generate and store a secure random state token along with the scopes
state_token = await store.store_state_token(user_id, provider, requested_scopes)
state_token = await creds_manager.store.store_state_token(
user_id, provider, requested_scopes
)
login_url = handler.get_login_url(requested_scopes, state_token)
@@ -77,7 +63,6 @@ async def callback(
provider: Annotated[str, Path(title="The target provider for this OAuth exchange")],
code: Annotated[str, Body(title="Authorization code acquired by user login")],
state_token: Annotated[str, Body(title="Anti-CSRF nonce")],
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
user_id: Annotated[str, Depends(get_user_id)],
request: Request,
) -> CredentialsMetaResponse:
@@ -85,12 +70,12 @@ async def callback(
handler = _get_provider_oauth_handler(request, provider)
# Verify the state token
if not await store.verify_state_token(user_id, state_token, provider):
if not await creds_manager.store.verify_state_token(user_id, state_token, provider):
logger.warning(f"Invalid or expired state token for user {user_id}")
raise HTTPException(status_code=400, detail="Invalid or expired state token")
try:
scopes = await store.get_any_valid_scopes_from_state_token(
scopes = await creds_manager.store.get_any_valid_scopes_from_state_token(
user_id, state_token, provider
)
logger.debug(f"Retrieved scopes from state token: {scopes}")
@@ -114,7 +99,7 @@ async def callback(
)
# TODO: Allow specifying `title` to set on `credentials`
store.add_creds(user_id, credentials)
creds_manager.create(user_id, credentials)
logger.debug(
f"Successfully processed OAuth callback for user {user_id} and provider {provider}"
@@ -132,9 +117,8 @@ async def callback(
async def list_credentials(
provider: Annotated[str, Path(title="The provider to list credentials for")],
user_id: Annotated[str, Depends(get_user_id)],
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
) -> list[CredentialsMetaResponse]:
credentials = store.get_creds_by_provider(user_id, provider)
credentials = creds_manager.store.get_creds_by_provider(user_id, provider)
return [
CredentialsMetaResponse(
id=cred.id,
@@ -152,9 +136,8 @@ async def get_credential(
provider: Annotated[str, Path(title="The provider to retrieve credentials for")],
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
user_id: Annotated[str, Depends(get_user_id)],
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
) -> Credentials:
credential = store.get_creds_by_id(user_id, cred_id)
credential = creds_manager.get(user_id, cred_id)
if not credential:
raise HTTPException(status_code=404, detail="Credentials not found")
if credential.provider != provider:
@@ -166,7 +149,6 @@ async def get_credential(
@router.post("/{provider}/credentials", status_code=201)
async def create_api_key_credentials(
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
user_id: Annotated[str, Depends(get_user_id)],
provider: Annotated[str, Path(title="The provider to create credentials for")],
api_key: Annotated[str, Body(title="The API key to store")],
@@ -183,7 +165,7 @@ async def create_api_key_credentials(
)
try:
store.add_creds(user_id, new_credentials)
creds_manager.create(user_id, new_credentials)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to store credentials: {str(e)}"
@@ -191,14 +173,23 @@ async def create_api_key_credentials(
return new_credentials
@router.delete("/{provider}/credentials/{cred_id}", status_code=204)
async def delete_credential(
class CredentialsDeletionResponse(BaseModel):
deleted: Literal[True] = True
revoked: bool | None = Field(
description="Indicates whether the credentials were also revoked by their "
"provider. `None`/`null` if not applicable, e.g. when deleting "
"non-revocable credentials such as API keys."
)
@router.delete("/{provider}/credentials/{cred_id}")
async def delete_credentials(
request: Request,
provider: Annotated[str, Path(title="The provider to delete credentials for")],
cred_id: Annotated[str, Path(title="The ID of the credentials to delete")],
user_id: Annotated[str, Depends(get_user_id)],
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
):
creds = store.get_creds_by_id(user_id, cred_id)
) -> CredentialsDeletionResponse:
creds = creds_manager.store.get_creds_by_id(user_id, cred_id)
if not creds:
raise HTTPException(status_code=404, detail="Credentials not found")
if creds.provider != provider:
@@ -206,8 +197,14 @@ async def delete_credential(
status_code=404, detail="Credentials do not match the specified provider"
)
store.delete_creds_by_id(user_id, cred_id)
return Response(status_code=204)
creds_manager.delete(user_id, cred_id)
tokens_revoked = None
if isinstance(creds, OAuth2Credentials):
handler = _get_provider_oauth_handler(request, provider)
tokens_revoked = handler.revoke_tokens(creds)
return CredentialsDeletionResponse(revoked=tokens_revoked)
# -------- UTILITIES --------- #

View File

@@ -0,0 +1,11 @@
from supabase import Client, create_client
from backend.util.settings import Settings
settings = Settings()
def get_supabase() -> Client:
return create_client(
settings.secrets.supabase_url, settings.secrets.supabase_service_role_key
)

View File

@@ -10,19 +10,20 @@ from autogpt_libs.auth.middleware import auth_middleware
from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from typing_extensions import TypedDict
from backend.data import block, db
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data import user as user_db
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.credit import get_block_costs, get_user_credit_model
from backend.data.queue import RedisEventQueue
from backend.data.user import get_or_create_user
from backend.executor import ExecutionManager, ExecutionScheduler
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.server.model import CreateGraph, SetGraphActiveVersion
from backend.util.service import AppService, expose, get_service_client
from backend.util.settings import Config, Settings
from backend.util.cache import thread_cached_property
from backend.util.service import AppService, get_service_client
from backend.util.settings import AppEnvironment, Config, Settings
from .utils import get_user_id
@@ -31,26 +32,22 @@ logger = logging.getLogger(__name__)
class AgentServer(AppService):
use_queue = True
_test_dependency_overrides = {}
_user_credit_model = get_user_credit_model()
def __init__(self):
super().__init__(port=Config().agent_server_port)
self.event_queue = RedisEventQueue()
self.use_redis = True
@asynccontextmanager
async def lifespan(self, _: FastAPI):
await db.connect()
self.event_queue.connect()
await block.initialize_blocks()
if await user_db.create_default_user(settings.config.enable_auth):
await graph_db.import_packaged_templates()
yield
self.event_queue.close()
await db.disconnect()
def run_service(self):
docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
app = FastAPI(
title="AutoGPT Agent Server",
description=(
@@ -60,6 +57,7 @@ class AgentServer(AppService):
summary="AutoGPT Agent Server",
version="0.1",
lifespan=self.lifespan,
docs_url=docs_url,
)
if self._test_dependency_overrides:
@@ -77,20 +75,29 @@ class AgentServer(AppService):
allow_headers=["*"], # Allows all headers
)
health_router = APIRouter()
health_router.add_api_route(
path="/health",
endpoint=self.health,
methods=["GET"],
tags=["health"],
)
# Define the API routes
api_router = APIRouter(prefix="/api")
api_router.dependencies.append(Depends(auth_middleware))
# Import & Attach sub-routers
import backend.server.integrations.router
import backend.server.routers.analytics
import backend.server.routers.integrations
api_router.include_router(
backend.server.routers.integrations.router,
backend.server.integrations.router.router,
prefix="/integrations",
tags=["integrations"],
dependencies=[Depends(auth_middleware)],
)
self.integration_creds_manager = IntegrationCredentialsManager()
api_router.include_router(
backend.server.routers.analytics.router,
@@ -166,6 +173,12 @@ class AgentServer(AppService):
methods=["PUT"],
tags=["templates", "graphs"],
)
api_router.add_api_route(
path="/graphs/{graph_id}",
endpoint=self.delete_graph,
methods=["DELETE"],
tags=["graphs"],
)
api_router.add_api_route(
path="/graphs/{graph_id}/versions",
endpoint=self.get_graph_all_versions,
@@ -254,6 +267,7 @@ class AgentServer(AppService):
app.add_exception_handler(500, self.handle_internal_http_error)
app.include_router(api_router)
app.include_router(health_router)
uvicorn.run(
app,
@@ -291,11 +305,11 @@ class AgentServer(AppService):
return wrapper
@property
@thread_cached_property
def execution_manager_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager, Config().execution_manager_port)
@property
@thread_cached_property
def execution_scheduler_client(self) -> ExecutionScheduler:
return get_service_client(ExecutionScheduler, Config().execution_scheduler_port)
@@ -355,8 +369,11 @@ class AgentServer(AppService):
graph_id: str,
user_id: Annotated[str, Depends(get_user_id)],
version: int | None = None,
hide_credentials: bool = False,
) -> graph_db.Graph:
graph = await graph_db.get_graph(graph_id, version, user_id=user_id)
graph = await graph_db.get_graph(
graph_id, version, user_id=user_id, hide_credentials=hide_credentials
)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return graph
@@ -393,6 +410,17 @@ class AgentServer(AppService):
) -> graph_db.Graph:
return await cls.create_graph(create_graph, is_template=True, user_id=user_id)
class DeleteGraphResponse(TypedDict):
version_counts: int
@classmethod
async def delete_graph(
cls, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> DeleteGraphResponse:
return {
"version_counts": await graph_db.delete_graph(graph_id, user_id=user_id)
}
@classmethod
async def create_graph(
cls,
@@ -613,10 +641,8 @@ class AgentServer(AppService):
execution_scheduler = self.execution_scheduler_client
return execution_scheduler.get_execution_schedules(graph_id, user_id)
@expose
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
execution_result = execution_db.ExecutionResult(**execution_result_dict)
self.event_queue.put(execution_result)
async def health(self):
return {"status": "healthy"}
@classmethod
def update_configuration(

View File

@@ -1,6 +1,5 @@
from autogpt_libs.auth.middleware import auth_middleware
from fastapi import Depends, HTTPException
from supabase import Client, create_client
from backend.data.user import DEFAULT_USER_ID
from backend.util.settings import Settings
@@ -17,9 +16,3 @@ def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found in token")
return user_id
def get_supabase() -> Client:
return create_client(
settings.secrets.supabase_url, settings.secrets.supabase_service_role_key
)

View File

@@ -7,12 +7,13 @@ from autogpt_libs.auth import parse_jwt_token
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from backend.data import redis
from backend.data.queue import RedisEventQueue
from backend.data.user import DEFAULT_USER_ID
from backend.server.conn_manager import ConnectionManager
from backend.server.model import ExecutionSubscription, Methods, WsMessage
from backend.util.service import AppProcess
from backend.util.settings import Config, Settings
from backend.util.settings import AppEnvironment, Config, Settings
logger = logging.getLogger(__name__)
settings = Settings()
@@ -20,16 +21,14 @@ settings = Settings()
@asynccontextmanager
async def lifespan(app: FastAPI):
event_queue.connect()
manager = get_connection_manager()
fut = asyncio.create_task(event_broadcaster(manager))
fut.add_done_callback(lambda _: logger.info("Event broadcaster stopped"))
yield
event_queue.close()
docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
app = FastAPI(lifespan=lifespan)
event_queue = RedisEventQueue()
_connection_manager = None
logger.info(f"CORS allow origins: {settings.config.backend_cors_allow_origins}")
@@ -50,12 +49,20 @@ def get_connection_manager():
async def event_broadcaster(manager: ConnectionManager):
while True:
event = event_queue.get()
if event is not None:
await manager.send_execution_result(event)
else:
await asyncio.sleep(0.1)
try:
redis.connect()
event_queue = RedisEventQueue()
while True:
event = event_queue.get()
if event:
await manager.send_execution_result(event)
else:
await asyncio.sleep(0.1)
except Exception as e:
logger.exception(f"Event broadcaster error: {e}")
raise
finally:
redis.disconnect()
async def authenticate_websocket(websocket: WebSocket) -> str:

View File

@@ -0,0 +1,21 @@
import threading
from functools import wraps
from typing import Callable, TypeVar
T = TypeVar("T")
R = TypeVar("R")
def thread_cached_property(func: Callable[[T], R]) -> property:
local_cache = threading.local()
@wraps(func)
def wrapper(self: T) -> R:
if not hasattr(local_cache, "cache"):
local_cache.cache = {}
key = id(self)
if key not in local_cache.cache:
local_cache.cache[key] = func(self)
return local_cache.cache[key]
return property(wrapper)

View File

@@ -1,4 +1,6 @@
import os
from backend.util.settings import AppEnvironment, BehaveAs, Settings
settings = Settings()
def configure_logging():
@@ -6,7 +8,10 @@ def configure_logging():
import autogpt_libs.logging.config
if os.getenv("APP_ENV") != "cloud":
if (
settings.config.behave_as == BehaveAs.LOCAL
or settings.config.app_env == AppEnvironment.LOCAL
):
autogpt_libs.logging.config.configure_logging(force_cloud_logging=False)
else:
autogpt_libs.logging.config.configure_logging(force_cloud_logging=True)

View File

@@ -17,6 +17,11 @@ def get_service_name():
return _SERVICE_NAME
def set_service_name(name: str):
global _SERVICE_NAME
_SERVICE_NAME = name
class AppProcess(ABC):
"""
A class to represent an object that can be executed in a background process.
@@ -63,9 +68,7 @@ class AppProcess(ABC):
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
global _SERVICE_NAME
_SERVICE_NAME = self.service_name
set_service_name(self.service_name)
logger.info(f"[{self.service_name}] Starting...")
self.run()
except (KeyboardInterrupt, SystemExit) as e:

View File

@@ -1,5 +1,6 @@
import logging
import os
from functools import wraps
from uuid import uuid4
from tenacity import retry, stop_after_attempt, wait_exponential
@@ -21,28 +22,33 @@ def _log_prefix(resource_name: str, conn_id: str):
def conn_retry(resource_name: str, action_name: str, max_retry: int = 5):
conn_id = str(uuid4())
def before_call(retry_state):
prefix = _log_prefix(resource_name, conn_id)
logger.info(f"{prefix} {action_name} started...")
def after_call(retry_state):
prefix = _log_prefix(resource_name, conn_id)
if retry_state.outcome.failed:
# Optionally, you can log something here if needed
pass
else:
logger.info(f"{prefix} {action_name} completed!")
def on_retry(retry_state):
prefix = _log_prefix(resource_name, conn_id)
exception = retry_state.outcome.exception()
logger.info(f"{prefix} {action_name} failed: {exception}. Retrying now...")
return retry(
stop=stop_after_attempt(max_retry + 1),
wait=wait_exponential(multiplier=1, min=1, max=30),
before=before_call,
after=after_call,
before_sleep=on_retry,
reraise=True,
)
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
prefix = _log_prefix(resource_name, conn_id)
logger.info(f"{prefix} {action_name} started...")
# Define the retrying strategy
retrying_func = retry(
stop=stop_after_attempt(max_retry + 1),
wait=wait_exponential(multiplier=1, min=1, max=30),
before_sleep=on_retry,
reraise=True,
)(func)
try:
result = retrying_func(*args, **kwargs)
logger.info(f"{prefix} {action_name} completed successfully.")
return result
except Exception as e:
logger.error(f"{prefix} {action_name} failed after retries: {e}")
raise
return wrapper
return decorator

View File

@@ -1,16 +1,36 @@
import asyncio
import builtins
import logging
import os
import threading
import time
from abc import abstractmethod
from typing import Any, Callable, Coroutine, Type, TypeVar, cast
import typing
from enum import Enum
from types import NoneType, UnionType
from typing import (
Annotated,
Any,
Callable,
Coroutine,
Dict,
FrozenSet,
Iterator,
List,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
get_args,
get_origin,
)
import Pyro5.api
from pydantic import BaseModel
from Pyro5 import api as pyro
from backend.data import db
from backend.data.queue import AbstractEventQueue, RedisEventQueue
from backend.data import db, redis
from backend.util.process import AppProcess
from backend.util.retry import conn_retry
from backend.util.settings import Config, Secrets
@@ -27,9 +47,8 @@ def expose(func: C) -> C:
Decorator to mark a method or class to be exposed for remote calls.
## ⚠️ Gotcha
The types on the exposed function signature are respected **as long as they are
fully picklable**. This is not the case for Pydantic models, so if you really need
to pass a model, try dumping the model and passing the resulting dict instead.
Aside from "simple" types, only Pydantic models are passed unscathed *if annotated*.
Any other passed or returned class objects are converted to dictionaries by Pyro.
"""
def wrapper(*args, **kwargs):
@@ -38,24 +57,59 @@ def expose(func: C) -> C:
except Exception as e:
msg = f"Error in {func.__name__}: {e.__str__()}"
logger.exception(msg)
raise Exception(msg, e)
raise
# Register custom serializers and deserializers for annotated Pydantic models
for name, annotation in func.__annotations__.items():
try:
pydantic_types = _pydantic_models_from_type_annotation(annotation)
except Exception as e:
raise TypeError(f"Error while exposing {func.__name__}: {e.__str__()}")
for model in pydantic_types:
logger.debug(
f"Registering Pyro (de)serializers for {func.__name__} annotation "
f"'{name}': {model.__qualname__}"
)
pyro.register_class_to_dict(model, _make_custom_serializer(model))
pyro.register_dict_to_class(
model.__qualname__, _make_custom_deserializer(model)
)
return pyro.expose(wrapper) # type: ignore
def _make_custom_serializer(model: Type[BaseModel]):
def custom_class_to_dict(obj):
data = {
"__class__": obj.__class__.__qualname__,
**obj.model_dump(),
}
logger.debug(f"Serializing {obj.__class__.__qualname__} with data: {data}")
return data
return custom_class_to_dict
def _make_custom_deserializer(model: Type[BaseModel]):
def custom_dict_to_class(qualname, data: dict):
logger.debug(f"Deserializing {model.__qualname__} from data: {data}")
return model(**data)
return custom_dict_to_class
class AppService(AppProcess):
shared_event_loop: asyncio.AbstractEventLoop
event_queue: AbstractEventQueue = RedisEventQueue()
use_db: bool = False
use_queue: bool = False
use_redis: bool = False
use_supabase: bool = False
def __init__(self, port):
self.port = port
self.uri = None
@abstractmethod
def run_service(self):
def run_service(self) -> None:
while True:
time.sleep(10)
@@ -70,8 +124,8 @@ class AppService(AppProcess):
self.shared_event_loop = asyncio.get_event_loop()
if self.use_db:
self.shared_event_loop.run_until_complete(db.connect())
if self.use_queue:
self.event_queue.connect()
if self.use_redis:
redis.connect()
if self.use_supabase:
from supabase import create_client
@@ -97,9 +151,9 @@ class AppService(AppProcess):
if self.use_db:
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting DB...")
self.run_and_wait(db.disconnect())
if self.use_queue:
if self.use_redis:
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting Redis...")
self.event_queue.close()
redis.disconnect()
@conn_retry("Pyro", "Starting Pyro Service")
def __start_pyro(self):
@@ -131,6 +185,53 @@ def get_service_client(service_type: Type[AS], port: int) -> AS:
logger.debug(f"Successfully connected to service [{service_name}]")
def __getattr__(self, name: str) -> Callable[..., Any]:
return getattr(self.proxy, name)
res = getattr(self.proxy, name)
return res
return cast(AS, DynamicClient())
# --------- UTILITIES --------- #
builtin_types = [*vars(builtins).values(), NoneType, Enum]
def _pydantic_models_from_type_annotation(annotation) -> Iterator[type[BaseModel]]:
# Peel Annotated parameters
if (origin := get_origin(annotation)) and origin is Annotated:
annotation = get_args(annotation)[0]
origin = get_origin(annotation)
args = get_args(annotation)
if origin in (
Union,
UnionType,
list,
List,
tuple,
Tuple,
set,
Set,
frozenset,
FrozenSet,
):
for arg in args:
yield from _pydantic_models_from_type_annotation(arg)
elif origin in (dict, Dict):
key_type, value_type = args
yield from _pydantic_models_from_type_annotation(key_type)
yield from _pydantic_models_from_type_annotation(value_type)
else:
annotype = annotation if origin is None else origin
# Exclude generic types and aliases
if (
annotype is not None
and not hasattr(typing, getattr(annotype, "__name__", ""))
and isinstance(annotype, type)
):
if issubclass(annotype, BaseModel):
yield annotype
elif annotype not in builtin_types and not issubclass(annotype, Enum):
raise TypeError(f"Unsupported type encountered: {annotype}")

View File

@@ -1,5 +1,6 @@
import json
import os
from enum import Enum
from typing import Any, Dict, Generic, List, Set, Tuple, Type, TypeVar
from pydantic import BaseModel, Field, PrivateAttr, field_validator
@@ -15,6 +16,17 @@ from backend.util.data import get_config_path, get_data_path, get_secrets_path
T = TypeVar("T", bound=BaseSettings)
class AppEnvironment(str, Enum):
LOCAL = "local"
DEVELOPMENT = "dev"
PRODUCTION = "prod"
class BehaveAs(str, Enum):
LOCAL = "local"
CLOUD = "cloud"
class UpdateTrackingModel(BaseModel, Generic[T]):
_updated_fields: Set[str] = PrivateAttr(default_factory=set)
@@ -105,6 +117,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="The port for agent server daemon to run on",
)
database_api_port: int = Field(
default=8005,
description="The port for database server API to run on",
)
agent_api_host: str = Field(
default="0.0.0.0",
description="The host for agent server API to run on",
@@ -121,6 +138,16 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
"This value is then used to generate redirect URLs for OAuth flows.",
)
app_env: AppEnvironment = Field(
default=AppEnvironment.LOCAL,
description="The name of the app environment: local or dev or prod",
)
behave_as: BehaveAs = Field(
default=BehaveAs.LOCAL,
description="What environment to behave as: local or cloud",
)
backend_cors_allow_origins: List[str] = Field(default_factory=list)
@field_validator("backend_cors_allow_origins")
@@ -177,10 +204,12 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
)
# OAuth server credentials for integrations
# --8<-- [start:OAuthServerCredentialsExample]
github_client_id: str = Field(default="", description="GitHub OAuth client ID")
github_client_secret: str = Field(
default="", description="GitHub OAuth client secret"
)
# --8<-- [end:OAuthServerCredentialsExample]
google_client_id: str = Field(default="", description="Google OAuth client ID")
google_client_secret: str = Field(
default="", description="Google OAuth client secret"

View File

@@ -5,15 +5,15 @@ from backend.data.block import Block, initialize_blocks
from backend.data.execution import ExecutionStatus
from backend.data.model import CREDENTIALS_FIELD_NAME
from backend.data.user import create_default_user
from backend.executor import ExecutionManager, ExecutionScheduler
from backend.server import AgentServer
from backend.server.rest_api import get_user_id
from backend.executor import DatabaseManager, ExecutionManager, ExecutionScheduler
from backend.server.rest_api import AgentServer, get_user_id
log = print
class SpinTestServer:
def __init__(self):
self.db_api = DatabaseManager()
self.exec_manager = ExecutionManager()
self.agent_server = AgentServer()
self.scheduler = ExecutionScheduler()
@@ -24,6 +24,7 @@ class SpinTestServer:
async def __aenter__(self):
self.setup_dependency_overrides()
self.db_api.__enter__()
self.agent_server.__enter__()
self.exec_manager.__enter__()
self.scheduler.__enter__()
@@ -40,6 +41,7 @@ class SpinTestServer:
self.scheduler.__exit__(exc_type, exc_val, exc_tb)
self.exec_manager.__exit__(exc_type, exc_val, exc_tb)
self.agent_server.__exit__(exc_type, exc_val, exc_tb)
self.db_api.__exit__(exc_type, exc_val, exc_tb)
def setup_dependency_overrides(self):
# Override get_user_id for testing

View File

@@ -0,0 +1,5 @@
-- DropForeignKey
ALTER TABLE "AgentGraph" DROP CONSTRAINT "AgentGraph_agentGraphParentId_version_fkey";
-- AddForeignKey
ALTER TABLE "AgentGraph" ADD CONSTRAINT "AgentGraph_agentGraphParentId_version_fkey" FOREIGN KEY ("agentGraphParentId", "version") REFERENCES "AgentGraph"("id", "version") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -293,6 +293,7 @@ develop = true
[package.dependencies]
colorama = "^0.4.6"
expiringdict = "^1.2.2"
google-cloud-logging = "^3.8.0"
pydantic = "^2.8.2"
pydantic-settings = "^2.5.2"
@@ -3667,4 +3668,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "3ab370b624b486517a2fbcdc17fb294fbd76b3ec6659c5b471c57bfd738e7277"
content-hash = "0962d61ced1a8154c64c6bbdb3f72aca558831adfbfda68eb66f39b535466f77"

View File

@@ -16,7 +16,6 @@ autogpt-libs = { path = "../autogpt_libs", develop = true }
click = "^8.1.7"
croniter = "^2.0.5"
discord-py = "^2.4.0"
expiringdict = "^1.2.2"
fastapi = "^0.109.0"
feedparser = "^6.0.11"
flake8 = "^7.0.0"

View File

@@ -53,7 +53,7 @@ model AgentGraph {
// All sub-graphs are defined within this 1-level depth list (even if it's a nested graph).
AgentSubGraphs AgentGraph[] @relation("AgentSubGraph")
agentGraphParentId String?
AgentGraphParent AgentGraph? @relation("AgentSubGraph", fields: [agentGraphParentId, version], references: [id, version])
AgentGraphParent AgentGraph? @relation("AgentSubGraph", fields: [agentGraphParentId, version], references: [id, version], onDelete: Cascade)
@@id(name: "graphVersionId", [id, version])
}
@@ -63,7 +63,7 @@ model AgentNode {
id String @id @default(uuid())
agentBlockId String
AgentBlock AgentBlock @relation(fields: [agentBlockId], references: [id])
AgentBlock AgentBlock @relation(fields: [agentBlockId], references: [id], onUpdate: Cascade)
agentGraphId String
agentGraphVersion Int @default(1)

View File

@@ -7,3 +7,28 @@ from backend.util.test import SpinTestServer
async def server():
async with SpinTestServer() as server:
yield server
@pytest.fixture(scope="session", autouse=True)
async def graph_cleanup(server):
created_graph_ids = []
original_create_graph = server.agent_server.create_graph
async def create_graph_wrapper(*args, **kwargs):
created_graph = await original_create_graph(*args, **kwargs)
# Extract user_id correctly
user_id = kwargs.get("user_id", args[2] if len(args) > 2 else None)
created_graph_ids.append((created_graph.id, user_id))
return created_graph
try:
server.agent_server.create_graph = create_graph_wrapper
yield # This runs the test function
finally:
server.agent_server.create_graph = original_create_graph
# Delete the created graphs and assert they were deleted
for graph_id, user_id in created_graph_ids:
resp = await server.agent_server.delete_graph(graph_id, user_id)
num_deleted = resp["version_counts"]
assert num_deleted > 0, f"Graph {graph_id} was not deleted."

View File

@@ -19,7 +19,7 @@ async def test_block_credit_usage(server: SpinTestServer):
spending_amount_1 = await user_credit.spend_credits(
DEFAULT_USER_ID,
current_credit,
AITextGeneratorBlock(),
AITextGeneratorBlock().id,
{"model": "gpt-4-turbo"},
0.0,
0.0,
@@ -30,7 +30,7 @@ async def test_block_credit_usage(server: SpinTestServer):
spending_amount_2 = await user_credit.spend_credits(
DEFAULT_USER_ID,
current_credit,
AITextGeneratorBlock(),
AITextGeneratorBlock().id,
{"model": "gpt-4-turbo", "api_key": "owned_api_key"},
0.0,
0.0,

View File

@@ -4,11 +4,16 @@ from prisma.models import User
from backend.blocks.basic import FindInDictionaryBlock, StoreValueBlock
from backend.blocks.maths import CalculatorBlock, Operation
from backend.data import execution, graph
from backend.server import AgentServer
from backend.server.model import CreateGraph
from backend.server.rest_api import AgentServer
from backend.usecases.sample import create_test_graph, create_test_user
from backend.util.test import SpinTestServer, wait_execution
async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Graph:
return await s.agent_server.create_graph(CreateGraph(graph=g), False, u.id)
async def execute_graph(
agent_server: AgentServer,
test_graph: graph.Graph,
@@ -99,9 +104,8 @@ async def assert_sample_graph_executions(
@pytest.mark.asyncio(scope="session")
async def test_agent_execution(server: SpinTestServer):
test_graph = create_test_graph()
test_user = await create_test_user()
await graph.create_graph(test_graph, user_id=test_user.id)
test_graph = await create_graph(server, create_test_graph(), test_user)
data = {"input_1": "Hello", "input_2": "World"}
graph_exec_id = await execute_graph(
server.agent_server,
@@ -163,7 +167,7 @@ async def test_input_pin_always_waited(server: SpinTestServer):
links=links,
)
test_user = await create_test_user()
test_graph = await graph.create_graph(test_graph, user_id=test_user.id)
test_graph = await create_graph(server, test_graph, test_user)
graph_exec_id = await execute_graph(
server.agent_server, test_graph, test_user, {}, 3
)
@@ -244,7 +248,7 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
links=links,
)
test_user = await create_test_user()
test_graph = await graph.create_graph(test_graph, user_id=test_user.id)
test_graph = await create_graph(server, test_graph, test_user)
graph_exec_id = await execute_graph(
server.agent_server, test_graph, test_user, {}, 8
)

View File

@@ -1,7 +1,8 @@
import pytest
from backend.data import db, graph
from backend.data import db
from backend.executor import ExecutionScheduler
from backend.server.model import CreateGraph
from backend.usecases.sample import create_test_graph, create_test_user
from backend.util.service import get_service_client
from backend.util.settings import Config
@@ -12,7 +13,11 @@ from backend.util.test import SpinTestServer
async def test_agent_schedule(server: SpinTestServer):
await db.connect()
test_user = await create_test_user()
test_graph = await graph.create_graph(create_test_graph(), user_id=test_user.id)
test_graph = await server.agent_server.create_graph(
create_graph=CreateGraph(graph=create_test_graph()),
is_template=False,
user_id=test_user.id,
)
scheduler = get_service_client(
ExecutionScheduler, Config().execution_scheduler_port

View File

@@ -2,13 +2,12 @@ import pytest
from backend.util.service import AppService, expose, get_service_client
TEST_SERVICE_PORT = 8765
class TestService(AppService):
class ServiceTest(AppService):
def __init__(self):
super().__init__(port=8005)
def run_service(self):
super().run_service()
super().__init__(port=TEST_SERVICE_PORT)
@expose
def add(self, a: int, b: int) -> int:
@@ -28,8 +27,8 @@ class TestService(AppService):
@pytest.mark.asyncio(scope="session")
async def test_service_creation(server):
with TestService():
client = get_service_client(TestService, 8005)
with ServiceTest():
client = get_service_client(ServiceTest, TEST_SERVICE_PORT)
assert client.add(5, 3) == 8
assert client.subtract(10, 4) == 6
assert client.fun_with_async(5, 3) == 8

View File

@@ -103,6 +103,7 @@ services:
- ENABLE_AUTH=true
- PYRO_HOST=0.0.0.0
- AGENTSERVER_HOST=rest_server
- DATABASEMANAGER_HOST=0.0.0.0
ports:
- "8002:8000"
networks:

View File

@@ -1,3 +1,3 @@
{
"extends": "next/core-web-vitals"
"extends": ["next/core-web-vitals", "plugin:storybook/recommended"]
}

View File

@@ -42,3 +42,6 @@ node_modules/
/playwright-report/
/blob-report/
/playwright/.cache/
*storybook.log
storybook-static

View File

@@ -0,0 +1,18 @@
import type { StorybookConfig } from "@storybook/nextjs";
const config: StorybookConfig = {
stories: ["../src/**/*.mdx", "../src/**/*.stories.@(js|jsx|mjs|ts|tsx)"],
addons: [
"@storybook/addon-onboarding",
"@storybook/addon-links",
"@storybook/addon-essentials",
"@chromatic-com/storybook",
"@storybook/addon-interactions",
],
framework: {
name: "@storybook/nextjs",
options: {},
},
staticDirs: ["../public"],
};
export default config;

View File

@@ -0,0 +1,15 @@
import type { Preview } from "@storybook/react";
import "../src/app/globals.css";
const preview: Preview = {
parameters: {
controls: {
matchers: {
color: /(background|color)$/i,
date: /Date$/i,
},
},
},
};
export default preview;

View File

@@ -14,7 +14,7 @@ CMD ["yarn", "run", "dev"]
# Build stage for prod
FROM base AS build
COPY autogpt_platform/frontend/ .
RUN npm run build
RUN yarn build
# Prod stage
FROM node:21-alpine AS prod
@@ -29,4 +29,4 @@ COPY --from=build /app/public ./public
COPY --from=build /app/next.config.mjs ./next.config.mjs
EXPOSE 3000
CMD ["npm", "start"]
CMD ["yarn", "start"]

View File

@@ -39,3 +39,50 @@ This project uses [`next/font`](https://nextjs.org/docs/basic-features/font-opti
## Deploy
TODO
## Storybook
Storybook is a powerful development environment for UI components. It allows you to build UI components in isolation, making it easier to develop, test, and document your components independently from your main application.
### Purpose in the Development Process
1. **Component Development**: Develop and test UI components in isolation.
2. **Visual Testing**: Easily spot visual regressions.
3. **Documentation**: Automatically document components and their props.
4. **Collaboration**: Share components with your team or stakeholders for feedback.
### How to Use Storybook
1. **Start Storybook**:
Run the following command to start the Storybook development server:
```bash
npm run storybook
```
This will start Storybook on port 6006. Open [http://localhost:6006](http://localhost:6006) in your browser to view your component library.
2. **Build Storybook**:
To build a static version of Storybook for deployment, use:
```bash
npm run build-storybook
```
3. **Running Storybook Tests**:
Storybook tests can be run using:
```bash
npm run test-storybook
```
For CI environments, use:
```bash
npm run test-storybook:ci
```
4. **Writing Stories**:
Create `.stories.tsx` files alongside your components to define different states and variations of your components.
By integrating Storybook into our development workflow, we can streamline UI development, improve component reusability, and maintain a consistent design system across the project.

View File

@@ -8,11 +8,15 @@
"dev:test": "export NODE_ENV=test && next dev",
"build": "next build",
"start": "next start",
"lint": "next lint",
"lint": "next lint && prettier --check .",
"format": "prettier --write .",
"test": "playwright test",
"test-ui": "playwright test --ui",
"gentests": "playwright codegen http://localhost:3000"
"gentests": "playwright codegen http://localhost:3000",
"storybook": "storybook dev -p 6006",
"build-storybook": "storybook build",
"test-storybook": "test-storybook",
"test-storybook:ci": "concurrently -k -s first -n \"SB,TEST\" -c \"magenta,blue\" \"npm run build-storybook -- --quiet && npx http-server storybook-static --port 6006 --silent\" \"wait-on tcp:6006 && npm run test-storybook\""
},
"browserslist": [
"defaults"
@@ -23,6 +27,7 @@
"@radix-ui/react-avatar": "^1.1.0",
"@radix-ui/react-checkbox": "^1.1.1",
"@radix-ui/react-collapsible": "^1.1.0",
"@radix-ui/react-context-menu": "^2.2.1",
"@radix-ui/react-dialog": "^1.1.1",
"@radix-ui/react-dropdown-menu": "^2.1.1",
"@radix-ui/react-icons": "^1.3.0",
@@ -39,7 +44,7 @@
"@supabase/ssr": "^0.4.0",
"@supabase/supabase-js": "^2.45.0",
"@tanstack/react-table": "^8.20.5",
"@xyflow/react": "^12.1.0",
"@xyflow/react": "^12.3.1",
"ajv": "^8.17.1",
"class-variance-authority": "^0.7.0",
"clsx": "^2.1.1",
@@ -65,17 +70,31 @@
"zod": "^3.23.8"
},
"devDependencies": {
"@chromatic-com/storybook": "^1.9.0",
"@playwright/test": "^1.47.1",
"@storybook/addon-essentials": "^8.3.5",
"@storybook/addon-interactions": "^8.3.5",
"@storybook/addon-links": "^8.3.5",
"@storybook/addon-onboarding": "^8.3.5",
"@storybook/blocks": "^8.3.5",
"@storybook/nextjs": "^8.3.5",
"@storybook/react": "^8.3.5",
"@storybook/test": "^8.3.5",
"@storybook/test-runner": "^0.19.1",
"@types/node": "^22.7.3",
"@types/react": "^18",
"@types/react-dom": "^18",
"@types/react-modal": "^3.16.3",
"concurrently": "^9.0.1",
"eslint": "^8",
"eslint-config-next": "14.2.4",
"eslint-plugin-storybook": "^0.9.0",
"postcss": "^8",
"prettier": "^3.3.3",
"prettier-plugin-tailwindcss": "^0.6.6",
"storybook": "^8.3.5",
"tailwindcss": "^3.4.1",
"typescript": "^5"
}
},
"packageManager": "yarn@1.22.22+sha512.a6b2f7906b721bba3d67d4aff083df04dad64c399707841b7acf00f6b133b7ac24255f2652fa22ae3534329dc6180534e98d17432037ff6fd140556e2bb3137e"
}

View File

@@ -10,7 +10,7 @@ async function AdminMarketplace() {
return (
<>
<AdminMarketplaceAgentList agents={reviewableAgents.agents} />
<AdminMarketplaceAgentList agents={reviewableAgents.items} />
<Separator className="my-4" />
<AdminFeaturedAgentsControl className="mt-4" />
</>

View File

@@ -27,7 +27,7 @@
--destructive: 0 84.2% 60.2%;
--destructive-foreground: 0 0% 98%;
--border: 240 5.9% 90%;
--input: 240 5.9% 90%;
--input: 240 5.9% 85%;
--ring: 240 5.9% 10%;
--radius: 0.5rem;
--chart-1: 12 76% 61%;
@@ -72,4 +72,12 @@
body {
@apply bg-background text-foreground;
}
.agpt-border-input {
@apply border-input focus-visible:border-gray-400 focus-visible:outline-none;
}
.agpt-shadow-input {
@apply shadow-sm focus-visible:shadow-md;
}
}

View File

@@ -114,7 +114,7 @@ export default function LoginPage() {
return (
<div className="flex h-[80vh] items-center justify-center">
<div className="w-full max-w-md space-y-6 rounded-lg p-8 shadow-md">
<div className="mb-6 space-y-2">
{/* <div className="mb-6 space-y-2">
<Button
className="w-full"
onClick={() => handleSignInWithProvider("google")}
@@ -145,7 +145,7 @@ export default function LoginPage() {
<FaDiscord className="mr-2 h-4 w-4" />
Sign in with Discord
</Button>
</div>
</div> */}
<Form {...form}>
<form onSubmit={form.handleSubmit(onLogin)}>
<FormField

Some files were not shown because too many files have changed in this diff Show More