mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
90 Commits
claude/iss
...
fix/vercel
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1fbc1c88f | ||
|
|
0f13d14fd4 | ||
|
|
fde3533943 | ||
|
|
a789f87734 | ||
|
|
0b6e46d363 | ||
|
|
6ffe57c3df | ||
|
|
3ca0d04ea0 | ||
|
|
c2eea593c0 | ||
|
|
36f5f24333 | ||
|
|
309114a727 | ||
|
|
4ffb99bfb0 | ||
|
|
5741331250 | ||
|
|
2fda8dfd32 | ||
|
|
22c76eab61 | ||
|
|
7688a9701e | ||
|
|
243400e128 | ||
|
|
c77cb1fcfb | ||
|
|
b3b5eefe2c | ||
|
|
fe36ba55dd | ||
|
|
45c1ca1ca1 | ||
|
|
b12507fb21 | ||
|
|
ab4eb10c3d | ||
|
|
42e141012f | ||
|
|
b7f9dcf419 | ||
|
|
a4ff8402f1 | ||
|
|
2183c94c58 | ||
|
|
5ff6d2ca56 | ||
|
|
02d3b42745 | ||
|
|
4bf73f63f4 | ||
|
|
189c353c59 | ||
|
|
07461a88cf | ||
|
|
daf05cb7bf | ||
|
|
29bdbf3650 | ||
|
|
171deea806 | ||
|
|
149bbd910a | ||
|
|
c6741e7c14 | ||
|
|
6de1e470d9 | ||
|
|
67eefdd35c | ||
|
|
01950ccc42 | ||
|
|
358ce1d258 | ||
|
|
a5691c0e89 | ||
|
|
0b35dff1e6 | ||
|
|
6cf9136cdd | ||
|
|
5d91a9c9b9 | ||
|
|
e3d84d87f8 | ||
|
|
9fecbe2a31 | ||
|
|
4744e0f6b1 | ||
|
|
24b4ab9864 | ||
|
|
04e90da031 | ||
|
|
d4646c249d | ||
|
|
095199bfa6 | ||
|
|
90fb223114 | ||
|
|
b1f3122243 | ||
|
|
f1cc2afbda | ||
|
|
f394a0eabb | ||
|
|
311bcc7751 | ||
|
|
e2bd727798 | ||
|
|
47f503f223 | ||
|
|
22d58367ec | ||
|
|
d076d0175f | ||
|
|
a33d58dd33 | ||
|
|
254bb6236f | ||
|
|
198b3d9f45 | ||
|
|
9a6ae90d12 | ||
|
|
89a5ba69e5 | ||
|
|
b32ac898db | ||
|
|
4f6e66447f | ||
|
|
fae927e2a7 | ||
|
|
6ef8119708 | ||
|
|
a0a7129081 | ||
|
|
f3202fa776 | ||
|
|
b5c7f381c1 | ||
|
|
2dd366172e | ||
|
|
4d0db27d5e | ||
|
|
5421ccf86a | ||
|
|
c4056cbae9 | ||
|
|
c01beaf003 | ||
|
|
cf560c5d65 | ||
|
|
77e99e9739 | ||
|
|
7f7c387156 | ||
|
|
21cf263eea | ||
|
|
500952a15f | ||
|
|
b3c81fa9e2 | ||
|
|
59e96d4759 | ||
|
|
e01dd94e36 | ||
|
|
2f12e8d731 | ||
|
|
83dbcd11e4 | ||
|
|
f66b8f9c74 | ||
|
|
2af9d75dec | ||
|
|
0b37263092 |
5
.github/workflows/platform-backend-ci.yml
vendored
5
.github/workflows/platform-backend-ci.yml
vendored
@@ -190,9 +190,9 @@ jobs:
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG test
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
|
||||
else
|
||||
poetry run pytest -s -vv test
|
||||
poetry run pytest -s -vv
|
||||
fi
|
||||
if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
env:
|
||||
@@ -205,6 +205,7 @@ jobs:
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
env:
|
||||
CI: true
|
||||
|
||||
86
.github/workflows/platform-frontend-ci.yml
vendored
86
.github/workflows/platform-frontend-ci.yml
vendored
@@ -18,11 +18,14 @@ defaults:
|
||||
working-directory: autogpt_platform/frontend
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
@@ -32,6 +35,45 @@ jobs:
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Generate cache key
|
||||
id: cache-key
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('**/pnpm-lock.yaml') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
@@ -40,9 +82,11 @@ jobs:
|
||||
|
||||
type-check:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
@@ -52,21 +96,29 @@ jobs:
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Generate API client
|
||||
run: pnpm generate:api-client
|
||||
|
||||
- name: Run tsc check
|
||||
run: pnpm type-check
|
||||
|
||||
chromatic:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
# Only run on dev branch pushes or PRs targeting dev
|
||||
if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -78,6 +130,14 @@ jobs:
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
@@ -88,9 +148,11 @@ jobs:
|
||||
onlyChanged: true
|
||||
workingDir: autogpt_platform/frontend
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
exitOnceUploaded: true
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -128,6 +190,14 @@ jobs:
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml up -d
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
@@ -143,6 +213,8 @@ jobs:
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build --project=${{ matrix.browser }}
|
||||
env:
|
||||
BROWSER_TYPE: ${{ matrix.browser }}
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -177,6 +177,3 @@ autogpt_platform/backend/settings.py
|
||||
*.ign.*
|
||||
.test-contents
|
||||
.claude/settings.local.json
|
||||
|
||||
# Auto generated client
|
||||
autogpt_platform/frontend/src/api/__generated__
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# AutoGPT: Build, Deploy, and Run AI Agents
|
||||
|
||||
[](https://discord.gg/autogpt)  
|
||||
[](https://discord.gg/autogpt)  
|
||||
[](https://twitter.com/Auto_GPT)  
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
|
||||
**AutoGPT** is a powerful platform that allows you to create, deploy, and manage continuous AI agents that automate complex workflows.
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
@@ -121,6 +120,7 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
3. Define input/output schemas
|
||||
4. Implement `run` method
|
||||
5. Register in block registry
|
||||
6. Generate the block uuid using `uuid.uuid4()`
|
||||
|
||||
**Modifying the API:**
|
||||
1. Update route in `/backend/backend/server/routers/`
|
||||
@@ -143,4 +143,4 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
- Applied to both main API server and external API applications
|
||||
|
||||
48
autogpt_platform/autogpt_libs/poetry.lock
generated
48
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
@@ -177,7 +177,7 @@ files = [
|
||||
{file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"},
|
||||
{file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
|
||||
]
|
||||
markers = {main = "python_version == \"3.10\"", dev = "python_full_version < \"3.11.3\""}
|
||||
markers = {main = "python_version < \"3.11\"", dev = "python_full_version < \"3.11.3\""}
|
||||
|
||||
[[package]]
|
||||
name = "attrs"
|
||||
@@ -390,7 +390,7 @@ description = "Backport of PEP 654 (exception groups)"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main"]
|
||||
markers = "python_version == \"3.10\""
|
||||
markers = "python_version < \"3.11\""
|
||||
files = [
|
||||
{file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"},
|
||||
{file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
|
||||
@@ -1667,30 +1667,30 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.11.10"
|
||||
version = "0.12.2"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "ruff-0.11.10-py3-none-linux_armv6l.whl", hash = "sha256:859a7bfa7bc8888abbea31ef8a2b411714e6a80f0d173c2a82f9041ed6b50f58"},
|
||||
{file = "ruff-0.11.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:968220a57e09ea5e4fd48ed1c646419961a0570727c7e069842edd018ee8afed"},
|
||||
{file = "ruff-0.11.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:1067245bad978e7aa7b22f67113ecc6eb241dca0d9b696144256c3a879663bca"},
|
||||
{file = "ruff-0.11.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4854fd09c7aed5b1590e996a81aeff0c9ff51378b084eb5a0b9cd9518e6cff2"},
|
||||
{file = "ruff-0.11.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8b4564e9f99168c0f9195a0fd5fa5928004b33b377137f978055e40008a082c5"},
|
||||
{file = "ruff-0.11.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b6a9cc5b62c03cc1fea0044ed8576379dbaf751d5503d718c973d5418483641"},
|
||||
{file = "ruff-0.11.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:607ecbb6f03e44c9e0a93aedacb17b4eb4f3563d00e8b474298a201622677947"},
|
||||
{file = "ruff-0.11.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7b3a522fa389402cd2137df9ddefe848f727250535c70dafa840badffb56b7a4"},
|
||||
{file = "ruff-0.11.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2f071b0deed7e9245d5820dac235cbdd4ef99d7b12ff04c330a241ad3534319f"},
|
||||
{file = "ruff-0.11.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a60e3a0a617eafba1f2e4186d827759d65348fa53708ca547e384db28406a0b"},
|
||||
{file = "ruff-0.11.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:da8ec977eaa4b7bf75470fb575bea2cb41a0e07c7ea9d5a0a97d13dbca697bf2"},
|
||||
{file = "ruff-0.11.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:ddf8967e08227d1bd95cc0851ef80d2ad9c7c0c5aab1eba31db49cf0a7b99523"},
|
||||
{file = "ruff-0.11.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5a94acf798a82db188f6f36575d80609072b032105d114b0f98661e1679c9125"},
|
||||
{file = "ruff-0.11.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:3afead355f1d16d95630df28d4ba17fb2cb9c8dfac8d21ced14984121f639bad"},
|
||||
{file = "ruff-0.11.10-py3-none-win32.whl", hash = "sha256:dc061a98d32a97211af7e7f3fa1d4ca2fcf919fb96c28f39551f35fc55bdbc19"},
|
||||
{file = "ruff-0.11.10-py3-none-win_amd64.whl", hash = "sha256:5cc725fbb4d25b0f185cb42df07ab6b76c4489b4bfb740a175f3a59c70e8a224"},
|
||||
{file = "ruff-0.11.10-py3-none-win_arm64.whl", hash = "sha256:ef69637b35fb8b210743926778d0e45e1bffa850a7c61e428c6b971549b5f5d1"},
|
||||
{file = "ruff-0.11.10.tar.gz", hash = "sha256:d522fb204b4959909ecac47da02830daec102eeb100fb50ea9554818d47a5fa6"},
|
||||
{file = "ruff-0.12.2-py3-none-linux_armv6l.whl", hash = "sha256:093ea2b221df1d2b8e7ad92fc6ffdca40a2cb10d8564477a987b44fd4008a7be"},
|
||||
{file = "ruff-0.12.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:09e4cf27cc10f96b1708100fa851e0daf21767e9709e1649175355280e0d950e"},
|
||||
{file = "ruff-0.12.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:8ae64755b22f4ff85e9c52d1f82644abd0b6b6b6deedceb74bd71f35c24044cc"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3eb3a6b2db4d6e2c77e682f0b988d4d61aff06860158fdb413118ca133d57922"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:73448de992d05517170fc37169cbca857dfeaeaa8c2b9be494d7bcb0d36c8f4b"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b8b94317cbc2ae4a2771af641739f933934b03555e51515e6e021c64441532d"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:45fc42c3bf1d30d2008023a0a9a0cfb06bf9835b147f11fe0679f21ae86d34b1"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce48f675c394c37e958bf229fb5c1e843e20945a6d962cf3ea20b7a107dcd9f4"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:793d8859445ea47591272021a81391350205a4af65a9392401f418a95dfb75c9"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6932323db80484dda89153da3d8e58164d01d6da86857c79f1961934354992da"},
|
||||
{file = "ruff-0.12.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6aa7e623a3a11538108f61e859ebf016c4f14a7e6e4eba1980190cacb57714ce"},
|
||||
{file = "ruff-0.12.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2a4a20aeed74671b2def096bdf2eac610c7d8ffcbf4fb0e627c06947a1d7078d"},
|
||||
{file = "ruff-0.12.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:71a4c550195612f486c9d1f2b045a600aeba851b298c667807ae933478fcef04"},
|
||||
{file = "ruff-0.12.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:4987b8f4ceadf597c927beee65a5eaf994c6e2b631df963f86d8ad1bdea99342"},
|
||||
{file = "ruff-0.12.2-py3-none-win32.whl", hash = "sha256:369ffb69b70cd55b6c3fc453b9492d98aed98062db9fec828cdfd069555f5f1a"},
|
||||
{file = "ruff-0.12.2-py3-none-win_amd64.whl", hash = "sha256:dca8a3b6d6dc9810ed8f328d406516bf4d660c00caeaef36eb831cf4871b0639"},
|
||||
{file = "ruff-0.12.2-py3-none-win_arm64.whl", hash = "sha256:48d6c6bfb4761df68bc05ae630e24f506755e702d4fb08f08460be778c7ccb12"},
|
||||
{file = "ruff-0.12.2.tar.gz", hash = "sha256:d7b4f55cd6f325cb7621244f19c873c565a08aff5a4ba9c69aa7355f3f7afd3e"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1823,7 +1823,7 @@ description = "A lil' TOML parser"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
markers = "python_version == \"3.10\""
|
||||
markers = "python_version < \"3.11\""
|
||||
files = [
|
||||
{file = "tomli-2.1.0-py3-none-any.whl", hash = "sha256:a5c57c3d1c56f5ccdf89f6523458f60ef716e210fc47c4cfb188c5ba473e0391"},
|
||||
{file = "tomli-2.1.0.tar.gz", hash = "sha256:3f646cae2aec94e17d04973e4249548320197cfabdf130015d023de4b74d8ab8"},
|
||||
@@ -2176,4 +2176,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "d92143928a88ca3a56ac200c335910eafac938940022fed8bd0d17c95040b54f"
|
||||
content-hash = "574057127b05f28c2ae39f7b11aa0d7c52f857655e9223e23a27c9989b2ac10f"
|
||||
|
||||
@@ -23,7 +23,7 @@ uvicorn = "^0.34.3"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
redis = "^5.2.1"
|
||||
ruff = "^0.11.10"
|
||||
ruff = "^0.12.2"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -199,9 +199,18 @@ ZEROBOUNCE_API_KEY=
|
||||
|
||||
## ===== OPTIONAL API KEYS END ===== ##
|
||||
|
||||
# Block Error Rate Monitoring
|
||||
BLOCK_ERROR_RATE_THRESHOLD=0.5
|
||||
BLOCK_ERROR_RATE_CHECK_INTERVAL_SECS=86400
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=INFO
|
||||
ENABLE_CLOUD_LOGGING=false
|
||||
ENABLE_FILE_LOGGING=false
|
||||
# Use to manually set the log directory
|
||||
# LOG_DIR=./logs
|
||||
|
||||
# Example Blocks Configuration
|
||||
# Set to true to enable example blocks in development
|
||||
# These blocks are disabled by default in production
|
||||
ENABLE_EXAMPLE_BLOCKS=false
|
||||
|
||||
150
autogpt_platform/backend/backend/TEST_DATA_README.md
Normal file
150
autogpt_platform/backend/backend/TEST_DATA_README.md
Normal file
@@ -0,0 +1,150 @@
|
||||
# Test Data Scripts
|
||||
|
||||
This directory contains scripts for creating and updating test data in the AutoGPT Platform database, specifically designed to test the materialized views for the store functionality.
|
||||
|
||||
## Scripts
|
||||
|
||||
### test_data_creator.py
|
||||
Creates a comprehensive set of test data including:
|
||||
- Users with profiles
|
||||
- Agent graphs, nodes, and executions
|
||||
- Store listings with multiple versions
|
||||
- Reviews and ratings
|
||||
- Library agents
|
||||
- Integration webhooks
|
||||
- Onboarding data
|
||||
- Credit transactions
|
||||
|
||||
**Image/Video Domains Used:**
|
||||
- Images: `picsum.photos` (for all image URLs)
|
||||
- Videos: `youtube.com` (for store listing videos)
|
||||
|
||||
### test_data_updater.py
|
||||
Updates existing test data to simulate real-world changes:
|
||||
- Adds new agent graph executions
|
||||
- Creates new store listing reviews
|
||||
- Updates store listing versions
|
||||
- Adds credit transactions
|
||||
- Refreshes materialized views
|
||||
|
||||
### check_db.py
|
||||
Tests and verifies materialized views functionality:
|
||||
- Checks pg_cron job status (for automatic refresh)
|
||||
- Displays current materialized view counts
|
||||
- Adds test data (executions and reviews)
|
||||
- Creates store listings if none exist
|
||||
- Manually refreshes materialized views
|
||||
- Compares before/after counts to verify updates
|
||||
- Provides a summary of test results
|
||||
|
||||
## Materialized Views
|
||||
|
||||
The scripts test three key database views:
|
||||
|
||||
1. **mv_agent_run_counts**: Tracks execution counts by agent
|
||||
2. **mv_review_stats**: Tracks review statistics (count, average rating) by store listing
|
||||
3. **StoreAgent**: A view that combines store listing data with execution counts and ratings for display
|
||||
|
||||
The materialized views (mv_agent_run_counts and mv_review_stats) are automatically refreshed every 15 minutes via pg_cron, or can be manually refreshed using the `refresh_store_materialized_views()` function.
|
||||
|
||||
## Usage
|
||||
|
||||
### Prerequisites
|
||||
|
||||
1. Ensure the database is running:
|
||||
```bash
|
||||
docker compose up -d
|
||||
# or for test database:
|
||||
docker compose -f docker-compose.test.yaml --env-file ../.env up -d
|
||||
```
|
||||
|
||||
2. Run database migrations:
|
||||
```bash
|
||||
poetry run prisma migrate deploy
|
||||
```
|
||||
|
||||
### Running the Scripts
|
||||
|
||||
#### Option 1: Use the helper script (from backend directory)
|
||||
```bash
|
||||
poetry run python run_test_data.py
|
||||
```
|
||||
|
||||
#### Option 2: Run individually
|
||||
```bash
|
||||
# From backend/test directory:
|
||||
# Create initial test data
|
||||
poetry run python test_data_creator.py
|
||||
|
||||
# Update data to test materialized view changes
|
||||
poetry run python test_data_updater.py
|
||||
|
||||
# From backend directory:
|
||||
# Test materialized views functionality
|
||||
poetry run python check_db.py
|
||||
|
||||
# Check store data status
|
||||
poetry run python check_store_data.py
|
||||
```
|
||||
|
||||
#### Option 3: Use the shell script (from backend directory)
|
||||
```bash
|
||||
./run_test_data_scripts.sh
|
||||
```
|
||||
|
||||
### Manual Materialized View Refresh
|
||||
|
||||
To manually refresh the materialized views:
|
||||
```sql
|
||||
SELECT refresh_store_materialized_views();
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
The scripts use the database configuration from your `.env` file:
|
||||
- `DATABASE_URL`: PostgreSQL connection string
|
||||
- Database should have the platform schema
|
||||
|
||||
## Data Generation Limits
|
||||
|
||||
Configured in `test_data_creator.py`:
|
||||
- 100 users
|
||||
- 100 agent blocks
|
||||
- 1-5 graphs per user
|
||||
- 2-5 nodes per graph
|
||||
- 1-5 presets per user
|
||||
- 1-10 library agents per user
|
||||
- 1-20 executions per graph
|
||||
- 1-5 reviews per store listing version
|
||||
|
||||
## Notes
|
||||
|
||||
- All image URLs use `picsum.photos` for consistency with Next.js image configuration
|
||||
- The scripts create realistic relationships between entities
|
||||
- Materialized views are refreshed at the end of each script
|
||||
- Data is designed to test both happy paths and edge cases
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Reviews and StoreAgent view showing 0
|
||||
|
||||
If `check_db.py` shows that reviews remain at 0 and StoreAgent view shows 0 store agents:
|
||||
|
||||
1. **No store listings exist**: The script will automatically create test store listings if none exist
|
||||
2. **No approved versions**: Store listings need approved versions to appear in the StoreAgent view
|
||||
3. **Check with `check_store_data.py`**: This script provides detailed information about:
|
||||
- Total store listings
|
||||
- Store listing versions by status
|
||||
- Existing reviews
|
||||
- StoreAgent view contents
|
||||
- Agent graph executions
|
||||
|
||||
### pg_cron not installed
|
||||
|
||||
The warning "pg_cron extension is not installed" is normal in local development environments. The materialized views can still be refreshed manually using the `refresh_store_materialized_views()` function, which all scripts do automatically.
|
||||
|
||||
### Common Issues
|
||||
|
||||
- **Type errors with None values**: Fixed in the latest version of check_db.py by using `or 0` for nullable numeric fields
|
||||
- **Missing relations**: Ensure you're using the correct field names (e.g., `StoreListing` not `storeListing` in includes)
|
||||
- **Column name mismatches**: The database uses camelCase for column names (e.g., `agentGraphId` not `agent_graph_id`)
|
||||
@@ -14,14 +14,27 @@ T = TypeVar("T")
|
||||
@functools.cache
|
||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
from backend.data.block import Block
|
||||
from backend.util.settings import Config
|
||||
|
||||
# Check if example blocks should be loaded from settings
|
||||
config = Config()
|
||||
load_examples = config.enable_example_blocks
|
||||
|
||||
# Dynamically load all modules under backend.blocks
|
||||
current_dir = Path(__file__).parent
|
||||
modules = [
|
||||
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
|
||||
for f in current_dir.rglob("*.py")
|
||||
if f.is_file() and f.name != "__init__.py"
|
||||
]
|
||||
modules = []
|
||||
for f in current_dir.rglob("*.py"):
|
||||
if not f.is_file() or f.name == "__init__.py" or f.name.startswith("test_"):
|
||||
continue
|
||||
|
||||
# Skip examples directory if not enabled
|
||||
relative_path = f.relative_to(current_dir)
|
||||
if not load_examples and relative_path.parts[0] == "examples":
|
||||
continue
|
||||
|
||||
module_path = str(relative_path)[:-3].replace(os.path.sep, ".")
|
||||
modules.append(module_path)
|
||||
|
||||
for module in modules:
|
||||
if not re.match("^[a-z0-9_.]+$", module):
|
||||
raise ValueError(
|
||||
|
||||
@@ -14,10 +14,10 @@ from backend.data.block import (
|
||||
get_block,
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util import json, retry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentExecutorBlock(Block):
|
||||
@@ -77,27 +77,42 @@ class AgentExecutorBlock(Block):
|
||||
use_db_query=False,
|
||||
)
|
||||
|
||||
logger = execution_utils.LogMetadata(
|
||||
logger=_logger,
|
||||
user_id=input_data.user_id,
|
||||
graph_eid=graph_exec.id,
|
||||
graph_id=input_data.graph_id,
|
||||
node_eid="*",
|
||||
node_id="*",
|
||||
block_name=self.name,
|
||||
)
|
||||
|
||||
try:
|
||||
async for name, data in self._run(
|
||||
graph_id=input_data.graph_id,
|
||||
graph_version=input_data.graph_version,
|
||||
graph_exec_id=graph_exec.id,
|
||||
user_id=input_data.user_id,
|
||||
logger=logger,
|
||||
):
|
||||
yield name, data
|
||||
except asyncio.CancelledError:
|
||||
logger.warning(
|
||||
f"Execution of graph {input_data.graph_id} version {input_data.graph_version} was cancelled."
|
||||
await self._stop(
|
||||
graph_exec_id=graph_exec.id,
|
||||
user_id=input_data.user_id,
|
||||
logger=logger,
|
||||
)
|
||||
await execution_utils.stop_graph_execution(
|
||||
graph_exec.id, use_db_query=False
|
||||
logger.warning(
|
||||
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} was cancelled."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Execution of graph {input_data.graph_id} version {input_data.graph_version} failed: {e}, stopping execution."
|
||||
await self._stop(
|
||||
graph_exec_id=graph_exec.id,
|
||||
user_id=input_data.user_id,
|
||||
logger=logger,
|
||||
)
|
||||
await execution_utils.stop_graph_execution(
|
||||
graph_exec.id, use_db_query=False
|
||||
logger.error(
|
||||
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} failed: {e}, execution is stopped."
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -107,6 +122,7 @@ class AgentExecutorBlock(Block):
|
||||
graph_version: int,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
logger,
|
||||
) -> BlockOutput:
|
||||
|
||||
from backend.data.execution import ExecutionEventType
|
||||
@@ -135,6 +151,12 @@ class AgentExecutorBlock(Block):
|
||||
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
|
||||
# If the graph execution is COMPLETED, TERMINATED, or FAILED,
|
||||
# we can stop listening for further events.
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
extra_cost=event.stats.cost if event.stats else 0,
|
||||
extra_steps=event.stats.node_exec_count if event.stats else 0,
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
@@ -159,3 +181,25 @@ class AgentExecutorBlock(Block):
|
||||
f"Execution {log_id} produced {output_name}: {output_data}"
|
||||
)
|
||||
yield output_name, output_data
|
||||
|
||||
@retry.func_retry
|
||||
async def _stop(
|
||||
self,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
logger,
|
||||
) -> None:
|
||||
from backend.executor import utils as execution_utils
|
||||
|
||||
log_id = f"Graph exec-id: {graph_exec_id}"
|
||||
logger.info(f"Stopping execution of {log_id}")
|
||||
|
||||
try:
|
||||
await execution_utils.stop_graph_execution(
|
||||
graph_exec_id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
use_db_query=False,
|
||||
)
|
||||
logger.info(f"Execution {log_id} stopped successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop execution {log_id}: {e}")
|
||||
|
||||
@@ -53,6 +53,7 @@ class AudioTrack(str, Enum):
|
||||
REFRESHER = ("Refresher",)
|
||||
TOURIST = ("Tourist",)
|
||||
TWIN_TYCHES = ("Twin Tyches",)
|
||||
DONT_STOP_ME_ABSTRACT_FUTURE_BASS = ("Dont Stop Me Abstract Future Bass",)
|
||||
|
||||
@property
|
||||
def audio_url(self):
|
||||
@@ -78,6 +79,7 @@ class AudioTrack(str, Enum):
|
||||
AudioTrack.REFRESHER: "https://cdn.tfrv.xyz/audio/refresher.mp3",
|
||||
AudioTrack.TOURIST: "https://cdn.tfrv.xyz/audio/tourist.mp3",
|
||||
AudioTrack.TWIN_TYCHES: "https://cdn.tfrv.xyz/audio/twin-tynches.mp3",
|
||||
AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS: "https://cdn.revid.ai/audio/_dont-stop-me-abstract-future-bass.mp3",
|
||||
}
|
||||
return audio_urls[self]
|
||||
|
||||
@@ -105,6 +107,7 @@ class GenerationPreset(str, Enum):
|
||||
MOVIE = ("Movie",)
|
||||
STYLIZED_ILLUSTRATION = ("Stylized Illustration",)
|
||||
MANGA = ("Manga",)
|
||||
DEFAULT = ("DEFAULT",)
|
||||
|
||||
|
||||
class Voice(str, Enum):
|
||||
@@ -114,6 +117,7 @@ class Voice(str, Enum):
|
||||
JESSICA = "Jessica"
|
||||
CHARLOTTE = "Charlotte"
|
||||
CALLUM = "Callum"
|
||||
EVA = "Eva"
|
||||
|
||||
@property
|
||||
def voice_id(self):
|
||||
@@ -124,6 +128,7 @@ class Voice(str, Enum):
|
||||
Voice.JESSICA: "cgSgspJ2msm6clMCkdW9",
|
||||
Voice.CHARLOTTE: "XB0fDUnXU5powFXDhCwa",
|
||||
Voice.CALLUM: "N2lVS1w4EtoT3dr4eOWO",
|
||||
Voice.EVA: "FGY2WhTYpPnrIDTdsKH5",
|
||||
}
|
||||
return voice_id_map[self]
|
||||
|
||||
@@ -141,6 +146,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIShortformVideoCreatorBlock(Block):
|
||||
"""Creates a short‑form text‑to‑video clip using stock or AI imagery."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
@@ -184,40 +191,8 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
video_url: str = SchemaField(description="The URL of the created video")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="361697fb-0c4f-4feb-aed3-8320c88c771b",
|
||||
description="Creates a shortform video using revid.ai",
|
||||
categories={BlockCategory.SOCIAL, BlockCategory.AI},
|
||||
input_schema=AIShortformVideoCreatorBlock.Input,
|
||||
output_schema=AIShortformVideoCreatorBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "[close-up of a cat] Meow!",
|
||||
"ratio": "9 / 16",
|
||||
"resolution": "720p",
|
||||
"frame_rate": 60,
|
||||
"generation_preset": GenerationPreset.LEONARDO,
|
||||
"background_music": AudioTrack.HIGHWAY_NOCTURNE,
|
||||
"voice": Voice.LILY,
|
||||
"video_style": VisualMediaType.STOCK_VIDEOS,
|
||||
},
|
||||
test_output=(
|
||||
"video_url",
|
||||
"https://example.com/video.mp4",
|
||||
),
|
||||
test_mock={
|
||||
"create_webhook": lambda: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda api_key, payload: {"pid": "test_pid"},
|
||||
"wait_for_video": lambda api_key, pid, webhook_token, max_wait_time=1000: "https://example.com/video.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def create_webhook(self):
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = await Requests().post(url, headers=headers)
|
||||
@@ -225,6 +200,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
"""Create a video using the Revid API."""
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().post(url, json=payload, headers=headers)
|
||||
@@ -234,6 +210,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
return response.json()
|
||||
|
||||
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
"""Check the status of a video creation job."""
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().get(url, headers=headers)
|
||||
@@ -243,9 +220,9 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
webhook_token: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
"""Wait for video creation to complete and return the video URL."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = await self.check_video_status(api_key, pid)
|
||||
@@ -266,6 +243,40 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="361697fb-0c4f-4feb-aed3-8320c88c771b",
|
||||
description="Creates a shortform video using revid.ai",
|
||||
categories={BlockCategory.SOCIAL, BlockCategory.AI},
|
||||
input_schema=AIShortformVideoCreatorBlock.Input,
|
||||
output_schema=AIShortformVideoCreatorBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "[close-up of a cat] Meow!",
|
||||
"ratio": "9 / 16",
|
||||
"resolution": "720p",
|
||||
"frame_rate": 60,
|
||||
"generation_preset": GenerationPreset.LEONARDO,
|
||||
"background_music": AudioTrack.HIGHWAY_NOCTURNE,
|
||||
"voice": Voice.LILY,
|
||||
"video_style": VisualMediaType.STOCK_VIDEOS,
|
||||
},
|
||||
test_output=("video_url", "https://example.com/video.mp4"),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/video.mp4",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
@@ -273,20 +284,18 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
logger.debug(f"Webhook URL: {webhook_url}")
|
||||
|
||||
audio_url = input_data.background_music.audio_url
|
||||
|
||||
payload = {
|
||||
"frameRate": input_data.frame_rate,
|
||||
"resolution": input_data.resolution,
|
||||
"frameDurationMultiplier": 18,
|
||||
"webhook": webhook_url,
|
||||
"webhook": None,
|
||||
"creationParams": {
|
||||
"mediaType": input_data.video_style,
|
||||
"captionPresetName": "Wrap 1",
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"hasEnhancedGeneration": True,
|
||||
"generationPreset": input_data.generation_preset.name,
|
||||
"selectedAudio": input_data.background_music,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"origin": "/create",
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
@@ -302,7 +311,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
"selectedStoryStyle": {"value": "custom", "label": "Custom"},
|
||||
"hasToGenerateVideos": input_data.video_style
|
||||
!= VisualMediaType.STOCK_VIDEOS,
|
||||
"audioUrl": audio_url,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -319,8 +328,370 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
logger.debug(
|
||||
f"Video created with project ID: {pid}. Waiting for completion..."
|
||||
)
|
||||
video_url = await self.wait_for_video(
|
||||
credentials.api_key, pid, webhook_token
|
||||
)
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
logger.debug(f"Video ready: {video_url}")
|
||||
yield "video_url", video_url
|
||||
|
||||
|
||||
class AIAdMakerVideoCreatorBlock(Block):
|
||||
"""Generates a 30‑second vertical AI advert using optional user‑supplied imagery."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Credentials for Revid.ai API access.",
|
||||
)
|
||||
script: str = SchemaField(
|
||||
description="Short advertising copy. Line breaks create new scenes.",
|
||||
placeholder="Introducing Foobar – [show product photo] the gadget that does it all.",
|
||||
)
|
||||
ratio: str = SchemaField(description="Aspect ratio", default="9 / 16")
|
||||
target_duration: int = SchemaField(
|
||||
description="Desired length of the ad in seconds.", default=30
|
||||
)
|
||||
voice: Voice = SchemaField(
|
||||
description="Narration voice", default=Voice.EVA, placeholder=Voice.EVA
|
||||
)
|
||||
background_music: AudioTrack = SchemaField(
|
||||
description="Background track",
|
||||
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS,
|
||||
)
|
||||
input_media_urls: list[str] = SchemaField(
|
||||
description="List of image URLs to feature in the advert.", default=[]
|
||||
)
|
||||
use_only_provided_media: bool = SchemaField(
|
||||
description="Restrict visuals to supplied images only.", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="URL of the finished advert")
|
||||
error: str = SchemaField(description="Error message on failure")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = await Requests().post(url, headers=headers)
|
||||
webhook_data = response.json()
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
"""Create a video using the Revid API."""
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().post(url, json=payload, headers=headers)
|
||||
logger.debug(
|
||||
f"API Response Status Code: {response.status}, Content: {response.text}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
"""Check the status of a video creation job."""
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().get(url, headers=headers)
|
||||
return response.json()
|
||||
|
||||
async def wait_for_video(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
"""Wait for video creation to complete and return the video URL."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = await self.check_video_status(api_key, pid)
|
||||
logger.debug(f"Video status: {status}")
|
||||
|
||||
if status.get("status") == "ready" and "videoUrl" in status:
|
||||
return status["videoUrl"]
|
||||
elif status.get("status") == "error":
|
||||
error_message = status.get("error", "Unknown error occurred")
|
||||
logger.error(f"Video creation failed: {error_message}")
|
||||
raise ValueError(f"Video creation failed: {error_message}")
|
||||
elif status.get("status") in ["FAILED", "CANCELED"]:
|
||||
logger.error(f"Video creation failed: {status.get('message')}")
|
||||
raise ValueError(f"Video creation failed: {status.get('message')}")
|
||||
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="58bd2a19-115d-4fd1-8ca4-13b9e37fa6a0",
|
||||
description="Creates an AI‑generated 30‑second advert (text + images)",
|
||||
categories={BlockCategory.MARKETING, BlockCategory.AI},
|
||||
input_schema=AIAdMakerVideoCreatorBlock.Input,
|
||||
output_schema=AIAdMakerVideoCreatorBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "Test product launch!",
|
||||
"input_media_urls": [
|
||||
"https://cdn.revid.ai/uploads/1747076315114-image.png",
|
||||
],
|
||||
},
|
||||
test_output=("video_url", "https://example.com/ad.mp4"),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/ad.mp4",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
|
||||
payload = {
|
||||
"webhook": webhook_url,
|
||||
"creationParams": {
|
||||
"targetDuration": input_data.target_duration,
|
||||
"ratio": input_data.ratio,
|
||||
"mediaType": "aiVideo",
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
"slug": "ai-ad-generator",
|
||||
"slugNew": "",
|
||||
"isCopiedFrom": False,
|
||||
"hasToGenerateVoice": True,
|
||||
"hasToTranscript": False,
|
||||
"hasToSearchMedia": True,
|
||||
"hasAvatar": False,
|
||||
"hasWebsiteRecorder": False,
|
||||
"hasTextSmallAtBottom": False,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
|
||||
"selectedAvatarType": "video/mp4",
|
||||
"websiteToRecord": "",
|
||||
"hasToGenerateCover": True,
|
||||
"nbGenerations": 1,
|
||||
"disableCaptions": False,
|
||||
"mediaMultiplier": "medium",
|
||||
"characters": [],
|
||||
"captionPresetName": "Revid",
|
||||
"sourceType": "contentScraping",
|
||||
"selectedStoryStyle": {"value": "custom", "label": "General"},
|
||||
"generationPreset": "DEFAULT",
|
||||
"hasToGenerateMusic": False,
|
||||
"isOptimizedForChinese": False,
|
||||
"generationUserPrompt": "",
|
||||
"enableNsfwFilter": False,
|
||||
"addStickers": False,
|
||||
"typeMovingImageAnim": "dynamic",
|
||||
"hasToGenerateSoundEffects": False,
|
||||
"forceModelType": "gpt-image-1",
|
||||
"selectedCharacters": [],
|
||||
"lang": "",
|
||||
"voiceSpeed": 1,
|
||||
"disableAudio": False,
|
||||
"disableVoice": False,
|
||||
"useOnlyProvidedMedia": input_data.use_only_provided_media,
|
||||
"imageGenerationModel": "ultra",
|
||||
"videoGenerationModel": "pro",
|
||||
"hasEnhancedGeneration": True,
|
||||
"hasEnhancedGenerationPro": True,
|
||||
"inputMedias": [
|
||||
{"url": url, "title": "", "type": "image"}
|
||||
for url in input_data.input_media_urls
|
||||
],
|
||||
"hasToGenerateVideos": True,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
"watermark": None,
|
||||
},
|
||||
}
|
||||
|
||||
response = await self.create_video(credentials.api_key, payload)
|
||||
pid = response.get("pid")
|
||||
if not pid:
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
yield "video_url", video_url
|
||||
|
||||
|
||||
class AIScreenshotToVideoAdBlock(Block):
|
||||
"""Creates an advert where the supplied screenshot is narrated by an AI avatar."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(description="Revid.ai API key")
|
||||
script: str = SchemaField(
|
||||
description="Narration that will accompany the screenshot.",
|
||||
placeholder="Check out these amazing stats!",
|
||||
)
|
||||
screenshot_url: str = SchemaField(
|
||||
description="Screenshot or image URL to showcase."
|
||||
)
|
||||
ratio: str = SchemaField(default="9 / 16")
|
||||
target_duration: int = SchemaField(default=30)
|
||||
voice: Voice = SchemaField(default=Voice.EVA)
|
||||
background_music: AudioTrack = SchemaField(
|
||||
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="Rendered video URL")
|
||||
error: str = SchemaField(description="Error, if encountered")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = await Requests().post(url, headers=headers)
|
||||
webhook_data = response.json()
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
"""Create a video using the Revid API."""
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().post(url, json=payload, headers=headers)
|
||||
logger.debug(
|
||||
f"API Response Status Code: {response.status}, Content: {response.text}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
"""Check the status of a video creation job."""
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().get(url, headers=headers)
|
||||
return response.json()
|
||||
|
||||
async def wait_for_video(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
"""Wait for video creation to complete and return the video URL."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = await self.check_video_status(api_key, pid)
|
||||
logger.debug(f"Video status: {status}")
|
||||
|
||||
if status.get("status") == "ready" and "videoUrl" in status:
|
||||
return status["videoUrl"]
|
||||
elif status.get("status") == "error":
|
||||
error_message = status.get("error", "Unknown error occurred")
|
||||
logger.error(f"Video creation failed: {error_message}")
|
||||
raise ValueError(f"Video creation failed: {error_message}")
|
||||
elif status.get("status") in ["FAILED", "CANCELED"]:
|
||||
logger.error(f"Video creation failed: {status.get('message')}")
|
||||
raise ValueError(f"Video creation failed: {status.get('message')}")
|
||||
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0f3e4635-e810-43d9-9e81-49e6f4e83b7c",
|
||||
description="Turns a screenshot into an engaging, avatar‑narrated video advert.",
|
||||
categories={BlockCategory.AI, BlockCategory.MARKETING},
|
||||
input_schema=AIScreenshotToVideoAdBlock.Input,
|
||||
output_schema=AIScreenshotToVideoAdBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "Amazing numbers!",
|
||||
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
|
||||
},
|
||||
test_output=("video_url", "https://example.com/screenshot.mp4"),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/screenshot.mp4",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
|
||||
payload = {
|
||||
"webhook": webhook_url,
|
||||
"creationParams": {
|
||||
"targetDuration": input_data.target_duration,
|
||||
"ratio": input_data.ratio,
|
||||
"mediaType": "aiVideo",
|
||||
"hasAvatar": True,
|
||||
"removeAvatarBackground": True,
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
"slug": "ai-ad-generator",
|
||||
"slugNew": "screenshot-to-video-ad",
|
||||
"isCopiedFrom": "ai-ad-generator",
|
||||
"hasToGenerateVoice": True,
|
||||
"hasToTranscript": False,
|
||||
"hasToSearchMedia": True,
|
||||
"hasWebsiteRecorder": False,
|
||||
"hasTextSmallAtBottom": False,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
|
||||
"selectedAvatarType": "video/mp4",
|
||||
"websiteToRecord": "",
|
||||
"hasToGenerateCover": True,
|
||||
"nbGenerations": 1,
|
||||
"disableCaptions": False,
|
||||
"mediaMultiplier": "medium",
|
||||
"characters": [],
|
||||
"captionPresetName": "Revid",
|
||||
"sourceType": "contentScraping",
|
||||
"selectedStoryStyle": {"value": "custom", "label": "General"},
|
||||
"generationPreset": "DEFAULT",
|
||||
"hasToGenerateMusic": False,
|
||||
"isOptimizedForChinese": False,
|
||||
"generationUserPrompt": "",
|
||||
"enableNsfwFilter": False,
|
||||
"addStickers": False,
|
||||
"typeMovingImageAnim": "dynamic",
|
||||
"hasToGenerateSoundEffects": False,
|
||||
"forceModelType": "gpt-image-1",
|
||||
"selectedCharacters": [],
|
||||
"lang": "",
|
||||
"voiceSpeed": 1,
|
||||
"disableAudio": False,
|
||||
"disableVoice": False,
|
||||
"useOnlyProvidedMedia": True,
|
||||
"imageGenerationModel": "ultra",
|
||||
"videoGenerationModel": "ultra",
|
||||
"hasEnhancedGeneration": True,
|
||||
"hasEnhancedGenerationPro": True,
|
||||
"inputMedias": [
|
||||
{"url": input_data.screenshot_url, "title": "", "type": "image"}
|
||||
],
|
||||
"hasToGenerateVideos": True,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
"watermark": None,
|
||||
},
|
||||
}
|
||||
|
||||
response = await self.create_video(credentials.api_key, payload)
|
||||
pid = response.get("pid")
|
||||
if not pid:
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
yield "video_url", video_url
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import List
|
||||
from backend.blocks.apollo._auth import ApolloCredentials
|
||||
from backend.blocks.apollo.models import (
|
||||
Contact,
|
||||
EnrichPersonRequest,
|
||||
Organization,
|
||||
SearchOrganizationsRequest,
|
||||
SearchOrganizationsResponse,
|
||||
@@ -110,3 +111,21 @@ class ApolloClient:
|
||||
return (
|
||||
organizations[: query.max_results] if query.max_results else organizations
|
||||
)
|
||||
|
||||
async def enrich_person(self, query: EnrichPersonRequest) -> Contact:
|
||||
"""Enrich a person's data including email & phone reveal"""
|
||||
response = await self.requests.post(
|
||||
f"{self.API_URL}/people/match",
|
||||
headers=self._get_headers(),
|
||||
json=query.model_dump(),
|
||||
params={
|
||||
"reveal_personal_emails": "true",
|
||||
},
|
||||
)
|
||||
data = response.json()
|
||||
if "person" not in data:
|
||||
raise ValueError(f"Person not found or enrichment failed: {data}")
|
||||
|
||||
contact = Contact(**data["person"])
|
||||
contact.email = contact.email or "-"
|
||||
return contact
|
||||
|
||||
@@ -23,9 +23,9 @@ class BaseModel(OriginalBaseModel):
|
||||
class PrimaryPhone(BaseModel):
|
||||
"""A primary phone in Apollo"""
|
||||
|
||||
number: str = ""
|
||||
source: str = ""
|
||||
sanitized_number: str = ""
|
||||
number: Optional[str] = ""
|
||||
source: Optional[str] = ""
|
||||
sanitized_number: Optional[str] = ""
|
||||
|
||||
|
||||
class SenorityLevels(str, Enum):
|
||||
@@ -56,102 +56,102 @@ class ContactEmailStatuses(str, Enum):
|
||||
class RuleConfigStatus(BaseModel):
|
||||
"""A rule config status in Apollo"""
|
||||
|
||||
_id: str = ""
|
||||
created_at: str = ""
|
||||
rule_action_config_id: str = ""
|
||||
rule_config_id: str = ""
|
||||
status_cd: str = ""
|
||||
updated_at: str = ""
|
||||
id: str = ""
|
||||
key: str = ""
|
||||
_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
rule_action_config_id: Optional[str] = ""
|
||||
rule_config_id: Optional[str] = ""
|
||||
status_cd: Optional[str] = ""
|
||||
updated_at: Optional[str] = ""
|
||||
id: Optional[str] = ""
|
||||
key: Optional[str] = ""
|
||||
|
||||
|
||||
class ContactCampaignStatus(BaseModel):
|
||||
"""A contact campaign status in Apollo"""
|
||||
|
||||
id: str = ""
|
||||
emailer_campaign_id: str = ""
|
||||
send_email_from_user_id: str = ""
|
||||
inactive_reason: str = ""
|
||||
status: str = ""
|
||||
added_at: str = ""
|
||||
added_by_user_id: str = ""
|
||||
finished_at: str = ""
|
||||
paused_at: str = ""
|
||||
auto_unpause_at: str = ""
|
||||
send_email_from_email_address: str = ""
|
||||
send_email_from_email_account_id: str = ""
|
||||
manually_set_unpause: str = ""
|
||||
failure_reason: str = ""
|
||||
current_step_id: str = ""
|
||||
in_response_to_emailer_message_id: str = ""
|
||||
cc_emails: str = ""
|
||||
bcc_emails: str = ""
|
||||
to_emails: str = ""
|
||||
id: Optional[str] = ""
|
||||
emailer_campaign_id: Optional[str] = ""
|
||||
send_email_from_user_id: Optional[str] = ""
|
||||
inactive_reason: Optional[str] = ""
|
||||
status: Optional[str] = ""
|
||||
added_at: Optional[str] = ""
|
||||
added_by_user_id: Optional[str] = ""
|
||||
finished_at: Optional[str] = ""
|
||||
paused_at: Optional[str] = ""
|
||||
auto_unpause_at: Optional[str] = ""
|
||||
send_email_from_email_address: Optional[str] = ""
|
||||
send_email_from_email_account_id: Optional[str] = ""
|
||||
manually_set_unpause: Optional[str] = ""
|
||||
failure_reason: Optional[str] = ""
|
||||
current_step_id: Optional[str] = ""
|
||||
in_response_to_emailer_message_id: Optional[str] = ""
|
||||
cc_emails: Optional[str] = ""
|
||||
bcc_emails: Optional[str] = ""
|
||||
to_emails: Optional[str] = ""
|
||||
|
||||
|
||||
class Account(BaseModel):
|
||||
"""An account in Apollo"""
|
||||
|
||||
id: str = ""
|
||||
name: str = ""
|
||||
website_url: str = ""
|
||||
blog_url: str = ""
|
||||
angellist_url: str = ""
|
||||
linkedin_url: str = ""
|
||||
twitter_url: str = ""
|
||||
facebook_url: str = ""
|
||||
primary_phone: PrimaryPhone = PrimaryPhone()
|
||||
languages: list[str]
|
||||
alexa_ranking: int = 0
|
||||
phone: str = ""
|
||||
linkedin_uid: str = ""
|
||||
founded_year: int = 0
|
||||
publicly_traded_symbol: str = ""
|
||||
publicly_traded_exchange: str = ""
|
||||
logo_url: str = ""
|
||||
chrunchbase_url: str = ""
|
||||
primary_domain: str = ""
|
||||
domain: str = ""
|
||||
team_id: str = ""
|
||||
organization_id: str = ""
|
||||
account_stage_id: str = ""
|
||||
source: str = ""
|
||||
original_source: str = ""
|
||||
creator_id: str = ""
|
||||
owner_id: str = ""
|
||||
created_at: str = ""
|
||||
phone_status: str = ""
|
||||
hubspot_id: str = ""
|
||||
salesforce_id: str = ""
|
||||
crm_owner_id: str = ""
|
||||
parent_account_id: str = ""
|
||||
sanitized_phone: str = ""
|
||||
id: Optional[str] = ""
|
||||
name: Optional[str] = ""
|
||||
website_url: Optional[str] = ""
|
||||
blog_url: Optional[str] = ""
|
||||
angellist_url: Optional[str] = ""
|
||||
linkedin_url: Optional[str] = ""
|
||||
twitter_url: Optional[str] = ""
|
||||
facebook_url: Optional[str] = ""
|
||||
primary_phone: Optional[PrimaryPhone] = PrimaryPhone()
|
||||
languages: Optional[list[str]] = []
|
||||
alexa_ranking: Optional[int] = 0
|
||||
phone: Optional[str] = ""
|
||||
linkedin_uid: Optional[str] = ""
|
||||
founded_year: Optional[int] = 0
|
||||
publicly_traded_symbol: Optional[str] = ""
|
||||
publicly_traded_exchange: Optional[str] = ""
|
||||
logo_url: Optional[str] = ""
|
||||
chrunchbase_url: Optional[str] = ""
|
||||
primary_domain: Optional[str] = ""
|
||||
domain: Optional[str] = ""
|
||||
team_id: Optional[str] = ""
|
||||
organization_id: Optional[str] = ""
|
||||
account_stage_id: Optional[str] = ""
|
||||
source: Optional[str] = ""
|
||||
original_source: Optional[str] = ""
|
||||
creator_id: Optional[str] = ""
|
||||
owner_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
phone_status: Optional[str] = ""
|
||||
hubspot_id: Optional[str] = ""
|
||||
salesforce_id: Optional[str] = ""
|
||||
crm_owner_id: Optional[str] = ""
|
||||
parent_account_id: Optional[str] = ""
|
||||
sanitized_phone: Optional[str] = ""
|
||||
# no listed type on the API docs
|
||||
account_playbook_statues: list[Any] = []
|
||||
account_rule_config_statuses: list[RuleConfigStatus] = []
|
||||
existence_level: str = ""
|
||||
label_ids: list[str] = []
|
||||
typed_custom_fields: Any
|
||||
custom_field_errors: Any
|
||||
modality: str = ""
|
||||
source_display_name: str = ""
|
||||
salesforce_record_id: str = ""
|
||||
crm_record_url: str = ""
|
||||
account_playbook_statues: Optional[list[Any]] = []
|
||||
account_rule_config_statuses: Optional[list[RuleConfigStatus]] = []
|
||||
existence_level: Optional[str] = ""
|
||||
label_ids: Optional[list[str]] = []
|
||||
typed_custom_fields: Optional[Any] = {}
|
||||
custom_field_errors: Optional[Any] = {}
|
||||
modality: Optional[str] = ""
|
||||
source_display_name: Optional[str] = ""
|
||||
salesforce_record_id: Optional[str] = ""
|
||||
crm_record_url: Optional[str] = ""
|
||||
|
||||
|
||||
class ContactEmail(BaseModel):
|
||||
"""A contact email in Apollo"""
|
||||
|
||||
email: str = ""
|
||||
email_md5: str = ""
|
||||
email_sha256: str = ""
|
||||
email_status: str = ""
|
||||
email_source: str = ""
|
||||
extrapolated_email_confidence: str = ""
|
||||
position: int = 0
|
||||
email_from_customer: str = ""
|
||||
free_domain: bool = True
|
||||
email: Optional[str] = ""
|
||||
email_md5: Optional[str] = ""
|
||||
email_sha256: Optional[str] = ""
|
||||
email_status: Optional[str] = ""
|
||||
email_source: Optional[str] = ""
|
||||
extrapolated_email_confidence: Optional[str] = ""
|
||||
position: Optional[int] = 0
|
||||
email_from_customer: Optional[str] = ""
|
||||
free_domain: Optional[bool] = True
|
||||
|
||||
|
||||
class EmploymentHistory(BaseModel):
|
||||
@@ -164,40 +164,40 @@ class EmploymentHistory(BaseModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
_id: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
current: Optional[bool] = None
|
||||
degree: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
emails: Optional[str] = None
|
||||
end_date: Optional[str] = None
|
||||
grade_level: Optional[str] = None
|
||||
kind: Optional[str] = None
|
||||
major: Optional[str] = None
|
||||
organization_id: Optional[str] = None
|
||||
organization_name: Optional[str] = None
|
||||
raw_address: Optional[str] = None
|
||||
start_date: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
id: Optional[str] = None
|
||||
key: Optional[str] = None
|
||||
_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
current: Optional[bool] = False
|
||||
degree: Optional[str] = ""
|
||||
description: Optional[str] = ""
|
||||
emails: Optional[str] = ""
|
||||
end_date: Optional[str] = ""
|
||||
grade_level: Optional[str] = ""
|
||||
kind: Optional[str] = ""
|
||||
major: Optional[str] = ""
|
||||
organization_id: Optional[str] = ""
|
||||
organization_name: Optional[str] = ""
|
||||
raw_address: Optional[str] = ""
|
||||
start_date: Optional[str] = ""
|
||||
title: Optional[str] = ""
|
||||
updated_at: Optional[str] = ""
|
||||
id: Optional[str] = ""
|
||||
key: Optional[str] = ""
|
||||
|
||||
|
||||
class Breadcrumb(BaseModel):
|
||||
"""A breadcrumb in Apollo"""
|
||||
|
||||
label: Optional[str] = "N/A"
|
||||
signal_field_name: Optional[str] = "N/A"
|
||||
value: str | list | None = "N/A"
|
||||
display_name: Optional[str] = "N/A"
|
||||
label: Optional[str] = ""
|
||||
signal_field_name: Optional[str] = ""
|
||||
value: str | list | None = ""
|
||||
display_name: Optional[str] = ""
|
||||
|
||||
|
||||
class TypedCustomField(BaseModel):
|
||||
"""A typed custom field in Apollo"""
|
||||
|
||||
id: Optional[str] = "N/A"
|
||||
value: Optional[str] = "N/A"
|
||||
id: Optional[str] = ""
|
||||
value: Optional[str] = ""
|
||||
|
||||
|
||||
class Pagination(BaseModel):
|
||||
@@ -219,23 +219,23 @@ class Pagination(BaseModel):
|
||||
class DialerFlags(BaseModel):
|
||||
"""A dialer flags in Apollo"""
|
||||
|
||||
country_name: str = ""
|
||||
country_enabled: bool
|
||||
high_risk_calling_enabled: bool
|
||||
potential_high_risk_number: bool
|
||||
country_name: Optional[str] = ""
|
||||
country_enabled: Optional[bool] = True
|
||||
high_risk_calling_enabled: Optional[bool] = True
|
||||
potential_high_risk_number: Optional[bool] = True
|
||||
|
||||
|
||||
class PhoneNumber(BaseModel):
|
||||
"""A phone number in Apollo"""
|
||||
|
||||
raw_number: str = ""
|
||||
sanitized_number: str = ""
|
||||
type: str = ""
|
||||
position: int = 0
|
||||
status: str = ""
|
||||
dnc_status: str = ""
|
||||
dnc_other_info: str = ""
|
||||
dailer_flags: DialerFlags = DialerFlags(
|
||||
raw_number: Optional[str] = ""
|
||||
sanitized_number: Optional[str] = ""
|
||||
type: Optional[str] = ""
|
||||
position: Optional[int] = 0
|
||||
status: Optional[str] = ""
|
||||
dnc_status: Optional[str] = ""
|
||||
dnc_other_info: Optional[str] = ""
|
||||
dailer_flags: Optional[DialerFlags] = DialerFlags(
|
||||
country_name="",
|
||||
country_enabled=True,
|
||||
high_risk_calling_enabled=True,
|
||||
@@ -253,33 +253,31 @@ class Organization(BaseModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
id: Optional[str] = "N/A"
|
||||
name: Optional[str] = "N/A"
|
||||
website_url: Optional[str] = "N/A"
|
||||
blog_url: Optional[str] = "N/A"
|
||||
angellist_url: Optional[str] = "N/A"
|
||||
linkedin_url: Optional[str] = "N/A"
|
||||
twitter_url: Optional[str] = "N/A"
|
||||
facebook_url: Optional[str] = "N/A"
|
||||
primary_phone: Optional[PrimaryPhone] = PrimaryPhone(
|
||||
number="N/A", source="N/A", sanitized_number="N/A"
|
||||
)
|
||||
languages: list[str] = []
|
||||
id: Optional[str] = ""
|
||||
name: Optional[str] = ""
|
||||
website_url: Optional[str] = ""
|
||||
blog_url: Optional[str] = ""
|
||||
angellist_url: Optional[str] = ""
|
||||
linkedin_url: Optional[str] = ""
|
||||
twitter_url: Optional[str] = ""
|
||||
facebook_url: Optional[str] = ""
|
||||
primary_phone: Optional[PrimaryPhone] = PrimaryPhone()
|
||||
languages: Optional[list[str]] = []
|
||||
alexa_ranking: Optional[int] = 0
|
||||
phone: Optional[str] = "N/A"
|
||||
linkedin_uid: Optional[str] = "N/A"
|
||||
phone: Optional[str] = ""
|
||||
linkedin_uid: Optional[str] = ""
|
||||
founded_year: Optional[int] = 0
|
||||
publicly_traded_symbol: Optional[str] = "N/A"
|
||||
publicly_traded_exchange: Optional[str] = "N/A"
|
||||
logo_url: Optional[str] = "N/A"
|
||||
chrunchbase_url: Optional[str] = "N/A"
|
||||
primary_domain: Optional[str] = "N/A"
|
||||
sanitized_phone: Optional[str] = "N/A"
|
||||
owned_by_organization_id: Optional[str] = "N/A"
|
||||
intent_strength: Optional[str] = "N/A"
|
||||
show_intent: bool = True
|
||||
publicly_traded_symbol: Optional[str] = ""
|
||||
publicly_traded_exchange: Optional[str] = ""
|
||||
logo_url: Optional[str] = ""
|
||||
chrunchbase_url: Optional[str] = ""
|
||||
primary_domain: Optional[str] = ""
|
||||
sanitized_phone: Optional[str] = ""
|
||||
owned_by_organization_id: Optional[str] = ""
|
||||
intent_strength: Optional[str] = ""
|
||||
show_intent: Optional[bool] = True
|
||||
has_intent_signal_account: Optional[bool] = True
|
||||
intent_signal_account: Optional[str] = "N/A"
|
||||
intent_signal_account: Optional[str] = ""
|
||||
|
||||
|
||||
class Contact(BaseModel):
|
||||
@@ -292,95 +290,95 @@ class Contact(BaseModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
contact_roles: list[Any] = []
|
||||
id: Optional[str] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
linkedin_url: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
contact_stage_id: Optional[str] = None
|
||||
owner_id: Optional[str] = None
|
||||
creator_id: Optional[str] = None
|
||||
person_id: Optional[str] = None
|
||||
email_needs_tickling: bool = True
|
||||
organization_name: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
original_source: Optional[str] = None
|
||||
organization_id: Optional[str] = None
|
||||
headline: Optional[str] = None
|
||||
photo_url: Optional[str] = None
|
||||
present_raw_address: Optional[str] = None
|
||||
linkededin_uid: Optional[str] = None
|
||||
extrapolated_email_confidence: Optional[float] = None
|
||||
salesforce_id: Optional[str] = None
|
||||
salesforce_lead_id: Optional[str] = None
|
||||
salesforce_contact_id: Optional[str] = None
|
||||
saleforce_account_id: Optional[str] = None
|
||||
crm_owner_id: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
emailer_campaign_ids: list[str] = []
|
||||
direct_dial_status: Optional[str] = None
|
||||
direct_dial_enrichment_failed_at: Optional[str] = None
|
||||
email_status: Optional[str] = None
|
||||
email_source: Optional[str] = None
|
||||
account_id: Optional[str] = None
|
||||
last_activity_date: Optional[str] = None
|
||||
hubspot_vid: Optional[str] = None
|
||||
hubspot_company_id: Optional[str] = None
|
||||
crm_id: Optional[str] = None
|
||||
sanitized_phone: Optional[str] = None
|
||||
merged_crm_ids: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
queued_for_crm_push: bool = True
|
||||
suggested_from_rule_engine_config_id: Optional[str] = None
|
||||
email_unsubscribed: Optional[str] = None
|
||||
label_ids: list[Any] = []
|
||||
has_pending_email_arcgate_request: bool = True
|
||||
has_email_arcgate_request: bool = True
|
||||
existence_level: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
email_from_customer: Optional[str] = None
|
||||
typed_custom_fields: list[TypedCustomField] = []
|
||||
custom_field_errors: Any = None
|
||||
salesforce_record_id: Optional[str] = None
|
||||
crm_record_url: Optional[str] = None
|
||||
email_status_unavailable_reason: Optional[str] = None
|
||||
email_true_status: Optional[str] = None
|
||||
updated_email_true_status: bool = True
|
||||
contact_rule_config_statuses: list[RuleConfigStatus] = []
|
||||
source_display_name: Optional[str] = None
|
||||
twitter_url: Optional[str] = None
|
||||
contact_campaign_statuses: list[ContactCampaignStatus] = []
|
||||
state: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
account: Optional[Account] = None
|
||||
contact_emails: list[ContactEmail] = []
|
||||
organization: Optional[Organization] = None
|
||||
employment_history: list[EmploymentHistory] = []
|
||||
time_zone: Optional[str] = None
|
||||
intent_strength: Optional[str] = None
|
||||
show_intent: bool = True
|
||||
phone_numbers: list[PhoneNumber] = []
|
||||
account_phone_note: Optional[str] = None
|
||||
free_domain: bool = True
|
||||
is_likely_to_engage: bool = True
|
||||
email_domain_catchall: bool = True
|
||||
contact_job_change_event: Optional[str] = None
|
||||
contact_roles: Optional[list[Any]] = []
|
||||
id: Optional[str] = ""
|
||||
first_name: Optional[str] = ""
|
||||
last_name: Optional[str] = ""
|
||||
name: Optional[str] = ""
|
||||
linkedin_url: Optional[str] = ""
|
||||
title: Optional[str] = ""
|
||||
contact_stage_id: Optional[str] = ""
|
||||
owner_id: Optional[str] = ""
|
||||
creator_id: Optional[str] = ""
|
||||
person_id: Optional[str] = ""
|
||||
email_needs_tickling: Optional[bool] = True
|
||||
organization_name: Optional[str] = ""
|
||||
source: Optional[str] = ""
|
||||
original_source: Optional[str] = ""
|
||||
organization_id: Optional[str] = ""
|
||||
headline: Optional[str] = ""
|
||||
photo_url: Optional[str] = ""
|
||||
present_raw_address: Optional[str] = ""
|
||||
linkededin_uid: Optional[str] = ""
|
||||
extrapolated_email_confidence: Optional[float] = 0.0
|
||||
salesforce_id: Optional[str] = ""
|
||||
salesforce_lead_id: Optional[str] = ""
|
||||
salesforce_contact_id: Optional[str] = ""
|
||||
saleforce_account_id: Optional[str] = ""
|
||||
crm_owner_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
emailer_campaign_ids: Optional[list[str]] = []
|
||||
direct_dial_status: Optional[str] = ""
|
||||
direct_dial_enrichment_failed_at: Optional[str] = ""
|
||||
email_status: Optional[str] = ""
|
||||
email_source: Optional[str] = ""
|
||||
account_id: Optional[str] = ""
|
||||
last_activity_date: Optional[str] = ""
|
||||
hubspot_vid: Optional[str] = ""
|
||||
hubspot_company_id: Optional[str] = ""
|
||||
crm_id: Optional[str] = ""
|
||||
sanitized_phone: Optional[str] = ""
|
||||
merged_crm_ids: Optional[str] = ""
|
||||
updated_at: Optional[str] = ""
|
||||
queued_for_crm_push: Optional[bool] = True
|
||||
suggested_from_rule_engine_config_id: Optional[str] = ""
|
||||
email_unsubscribed: Optional[str] = ""
|
||||
label_ids: Optional[list[Any]] = []
|
||||
has_pending_email_arcgate_request: Optional[bool] = True
|
||||
has_email_arcgate_request: Optional[bool] = True
|
||||
existence_level: Optional[str] = ""
|
||||
email: Optional[str] = ""
|
||||
email_from_customer: Optional[str] = ""
|
||||
typed_custom_fields: Optional[list[TypedCustomField]] = []
|
||||
custom_field_errors: Optional[Any] = {}
|
||||
salesforce_record_id: Optional[str] = ""
|
||||
crm_record_url: Optional[str] = ""
|
||||
email_status_unavailable_reason: Optional[str] = ""
|
||||
email_true_status: Optional[str] = ""
|
||||
updated_email_true_status: Optional[bool] = True
|
||||
contact_rule_config_statuses: Optional[list[RuleConfigStatus]] = []
|
||||
source_display_name: Optional[str] = ""
|
||||
twitter_url: Optional[str] = ""
|
||||
contact_campaign_statuses: Optional[list[ContactCampaignStatus]] = []
|
||||
state: Optional[str] = ""
|
||||
city: Optional[str] = ""
|
||||
country: Optional[str] = ""
|
||||
account: Optional[Account] = Account()
|
||||
contact_emails: Optional[list[ContactEmail]] = []
|
||||
organization: Optional[Organization] = Organization()
|
||||
employment_history: Optional[list[EmploymentHistory]] = []
|
||||
time_zone: Optional[str] = ""
|
||||
intent_strength: Optional[str] = ""
|
||||
show_intent: Optional[bool] = True
|
||||
phone_numbers: Optional[list[PhoneNumber]] = []
|
||||
account_phone_note: Optional[str] = ""
|
||||
free_domain: Optional[bool] = True
|
||||
is_likely_to_engage: Optional[bool] = True
|
||||
email_domain_catchall: Optional[bool] = True
|
||||
contact_job_change_event: Optional[str] = ""
|
||||
|
||||
|
||||
class SearchOrganizationsRequest(BaseModel):
|
||||
"""Request for Apollo's search organizations API"""
|
||||
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
organization_num_employees_range: Optional[list[int]] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default=[0, 1000000],
|
||||
)
|
||||
|
||||
organization_locations: list[str] = SchemaField(
|
||||
organization_locations: Optional[list[str]] = SchemaField(
|
||||
description="""The location of the company headquarters. You can search across cities, US states, and countries.
|
||||
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, any Boston-based companies will not appearch in your search results, even if they match other parameters.
|
||||
@@ -389,28 +387,30 @@ To exclude companies based on location, use the organization_not_locations param
|
||||
""",
|
||||
default_factory=list,
|
||||
)
|
||||
organizations_not_locations: list[str] = SchemaField(
|
||||
organizations_not_locations: Optional[list[str]] = SchemaField(
|
||||
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
|
||||
|
||||
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
|
||||
""",
|
||||
default_factory=list,
|
||||
)
|
||||
q_organization_keyword_tags: list[str] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry."""
|
||||
q_organization_keyword_tags: Optional[list[str]] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry.""",
|
||||
default_factory=list,
|
||||
)
|
||||
q_organization_name: str = SchemaField(
|
||||
q_organization_name: Optional[str] = SchemaField(
|
||||
description="""Filter search results to include a specific company name.
|
||||
|
||||
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible."""
|
||||
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible.""",
|
||||
default="",
|
||||
)
|
||||
organization_ids: list[str] = SchemaField(
|
||||
organization_ids: Optional[list[str]] = SchemaField(
|
||||
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, identify the values for organization_id when you call this endpoint.""",
|
||||
default_factory=list,
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
max_results: Optional[int] = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
ge=1,
|
||||
@@ -435,11 +435,11 @@ Use the page parameter to search the different pages of data.""",
|
||||
class SearchOrganizationsResponse(BaseModel):
|
||||
"""Response from Apollo's search organizations API"""
|
||||
|
||||
breadcrumbs: list[Breadcrumb] = []
|
||||
partial_results_only: bool = True
|
||||
has_join: bool = True
|
||||
disable_eu_prospecting: bool = True
|
||||
partial_results_limit: int = 0
|
||||
breadcrumbs: Optional[list[Breadcrumb]] = []
|
||||
partial_results_only: Optional[bool] = True
|
||||
has_join: Optional[bool] = True
|
||||
disable_eu_prospecting: Optional[bool] = True
|
||||
partial_results_limit: Optional[int] = 0
|
||||
pagination: Pagination = Pagination(
|
||||
page=0, per_page=0, total_entries=0, total_pages=0
|
||||
)
|
||||
@@ -447,14 +447,14 @@ class SearchOrganizationsResponse(BaseModel):
|
||||
accounts: list[Any] = []
|
||||
organizations: list[Organization] = []
|
||||
models_ids: list[str] = []
|
||||
num_fetch_result: Optional[str] = "N/A"
|
||||
derived_params: Optional[str] = "N/A"
|
||||
num_fetch_result: Optional[str] = ""
|
||||
derived_params: Optional[str] = ""
|
||||
|
||||
|
||||
class SearchPeopleRequest(BaseModel):
|
||||
"""Request for Apollo's search people API"""
|
||||
|
||||
person_titles: list[str] = SchemaField(
|
||||
person_titles: Optional[list[str]] = SchemaField(
|
||||
description="""Job titles held by the people you want to find. For a person to be included in search results, they only need to match 1 of the job titles you add. Adding more job titles expands your search results.
|
||||
|
||||
Results also include job titles with the same terms, even if they are not exact matches. For example, searching for marketing manager might return people with the job title content marketing manager.
|
||||
@@ -464,13 +464,13 @@ Use this parameter in combination with the person_seniorities[] parameter to fin
|
||||
default_factory=list,
|
||||
placeholder="marketing manager",
|
||||
)
|
||||
person_locations: list[str] = SchemaField(
|
||||
person_locations: Optional[list[str]] = SchemaField(
|
||||
description="""The location where people live. You can search across cities, US states, and countries.
|
||||
|
||||
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
|
||||
default_factory=list,
|
||||
)
|
||||
person_seniorities: list[SenorityLevels] = SchemaField(
|
||||
person_seniorities: Optional[list[SenorityLevels]] = SchemaField(
|
||||
description="""The job seniority that people hold within their current employer. This enables you to find people that currently hold positions at certain reporting levels, such as Director level or senior IC level.
|
||||
|
||||
For a person to be included in search results, they only need to match 1 of the seniorities you add. Adding more seniorities expands your search results.
|
||||
@@ -480,7 +480,7 @@ Searches only return results based on their current job title, so searching for
|
||||
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
|
||||
default_factory=list,
|
||||
)
|
||||
organization_locations: list[str] = SchemaField(
|
||||
organization_locations: Optional[list[str]] = SchemaField(
|
||||
description="""The location of the company headquarters for a person's current employer. You can search across cities, US states, and countries.
|
||||
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
|
||||
@@ -488,7 +488,7 @@ If a company has several office locations, results are still based on the headqu
|
||||
To find people based on their personal location, use the person_locations parameter.""",
|
||||
default_factory=list,
|
||||
)
|
||||
q_organization_domains: list[str] = SchemaField(
|
||||
q_organization_domains: Optional[list[str]] = SchemaField(
|
||||
description="""The domain name for the person's employer. This can be the current employer or a previous employer. Do not include www., the @ symbol, or similar.
|
||||
|
||||
You can add multiple domains to search across companies.
|
||||
@@ -496,23 +496,23 @@ You can add multiple domains to search across companies.
|
||||
Examples: apollo.io and microsoft.com""",
|
||||
default_factory=list,
|
||||
)
|
||||
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
|
||||
contact_email_statuses: Optional[list[ContactEmailStatuses]] = SchemaField(
|
||||
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
|
||||
default_factory=list,
|
||||
)
|
||||
organization_ids: list[str] = SchemaField(
|
||||
organization_ids: Optional[list[str]] = SchemaField(
|
||||
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
|
||||
default_factory=list,
|
||||
)
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
organization_num_employees_range: Optional[list[int]] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default_factory=list,
|
||||
)
|
||||
q_keywords: str = SchemaField(
|
||||
q_keywords: Optional[str] = SchemaField(
|
||||
description="""A string of words over which we want to filter the results""",
|
||||
default="",
|
||||
)
|
||||
@@ -528,7 +528,7 @@ Use this parameter in combination with the per_page parameter to make search res
|
||||
Use the page parameter to search the different pages of data.""",
|
||||
default=100,
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
max_results: Optional[int] = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
ge=1,
|
||||
@@ -547,16 +547,61 @@ class SearchPeopleResponse(BaseModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
breadcrumbs: list[Breadcrumb] = []
|
||||
partial_results_only: bool = True
|
||||
has_join: bool = True
|
||||
disable_eu_prospecting: bool = True
|
||||
partial_results_limit: int = 0
|
||||
breadcrumbs: Optional[list[Breadcrumb]] = []
|
||||
partial_results_only: Optional[bool] = True
|
||||
has_join: Optional[bool] = True
|
||||
disable_eu_prospecting: Optional[bool] = True
|
||||
partial_results_limit: Optional[int] = 0
|
||||
pagination: Pagination = Pagination(
|
||||
page=0, per_page=0, total_entries=0, total_pages=0
|
||||
)
|
||||
contacts: list[Contact] = []
|
||||
people: list[Contact] = []
|
||||
model_ids: list[str] = []
|
||||
num_fetch_result: Optional[str] = "N/A"
|
||||
derived_params: Optional[str] = "N/A"
|
||||
num_fetch_result: Optional[str] = ""
|
||||
derived_params: Optional[str] = ""
|
||||
|
||||
|
||||
class EnrichPersonRequest(BaseModel):
|
||||
"""Request for Apollo's person enrichment API"""
|
||||
|
||||
person_id: Optional[str] = SchemaField(
|
||||
description="Apollo person ID to enrich (most accurate method)",
|
||||
default="",
|
||||
)
|
||||
first_name: Optional[str] = SchemaField(
|
||||
description="First name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
last_name: Optional[str] = SchemaField(
|
||||
description="Last name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
name: Optional[str] = SchemaField(
|
||||
description="Full name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
email: Optional[str] = SchemaField(
|
||||
description="Email address of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
domain: Optional[str] = SchemaField(
|
||||
description="Company domain of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
company: Optional[str] = SchemaField(
|
||||
description="Company name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
linkedin_url: Optional[str] = SchemaField(
|
||||
description="LinkedIn URL of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
organization_id: Optional[str] = SchemaField(
|
||||
description="Apollo organization ID of the person's company",
|
||||
default="",
|
||||
)
|
||||
title: Optional[str] = SchemaField(
|
||||
description="Job title of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
|
||||
@@ -11,14 +11,14 @@ from backend.blocks.apollo.models import (
|
||||
SearchOrganizationsRequest,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class SearchOrganizationsBlock(Block):
|
||||
"""Search for organizations in Apollo"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
organization_num_employees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
@@ -65,7 +65,7 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
le=50000,
|
||||
advanced=True,
|
||||
)
|
||||
credentials: ApolloCredentialsInput = SchemaField(
|
||||
credentials: ApolloCredentialsInput = CredentialsField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
|
||||
from backend.blocks.apollo._api import ApolloClient
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
@@ -8,11 +10,12 @@ from backend.blocks.apollo._auth import (
|
||||
from backend.blocks.apollo.models import (
|
||||
Contact,
|
||||
ContactEmailStatuses,
|
||||
EnrichPersonRequest,
|
||||
SearchPeopleRequest,
|
||||
SenorityLevels,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class SearchPeopleBlock(Block):
|
||||
@@ -77,7 +80,7 @@ class SearchPeopleBlock(Block):
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
organization_num_employees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
@@ -90,14 +93,19 @@ class SearchPeopleBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 25. Limited to 500 to prevent overspending.""",
|
||||
default=25,
|
||||
ge=1,
|
||||
le=50000,
|
||||
le=500,
|
||||
advanced=True,
|
||||
)
|
||||
enrich_info: bool = SchemaField(
|
||||
description="""Whether to enrich contacts with detailed information including real email addresses. This will double the search cost.""",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
credentials: ApolloCredentialsInput = SchemaField(
|
||||
credentials: ApolloCredentialsInput = CredentialsField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
@@ -106,10 +114,6 @@ class SearchPeopleBlock(Block):
|
||||
description="List of people found",
|
||||
default_factory=list,
|
||||
)
|
||||
person: Contact = SchemaField(
|
||||
title="Person",
|
||||
description="Each found person, one at a time",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the search failed",
|
||||
default="",
|
||||
@@ -125,87 +129,6 @@ class SearchPeopleBlock(Block):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
(
|
||||
"person",
|
||||
Contact(
|
||||
contact_roles=[],
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
organization_id="123456",
|
||||
contact_stage_id="1",
|
||||
owner_id="1",
|
||||
creator_id="1",
|
||||
person_id="1",
|
||||
email_needs_tickling=True,
|
||||
source="apollo",
|
||||
original_source="apollo",
|
||||
headline="Software Engineer",
|
||||
photo_url="https://www.linkedin.com/in/johndoe",
|
||||
present_raw_address="123 Main St, Anytown, USA",
|
||||
linkededin_uid="123456",
|
||||
extrapolated_email_confidence=0.8,
|
||||
salesforce_id="123456",
|
||||
salesforce_lead_id="123456",
|
||||
salesforce_contact_id="123456",
|
||||
saleforce_account_id="123456",
|
||||
crm_owner_id="123456",
|
||||
created_at="2021-01-01",
|
||||
emailer_campaign_ids=[],
|
||||
direct_dial_status="active",
|
||||
direct_dial_enrichment_failed_at="2021-01-01",
|
||||
email_status="active",
|
||||
email_source="apollo",
|
||||
account_id="123456",
|
||||
last_activity_date="2021-01-01",
|
||||
hubspot_vid="123456",
|
||||
hubspot_company_id="123456",
|
||||
crm_id="123456",
|
||||
sanitized_phone="123456",
|
||||
merged_crm_ids="123456",
|
||||
updated_at="2021-01-01",
|
||||
queued_for_crm_push=True,
|
||||
suggested_from_rule_engine_config_id="123456",
|
||||
email_unsubscribed=None,
|
||||
label_ids=[],
|
||||
has_pending_email_arcgate_request=True,
|
||||
has_email_arcgate_request=True,
|
||||
existence_level=None,
|
||||
email=None,
|
||||
email_from_customer=None,
|
||||
typed_custom_fields=[],
|
||||
custom_field_errors=None,
|
||||
salesforce_record_id=None,
|
||||
crm_record_url=None,
|
||||
email_status_unavailable_reason=None,
|
||||
email_true_status=None,
|
||||
updated_email_true_status=True,
|
||||
contact_rule_config_statuses=[],
|
||||
source_display_name=None,
|
||||
twitter_url=None,
|
||||
contact_campaign_statuses=[],
|
||||
state=None,
|
||||
city=None,
|
||||
country=None,
|
||||
account=None,
|
||||
contact_emails=[],
|
||||
organization=None,
|
||||
employment_history=[],
|
||||
time_zone=None,
|
||||
intent_strength=None,
|
||||
show_intent=True,
|
||||
phone_numbers=[],
|
||||
account_phone_note=None,
|
||||
free_domain=True,
|
||||
is_likely_to_engage=True,
|
||||
email_domain_catchall=True,
|
||||
contact_job_change_event=None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"people",
|
||||
[
|
||||
@@ -380,6 +303,34 @@ class SearchPeopleBlock(Block):
|
||||
client = ApolloClient(credentials)
|
||||
return await client.search_people(query)
|
||||
|
||||
@staticmethod
|
||||
async def enrich_person(
|
||||
query: EnrichPersonRequest, credentials: ApolloCredentials
|
||||
) -> Contact:
|
||||
client = ApolloClient(credentials)
|
||||
return await client.enrich_person(query)
|
||||
|
||||
@staticmethod
|
||||
def merge_contact_data(original: Contact, enriched: Contact) -> Contact:
|
||||
"""
|
||||
Merge contact data from original search with enriched data.
|
||||
Enriched data complements original data, only filling in missing values.
|
||||
"""
|
||||
merged_data = original.model_dump()
|
||||
enriched_data = enriched.model_dump()
|
||||
|
||||
# Only update fields that are None, empty string, empty list, or default values in original
|
||||
for key, enriched_value in enriched_data.items():
|
||||
# Skip if enriched value is None, empty string, or empty list
|
||||
if enriched_value is None or enriched_value == "" or enriched_value == []:
|
||||
continue
|
||||
|
||||
# Update if original value is None, empty string, empty list, or zero
|
||||
if enriched_value:
|
||||
merged_data[key] = enriched_value
|
||||
|
||||
return Contact(**merged_data)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
@@ -390,6 +341,23 @@ class SearchPeopleBlock(Block):
|
||||
|
||||
query = SearchPeopleRequest(**input_data.model_dump())
|
||||
people = await self.search_people(query, credentials)
|
||||
for person in people:
|
||||
yield "person", person
|
||||
|
||||
# Enrich with detailed info if requested
|
||||
if input_data.enrich_info:
|
||||
|
||||
async def enrich_or_fallback(person: Contact):
|
||||
try:
|
||||
enrich_query = EnrichPersonRequest(person_id=person.id)
|
||||
enriched_person = await self.enrich_person(
|
||||
enrich_query, credentials
|
||||
)
|
||||
# Merge enriched data with original data, complementing instead of replacing
|
||||
return self.merge_contact_data(person, enriched_person)
|
||||
except Exception:
|
||||
return person # If enrichment fails, use original person data
|
||||
|
||||
people = await asyncio.gather(
|
||||
*(enrich_or_fallback(person) for person in people)
|
||||
)
|
||||
|
||||
yield "people", people
|
||||
|
||||
138
autogpt_platform/backend/backend/blocks/apollo/person.py
Normal file
138
autogpt_platform/backend/backend/blocks/apollo/person.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from backend.blocks.apollo._api import ApolloClient
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
ApolloCredentials,
|
||||
ApolloCredentialsInput,
|
||||
)
|
||||
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class GetPersonDetailBlock(Block):
|
||||
"""Get detailed person data with Apollo API, including email reveal"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
person_id: str = SchemaField(
|
||||
description="Apollo person ID to enrich (most accurate method)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
first_name: str = SchemaField(
|
||||
description="First name of the person to enrich",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
last_name: str = SchemaField(
|
||||
description="Last name of the person to enrich",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
name: str = SchemaField(
|
||||
description="Full name of the person to enrich (alternative to first_name + last_name)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
email: str = SchemaField(
|
||||
description="Known email address of the person (helps with matching)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
domain: str = SchemaField(
|
||||
description="Company domain of the person (e.g., 'google.com')",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
company: str = SchemaField(
|
||||
description="Company name of the person",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
linkedin_url: str = SchemaField(
|
||||
description="LinkedIn URL of the person",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
organization_id: str = SchemaField(
|
||||
description="Apollo organization ID of the person's company",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
title: str = SchemaField(
|
||||
description="Job title of the person to enrich",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
credentials: ApolloCredentialsInput = CredentialsField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
contact: Contact = SchemaField(
|
||||
description="Enriched contact information",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if enrichment failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3b18d46c-3db6-42ae-a228-0ba441bdd176",
|
||||
description="Get detailed person data with Apollo API, including email reveal",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=GetPersonDetailBlock.Input,
|
||||
output_schema=GetPersonDetailBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"first_name": "John",
|
||||
"last_name": "Doe",
|
||||
"company": "Google",
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"contact",
|
||||
Contact(
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
email="john.doe@gmail.com",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
),
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"enrich_person": lambda query, credentials: Contact(
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
email="john.doe@gmail.com",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def enrich_person(
|
||||
query: EnrichPersonRequest, credentials: ApolloCredentials
|
||||
) -> Contact:
|
||||
client = ApolloClient(credentials)
|
||||
return await client.enrich_person(query)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: ApolloCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
query = EnrichPersonRequest(**input_data.model_dump())
|
||||
yield "contact", await self.enrich_person(query, credentials)
|
||||
@@ -1,11 +1,9 @@
|
||||
import enum
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.type import MediaFileType, convert
|
||||
|
||||
|
||||
@@ -120,266 +118,6 @@ class PrintToConsoleBlock(Block):
|
||||
yield "status", "printed"
|
||||
|
||||
|
||||
class FindInDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: Any = SchemaField(description="Dictionary to lookup from")
|
||||
key: str | int = SchemaField(description="Key to lookup in the dictionary")
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="Value found for the given key")
|
||||
missing: Any = SchemaField(
|
||||
description="Value of the input that missing the key"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0e50422c-6dee-4145-83d6-3a5a392f65de",
|
||||
description="Lookup the given key in the input dictionary/object/list and return the value.",
|
||||
input_schema=FindInDictionaryBlock.Input,
|
||||
output_schema=FindInDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{"input": {"apple": 1, "banana": 2, "cherry": 3}, "key": "banana"},
|
||||
{"input": {"x": 10, "y": 20, "z": 30}, "key": "w"},
|
||||
{"input": [1, 2, 3], "key": 1},
|
||||
{"input": [1, 2, 3], "key": 3},
|
||||
{"input": MockObject(value="!!", key="key"), "key": "key"},
|
||||
{"input": [{"k1": "v1"}, {"k2": "v2"}, {"k1": "v3"}], "key": "k1"},
|
||||
],
|
||||
test_output=[
|
||||
("output", 2),
|
||||
("missing", {"x": 10, "y": 20, "z": 30}),
|
||||
("output", 2),
|
||||
("missing", [1, 2, 3]),
|
||||
("output", "key"),
|
||||
("output", ["v1", "v3"]),
|
||||
],
|
||||
categories={BlockCategory.BASIC},
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
obj = input_data.input
|
||||
key = input_data.key
|
||||
|
||||
if isinstance(obj, str):
|
||||
obj = json.loads(obj)
|
||||
|
||||
if isinstance(obj, dict) and key in obj:
|
||||
yield "output", obj[key]
|
||||
elif isinstance(obj, list) and isinstance(key, int) and 0 <= key < len(obj):
|
||||
yield "output", obj[key]
|
||||
elif isinstance(obj, list) and isinstance(key, str):
|
||||
if len(obj) == 0:
|
||||
yield "output", []
|
||||
elif isinstance(obj[0], dict) and key in obj[0]:
|
||||
yield "output", [item[key] for item in obj if key in item]
|
||||
else:
|
||||
yield "output", [getattr(val, key) for val in obj if hasattr(val, key)]
|
||||
elif isinstance(obj, object) and isinstance(key, str) and hasattr(obj, key):
|
||||
yield "output", getattr(obj, key)
|
||||
else:
|
||||
yield "missing", input_data.input
|
||||
|
||||
|
||||
class AddToDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
|
||||
)
|
||||
key: str = SchemaField(
|
||||
default="",
|
||||
description="The key for the new entry.",
|
||||
placeholder="new_key",
|
||||
advanced=False,
|
||||
)
|
||||
value: Any = SchemaField(
|
||||
default=None,
|
||||
description="The value for the new entry.",
|
||||
placeholder="new_value",
|
||||
advanced=False,
|
||||
)
|
||||
entries: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
description="The entries to add to the dictionary. This is the batch version of the `key` and `value` fields.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_dictionary: dict = SchemaField(
|
||||
description="The dictionary with the new entry added."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="31d1064e-7446-4693-a7d4-65e5ca1180d1",
|
||||
description="Adds a new key-value pair to a dictionary. If no dictionary is provided, a new one is created.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=AddToDictionaryBlock.Input,
|
||||
output_schema=AddToDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"dictionary": {"existing_key": "existing_value"},
|
||||
"key": "new_key",
|
||||
"value": "new_value",
|
||||
},
|
||||
{"key": "first_key", "value": "first_value"},
|
||||
{
|
||||
"dictionary": {"existing_key": "existing_value"},
|
||||
"entries": {"new_key": "new_value", "first_key": "first_value"},
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"updated_dictionary",
|
||||
{"existing_key": "existing_value", "new_key": "new_value"},
|
||||
),
|
||||
("updated_dictionary", {"first_key": "first_value"}),
|
||||
(
|
||||
"updated_dictionary",
|
||||
{
|
||||
"existing_key": "existing_value",
|
||||
"new_key": "new_value",
|
||||
"first_key": "first_value",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
updated_dict = input_data.dictionary.copy()
|
||||
|
||||
if input_data.value is not None and input_data.key:
|
||||
updated_dict[input_data.key] = input_data.value
|
||||
|
||||
for key, value in input_data.entries.items():
|
||||
updated_dict[key] = value
|
||||
|
||||
yield "updated_dictionary", updated_dict
|
||||
|
||||
|
||||
class AddToListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
description="The list to add the entry to. If not provided, a new list will be created.",
|
||||
)
|
||||
entry: Any = SchemaField(
|
||||
description="The entry to add to the list. Can be of any type (string, int, dict, etc.).",
|
||||
advanced=False,
|
||||
default=None,
|
||||
)
|
||||
entries: List[Any] = SchemaField(
|
||||
default_factory=lambda: list(),
|
||||
description="The entries to add to the list. This is the batch version of the `entry` field.",
|
||||
advanced=True,
|
||||
)
|
||||
position: int | None = SchemaField(
|
||||
default=None,
|
||||
description="The position to insert the new entry. If not provided, the entry will be appended to the end of the list.",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_list: List[Any] = SchemaField(
|
||||
description="The list with the new entry added."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="aeb08fc1-2fc1-4141-bc8e-f758f183a822",
|
||||
description="Adds a new entry to a list. The entry can be of any type. If no list is provided, a new one is created.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=AddToListBlock.Input,
|
||||
output_schema=AddToListBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"list": [1, "string", {"existing_key": "existing_value"}],
|
||||
"entry": {"new_key": "new_value"},
|
||||
"position": 1,
|
||||
},
|
||||
{"entry": "first_entry"},
|
||||
{"list": ["a", "b", "c"], "entry": "d"},
|
||||
{
|
||||
"entry": "e",
|
||||
"entries": ["f", "g"],
|
||||
"list": ["a", "b"],
|
||||
"position": 1,
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"updated_list",
|
||||
[
|
||||
1,
|
||||
{"new_key": "new_value"},
|
||||
"string",
|
||||
{"existing_key": "existing_value"},
|
||||
],
|
||||
),
|
||||
("updated_list", ["first_entry"]),
|
||||
("updated_list", ["a", "b", "c", "d"]),
|
||||
("updated_list", ["a", "f", "g", "e", "b"]),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
entries_added = input_data.entries.copy()
|
||||
if input_data.entry:
|
||||
entries_added.append(input_data.entry)
|
||||
|
||||
updated_list = input_data.list.copy()
|
||||
if (pos := input_data.position) is not None:
|
||||
updated_list = updated_list[:pos] + entries_added + updated_list[pos:]
|
||||
else:
|
||||
updated_list += entries_added
|
||||
|
||||
yield "updated_list", updated_list
|
||||
|
||||
|
||||
class FindInListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to search in.")
|
||||
value: Any = SchemaField(description="The value to search for.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
index: int = SchemaField(description="The index of the value in the list.")
|
||||
found: bool = SchemaField(
|
||||
description="Whether the value was found in the list."
|
||||
)
|
||||
not_found_value: Any = SchemaField(
|
||||
description="The value that was not found in the list."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5e2c6d0a-1e37-489f-b1d0-8e1812b23333",
|
||||
description="Finds the index of the value in the list.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=FindInListBlock.Input,
|
||||
output_schema=FindInListBlock.Output,
|
||||
test_input=[
|
||||
{"list": [1, 2, 3, 4, 5], "value": 3},
|
||||
{"list": [1, 2, 3, 4, 5], "value": 6},
|
||||
],
|
||||
test_output=[
|
||||
("index", 2),
|
||||
("found", True),
|
||||
("found", False),
|
||||
("not_found_value", 6),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
yield "index", input_data.list.index(input_data.value)
|
||||
yield "found", True
|
||||
except ValueError:
|
||||
yield "found", False
|
||||
yield "not_found_value", input_data.value
|
||||
|
||||
|
||||
class NoteBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(description="The text to display in the sticky note.")
|
||||
@@ -405,110 +143,6 @@ class NoteBlock(Block):
|
||||
yield "output", input_data.text
|
||||
|
||||
|
||||
class CreateDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: dict[str, Any] = SchemaField(
|
||||
description="Key-value pairs to create the dictionary with",
|
||||
placeholder="e.g., {'name': 'Alice', 'age': 25}",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
dictionary: dict[str, Any] = SchemaField(
|
||||
description="The created dictionary containing the specified key-value pairs"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if dictionary creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b924ddf4-de4f-4b56-9a85-358930dcbc91",
|
||||
description="Creates a dictionary with the specified key-value pairs. Use this when you know all the values you want to add upfront.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=CreateDictionaryBlock.Input,
|
||||
output_schema=CreateDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"values": {"name": "Alice", "age": 25, "city": "New York"},
|
||||
},
|
||||
{
|
||||
"values": {"numbers": [1, 2, 3], "active": True, "score": 95.5},
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"dictionary",
|
||||
{"name": "Alice", "age": 25, "city": "New York"},
|
||||
),
|
||||
(
|
||||
"dictionary",
|
||||
{"numbers": [1, 2, 3], "active": True, "score": 95.5},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# The values are already validated by Pydantic schema
|
||||
yield "dictionary", input_data.values
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create dictionary: {str(e)}"
|
||||
|
||||
|
||||
class CreateListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: List[Any] = SchemaField(
|
||||
description="A list of values to be combined into a new list.",
|
||||
placeholder="e.g., ['Alice', 25, True]",
|
||||
)
|
||||
max_size: int | None = SchemaField(
|
||||
default=None,
|
||||
description="Maximum size of the list. If provided, the list will be yielded in chunks of this size.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
description="The created list containing the specified values."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if list creation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a912d5c7-6e00-4542-b2a9-8034136930e4",
|
||||
description="Creates a list with the specified values. Use this when you know all the values you want to add upfront.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=CreateListBlock.Input,
|
||||
output_schema=CreateListBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"values": ["Alice", 25, True],
|
||||
},
|
||||
{
|
||||
"values": [1, 2, 3, "four", {"key": "value"}],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"list",
|
||||
["Alice", 25, True],
|
||||
),
|
||||
(
|
||||
"list",
|
||||
[1, 2, 3, "four", {"key": "value"}],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
max_size = input_data.max_size or len(input_data.values)
|
||||
for i in range(0, len(input_data.values), max_size):
|
||||
yield "list", input_data.values[i : i + max_size]
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create list: {str(e)}"
|
||||
|
||||
|
||||
class TypeOptions(enum.Enum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
@@ -526,6 +160,7 @@ class UniversalTypeConverterBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
value: Any = SchemaField(description="The converted value.")
|
||||
error: str = SchemaField(description="Error message if conversion failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.type import convert
|
||||
|
||||
|
||||
class ComparisonOperator(Enum):
|
||||
@@ -181,7 +182,23 @@ class IfInputMatchesBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
if input_data.input == input_data.value or input_data.input is input_data.value:
|
||||
|
||||
# If input_data.value is not matching input_data.input, convert value to type of input
|
||||
if (
|
||||
input_data.input != input_data.value
|
||||
and input_data.input is not input_data.value
|
||||
):
|
||||
try:
|
||||
# Only attempt conversion if input is not None and value is not None
|
||||
if input_data.input is not None and input_data.value is not None:
|
||||
input_type = type(input_data.input)
|
||||
# Avoid converting if input_type is Any or object
|
||||
if input_type not in (Any, object):
|
||||
input_data.value = convert(input_data.value, input_type)
|
||||
except Exception:
|
||||
pass # If conversion fails, just leave value as is
|
||||
|
||||
if input_data.input == input_data.value:
|
||||
yield "result", True
|
||||
yield "yes_output", input_data.yes_value
|
||||
else:
|
||||
|
||||
683
autogpt_platform/backend/backend/blocks/data_manipulation.py
Normal file
683
autogpt_platform/backend/backend/blocks/data_manipulation.py
Normal file
@@ -0,0 +1,683 @@
|
||||
from typing import Any, List
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.json import loads
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.prompt import estimate_token_count_str
|
||||
|
||||
# =============================================================================
|
||||
# Dictionary Manipulation Blocks
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CreateDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: dict[str, Any] = SchemaField(
|
||||
description="Key-value pairs to create the dictionary with",
|
||||
placeholder="e.g., {'name': 'Alice', 'age': 25}",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
dictionary: dict[str, Any] = SchemaField(
|
||||
description="The created dictionary containing the specified key-value pairs"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if dictionary creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b924ddf4-de4f-4b56-9a85-358930dcbc91",
|
||||
description="Creates a dictionary with the specified key-value pairs. Use this when you know all the values you want to add upfront.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=CreateDictionaryBlock.Input,
|
||||
output_schema=CreateDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"values": {"name": "Alice", "age": 25, "city": "New York"},
|
||||
},
|
||||
{
|
||||
"values": {"numbers": [1, 2, 3], "active": True, "score": 95.5},
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"dictionary",
|
||||
{"name": "Alice", "age": 25, "city": "New York"},
|
||||
),
|
||||
(
|
||||
"dictionary",
|
||||
{"numbers": [1, 2, 3], "active": True, "score": 95.5},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# The values are already validated by Pydantic schema
|
||||
yield "dictionary", input_data.values
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create dictionary: {str(e)}"
|
||||
|
||||
|
||||
class AddToDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
|
||||
)
|
||||
key: str = SchemaField(
|
||||
default="",
|
||||
description="The key for the new entry.",
|
||||
placeholder="new_key",
|
||||
advanced=False,
|
||||
)
|
||||
value: Any = SchemaField(
|
||||
default=None,
|
||||
description="The value for the new entry.",
|
||||
placeholder="new_value",
|
||||
advanced=False,
|
||||
)
|
||||
entries: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
description="The entries to add to the dictionary. This is the batch version of the `key` and `value` fields.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_dictionary: dict = SchemaField(
|
||||
description="The dictionary with the new entry added."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="31d1064e-7446-4693-a7d4-65e5ca1180d1",
|
||||
description="Adds a new key-value pair to a dictionary. If no dictionary is provided, a new one is created.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=AddToDictionaryBlock.Input,
|
||||
output_schema=AddToDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"dictionary": {"existing_key": "existing_value"},
|
||||
"key": "new_key",
|
||||
"value": "new_value",
|
||||
},
|
||||
{"key": "first_key", "value": "first_value"},
|
||||
{
|
||||
"dictionary": {"existing_key": "existing_value"},
|
||||
"entries": {"new_key": "new_value", "first_key": "first_value"},
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"updated_dictionary",
|
||||
{"existing_key": "existing_value", "new_key": "new_value"},
|
||||
),
|
||||
("updated_dictionary", {"first_key": "first_value"}),
|
||||
(
|
||||
"updated_dictionary",
|
||||
{
|
||||
"existing_key": "existing_value",
|
||||
"new_key": "new_value",
|
||||
"first_key": "first_value",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
updated_dict = input_data.dictionary.copy()
|
||||
|
||||
if input_data.value is not None and input_data.key:
|
||||
updated_dict[input_data.key] = input_data.value
|
||||
|
||||
for key, value in input_data.entries.items():
|
||||
updated_dict[key] = value
|
||||
|
||||
yield "updated_dictionary", updated_dict
|
||||
|
||||
|
||||
class FindInDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: Any = SchemaField(description="Dictionary to lookup from")
|
||||
key: str | int = SchemaField(description="Key to lookup in the dictionary")
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="Value found for the given key")
|
||||
missing: Any = SchemaField(
|
||||
description="Value of the input that missing the key"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0e50422c-6dee-4145-83d6-3a5a392f65de",
|
||||
description="Lookup the given key in the input dictionary/object/list and return the value.",
|
||||
input_schema=FindInDictionaryBlock.Input,
|
||||
output_schema=FindInDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{"input": {"apple": 1, "banana": 2, "cherry": 3}, "key": "banana"},
|
||||
{"input": {"x": 10, "y": 20, "z": 30}, "key": "w"},
|
||||
{"input": [1, 2, 3], "key": 1},
|
||||
{"input": [1, 2, 3], "key": 3},
|
||||
{"input": MockObject(value="!!", key="key"), "key": "key"},
|
||||
{"input": [{"k1": "v1"}, {"k2": "v2"}, {"k1": "v3"}], "key": "k1"},
|
||||
],
|
||||
test_output=[
|
||||
("output", 2),
|
||||
("missing", {"x": 10, "y": 20, "z": 30}),
|
||||
("output", 2),
|
||||
("missing", [1, 2, 3]),
|
||||
("output", "key"),
|
||||
("output", ["v1", "v3"]),
|
||||
],
|
||||
categories={BlockCategory.BASIC},
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
obj = input_data.input
|
||||
key = input_data.key
|
||||
|
||||
if isinstance(obj, str):
|
||||
obj = loads(obj)
|
||||
|
||||
if isinstance(obj, dict) and key in obj:
|
||||
yield "output", obj[key]
|
||||
elif isinstance(obj, list) and isinstance(key, int) and 0 <= key < len(obj):
|
||||
yield "output", obj[key]
|
||||
elif isinstance(obj, list) and isinstance(key, str):
|
||||
if len(obj) == 0:
|
||||
yield "output", []
|
||||
elif isinstance(obj[0], dict) and key in obj[0]:
|
||||
yield "output", [item[key] for item in obj if key in item]
|
||||
else:
|
||||
yield "output", [getattr(val, key) for val in obj if hasattr(val, key)]
|
||||
elif isinstance(obj, object) and isinstance(key, str) and hasattr(obj, key):
|
||||
yield "output", getattr(obj, key)
|
||||
else:
|
||||
yield "missing", input_data.input
|
||||
|
||||
|
||||
class RemoveFromDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary to modify."
|
||||
)
|
||||
key: str | int = SchemaField(description="Key to remove from the dictionary.")
|
||||
return_value: bool = SchemaField(
|
||||
default=False, description="Whether to return the removed value."
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary after removal."
|
||||
)
|
||||
removed_value: Any = SchemaField(description="The removed value if requested.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="46afe2ea-c613-43f8-95ff-6692c3ef6876",
|
||||
description="Removes a key-value pair from a dictionary.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=RemoveFromDictionaryBlock.Input,
|
||||
output_schema=RemoveFromDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"dictionary": {"a": 1, "b": 2, "c": 3},
|
||||
"key": "b",
|
||||
"return_value": True,
|
||||
},
|
||||
{"dictionary": {"x": "hello", "y": "world"}, "key": "x"},
|
||||
],
|
||||
test_output=[
|
||||
("updated_dictionary", {"a": 1, "c": 3}),
|
||||
("removed_value", 2),
|
||||
("updated_dictionary", {"y": "world"}),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
updated_dict = input_data.dictionary.copy()
|
||||
try:
|
||||
removed_value = updated_dict.pop(input_data.key)
|
||||
yield "updated_dictionary", updated_dict
|
||||
if input_data.return_value:
|
||||
yield "removed_value", removed_value
|
||||
except KeyError:
|
||||
yield "error", f"Key '{input_data.key}' not found in dictionary"
|
||||
|
||||
|
||||
class ReplaceDictionaryValueBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary to modify."
|
||||
)
|
||||
key: str | int = SchemaField(description="Key to replace the value for.")
|
||||
value: Any = SchemaField(description="The new value for the given key.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary after replacement."
|
||||
)
|
||||
old_value: Any = SchemaField(description="The value that was replaced.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="27e31876-18b6-44f3-ab97-f6226d8b3889",
|
||||
description="Replaces the value for a specified key in a dictionary.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=ReplaceDictionaryValueBlock.Input,
|
||||
output_schema=ReplaceDictionaryValueBlock.Output,
|
||||
test_input=[
|
||||
{"dictionary": {"a": 1, "b": 2, "c": 3}, "key": "b", "value": 99},
|
||||
{
|
||||
"dictionary": {"x": "hello", "y": "world"},
|
||||
"key": "y",
|
||||
"value": "universe",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("updated_dictionary", {"a": 1, "b": 99, "c": 3}),
|
||||
("old_value", 2),
|
||||
("updated_dictionary", {"x": "hello", "y": "universe"}),
|
||||
("old_value", "world"),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
updated_dict = input_data.dictionary.copy()
|
||||
try:
|
||||
old_value = updated_dict[input_data.key]
|
||||
updated_dict[input_data.key] = input_data.value
|
||||
yield "updated_dictionary", updated_dict
|
||||
yield "old_value", old_value
|
||||
except KeyError:
|
||||
yield "error", f"Key '{input_data.key}' not found in dictionary"
|
||||
|
||||
|
||||
class DictionaryIsEmptyBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(description="The dictionary to check.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
is_empty: bool = SchemaField(description="True if the dictionary is empty.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a3cf3f64-6bb9-4cc6-9900-608a0b3359b0",
|
||||
description="Checks if a dictionary is empty.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=DictionaryIsEmptyBlock.Input,
|
||||
output_schema=DictionaryIsEmptyBlock.Output,
|
||||
test_input=[{"dictionary": {}}, {"dictionary": {"a": 1}}],
|
||||
test_output=[("is_empty", True), ("is_empty", False)],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "is_empty", len(input_data.dictionary) == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# List Manipulation Blocks
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CreateListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: List[Any] = SchemaField(
|
||||
description="A list of values to be combined into a new list.",
|
||||
placeholder="e.g., ['Alice', 25, True]",
|
||||
)
|
||||
max_size: int | None = SchemaField(
|
||||
default=None,
|
||||
description="Maximum size of the list. If provided, the list will be yielded in chunks of this size.",
|
||||
advanced=True,
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
default=None,
|
||||
description="Maximum tokens for the list. If provided, the list will be yielded in chunks that fit within this token limit.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
description="The created list containing the specified values."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if list creation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a912d5c7-6e00-4542-b2a9-8034136930e4",
|
||||
description="Creates a list with the specified values. Use this when you know all the values you want to add upfront. This block can also yield the list in batches based on a maximum size or token limit.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=CreateListBlock.Input,
|
||||
output_schema=CreateListBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"values": ["Alice", 25, True],
|
||||
},
|
||||
{
|
||||
"values": [1, 2, 3, "four", {"key": "value"}],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"list",
|
||||
["Alice", 25, True],
|
||||
),
|
||||
(
|
||||
"list",
|
||||
[1, 2, 3, "four", {"key": "value"}],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
chunk = []
|
||||
cur_tokens, max_tokens = 0, input_data.max_tokens
|
||||
cur_size, max_size = 0, input_data.max_size
|
||||
|
||||
for value in input_data.values:
|
||||
if max_tokens:
|
||||
tokens = estimate_token_count_str(value)
|
||||
else:
|
||||
tokens = 0
|
||||
|
||||
# Check if adding this value would exceed either limit
|
||||
if (max_tokens and (cur_tokens + tokens > max_tokens)) or (
|
||||
max_size and (cur_size + 1 > max_size)
|
||||
):
|
||||
yield "list", chunk
|
||||
chunk = [value]
|
||||
cur_size, cur_tokens = 1, tokens
|
||||
else:
|
||||
chunk.append(value)
|
||||
cur_size, cur_tokens = cur_size + 1, cur_tokens + tokens
|
||||
|
||||
# Yield final chunk if any
|
||||
if chunk or not input_data.values:
|
||||
yield "list", chunk
|
||||
|
||||
|
||||
class AddToListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
description="The list to add the entry to. If not provided, a new list will be created.",
|
||||
)
|
||||
entry: Any = SchemaField(
|
||||
description="The entry to add to the list. Can be of any type (string, int, dict, etc.).",
|
||||
advanced=False,
|
||||
default=None,
|
||||
)
|
||||
entries: List[Any] = SchemaField(
|
||||
default_factory=lambda: list(),
|
||||
description="The entries to add to the list. This is the batch version of the `entry` field.",
|
||||
advanced=True,
|
||||
)
|
||||
position: int | None = SchemaField(
|
||||
default=None,
|
||||
description="The position to insert the new entry. If not provided, the entry will be appended to the end of the list.",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_list: List[Any] = SchemaField(
|
||||
description="The list with the new entry added."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="aeb08fc1-2fc1-4141-bc8e-f758f183a822",
|
||||
description="Adds a new entry to a list. The entry can be of any type. If no list is provided, a new one is created.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=AddToListBlock.Input,
|
||||
output_schema=AddToListBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"list": [1, "string", {"existing_key": "existing_value"}],
|
||||
"entry": {"new_key": "new_value"},
|
||||
"position": 1,
|
||||
},
|
||||
{"entry": "first_entry"},
|
||||
{"list": ["a", "b", "c"], "entry": "d"},
|
||||
{
|
||||
"entry": "e",
|
||||
"entries": ["f", "g"],
|
||||
"list": ["a", "b"],
|
||||
"position": 1,
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"updated_list",
|
||||
[
|
||||
1,
|
||||
{"new_key": "new_value"},
|
||||
"string",
|
||||
{"existing_key": "existing_value"},
|
||||
],
|
||||
),
|
||||
("updated_list", ["first_entry"]),
|
||||
("updated_list", ["a", "b", "c", "d"]),
|
||||
("updated_list", ["a", "f", "g", "e", "b"]),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
entries_added = input_data.entries.copy()
|
||||
if input_data.entry:
|
||||
entries_added.append(input_data.entry)
|
||||
|
||||
updated_list = input_data.list.copy()
|
||||
if (pos := input_data.position) is not None:
|
||||
updated_list = updated_list[:pos] + entries_added + updated_list[pos:]
|
||||
else:
|
||||
updated_list += entries_added
|
||||
|
||||
yield "updated_list", updated_list
|
||||
|
||||
|
||||
class FindInListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to search in.")
|
||||
value: Any = SchemaField(description="The value to search for.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
index: int = SchemaField(description="The index of the value in the list.")
|
||||
found: bool = SchemaField(
|
||||
description="Whether the value was found in the list."
|
||||
)
|
||||
not_found_value: Any = SchemaField(
|
||||
description="The value that was not found in the list."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5e2c6d0a-1e37-489f-b1d0-8e1812b23333",
|
||||
description="Finds the index of the value in the list.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=FindInListBlock.Input,
|
||||
output_schema=FindInListBlock.Output,
|
||||
test_input=[
|
||||
{"list": [1, 2, 3, 4, 5], "value": 3},
|
||||
{"list": [1, 2, 3, 4, 5], "value": 6},
|
||||
],
|
||||
test_output=[
|
||||
("index", 2),
|
||||
("found", True),
|
||||
("found", False),
|
||||
("not_found_value", 6),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
yield "index", input_data.list.index(input_data.value)
|
||||
yield "found", True
|
||||
except ValueError:
|
||||
yield "found", False
|
||||
yield "not_found_value", input_data.value
|
||||
|
||||
|
||||
class GetListItemBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to get the item from.")
|
||||
index: int = SchemaField(
|
||||
description="The 0-based index of the item (supports negative indices)."
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
item: Any = SchemaField(description="The item at the specified index.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="262ca24c-1025-43cf-a578-534e23234e97",
|
||||
description="Returns the element at the given index.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=GetListItemBlock.Input,
|
||||
output_schema=GetListItemBlock.Output,
|
||||
test_input=[
|
||||
{"list": [1, 2, 3], "index": 1},
|
||||
{"list": [1, 2, 3], "index": -1},
|
||||
],
|
||||
test_output=[
|
||||
("item", 2),
|
||||
("item", 3),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
yield "item", input_data.list[input_data.index]
|
||||
except IndexError:
|
||||
yield "error", "Index out of range"
|
||||
|
||||
|
||||
class RemoveFromListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to modify.")
|
||||
value: Any = SchemaField(
|
||||
default=None, description="Value to remove from the list."
|
||||
)
|
||||
index: int | None = SchemaField(
|
||||
default=None,
|
||||
description="Index of the item to pop (supports negative indices).",
|
||||
)
|
||||
return_item: bool = SchemaField(
|
||||
default=False, description="Whether to return the removed item."
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_list: List[Any] = SchemaField(description="The list after removal.")
|
||||
removed_item: Any = SchemaField(description="The removed item if requested.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d93c5a93-ac7e-41c1-ae5c-ef67e6e9b826",
|
||||
description="Removes an item from a list by value or index.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=RemoveFromListBlock.Input,
|
||||
output_schema=RemoveFromListBlock.Output,
|
||||
test_input=[
|
||||
{"list": [1, 2, 3], "index": 1, "return_item": True},
|
||||
{"list": ["a", "b", "c"], "value": "b"},
|
||||
],
|
||||
test_output=[
|
||||
("updated_list", [1, 3]),
|
||||
("removed_item", 2),
|
||||
("updated_list", ["a", "c"]),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
lst = input_data.list.copy()
|
||||
removed = None
|
||||
try:
|
||||
if input_data.index is not None:
|
||||
removed = lst.pop(input_data.index)
|
||||
elif input_data.value is not None:
|
||||
lst.remove(input_data.value)
|
||||
removed = input_data.value
|
||||
else:
|
||||
raise ValueError("No index or value provided for removal")
|
||||
except (IndexError, ValueError):
|
||||
yield "error", "Index or value not found"
|
||||
return
|
||||
|
||||
yield "updated_list", lst
|
||||
if input_data.return_item:
|
||||
yield "removed_item", removed
|
||||
|
||||
|
||||
class ReplaceListItemBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to modify.")
|
||||
index: int = SchemaField(
|
||||
description="Index of the item to replace (supports negative indices)."
|
||||
)
|
||||
value: Any = SchemaField(description="The new value for the given index.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_list: List[Any] = SchemaField(description="The list after replacement.")
|
||||
old_item: Any = SchemaField(description="The item that was replaced.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fbf62922-bea1-4a3d-8bac-23587f810b38",
|
||||
description="Replaces an item at the specified index.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=ReplaceListItemBlock.Input,
|
||||
output_schema=ReplaceListItemBlock.Output,
|
||||
test_input=[
|
||||
{"list": [1, 2, 3], "index": 1, "value": 99},
|
||||
{"list": ["a", "b"], "index": -1, "value": "c"},
|
||||
],
|
||||
test_output=[
|
||||
("updated_list", [1, 99, 3]),
|
||||
("old_item", 2),
|
||||
("updated_list", ["a", "c"]),
|
||||
("old_item", "b"),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
lst = input_data.list.copy()
|
||||
try:
|
||||
old = lst[input_data.index]
|
||||
lst[input_data.index] = input_data.value
|
||||
except IndexError:
|
||||
yield "error", "Index out of range"
|
||||
return
|
||||
|
||||
yield "updated_list", lst
|
||||
yield "old_item", old
|
||||
|
||||
|
||||
class ListIsEmptyBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to check.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
is_empty: bool = SchemaField(description="True if the list is empty.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="896ed73b-27d0-41be-813c-c1c1dc856c03",
|
||||
description="Checks if a list is empty.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=ListIsEmptyBlock.Input,
|
||||
output_schema=ListIsEmptyBlock.Output,
|
||||
test_input=[{"list": []}, {"list": [1]}],
|
||||
test_output=[("is_empty", True), ("is_empty", False)],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "is_empty", len(input_data.list) == 0
|
||||
@@ -1,32 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
ExaCredentials = APIKeyCredentials
|
||||
ExaCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.EXA],
|
||||
Literal["api_key"],
|
||||
]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="exa",
|
||||
api_key=SecretStr("mock-exa-api-key"),
|
||||
title="Mock Exa API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
def ExaCredentialsField() -> ExaCredentialsInput:
|
||||
"""Creates an Exa credentials input on a block."""
|
||||
return CredentialsField(description="The Exa integration requires an API Key.")
|
||||
16
autogpt_platform/backend/backend/blocks/exa/_config.py
Normal file
16
autogpt_platform/backend/backend/blocks/exa/_config.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Shared configuration for all Exa blocks using the new SDK pattern.
|
||||
"""
|
||||
|
||||
from backend.sdk import BlockCostType, ProviderBuilder
|
||||
|
||||
from ._webhook import ExaWebhookManager
|
||||
|
||||
# Configure the Exa provider once for all blocks
|
||||
exa = (
|
||||
ProviderBuilder("exa")
|
||||
.with_api_key("EXA_API_KEY", "Exa API Key")
|
||||
.with_webhook_manager(ExaWebhookManager)
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
.build()
|
||||
)
|
||||
134
autogpt_platform/backend/backend/blocks/exa/_webhook.py
Normal file
134
autogpt_platform/backend/backend/blocks/exa/_webhook.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
Exa Webhook Manager implementation.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
from enum import Enum
|
||||
|
||||
from backend.data.model import Credentials
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseWebhooksManager,
|
||||
ProviderName,
|
||||
Requests,
|
||||
Webhook,
|
||||
)
|
||||
|
||||
|
||||
class ExaWebhookType(str, Enum):
|
||||
"""Available webhook types for Exa."""
|
||||
|
||||
WEBSET = "webset"
|
||||
|
||||
|
||||
class ExaEventType(str, Enum):
|
||||
"""Available event types for Exa webhooks."""
|
||||
|
||||
WEBSET_CREATED = "webset.created"
|
||||
WEBSET_DELETED = "webset.deleted"
|
||||
WEBSET_PAUSED = "webset.paused"
|
||||
WEBSET_IDLE = "webset.idle"
|
||||
WEBSET_SEARCH_CREATED = "webset.search.created"
|
||||
WEBSET_SEARCH_CANCELED = "webset.search.canceled"
|
||||
WEBSET_SEARCH_COMPLETED = "webset.search.completed"
|
||||
WEBSET_SEARCH_UPDATED = "webset.search.updated"
|
||||
IMPORT_CREATED = "import.created"
|
||||
IMPORT_COMPLETED = "import.completed"
|
||||
IMPORT_PROCESSING = "import.processing"
|
||||
WEBSET_ITEM_CREATED = "webset.item.created"
|
||||
WEBSET_ITEM_ENRICHED = "webset.item.enriched"
|
||||
WEBSET_EXPORT_CREATED = "webset.export.created"
|
||||
WEBSET_EXPORT_COMPLETED = "webset.export.completed"
|
||||
|
||||
|
||||
class ExaWebhookManager(BaseWebhooksManager):
|
||||
"""Webhook manager for Exa API."""
|
||||
|
||||
PROVIDER_NAME = ProviderName("exa")
|
||||
|
||||
class WebhookType(str, Enum):
|
||||
WEBSET = "webset"
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(cls, webhook: Webhook, request) -> tuple[dict, str]:
|
||||
"""Validate incoming webhook payload and signature."""
|
||||
payload = await request.json()
|
||||
|
||||
# Get event type from payload
|
||||
event_type = payload.get("eventType", "unknown")
|
||||
|
||||
# Verify webhook signature if secret is available
|
||||
if webhook.secret:
|
||||
signature = request.headers.get("X-Exa-Signature")
|
||||
if signature:
|
||||
# Compute expected signature
|
||||
body = await request.body()
|
||||
expected_signature = hmac.new(
|
||||
webhook.secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Compare signatures
|
||||
if not hmac.compare_digest(signature, expected_signature):
|
||||
raise ValueError("Invalid webhook signature")
|
||||
|
||||
return payload, event_type
|
||||
|
||||
async def _register_webhook(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
ingress_url: str,
|
||||
secret: str,
|
||||
) -> tuple[str, dict]:
|
||||
"""Register webhook with Exa API."""
|
||||
if not isinstance(credentials, APIKeyCredentials):
|
||||
raise ValueError("Exa webhooks require API key credentials")
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Create webhook via Exa API
|
||||
response = await Requests().post(
|
||||
"https://api.exa.ai/v0/webhooks",
|
||||
headers={"x-api-key": api_key},
|
||||
json={
|
||||
"url": ingress_url,
|
||||
"events": events,
|
||||
"metadata": {
|
||||
"resource": resource,
|
||||
"webhook_type": webhook_type,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
error_data = response.json()
|
||||
raise Exception(f"Failed to create Exa webhook: {error_data}")
|
||||
|
||||
webhook_data = response.json()
|
||||
|
||||
# Store the secret returned by Exa
|
||||
return webhook_data["id"], {
|
||||
"events": events,
|
||||
"resource": resource,
|
||||
"exa_secret": webhook_data.get("secret"),
|
||||
}
|
||||
|
||||
async def _deregister_webhook(
|
||||
self, webhook: Webhook, credentials: Credentials
|
||||
) -> None:
|
||||
"""Deregister webhook from Exa API."""
|
||||
if not isinstance(credentials, APIKeyCredentials):
|
||||
raise ValueError("Exa webhooks require API key credentials")
|
||||
api_key = credentials.api_key.get_secret_value()
|
||||
|
||||
# Delete webhook via Exa API
|
||||
response = await Requests().delete(
|
||||
f"https://api.exa.ai/v0/webhooks/{webhook.provider_webhook_id}",
|
||||
headers={"x-api-key": api_key},
|
||||
)
|
||||
|
||||
if not response.ok and response.status != 404:
|
||||
error_data = response.json()
|
||||
raise Exception(f"Failed to delete Exa webhook: {error_data}")
|
||||
124
autogpt_platform/backend/backend/blocks/exa/answers.py
Normal file
124
autogpt_platform/backend/backend/blocks/exa/answers.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseModel,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
|
||||
|
||||
class CostBreakdown(BaseModel):
|
||||
keywordSearch: float
|
||||
neuralSearch: float
|
||||
contentText: float
|
||||
contentHighlight: float
|
||||
contentSummary: float
|
||||
|
||||
|
||||
class SearchBreakdown(BaseModel):
|
||||
search: float
|
||||
contents: float
|
||||
breakdown: CostBreakdown
|
||||
|
||||
|
||||
class PerRequestPrices(BaseModel):
|
||||
neuralSearch_1_25_results: float
|
||||
neuralSearch_26_100_results: float
|
||||
neuralSearch_100_plus_results: float
|
||||
keywordSearch_1_100_results: float
|
||||
keywordSearch_100_plus_results: float
|
||||
|
||||
|
||||
class PerPagePrices(BaseModel):
|
||||
contentText: float
|
||||
contentHighlight: float
|
||||
contentSummary: float
|
||||
|
||||
|
||||
class CostDollars(BaseModel):
|
||||
total: float
|
||||
breakDown: list[SearchBreakdown]
|
||||
perRequestPrices: PerRequestPrices
|
||||
perPagePrices: PerPagePrices
|
||||
|
||||
|
||||
class ExaAnswerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
query: str = SchemaField(
|
||||
description="The question or query to answer",
|
||||
placeholder="What is the latest valuation of SpaceX?",
|
||||
)
|
||||
text: bool = SchemaField(
|
||||
default=False,
|
||||
description="If true, the response includes full text content in the search results",
|
||||
advanced=True,
|
||||
)
|
||||
model: str = SchemaField(
|
||||
default="exa",
|
||||
description="The search model to use (exa or exa-pro)",
|
||||
placeholder="exa",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
answer: str = SchemaField(
|
||||
description="The generated answer based on search results"
|
||||
)
|
||||
citations: list[dict] = SchemaField(
|
||||
description="Search results used to generate the answer",
|
||||
default_factory=list,
|
||||
)
|
||||
cost_dollars: CostDollars = SchemaField(
|
||||
description="Cost breakdown of the request"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b79ca4cc-9d5e-47d1-9d4f-e3a2d7f28df5",
|
||||
description="Get an LLM answer to a question informed by Exa search results",
|
||||
categories={BlockCategory.SEARCH, BlockCategory.AI},
|
||||
input_schema=ExaAnswerBlock.Input,
|
||||
output_schema=ExaAnswerBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/answer"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
# Build the payload
|
||||
payload = {
|
||||
"query": input_data.query,
|
||||
"text": input_data.text,
|
||||
"model": input_data.model,
|
||||
}
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
|
||||
yield "answer", data.get("answer", "")
|
||||
yield "citations", data.get("citations", [])
|
||||
yield "cost_dollars", data.get("costDollars", {})
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "answer", ""
|
||||
yield "citations", []
|
||||
yield "cost_dollars", {}
|
||||
@@ -1,57 +1,39 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.blocks.exa._auth import (
|
||||
ExaCredentials,
|
||||
ExaCredentialsField,
|
||||
ExaCredentialsInput,
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
class ContentRetrievalSettings(BaseModel):
|
||||
text: dict = SchemaField(
|
||||
description="Text content settings",
|
||||
default={"maxCharacters": 1000, "includeHtmlTags": False},
|
||||
advanced=True,
|
||||
)
|
||||
highlights: dict = SchemaField(
|
||||
description="Highlight settings",
|
||||
default={
|
||||
"numSentences": 3,
|
||||
"highlightsPerUrl": 3,
|
||||
"query": "",
|
||||
},
|
||||
advanced=True,
|
||||
)
|
||||
summary: dict = SchemaField(
|
||||
description="Summary settings",
|
||||
default={"query": ""},
|
||||
advanced=True,
|
||||
)
|
||||
from ._config import exa
|
||||
from .helpers import ContentSettings
|
||||
|
||||
|
||||
class ExaContentsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: ExaCredentialsInput = ExaCredentialsField()
|
||||
ids: List[str] = SchemaField(
|
||||
description="Array of document IDs obtained from searches",
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
contents: ContentRetrievalSettings = SchemaField(
|
||||
ids: list[str] = SchemaField(
|
||||
description="Array of document IDs obtained from searches"
|
||||
)
|
||||
contents: ContentSettings = SchemaField(
|
||||
description="Content retrieval settings",
|
||||
default=ContentRetrievalSettings(),
|
||||
default=ContentSettings(),
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: list = SchemaField(
|
||||
description="List of document contents",
|
||||
default_factory=list,
|
||||
description="List of document contents", default_factory=list
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -63,7 +45,7 @@ class ExaContentsBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: ExaCredentials, **kwargs
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/contents"
|
||||
headers = {
|
||||
@@ -71,6 +53,7 @@ class ExaContentsBlock(Block):
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
# Convert ContentSettings to API format
|
||||
payload = {
|
||||
"ids": input_data.ids,
|
||||
"text": input_data.contents.text,
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import SchemaField
|
||||
from backend.sdk import BaseModel, SchemaField
|
||||
|
||||
|
||||
class TextSettings(BaseModel):
|
||||
@@ -42,13 +40,90 @@ class SummarySettings(BaseModel):
|
||||
class ContentSettings(BaseModel):
|
||||
text: TextSettings = SchemaField(
|
||||
default=TextSettings(),
|
||||
description="Text content settings",
|
||||
)
|
||||
highlights: HighlightSettings = SchemaField(
|
||||
default=HighlightSettings(),
|
||||
description="Highlight settings",
|
||||
)
|
||||
summary: SummarySettings = SchemaField(
|
||||
default=SummarySettings(),
|
||||
description="Summary settings",
|
||||
)
|
||||
|
||||
|
||||
# Websets Models
|
||||
class WebsetEntitySettings(BaseModel):
|
||||
type: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Entity type (e.g., 'company', 'person')",
|
||||
placeholder="company",
|
||||
)
|
||||
|
||||
|
||||
class WebsetCriterion(BaseModel):
|
||||
description: str = SchemaField(
|
||||
description="Description of the criterion",
|
||||
placeholder="Must be based in the US",
|
||||
)
|
||||
success_rate: Optional[int] = SchemaField(
|
||||
default=None,
|
||||
description="Success rate percentage",
|
||||
ge=0,
|
||||
le=100,
|
||||
)
|
||||
|
||||
|
||||
class WebsetSearchConfig(BaseModel):
|
||||
query: str = SchemaField(
|
||||
description="Search query",
|
||||
placeholder="Marketing agencies based in the US",
|
||||
)
|
||||
count: int = SchemaField(
|
||||
default=10,
|
||||
description="Number of results to return",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
entity: Optional[WebsetEntitySettings] = SchemaField(
|
||||
default=None,
|
||||
description="Entity settings for the search",
|
||||
)
|
||||
criteria: Optional[list[WebsetCriterion]] = SchemaField(
|
||||
default=None,
|
||||
description="Search criteria",
|
||||
)
|
||||
behavior: Optional[str] = SchemaField(
|
||||
default="override",
|
||||
description="Behavior when updating results ('override' or 'append')",
|
||||
placeholder="override",
|
||||
)
|
||||
|
||||
|
||||
class EnrichmentOption(BaseModel):
|
||||
label: str = SchemaField(
|
||||
description="Label for the enrichment option",
|
||||
placeholder="Option 1",
|
||||
)
|
||||
|
||||
|
||||
class WebsetEnrichmentConfig(BaseModel):
|
||||
title: str = SchemaField(
|
||||
description="Title of the enrichment",
|
||||
placeholder="Company Details",
|
||||
)
|
||||
description: str = SchemaField(
|
||||
description="Description of what this enrichment does",
|
||||
placeholder="Extract company information",
|
||||
)
|
||||
format: str = SchemaField(
|
||||
default="text",
|
||||
description="Format of the enrichment result",
|
||||
placeholder="text",
|
||||
)
|
||||
instructions: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Instructions for the enrichment",
|
||||
placeholder="Extract key company metrics",
|
||||
)
|
||||
options: Optional[list[EnrichmentOption]] = SchemaField(
|
||||
default=None,
|
||||
description="Options for the enrichment",
|
||||
)
|
||||
|
||||
@@ -1,71 +1,61 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from backend.blocks.exa._auth import (
|
||||
ExaCredentials,
|
||||
ExaCredentialsField,
|
||||
ExaCredentialsInput,
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.blocks.exa.helpers import ContentSettings
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import ContentSettings
|
||||
|
||||
|
||||
class ExaSearchBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: ExaCredentialsInput = ExaCredentialsField()
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
query: str = SchemaField(description="The search query")
|
||||
use_auto_prompt: bool = SchemaField(
|
||||
description="Whether to use autoprompt",
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
type: str = SchemaField(
|
||||
description="Type of search",
|
||||
default="",
|
||||
advanced=True,
|
||||
description="Whether to use autoprompt", default=True, advanced=True
|
||||
)
|
||||
type: str = SchemaField(description="Type of search", default="", advanced=True)
|
||||
category: str = SchemaField(
|
||||
description="Category to search within",
|
||||
default="",
|
||||
advanced=True,
|
||||
description="Category to search within", default="", advanced=True
|
||||
)
|
||||
number_of_results: int = SchemaField(
|
||||
description="Number of results to return",
|
||||
default=10,
|
||||
advanced=True,
|
||||
description="Number of results to return", default=10, advanced=True
|
||||
)
|
||||
include_domains: List[str] = SchemaField(
|
||||
description="Domains to include in search",
|
||||
default_factory=list,
|
||||
include_domains: list[str] = SchemaField(
|
||||
description="Domains to include in search", default_factory=list
|
||||
)
|
||||
exclude_domains: List[str] = SchemaField(
|
||||
exclude_domains: list[str] = SchemaField(
|
||||
description="Domains to exclude from search",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
start_crawl_date: datetime = SchemaField(
|
||||
description="Start date for crawled content",
|
||||
description="Start date for crawled content"
|
||||
)
|
||||
end_crawl_date: datetime = SchemaField(
|
||||
description="End date for crawled content",
|
||||
description="End date for crawled content"
|
||||
)
|
||||
start_published_date: datetime = SchemaField(
|
||||
description="Start date for published content",
|
||||
description="Start date for published content"
|
||||
)
|
||||
end_published_date: datetime = SchemaField(
|
||||
description="End date for published content",
|
||||
description="End date for published content"
|
||||
)
|
||||
include_text: List[str] = SchemaField(
|
||||
description="Text patterns to include",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
include_text: list[str] = SchemaField(
|
||||
description="Text patterns to include", default_factory=list, advanced=True
|
||||
)
|
||||
exclude_text: List[str] = SchemaField(
|
||||
description="Text patterns to exclude",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
exclude_text: list[str] = SchemaField(
|
||||
description="Text patterns to exclude", default_factory=list, advanced=True
|
||||
)
|
||||
contents: ContentSettings = SchemaField(
|
||||
description="Content retrieval settings",
|
||||
@@ -75,8 +65,7 @@ class ExaSearchBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: list = SchemaField(
|
||||
description="List of search results",
|
||||
default_factory=list,
|
||||
description="List of search results", default_factory=list
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed",
|
||||
@@ -92,7 +81,7 @@ class ExaSearchBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: ExaCredentials, **kwargs
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/search"
|
||||
headers = {
|
||||
@@ -104,7 +93,7 @@ class ExaSearchBlock(Block):
|
||||
"query": input_data.query,
|
||||
"useAutoprompt": input_data.use_auto_prompt,
|
||||
"numResults": input_data.number_of_results,
|
||||
"contents": input_data.contents.dict(),
|
||||
"contents": input_data.contents.model_dump(),
|
||||
}
|
||||
|
||||
date_field_mapping = {
|
||||
|
||||
@@ -1,57 +1,60 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks.exa._auth import (
|
||||
ExaCredentials,
|
||||
ExaCredentialsField,
|
||||
ExaCredentialsInput,
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import Requests
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import ContentSettings
|
||||
|
||||
|
||||
class ExaFindSimilarBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: ExaCredentialsInput = ExaCredentialsField()
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
url: str = SchemaField(
|
||||
description="The url for which you would like to find similar links"
|
||||
)
|
||||
number_of_results: int = SchemaField(
|
||||
description="Number of results to return",
|
||||
default=10,
|
||||
advanced=True,
|
||||
description="Number of results to return", default=10, advanced=True
|
||||
)
|
||||
include_domains: List[str] = SchemaField(
|
||||
include_domains: list[str] = SchemaField(
|
||||
description="Domains to include in search",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
exclude_domains: List[str] = SchemaField(
|
||||
exclude_domains: list[str] = SchemaField(
|
||||
description="Domains to exclude from search",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
start_crawl_date: datetime = SchemaField(
|
||||
description="Start date for crawled content",
|
||||
description="Start date for crawled content"
|
||||
)
|
||||
end_crawl_date: datetime = SchemaField(
|
||||
description="End date for crawled content",
|
||||
description="End date for crawled content"
|
||||
)
|
||||
start_published_date: datetime = SchemaField(
|
||||
description="Start date for published content",
|
||||
description="Start date for published content"
|
||||
)
|
||||
end_published_date: datetime = SchemaField(
|
||||
description="End date for published content",
|
||||
description="End date for published content"
|
||||
)
|
||||
include_text: List[str] = SchemaField(
|
||||
include_text: list[str] = SchemaField(
|
||||
description="Text patterns to include (max 1 string, up to 5 words)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
exclude_text: List[str] = SchemaField(
|
||||
exclude_text: list[str] = SchemaField(
|
||||
description="Text patterns to exclude (max 1 string, up to 5 words)",
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
@@ -63,11 +66,13 @@ class ExaFindSimilarBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: List[Any] = SchemaField(
|
||||
results: list[Any] = SchemaField(
|
||||
description="List of similar documents with title, URL, published date, author, and score",
|
||||
default_factory=list,
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -79,7 +84,7 @@ class ExaFindSimilarBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: ExaCredentials, **kwargs
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/findSimilar"
|
||||
headers = {
|
||||
@@ -90,7 +95,7 @@ class ExaFindSimilarBlock(Block):
|
||||
payload = {
|
||||
"url": input_data.url,
|
||||
"numResults": input_data.number_of_results,
|
||||
"contents": input_data.contents.dict(),
|
||||
"contents": input_data.contents.model_dump(),
|
||||
}
|
||||
|
||||
optional_field_mapping = {
|
||||
|
||||
201
autogpt_platform/backend/backend/blocks/exa/webhook_blocks.py
Normal file
201
autogpt_platform/backend/backend/blocks/exa/webhook_blocks.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Exa Webhook Blocks
|
||||
|
||||
These blocks handle webhook events from Exa's API for websets and other events.
|
||||
"""
|
||||
|
||||
from backend.sdk import (
|
||||
BaseModel,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
BlockWebhookConfig,
|
||||
CredentialsMetaInput,
|
||||
Field,
|
||||
ProviderName,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from ._webhook import ExaEventType
|
||||
|
||||
|
||||
class WebsetEventFilter(BaseModel):
|
||||
"""Filter configuration for Exa webset events."""
|
||||
|
||||
webset_created: bool = Field(
|
||||
default=True, description="Receive notifications when websets are created"
|
||||
)
|
||||
webset_deleted: bool = Field(
|
||||
default=False, description="Receive notifications when websets are deleted"
|
||||
)
|
||||
webset_paused: bool = Field(
|
||||
default=False, description="Receive notifications when websets are paused"
|
||||
)
|
||||
webset_idle: bool = Field(
|
||||
default=False, description="Receive notifications when websets become idle"
|
||||
)
|
||||
search_created: bool = Field(
|
||||
default=True,
|
||||
description="Receive notifications when webset searches are created",
|
||||
)
|
||||
search_completed: bool = Field(
|
||||
default=True, description="Receive notifications when webset searches complete"
|
||||
)
|
||||
search_canceled: bool = Field(
|
||||
default=False,
|
||||
description="Receive notifications when webset searches are canceled",
|
||||
)
|
||||
search_updated: bool = Field(
|
||||
default=False,
|
||||
description="Receive notifications when webset searches are updated",
|
||||
)
|
||||
item_created: bool = Field(
|
||||
default=True, description="Receive notifications when webset items are created"
|
||||
)
|
||||
item_enriched: bool = Field(
|
||||
default=True, description="Receive notifications when webset items are enriched"
|
||||
)
|
||||
export_created: bool = Field(
|
||||
default=False,
|
||||
description="Receive notifications when webset exports are created",
|
||||
)
|
||||
export_completed: bool = Field(
|
||||
default=True, description="Receive notifications when webset exports complete"
|
||||
)
|
||||
import_created: bool = Field(
|
||||
default=False, description="Receive notifications when imports are created"
|
||||
)
|
||||
import_completed: bool = Field(
|
||||
default=True, description="Receive notifications when imports complete"
|
||||
)
|
||||
import_processing: bool = Field(
|
||||
default=False, description="Receive notifications when imports are processing"
|
||||
)
|
||||
|
||||
|
||||
class ExaWebsetWebhookBlock(Block):
|
||||
"""
|
||||
Receives webhook notifications for Exa webset events.
|
||||
|
||||
This block allows you to monitor various events related to Exa websets,
|
||||
including creation, updates, searches, and exports.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="Exa API credentials for webhook management"
|
||||
)
|
||||
webhook_url: str = SchemaField(
|
||||
description="URL to receive webhooks (auto-generated)",
|
||||
default="",
|
||||
hidden=True,
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The webset ID to monitor (optional, monitors all if empty)",
|
||||
default="",
|
||||
)
|
||||
event_filter: WebsetEventFilter = SchemaField(
|
||||
description="Configure which events to receive", default=WebsetEventFilter()
|
||||
)
|
||||
payload: dict = SchemaField(
|
||||
description="Webhook payload data", default={}, hidden=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
event_type: str = SchemaField(description="Type of event that occurred")
|
||||
event_id: str = SchemaField(description="Unique identifier for this event")
|
||||
webset_id: str = SchemaField(description="ID of the affected webset")
|
||||
data: dict = SchemaField(description="Event-specific data")
|
||||
timestamp: str = SchemaField(description="When the event occurred")
|
||||
metadata: dict = SchemaField(description="Additional event metadata")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d0204ed8-8b81-408d-8b8d-ed087a546228",
|
||||
description="Receive webhook notifications for Exa webset events",
|
||||
categories={BlockCategory.INPUT},
|
||||
input_schema=ExaWebsetWebhookBlock.Input,
|
||||
output_schema=ExaWebsetWebhookBlock.Output,
|
||||
block_type=BlockType.WEBHOOK,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName("exa"),
|
||||
webhook_type="webset",
|
||||
event_filter_input="event_filter",
|
||||
resource_format="{webset_id}",
|
||||
),
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
"""Process incoming Exa webhook payload."""
|
||||
try:
|
||||
payload = input_data.payload
|
||||
|
||||
# Extract event details
|
||||
event_type = payload.get("eventType", "unknown")
|
||||
event_id = payload.get("eventId", "")
|
||||
|
||||
# Get webset ID from payload or input
|
||||
webset_id = payload.get("websetId", input_data.webset_id)
|
||||
|
||||
# Check if we should process this event based on filter
|
||||
should_process = self._should_process_event(
|
||||
event_type, input_data.event_filter
|
||||
)
|
||||
|
||||
if not should_process:
|
||||
# Skip events that don't match our filter
|
||||
return
|
||||
|
||||
# Extract event data
|
||||
event_data = payload.get("data", {})
|
||||
timestamp = payload.get("occurredAt", payload.get("createdAt", ""))
|
||||
metadata = payload.get("metadata", {})
|
||||
|
||||
yield "event_type", event_type
|
||||
yield "event_id", event_id
|
||||
yield "webset_id", webset_id
|
||||
yield "data", event_data
|
||||
yield "timestamp", timestamp
|
||||
yield "metadata", metadata
|
||||
|
||||
except Exception as e:
|
||||
# Handle errors gracefully
|
||||
yield "event_type", "error"
|
||||
yield "event_id", ""
|
||||
yield "webset_id", input_data.webset_id
|
||||
yield "data", {"error": str(e)}
|
||||
yield "timestamp", ""
|
||||
yield "metadata", {}
|
||||
|
||||
def _should_process_event(
|
||||
self, event_type: str, event_filter: WebsetEventFilter
|
||||
) -> bool:
|
||||
"""Check if an event should be processed based on the filter."""
|
||||
filter_mapping = {
|
||||
ExaEventType.WEBSET_CREATED: event_filter.webset_created,
|
||||
ExaEventType.WEBSET_DELETED: event_filter.webset_deleted,
|
||||
ExaEventType.WEBSET_PAUSED: event_filter.webset_paused,
|
||||
ExaEventType.WEBSET_IDLE: event_filter.webset_idle,
|
||||
ExaEventType.WEBSET_SEARCH_CREATED: event_filter.search_created,
|
||||
ExaEventType.WEBSET_SEARCH_COMPLETED: event_filter.search_completed,
|
||||
ExaEventType.WEBSET_SEARCH_CANCELED: event_filter.search_canceled,
|
||||
ExaEventType.WEBSET_SEARCH_UPDATED: event_filter.search_updated,
|
||||
ExaEventType.WEBSET_ITEM_CREATED: event_filter.item_created,
|
||||
ExaEventType.WEBSET_ITEM_ENRICHED: event_filter.item_enriched,
|
||||
ExaEventType.WEBSET_EXPORT_CREATED: event_filter.export_created,
|
||||
ExaEventType.WEBSET_EXPORT_COMPLETED: event_filter.export_completed,
|
||||
ExaEventType.IMPORT_CREATED: event_filter.import_created,
|
||||
ExaEventType.IMPORT_COMPLETED: event_filter.import_completed,
|
||||
ExaEventType.IMPORT_PROCESSING: event_filter.import_processing,
|
||||
}
|
||||
|
||||
# Try to convert string to ExaEventType enum
|
||||
try:
|
||||
event_type_enum = ExaEventType(event_type)
|
||||
return filter_mapping.get(event_type_enum, True)
|
||||
except ValueError:
|
||||
# If event_type is not a valid enum value, process it by default
|
||||
return True
|
||||
456
autogpt_platform/backend/backend/blocks/exa/websets.py
Normal file
456
autogpt_platform/backend/backend/blocks/exa/websets.py
Normal file
@@ -0,0 +1,456 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._config import exa
|
||||
from .helpers import WebsetEnrichmentConfig, WebsetSearchConfig
|
||||
|
||||
|
||||
class ExaCreateWebsetBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
search: WebsetSearchConfig = SchemaField(
|
||||
description="Initial search configuration for the Webset"
|
||||
)
|
||||
enrichments: Optional[list[WebsetEnrichmentConfig]] = SchemaField(
|
||||
default=None,
|
||||
description="Enrichments to apply to Webset items",
|
||||
advanced=True,
|
||||
)
|
||||
external_id: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="External identifier for the webset",
|
||||
placeholder="my-webset-123",
|
||||
advanced=True,
|
||||
)
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="Key-value pairs to associate with this webset",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: str = SchemaField(
|
||||
description="The unique identifier for the created webset"
|
||||
)
|
||||
status: str = SchemaField(description="The status of the webset")
|
||||
external_id: Optional[str] = SchemaField(
|
||||
description="The external identifier for the webset", default=None
|
||||
)
|
||||
created_at: str = SchemaField(
|
||||
description="The date and time the webset was created"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0cda29ff-c549-4a19-8805-c982b7d4ec34",
|
||||
description="Create a new Exa Webset for persistent web search collections",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaCreateWebsetBlock.Input,
|
||||
output_schema=ExaCreateWebsetBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/websets/v0/websets"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
# Build the payload
|
||||
payload: dict[str, Any] = {
|
||||
"search": input_data.search.model_dump(exclude_none=True),
|
||||
}
|
||||
|
||||
# Convert enrichments to API format
|
||||
if input_data.enrichments:
|
||||
enrichments_data = []
|
||||
for enrichment in input_data.enrichments:
|
||||
enrichments_data.append(enrichment.model_dump(exclude_none=True))
|
||||
payload["enrichments"] = enrichments_data
|
||||
|
||||
if input_data.external_id:
|
||||
payload["externalId"] = input_data.external_id
|
||||
|
||||
if input_data.metadata:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
|
||||
yield "webset_id", data.get("id", "")
|
||||
yield "status", data.get("status", "")
|
||||
yield "external_id", data.get("externalId")
|
||||
yield "created_at", data.get("createdAt", "")
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "webset_id", ""
|
||||
yield "status", ""
|
||||
yield "created_at", ""
|
||||
|
||||
|
||||
class ExaUpdateWebsetBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset to update",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
metadata: Optional[dict] = SchemaField(
|
||||
default=None,
|
||||
description="Key-value pairs to associate with this webset (set to null to clear)",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: str = SchemaField(description="The unique identifier for the webset")
|
||||
status: str = SchemaField(description="The status of the webset")
|
||||
external_id: Optional[str] = SchemaField(
|
||||
description="The external identifier for the webset", default=None
|
||||
)
|
||||
metadata: dict = SchemaField(
|
||||
description="Updated metadata for the webset", default_factory=dict
|
||||
)
|
||||
updated_at: str = SchemaField(
|
||||
description="The date and time the webset was updated"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="89ccd99a-3c2b-4fbf-9e25-0ffa398d0314",
|
||||
description="Update metadata for an existing Webset",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaUpdateWebsetBlock.Input,
|
||||
output_schema=ExaUpdateWebsetBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
# Build the payload
|
||||
payload = {}
|
||||
if input_data.metadata is not None:
|
||||
payload["metadata"] = input_data.metadata
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=payload)
|
||||
data = response.json()
|
||||
|
||||
yield "webset_id", data.get("id", "")
|
||||
yield "status", data.get("status", "")
|
||||
yield "external_id", data.get("externalId")
|
||||
yield "metadata", data.get("metadata", {})
|
||||
yield "updated_at", data.get("updatedAt", "")
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "webset_id", ""
|
||||
yield "status", ""
|
||||
yield "metadata", {}
|
||||
yield "updated_at", ""
|
||||
|
||||
|
||||
class ExaListWebsetsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
cursor: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Cursor for pagination through results",
|
||||
advanced=True,
|
||||
)
|
||||
limit: int = SchemaField(
|
||||
default=25,
|
||||
description="Number of websets to return (1-100)",
|
||||
ge=1,
|
||||
le=100,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
websets: list = SchemaField(description="List of websets", default_factory=list)
|
||||
has_more: bool = SchemaField(
|
||||
description="Whether there are more results to paginate through",
|
||||
default=False,
|
||||
)
|
||||
next_cursor: Optional[str] = SchemaField(
|
||||
description="Cursor for the next page of results", default=None
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1dcd8fd6-c13f-4e6f-bd4c-654428fa4757",
|
||||
description="List all Websets with pagination support",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaListWebsetsBlock.Input,
|
||||
output_schema=ExaListWebsetsBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/websets/v0/websets"
|
||||
headers = {
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"limit": input_data.limit,
|
||||
}
|
||||
if input_data.cursor:
|
||||
params["cursor"] = input_data.cursor
|
||||
|
||||
try:
|
||||
response = await Requests().get(url, headers=headers, params=params)
|
||||
data = response.json()
|
||||
|
||||
yield "websets", data.get("data", [])
|
||||
yield "has_more", data.get("hasMore", False)
|
||||
yield "next_cursor", data.get("nextCursor")
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "websets", []
|
||||
yield "has_more", False
|
||||
|
||||
|
||||
class ExaGetWebsetBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset to retrieve",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
expand_items: bool = SchemaField(
|
||||
default=False, description="Include items in the response", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: str = SchemaField(description="The unique identifier for the webset")
|
||||
status: str = SchemaField(description="The status of the webset")
|
||||
external_id: Optional[str] = SchemaField(
|
||||
description="The external identifier for the webset", default=None
|
||||
)
|
||||
searches: list[dict] = SchemaField(
|
||||
description="The searches performed on the webset", default_factory=list
|
||||
)
|
||||
enrichments: list[dict] = SchemaField(
|
||||
description="The enrichments applied to the webset", default_factory=list
|
||||
)
|
||||
monitors: list[dict] = SchemaField(
|
||||
description="The monitors for the webset", default_factory=list
|
||||
)
|
||||
items: Optional[list[dict]] = SchemaField(
|
||||
description="The items in the webset (if expand_items is true)",
|
||||
default=None,
|
||||
)
|
||||
metadata: dict = SchemaField(
|
||||
description="Key-value pairs associated with the webset",
|
||||
default_factory=dict,
|
||||
)
|
||||
created_at: str = SchemaField(
|
||||
description="The date and time the webset was created"
|
||||
)
|
||||
updated_at: str = SchemaField(
|
||||
description="The date and time the webset was last updated"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="6ab8e12a-132c-41bf-b5f3-d662620fa832",
|
||||
description="Retrieve a Webset by ID or external ID",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaGetWebsetBlock.Input,
|
||||
output_schema=ExaGetWebsetBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}"
|
||||
headers = {
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
params = {}
|
||||
if input_data.expand_items:
|
||||
params["expand[]"] = "items"
|
||||
|
||||
try:
|
||||
response = await Requests().get(url, headers=headers, params=params)
|
||||
data = response.json()
|
||||
|
||||
yield "webset_id", data.get("id", "")
|
||||
yield "status", data.get("status", "")
|
||||
yield "external_id", data.get("externalId")
|
||||
yield "searches", data.get("searches", [])
|
||||
yield "enrichments", data.get("enrichments", [])
|
||||
yield "monitors", data.get("monitors", [])
|
||||
yield "items", data.get("items")
|
||||
yield "metadata", data.get("metadata", {})
|
||||
yield "created_at", data.get("createdAt", "")
|
||||
yield "updated_at", data.get("updatedAt", "")
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "webset_id", ""
|
||||
yield "status", ""
|
||||
yield "searches", []
|
||||
yield "enrichments", []
|
||||
yield "monitors", []
|
||||
yield "metadata", {}
|
||||
yield "created_at", ""
|
||||
yield "updated_at", ""
|
||||
|
||||
|
||||
class ExaDeleteWebsetBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset to delete",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: str = SchemaField(
|
||||
description="The unique identifier for the deleted webset"
|
||||
)
|
||||
external_id: Optional[str] = SchemaField(
|
||||
description="The external identifier for the deleted webset", default=None
|
||||
)
|
||||
status: str = SchemaField(description="The status of the deleted webset")
|
||||
success: str = SchemaField(
|
||||
description="Whether the deletion was successful", default="true"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="aa6994a2-e986-421f-8d4c-7671d3be7b7e",
|
||||
description="Delete a Webset and all its items",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaDeleteWebsetBlock.Input,
|
||||
output_schema=ExaDeleteWebsetBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}"
|
||||
headers = {
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
try:
|
||||
response = await Requests().delete(url, headers=headers)
|
||||
data = response.json()
|
||||
|
||||
yield "webset_id", data.get("id", "")
|
||||
yield "external_id", data.get("externalId")
|
||||
yield "status", data.get("status", "")
|
||||
yield "success", "true"
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "webset_id", ""
|
||||
yield "status", ""
|
||||
yield "success", "false"
|
||||
|
||||
|
||||
class ExaCancelWebsetBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput = exa.credentials_field(
|
||||
description="The Exa integration requires an API Key."
|
||||
)
|
||||
webset_id: str = SchemaField(
|
||||
description="The ID or external ID of the Webset to cancel",
|
||||
placeholder="webset-id-or-external-id",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
webset_id: str = SchemaField(description="The unique identifier for the webset")
|
||||
status: str = SchemaField(
|
||||
description="The status of the webset after cancellation"
|
||||
)
|
||||
external_id: Optional[str] = SchemaField(
|
||||
description="The external identifier for the webset", default=None
|
||||
)
|
||||
success: str = SchemaField(
|
||||
description="Whether the cancellation was successful", default="true"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the request failed", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e40a6420-1db8-47bb-b00a-0e6aecd74176",
|
||||
description="Cancel all operations being performed on a Webset",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaCancelWebsetBlock.Input,
|
||||
output_schema=ExaCancelWebsetBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}/cancel"
|
||||
headers = {
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers)
|
||||
data = response.json()
|
||||
|
||||
yield "webset_id", data.get("id", "")
|
||||
yield "status", data.get("status", "")
|
||||
yield "external_id", data.get("externalId")
|
||||
yield "success", "true"
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "webset_id", ""
|
||||
yield "status", ""
|
||||
yield "success", "false"
|
||||
@@ -0,0 +1,9 @@
|
||||
# Import the provider builder to ensure it's registered
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
from .triggers import GenericWebhookTriggerBlock, generic_webhook
|
||||
|
||||
# Ensure the SDK registry is patched to include our webhook manager
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
__all__ = ["GenericWebhookTriggerBlock", "generic_webhook"]
|
||||
@@ -3,10 +3,7 @@ import logging
|
||||
from fastapi import Request
|
||||
from strenum import StrEnum
|
||||
|
||||
from backend.data import integrations
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from ._manual_base import ManualWebhookManagerBase
|
||||
from backend.sdk import ManualWebhookManagerBase, Webhook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,12 +13,11 @@ class GenericWebhookType(StrEnum):
|
||||
|
||||
|
||||
class GenericWebhooksManager(ManualWebhookManagerBase):
|
||||
PROVIDER_NAME = ProviderName.GENERIC_WEBHOOK
|
||||
WebhookType = GenericWebhookType
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(
|
||||
cls, webhook: integrations.Webhook, request: Request
|
||||
cls, webhook: Webhook, request: Request
|
||||
) -> tuple[dict, str]:
|
||||
payload = await request.json()
|
||||
event_type = GenericWebhookType.PLAIN
|
||||
@@ -1,13 +1,21 @@
|
||||
from backend.data.block import (
|
||||
from backend.sdk import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockManualWebhookConfig,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
ProviderBuilder,
|
||||
ProviderName,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._webhook import GenericWebhooksManager, GenericWebhookType
|
||||
|
||||
generic_webhook = (
|
||||
ProviderBuilder("generic_webhook")
|
||||
.with_webhook_manager(GenericWebhooksManager)
|
||||
.build()
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.generic import GenericWebhookType
|
||||
|
||||
|
||||
class GenericWebhookTriggerBlock(Block):
|
||||
@@ -36,7 +44,7 @@ class GenericWebhookTriggerBlock(Block):
|
||||
input_schema=GenericWebhookTriggerBlock.Input,
|
||||
output_schema=GenericWebhookTriggerBlock.Output,
|
||||
webhook_config=BlockManualWebhookConfig(
|
||||
provider=ProviderName.GENERIC_WEBHOOK,
|
||||
provider=ProviderName(generic_webhook.name),
|
||||
webhook_type=GenericWebhookType.PLAIN,
|
||||
),
|
||||
test_input={"constants": {"key": "value"}, "payload": self.example_payload},
|
||||
|
||||
@@ -498,6 +498,9 @@ class GithubListIssuesBlock(Block):
|
||||
issue: IssueItem = SchemaField(
|
||||
title="Issue", description="Issues with their title and URL"
|
||||
)
|
||||
issues: list[IssueItem] = SchemaField(
|
||||
description="List of issues with their title and URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing issues failed")
|
||||
|
||||
def __init__(self):
|
||||
@@ -513,13 +516,22 @@ class GithubListIssuesBlock(Block):
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"issues",
|
||||
[
|
||||
{
|
||||
"title": "Issue 1",
|
||||
"url": "https://github.com/owner/repo/issues/1",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"issue",
|
||||
{
|
||||
"title": "Issue 1",
|
||||
"url": "https://github.com/owner/repo/issues/1",
|
||||
},
|
||||
)
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_issues": lambda *args, **kwargs: [
|
||||
@@ -551,10 +563,12 @@ class GithubListIssuesBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
for issue in await self.list_issues(
|
||||
issues = await self.list_issues(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
):
|
||||
)
|
||||
yield "issues", issues
|
||||
for issue in issues:
|
||||
yield "issue", issue
|
||||
|
||||
|
||||
|
||||
@@ -31,7 +31,12 @@ class GithubListPullRequestsBlock(Block):
|
||||
pull_request: PRItem = SchemaField(
|
||||
title="Pull Request", description="PRs with their title and URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing issues failed")
|
||||
pull_requests: list[PRItem] = SchemaField(
|
||||
description="List of pull requests with their title and URL"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if listing pull requests failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -46,13 +51,22 @@ class GithubListPullRequestsBlock(Block):
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"pull_requests",
|
||||
[
|
||||
{
|
||||
"title": "Pull request 1",
|
||||
"url": "https://github.com/owner/repo/pull/1",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"pull_request",
|
||||
{
|
||||
"title": "Pull request 1",
|
||||
"url": "https://github.com/owner/repo/pull/1",
|
||||
},
|
||||
)
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_prs": lambda *args, **kwargs: [
|
||||
@@ -88,6 +102,7 @@ class GithubListPullRequestsBlock(Block):
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield "pull_requests", pull_requests
|
||||
for pr in pull_requests:
|
||||
yield "pull_request", pr
|
||||
|
||||
@@ -265,10 +280,26 @@ class GithubReadPullRequestBlock(Block):
|
||||
files = response.json()
|
||||
changes = []
|
||||
for file in files:
|
||||
filename = file.get("filename", "")
|
||||
status = file.get("status", "")
|
||||
changes.append(f"{filename}: {status}")
|
||||
return "\n".join(changes)
|
||||
status: str = file.get("status", "")
|
||||
diff: str = file.get("patch", "")
|
||||
if status != "removed":
|
||||
is_filename: str = file.get("filename", "")
|
||||
was_filename: str = (
|
||||
file.get("previous_filename", is_filename)
|
||||
if status != "added"
|
||||
else ""
|
||||
)
|
||||
else:
|
||||
is_filename = ""
|
||||
was_filename: str = file.get("filename", "")
|
||||
|
||||
patch_header = ""
|
||||
if was_filename:
|
||||
patch_header += f"--- {was_filename}\n"
|
||||
if is_filename:
|
||||
patch_header += f"+++ {is_filename}\n"
|
||||
changes.append(patch_header + diff)
|
||||
return "\n\n".join(changes)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
@@ -444,6 +475,9 @@ class GithubListPRReviewersBlock(Block):
|
||||
title="Reviewer",
|
||||
description="Reviewers with their username and profile URL",
|
||||
)
|
||||
reviewers: list[ReviewerItem] = SchemaField(
|
||||
description="List of reviewers with their username and profile URL"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if listing reviewers failed"
|
||||
)
|
||||
@@ -461,13 +495,22 @@ class GithubListPRReviewersBlock(Block):
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"reviewers",
|
||||
[
|
||||
{
|
||||
"username": "reviewer1",
|
||||
"url": "https://github.com/reviewer1",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"reviewer",
|
||||
{
|
||||
"username": "reviewer1",
|
||||
"url": "https://github.com/reviewer1",
|
||||
},
|
||||
)
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_reviewers": lambda *args, **kwargs: [
|
||||
@@ -500,10 +543,12 @@ class GithubListPRReviewersBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
for reviewer in await self.list_reviewers(
|
||||
reviewers = await self.list_reviewers(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
):
|
||||
)
|
||||
yield "reviewers", reviewers
|
||||
for reviewer in reviewers:
|
||||
yield "reviewer", reviewer
|
||||
|
||||
|
||||
|
||||
@@ -31,6 +31,9 @@ class GithubListTagsBlock(Block):
|
||||
tag: TagItem = SchemaField(
|
||||
title="Tag", description="Tags with their name and file tree browser URL"
|
||||
)
|
||||
tags: list[TagItem] = SchemaField(
|
||||
description="List of tags with their name and file tree browser URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing tags failed")
|
||||
|
||||
def __init__(self):
|
||||
@@ -46,13 +49,22 @@ class GithubListTagsBlock(Block):
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"tags",
|
||||
[
|
||||
{
|
||||
"name": "v1.0.0",
|
||||
"url": "https://github.com/owner/repo/tree/v1.0.0",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"tag",
|
||||
{
|
||||
"name": "v1.0.0",
|
||||
"url": "https://github.com/owner/repo/tree/v1.0.0",
|
||||
},
|
||||
)
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_tags": lambda *args, **kwargs: [
|
||||
@@ -93,6 +105,7 @@ class GithubListTagsBlock(Block):
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield "tags", tags
|
||||
for tag in tags:
|
||||
yield "tag", tag
|
||||
|
||||
@@ -114,6 +127,9 @@ class GithubListBranchesBlock(Block):
|
||||
title="Branch",
|
||||
description="Branches with their name and file tree browser URL",
|
||||
)
|
||||
branches: list[BranchItem] = SchemaField(
|
||||
description="List of branches with their name and file tree browser URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing branches failed")
|
||||
|
||||
def __init__(self):
|
||||
@@ -129,13 +145,22 @@ class GithubListBranchesBlock(Block):
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"branches",
|
||||
[
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"branch",
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
},
|
||||
)
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_branches": lambda *args, **kwargs: [
|
||||
@@ -176,6 +201,7 @@ class GithubListBranchesBlock(Block):
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield "branches", branches
|
||||
for branch in branches:
|
||||
yield "branch", branch
|
||||
|
||||
@@ -199,6 +225,9 @@ class GithubListDiscussionsBlock(Block):
|
||||
discussion: DiscussionItem = SchemaField(
|
||||
title="Discussion", description="Discussions with their title and URL"
|
||||
)
|
||||
discussions: list[DiscussionItem] = SchemaField(
|
||||
description="List of discussions with their title and URL"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if listing discussions failed"
|
||||
)
|
||||
@@ -217,13 +246,22 @@ class GithubListDiscussionsBlock(Block):
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"discussions",
|
||||
[
|
||||
{
|
||||
"title": "Discussion 1",
|
||||
"url": "https://github.com/owner/repo/discussions/1",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"discussion",
|
||||
{
|
||||
"title": "Discussion 1",
|
||||
"url": "https://github.com/owner/repo/discussions/1",
|
||||
},
|
||||
)
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_discussions": lambda *args, **kwargs: [
|
||||
@@ -279,6 +317,7 @@ class GithubListDiscussionsBlock(Block):
|
||||
input_data.repo_url,
|
||||
input_data.num_discussions,
|
||||
)
|
||||
yield "discussions", discussions
|
||||
for discussion in discussions:
|
||||
yield "discussion", discussion
|
||||
|
||||
@@ -300,6 +339,9 @@ class GithubListReleasesBlock(Block):
|
||||
title="Release",
|
||||
description="Releases with their name and file tree browser URL",
|
||||
)
|
||||
releases: list[ReleaseItem] = SchemaField(
|
||||
description="List of releases with their name and file tree browser URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing releases failed")
|
||||
|
||||
def __init__(self):
|
||||
@@ -315,13 +357,22 @@ class GithubListReleasesBlock(Block):
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"releases",
|
||||
[
|
||||
{
|
||||
"name": "v1.0.0",
|
||||
"url": "https://github.com/owner/repo/releases/tag/v1.0.0",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"release",
|
||||
{
|
||||
"name": "v1.0.0",
|
||||
"url": "https://github.com/owner/repo/releases/tag/v1.0.0",
|
||||
},
|
||||
)
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_releases": lambda *args, **kwargs: [
|
||||
@@ -357,6 +408,7 @@ class GithubListReleasesBlock(Block):
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield "releases", releases
|
||||
for release in releases:
|
||||
yield "release", release
|
||||
|
||||
@@ -1041,6 +1093,9 @@ class GithubListStargazersBlock(Block):
|
||||
title="Stargazer",
|
||||
description="Stargazers with their username and profile URL",
|
||||
)
|
||||
stargazers: list[StargazerItem] = SchemaField(
|
||||
description="List of stargazers with their username and profile URL"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if listing stargazers failed"
|
||||
)
|
||||
@@ -1058,13 +1113,22 @@ class GithubListStargazersBlock(Block):
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"stargazers",
|
||||
[
|
||||
{
|
||||
"username": "octocat",
|
||||
"url": "https://github.com/octocat",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"stargazer",
|
||||
{
|
||||
"username": "octocat",
|
||||
"url": "https://github.com/octocat",
|
||||
},
|
||||
)
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_stargazers": lambda *args, **kwargs: [
|
||||
@@ -1104,5 +1168,6 @@ class GithubListStargazersBlock(Block):
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield "stargazers", stargazers
|
||||
for stargazer in stargazers:
|
||||
yield "stargazer", stargazer
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,11 +3,19 @@ import logging
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import aiofiles
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
HostScopedCredentials,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.file import (
|
||||
MediaFileType,
|
||||
get_exec_file_path,
|
||||
@@ -19,6 +27,30 @@ from backend.util.request import Requests
|
||||
logger = logging.getLogger(name=__name__)
|
||||
|
||||
|
||||
# Host-scoped credentials for HTTP requests
|
||||
HttpCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.HTTP], Literal["host_scoped"]
|
||||
]
|
||||
|
||||
|
||||
TEST_CREDENTIALS = HostScopedCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer test-token"),
|
||||
"X-API-Key": SecretStr("test-api-key"),
|
||||
},
|
||||
title="Mock HTTP Host-Scoped Credentials",
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class HttpMethod(Enum):
|
||||
GET = "GET"
|
||||
POST = "POST"
|
||||
@@ -169,3 +201,62 @@ class SendWebRequestBlock(Block):
|
||||
yield "client_error", result
|
||||
else:
|
||||
yield "server_error", result
|
||||
|
||||
|
||||
class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
||||
class Input(SendWebRequestBlock.Input):
|
||||
credentials: HttpCredentials = CredentialsField(
|
||||
description="HTTP host-scoped credentials for automatic header injection",
|
||||
discriminator="url",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
Block.__init__(
|
||||
self,
|
||||
id="fff86bcd-e001-4bad-a7f6-2eae4720c8dc",
|
||||
description="Make an authenticated HTTP request with host-scoped credentials (JSON / form / multipart).",
|
||||
categories={BlockCategory.OUTPUT},
|
||||
input_schema=SendAuthenticatedWebRequestBlock.Input,
|
||||
output_schema=SendWebRequestBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run( # type: ignore[override]
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
credentials: HostScopedCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Create SendWebRequestBlock.Input from our input (removing credentials field)
|
||||
base_input = SendWebRequestBlock.Input(
|
||||
url=input_data.url,
|
||||
method=input_data.method,
|
||||
headers=input_data.headers,
|
||||
json_format=input_data.json_format,
|
||||
body=input_data.body,
|
||||
files_name=input_data.files_name,
|
||||
files=input_data.files,
|
||||
)
|
||||
|
||||
# Apply host-scoped credentials to headers
|
||||
extra_headers = {}
|
||||
if credentials.matches_url(input_data.url):
|
||||
logger.debug(
|
||||
f"Applying host-scoped credentials {credentials.id} for URL {input_data.url}"
|
||||
)
|
||||
extra_headers.update(credentials.get_headers_dict())
|
||||
else:
|
||||
logger.warning(
|
||||
f"Host-scoped credentials {credentials.id} do not match URL {input_data.url}"
|
||||
)
|
||||
|
||||
# Merge with user-provided headers (user headers take precedence)
|
||||
base_input.headers = {**extra_headers, **input_data.headers}
|
||||
|
||||
# Use parent class run method
|
||||
async for output_name, output_data in super().run(
|
||||
base_input, graph_exec_id=graph_exec_id, **kwargs
|
||||
):
|
||||
yield output_name, output_data
|
||||
|
||||
14
autogpt_platform/backend/backend/blocks/linear/__init__.py
Normal file
14
autogpt_platform/backend/backend/blocks/linear/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Linear integration blocks for AutoGPT Platform.
|
||||
"""
|
||||
|
||||
from .comment import LinearCreateCommentBlock
|
||||
from .issues import LinearCreateIssueBlock, LinearSearchIssuesBlock
|
||||
from .projects import LinearSearchProjectsBlock
|
||||
|
||||
__all__ = [
|
||||
"LinearCreateCommentBlock",
|
||||
"LinearCreateIssueBlock",
|
||||
"LinearSearchIssuesBlock",
|
||||
"LinearSearchProjectsBlock",
|
||||
]
|
||||
@@ -1,16 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from backend.blocks.linear._auth import LinearCredentials
|
||||
from backend.blocks.linear.models import (
|
||||
CreateCommentResponse,
|
||||
CreateIssueResponse,
|
||||
Issue,
|
||||
Project,
|
||||
)
|
||||
from backend.util.request import Requests
|
||||
from backend.sdk import APIKeyCredentials, OAuth2Credentials, Requests
|
||||
|
||||
from .models import CreateCommentResponse, CreateIssueResponse, Issue, Project
|
||||
|
||||
|
||||
class LinearAPIException(Exception):
|
||||
@@ -29,13 +24,12 @@ class LinearClient:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
credentials: LinearCredentials | None = None,
|
||||
credentials: Union[OAuth2Credentials, APIKeyCredentials, None] = None,
|
||||
custom_requests: Optional[Requests] = None,
|
||||
):
|
||||
if custom_requests:
|
||||
self._requests = custom_requests
|
||||
else:
|
||||
|
||||
headers: Dict[str, str] = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
@@ -1,31 +1,19 @@
|
||||
"""
|
||||
Shared configuration for all Linear blocks using the new SDK pattern.
|
||||
"""
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import (
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
BlockCostType,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
secrets = Secrets()
|
||||
LINEAR_OAUTH_IS_CONFIGURED = bool(
|
||||
secrets.linear_client_id and secrets.linear_client_secret
|
||||
ProviderBuilder,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
LinearCredentials = OAuth2Credentials | APIKeyCredentials
|
||||
# LinearCredentialsInput = CredentialsMetaInput[
|
||||
# Literal[ProviderName.LINEAR],
|
||||
# Literal["oauth2", "api_key"] if LINEAR_OAUTH_IS_CONFIGURED else Literal["oauth2"],
|
||||
# ]
|
||||
LinearCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.LINEAR], Literal["oauth2"]
|
||||
]
|
||||
|
||||
from ._oauth import LinearOAuthHandler
|
||||
|
||||
# (required) Comma separated list of scopes:
|
||||
|
||||
@@ -50,21 +38,35 @@ class LinearScope(str, Enum):
|
||||
ADMIN = "admin"
|
||||
|
||||
|
||||
def LinearCredentialsField(scopes: list[LinearScope]) -> LinearCredentialsInput:
|
||||
"""
|
||||
Creates a Linear credentials input on a block.
|
||||
# Check if Linear OAuth is configured
|
||||
client_id = os.getenv("LINEAR_CLIENT_ID")
|
||||
client_secret = os.getenv("LINEAR_CLIENT_SECRET")
|
||||
LINEAR_OAUTH_IS_CONFIGURED = bool(client_id and client_secret)
|
||||
|
||||
Params:
|
||||
scope: The authorization scope needed for the block to work. ([list of available scopes](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/scopes-for-oauth-apps#available-scopes))
|
||||
""" # noqa
|
||||
return CredentialsField(
|
||||
required_scopes=set([LinearScope.READ.value]).union(
|
||||
set([scope.value for scope in scopes])
|
||||
),
|
||||
description="The Linear integration can be used with OAuth, "
|
||||
"or any API key with sufficient permissions for the blocks it is used on.",
|
||||
# Build the Linear provider
|
||||
builder = (
|
||||
ProviderBuilder("linear")
|
||||
.with_api_key(env_var_name="LINEAR_API_KEY", title="Linear API Key")
|
||||
.with_base_cost(1, BlockCostType.RUN)
|
||||
)
|
||||
|
||||
# Linear only supports OAuth authentication
|
||||
if LINEAR_OAUTH_IS_CONFIGURED:
|
||||
builder = builder.with_oauth(
|
||||
LinearOAuthHandler,
|
||||
scopes=[
|
||||
LinearScope.READ,
|
||||
LinearScope.WRITE,
|
||||
LinearScope.ISSUES_CREATE,
|
||||
LinearScope.COMMENTS_CREATE,
|
||||
],
|
||||
client_id_env_var="LINEAR_CLIENT_ID",
|
||||
client_secret_env_var="LINEAR_CLIENT_SECRET",
|
||||
)
|
||||
|
||||
# Build the provider
|
||||
linear = builder.build()
|
||||
|
||||
|
||||
TEST_CREDENTIALS_OAUTH = OAuth2Credentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -1,15 +1,27 @@
|
||||
"""
|
||||
Linear OAuth handler implementation.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from pydantic import SecretStr
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
BaseOAuthHandler,
|
||||
OAuth2Credentials,
|
||||
ProviderName,
|
||||
Requests,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from backend.blocks.linear._api import LinearAPIException
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import Requests
|
||||
|
||||
from .base import BaseOAuthHandler
|
||||
class LinearAPIException(Exception):
|
||||
"""Exception for Linear API errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class LinearOAuthHandler(BaseOAuthHandler):
|
||||
@@ -17,7 +29,9 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
OAuth2 handler for Linear.
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = ProviderName.LINEAR
|
||||
# Provider name will be set dynamically by the SDK when registered
|
||||
# We use a placeholder that will be replaced by AutoRegistry.register_provider()
|
||||
PROVIDER_NAME = ProviderName("linear")
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
@@ -30,7 +44,6 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
@@ -139,9 +152,10 @@ class LinearOAuthHandler(BaseOAuthHandler):
|
||||
|
||||
async def _request_username(self, access_token: str) -> Optional[str]:
|
||||
# Use the LinearClient to fetch user details using GraphQL
|
||||
from backend.blocks.linear._api import LinearClient
|
||||
from ._api import LinearClient
|
||||
|
||||
try:
|
||||
# Create a temporary OAuth2Credentials object for the LinearClient
|
||||
linear_client = LinearClient(
|
||||
APIKeyCredentials(
|
||||
api_key=SecretStr(access_token),
|
||||
@@ -1,24 +1,32 @@
|
||||
from backend.blocks.linear._api import LinearAPIException, LinearClient
|
||||
from backend.blocks.linear._auth import (
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import LinearAPIException, LinearClient
|
||||
from ._config import (
|
||||
LINEAR_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS_INPUT_OAUTH,
|
||||
TEST_CREDENTIALS_OAUTH,
|
||||
LinearCredentials,
|
||||
LinearCredentialsField,
|
||||
LinearCredentialsInput,
|
||||
LinearScope,
|
||||
linear,
|
||||
)
|
||||
from backend.blocks.linear.models import CreateCommentResponse
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from .models import CreateCommentResponse
|
||||
|
||||
|
||||
class LinearCreateCommentBlock(Block):
|
||||
"""Block for creating comments on Linear issues"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: LinearCredentialsInput = LinearCredentialsField(
|
||||
scopes=[LinearScope.COMMENTS_CREATE],
|
||||
credentials: CredentialsMetaInput = linear.credentials_field(
|
||||
description="Linear credentials with comment creation permissions",
|
||||
required_scopes={LinearScope.COMMENTS_CREATE},
|
||||
)
|
||||
issue_id: str = SchemaField(description="ID of the issue to comment on")
|
||||
comment: str = SchemaField(description="Comment text to add to the issue")
|
||||
@@ -55,7 +63,7 @@ class LinearCreateCommentBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
async def create_comment(
|
||||
credentials: LinearCredentials, issue_id: str, comment: str
|
||||
credentials: OAuth2Credentials | APIKeyCredentials, issue_id: str, comment: str
|
||||
) -> tuple[str, str]:
|
||||
client = LinearClient(credentials=credentials)
|
||||
response: CreateCommentResponse = await client.try_create_comment(
|
||||
@@ -64,7 +72,11 @@ class LinearCreateCommentBlock(Block):
|
||||
return response.comment.id, response.comment.body
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the comment creation"""
|
||||
try:
|
||||
|
||||
@@ -1,24 +1,32 @@
|
||||
from backend.blocks.linear._api import LinearAPIException, LinearClient
|
||||
from backend.blocks.linear._auth import (
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import LinearAPIException, LinearClient
|
||||
from ._config import (
|
||||
LINEAR_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS_INPUT_OAUTH,
|
||||
TEST_CREDENTIALS_OAUTH,
|
||||
LinearCredentials,
|
||||
LinearCredentialsField,
|
||||
LinearCredentialsInput,
|
||||
LinearScope,
|
||||
linear,
|
||||
)
|
||||
from backend.blocks.linear.models import CreateIssueResponse, Issue
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from .models import CreateIssueResponse, Issue
|
||||
|
||||
|
||||
class LinearCreateIssueBlock(Block):
|
||||
"""Block for creating issues on Linear"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: LinearCredentialsInput = LinearCredentialsField(
|
||||
scopes=[LinearScope.ISSUES_CREATE],
|
||||
credentials: CredentialsMetaInput = linear.credentials_field(
|
||||
description="Linear credentials with issue creation permissions",
|
||||
required_scopes={LinearScope.ISSUES_CREATE},
|
||||
)
|
||||
title: str = SchemaField(description="Title of the issue")
|
||||
description: str | None = SchemaField(description="Description of the issue")
|
||||
@@ -68,7 +76,7 @@ class LinearCreateIssueBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
async def create_issue(
|
||||
credentials: LinearCredentials,
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
team_name: str,
|
||||
title: str,
|
||||
description: str | None = None,
|
||||
@@ -94,7 +102,11 @@ class LinearCreateIssueBlock(Block):
|
||||
return response.issue.identifier, response.issue.title
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the issue creation"""
|
||||
try:
|
||||
@@ -121,8 +133,9 @@ class LinearSearchIssuesBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
term: str = SchemaField(description="Term to search for issues")
|
||||
credentials: LinearCredentialsInput = LinearCredentialsField(
|
||||
scopes=[LinearScope.READ],
|
||||
credentials: CredentialsMetaInput = linear.credentials_field(
|
||||
description="Linear credentials with read permissions",
|
||||
required_scopes={LinearScope.READ},
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -169,7 +182,7 @@ class LinearSearchIssuesBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
async def search_issues(
|
||||
credentials: LinearCredentials,
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
term: str,
|
||||
) -> list[Issue]:
|
||||
client = LinearClient(credentials=credentials)
|
||||
@@ -177,7 +190,11 @@ class LinearSearchIssuesBlock(Block):
|
||||
return response
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the issue search"""
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel
|
||||
from backend.sdk import BaseModel
|
||||
|
||||
|
||||
class Comment(BaseModel):
|
||||
|
||||
@@ -1,24 +1,32 @@
|
||||
from backend.blocks.linear._api import LinearAPIException, LinearClient
|
||||
from backend.blocks.linear._auth import (
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import LinearAPIException, LinearClient
|
||||
from ._config import (
|
||||
LINEAR_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS_INPUT_OAUTH,
|
||||
TEST_CREDENTIALS_OAUTH,
|
||||
LinearCredentials,
|
||||
LinearCredentialsField,
|
||||
LinearCredentialsInput,
|
||||
LinearScope,
|
||||
linear,
|
||||
)
|
||||
from backend.blocks.linear.models import Project
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from .models import Project
|
||||
|
||||
|
||||
class LinearSearchProjectsBlock(Block):
|
||||
"""Block for searching projects on Linear"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: LinearCredentialsInput = LinearCredentialsField(
|
||||
scopes=[LinearScope.READ],
|
||||
credentials: CredentialsMetaInput = linear.credentials_field(
|
||||
description="Linear credentials with read permissions",
|
||||
required_scopes={LinearScope.READ},
|
||||
)
|
||||
term: str = SchemaField(description="Term to search for projects")
|
||||
|
||||
@@ -70,7 +78,7 @@ class LinearSearchProjectsBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
async def search_projects(
|
||||
credentials: LinearCredentials,
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
term: str,
|
||||
) -> list[Project]:
|
||||
client = LinearClient(credentials=credentials)
|
||||
@@ -78,7 +86,11 @@ class LinearSearchProjectsBlock(Block):
|
||||
return response
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
"""Execute the project search"""
|
||||
try:
|
||||
|
||||
@@ -23,6 +23,7 @@ from backend.data.model import (
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.prompt import compress_prompt, estimate_token_count
|
||||
from backend.util.text import TextFormatter
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||
@@ -40,7 +41,7 @@ LLMProviderName = Literal[
|
||||
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
id="769f6af7-820b-4d5d-9b7a-ab82bbc165f",
|
||||
provider="openai",
|
||||
api_key=SecretStr("mock-openai-api-key"),
|
||||
title="Mock OpenAI API key",
|
||||
@@ -126,6 +127,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE = (
|
||||
"perplexity/llama-3.1-sonar-large-128k-online"
|
||||
)
|
||||
PERPLEXITY_SONAR = "perplexity/sonar"
|
||||
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
|
||||
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
QWEN_QWQ_32B_PREVIEW = "qwen/qwq-32b-preview"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
|
||||
@@ -228,6 +232,13 @@ MODEL_METADATA = {
|
||||
LlmModel.PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE: ModelMetadata(
|
||||
"open_router", 127072, 127072
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR: ModelMetadata("open_router", 127000, 127000),
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata("open_router", 200000, 8000),
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
128000,
|
||||
),
|
||||
LlmModel.QWEN_QWQ_32B_PREVIEW: ModelMetadata("open_router", 32768, 32768),
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata(
|
||||
"open_router", 131000, 4096
|
||||
@@ -272,6 +283,7 @@ class LLMResponse(BaseModel):
|
||||
tool_calls: Optional[List[ToolContentBlock]] | None
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
reasoning: Optional[str] = None
|
||||
|
||||
|
||||
def convert_openai_tool_fmt_to_anthropic(
|
||||
@@ -306,11 +318,44 @@ def convert_openai_tool_fmt_to_anthropic(
|
||||
return anthropic_tools
|
||||
|
||||
|
||||
def estimate_token_count(prompt_messages: list[dict]) -> int:
|
||||
char_count = sum(len(str(msg.get("content", ""))) for msg in prompt_messages)
|
||||
message_overhead = len(prompt_messages) * 4
|
||||
estimated_tokens = (char_count // 4) + message_overhead
|
||||
return int(estimated_tokens * 1.2)
|
||||
def extract_openai_reasoning(response) -> str | None:
|
||||
"""Extract reasoning from OpenAI-compatible response if available."""
|
||||
"""Note: This will likely not working since the reasoning is not present in another Response API"""
|
||||
reasoning = None
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "reasoning") and getattr(choice, "reasoning", None):
|
||||
reasoning = str(getattr(choice, "reasoning"))
|
||||
elif hasattr(response, "reasoning") and getattr(response, "reasoning", None):
|
||||
reasoning = str(getattr(response, "reasoning"))
|
||||
elif hasattr(choice.message, "reasoning") and getattr(
|
||||
choice.message, "reasoning", None
|
||||
):
|
||||
reasoning = str(getattr(choice.message, "reasoning"))
|
||||
return reasoning
|
||||
|
||||
|
||||
def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
||||
"""Extract tool calls from OpenAI-compatible response."""
|
||||
if response.choices[0].message.tool_calls:
|
||||
return [
|
||||
ToolContentBlock(
|
||||
id=tool.id,
|
||||
type=tool.type,
|
||||
function=ToolCall(
|
||||
name=tool.function.name,
|
||||
arguments=tool.function.arguments,
|
||||
),
|
||||
)
|
||||
for tool in response.choices[0].message.tool_calls
|
||||
]
|
||||
return None
|
||||
|
||||
|
||||
def get_parallel_tool_calls_param(llm_model: LlmModel, parallel_tool_calls):
|
||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||
return openai.NOT_GIVEN
|
||||
return parallel_tool_calls
|
||||
|
||||
|
||||
async def llm_call(
|
||||
@@ -321,7 +366,8 @@ async def llm_call(
|
||||
max_tokens: int | None,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
parallel_tool_calls: bool | None = None,
|
||||
parallel_tool_calls=None,
|
||||
compress_prompt_to_fit: bool = True,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Make a call to a language model.
|
||||
@@ -344,10 +390,17 @@ async def llm_call(
|
||||
- completion_tokens: The number of tokens used in the completion.
|
||||
"""
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
|
||||
if compress_prompt_to_fit:
|
||||
prompt = compress_prompt(
|
||||
messages=prompt,
|
||||
target_tokens=llm_model.context_window // 2,
|
||||
lossy_ok=True,
|
||||
)
|
||||
|
||||
# Calculate available tokens based on context window and input length
|
||||
estimated_input_tokens = estimate_token_count(prompt)
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
user_max = max_tokens or model_max_output
|
||||
available_tokens = max(context_window - estimated_input_tokens, 0)
|
||||
@@ -358,14 +411,11 @@ async def llm_call(
|
||||
oai_client = openai.AsyncOpenAI(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = None
|
||||
|
||||
if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
prompt = [
|
||||
{"role": "user", "content": "\n".join(sys_messages)},
|
||||
{"role": "user", "content": "\n".join(usr_messages)},
|
||||
]
|
||||
elif json_format:
|
||||
parallel_tool_calls = get_parallel_tool_calls_param(
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
if json_format:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
response = await oai_client.chat.completions.create(
|
||||
@@ -374,25 +424,11 @@ async def llm_call(
|
||||
response_format=response_format, # type: ignore
|
||||
max_completion_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=(
|
||||
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
|
||||
),
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
)
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
tool_calls = [
|
||||
ToolContentBlock(
|
||||
id=tool.id,
|
||||
type=tool.type,
|
||||
function=ToolCall(
|
||||
name=tool.function.name,
|
||||
arguments=tool.function.arguments,
|
||||
),
|
||||
)
|
||||
for tool in response.choices[0].message.tool_calls
|
||||
]
|
||||
else:
|
||||
tool_calls = None
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
@@ -401,6 +437,7 @@ async def llm_call(
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
|
||||
@@ -462,6 +499,12 @@ async def llm_call(
|
||||
f"Tool use stop reason but no tool calls found in content. {resp}"
|
||||
)
|
||||
|
||||
reasoning = None
|
||||
for content_block in resp.content:
|
||||
if hasattr(content_block, "type") and content_block.type == "thinking":
|
||||
reasoning = content_block.thinking
|
||||
break
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=resp,
|
||||
prompt=prompt,
|
||||
@@ -473,6 +516,7 @@ async def llm_call(
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=resp.usage.input_tokens,
|
||||
completion_tokens=resp.usage.output_tokens,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
except anthropic.APIError as e:
|
||||
error_message = f"Anthropic API error: {str(e)}"
|
||||
@@ -497,6 +541,7 @@ async def llm_call(
|
||||
tool_calls=None,
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
reasoning=None,
|
||||
)
|
||||
elif provider == "ollama":
|
||||
if tools:
|
||||
@@ -518,6 +563,7 @@ async def llm_call(
|
||||
tool_calls=None,
|
||||
prompt_tokens=response.get("prompt_eval_count") or 0,
|
||||
completion_tokens=response.get("eval_count") or 0,
|
||||
reasoning=None,
|
||||
)
|
||||
elif provider == "open_router":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
@@ -526,6 +572,10 @@ async def llm_call(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
extra_headers={
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
@@ -535,6 +585,7 @@ async def llm_call(
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
@@ -544,19 +595,8 @@ async def llm_call(
|
||||
else:
|
||||
raise ValueError("No response from OpenRouter.")
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
tool_calls = [
|
||||
ToolContentBlock(
|
||||
id=tool.id,
|
||||
type=tool.type,
|
||||
function=ToolCall(
|
||||
name=tool.function.name, arguments=tool.function.arguments
|
||||
),
|
||||
)
|
||||
for tool in response.choices[0].message.tool_calls
|
||||
]
|
||||
else:
|
||||
tool_calls = None
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
@@ -565,6 +605,7 @@ async def llm_call(
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "llama_api":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
@@ -573,6 +614,10 @@ async def llm_call(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
extra_headers={
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
@@ -582,9 +627,7 @@ async def llm_call(
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=(
|
||||
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
|
||||
),
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
@@ -594,19 +637,8 @@ async def llm_call(
|
||||
else:
|
||||
raise ValueError("No response from Llama API.")
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
tool_calls = [
|
||||
ToolContentBlock(
|
||||
id=tool.id,
|
||||
type=tool.type,
|
||||
function=ToolCall(
|
||||
name=tool.function.name, arguments=tool.function.arguments
|
||||
),
|
||||
)
|
||||
for tool in response.choices[0].message.tool_calls
|
||||
]
|
||||
else:
|
||||
tool_calls = None
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
@@ -615,6 +647,7 @@ async def llm_call(
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "aiml_api":
|
||||
client = openai.OpenAI(
|
||||
@@ -638,6 +671,7 @@ async def llm_call(
|
||||
completion_tokens=(
|
||||
completion.usage.completion_tokens if completion.usage else 0
|
||||
),
|
||||
reasoning=None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
@@ -699,7 +733,11 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate in the chat completion.",
|
||||
)
|
||||
|
||||
compress_prompt_to_fit: bool = SchemaField(
|
||||
advanced=True,
|
||||
default=True,
|
||||
description="Whether to compress the prompt to fit within the model's context window.",
|
||||
)
|
||||
ollama_host: str = SchemaField(
|
||||
advanced=True,
|
||||
default="localhost:11434",
|
||||
@@ -747,6 +785,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
tool_calls=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
reasoning=None,
|
||||
)
|
||||
},
|
||||
)
|
||||
@@ -757,6 +796,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
llm_model: LlmModel,
|
||||
prompt: list[dict],
|
||||
json_format: bool,
|
||||
compress_prompt_to_fit: bool,
|
||||
max_tokens: int | None,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
@@ -774,6 +814,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
ollama_host=ollama_host,
|
||||
compress_prompt_to_fit=compress_prompt_to_fit,
|
||||
)
|
||||
|
||||
async def run(
|
||||
@@ -832,7 +873,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
except JSONDecodeError as e:
|
||||
return f"JSON decode error: {e}"
|
||||
|
||||
logger.info(f"LLM request: {prompt}")
|
||||
logger.debug(f"LLM request: {prompt}")
|
||||
retry_prompt = ""
|
||||
llm_model = input_data.model
|
||||
|
||||
@@ -842,6 +883,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
compress_prompt_to_fit=input_data.compress_prompt_to_fit,
|
||||
json_format=bool(input_data.expected_format),
|
||||
ollama_host=input_data.ollama_host,
|
||||
max_tokens=input_data.max_tokens,
|
||||
@@ -853,7 +895,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
output_token_count=llm_response.completion_tokens,
|
||||
)
|
||||
)
|
||||
logger.info(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
if input_data.expected_format:
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from backend.data.model import (
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
id="8cc8b2c5-d3e4-4b1c-84ad-e1e9fe2a0122",
|
||||
provider="mem0",
|
||||
api_key=SecretStr("mock-mem0-api-key"),
|
||||
title="Mock Mem0 API key",
|
||||
@@ -67,17 +67,19 @@ class AddMemoryBlock(Block, Mem0Base):
|
||||
metadata: dict[str, Any] = SchemaField(
|
||||
description="Optional metadata for the memory", default_factory=dict
|
||||
)
|
||||
|
||||
limit_memory_to_run: bool = SchemaField(
|
||||
description="Limit the memory to the run", default=False
|
||||
)
|
||||
limit_memory_to_agent: bool = SchemaField(
|
||||
description="Limit the memory to the agent", default=False
|
||||
description="Limit the memory to the agent", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
action: str = SchemaField(description="Action of the operation")
|
||||
memory: str = SchemaField(description="Memory created")
|
||||
results: list[dict[str, str]] = SchemaField(
|
||||
description="List of all results from the operation"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if operation fails")
|
||||
|
||||
def __init__(self):
|
||||
@@ -104,7 +106,14 @@ class AddMemoryBlock(Block, Mem0Base):
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
],
|
||||
test_output=[("action", "NO_CHANGE"), ("action", "NO_CHANGE")],
|
||||
test_output=[
|
||||
("results", [{"event": "CREATED", "memory": "test memory"}]),
|
||||
("action", "CREATED"),
|
||||
("memory", "test memory"),
|
||||
("results", [{"event": "CREATED", "memory": "test memory"}]),
|
||||
("action", "CREATED"),
|
||||
("memory", "test memory"),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={"_get_client": lambda credentials: MockMemoryClient()},
|
||||
)
|
||||
@@ -117,7 +126,7 @@ class AddMemoryBlock(Block, Mem0Base):
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
client = self._get_client(credentials)
|
||||
@@ -146,8 +155,11 @@ class AddMemoryBlock(Block, Mem0Base):
|
||||
**params,
|
||||
)
|
||||
|
||||
if len(result.get("results", [])) > 0:
|
||||
for result in result.get("results", []):
|
||||
results = result.get("results", [])
|
||||
yield "results", results
|
||||
|
||||
if len(results) > 0:
|
||||
for result in results:
|
||||
yield "action", result["event"]
|
||||
yield "memory", result["memory"]
|
||||
else:
|
||||
@@ -178,6 +190,10 @@ class SearchMemoryBlock(Block, Mem0Base):
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
metadata_filter: Optional[dict[str, Any]] = SchemaField(
|
||||
description="Optional metadata filters to apply",
|
||||
default=None,
|
||||
)
|
||||
limit_memory_to_run: bool = SchemaField(
|
||||
description="Limit the memory to the run", default=False
|
||||
)
|
||||
@@ -216,7 +232,7 @@ class SearchMemoryBlock(Block, Mem0Base):
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
client = self._get_client(credentials)
|
||||
@@ -235,6 +251,8 @@ class SearchMemoryBlock(Block, Mem0Base):
|
||||
filters["AND"].append({"run_id": graph_exec_id})
|
||||
if input_data.limit_memory_to_agent:
|
||||
filters["AND"].append({"agent_id": graph_id})
|
||||
if input_data.metadata_filter:
|
||||
filters["AND"].append({"metadata": input_data.metadata_filter})
|
||||
|
||||
result: list[dict[str, Any]] = client.search(
|
||||
input_data.query, version="v2", filters=filters
|
||||
@@ -260,11 +278,15 @@ class GetAllMemoriesBlock(Block, Mem0Base):
|
||||
categories: Optional[list[str]] = SchemaField(
|
||||
description="Filter by categories", default=None
|
||||
)
|
||||
metadata_filter: Optional[dict[str, Any]] = SchemaField(
|
||||
description="Optional metadata filters to apply",
|
||||
default=None,
|
||||
)
|
||||
limit_memory_to_run: bool = SchemaField(
|
||||
description="Limit the memory to the run", default=False
|
||||
)
|
||||
limit_memory_to_agent: bool = SchemaField(
|
||||
description="Limit the memory to the agent", default=False
|
||||
description="Limit the memory to the agent", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -274,11 +296,11 @@ class GetAllMemoriesBlock(Block, Mem0Base):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="45aee5bf-4767-45d1-a28b-e01c5aae9fc1",
|
||||
description="Retrieve all memories from Mem0 with pagination",
|
||||
description="Retrieve all memories from Mem0 with optional conversation filtering",
|
||||
input_schema=GetAllMemoriesBlock.Input,
|
||||
output_schema=GetAllMemoriesBlock.Output,
|
||||
test_input={
|
||||
"user_id": "test_user",
|
||||
"metadata_filter": {"type": "test"},
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
@@ -296,7 +318,7 @@ class GetAllMemoriesBlock(Block, Mem0Base):
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
client = self._get_client(credentials)
|
||||
@@ -314,6 +336,8 @@ class GetAllMemoriesBlock(Block, Mem0Base):
|
||||
filters["AND"].append(
|
||||
{"categories": {"contains": input_data.categories}}
|
||||
)
|
||||
if input_data.metadata_filter:
|
||||
filters["AND"].append({"metadata": input_data.metadata_filter})
|
||||
|
||||
memories: list[dict[str, Any]] = client.get_all(
|
||||
filters=filters,
|
||||
@@ -326,14 +350,116 @@ class GetAllMemoriesBlock(Block, Mem0Base):
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GetLatestMemoryBlock(Block, Mem0Base):
|
||||
"""Block for retrieving the latest memory from Mem0"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.MEM0], Literal["api_key"]
|
||||
] = CredentialsField(description="Mem0 API key credentials")
|
||||
trigger: bool = SchemaField(
|
||||
description="An unused field that is used to trigger the block when you have no other inputs",
|
||||
default=False,
|
||||
advanced=False,
|
||||
)
|
||||
categories: Optional[list[str]] = SchemaField(
|
||||
description="Filter by categories", default=None
|
||||
)
|
||||
conversation_id: Optional[str] = SchemaField(
|
||||
description="Optional conversation ID to retrieve the latest memory from (uses run_id)",
|
||||
default=None,
|
||||
)
|
||||
metadata_filter: Optional[dict[str, Any]] = SchemaField(
|
||||
description="Optional metadata filters to apply",
|
||||
default=None,
|
||||
)
|
||||
limit_memory_to_run: bool = SchemaField(
|
||||
description="Limit the memory to the run", default=False
|
||||
)
|
||||
limit_memory_to_agent: bool = SchemaField(
|
||||
description="Limit the memory to the agent", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
memory: Optional[dict[str, Any]] = SchemaField(
|
||||
description="Latest memory if found"
|
||||
)
|
||||
found: bool = SchemaField(description="Whether a memory was found")
|
||||
error: str = SchemaField(description="Error message if operation fails")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0f9d81b5-a145-4c23-b87f-01d6bf37b677",
|
||||
description="Retrieve the latest memory from Mem0 with optional key filtering",
|
||||
input_schema=GetLatestMemoryBlock.Input,
|
||||
output_schema=GetLatestMemoryBlock.Output,
|
||||
test_input={
|
||||
"metadata_filter": {"type": "test"},
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
("memory", {"id": "test-memory", "content": "test content"}),
|
||||
("found", True),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={"_get_client": lambda credentials: MockMemoryClient()},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
client = self._get_client(credentials)
|
||||
|
||||
filters: Filter = {
|
||||
"AND": [
|
||||
{"user_id": user_id},
|
||||
]
|
||||
}
|
||||
if input_data.limit_memory_to_run:
|
||||
filters["AND"].append({"run_id": graph_exec_id})
|
||||
if input_data.limit_memory_to_agent:
|
||||
filters["AND"].append({"agent_id": graph_id})
|
||||
if input_data.categories:
|
||||
filters["AND"].append(
|
||||
{"categories": {"contains": input_data.categories}}
|
||||
)
|
||||
if input_data.metadata_filter:
|
||||
filters["AND"].append({"metadata": input_data.metadata_filter})
|
||||
|
||||
memories: list[dict[str, Any]] = client.get_all(
|
||||
filters=filters,
|
||||
version="v2",
|
||||
)
|
||||
|
||||
if memories:
|
||||
# Return the latest memory (first in the list as they're sorted by recency)
|
||||
latest_memory = memories[0]
|
||||
yield "memory", latest_memory
|
||||
yield "found", True
|
||||
else:
|
||||
yield "memory", None
|
||||
yield "found", False
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
# Mock client for testing
|
||||
class MockMemoryClient:
|
||||
"""Mock Mem0 client for testing"""
|
||||
|
||||
def add(self, *args, **kwargs):
|
||||
return {"memory_id": "test-memory-id", "status": "success"}
|
||||
return {"results": [{"event": "CREATED", "memory": "test memory"}]}
|
||||
|
||||
def search(self, *args, **kwargs) -> list[dict[str, str]]:
|
||||
def search(self, *args, **kwargs) -> list[dict[str, Any]]:
|
||||
return [{"id": "test-memory", "content": "test content"}]
|
||||
|
||||
def get_all(self, *args, **kwargs) -> list[dict[str, str]]:
|
||||
|
||||
155
autogpt_platform/backend/backend/blocks/persistence.py
Normal file
155
autogpt_platform/backend/backend/blocks/persistence.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client():
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
|
||||
|
||||
|
||||
StorageScope = Literal["within_agent", "across_agents"]
|
||||
|
||||
|
||||
def get_storage_key(key: str, scope: StorageScope, graph_id: str) -> str:
|
||||
"""Generate the storage key based on scope"""
|
||||
if scope == "across_agents":
|
||||
return f"global#{key}"
|
||||
else:
|
||||
return f"agent#{graph_id}#{key}"
|
||||
|
||||
|
||||
class PersistInformationBlock(Block):
|
||||
"""Block for persisting key-value data for the current user with configurable scope"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
key: str = SchemaField(description="Key to store the information under")
|
||||
value: Any = SchemaField(description="Value to store")
|
||||
scope: StorageScope = SchemaField(
|
||||
description="Scope of persistence: within_agent (shared across all runs of this agent) or across_agents (shared across all agents for this user)",
|
||||
default="within_agent",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
value: Any = SchemaField(description="Value that was stored")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1d055e55-a2b9-4547-8311-907d05b0304d",
|
||||
description="Persist key-value information for the current user",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=PersistInformationBlock.Input,
|
||||
output_schema=PersistInformationBlock.Output,
|
||||
test_input={
|
||||
"key": "user_preference",
|
||||
"value": {"theme": "dark", "language": "en"},
|
||||
"scope": "within_agent",
|
||||
},
|
||||
test_output=[
|
||||
("value", {"theme": "dark", "language": "en"}),
|
||||
],
|
||||
test_mock={
|
||||
"_store_data": lambda *args, **kwargs: {
|
||||
"theme": "dark",
|
||||
"language": "en",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
node_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Determine the storage key based on scope
|
||||
storage_key = get_storage_key(input_data.key, input_data.scope, graph_id)
|
||||
|
||||
# Store the data
|
||||
yield "value", await self._store_data(
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
key=storage_key,
|
||||
data=input_data.value,
|
||||
)
|
||||
|
||||
async def _store_data(
|
||||
self, user_id: str, node_exec_id: str, key: str, data: Any
|
||||
) -> Any | None:
|
||||
return await get_database_manager_client().set_execution_kv_data(
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
key=key,
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
class RetrieveInformationBlock(Block):
|
||||
"""Block for retrieving key-value data for the current user with configurable scope"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
key: str = SchemaField(description="Key to retrieve the information for")
|
||||
scope: StorageScope = SchemaField(
|
||||
description="Scope of persistence: within_agent (shared across all runs of this agent) or across_agents (shared across all agents for this user)",
|
||||
default="within_agent",
|
||||
)
|
||||
default_value: Any = SchemaField(
|
||||
description="Default value to return if key is not found", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
value: Any = SchemaField(description="Retrieved value or default value")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d8710fc9-6e29-481e-a7d5-165eb16f8471",
|
||||
description="Retrieve key-value information for the current user",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=RetrieveInformationBlock.Input,
|
||||
output_schema=RetrieveInformationBlock.Output,
|
||||
test_input={
|
||||
"key": "user_preference",
|
||||
"scope": "within_agent",
|
||||
"default_value": {"theme": "light", "language": "en"},
|
||||
},
|
||||
test_output=[
|
||||
("value", {"theme": "light", "language": "en"}),
|
||||
],
|
||||
test_mock={"_retrieve_data": lambda *args, **kwargs: None},
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, user_id: str, graph_id: str, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Determine the storage key based on scope
|
||||
storage_key = get_storage_key(input_data.key, input_data.scope, graph_id)
|
||||
|
||||
# Retrieve the data
|
||||
stored_value = await self._retrieve_data(
|
||||
user_id=user_id,
|
||||
key=storage_key,
|
||||
)
|
||||
|
||||
if stored_value is not None:
|
||||
yield "value", stored_value
|
||||
else:
|
||||
yield "value", input_data.default_value
|
||||
|
||||
async def _retrieve_data(self, user_id: str, key: str) -> Any | None:
|
||||
return await get_database_manager_client().get_execution_kv_data(
|
||||
user_id=user_id,
|
||||
key=key,
|
||||
)
|
||||
@@ -96,6 +96,7 @@ class GetRedditPostsBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
post: RedditPost = SchemaField(description="Reddit post")
|
||||
posts: list[RedditPost] = SchemaField(description="List of all Reddit posts")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -128,6 +129,23 @@ class GetRedditPostsBlock(Block):
|
||||
id="id2", subreddit="subreddit", title="title2", body="body2"
|
||||
),
|
||||
),
|
||||
(
|
||||
"posts",
|
||||
[
|
||||
RedditPost(
|
||||
id="id1",
|
||||
subreddit="subreddit",
|
||||
title="title1",
|
||||
body="body1",
|
||||
),
|
||||
RedditPost(
|
||||
id="id2",
|
||||
subreddit="subreddit",
|
||||
title="title2",
|
||||
body="body2",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_posts": lambda input_data, credentials: [
|
||||
@@ -150,6 +168,7 @@ class GetRedditPostsBlock(Block):
|
||||
self, input_data: Input, *, credentials: RedditCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
current_time = datetime.now(tz=timezone.utc)
|
||||
all_posts = []
|
||||
for post in self.get_posts(input_data=input_data, credentials=credentials):
|
||||
if input_data.last_minutes:
|
||||
post_datetime = datetime.fromtimestamp(
|
||||
@@ -162,12 +181,16 @@ class GetRedditPostsBlock(Block):
|
||||
if input_data.last_post and post.id == input_data.last_post:
|
||||
break
|
||||
|
||||
yield "post", RedditPost(
|
||||
reddit_post = RedditPost(
|
||||
id=post.id,
|
||||
subreddit=input_data.subreddit,
|
||||
title=post.title,
|
||||
body=post.selftext,
|
||||
)
|
||||
all_posts.append(reddit_post)
|
||||
yield "post", reddit_post
|
||||
|
||||
yield "posts", all_posts
|
||||
|
||||
|
||||
class PostRedditCommentBlock(Block):
|
||||
|
||||
@@ -40,6 +40,7 @@ class ReadRSSFeedBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
entry: RSSEntry = SchemaField(description="The RSS item")
|
||||
entries: list[RSSEntry] = SchemaField(description="List of all RSS entries")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -66,6 +67,21 @@ class ReadRSSFeedBlock(Block):
|
||||
categories=["Technology", "News"],
|
||||
),
|
||||
),
|
||||
(
|
||||
"entries",
|
||||
[
|
||||
RSSEntry(
|
||||
title="Example RSS Item",
|
||||
link="https://example.com/article",
|
||||
description="This is an example RSS item description.",
|
||||
pub_date=datetime(
|
||||
2023, 6, 23, 12, 30, 0, tzinfo=timezone.utc
|
||||
),
|
||||
author="John Doe",
|
||||
categories=["Technology", "News"],
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"parse_feed": lambda *args, **kwargs: {
|
||||
@@ -96,21 +112,22 @@ class ReadRSSFeedBlock(Block):
|
||||
keep_going = input_data.run_continuously
|
||||
|
||||
feed = self.parse_feed(input_data.rss_url)
|
||||
all_entries = []
|
||||
|
||||
for entry in feed["entries"]:
|
||||
pub_date = datetime(*entry["published_parsed"][:6], tzinfo=timezone.utc)
|
||||
|
||||
if pub_date > start_time:
|
||||
yield (
|
||||
"entry",
|
||||
RSSEntry(
|
||||
title=entry["title"],
|
||||
link=entry["link"],
|
||||
description=entry.get("summary", ""),
|
||||
pub_date=pub_date,
|
||||
author=entry.get("author", ""),
|
||||
categories=[tag["term"] for tag in entry.get("tags", [])],
|
||||
),
|
||||
rss_entry = RSSEntry(
|
||||
title=entry["title"],
|
||||
link=entry["link"],
|
||||
description=entry.get("summary", ""),
|
||||
pub_date=pub_date,
|
||||
author=entry.get("author", ""),
|
||||
categories=[tag["term"] for tag in entry.get("tags", [])],
|
||||
)
|
||||
all_entries.append(rss_entry)
|
||||
yield "entry", rss_entry
|
||||
|
||||
yield "entries", all_entries
|
||||
await asyncio.sleep(input_data.polling_rate)
|
||||
|
||||
@@ -26,10 +26,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client():
|
||||
from backend.executor import DatabaseManagerClient
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
|
||||
|
||||
|
||||
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||
@@ -85,7 +85,7 @@ def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
|
||||
return tool_call_ids
|
||||
|
||||
|
||||
def _create_tool_response(call_id: str, output: dict[str, Any]) -> dict[str, Any]:
|
||||
def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Create a tool response message for either OpenAI or Anthropics,
|
||||
based on the tool_id format.
|
||||
@@ -212,6 +212,15 @@ class SmartDecisionMakerBlock(Block):
|
||||
"link like the output of `StoreValue` or `AgentInput` block"
|
||||
)
|
||||
|
||||
# Check that both conversation_history and last_tool_output are connected together
|
||||
if any(link.sink_name == "conversation_history" for link in links) != any(
|
||||
link.sink_name == "last_tool_output" for link in links
|
||||
):
|
||||
raise ValueError(
|
||||
"Last Tool Output is needed when Conversation History is used, "
|
||||
"and vice versa. Please connect both inputs together."
|
||||
)
|
||||
|
||||
return missing_links
|
||||
|
||||
@classmethod
|
||||
@@ -222,8 +231,15 @@ class SmartDecisionMakerBlock(Block):
|
||||
conversation_history = data.get("conversation_history", [])
|
||||
pending_tool_calls = get_pending_tool_calls(conversation_history)
|
||||
last_tool_output = data.get("last_tool_output")
|
||||
if not last_tool_output and pending_tool_calls:
|
||||
|
||||
# Tool call is pending, wait for the tool output to be provided.
|
||||
if last_tool_output is None and pending_tool_calls:
|
||||
return {"last_tool_output"}
|
||||
|
||||
# No tool call is pending, wait for the conversation history to be updated.
|
||||
if last_tool_output is not None and not pending_tool_calls:
|
||||
return {"conversation_history"}
|
||||
|
||||
return set()
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -257,7 +273,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
|
||||
|
||||
@staticmethod
|
||||
def _create_block_function_signature(
|
||||
async def _create_block_function_signature(
|
||||
sink_node: "Node", links: list["Link"]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
@@ -296,7 +312,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
def _create_agent_function_signature(
|
||||
async def _create_agent_function_signature(
|
||||
sink_node: "Node", links: list["Link"]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
@@ -318,7 +334,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
raise ValueError("Graph ID or Graph Version not found in sink node.")
|
||||
|
||||
db_client = get_database_manager_client()
|
||||
sink_graph_meta = db_client.get_graph_metadata(graph_id, graph_version)
|
||||
sink_graph_meta = await db_client.get_graph_metadata(graph_id, graph_version)
|
||||
if not sink_graph_meta:
|
||||
raise ValueError(
|
||||
f"Sink graph metadata not found: {graph_id} {graph_version}"
|
||||
@@ -358,7 +374,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
|
||||
async def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Creates function signatures for tools linked to a specified node within a graph.
|
||||
|
||||
@@ -380,13 +396,13 @@ class SmartDecisionMakerBlock(Block):
|
||||
db_client = get_database_manager_client()
|
||||
tools = [
|
||||
(link, node)
|
||||
for link, node in db_client.get_connected_output_nodes(node_id)
|
||||
for link, node in await db_client.get_connected_output_nodes(node_id)
|
||||
if link.source_name.startswith("tools_^_") and link.source_id == node_id
|
||||
]
|
||||
if not tools:
|
||||
raise ValueError("There is no next node to execute.")
|
||||
|
||||
return_tool_functions = []
|
||||
return_tool_functions: list[dict[str, Any]] = []
|
||||
|
||||
grouped_tool_links: dict[str, tuple["Node", list["Link"]]] = {}
|
||||
for link, node in tools:
|
||||
@@ -401,13 +417,13 @@ class SmartDecisionMakerBlock(Block):
|
||||
|
||||
if sink_node.block_id == AgentExecutorBlock().id:
|
||||
return_tool_functions.append(
|
||||
SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
else:
|
||||
return_tool_functions.append(
|
||||
SmartDecisionMakerBlock._create_block_function_signature(
|
||||
await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
@@ -426,38 +442,43 @@ class SmartDecisionMakerBlock(Block):
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
tool_functions = self._create_function_signature(node_id)
|
||||
tool_functions = await self._create_function_signature(node_id)
|
||||
yield "tool_functions", json.dumps(tool_functions)
|
||||
|
||||
input_data.conversation_history = input_data.conversation_history or []
|
||||
prompt = [json.to_dict(p) for p in input_data.conversation_history if p]
|
||||
|
||||
pending_tool_calls = get_pending_tool_calls(input_data.conversation_history)
|
||||
if pending_tool_calls and not input_data.last_tool_output:
|
||||
if pending_tool_calls and input_data.last_tool_output is None:
|
||||
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
|
||||
|
||||
# Prefill all missing tool calls with the last tool output/
|
||||
# TODO: we need a better way to handle this.
|
||||
tool_output = [
|
||||
_create_tool_response(pending_call_id, input_data.last_tool_output)
|
||||
for pending_call_id, count in pending_tool_calls.items()
|
||||
for _ in range(count)
|
||||
]
|
||||
|
||||
# If the SDM block only calls 1 tool at a time, this should not happen.
|
||||
if len(tool_output) > 1:
|
||||
logger.warning(
|
||||
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
|
||||
f"Multiple pending tool calls are prefilled using a single output. "
|
||||
f"Execution may not be accurate."
|
||||
# Only assign the last tool output to the first pending tool call
|
||||
tool_output = []
|
||||
if pending_tool_calls and input_data.last_tool_output is not None:
|
||||
# Get the first pending tool call ID
|
||||
first_call_id = next(iter(pending_tool_calls.keys()))
|
||||
tool_output.append(
|
||||
_create_tool_response(first_call_id, input_data.last_tool_output)
|
||||
)
|
||||
|
||||
# Add tool output to prompt right away
|
||||
prompt.extend(tool_output)
|
||||
|
||||
# Check if there are still pending tool calls after handling the first one
|
||||
remaining_pending_calls = get_pending_tool_calls(prompt)
|
||||
|
||||
# If there are still pending tool calls, yield the conversation and return early
|
||||
if remaining_pending_calls:
|
||||
yield "conversations", prompt
|
||||
return
|
||||
|
||||
# Fallback on adding tool output in the conversation history as user prompt.
|
||||
if len(tool_output) == 0 and input_data.last_tool_output:
|
||||
logger.warning(
|
||||
elif input_data.last_tool_output:
|
||||
logger.error(
|
||||
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
|
||||
f"No pending tool calls found. This may indicate an issue with the "
|
||||
f"conversation history, or an LLM calling two tools at the same time."
|
||||
f"conversation history, or the tool giving response more than once."
|
||||
f"This should not happen! Please check the conversation history for any inconsistencies."
|
||||
)
|
||||
tool_output.append(
|
||||
{
|
||||
@@ -465,8 +486,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
"content": f"Last tool output: {json.dumps(input_data.last_tool_output)}",
|
||||
}
|
||||
)
|
||||
|
||||
prompt.extend(tool_output)
|
||||
prompt.extend(tool_output)
|
||||
if input_data.multiple_tool_calls:
|
||||
input_data.sys_prompt += "\nYou can call a tool (different tools) multiple times in a single response."
|
||||
else:
|
||||
@@ -497,7 +517,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=True if input_data.multiple_tool_calls else None,
|
||||
parallel_tool_calls=input_data.multiple_tool_calls,
|
||||
)
|
||||
|
||||
if not response.tool_calls:
|
||||
@@ -534,5 +554,11 @@ class SmartDecisionMakerBlock(Block):
|
||||
else:
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", None
|
||||
|
||||
response.prompt.append(response.raw_response)
|
||||
yield "conversations", response.prompt
|
||||
# Add reasoning to conversation history if available
|
||||
if response.reasoning:
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
|
||||
)
|
||||
|
||||
prompt.append(response.raw_response)
|
||||
yield "conversations", prompt
|
||||
|
||||
@@ -17,7 +17,7 @@ from backend.blocks.smartlead.models import (
|
||||
Sequence,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class CreateCampaignBlock(Block):
|
||||
@@ -27,7 +27,7 @@ class CreateCampaignBlock(Block):
|
||||
name: str = SchemaField(
|
||||
description="The name of the campaign",
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
credentials: SmartLeadCredentialsInput = CredentialsField(
|
||||
description="SmartLead credentials",
|
||||
)
|
||||
|
||||
@@ -119,7 +119,7 @@ class AddLeadToCampaignBlock(Block):
|
||||
description="Settings for lead upload",
|
||||
default=LeadUploadSettings(),
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
credentials: SmartLeadCredentialsInput = CredentialsField(
|
||||
description="SmartLead credentials",
|
||||
)
|
||||
|
||||
@@ -251,7 +251,7 @@ class SaveCampaignSequencesBlock(Block):
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
credentials: SmartLeadCredentialsInput = CredentialsField(
|
||||
description="SmartLead credentials",
|
||||
)
|
||||
|
||||
|
||||
125
autogpt_platform/backend/backend/blocks/test/test_block.py
Normal file
125
autogpt_platform/backend/backend/blocks/test/test_block.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.block import Block, get_blocks
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", get_blocks().values(), ids=lambda b: b.name)
|
||||
async def test_available_blocks(block: Type[Block]):
|
||||
await execute_block_test(block())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", get_blocks().values(), ids=lambda b: b.name)
|
||||
async def test_block_ids_valid(block: Type[Block]):
|
||||
# add the tests here to check they are uuid4
|
||||
import uuid
|
||||
|
||||
# Skip list for blocks with known invalid UUIDs
|
||||
skip_blocks = {
|
||||
"GetWeatherInformationBlock",
|
||||
"CodeExecutionBlock",
|
||||
"CountdownTimerBlock",
|
||||
"TwitterGetListTweetsBlock",
|
||||
"TwitterRemoveListMemberBlock",
|
||||
"TwitterAddListMemberBlock",
|
||||
"TwitterGetListMembersBlock",
|
||||
"TwitterGetListMembershipsBlock",
|
||||
"TwitterUnfollowListBlock",
|
||||
"TwitterFollowListBlock",
|
||||
"TwitterUnpinListBlock",
|
||||
"TwitterPinListBlock",
|
||||
"TwitterGetPinnedListsBlock",
|
||||
"TwitterDeleteListBlock",
|
||||
"TwitterUpdateListBlock",
|
||||
"TwitterCreateListBlock",
|
||||
"TwitterGetListBlock",
|
||||
"TwitterGetOwnedListsBlock",
|
||||
"TwitterGetSpacesBlock",
|
||||
"TwitterGetSpaceByIdBlock",
|
||||
"TwitterGetSpaceBuyersBlock",
|
||||
"TwitterGetSpaceTweetsBlock",
|
||||
"TwitterSearchSpacesBlock",
|
||||
"TwitterGetUserMentionsBlock",
|
||||
"TwitterGetHomeTimelineBlock",
|
||||
"TwitterGetUserTweetsBlock",
|
||||
"TwitterGetTweetBlock",
|
||||
"TwitterGetTweetsBlock",
|
||||
"TwitterGetQuoteTweetsBlock",
|
||||
"TwitterLikeTweetBlock",
|
||||
"TwitterGetLikingUsersBlock",
|
||||
"TwitterGetLikedTweetsBlock",
|
||||
"TwitterUnlikeTweetBlock",
|
||||
"TwitterBookmarkTweetBlock",
|
||||
"TwitterGetBookmarkedTweetsBlock",
|
||||
"TwitterRemoveBookmarkTweetBlock",
|
||||
"TwitterRetweetBlock",
|
||||
"TwitterRemoveRetweetBlock",
|
||||
"TwitterGetRetweetersBlock",
|
||||
"TwitterHideReplyBlock",
|
||||
"TwitterUnhideReplyBlock",
|
||||
"TwitterPostTweetBlock",
|
||||
"TwitterDeleteTweetBlock",
|
||||
"TwitterSearchRecentTweetsBlock",
|
||||
"TwitterUnfollowUserBlock",
|
||||
"TwitterFollowUserBlock",
|
||||
"TwitterGetFollowersBlock",
|
||||
"TwitterGetFollowingBlock",
|
||||
"TwitterUnmuteUserBlock",
|
||||
"TwitterGetMutedUsersBlock",
|
||||
"TwitterMuteUserBlock",
|
||||
"TwitterGetBlockedUsersBlock",
|
||||
"TwitterGetUserBlock",
|
||||
"TwitterGetUsersBlock",
|
||||
"TodoistCreateLabelBlock",
|
||||
"TodoistListLabelsBlock",
|
||||
"TodoistGetLabelBlock",
|
||||
"TodoistUpdateLabelBlock",
|
||||
"TodoistDeleteLabelBlock",
|
||||
"TodoistGetSharedLabelsBlock",
|
||||
"TodoistRenameSharedLabelsBlock",
|
||||
"TodoistRemoveSharedLabelsBlock",
|
||||
"TodoistCreateTaskBlock",
|
||||
"TodoistGetTasksBlock",
|
||||
"TodoistGetTaskBlock",
|
||||
"TodoistUpdateTaskBlock",
|
||||
"TodoistCloseTaskBlock",
|
||||
"TodoistReopenTaskBlock",
|
||||
"TodoistDeleteTaskBlock",
|
||||
"TodoistListSectionsBlock",
|
||||
"TodoistGetSectionBlock",
|
||||
"TodoistDeleteSectionBlock",
|
||||
"TodoistCreateProjectBlock",
|
||||
"TodoistGetProjectBlock",
|
||||
"TodoistUpdateProjectBlock",
|
||||
"TodoistDeleteProjectBlock",
|
||||
"TodoistListCollaboratorsBlock",
|
||||
"TodoistGetCommentsBlock",
|
||||
"TodoistGetCommentBlock",
|
||||
"TodoistUpdateCommentBlock",
|
||||
"TodoistDeleteCommentBlock",
|
||||
"GithubListStargazersBlock",
|
||||
"Slant3DSlicerBlock",
|
||||
}
|
||||
|
||||
block_instance = block()
|
||||
|
||||
# Skip blocks with known invalid UUIDs
|
||||
if block_instance.__class__.__name__ in skip_blocks:
|
||||
pytest.skip(
|
||||
f"Skipping UUID check for {block_instance.__class__.__name__} - known invalid UUID"
|
||||
)
|
||||
|
||||
# Check that the ID is not empty
|
||||
assert block_instance.id, f"Block {block.name} has empty ID"
|
||||
|
||||
# Check that the ID is a valid UUID4
|
||||
try:
|
||||
parsed_uuid = uuid.UUID(block_instance.id)
|
||||
# Verify it's specifically UUID version 4
|
||||
assert (
|
||||
parsed_uuid.version == 4
|
||||
), f"Block {block.name} ID is UUID version {parsed_uuid.version}, expected version 4"
|
||||
except ValueError:
|
||||
pytest.fail(f"Block {block.name} has invalid UUID format: {block_instance.id}")
|
||||
485
autogpt_platform/backend/backend/blocks/test/test_http.py
Normal file
485
autogpt_platform/backend/backend/blocks/test/test_http.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""Comprehensive tests for HTTP block with HostScopedCredentials functionality."""
|
||||
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.http import (
|
||||
HttpCredentials,
|
||||
HttpMethod,
|
||||
SendAuthenticatedWebRequestBlock,
|
||||
)
|
||||
from backend.data.model import HostScopedCredentials
|
||||
from backend.util.request import Response
|
||||
|
||||
|
||||
class TestHttpBlockWithHostScopedCredentials:
|
||||
"""Test suite for HTTP block integration with HostScopedCredentials."""
|
||||
|
||||
@pytest.fixture
|
||||
def http_block(self):
|
||||
"""Create an HTTP block instance."""
|
||||
return SendAuthenticatedWebRequestBlock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response(self):
|
||||
"""Mock a successful HTTP response."""
|
||||
response = MagicMock(spec=Response)
|
||||
response.status = 200
|
||||
response.headers = {"content-type": "application/json"}
|
||||
response.json.return_value = {"success": True, "data": "test"}
|
||||
return response
|
||||
|
||||
@pytest.fixture
|
||||
def exact_match_credentials(self):
|
||||
"""Create host-scoped credentials for exact domain matching."""
|
||||
return HostScopedCredentials(
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer exact-match-token"),
|
||||
"X-API-Key": SecretStr("api-key-123"),
|
||||
},
|
||||
title="Exact Match API Credentials",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def wildcard_credentials(self):
|
||||
"""Create host-scoped credentials with wildcard pattern."""
|
||||
return HostScopedCredentials(
|
||||
provider="http",
|
||||
host="*.github.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("token ghp_wildcard123"),
|
||||
},
|
||||
title="GitHub Wildcard Credentials",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def non_matching_credentials(self):
|
||||
"""Create credentials that don't match test URLs."""
|
||||
return HostScopedCredentials(
|
||||
provider="http",
|
||||
host="different.api.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer non-matching-token"),
|
||||
},
|
||||
title="Non-matching Credentials",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_http_block_with_exact_host_match(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
exact_match_credentials,
|
||||
mock_response,
|
||||
):
|
||||
"""Test HTTP block with exact host matching credentials."""
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Prepare input data
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.GET,
|
||||
headers={"User-Agent": "test-agent"},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": exact_match_credentials.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": exact_match_credentials.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with credentials provided by execution manager
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=exact_match_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify request headers include both credential and user headers
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {
|
||||
"Authorization": "Bearer exact-match-token",
|
||||
"X-API-Key": "api-key-123",
|
||||
"User-Agent": "test-agent",
|
||||
}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
# Verify response handling
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "response"
|
||||
assert result[0][1] == {"success": True, "data": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_http_block_with_wildcard_host_match(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
wildcard_credentials,
|
||||
mock_response,
|
||||
):
|
||||
"""Test HTTP block with wildcard host pattern matching."""
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with subdomain that should match *.github.com
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.github.com/user",
|
||||
method=HttpMethod.GET,
|
||||
headers={},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": wildcard_credentials.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": wildcard_credentials.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with wildcard credentials
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=wildcard_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify wildcard matching works
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {"Authorization": "token ghp_wildcard123"}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_http_block_with_non_matching_credentials(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
non_matching_credentials,
|
||||
mock_response,
|
||||
):
|
||||
"""Test HTTP block when credentials don't match the target URL."""
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with URL that doesn't match the credentials
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.GET,
|
||||
headers={"User-Agent": "test-agent"},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": non_matching_credentials.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": non_matching_credentials.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with non-matching credentials
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=non_matching_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify only user headers are included (no credential headers)
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {"User-Agent": "test-agent"}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_user_headers_override_credential_headers(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
exact_match_credentials,
|
||||
mock_response,
|
||||
):
|
||||
"""Test that user-provided headers take precedence over credential headers."""
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with user header that conflicts with credential header
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.POST,
|
||||
headers={
|
||||
"Authorization": "Bearer user-override-token", # Should override
|
||||
"Content-Type": "application/json", # Additional user header
|
||||
},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": exact_match_credentials.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": exact_match_credentials.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with conflicting headers
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=exact_match_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify user headers take precedence
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {
|
||||
"X-API-Key": "api-key-123", # From credentials
|
||||
"Authorization": "Bearer user-override-token", # User override
|
||||
"Content-Type": "application/json", # User header
|
||||
}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_auto_discovered_credentials_flow(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
mock_response,
|
||||
):
|
||||
"""Test the auto-discovery flow where execution manager provides matching credentials."""
|
||||
# Create auto-discovered credentials
|
||||
auto_discovered_creds = HostScopedCredentials(
|
||||
provider="http",
|
||||
host="*.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer auto-discovered-token"),
|
||||
},
|
||||
title="Auto-discovered Credentials",
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with empty credentials field (triggers auto-discovery)
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.GET,
|
||||
headers={},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": "", # Empty ID triggers auto-discovery in execution manager
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": "",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with auto-discovered credentials provided by execution manager
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=auto_discovered_creds, # Execution manager found these
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify auto-discovered credentials were applied
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {"Authorization": "Bearer auto-discovered-token"}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
# Verify response handling
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "response"
|
||||
assert result[0][1] == {"success": True, "data": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_multiple_header_credentials(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
mock_response,
|
||||
):
|
||||
"""Test credentials with multiple headers are all applied."""
|
||||
# Create credentials with multiple headers
|
||||
multi_header_creds = HostScopedCredentials(
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer multi-token"),
|
||||
"X-API-Key": SecretStr("api-key-456"),
|
||||
"X-Client-ID": SecretStr("client-789"),
|
||||
"X-Custom-Header": SecretStr("custom-value"),
|
||||
},
|
||||
title="Multi-Header Credentials",
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with credentials containing multiple headers
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.GET,
|
||||
headers={"User-Agent": "test-agent"},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": multi_header_creds.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": multi_header_creds.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with multi-header credentials
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=multi_header_creds,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify all headers are included
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {
|
||||
"Authorization": "Bearer multi-token",
|
||||
"X-API-Key": "api-key-456",
|
||||
"X-Client-ID": "client-789",
|
||||
"X-Custom-Header": "custom-value",
|
||||
"User-Agent": "test-agent",
|
||||
}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_credentials_with_complex_url_patterns(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
mock_response,
|
||||
):
|
||||
"""Test credentials matching various URL patterns."""
|
||||
# Test cases for different URL patterns
|
||||
test_cases = [
|
||||
{
|
||||
"host_pattern": "api.example.com",
|
||||
"test_url": "https://api.example.com/v1/users",
|
||||
"should_match": True,
|
||||
},
|
||||
{
|
||||
"host_pattern": "*.example.com",
|
||||
"test_url": "https://api.example.com/v1/users",
|
||||
"should_match": True,
|
||||
},
|
||||
{
|
||||
"host_pattern": "*.example.com",
|
||||
"test_url": "https://subdomain.example.com/data",
|
||||
"should_match": True,
|
||||
},
|
||||
{
|
||||
"host_pattern": "api.example.com",
|
||||
"test_url": "https://api.different.com/data",
|
||||
"should_match": False,
|
||||
},
|
||||
]
|
||||
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
for case in test_cases:
|
||||
# Reset mock for each test case
|
||||
mock_requests.reset_mock()
|
||||
|
||||
# Create credentials for this test case
|
||||
test_creds = HostScopedCredentials(
|
||||
provider="http",
|
||||
host=case["host_pattern"],
|
||||
headers={
|
||||
"Authorization": SecretStr(f"Bearer {case['host_pattern']}-token"),
|
||||
},
|
||||
title=f"Credentials for {case['host_pattern']}",
|
||||
)
|
||||
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url=case["test_url"],
|
||||
method=HttpMethod.GET,
|
||||
headers={"User-Agent": "test-agent"},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": test_creds.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": test_creds.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with test credentials
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=test_creds,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify headers based on whether pattern should match
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
headers = call_args.kwargs["headers"]
|
||||
|
||||
if case["should_match"]:
|
||||
# Should include both user and credential headers
|
||||
expected_auth = f"Bearer {case['host_pattern']}-token"
|
||||
assert headers["Authorization"] == expected_auth
|
||||
assert headers["User-Agent"] == "test-agent"
|
||||
else:
|
||||
# Should only include user headers
|
||||
assert "Authorization" not in headers
|
||||
assert headers["User-Agent"] == "test-agent"
|
||||
@@ -25,12 +25,7 @@ async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Grap
|
||||
async def create_credentials(s: SpinTestServer, u: User):
|
||||
provider = ProviderName.OPENAI
|
||||
credentials = llm.TEST_CREDENTIALS
|
||||
try:
|
||||
await s.agent_server.test_create_credentials(u.id, provider, credentials)
|
||||
except Exception:
|
||||
# ValueErrors is raised trying to recreate the same credentials
|
||||
# so hidding the error
|
||||
pass
|
||||
return await s.agent_server.test_create_credentials(u.id, provider, credentials)
|
||||
|
||||
|
||||
async def execute_graph(
|
||||
@@ -60,19 +55,18 @@ async def execute_graph(
|
||||
return graph_exec_id
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
await create_credentials(server, test_user)
|
||||
creds = await create_credentials(server, test_user)
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
"credentials": creds,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
@@ -110,80 +104,18 @@ async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_graph_validation_with_tool_nodes_raises_error(server: SpinTestServer):
|
||||
|
||||
test_user = await create_test_user()
|
||||
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
await create_credentials(server, test_user)
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=AgentExecutorBlock().id,
|
||||
input_default={
|
||||
"graph_id": test_tool_graph.id,
|
||||
"graph_version": test_tool_graph.version,
|
||||
"input_schema": test_tool_graph.input_schema,
|
||||
"output_schema": test_tool_graph.output_schema,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=StoreValueBlock().id,
|
||||
),
|
||||
]
|
||||
|
||||
links = [
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[1].id,
|
||||
source_name="tools_^_sample_tool_input_1",
|
||||
sink_name="input_1",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[1].id,
|
||||
source_name="tools_^_sample_tool_input_2",
|
||||
sink_name="input_2",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="tools_^_store_value_input",
|
||||
sink_name="input",
|
||||
),
|
||||
]
|
||||
|
||||
test_graph = graph.Graph(
|
||||
name="TestGraph",
|
||||
description="Test graph",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
await create_credentials(server, test_user)
|
||||
creds = await create_credentials(server, test_user)
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
"credentials": creds,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
@@ -229,7 +161,7 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
)
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
tool_functions = SmartDecisionMakerBlock._create_function_signature(
|
||||
tool_functions = await SmartDecisionMakerBlock._create_function_signature(
|
||||
test_graph.nodes[0].id
|
||||
)
|
||||
assert tool_functions is not None, "Tool functions should not be None"
|
||||
@@ -15,7 +15,7 @@ from backend.blocks.zerobounce._auth import (
|
||||
ZeroBounceCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
@@ -90,7 +90,7 @@ class ValidateEmailsBlock(Block):
|
||||
description="IP address to validate",
|
||||
default="",
|
||||
)
|
||||
credentials: ZeroBounceCredentialsInput = SchemaField(
|
||||
credentials: ZeroBounceCredentialsInput = CredentialsField(
|
||||
description="ZeroBounce credentials",
|
||||
)
|
||||
|
||||
|
||||
359
autogpt_platform/backend/backend/check_db.py
Normal file
359
autogpt_platform/backend/backend/check_db.py
Normal file
@@ -0,0 +1,359 @@
|
||||
import asyncio
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
from faker import Faker
|
||||
from prisma import Prisma
|
||||
|
||||
faker = Faker()
|
||||
|
||||
|
||||
async def check_cron_job(db):
|
||||
"""Check if the pg_cron job for refreshing materialized views exists."""
|
||||
print("\n1. Checking pg_cron job...")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
# Check if pg_cron extension exists
|
||||
extension_check = await db.query_raw("CREATE EXTENSION pg_cron;")
|
||||
print(extension_check)
|
||||
extension_check = await db.query_raw(
|
||||
"SELECT COUNT(*) as count FROM pg_extension WHERE extname = 'pg_cron'"
|
||||
)
|
||||
if extension_check[0]["count"] == 0:
|
||||
print("⚠️ pg_cron extension is not installed")
|
||||
return False
|
||||
|
||||
# Check if the refresh job exists
|
||||
job_check = await db.query_raw(
|
||||
"""
|
||||
SELECT jobname, schedule, command
|
||||
FROM cron.job
|
||||
WHERE jobname = 'refresh-store-views'
|
||||
"""
|
||||
)
|
||||
|
||||
if job_check:
|
||||
job = job_check[0]
|
||||
print("✅ pg_cron job found:")
|
||||
print(f" Name: {job['jobname']}")
|
||||
print(f" Schedule: {job['schedule']} (every 15 minutes)")
|
||||
print(f" Command: {job['command']}")
|
||||
return True
|
||||
else:
|
||||
print("⚠️ pg_cron job 'refresh-store-views' not found")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error checking pg_cron: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_materialized_view_counts(db):
|
||||
"""Get current counts from materialized views."""
|
||||
print("\n2. Getting current materialized view data...")
|
||||
print("-" * 40)
|
||||
|
||||
# Get counts from mv_agent_run_counts
|
||||
agent_runs = await db.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as total_agents,
|
||||
SUM(run_count) as total_runs,
|
||||
MAX(run_count) as max_runs,
|
||||
MIN(run_count) as min_runs
|
||||
FROM mv_agent_run_counts
|
||||
"""
|
||||
)
|
||||
|
||||
# Get counts from mv_review_stats
|
||||
review_stats = await db.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as total_listings,
|
||||
SUM(review_count) as total_reviews,
|
||||
AVG(avg_rating) as overall_avg_rating
|
||||
FROM mv_review_stats
|
||||
"""
|
||||
)
|
||||
|
||||
# Get sample data from StoreAgent view
|
||||
store_agents = await db.query_raw(
|
||||
"""
|
||||
SELECT COUNT(*) as total_store_agents,
|
||||
AVG(runs) as avg_runs,
|
||||
AVG(rating) as avg_rating
|
||||
FROM "StoreAgent"
|
||||
"""
|
||||
)
|
||||
|
||||
agent_run_data = agent_runs[0] if agent_runs else {}
|
||||
review_data = review_stats[0] if review_stats else {}
|
||||
store_data = store_agents[0] if store_agents else {}
|
||||
|
||||
print("📊 mv_agent_run_counts:")
|
||||
print(f" Total agents: {agent_run_data.get('total_agents', 0)}")
|
||||
print(f" Total runs: {agent_run_data.get('total_runs', 0)}")
|
||||
print(f" Max runs per agent: {agent_run_data.get('max_runs', 0)}")
|
||||
print(f" Min runs per agent: {agent_run_data.get('min_runs', 0)}")
|
||||
|
||||
print("\n📊 mv_review_stats:")
|
||||
print(f" Total listings: {review_data.get('total_listings', 0)}")
|
||||
print(f" Total reviews: {review_data.get('total_reviews', 0)}")
|
||||
print(f" Overall avg rating: {review_data.get('overall_avg_rating') or 0:.2f}")
|
||||
|
||||
print("\n📊 StoreAgent view:")
|
||||
print(f" Total store agents: {store_data.get('total_store_agents', 0)}")
|
||||
print(f" Average runs: {store_data.get('avg_runs') or 0:.2f}")
|
||||
print(f" Average rating: {store_data.get('avg_rating') or 0:.2f}")
|
||||
|
||||
return {
|
||||
"agent_runs": agent_run_data,
|
||||
"reviews": review_data,
|
||||
"store_agents": store_data,
|
||||
}
|
||||
|
||||
|
||||
async def add_test_data(db):
|
||||
"""Add some test data to verify materialized view updates."""
|
||||
print("\n3. Adding test data...")
|
||||
print("-" * 40)
|
||||
|
||||
# Get some existing data
|
||||
users = await db.user.find_many(take=5)
|
||||
graphs = await db.agentgraph.find_many(take=5)
|
||||
|
||||
if not users or not graphs:
|
||||
print("❌ No existing users or graphs found. Run test_data_creator.py first.")
|
||||
return False
|
||||
|
||||
# Add new executions
|
||||
print("Adding new agent graph executions...")
|
||||
new_executions = 0
|
||||
for graph in graphs:
|
||||
for _ in range(random.randint(2, 5)):
|
||||
await db.agentgraphexecution.create(
|
||||
data={
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"userId": random.choice(users).id,
|
||||
"executionStatus": "COMPLETED",
|
||||
"startedAt": datetime.now(),
|
||||
}
|
||||
)
|
||||
new_executions += 1
|
||||
|
||||
print(f"✅ Added {new_executions} new executions")
|
||||
|
||||
# Check if we need to create store listings first
|
||||
store_versions = await db.storelistingversion.find_many(
|
||||
where={"submissionStatus": "APPROVED"}, take=5
|
||||
)
|
||||
|
||||
if not store_versions:
|
||||
print("\nNo approved store listings found. Creating test store listings...")
|
||||
|
||||
# Create store listings for existing agent graphs
|
||||
for i, graph in enumerate(graphs[:3]): # Create up to 3 store listings
|
||||
# Create a store listing
|
||||
listing = await db.storelisting.create(
|
||||
data={
|
||||
"slug": f"test-agent-{graph.id[:8]}",
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"hasApprovedVersion": True,
|
||||
"owningUserId": graph.userId,
|
||||
}
|
||||
)
|
||||
|
||||
# Create an approved version
|
||||
version = await db.storelistingversion.create(
|
||||
data={
|
||||
"storeListingId": listing.id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"name": f"Test Agent {i+1}",
|
||||
"subHeading": faker.catch_phrase(),
|
||||
"description": faker.paragraph(nb_sentences=5),
|
||||
"imageUrls": [faker.image_url()],
|
||||
"categories": ["productivity", "automation"],
|
||||
"submissionStatus": "APPROVED",
|
||||
"submittedAt": datetime.now(),
|
||||
}
|
||||
)
|
||||
|
||||
# Update listing with active version
|
||||
await db.storelisting.update(
|
||||
where={"id": listing.id}, data={"activeVersionId": version.id}
|
||||
)
|
||||
|
||||
print("✅ Created test store listings")
|
||||
|
||||
# Re-fetch approved versions
|
||||
store_versions = await db.storelistingversion.find_many(
|
||||
where={"submissionStatus": "APPROVED"}, take=5
|
||||
)
|
||||
|
||||
# Add new reviews
|
||||
print("\nAdding new store listing reviews...")
|
||||
new_reviews = 0
|
||||
for version in store_versions:
|
||||
# Find users who haven't reviewed this version
|
||||
existing_reviews = await db.storelistingreview.find_many(
|
||||
where={"storeListingVersionId": version.id}
|
||||
)
|
||||
reviewed_user_ids = {r.reviewByUserId for r in existing_reviews}
|
||||
available_users = [u for u in users if u.id not in reviewed_user_ids]
|
||||
|
||||
if available_users:
|
||||
user = random.choice(available_users)
|
||||
await db.storelistingreview.create(
|
||||
data={
|
||||
"storeListingVersionId": version.id,
|
||||
"reviewByUserId": user.id,
|
||||
"score": random.randint(3, 5),
|
||||
"comments": faker.text(max_nb_chars=100),
|
||||
}
|
||||
)
|
||||
new_reviews += 1
|
||||
|
||||
print(f"✅ Added {new_reviews} new reviews")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def refresh_materialized_views(db):
|
||||
"""Manually refresh the materialized views."""
|
||||
print("\n4. Manually refreshing materialized views...")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
await db.execute_raw("SELECT refresh_store_materialized_views();")
|
||||
print("✅ Materialized views refreshed successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"❌ Error refreshing views: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def compare_counts(before, after):
|
||||
"""Compare counts before and after refresh."""
|
||||
print("\n5. Comparing counts before and after refresh...")
|
||||
print("-" * 40)
|
||||
|
||||
# Compare agent runs
|
||||
print("🔍 Agent run changes:")
|
||||
before_runs = before["agent_runs"].get("total_runs") or 0
|
||||
after_runs = after["agent_runs"].get("total_runs") or 0
|
||||
print(
|
||||
f" Total runs: {before_runs} → {after_runs} " f"(+{after_runs - before_runs})"
|
||||
)
|
||||
|
||||
# Compare reviews
|
||||
print("\n🔍 Review changes:")
|
||||
before_reviews = before["reviews"].get("total_reviews") or 0
|
||||
after_reviews = after["reviews"].get("total_reviews") or 0
|
||||
print(
|
||||
f" Total reviews: {before_reviews} → {after_reviews} "
|
||||
f"(+{after_reviews - before_reviews})"
|
||||
)
|
||||
|
||||
# Compare store agents
|
||||
print("\n🔍 StoreAgent view changes:")
|
||||
before_avg_runs = before["store_agents"].get("avg_runs", 0) or 0
|
||||
after_avg_runs = after["store_agents"].get("avg_runs", 0) or 0
|
||||
print(
|
||||
f" Average runs: {before_avg_runs:.2f} → {after_avg_runs:.2f} "
|
||||
f"(+{after_avg_runs - before_avg_runs:.2f})"
|
||||
)
|
||||
|
||||
# Verify changes occurred
|
||||
runs_changed = (after["agent_runs"].get("total_runs") or 0) > (
|
||||
before["agent_runs"].get("total_runs") or 0
|
||||
)
|
||||
reviews_changed = (after["reviews"].get("total_reviews") or 0) > (
|
||||
before["reviews"].get("total_reviews") or 0
|
||||
)
|
||||
|
||||
if runs_changed and reviews_changed:
|
||||
print("\n✅ Materialized views are updating correctly!")
|
||||
return True
|
||||
else:
|
||||
print("\n⚠️ Some materialized views may not have updated:")
|
||||
if not runs_changed:
|
||||
print(" - Agent run counts did not increase")
|
||||
if not reviews_changed:
|
||||
print(" - Review counts did not increase")
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
db = Prisma()
|
||||
await db.connect()
|
||||
|
||||
print("=" * 60)
|
||||
print("Materialized Views Test")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# Check if data exists
|
||||
user_count = await db.user.count()
|
||||
if user_count == 0:
|
||||
print("❌ No data in database. Please run test_data_creator.py first.")
|
||||
await db.disconnect()
|
||||
return
|
||||
|
||||
# 1. Check cron job
|
||||
cron_exists = await check_cron_job(db)
|
||||
|
||||
# 2. Get initial counts
|
||||
counts_before = await get_materialized_view_counts(db)
|
||||
|
||||
# 3. Add test data
|
||||
data_added = await add_test_data(db)
|
||||
refresh_success = False
|
||||
|
||||
if data_added:
|
||||
# Wait a moment for data to be committed
|
||||
print("\nWaiting for data to be committed...")
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# 4. Manually refresh views
|
||||
refresh_success = await refresh_materialized_views(db)
|
||||
|
||||
if refresh_success:
|
||||
# 5. Get counts after refresh
|
||||
counts_after = await get_materialized_view_counts(db)
|
||||
|
||||
# 6. Compare results
|
||||
await compare_counts(counts_before, counts_after)
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("Test Summary")
|
||||
print("=" * 60)
|
||||
print(f"✓ pg_cron job exists: {'Yes' if cron_exists else 'No'}")
|
||||
print(f"✓ Test data added: {'Yes' if data_added else 'No'}")
|
||||
print(f"✓ Manual refresh worked: {'Yes' if refresh_success else 'No'}")
|
||||
print(
|
||||
f"✓ Views updated correctly: {'Yes' if data_added and refresh_success else 'Cannot verify'}"
|
||||
)
|
||||
|
||||
if cron_exists:
|
||||
print(
|
||||
"\n💡 The materialized views will also refresh automatically every 15 minutes via pg_cron."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"\n⚠️ Automatic refresh is not configured. Views must be refreshed manually."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test failed with error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
await db.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
159
autogpt_platform/backend/backend/check_store_data.py
Normal file
159
autogpt_platform/backend/backend/check_store_data.py
Normal file
@@ -0,0 +1,159 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Check store-related data in the database."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from prisma import Prisma
|
||||
|
||||
|
||||
async def check_store_data(db):
|
||||
"""Check what store data exists in the database."""
|
||||
|
||||
print("============================================================")
|
||||
print("Store Data Check")
|
||||
print("============================================================")
|
||||
|
||||
# Check store listings
|
||||
print("\n1. Store Listings:")
|
||||
print("-" * 40)
|
||||
listings = await db.storelisting.find_many()
|
||||
print(f"Total store listings: {len(listings)}")
|
||||
|
||||
if listings:
|
||||
for listing in listings[:5]:
|
||||
print(f"\nListing ID: {listing.id}")
|
||||
print(f" Name: {listing.name}")
|
||||
print(f" Status: {listing.status}")
|
||||
print(f" Slug: {listing.slug}")
|
||||
|
||||
# Check store listing versions
|
||||
print("\n\n2. Store Listing Versions:")
|
||||
print("-" * 40)
|
||||
versions = await db.storelistingversion.find_many(include={"StoreListing": True})
|
||||
print(f"Total store listing versions: {len(versions)}")
|
||||
|
||||
# Group by submission status
|
||||
status_counts = {}
|
||||
for version in versions:
|
||||
status = version.submissionStatus
|
||||
status_counts[status] = status_counts.get(status, 0) + 1
|
||||
|
||||
print("\nVersions by status:")
|
||||
for status, count in status_counts.items():
|
||||
print(f" {status}: {count}")
|
||||
|
||||
# Show approved versions
|
||||
approved_versions = [v for v in versions if v.submissionStatus == "APPROVED"]
|
||||
print(f"\nApproved versions: {len(approved_versions)}")
|
||||
if approved_versions:
|
||||
for version in approved_versions[:5]:
|
||||
print(f"\n Version ID: {version.id}")
|
||||
print(f" Listing: {version.StoreListing.name}")
|
||||
print(f" Version: {version.version}")
|
||||
|
||||
# Check store listing reviews
|
||||
print("\n\n3. Store Listing Reviews:")
|
||||
print("-" * 40)
|
||||
reviews = await db.storelistingreview.find_many(
|
||||
include={"StoreListingVersion": {"include": {"StoreListing": True}}}
|
||||
)
|
||||
print(f"Total reviews: {len(reviews)}")
|
||||
|
||||
if reviews:
|
||||
# Calculate average rating
|
||||
total_score = sum(r.score for r in reviews)
|
||||
avg_score = total_score / len(reviews) if reviews else 0
|
||||
print(f"Average rating: {avg_score:.2f}")
|
||||
|
||||
# Show sample reviews
|
||||
print("\nSample reviews:")
|
||||
for review in reviews[:3]:
|
||||
print(f"\n Review for: {review.StoreListingVersion.StoreListing.name}")
|
||||
print(f" Score: {review.score}")
|
||||
print(f" Comments: {review.comments[:100]}...")
|
||||
|
||||
# Check StoreAgent view data
|
||||
print("\n\n4. StoreAgent View Data:")
|
||||
print("-" * 40)
|
||||
|
||||
# Query the StoreAgent view
|
||||
query = """
|
||||
SELECT
|
||||
sa.listing_id,
|
||||
sa.slug,
|
||||
sa.agent_name,
|
||||
sa.description,
|
||||
sa.featured,
|
||||
sa.runs,
|
||||
sa.rating,
|
||||
sa.creator_username,
|
||||
sa.categories,
|
||||
sa.updated_at
|
||||
FROM "StoreAgent" sa
|
||||
LIMIT 10;
|
||||
"""
|
||||
|
||||
store_agents = await db.query_raw(query)
|
||||
print(f"Total store agents in view: {len(store_agents)}")
|
||||
|
||||
if store_agents:
|
||||
for agent in store_agents[:5]:
|
||||
print(f"\nStore Agent: {agent['agent_name']}")
|
||||
print(f" Slug: {agent['slug']}")
|
||||
print(f" Runs: {agent['runs']}")
|
||||
print(f" Rating: {agent['rating']}")
|
||||
print(f" Creator: {agent['creator_username']}")
|
||||
|
||||
# Check the underlying data that should populate StoreAgent
|
||||
print("\n\n5. Data that should populate StoreAgent view:")
|
||||
print("-" * 40)
|
||||
|
||||
# Check for any APPROVED store listing versions
|
||||
query = """
|
||||
SELECT COUNT(*) as count
|
||||
FROM "StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
"""
|
||||
|
||||
result = await db.query_raw(query)
|
||||
approved_count = result[0]["count"] if result else 0
|
||||
print(f"Approved store listing versions: {approved_count}")
|
||||
|
||||
# Check for store listings with hasApprovedVersion = true
|
||||
query = """
|
||||
SELECT COUNT(*) as count
|
||||
FROM "StoreListing"
|
||||
WHERE "hasApprovedVersion" = true AND "isDeleted" = false
|
||||
"""
|
||||
|
||||
result = await db.query_raw(query)
|
||||
has_approved_count = result[0]["count"] if result else 0
|
||||
print(f"Store listings with approved versions: {has_approved_count}")
|
||||
|
||||
# Check agent graph executions
|
||||
query = """
|
||||
SELECT COUNT(DISTINCT "agentGraphId") as unique_agents,
|
||||
COUNT(*) as total_executions
|
||||
FROM "AgentGraphExecution"
|
||||
"""
|
||||
|
||||
result = await db.query_raw(query)
|
||||
if result:
|
||||
print("\nAgent Graph Executions:")
|
||||
print(f" Unique agents with executions: {result[0]['unique_agents']}")
|
||||
print(f" Total executions: {result[0]['total_executions']}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function."""
|
||||
db = Prisma()
|
||||
await db.connect()
|
||||
|
||||
try:
|
||||
await check_store_data(db)
|
||||
finally:
|
||||
await db.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -6,6 +6,8 @@ from dotenv import load_dotenv
|
||||
|
||||
from backend.util.logging import configure_logging
|
||||
|
||||
os.environ["ENABLE_AUTH"] = "false"
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# NOTE: You can run tests like with the --log-cli-level=INFO to see the logs
|
||||
5
autogpt_platform/backend/backend/data/__init__.py
Normal file
5
autogpt_platform/backend/backend/data/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .graph import NodeModel
|
||||
from .integrations import Webhook # noqa: F401
|
||||
|
||||
# Resolve Webhook <- NodeModel forward reference
|
||||
NodeModel.model_rebuild()
|
||||
@@ -78,6 +78,7 @@ class BlockCategory(Enum):
|
||||
PRODUCTIVITY = "Block that helps with productivity"
|
||||
ISSUE_TRACKING = "Block that helps with issue tracking"
|
||||
MULTIMEDIA = "Block that interacts with multimedia content"
|
||||
MARKETING = "Block that helps with marketing"
|
||||
|
||||
def dict(self) -> dict[str, str]:
|
||||
return {"category": self.name, "description": self.value}
|
||||
@@ -424,28 +425,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
raise ValueError(f"{self.name} did not produce any output for {output}")
|
||||
|
||||
def merge_stats(self, stats: NodeExecutionStats) -> NodeExecutionStats:
|
||||
stats_dict = stats.model_dump()
|
||||
current_stats = self.execution_stats.model_dump()
|
||||
|
||||
for key, value in stats_dict.items():
|
||||
if key not in current_stats:
|
||||
# Field doesn't exist yet, just set it, but this will probably
|
||||
# not happen, just in case though so we throw for invalid when
|
||||
# converting back in
|
||||
current_stats[key] = value
|
||||
elif isinstance(value, dict) and isinstance(current_stats[key], dict):
|
||||
current_stats[key].update(value)
|
||||
elif isinstance(value, (int, float)) and isinstance(
|
||||
current_stats[key], (int, float)
|
||||
):
|
||||
current_stats[key] += value
|
||||
elif isinstance(value, list) and isinstance(current_stats[key], list):
|
||||
current_stats[key].extend(value)
|
||||
else:
|
||||
current_stats[key] = value
|
||||
|
||||
self.execution_stats = NodeExecutionStats(**current_stats)
|
||||
|
||||
self.execution_stats += stats
|
||||
return self.execution_stats
|
||||
|
||||
@property
|
||||
@@ -512,6 +492,12 @@ def get_blocks() -> dict[str, Type[Block]]:
|
||||
|
||||
|
||||
async def initialize_blocks() -> None:
|
||||
# First, sync all provider costs to blocks
|
||||
# Imported here to avoid circular import
|
||||
from backend.sdk.cost_integration import sync_all_provider_costs
|
||||
|
||||
sync_all_provider_costs()
|
||||
|
||||
for cls in get_blocks().values():
|
||||
block = cls()
|
||||
existing_block = await AgentBlock.prisma().find_first(
|
||||
|
||||
@@ -4,6 +4,7 @@ from backend.blocks.ai_music_generator import AIMusicGeneratorBlock
|
||||
from backend.blocks.ai_shortform_video_block import AIShortformVideoCreatorBlock
|
||||
from backend.blocks.apollo.organization import SearchOrganizationsBlock
|
||||
from backend.blocks.apollo.people import SearchPeopleBlock
|
||||
from backend.blocks.apollo.person import GetPersonDetailBlock
|
||||
from backend.blocks.flux_kontext import AIImageEditorBlock, FluxKontextModelName
|
||||
from backend.blocks.ideogram import IdeogramModelBlock
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
@@ -84,6 +85,9 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.EVA_QWEN_2_5_32B: 1,
|
||||
LlmModel.DEEPSEEK_CHAT: 2,
|
||||
LlmModel.PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE: 1,
|
||||
LlmModel.PERPLEXITY_SONAR: 1,
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: 5,
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10,
|
||||
LlmModel.QWEN_QWQ_32B_PREVIEW: 2,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
|
||||
@@ -362,7 +366,31 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
],
|
||||
SearchPeopleBlock: [
|
||||
BlockCost(
|
||||
cost_amount=2,
|
||||
cost_amount=10,
|
||||
cost_filter={
|
||||
"enrich_info": False,
|
||||
"credentials": {
|
||||
"id": apollo_credentials.id,
|
||||
"provider": apollo_credentials.provider,
|
||||
"type": apollo_credentials.type,
|
||||
},
|
||||
},
|
||||
),
|
||||
BlockCost(
|
||||
cost_amount=20,
|
||||
cost_filter={
|
||||
"enrich_info": True,
|
||||
"credentials": {
|
||||
"id": apollo_credentials.id,
|
||||
"provider": apollo_credentials.provider,
|
||||
"type": apollo_credentials.type,
|
||||
},
|
||||
},
|
||||
),
|
||||
],
|
||||
GetPersonDetailBlock: [
|
||||
BlockCost(
|
||||
cost_amount=1,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": apollo_credentials.id,
|
||||
|
||||
@@ -93,6 +93,28 @@ async def locked_transaction(key: str):
|
||||
yield tx
|
||||
|
||||
|
||||
def get_database_schema() -> str:
|
||||
"""Extract database schema from DATABASE_URL."""
|
||||
parsed_url = urlparse(DATABASE_URL)
|
||||
query_params = dict(parse_qsl(parsed_url.query))
|
||||
return query_params.get("schema", "public")
|
||||
|
||||
|
||||
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
|
||||
"""Execute raw SQL query with proper schema handling."""
|
||||
schema = get_database_schema()
|
||||
schema_prefix = f"{schema}." if schema != "public" else ""
|
||||
formatted_query = query_template.format(schema_prefix=schema_prefix)
|
||||
|
||||
import prisma as prisma_module
|
||||
|
||||
result = await prisma_module.get_client().query_raw(
|
||||
formatted_query, *args # type: ignore
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class BaseDbModel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from pydantic import BaseModel
|
||||
from redis.asyncio.client import PubSub as AsyncPubSub
|
||||
from redis.client import PubSub
|
||||
|
||||
from backend.data import redis
|
||||
from backend.data import redis_client as redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -22,6 +22,7 @@ from prisma.models import (
|
||||
AgentGraphExecution,
|
||||
AgentNodeExecution,
|
||||
AgentNodeExecutionInputOutput,
|
||||
AgentNodeExecutionKeyValueData,
|
||||
)
|
||||
from prisma.types import (
|
||||
AgentGraphExecutionCreateInput,
|
||||
@@ -29,6 +30,7 @@ from prisma.types import (
|
||||
AgentGraphExecutionWhereInput,
|
||||
AgentNodeExecutionCreateInput,
|
||||
AgentNodeExecutionInputOutputCreateInput,
|
||||
AgentNodeExecutionKeyValueDataCreateInput,
|
||||
AgentNodeExecutionUpdateInput,
|
||||
AgentNodeExecutionWhereInput,
|
||||
)
|
||||
@@ -47,7 +49,8 @@ from .block import (
|
||||
get_io_block_ids,
|
||||
get_webhook_block_ids,
|
||||
)
|
||||
from .db import BaseDbModel
|
||||
from .db import BaseDbModel, query_raw_with_schema
|
||||
from .event_bus import AsyncRedisEventBus, RedisEventBus
|
||||
from .includes import (
|
||||
EXECUTION_RESULT_INCLUDE,
|
||||
EXECUTION_RESULT_ORDER,
|
||||
@@ -55,7 +58,6 @@ from .includes import (
|
||||
graph_execution_include,
|
||||
)
|
||||
from .model import GraphExecutionStats, NodeExecutionStats
|
||||
from .queue import AsyncRedisEventBus, RedisEventBus
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -66,6 +68,21 @@ config = Config()
|
||||
# -------------------------- Models -------------------------- #
|
||||
|
||||
|
||||
class BlockErrorStats(BaseModel):
|
||||
"""Typed data structure for block error statistics."""
|
||||
|
||||
block_id: str
|
||||
total_executions: int
|
||||
failed_executions: int
|
||||
|
||||
@property
|
||||
def error_rate(self) -> float:
|
||||
"""Calculate error rate as a percentage."""
|
||||
if self.total_executions == 0:
|
||||
return 0.0
|
||||
return (self.failed_executions / self.total_executions) * 100
|
||||
|
||||
|
||||
ExecutionStatus = AgentExecutionStatus
|
||||
|
||||
|
||||
@@ -347,6 +364,7 @@ class NodeExecutionResult(BaseModel):
|
||||
|
||||
|
||||
async def get_graph_executions(
|
||||
graph_exec_id: str | None = None,
|
||||
graph_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
@@ -354,9 +372,12 @@ async def get_graph_executions(
|
||||
created_time_lte: datetime | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[GraphExecutionMeta]:
|
||||
"""⚠️ **Optional `user_id` check**: MUST USE check in user-facing endpoints."""
|
||||
where_filter: AgentGraphExecutionWhereInput = {
|
||||
"isDeleted": False,
|
||||
}
|
||||
if graph_exec_id:
|
||||
where_filter["id"] = graph_exec_id
|
||||
if user_id:
|
||||
where_filter["userId"] = user_id
|
||||
if graph_id:
|
||||
@@ -588,12 +609,10 @@ async def update_graph_execution_start_time(
|
||||
|
||||
async def update_graph_execution_stats(
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecution | None:
|
||||
update_data: AgentGraphExecutionUpdateManyMutationInput = {
|
||||
"executionStatus": status
|
||||
}
|
||||
update_data: AgentGraphExecutionUpdateManyMutationInput = {}
|
||||
|
||||
if stats:
|
||||
stats_dict = stats.model_dump()
|
||||
@@ -601,6 +620,9 @@ async def update_graph_execution_stats(
|
||||
stats_dict["error"] = str(stats_dict["error"])
|
||||
update_data["stats"] = Json(stats_dict)
|
||||
|
||||
if status:
|
||||
update_data["executionStatus"] = status
|
||||
|
||||
updated_count = await AgentGraphExecution.prisma().update_many(
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
@@ -716,6 +738,7 @@ async def delete_graph_execution(
|
||||
|
||||
|
||||
async def get_node_execution(node_exec_id: str) -> NodeExecutionResult | None:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
execution = await AgentNodeExecution.prisma().find_first(
|
||||
where={"id": node_exec_id},
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
@@ -726,15 +749,19 @@ async def get_node_execution(node_exec_id: str) -> NodeExecutionResult | None:
|
||||
|
||||
|
||||
async def get_node_executions(
|
||||
graph_exec_id: str,
|
||||
graph_exec_id: str | None = None,
|
||||
node_id: str | None = None,
|
||||
block_ids: list[str] | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
limit: int | None = None,
|
||||
created_time_gte: datetime | None = None,
|
||||
created_time_lte: datetime | None = None,
|
||||
include_exec_data: bool = True,
|
||||
) -> list[NodeExecutionResult]:
|
||||
where_clause: AgentNodeExecutionWhereInput = {
|
||||
"agentGraphExecutionId": graph_exec_id,
|
||||
}
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
where_clause: AgentNodeExecutionWhereInput = {}
|
||||
if graph_exec_id:
|
||||
where_clause["agentGraphExecutionId"] = graph_exec_id
|
||||
if node_id:
|
||||
where_clause["agentNodeId"] = node_id
|
||||
if block_ids:
|
||||
@@ -742,9 +769,19 @@ async def get_node_executions(
|
||||
if statuses:
|
||||
where_clause["OR"] = [{"executionStatus": status} for status in statuses]
|
||||
|
||||
if created_time_gte or created_time_lte:
|
||||
where_clause["addedTime"] = {
|
||||
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
|
||||
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
|
||||
}
|
||||
|
||||
executions = await AgentNodeExecution.prisma().find_many(
|
||||
where=where_clause,
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
include=(
|
||||
EXECUTION_RESULT_INCLUDE
|
||||
if include_exec_data
|
||||
else {"Node": True, "GraphExecution": True}
|
||||
),
|
||||
order=EXECUTION_RESULT_ORDER,
|
||||
take=limit,
|
||||
)
|
||||
@@ -755,6 +792,7 @@ async def get_node_executions(
|
||||
async def get_latest_node_execution(
|
||||
node_id: str, graph_eid: str
|
||||
) -> NodeExecutionResult | None:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
execution = await AgentNodeExecution.prisma().find_first(
|
||||
where={
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
@@ -903,3 +941,87 @@ class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
|
||||
) -> AsyncGenerator[ExecutionEvent, None]:
|
||||
async for event in self.listen_events(f"{user_id}/{graph_id}/{graph_exec_id}"):
|
||||
yield event
|
||||
|
||||
|
||||
# --------------------- KV Data Functions --------------------- #
|
||||
|
||||
|
||||
async def get_execution_kv_data(user_id: str, key: str) -> Any | None:
|
||||
"""
|
||||
Get key-value data for a user and key.
|
||||
|
||||
Args:
|
||||
user_id: The id of the User.
|
||||
key: The key to retrieve data for.
|
||||
|
||||
Returns:
|
||||
The data associated with the key, or None if not found.
|
||||
"""
|
||||
kv_data = await AgentNodeExecutionKeyValueData.prisma().find_unique(
|
||||
where={"userId_key": {"userId": user_id, "key": key}}
|
||||
)
|
||||
return (
|
||||
type_utils.convert(kv_data.data, type[Any])
|
||||
if kv_data and kv_data.data
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
async def set_execution_kv_data(
|
||||
user_id: str, node_exec_id: str, key: str, data: Any
|
||||
) -> Any | None:
|
||||
"""
|
||||
Set key-value data for a user and key.
|
||||
|
||||
Args:
|
||||
user_id: The id of the User.
|
||||
node_exec_id: The id of the AgentNodeExecution.
|
||||
key: The key to store data under.
|
||||
data: The data to store.
|
||||
"""
|
||||
resp = await AgentNodeExecutionKeyValueData.prisma().upsert(
|
||||
where={"userId_key": {"userId": user_id, "key": key}},
|
||||
data={
|
||||
"create": AgentNodeExecutionKeyValueDataCreateInput(
|
||||
userId=user_id,
|
||||
agentNodeExecutionId=node_exec_id,
|
||||
key=key,
|
||||
data=Json(data) if data is not None else None,
|
||||
),
|
||||
"update": {
|
||||
"agentNodeExecutionId": node_exec_id,
|
||||
"data": Json(data) if data is not None else None,
|
||||
},
|
||||
},
|
||||
)
|
||||
return type_utils.convert(resp.data, type[Any]) if resp and resp.data else None
|
||||
|
||||
|
||||
async def get_block_error_stats(
|
||||
start_time: datetime, end_time: datetime
|
||||
) -> list[BlockErrorStats]:
|
||||
"""Get block execution stats using efficient SQL aggregation."""
|
||||
|
||||
query_template = """
|
||||
SELECT
|
||||
n."agentBlockId" as block_id,
|
||||
COUNT(*) as total_executions,
|
||||
SUM(CASE WHEN ne."executionStatus" = 'FAILED' THEN 1 ELSE 0 END) as failed_executions
|
||||
FROM {schema_prefix}"AgentNodeExecution" ne
|
||||
JOIN {schema_prefix}"AgentNode" n ON ne."agentNodeId" = n.id
|
||||
WHERE ne."addedTime" >= $1::timestamp AND ne."addedTime" <= $2::timestamp
|
||||
GROUP BY n."agentBlockId"
|
||||
HAVING COUNT(*) >= 10
|
||||
"""
|
||||
|
||||
result = await query_raw_with_schema(query_template, start_time, end_time)
|
||||
|
||||
# Convert to typed data structures
|
||||
return [
|
||||
BlockErrorStats(
|
||||
block_id=row["block_id"],
|
||||
total_executions=int(row["total_executions"]),
|
||||
failed_executions=int(row["failed_executions"]),
|
||||
)
|
||||
for row in result
|
||||
]
|
||||
|
||||
@@ -3,7 +3,6 @@ import uuid
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
import prisma
|
||||
from prisma import Json
|
||||
from prisma.enums import SubmissionStatus
|
||||
from prisma.models import AgentGraph, AgentNode, AgentNodeLink, StoreListingVersion
|
||||
@@ -14,7 +13,7 @@ from prisma.types import (
|
||||
AgentNodeLinkCreateInput,
|
||||
StoreListingVersionWhereInput,
|
||||
)
|
||||
from pydantic import JsonValue, create_model
|
||||
from pydantic import Field, JsonValue, create_model
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
@@ -27,10 +26,11 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
is_credentials_field_name,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import type as type_utils
|
||||
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
from .db import BaseDbModel, transaction
|
||||
from .db import BaseDbModel, query_raw_with_schema, transaction
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -188,6 +188,23 @@ class BaseGraph(BaseDbModel):
|
||||
)
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def has_external_trigger(self) -> bool:
|
||||
return self.webhook_input_node is not None
|
||||
|
||||
@property
|
||||
def webhook_input_node(self) -> Node | None:
|
||||
return next(
|
||||
(
|
||||
node
|
||||
for node in self.nodes
|
||||
if node.block.block_type
|
||||
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_schema(
|
||||
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
|
||||
@@ -243,6 +260,8 @@ class Graph(BaseGraph):
|
||||
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
|
||||
if field.provider != other_field.provider:
|
||||
continue
|
||||
if ProviderName.HTTP in field.provider:
|
||||
continue
|
||||
|
||||
# If this happens, that means a block implementation probably needs
|
||||
# to be updated.
|
||||
@@ -264,6 +283,7 @@ class Graph(BaseGraph):
|
||||
required_scopes=set(field_info.required_scopes or []),
|
||||
discriminator=field_info.discriminator,
|
||||
discriminator_mapping=field_info.discriminator_mapping,
|
||||
discriminator_values=field_info.discriminator_values,
|
||||
),
|
||||
)
|
||||
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
|
||||
@@ -282,48 +302,46 @@ class Graph(BaseGraph):
|
||||
Returns:
|
||||
dict[aggregated_field_key, tuple(
|
||||
CredentialsFieldInfo: A spec for one aggregated credentials field
|
||||
(now includes discriminator_values from matching nodes)
|
||||
set[(node_id, field_name)]: Node credentials fields that are
|
||||
compatible with this aggregated field spec
|
||||
)]
|
||||
"""
|
||||
return {
|
||||
"_".join(sorted(agg_field_info.provider))
|
||||
+ "_"
|
||||
+ "_".join(sorted(agg_field_info.supported_types))
|
||||
+ "_credentials": (agg_field_info, node_fields)
|
||||
for agg_field_info, node_fields in CredentialsFieldInfo.combine(
|
||||
*(
|
||||
(
|
||||
# Apply discrimination before aggregating credentials inputs
|
||||
(
|
||||
field_info.discriminate(
|
||||
node.input_default[field_info.discriminator]
|
||||
)
|
||||
if (
|
||||
field_info.discriminator
|
||||
and node.input_default.get(field_info.discriminator)
|
||||
)
|
||||
else field_info
|
||||
),
|
||||
(node.id, field_name),
|
||||
# First collect all credential field data with input defaults
|
||||
node_credential_data = []
|
||||
|
||||
for graph in [self] + self.sub_graphs:
|
||||
for node in graph.nodes:
|
||||
for (
|
||||
field_name,
|
||||
field_info,
|
||||
) in node.block.input_schema.get_credentials_fields_info().items():
|
||||
|
||||
discriminator = field_info.discriminator
|
||||
if not discriminator:
|
||||
node_credential_data.append((field_info, (node.id, field_name)))
|
||||
continue
|
||||
|
||||
discriminator_value = node.input_default.get(discriminator)
|
||||
if discriminator_value is None:
|
||||
node_credential_data.append((field_info, (node.id, field_name)))
|
||||
continue
|
||||
|
||||
discriminated_info = field_info.discriminate(discriminator_value)
|
||||
discriminated_info.discriminator_values.add(discriminator_value)
|
||||
|
||||
node_credential_data.append(
|
||||
(discriminated_info, (node.id, field_name))
|
||||
)
|
||||
for graph in [self] + self.sub_graphs
|
||||
for node in graph.nodes
|
||||
for field_name, field_info in node.block.input_schema.get_credentials_fields_info().items()
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
# Combine credential field info (this will merge discriminator_values automatically)
|
||||
return CredentialsFieldInfo.combine(*node_credential_data)
|
||||
|
||||
|
||||
class GraphModel(Graph):
|
||||
user_id: str
|
||||
nodes: list[NodeModel] = [] # type: ignore
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def has_webhook_trigger(self) -> bool:
|
||||
return self.webhook_input_node is not None
|
||||
|
||||
@property
|
||||
def starting_nodes(self) -> list[NodeModel]:
|
||||
outbound_nodes = {link.sink_id for link in self.links}
|
||||
@@ -336,17 +354,12 @@ class GraphModel(Graph):
|
||||
if node.id not in outbound_nodes or node.id in input_nodes
|
||||
]
|
||||
|
||||
@property
|
||||
def webhook_input_node(self) -> NodeModel | None:
|
||||
return next(
|
||||
(
|
||||
node
|
||||
for node in self.nodes
|
||||
if node.block.block_type
|
||||
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
),
|
||||
None,
|
||||
)
|
||||
def meta(self) -> "GraphMeta":
|
||||
"""
|
||||
Returns a GraphMeta object with metadata about the graph.
|
||||
This is used to return metadata about the graph without exposing nodes and links.
|
||||
"""
|
||||
return GraphMeta.from_graph(self)
|
||||
|
||||
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
||||
"""
|
||||
@@ -382,8 +395,10 @@ class GraphModel(Graph):
|
||||
|
||||
# Reassign Link IDs
|
||||
for link in graph.links:
|
||||
link.source_id = id_map[link.source_id]
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
if link.source_id in id_map:
|
||||
link.source_id = id_map[link.source_id]
|
||||
if link.sink_id in id_map:
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
|
||||
# Reassign User IDs for agent blocks
|
||||
for node in graph.nodes:
|
||||
@@ -391,7 +406,9 @@ class GraphModel(Graph):
|
||||
continue
|
||||
node.input_default["user_id"] = user_id
|
||||
node.input_default.setdefault("inputs", {})
|
||||
if (graph_id := node.input_default.get("graph_id")) in graph_id_map:
|
||||
if (
|
||||
graph_id := node.input_default.get("graph_id")
|
||||
) and graph_id in graph_id_map:
|
||||
node.input_default["graph_id"] = graph_id_map[graph_id]
|
||||
|
||||
def validate_graph(
|
||||
@@ -601,6 +618,18 @@ class GraphModel(Graph):
|
||||
)
|
||||
|
||||
|
||||
class GraphMeta(Graph):
|
||||
user_id: str
|
||||
|
||||
# Easy work-around to prevent exposing nodes and links in the API response
|
||||
nodes: list[NodeModel] = Field(default=[], exclude=True) # type: ignore
|
||||
links: list[Link] = Field(default=[], exclude=True)
|
||||
|
||||
@staticmethod
|
||||
def from_graph(graph: GraphModel) -> "GraphMeta":
|
||||
return GraphMeta(**graph.model_dump())
|
||||
|
||||
|
||||
# --------------------- CRUD functions --------------------- #
|
||||
|
||||
|
||||
@@ -629,10 +658,10 @@ async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
|
||||
return NodeModel.from_db(node)
|
||||
|
||||
|
||||
async def get_graphs(
|
||||
async def list_graphs(
|
||||
user_id: str,
|
||||
filter_by: Literal["active"] | None = "active",
|
||||
) -> list[GraphModel]:
|
||||
) -> list[GraphMeta]:
|
||||
"""
|
||||
Retrieves graph metadata objects.
|
||||
Default behaviour is to get all currently active graphs.
|
||||
@@ -642,7 +671,7 @@ async def get_graphs(
|
||||
user_id: The ID of the user that owns the graph.
|
||||
|
||||
Returns:
|
||||
list[GraphModel]: A list of objects representing the retrieved graphs.
|
||||
list[GraphMeta]: A list of objects representing the retrieved graphs.
|
||||
"""
|
||||
where_clause: AgentGraphWhereInput = {"userId": user_id}
|
||||
|
||||
@@ -656,13 +685,13 @@ async def get_graphs(
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
|
||||
graph_models = []
|
||||
graph_models: list[GraphMeta] = []
|
||||
for graph in graphs:
|
||||
try:
|
||||
graph_model = GraphModel.from_db(graph)
|
||||
# Trigger serialization to validate that the graph is well formed.
|
||||
graph_model.model_dump()
|
||||
graph_models.append(graph_model)
|
||||
graph_meta = GraphModel.from_db(graph).meta()
|
||||
# Trigger serialization to validate that the graph is well formed
|
||||
graph_meta.model_dump()
|
||||
graph_models.append(graph_meta)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing graph {graph.id}: {e}")
|
||||
continue
|
||||
@@ -1029,13 +1058,13 @@ async def fix_llm_provider_credentials():
|
||||
|
||||
broken_nodes = []
|
||||
try:
|
||||
broken_nodes = await prisma.get_client().query_raw(
|
||||
broken_nodes = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT graph."userId" user_id,
|
||||
node.id node_id,
|
||||
node."constantInput" node_preset_input
|
||||
FROM platform."AgentNode" node
|
||||
LEFT JOIN platform."AgentGraph" graph
|
||||
FROM {schema_prefix}"AgentNode" node
|
||||
LEFT JOIN {schema_prefix}"AgentGraph" graph
|
||||
ON node."agentGraphId" = graph.id
|
||||
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
|
||||
ORDER BY graph."userId";
|
||||
|
||||
@@ -11,8 +11,8 @@ from prisma.types import (
|
||||
)
|
||||
from pydantic import Field, computed_field
|
||||
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
|
||||
from backend.data.queue import AsyncRedisEventBus
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.utils import webhook_ingress_url
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
@@ -77,10 +77,6 @@ class WebhookWithRelations(Webhook):
|
||||
)
|
||||
|
||||
|
||||
# Fix Webhook <- NodeModel relations
|
||||
NodeModel.model_rebuild()
|
||||
|
||||
|
||||
# --------------------- CRUD functions --------------------- #
|
||||
|
||||
|
||||
|
||||
@@ -14,11 +14,12 @@ from typing import (
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
cast,
|
||||
get_args,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
from prisma.enums import CreditTransactionType
|
||||
@@ -41,6 +42,9 @@ from pydantic_core import (
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
# Type alias for any provider name (including custom ones)
|
||||
AnyProviderName = str # Will be validated as ProviderName at runtime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.block import BlockSchema
|
||||
|
||||
@@ -240,13 +244,65 @@ class UserPasswordCredentials(_BaseCredentials):
|
||||
return f"Basic {base64.b64encode(f'{self.username.get_secret_value()}:{self.password.get_secret_value()}'.encode()).decode()}"
|
||||
|
||||
|
||||
class HostScopedCredentials(_BaseCredentials):
|
||||
type: Literal["host_scoped"] = "host_scoped"
|
||||
host: str = Field(description="The host/URI pattern to match against request URLs")
|
||||
headers: dict[str, SecretStr] = Field(
|
||||
description="Key-value header map to add to matching requests",
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
def _extract_headers(self, headers: dict[str, SecretStr]) -> dict[str, str]:
|
||||
"""Helper to extract secret values from headers."""
|
||||
return {key: value.get_secret_value() for key, value in headers.items()}
|
||||
|
||||
@field_serializer("headers")
|
||||
def serialize_headers(self, headers: dict[str, SecretStr]) -> dict[str, str]:
|
||||
"""Serialize headers by extracting secret values."""
|
||||
return self._extract_headers(headers)
|
||||
|
||||
def get_headers_dict(self) -> dict[str, str]:
|
||||
"""Get headers with secret values extracted."""
|
||||
return self._extract_headers(self.headers)
|
||||
|
||||
def auth_header(self) -> str:
|
||||
"""Get authorization header for backward compatibility."""
|
||||
auth_headers = self.get_headers_dict()
|
||||
if "Authorization" in auth_headers:
|
||||
return auth_headers["Authorization"]
|
||||
return ""
|
||||
|
||||
def matches_url(self, url: str) -> bool:
|
||||
"""Check if this credential should be applied to the given URL."""
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
# Extract hostname without port
|
||||
request_host = parsed_url.hostname
|
||||
if not request_host:
|
||||
return False
|
||||
|
||||
# Simple host matching - exact match or wildcard subdomain match
|
||||
if self.host == request_host:
|
||||
return True
|
||||
|
||||
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
|
||||
if self.host.startswith("*."):
|
||||
domain = self.host[2:] # Remove "*."
|
||||
return request_host.endswith(f".{domain}") or request_host == domain
|
||||
|
||||
return False
|
||||
|
||||
|
||||
Credentials = Annotated[
|
||||
OAuth2Credentials | APIKeyCredentials | UserPasswordCredentials,
|
||||
OAuth2Credentials
|
||||
| APIKeyCredentials
|
||||
| UserPasswordCredentials
|
||||
| HostScopedCredentials,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
CredentialsType = Literal["api_key", "oauth2", "user_password"]
|
||||
CredentialsType = Literal["api_key", "oauth2", "user_password", "host_scoped"]
|
||||
|
||||
|
||||
class OAuthState(BaseModel):
|
||||
@@ -288,7 +344,7 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
type: CT
|
||||
|
||||
@classmethod
|
||||
def allowed_providers(cls) -> tuple[ProviderName, ...]:
|
||||
def allowed_providers(cls) -> tuple[ProviderName, ...] | None:
|
||||
return get_args(cls.model_fields["provider"].annotation)
|
||||
|
||||
@classmethod
|
||||
@@ -313,22 +369,46 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
f"{field_schema}"
|
||||
) from e
|
||||
|
||||
if len(cls.allowed_providers()) > 1 and not schema_extra.discriminator:
|
||||
providers = cls.allowed_providers()
|
||||
if (
|
||||
providers is not None
|
||||
and len(providers) > 1
|
||||
and not schema_extra.discriminator
|
||||
):
|
||||
raise TypeError(
|
||||
f"Multi-provider CredentialsField '{field_name}' "
|
||||
"requires discriminator!"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _add_json_schema_extra(schema, cls: CredentialsMetaInput):
|
||||
schema["credentials_provider"] = cls.allowed_providers()
|
||||
schema["credentials_types"] = cls.allowed_cred_types()
|
||||
def _add_json_schema_extra(schema: dict, model_class: type):
|
||||
# Use model_class for allowed_providers/cred_types
|
||||
if hasattr(model_class, "allowed_providers") and hasattr(
|
||||
model_class, "allowed_cred_types"
|
||||
):
|
||||
allowed_providers = model_class.allowed_providers()
|
||||
# If no specific providers (None), allow any string
|
||||
if allowed_providers is None:
|
||||
schema["credentials_provider"] = ["string"] # Allow any string provider
|
||||
else:
|
||||
schema["credentials_provider"] = allowed_providers
|
||||
schema["credentials_types"] = model_class.allowed_cred_types()
|
||||
# Do not return anything, just mutate schema in place
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=_add_json_schema_extra, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def _extract_host_from_url(url: str) -> str:
|
||||
"""Extract host from URL for grouping host-scoped credentials."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return parsed.hostname or url
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
|
||||
provider: frozenset[CP] = Field(..., alias="credentials_provider")
|
||||
@@ -336,11 +416,12 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
required_scopes: Optional[frozenset[str]] = Field(None, alias="credentials_scopes")
|
||||
discriminator: Optional[str] = None
|
||||
discriminator_mapping: Optional[dict[str, CP]] = None
|
||||
discriminator_values: set[Any] = Field(default_factory=set)
|
||||
|
||||
@classmethod
|
||||
def combine(
|
||||
cls, *fields: tuple[CredentialsFieldInfo[CP, CT], T]
|
||||
) -> Sequence[tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
|
||||
) -> dict[str, tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
|
||||
"""
|
||||
Combines multiple CredentialsFieldInfo objects into as few as possible.
|
||||
|
||||
@@ -358,22 +439,36 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
the set of keys of the respective original items that were grouped together.
|
||||
"""
|
||||
if not fields:
|
||||
return []
|
||||
return {}
|
||||
|
||||
# Group fields by their provider and supported_types
|
||||
# For HTTP host-scoped credentials, also group by host
|
||||
grouped_fields: defaultdict[
|
||||
tuple[frozenset[CP], frozenset[CT]],
|
||||
list[tuple[T, CredentialsFieldInfo[CP, CT]]],
|
||||
] = defaultdict(list)
|
||||
|
||||
for field, key in fields:
|
||||
group_key = (frozenset(field.provider), frozenset(field.supported_types))
|
||||
if field.provider == frozenset([ProviderName.HTTP]):
|
||||
# HTTP host-scoped credentials can have different hosts that reqires different credential sets.
|
||||
# Group by host extracted from the URL
|
||||
providers = frozenset(
|
||||
[cast(CP, "http")]
|
||||
+ [
|
||||
cast(CP, _extract_host_from_url(str(value)))
|
||||
for value in field.discriminator_values
|
||||
]
|
||||
)
|
||||
else:
|
||||
providers = frozenset(field.provider)
|
||||
|
||||
group_key = (providers, frozenset(field.supported_types))
|
||||
grouped_fields[group_key].append((key, field))
|
||||
|
||||
# Combine fields within each group
|
||||
result: list[tuple[CredentialsFieldInfo[CP, CT], set[T]]] = []
|
||||
result: dict[str, tuple[CredentialsFieldInfo[CP, CT], set[T]]] = {}
|
||||
|
||||
for group in grouped_fields.values():
|
||||
for key, group in grouped_fields.items():
|
||||
# Start with the first field in the group
|
||||
_, combined = group[0]
|
||||
|
||||
@@ -386,18 +481,32 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
if field.required_scopes:
|
||||
all_scopes.update(field.required_scopes)
|
||||
|
||||
# Create a new combined field
|
||||
result.append(
|
||||
(
|
||||
CredentialsFieldInfo[CP, CT](
|
||||
credentials_provider=combined.provider,
|
||||
credentials_types=combined.supported_types,
|
||||
credentials_scopes=frozenset(all_scopes) or None,
|
||||
discriminator=combined.discriminator,
|
||||
discriminator_mapping=combined.discriminator_mapping,
|
||||
),
|
||||
combined_keys,
|
||||
)
|
||||
# Combine discriminator_values from all fields in the group (removing duplicates)
|
||||
all_discriminator_values = []
|
||||
for _, field in group:
|
||||
for value in field.discriminator_values:
|
||||
if value not in all_discriminator_values:
|
||||
all_discriminator_values.append(value)
|
||||
|
||||
# Generate the key for the combined result
|
||||
providers_key, supported_types_key = key
|
||||
group_key = (
|
||||
"-".join(sorted(providers_key))
|
||||
+ "_"
|
||||
+ "-".join(sorted(supported_types_key))
|
||||
+ "_credentials"
|
||||
)
|
||||
|
||||
result[group_key] = (
|
||||
CredentialsFieldInfo[CP, CT](
|
||||
credentials_provider=combined.provider,
|
||||
credentials_types=combined.supported_types,
|
||||
credentials_scopes=frozenset(all_scopes) or None,
|
||||
discriminator=combined.discriminator,
|
||||
discriminator_mapping=combined.discriminator_mapping,
|
||||
discriminator_values=set(all_discriminator_values),
|
||||
),
|
||||
combined_keys,
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -406,11 +515,15 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
if not (self.discriminator and self.discriminator_mapping):
|
||||
return self
|
||||
|
||||
discriminator_value = self.discriminator_mapping[discriminator_value]
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([discriminator_value]),
|
||||
credentials_provider=frozenset(
|
||||
[self.discriminator_mapping[discriminator_value]]
|
||||
),
|
||||
credentials_types=self.supported_types,
|
||||
credentials_scopes=self.required_scopes,
|
||||
discriminator=self.discriminator,
|
||||
discriminator_mapping=self.discriminator_mapping,
|
||||
discriminator_values=self.discriminator_values,
|
||||
)
|
||||
|
||||
|
||||
@@ -419,6 +532,7 @@ def CredentialsField(
|
||||
*,
|
||||
discriminator: Optional[str] = None,
|
||||
discriminator_mapping: Optional[dict[str, Any]] = None,
|
||||
discriminator_values: Optional[set[Any]] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
**kwargs,
|
||||
@@ -434,10 +548,16 @@ def CredentialsField(
|
||||
"credentials_scopes": list(required_scopes) or None,
|
||||
"discriminator": discriminator,
|
||||
"discriminator_mapping": discriminator_mapping,
|
||||
"discriminator_values": discriminator_values,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
# Merge any json_schema_extra passed in kwargs
|
||||
if "json_schema_extra" in kwargs:
|
||||
extra_schema = kwargs.pop("json_schema_extra")
|
||||
field_schema_extra.update(extra_schema)
|
||||
|
||||
return Field(
|
||||
title=title,
|
||||
description=description,
|
||||
@@ -516,6 +636,35 @@ class NodeExecutionStats(BaseModel):
|
||||
llm_retry_count: int = 0
|
||||
input_token_count: int = 0
|
||||
output_token_count: int = 0
|
||||
extra_cost: int = 0
|
||||
extra_steps: int = 0
|
||||
|
||||
def __iadd__(self, other: "NodeExecutionStats") -> "NodeExecutionStats":
|
||||
"""Mutate this instance by adding another NodeExecutionStats."""
|
||||
if not isinstance(other, NodeExecutionStats):
|
||||
return NotImplemented
|
||||
|
||||
stats_dict = other.model_dump()
|
||||
current_stats = self.model_dump()
|
||||
|
||||
for key, value in stats_dict.items():
|
||||
if key not in current_stats:
|
||||
# Field doesn't exist yet, just set it
|
||||
setattr(self, key, value)
|
||||
elif isinstance(value, dict) and isinstance(current_stats[key], dict):
|
||||
current_stats[key].update(value)
|
||||
setattr(self, key, current_stats[key])
|
||||
elif isinstance(value, (int, float)) and isinstance(
|
||||
current_stats[key], (int, float)
|
||||
):
|
||||
setattr(self, key, current_stats[key] + value)
|
||||
elif isinstance(value, list) and isinstance(current_stats[key], list):
|
||||
current_stats[key].extend(value)
|
||||
setattr(self, key, current_stats[key])
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class GraphExecutionStats(BaseModel):
|
||||
|
||||
143
autogpt_platform/backend/backend/data/model_test.py
Normal file
143
autogpt_platform/backend/backend/data/model_test.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import HostScopedCredentials
|
||||
|
||||
|
||||
class TestHostScopedCredentials:
|
||||
def test_host_scoped_credentials_creation(self):
|
||||
"""Test creating HostScopedCredentials with required fields."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer secret-token"),
|
||||
"X-API-Key": SecretStr("api-key-123"),
|
||||
},
|
||||
title="Example API Credentials",
|
||||
)
|
||||
|
||||
assert creds.type == "host_scoped"
|
||||
assert creds.provider == "custom"
|
||||
assert creds.host == "api.example.com"
|
||||
assert creds.title == "Example API Credentials"
|
||||
assert len(creds.headers) == 2
|
||||
assert "Authorization" in creds.headers
|
||||
assert "X-API-Key" in creds.headers
|
||||
|
||||
def test_get_headers_dict(self):
|
||||
"""Test getting headers with secret values extracted."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer secret-token"),
|
||||
"X-Custom-Header": SecretStr("custom-value"),
|
||||
},
|
||||
)
|
||||
|
||||
headers_dict = creds.get_headers_dict()
|
||||
|
||||
assert headers_dict == {
|
||||
"Authorization": "Bearer secret-token",
|
||||
"X-Custom-Header": "custom-value",
|
||||
}
|
||||
|
||||
def test_matches_url_exact_host(self):
|
||||
"""Test URL matching with exact host match."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="api.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url("https://api.example.com/v1/data")
|
||||
assert creds.matches_url("http://api.example.com/endpoint")
|
||||
assert not creds.matches_url("https://other.example.com/v1/data")
|
||||
assert not creds.matches_url("https://subdomain.api.example.com/v1/data")
|
||||
|
||||
def test_matches_url_wildcard_subdomain(self):
|
||||
"""Test URL matching with wildcard subdomain pattern."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="*.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url("https://api.example.com/v1/data")
|
||||
assert creds.matches_url("https://subdomain.example.com/endpoint")
|
||||
assert creds.matches_url("https://deep.nested.example.com/path")
|
||||
assert creds.matches_url("https://example.com/path") # Base domain should match
|
||||
assert not creds.matches_url("https://example.org/v1/data")
|
||||
assert not creds.matches_url("https://notexample.com/v1/data")
|
||||
|
||||
def test_matches_url_with_port_and_path(self):
|
||||
"""Test URL matching with ports and paths."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="localhost",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url("http://localhost:8080/api/v1")
|
||||
assert creds.matches_url("https://localhost:443/secure/endpoint")
|
||||
assert creds.matches_url("http://localhost/simple")
|
||||
|
||||
def test_empty_headers_dict(self):
|
||||
"""Test HostScopedCredentials with empty headers."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom", host="api.example.com", headers={}
|
||||
)
|
||||
|
||||
assert creds.get_headers_dict() == {}
|
||||
assert creds.matches_url("https://api.example.com/test")
|
||||
|
||||
def test_credential_serialization(self):
|
||||
"""Test that credentials can be serialized/deserialized properly."""
|
||||
original_creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer secret-token"),
|
||||
"X-API-Key": SecretStr("api-key-123"),
|
||||
},
|
||||
title="Test Credentials",
|
||||
)
|
||||
|
||||
# Serialize to dict (simulating storage)
|
||||
serialized = original_creds.model_dump()
|
||||
|
||||
# Deserialize back
|
||||
restored_creds = HostScopedCredentials.model_validate(serialized)
|
||||
|
||||
assert restored_creds.id == original_creds.id
|
||||
assert restored_creds.provider == original_creds.provider
|
||||
assert restored_creds.host == original_creds.host
|
||||
assert restored_creds.title == original_creds.title
|
||||
assert restored_creds.type == "host_scoped"
|
||||
|
||||
# Check that headers are properly restored
|
||||
assert restored_creds.get_headers_dict() == original_creds.get_headers_dict()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"host,test_url,expected",
|
||||
[
|
||||
("api.example.com", "https://api.example.com/test", True),
|
||||
("api.example.com", "https://different.example.com/test", False),
|
||||
("*.example.com", "https://api.example.com/test", True),
|
||||
("*.example.com", "https://sub.api.example.com/test", True),
|
||||
("*.example.com", "https://example.com/test", True),
|
||||
("*.example.com", "https://example.org/test", False),
|
||||
("localhost", "http://localhost:3000/test", True),
|
||||
("localhost", "http://127.0.0.1:3000/test", False),
|
||||
],
|
||||
)
|
||||
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
|
||||
"""Parametrized test for various URL matching scenarios."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="test",
|
||||
host=host,
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url(test_url) == expected
|
||||
@@ -5,12 +5,15 @@ from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
create_graph_execution,
|
||||
get_block_error_stats,
|
||||
get_execution_kv_data,
|
||||
get_graph_execution,
|
||||
get_graph_execution_meta,
|
||||
get_graph_executions,
|
||||
get_latest_node_execution,
|
||||
get_node_execution,
|
||||
get_node_executions,
|
||||
set_execution_kv_data,
|
||||
update_graph_execution_start_time,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_stats,
|
||||
@@ -101,6 +104,9 @@ class DatabaseManager(AppService):
|
||||
update_node_execution_stats = _(update_node_execution_stats)
|
||||
upsert_execution_input = _(upsert_execution_input)
|
||||
upsert_execution_output = _(upsert_execution_output)
|
||||
get_execution_kv_data = _(get_execution_kv_data)
|
||||
set_execution_kv_data = _(set_execution_kv_data)
|
||||
get_block_error_stats = _(get_block_error_stats)
|
||||
|
||||
# Graphs
|
||||
get_node = _(get_node)
|
||||
@@ -159,6 +165,8 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
update_node_execution_stats = _(d.update_node_execution_stats)
|
||||
upsert_execution_input = _(d.upsert_execution_input)
|
||||
upsert_execution_output = _(d.upsert_execution_output)
|
||||
get_execution_kv_data = _(d.get_execution_kv_data)
|
||||
set_execution_kv_data = _(d.set_execution_kv_data)
|
||||
|
||||
# Graphs
|
||||
get_node = _(d.get_node)
|
||||
@@ -193,6 +201,9 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
d.get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
# Block error monitoring
|
||||
get_block_error_stats = _(d.get_block_error_stats)
|
||||
|
||||
|
||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
@@ -202,8 +213,11 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
return DatabaseManager
|
||||
|
||||
create_graph_execution = d.create_graph_execution
|
||||
get_connected_output_nodes = d.get_connected_output_nodes
|
||||
get_latest_node_execution = d.get_latest_node_execution
|
||||
get_graph = d.get_graph
|
||||
get_graph_metadata = d.get_graph_metadata
|
||||
get_graph_execution_meta = d.get_graph_execution_meta
|
||||
get_node = d.get_node
|
||||
get_node_execution = d.get_node_execution
|
||||
get_node_executions = d.get_node_executions
|
||||
@@ -215,3 +229,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
update_node_execution_status = d.update_node_execution_status
|
||||
update_node_execution_status_batch = d.update_node_execution_status_batch
|
||||
update_user_integrations = d.update_user_integrations
|
||||
get_execution_kv_data = d.get_execution_kv_data
|
||||
set_execution_kv_data = d.set_execution_kv_data
|
||||
get_block_error_stats = d.get_block_error_stats
|
||||
|
||||
@@ -24,7 +24,7 @@ from backend.data.notifications import (
|
||||
NotificationType,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.utils import create_execution_queue_config
|
||||
from backend.executor.utils import LogMetadata, create_execution_queue_config
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
@@ -35,7 +35,7 @@ from autogpt_libs.utils.cache import thread_cached
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.block import (
|
||||
BlockData,
|
||||
BlockInput,
|
||||
@@ -98,35 +98,6 @@ utilization_gauge = Gauge(
|
||||
)
|
||||
|
||||
|
||||
class LogMetadata(TruncatedLogger):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_eid: str,
|
||||
graph_id: str,
|
||||
node_eid: str,
|
||||
node_id: str,
|
||||
block_name: str,
|
||||
max_length: int = 1000,
|
||||
):
|
||||
metadata = {
|
||||
"component": "ExecutionManager",
|
||||
"user_id": user_id,
|
||||
"graph_eid": graph_eid,
|
||||
"graph_id": graph_id,
|
||||
"node_eid": node_eid,
|
||||
"node_id": node_id,
|
||||
"block_name": block_name,
|
||||
}
|
||||
prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|neid:{node_eid}|{block_name}]"
|
||||
super().__init__(
|
||||
_logger,
|
||||
max_length=max_length,
|
||||
prefix=prefix,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@@ -158,6 +129,7 @@ async def execute_node(
|
||||
node_block = node.block
|
||||
|
||||
log_metadata = LogMetadata(
|
||||
logger=_logger,
|
||||
user_id=user_id,
|
||||
graph_eid=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
@@ -235,9 +207,7 @@ async def execute_node(
|
||||
|
||||
# Update execution stats
|
||||
if execution_stats is not None:
|
||||
execution_stats = execution_stats.model_copy(
|
||||
update=node_block.execution_stats.model_dump()
|
||||
)
|
||||
execution_stats += node_block.execution_stats
|
||||
execution_stats.input_size = input_size
|
||||
execution_stats.output_size = output_size
|
||||
|
||||
@@ -421,7 +391,7 @@ class Executor:
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@async_error_logged
|
||||
@async_error_logged(swallow=True)
|
||||
async def on_node_execution(
|
||||
cls,
|
||||
node_exec: NodeExecutionEntry,
|
||||
@@ -429,6 +399,7 @@ class Executor:
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> NodeExecutionStats:
|
||||
log_metadata = LogMetadata(
|
||||
logger=_logger,
|
||||
user_id=node_exec.user_id,
|
||||
graph_eid=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
@@ -529,11 +500,12 @@ class Executor:
|
||||
logger.info(f"[GraphExecutor] {cls.pid} started")
|
||||
|
||||
@classmethod
|
||||
@error_logged
|
||||
@error_logged(swallow=False)
|
||||
def on_graph_execution(
|
||||
cls, graph_exec: GraphExecutionEntry, cancel: threading.Event
|
||||
):
|
||||
log_metadata = LogMetadata(
|
||||
logger=_logger,
|
||||
user_id=graph_exec.user_id,
|
||||
graph_eid=graph_exec.graph_exec_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
@@ -581,6 +553,15 @@ class Executor:
|
||||
exec_stats.cputime += timing_info.cpu_time
|
||||
exec_stats.error = str(error) if error else exec_stats.error
|
||||
|
||||
if status not in {
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.TERMINATED,
|
||||
ExecutionStatus.FAILED,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Graph Execution #{graph_exec.graph_exec_id} ended with unexpected status {status}"
|
||||
)
|
||||
|
||||
if graph_exec_result := db_client.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
status=status,
|
||||
@@ -665,9 +646,10 @@ class Executor:
|
||||
return
|
||||
|
||||
nonlocal execution_stats
|
||||
execution_stats.node_count += 1
|
||||
execution_stats.node_count += 1 + result.extra_steps
|
||||
execution_stats.nodes_cputime += result.cputime
|
||||
execution_stats.nodes_walltime += result.walltime
|
||||
execution_stats.cost += result.extra_cost
|
||||
if (err := result.error) and isinstance(err, Exception):
|
||||
execution_stats.node_error_count += 1
|
||||
update_node_execution_status(
|
||||
@@ -684,7 +666,6 @@ class Executor:
|
||||
|
||||
if _graph_exec := db_client.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
status=execution_status,
|
||||
stats=execution_stats,
|
||||
):
|
||||
send_execution_update(_graph_exec)
|
||||
@@ -853,6 +834,7 @@ class Executor:
|
||||
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
|
||||
)
|
||||
finally:
|
||||
# Cancel and wait for all node executions to complete
|
||||
for node_id, inflight_exec in running_node_execution.items():
|
||||
if inflight_exec.is_done():
|
||||
continue
|
||||
@@ -865,6 +847,28 @@ class Executor:
|
||||
log_metadata.info(f"Stopping node evaluation {node_id}")
|
||||
inflight_eval.cancel()
|
||||
|
||||
for node_id, inflight_exec in running_node_execution.items():
|
||||
if inflight_exec.is_done():
|
||||
continue
|
||||
try:
|
||||
inflight_exec.wait_for_cancellation(timeout=60.0)
|
||||
except TimeoutError:
|
||||
log_metadata.exception(
|
||||
f"Node execution #{node_id} did not stop in time, "
|
||||
"it may be stuck or taking too long."
|
||||
)
|
||||
|
||||
for node_id, inflight_eval in running_node_evaluation.items():
|
||||
if inflight_eval.done():
|
||||
continue
|
||||
try:
|
||||
inflight_eval.result(timeout=60.0)
|
||||
except TimeoutError:
|
||||
log_metadata.exception(
|
||||
f"Node evaluation #{node_id} did not stop in time, "
|
||||
"it may be stuck or taking too long."
|
||||
)
|
||||
|
||||
if execution_status in [ExecutionStatus.TERMINATED, ExecutionStatus.FAILED]:
|
||||
inflight_executions = db_client.get_node_executions(
|
||||
graph_exec.graph_exec_id,
|
||||
@@ -872,6 +876,7 @@ class Executor:
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
include_exec_data=False,
|
||||
)
|
||||
db_client.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in inflight_executions],
|
||||
|
||||
@@ -7,7 +7,8 @@ from prisma.models import User
|
||||
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.store.model
|
||||
from backend.blocks.basic import FindInDictionaryBlock, StoreValueBlock
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.data_manipulation import FindInDictionaryBlock
|
||||
from backend.blocks.io import AgentInputBlock
|
||||
from backend.blocks.maths import CalculatorBlock, Operation
|
||||
from backend.data import execution, graph
|
||||
@@ -1,8 +1,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
|
||||
@@ -13,22 +13,23 @@ from apscheduler.schedulers.blocking import BlockingScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from dotenv import load_dotenv
|
||||
from prisma.enums import NotificationType
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.metrics import sentry_capture_error
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
endpoint_to_async,
|
||||
expose,
|
||||
get_service_client,
|
||||
from backend.monitoring import (
|
||||
NotificationJobArgs,
|
||||
process_existing_batches,
|
||||
process_weekly_summary,
|
||||
report_block_error_rates,
|
||||
report_late_executions,
|
||||
)
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
from backend.util.logging import PrefixFilter
|
||||
from backend.util.service import AppService, AppServiceClient, endpoint_to_async, expose
|
||||
from backend.util.settings import Config
|
||||
|
||||
|
||||
@@ -52,24 +53,19 @@ def _extract_schema_from_url(database_url) -> tuple[str, str]:
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.addFilter(PrefixFilter("[Scheduler]"))
|
||||
apscheduler_logger = logger.getChild("apscheduler")
|
||||
apscheduler_logger.addFilter(PrefixFilter("[Scheduler] [APScheduler]"))
|
||||
|
||||
config = Config()
|
||||
|
||||
|
||||
def log(msg, **kwargs):
|
||||
logger.info("[Scheduler] " + msg, **kwargs)
|
||||
|
||||
|
||||
def job_listener(event):
|
||||
"""Logs job execution outcomes for better monitoring."""
|
||||
if event.exception:
|
||||
log(f"Job {event.job_id} failed.")
|
||||
logger.error(f"Job {event.job_id} failed.")
|
||||
else:
|
||||
log(f"Job {event.job_id} completed successfully.")
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_notification_client():
|
||||
return get_service_client(NotificationManagerClient)
|
||||
logger.info(f"Job {event.job_id} completed successfully.")
|
||||
|
||||
|
||||
@thread_cached
|
||||
@@ -84,73 +80,23 @@ def execute_graph(**kwargs):
|
||||
async def _execute_graph(**kwargs):
|
||||
args = GraphExecutionJobArgs(**kwargs)
|
||||
try:
|
||||
log(f"Executing recurring job for graph #{args.graph_id}")
|
||||
await execution_utils.add_graph_execution(
|
||||
graph_id=args.graph_id,
|
||||
inputs=args.input_data,
|
||||
logger.info(f"Executing recurring job for graph #{args.graph_id}")
|
||||
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
|
||||
user_id=args.user_id,
|
||||
graph_id=args.graph_id,
|
||||
graph_version=args.graph_version,
|
||||
inputs=args.input_data,
|
||||
graph_credentials_inputs=args.input_credentials,
|
||||
use_db_query=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error executing graph {args.graph_id}: {e}")
|
||||
|
||||
|
||||
class LateExecutionException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def report_late_executions() -> str:
|
||||
late_executions = execution_utils.get_db_client().get_graph_executions(
|
||||
statuses=[ExecutionStatus.QUEUED],
|
||||
created_time_gte=datetime.now(timezone.utc)
|
||||
- timedelta(seconds=config.execution_late_notification_checkrange_secs),
|
||||
created_time_lte=datetime.now(timezone.utc)
|
||||
- timedelta(seconds=config.execution_late_notification_threshold_secs),
|
||||
limit=1000,
|
||||
)
|
||||
|
||||
if not late_executions:
|
||||
return "No late executions detected."
|
||||
|
||||
num_late_executions = len(late_executions)
|
||||
num_users = len(set([r.user_id for r in late_executions]))
|
||||
|
||||
late_execution_details = [
|
||||
f"* `Execution ID: {exec.id}, Graph ID: {exec.graph_id}v{exec.graph_version}, User ID: {exec.user_id}, Created At: {exec.started_at.isoformat()}`"
|
||||
for exec in late_executions
|
||||
]
|
||||
|
||||
error = LateExecutionException(
|
||||
f"Late executions detected: {num_late_executions} late executions from {num_users} users "
|
||||
f"in the last {config.execution_late_notification_checkrange_secs} seconds. "
|
||||
f"Graph has been queued for more than {config.execution_late_notification_threshold_secs} seconds. "
|
||||
"Please check the executor status. Details:\n"
|
||||
+ "\n".join(late_execution_details)
|
||||
)
|
||||
msg = str(error)
|
||||
sentry_capture_error(error)
|
||||
get_notification_client().discord_system_alert(msg)
|
||||
return msg
|
||||
|
||||
|
||||
def process_existing_batches(**kwargs):
|
||||
args = NotificationJobArgs(**kwargs)
|
||||
try:
|
||||
log(
|
||||
f"Processing existing batches for notification type {args.notification_types}"
|
||||
logger.info(
|
||||
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id}"
|
||||
)
|
||||
get_notification_client().process_existing_batches(args.notification_types)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing existing batches: {e}")
|
||||
logger.error(f"Error executing graph {args.graph_id}: {e}")
|
||||
|
||||
|
||||
def process_weekly_summary(**kwargs):
|
||||
try:
|
||||
log("Processing weekly summary")
|
||||
get_notification_client().queue_weekly_summary()
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing weekly summary: {e}")
|
||||
# Monitoring functions are now imported from monitoring module
|
||||
|
||||
|
||||
class Jobstores(Enum):
|
||||
@@ -160,11 +106,12 @@ class Jobstores(Enum):
|
||||
|
||||
|
||||
class GraphExecutionJobArgs(BaseModel):
|
||||
graph_id: str
|
||||
input_data: BlockInput
|
||||
user_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
cron: str
|
||||
input_data: BlockInput
|
||||
input_credentials: dict[str, CredentialsMetaInput] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class GraphExecutionJobInfo(GraphExecutionJobArgs):
|
||||
@@ -184,11 +131,6 @@ class GraphExecutionJobInfo(GraphExecutionJobArgs):
|
||||
)
|
||||
|
||||
|
||||
class NotificationJobArgs(BaseModel):
|
||||
notification_types: list[NotificationType]
|
||||
cron: str
|
||||
|
||||
|
||||
class NotificationJobInfo(NotificationJobArgs):
|
||||
id: str
|
||||
name: str
|
||||
@@ -247,7 +189,8 @@ class Scheduler(AppService):
|
||||
),
|
||||
# These don't really need persistence
|
||||
Jobstores.WEEKLY_NOTIFICATIONS.value: MemoryJobStore(),
|
||||
}
|
||||
},
|
||||
logger=apscheduler_logger,
|
||||
)
|
||||
|
||||
if self.register_system_tasks:
|
||||
@@ -280,39 +223,55 @@ class Scheduler(AppService):
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
# Block Error Rate Monitoring
|
||||
self.scheduler.add_job(
|
||||
report_block_error_rates,
|
||||
id="report_block_error_rates",
|
||||
trigger="interval",
|
||||
replace_existing=True,
|
||||
seconds=config.block_error_rate_check_interval_secs,
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
|
||||
self.scheduler.start()
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down scheduler...")
|
||||
logger.info("⏳ Shutting down scheduler...")
|
||||
if self.scheduler:
|
||||
self.scheduler.shutdown(wait=False)
|
||||
|
||||
@expose
|
||||
def add_graph_execution_schedule(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
cron: str,
|
||||
input_data: BlockInput,
|
||||
user_id: str,
|
||||
input_credentials: dict[str, CredentialsMetaInput],
|
||||
name: Optional[str] = None,
|
||||
) -> GraphExecutionJobInfo:
|
||||
job_args = GraphExecutionJobArgs(
|
||||
graph_id=graph_id,
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
cron=cron,
|
||||
input_data=input_data,
|
||||
input_credentials=input_credentials,
|
||||
)
|
||||
job = self.scheduler.add_job(
|
||||
execute_graph,
|
||||
CronTrigger.from_crontab(cron),
|
||||
kwargs=job_args.model_dump(),
|
||||
replace_existing=True,
|
||||
name=name,
|
||||
trigger=CronTrigger.from_crontab(cron),
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Added job {job.id} with cron schedule '{cron}' input data: {input_data}"
|
||||
)
|
||||
log(f"Added job {job.id} with cron schedule '{cron}' input data: {input_data}")
|
||||
return GraphExecutionJobInfo.from_db(job_args, job)
|
||||
|
||||
@expose
|
||||
@@ -321,14 +280,13 @@ class Scheduler(AppService):
|
||||
) -> GraphExecutionJobInfo:
|
||||
job = self.scheduler.get_job(schedule_id, jobstore=Jobstores.EXECUTION.value)
|
||||
if not job:
|
||||
log(f"Job {schedule_id} not found.")
|
||||
raise ValueError(f"Job #{schedule_id} not found.")
|
||||
raise NotFoundError(f"Job #{schedule_id} not found.")
|
||||
|
||||
job_args = GraphExecutionJobArgs(**job.kwargs)
|
||||
if job_args.user_id != user_id:
|
||||
raise ValueError("User ID does not match the job's user ID.")
|
||||
raise NotAuthorizedError("User ID does not match the job's user ID")
|
||||
|
||||
log(f"Deleting job {schedule_id}")
|
||||
logger.info(f"Deleting job {schedule_id}")
|
||||
job.remove()
|
||||
|
||||
return GraphExecutionJobInfo.from_db(job_args, job)
|
||||
@@ -367,6 +325,10 @@ class Scheduler(AppService):
|
||||
def execute_report_late_executions(self):
|
||||
return report_late_executions()
|
||||
|
||||
@expose
|
||||
def execute_report_block_error_rates(self):
|
||||
return report_block_error_rates()
|
||||
|
||||
|
||||
class SchedulerClient(AppServiceClient):
|
||||
@classmethod
|
||||
|
||||
@@ -27,6 +27,7 @@ async def test_agent_schedule(server: SpinTestServer):
|
||||
graph_version=1,
|
||||
cron="0 0 * * *",
|
||||
input_data={"input": "data"},
|
||||
input_credentials={},
|
||||
)
|
||||
assert schedule
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
|
||||
@@ -7,6 +8,8 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, cast
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from pydantic import BaseModel, JsonValue
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockData,
|
||||
@@ -23,12 +26,8 @@ from backend.data.execution import (
|
||||
GraphExecutionStats,
|
||||
GraphExecutionWithNodes,
|
||||
RedisExecutionEventBus,
|
||||
create_graph_execution,
|
||||
get_node_executions,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_status_batch,
|
||||
)
|
||||
from backend.data.graph import GraphModel, Node, get_graph
|
||||
from backend.data.graph import GraphModel, Node
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.rabbitmq import (
|
||||
AsyncRabbitMQ,
|
||||
@@ -55,6 +54,36 @@ logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil
|
||||
# ============ Resource Helpers ============ #
|
||||
|
||||
|
||||
class LogMetadata(TruncatedLogger):
|
||||
def __init__(
|
||||
self,
|
||||
logger: logging.Logger,
|
||||
user_id: str,
|
||||
graph_eid: str,
|
||||
graph_id: str,
|
||||
node_eid: str,
|
||||
node_id: str,
|
||||
block_name: str,
|
||||
max_length: int = 1000,
|
||||
):
|
||||
metadata = {
|
||||
"component": "ExecutionManager",
|
||||
"user_id": user_id,
|
||||
"graph_eid": graph_eid,
|
||||
"graph_id": graph_id,
|
||||
"node_eid": node_eid,
|
||||
"node_id": node_id,
|
||||
"block_name": block_name,
|
||||
}
|
||||
prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|neid:{node_eid}|{block_name}]"
|
||||
super().__init__(
|
||||
logger,
|
||||
max_length=max_length,
|
||||
prefix=prefix,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_event_bus() -> RedisExecutionEventBus:
|
||||
return RedisExecutionEventBus()
|
||||
@@ -653,8 +682,10 @@ def create_execution_queue_config() -> RabbitMQConfig:
|
||||
|
||||
|
||||
async def stop_graph_execution(
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
use_db_query: bool = True,
|
||||
wait_timeout: float = 60.0,
|
||||
):
|
||||
"""
|
||||
Mechanism:
|
||||
@@ -664,66 +695,57 @@ async def stop_graph_execution(
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
queue_client = await get_async_execution_queue()
|
||||
db = execution_db if use_db_query else get_db_async_client()
|
||||
await queue_client.publish_message(
|
||||
routing_key="",
|
||||
message=CancelExecutionEvent(graph_exec_id=graph_exec_id).model_dump_json(),
|
||||
exchange=GRAPH_EXECUTION_CANCEL_EXCHANGE,
|
||||
)
|
||||
|
||||
# Update the status of the graph execution
|
||||
if use_db_query:
|
||||
graph_execution = await update_graph_execution_stats(
|
||||
graph_exec_id,
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
else:
|
||||
graph_execution = await get_db_async_client().update_graph_execution_stats(
|
||||
graph_exec_id,
|
||||
ExecutionStatus.TERMINATED,
|
||||
if not wait_timeout:
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < wait_timeout:
|
||||
graph_exec = await db.get_graph_execution_meta(
|
||||
execution_id=graph_exec_id, user_id=user_id
|
||||
)
|
||||
|
||||
if graph_execution:
|
||||
await get_async_execution_event_bus().publish(graph_execution)
|
||||
else:
|
||||
raise NotFoundError(
|
||||
f"Graph execution #{graph_exec_id} not found for termination."
|
||||
)
|
||||
if not graph_exec:
|
||||
raise NotFoundError(f"Graph execution #{graph_exec_id} not found.")
|
||||
|
||||
# Update the status of the node executions
|
||||
if use_db_query:
|
||||
node_executions = await get_node_executions(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
)
|
||||
await update_node_execution_status_batch(
|
||||
[v.node_exec_id for v in node_executions],
|
||||
if graph_exec.status in [
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
else:
|
||||
node_executions = await get_db_async_client().get_node_executions(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
)
|
||||
await get_db_async_client().update_node_execution_status_batch(
|
||||
[v.node_exec_id for v in node_executions],
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.FAILED,
|
||||
]:
|
||||
# If graph execution is terminated/completed/failed, cancellation is complete
|
||||
return
|
||||
|
||||
await asyncio.gather(
|
||||
*[
|
||||
get_async_execution_event_bus().publish(
|
||||
v.model_copy(update={"status": ExecutionStatus.TERMINATED})
|
||||
elif graph_exec.status in [
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
]:
|
||||
# If the graph is still on the queue, we can prevent them from being executed
|
||||
# by setting the status to TERMINATED.
|
||||
node_execs = await db.get_node_executions(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[ExecutionStatus.QUEUED, ExecutionStatus.INCOMPLETE],
|
||||
include_exec_data=False,
|
||||
)
|
||||
for v in node_executions
|
||||
]
|
||||
await db.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in node_execs],
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
await db.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec_id,
|
||||
status=ExecutionStatus.TERMINATED,
|
||||
)
|
||||
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for graph execution #{graph_exec_id} to terminate."
|
||||
)
|
||||
|
||||
|
||||
@@ -753,22 +775,16 @@ async def add_graph_execution(
|
||||
GraphExecutionEntry: The entry for the graph execution.
|
||||
Raises:
|
||||
ValueError: If the graph is not found or if there are validation errors.
|
||||
""" # noqa
|
||||
if use_db_query:
|
||||
graph: GraphModel | None = await get_graph(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
version=graph_version,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
else:
|
||||
graph: GraphModel | None = await get_db_async_client().get_graph(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
version=graph_version,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
"""
|
||||
gdb = graph_db if use_db_query else get_db_async_client()
|
||||
edb = execution_db if use_db_query else get_db_async_client()
|
||||
|
||||
graph: GraphModel | None = await gdb.get_graph(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
version=graph_version,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
|
||||
@@ -787,22 +803,13 @@ async def add_graph_execution(
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
|
||||
if use_db_query:
|
||||
graph_exec = await create_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
starting_nodes_input=starting_nodes_input,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
else:
|
||||
graph_exec = await get_db_async_client().create_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
starting_nodes_input=starting_nodes_input,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
graph_exec = await edb.create_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
starting_nodes_input=starting_nodes_input,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
try:
|
||||
queue = await get_async_execution_queue()
|
||||
@@ -821,28 +828,15 @@ async def add_graph_execution(
|
||||
return graph_exec
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {e}")
|
||||
|
||||
if use_db_query:
|
||||
await update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
|
||||
ExecutionStatus.FAILED,
|
||||
)
|
||||
await update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
stats=GraphExecutionStats(error=str(e)),
|
||||
)
|
||||
else:
|
||||
await get_db_async_client().update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
|
||||
ExecutionStatus.FAILED,
|
||||
)
|
||||
await get_db_async_client().update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
stats=GraphExecutionStats(error=str(e)),
|
||||
)
|
||||
|
||||
await edb.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
|
||||
ExecutionStatus.FAILED,
|
||||
)
|
||||
await edb.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
stats=GraphExecutionStats(error=str(e)),
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -897,14 +891,10 @@ class NodeExecutionProgress:
|
||||
try:
|
||||
self.tasks[exec_id].result(wait_time)
|
||||
except TimeoutError:
|
||||
print(
|
||||
">>>>>>> -- Timeout, after waiting for",
|
||||
wait_time,
|
||||
"seconds for node_id",
|
||||
exec_id,
|
||||
)
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Task for exec ID {exec_id} failed with error: {str(e)}")
|
||||
pass
|
||||
return self.is_done(0)
|
||||
|
||||
def stop(self) -> list[str]:
|
||||
@@ -921,6 +911,25 @@ class NodeExecutionProgress:
|
||||
cancelled_ids.append(task_id)
|
||||
return cancelled_ids
|
||||
|
||||
def wait_for_cancellation(self, timeout: float = 5.0):
|
||||
"""
|
||||
Wait for all cancelled tasks to complete cancellation.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for cancellation in seconds
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
# Check if all tasks are done (either completed or cancelled)
|
||||
if all(task.done() for task in self.tasks.values()):
|
||||
return True
|
||||
time.sleep(0.1) # Small delay to avoid busy waiting
|
||||
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for cancellation of tasks: {list(self.tasks.keys())}"
|
||||
)
|
||||
|
||||
def _pop_done_task(self, exec_id: str) -> bool:
|
||||
task = self.tasks.get(exec_id)
|
||||
if not task:
|
||||
@@ -933,8 +942,10 @@ class NodeExecutionProgress:
|
||||
return False
|
||||
|
||||
if task := self.tasks.pop(exec_id):
|
||||
self.on_done_task(exec_id, task.result())
|
||||
|
||||
try:
|
||||
self.on_done_task(exec_id, task.result())
|
||||
except Exception as e:
|
||||
logger.error(f"Task for exec ID {exec_id} failed with error: {str(e)}")
|
||||
return True
|
||||
|
||||
def _next_exec(self) -> str | None:
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.redis import get_redis_async
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor.database import DatabaseManagerAsyncClient
|
||||
|
||||
@@ -7,7 +7,7 @@ from autogpt_libs.utils.synchronize import AsyncRedisKeyedMutex
|
||||
from redis.asyncio.lock import Lock as AsyncRedisLock
|
||||
|
||||
from backend.data.model import Credentials, OAuth2Credentials
|
||||
from backend.data.redis import get_redis_async
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
@@ -1,29 +1,226 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.integrations.oauth.todoist import TodoistOAuthHandler
|
||||
|
||||
from .github import GitHubOAuthHandler
|
||||
from .google import GoogleOAuthHandler
|
||||
from .linear import LinearOAuthHandler
|
||||
from .notion import NotionOAuthHandler
|
||||
from .twitter import TwitterOAuthHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..providers import ProviderName
|
||||
from .base import BaseOAuthHandler
|
||||
|
||||
# --8<-- [start:HANDLERS_BY_NAMEExample]
|
||||
HANDLERS_BY_NAME: dict["ProviderName", type["BaseOAuthHandler"]] = {
|
||||
handler.PROVIDER_NAME: handler
|
||||
for handler in [
|
||||
GitHubOAuthHandler,
|
||||
GoogleOAuthHandler,
|
||||
NotionOAuthHandler,
|
||||
TwitterOAuthHandler,
|
||||
LinearOAuthHandler,
|
||||
TodoistOAuthHandler,
|
||||
]
|
||||
# Build handlers dict with string keys for compatibility with SDK auto-registration
|
||||
_ORIGINAL_HANDLERS = [
|
||||
GitHubOAuthHandler,
|
||||
GoogleOAuthHandler,
|
||||
NotionOAuthHandler,
|
||||
TwitterOAuthHandler,
|
||||
TodoistOAuthHandler,
|
||||
]
|
||||
|
||||
# Start with original handlers
|
||||
_handlers_dict = {
|
||||
(
|
||||
handler.PROVIDER_NAME.value
|
||||
if hasattr(handler.PROVIDER_NAME, "value")
|
||||
else str(handler.PROVIDER_NAME)
|
||||
): handler
|
||||
for handler in _ORIGINAL_HANDLERS
|
||||
}
|
||||
|
||||
|
||||
class SDKAwareCredentials(BaseModel):
|
||||
"""OAuth credentials configuration."""
|
||||
|
||||
use_secrets: bool = True
|
||||
client_id_env_var: Optional[str] = None
|
||||
client_secret_env_var: Optional[str] = None
|
||||
|
||||
|
||||
_credentials_by_provider = {}
|
||||
# Add default credentials for original handlers
|
||||
for handler in _ORIGINAL_HANDLERS:
|
||||
provider_name = (
|
||||
handler.PROVIDER_NAME.value
|
||||
if hasattr(handler.PROVIDER_NAME, "value")
|
||||
else str(handler.PROVIDER_NAME)
|
||||
)
|
||||
_credentials_by_provider[provider_name] = SDKAwareCredentials(
|
||||
use_secrets=True, client_id_env_var=None, client_secret_env_var=None
|
||||
)
|
||||
|
||||
|
||||
# Create a custom dict class that includes SDK handlers
|
||||
class SDKAwareHandlersDict(dict):
|
||||
"""Dictionary that automatically includes SDK-registered OAuth handlers."""
|
||||
|
||||
def __getitem__(self, key):
|
||||
# First try the original handlers
|
||||
if key in _handlers_dict:
|
||||
return _handlers_dict[key]
|
||||
|
||||
# Then try SDK handlers
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_handlers = AutoRegistry.get_oauth_handlers()
|
||||
if key in sdk_handlers:
|
||||
return sdk_handlers[key]
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# If not found, raise KeyError
|
||||
raise KeyError(key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def __contains__(self, key):
|
||||
if key in _handlers_dict:
|
||||
return True
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_handlers = AutoRegistry.get_oauth_handlers()
|
||||
return key in sdk_handlers
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
def keys(self):
|
||||
# Combine all keys into a single dict and return its keys view
|
||||
combined = dict(_handlers_dict)
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_handlers = AutoRegistry.get_oauth_handlers()
|
||||
combined.update(sdk_handlers)
|
||||
except ImportError:
|
||||
pass
|
||||
return combined.keys()
|
||||
|
||||
def values(self):
|
||||
combined = dict(_handlers_dict)
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_handlers = AutoRegistry.get_oauth_handlers()
|
||||
combined.update(sdk_handlers)
|
||||
except ImportError:
|
||||
pass
|
||||
return combined.values()
|
||||
|
||||
def items(self):
|
||||
combined = dict(_handlers_dict)
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_handlers = AutoRegistry.get_oauth_handlers()
|
||||
combined.update(sdk_handlers)
|
||||
except ImportError:
|
||||
pass
|
||||
return combined.items()
|
||||
|
||||
|
||||
class SDKAwareCredentialsDict(dict):
|
||||
"""Dictionary that automatically includes SDK-registered OAuth credentials."""
|
||||
|
||||
def __getitem__(self, key):
|
||||
# First try the original handlers
|
||||
if key in _credentials_by_provider:
|
||||
return _credentials_by_provider[key]
|
||||
|
||||
# Then try SDK credentials
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_credentials = AutoRegistry.get_oauth_credentials()
|
||||
if key in sdk_credentials:
|
||||
# Convert from SDKOAuthCredentials to SDKAwareCredentials
|
||||
sdk_cred = sdk_credentials[key]
|
||||
return SDKAwareCredentials(
|
||||
use_secrets=sdk_cred.use_secrets,
|
||||
client_id_env_var=sdk_cred.client_id_env_var,
|
||||
client_secret_env_var=sdk_cred.client_secret_env_var,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# If not found, raise KeyError
|
||||
raise KeyError(key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def __contains__(self, key):
|
||||
if key in _credentials_by_provider:
|
||||
return True
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_credentials = AutoRegistry.get_oauth_credentials()
|
||||
return key in sdk_credentials
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
def keys(self):
|
||||
# Combine all keys into a single dict and return its keys view
|
||||
combined = dict(_credentials_by_provider)
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_credentials = AutoRegistry.get_oauth_credentials()
|
||||
combined.update(sdk_credentials)
|
||||
except ImportError:
|
||||
pass
|
||||
return combined.keys()
|
||||
|
||||
def values(self):
|
||||
combined = dict(_credentials_by_provider)
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_credentials = AutoRegistry.get_oauth_credentials()
|
||||
# Convert SDK credentials to SDKAwareCredentials
|
||||
for key, sdk_cred in sdk_credentials.items():
|
||||
combined[key] = SDKAwareCredentials(
|
||||
use_secrets=sdk_cred.use_secrets,
|
||||
client_id_env_var=sdk_cred.client_id_env_var,
|
||||
client_secret_env_var=sdk_cred.client_secret_env_var,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
return combined.values()
|
||||
|
||||
def items(self):
|
||||
combined = dict(_credentials_by_provider)
|
||||
try:
|
||||
from backend.sdk import AutoRegistry
|
||||
|
||||
sdk_credentials = AutoRegistry.get_oauth_credentials()
|
||||
# Convert SDK credentials to SDKAwareCredentials
|
||||
for key, sdk_cred in sdk_credentials.items():
|
||||
combined[key] = SDKAwareCredentials(
|
||||
use_secrets=sdk_cred.use_secrets,
|
||||
client_id_env_var=sdk_cred.client_id_env_var,
|
||||
client_secret_env_var=sdk_cred.client_secret_env_var,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
return combined.items()
|
||||
|
||||
|
||||
HANDLERS_BY_NAME: dict[str, type["BaseOAuthHandler"]] = SDKAwareHandlersDict()
|
||||
CREDENTIALS_BY_PROVIDER: dict[str, SDKAwareCredentials] = SDKAwareCredentialsDict()
|
||||
# --8<-- [end:HANDLERS_BY_NAMEExample]
|
||||
|
||||
__all__ = ["HANDLERS_BY_NAME"]
|
||||
|
||||
@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseOAuthHandler(ABC):
|
||||
# --8<-- [start:BaseOAuthHandler1]
|
||||
PROVIDER_NAME: ClassVar[ProviderName]
|
||||
PROVIDER_NAME: ClassVar[ProviderName | str]
|
||||
DEFAULT_SCOPES: ClassVar[list[str]] = []
|
||||
# --8<-- [end:BaseOAuthHandler1]
|
||||
|
||||
@@ -81,8 +81,6 @@ class BaseOAuthHandler(ABC):
|
||||
"""Handles the default scopes for the provider"""
|
||||
# If scopes are empty, use the default scopes for the provider
|
||||
if not scopes:
|
||||
logger.debug(
|
||||
f"Using default scopes for provider {self.PROVIDER_NAME.value}"
|
||||
)
|
||||
logger.debug(f"Using default scopes for provider {str(self.PROVIDER_NAME)}")
|
||||
scopes = self.DEFAULT_SCOPES
|
||||
return scopes
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
# --8<-- [start:ProviderName]
|
||||
class ProviderName(str, Enum):
|
||||
"""
|
||||
Provider names for integrations.
|
||||
|
||||
This enum extends str to accept any string value while maintaining
|
||||
backward compatibility with existing provider constants.
|
||||
"""
|
||||
|
||||
AIML_API = "aiml_api"
|
||||
ANTHROPIC = "anthropic"
|
||||
APOLLO = "apollo"
|
||||
@@ -10,17 +18,15 @@ class ProviderName(str, Enum):
|
||||
DISCORD = "discord"
|
||||
D_ID = "d_id"
|
||||
E2B = "e2b"
|
||||
EXA = "exa"
|
||||
FAL = "fal"
|
||||
GENERIC_WEBHOOK = "generic_webhook"
|
||||
GITHUB = "github"
|
||||
GOOGLE = "google"
|
||||
GOOGLE_MAPS = "google_maps"
|
||||
GROQ = "groq"
|
||||
HTTP = "http"
|
||||
HUBSPOT = "hubspot"
|
||||
IDEOGRAM = "ideogram"
|
||||
JINA = "jina"
|
||||
LINEAR = "linear"
|
||||
LLAMA_API = "llama_api"
|
||||
MEDIUM = "medium"
|
||||
MEM0 = "mem0"
|
||||
@@ -42,4 +48,57 @@ class ProviderName(str, Enum):
|
||||
TODOIST = "todoist"
|
||||
UNREAL_SPEECH = "unreal_speech"
|
||||
ZEROBOUNCE = "zerobounce"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: Any) -> "ProviderName":
|
||||
"""
|
||||
Allow any string value to be used as a ProviderName.
|
||||
This enables SDK users to define custom providers without
|
||||
modifying the enum.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
# Create a pseudo-member that behaves like an enum member
|
||||
pseudo_member = str.__new__(cls, value)
|
||||
pseudo_member._name_ = value.upper()
|
||||
pseudo_member._value_ = value
|
||||
return pseudo_member
|
||||
return None # type: ignore
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, schema, handler):
|
||||
"""
|
||||
Custom JSON schema generation that allows any string value,
|
||||
not just the predefined enum values.
|
||||
"""
|
||||
# Get the default schema
|
||||
json_schema = handler(schema)
|
||||
|
||||
# Remove the enum constraint to allow any string
|
||||
if "enum" in json_schema:
|
||||
del json_schema["enum"]
|
||||
|
||||
# Keep the type as string
|
||||
json_schema["type"] = "string"
|
||||
|
||||
# Update description to indicate custom providers are allowed
|
||||
json_schema["description"] = (
|
||||
"Provider name for integrations. "
|
||||
"Can be any string value, including custom provider names."
|
||||
)
|
||||
|
||||
return json_schema
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type, handler):
|
||||
"""
|
||||
Pydantic v2 core schema that allows any string value.
|
||||
"""
|
||||
from pydantic_core import core_schema
|
||||
|
||||
# Create a string schema that validates any string
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls,
|
||||
core_schema.str_schema(),
|
||||
)
|
||||
|
||||
# --8<-- [end:ProviderName]
|
||||
|
||||
@@ -12,7 +12,6 @@ def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]
|
||||
webhook_managers = {}
|
||||
|
||||
from .compass import CompassWebhookManager
|
||||
from .generic import GenericWebhooksManager
|
||||
from .github import GithubWebhooksManager
|
||||
from .slant3d import Slant3DWebhooksManager
|
||||
|
||||
@@ -23,7 +22,6 @@ def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]
|
||||
CompassWebhookManager,
|
||||
GithubWebhooksManager,
|
||||
Slant3DWebhooksManager,
|
||||
GenericWebhooksManager,
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
24
autogpt_platform/backend/backend/monitoring/__init__.py
Normal file
24
autogpt_platform/backend/backend/monitoring/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Monitoring module for platform health and alerting."""
|
||||
|
||||
from .block_error_monitor import BlockErrorMonitor, report_block_error_rates
|
||||
from .late_execution_monitor import (
|
||||
LateExecutionException,
|
||||
LateExecutionMonitor,
|
||||
report_late_executions,
|
||||
)
|
||||
from .notification_monitor import (
|
||||
NotificationJobArgs,
|
||||
process_existing_batches,
|
||||
process_weekly_summary,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BlockErrorMonitor",
|
||||
"LateExecutionMonitor",
|
||||
"LateExecutionException",
|
||||
"NotificationJobArgs",
|
||||
"report_block_error_rates",
|
||||
"report_late_executions",
|
||||
"process_existing_batches",
|
||||
"process_weekly_summary",
|
||||
]
|
||||
@@ -0,0 +1,291 @@
|
||||
"""Block error rate monitoring module."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionStatus, NodeExecutionResult
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.metrics import sentry_capture_error
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = Config()
|
||||
|
||||
|
||||
class BlockStatsWithSamples(BaseModel):
|
||||
"""Enhanced block stats with error samples."""
|
||||
|
||||
block_id: str
|
||||
block_name: str
|
||||
total_executions: int
|
||||
failed_executions: int
|
||||
error_samples: list[str] = []
|
||||
|
||||
@property
|
||||
def error_rate(self) -> float:
|
||||
"""Calculate error rate as a percentage."""
|
||||
if self.total_executions == 0:
|
||||
return 0.0
|
||||
return (self.failed_executions / self.total_executions) * 100
|
||||
|
||||
|
||||
class BlockErrorMonitor:
|
||||
"""Monitor block error rates and send alerts when thresholds are exceeded."""
|
||||
|
||||
def __init__(self, include_top_blocks: int | None = None):
|
||||
self.config = config
|
||||
self.notification_client = get_service_client(NotificationManagerClient)
|
||||
self.include_top_blocks = (
|
||||
include_top_blocks
|
||||
if include_top_blocks is not None
|
||||
else config.block_error_include_top_blocks
|
||||
)
|
||||
|
||||
def check_block_error_rates(self) -> str:
|
||||
"""Check block error rates and send Discord alerts if thresholds are exceeded."""
|
||||
try:
|
||||
logger.info("Checking block error rates")
|
||||
|
||||
# Get executions from the last 24 hours
|
||||
end_time = datetime.now(timezone.utc)
|
||||
start_time = end_time - timedelta(hours=24)
|
||||
|
||||
# Use SQL aggregation to efficiently count totals and failures by block
|
||||
block_stats = self._get_block_stats_from_db(start_time, end_time)
|
||||
|
||||
# For blocks with high error rates, fetch error samples
|
||||
threshold = self.config.block_error_rate_threshold
|
||||
for block_name, stats in block_stats.items():
|
||||
if stats.total_executions >= 10 and stats.error_rate >= threshold * 100:
|
||||
# Only fetch error samples for blocks that exceed threshold
|
||||
error_samples = self._get_error_samples_for_block(
|
||||
stats.block_id, start_time, end_time, limit=3
|
||||
)
|
||||
stats.error_samples = error_samples
|
||||
|
||||
# Check thresholds and send alerts
|
||||
critical_alerts = self._generate_critical_alerts(block_stats, threshold)
|
||||
|
||||
if critical_alerts:
|
||||
msg = "Block Error Rate Alert:\n\n" + "\n\n".join(critical_alerts)
|
||||
self.notification_client.discord_system_alert(msg)
|
||||
logger.info(
|
||||
f"Sent block error rate alert for {len(critical_alerts)} blocks"
|
||||
)
|
||||
return f"Alert sent for {len(critical_alerts)} blocks with high error rates"
|
||||
|
||||
# If no critical alerts, check if we should show top blocks
|
||||
if self.include_top_blocks > 0:
|
||||
top_blocks_msg = self._generate_top_blocks_alert(
|
||||
block_stats, start_time, end_time
|
||||
)
|
||||
if top_blocks_msg:
|
||||
self.notification_client.discord_system_alert(top_blocks_msg)
|
||||
logger.info("Sent top blocks summary")
|
||||
return "Sent top blocks summary"
|
||||
|
||||
logger.info("No blocks exceeded error rate threshold")
|
||||
return "No errors reported for today"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error checking block error rates: {e}")
|
||||
|
||||
error = Exception(f"Error checking block error rates: {e}")
|
||||
msg = str(error)
|
||||
sentry_capture_error(error)
|
||||
self.notification_client.discord_system_alert(msg)
|
||||
return msg
|
||||
|
||||
def _get_block_stats_from_db(
|
||||
self, start_time: datetime, end_time: datetime
|
||||
) -> dict[str, BlockStatsWithSamples]:
|
||||
"""Get block execution stats using efficient SQL aggregation."""
|
||||
|
||||
result = execution_utils.get_db_client().get_block_error_stats(
|
||||
start_time, end_time
|
||||
)
|
||||
|
||||
block_stats = {}
|
||||
for stats in result:
|
||||
block_name = b.name if (b := get_block(stats.block_id)) else "Unknown"
|
||||
|
||||
block_stats[block_name] = BlockStatsWithSamples(
|
||||
block_id=stats.block_id,
|
||||
block_name=block_name,
|
||||
total_executions=stats.total_executions,
|
||||
failed_executions=stats.failed_executions,
|
||||
error_samples=[],
|
||||
)
|
||||
|
||||
return block_stats
|
||||
|
||||
def _generate_critical_alerts(
|
||||
self, block_stats: dict[str, BlockStatsWithSamples], threshold: float
|
||||
) -> list[str]:
|
||||
"""Generate alerts for blocks that exceed the error rate threshold."""
|
||||
alerts = []
|
||||
|
||||
for block_name, stats in block_stats.items():
|
||||
if stats.total_executions >= 10 and stats.error_rate >= threshold * 100:
|
||||
error_groups = self._group_similar_errors(stats.error_samples)
|
||||
|
||||
alert_msg = (
|
||||
f"🚨 Block '{block_name}' has {stats.error_rate:.1f}% error rate "
|
||||
f"({stats.failed_executions}/{stats.total_executions}) in the last 24 hours"
|
||||
)
|
||||
|
||||
if error_groups:
|
||||
alert_msg += "\n\n📊 Error Types:"
|
||||
for error_pattern, count in error_groups.items():
|
||||
alert_msg += f"\n• {error_pattern} ({count}x)"
|
||||
|
||||
alerts.append(alert_msg)
|
||||
|
||||
return alerts
|
||||
|
||||
def _generate_top_blocks_alert(
|
||||
self,
|
||||
block_stats: dict[str, BlockStatsWithSamples],
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> str | None:
|
||||
"""Generate top blocks summary when no critical alerts exist."""
|
||||
top_error_blocks = sorted(
|
||||
[
|
||||
(name, stats)
|
||||
for name, stats in block_stats.items()
|
||||
if stats.total_executions >= 10 and stats.failed_executions > 0
|
||||
],
|
||||
key=lambda x: x[1].failed_executions,
|
||||
reverse=True,
|
||||
)[: self.include_top_blocks]
|
||||
|
||||
if not top_error_blocks:
|
||||
return "✅ No errors reported for today - all blocks are running smoothly!"
|
||||
|
||||
# Get error samples for top blocks
|
||||
for block_name, stats in top_error_blocks:
|
||||
if not stats.error_samples:
|
||||
stats.error_samples = self._get_error_samples_for_block(
|
||||
stats.block_id, start_time, end_time, limit=2
|
||||
)
|
||||
|
||||
count_text = (
|
||||
f"top {self.include_top_blocks}" if self.include_top_blocks > 1 else "top"
|
||||
)
|
||||
alert_msg = f"📊 Daily Error Summary - {count_text} blocks with most errors:"
|
||||
for block_name, stats in top_error_blocks:
|
||||
alert_msg += f"\n• {block_name}: {stats.failed_executions} errors ({stats.error_rate:.1f}% of {stats.total_executions})"
|
||||
|
||||
if stats.error_samples:
|
||||
error_groups = self._group_similar_errors(stats.error_samples)
|
||||
if error_groups:
|
||||
# Show most common error
|
||||
most_common_error = next(iter(error_groups.items()))
|
||||
alert_msg += f"\n └ Most common: {most_common_error[0]}"
|
||||
|
||||
return alert_msg
|
||||
|
||||
def _get_error_samples_for_block(
|
||||
self, block_id: str, start_time: datetime, end_time: datetime, limit: int = 3
|
||||
) -> list[str]:
|
||||
"""Get error samples for a specific block - just a few recent ones."""
|
||||
# Only fetch a small number of recent failed executions for this specific block
|
||||
executions = execution_utils.get_db_client().get_node_executions(
|
||||
block_ids=[block_id],
|
||||
statuses=[ExecutionStatus.FAILED],
|
||||
created_time_gte=start_time,
|
||||
created_time_lte=end_time,
|
||||
limit=limit, # Just get the limit we need
|
||||
)
|
||||
|
||||
error_samples = []
|
||||
for execution in executions:
|
||||
if error_message := self._extract_error_message(execution):
|
||||
masked_error = self._mask_sensitive_data(error_message)
|
||||
error_samples.append(masked_error)
|
||||
|
||||
if len(error_samples) >= limit: # Stop once we have enough samples
|
||||
break
|
||||
|
||||
return error_samples
|
||||
|
||||
def _extract_error_message(self, execution: NodeExecutionResult) -> str | None:
|
||||
"""Extract error message from execution output."""
|
||||
try:
|
||||
if execution.output_data and (
|
||||
error_msg := execution.output_data.get("error")
|
||||
):
|
||||
return str(error_msg[0])
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _mask_sensitive_data(self, error_message):
|
||||
"""Mask sensitive data in error messages to enable grouping."""
|
||||
if not error_message:
|
||||
return ""
|
||||
|
||||
# Convert to string if not already
|
||||
error_str = str(error_message)
|
||||
|
||||
# Mask numbers (replace with X)
|
||||
error_str = re.sub(r"\d+", "X", error_str)
|
||||
|
||||
# Mask all caps words (likely constants/IDs)
|
||||
error_str = re.sub(r"\b[A-Z_]{3,}\b", "MASKED", error_str)
|
||||
|
||||
# Mask words with underscores (likely internal variables)
|
||||
error_str = re.sub(r"\b\w*_\w*\b", "MASKED", error_str)
|
||||
|
||||
# Mask UUIDs and long alphanumeric strings
|
||||
error_str = re.sub(
|
||||
r"\b[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}\b",
|
||||
"UUID",
|
||||
error_str,
|
||||
)
|
||||
error_str = re.sub(r"\b[a-f0-9]{20,}\b", "HASH", error_str)
|
||||
|
||||
# Mask file paths
|
||||
error_str = re.sub(r"(/[^/\s]+)+", "/MASKED/path", error_str)
|
||||
|
||||
# Mask URLs
|
||||
error_str = re.sub(r"https?://[^\s]+", "URL", error_str)
|
||||
|
||||
# Mask email addresses
|
||||
error_str = re.sub(
|
||||
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", "EMAIL", error_str
|
||||
)
|
||||
|
||||
# Truncate if too long
|
||||
if len(error_str) > 100:
|
||||
error_str = error_str[:97] + "..."
|
||||
|
||||
return error_str.strip()
|
||||
|
||||
def _group_similar_errors(self, error_samples):
|
||||
"""Group similar error messages and return counts."""
|
||||
if not error_samples:
|
||||
return {}
|
||||
|
||||
error_groups = {}
|
||||
for error in error_samples:
|
||||
if error in error_groups:
|
||||
error_groups[error] += 1
|
||||
else:
|
||||
error_groups[error] = 1
|
||||
|
||||
# Sort by frequency, most common first
|
||||
return dict(sorted(error_groups.items(), key=lambda x: x[1], reverse=True))
|
||||
|
||||
|
||||
def report_block_error_rates(include_top_blocks: int | None = None):
|
||||
"""Check block error rates and send Discord alerts if thresholds are exceeded."""
|
||||
monitor = BlockErrorMonitor(include_top_blocks=include_top_blocks)
|
||||
return monitor.check_block_error_rates()
|
||||
@@ -0,0 +1,71 @@
|
||||
"""Late execution monitoring module."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.metrics import sentry_capture_error
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = Config()
|
||||
|
||||
|
||||
class LateExecutionException(Exception):
|
||||
"""Exception raised when late executions are detected."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LateExecutionMonitor:
|
||||
"""Monitor late executions and send alerts when thresholds are exceeded."""
|
||||
|
||||
def __init__(self):
|
||||
self.config = config
|
||||
self.notification_client = get_service_client(NotificationManagerClient)
|
||||
|
||||
def check_late_executions(self) -> str:
|
||||
"""Check for late executions and send alerts if found."""
|
||||
late_executions = execution_utils.get_db_client().get_graph_executions(
|
||||
statuses=[ExecutionStatus.QUEUED],
|
||||
created_time_gte=datetime.now(timezone.utc)
|
||||
- timedelta(
|
||||
seconds=self.config.execution_late_notification_checkrange_secs
|
||||
),
|
||||
created_time_lte=datetime.now(timezone.utc)
|
||||
- timedelta(seconds=self.config.execution_late_notification_threshold_secs),
|
||||
limit=1000,
|
||||
)
|
||||
|
||||
if not late_executions:
|
||||
return "No late executions detected."
|
||||
|
||||
num_late_executions = len(late_executions)
|
||||
num_users = len(set([r.user_id for r in late_executions]))
|
||||
|
||||
late_execution_details = [
|
||||
f"* `Execution ID: {exec.id}, Graph ID: {exec.graph_id}v{exec.graph_version}, User ID: {exec.user_id}, Created At: {exec.started_at.isoformat()}`"
|
||||
for exec in late_executions
|
||||
]
|
||||
|
||||
error = LateExecutionException(
|
||||
f"Late executions detected: {num_late_executions} late executions from {num_users} users "
|
||||
f"in the last {self.config.execution_late_notification_checkrange_secs} seconds. "
|
||||
f"Graph has been queued for more than {self.config.execution_late_notification_threshold_secs} seconds. "
|
||||
"Please check the executor status. Details:\n"
|
||||
+ "\n".join(late_execution_details)
|
||||
)
|
||||
msg = str(error)
|
||||
|
||||
sentry_capture_error(error)
|
||||
self.notification_client.discord_system_alert(msg)
|
||||
return msg
|
||||
|
||||
|
||||
def report_late_executions() -> str:
|
||||
"""Check for late executions and send Discord alerts if found."""
|
||||
monitor = LateExecutionMonitor()
|
||||
return monitor.check_late_executions()
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Notification processing monitoring module."""
|
||||
|
||||
import logging
|
||||
|
||||
from prisma.enums import NotificationType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotificationJobArgs(BaseModel):
|
||||
notification_types: list[NotificationType]
|
||||
cron: str
|
||||
|
||||
|
||||
def process_existing_batches(**kwargs):
|
||||
"""Process existing notification batches."""
|
||||
args = NotificationJobArgs(**kwargs)
|
||||
try:
|
||||
logging.info(
|
||||
f"Processing existing batches for notification type {args.notification_types}"
|
||||
)
|
||||
get_service_client(NotificationManagerClient).process_existing_batches(
|
||||
args.notification_types
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing existing batches: {e}")
|
||||
|
||||
|
||||
def process_weekly_summary(**kwargs):
|
||||
"""Process weekly summary notifications."""
|
||||
try:
|
||||
logging.info("Processing weekly summary")
|
||||
get_service_client(NotificationManagerClient).queue_weekly_summary()
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing weekly summary: {e}")
|
||||
169
autogpt_platform/backend/backend/sdk/__init__.py
Normal file
169
autogpt_platform/backend/backend/sdk/__init__.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
AutoGPT Platform Block Development SDK
|
||||
|
||||
Complete re-export of all dependencies needed for block development.
|
||||
Usage: from backend.sdk import *
|
||||
|
||||
This module provides:
|
||||
- All block base classes and types
|
||||
- All credential and authentication components
|
||||
- All cost tracking components
|
||||
- All webhook components
|
||||
- All utility functions
|
||||
- Auto-registration decorators
|
||||
"""
|
||||
|
||||
# Third-party imports
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
# === CORE BLOCK SYSTEM ===
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockManualWebhookConfig,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
BlockWebhookConfig,
|
||||
)
|
||||
from backend.data.integrations import Webhook
|
||||
from backend.data.model import APIKeyCredentials, Credentials, CredentialsField
|
||||
from backend.data.model import CredentialsMetaInput as _CredentialsMetaInput
|
||||
from backend.data.model import (
|
||||
NodeExecutionStats,
|
||||
OAuth2Credentials,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
|
||||
# === INTEGRATIONS ===
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.sdk.builder import ProviderBuilder
|
||||
from backend.sdk.cost_integration import cost
|
||||
from backend.sdk.provider import Provider
|
||||
|
||||
# === NEW SDK COMPONENTS (imported early for patches) ===
|
||||
from backend.sdk.registry import AutoRegistry, BlockConfiguration
|
||||
|
||||
# === UTILITIES ===
|
||||
from backend.util import json
|
||||
from backend.util.request import Requests
|
||||
|
||||
# === OPTIONAL IMPORTS WITH TRY/EXCEPT ===
|
||||
# Webhooks
|
||||
try:
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
except ImportError:
|
||||
BaseWebhooksManager = None
|
||||
|
||||
try:
|
||||
from backend.integrations.webhooks._manual_base import ManualWebhookManagerBase
|
||||
except ImportError:
|
||||
ManualWebhookManagerBase = None
|
||||
|
||||
# Cost System
|
||||
try:
|
||||
from backend.data.cost import BlockCost, BlockCostType
|
||||
except ImportError:
|
||||
from backend.data.block_cost_config import BlockCost, BlockCostType
|
||||
|
||||
try:
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
except ImportError:
|
||||
UsageTransactionMetadata = None
|
||||
|
||||
try:
|
||||
from backend.executor.utils import block_usage_cost
|
||||
except ImportError:
|
||||
block_usage_cost = None
|
||||
|
||||
# Utilities
|
||||
try:
|
||||
from backend.util.file import store_media_file
|
||||
except ImportError:
|
||||
store_media_file = None
|
||||
|
||||
try:
|
||||
from backend.util.type import MediaFileType, convert
|
||||
except ImportError:
|
||||
MediaFileType = None
|
||||
convert = None
|
||||
|
||||
try:
|
||||
from backend.util.text import TextFormatter
|
||||
except ImportError:
|
||||
TextFormatter = None
|
||||
|
||||
try:
|
||||
from backend.util.logging import TruncatedLogger
|
||||
except ImportError:
|
||||
TruncatedLogger = None
|
||||
|
||||
|
||||
# OAuth handlers
|
||||
try:
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
except ImportError:
|
||||
BaseOAuthHandler = None
|
||||
|
||||
|
||||
# Credential type with proper provider name
|
||||
from typing import Literal as _Literal
|
||||
|
||||
CredentialsMetaInput = _CredentialsMetaInput[
|
||||
ProviderName, _Literal["api_key", "oauth2", "user_password"]
|
||||
]
|
||||
|
||||
|
||||
# === COMPREHENSIVE __all__ EXPORT ===
|
||||
__all__ = [
|
||||
# Core Block System
|
||||
"Block",
|
||||
"BlockCategory",
|
||||
"BlockOutput",
|
||||
"BlockSchema",
|
||||
"BlockType",
|
||||
"BlockWebhookConfig",
|
||||
"BlockManualWebhookConfig",
|
||||
# Schema and Model Components
|
||||
"SchemaField",
|
||||
"Credentials",
|
||||
"CredentialsField",
|
||||
"CredentialsMetaInput",
|
||||
"APIKeyCredentials",
|
||||
"OAuth2Credentials",
|
||||
"UserPasswordCredentials",
|
||||
"NodeExecutionStats",
|
||||
# Cost System
|
||||
"BlockCost",
|
||||
"BlockCostType",
|
||||
"UsageTransactionMetadata",
|
||||
"block_usage_cost",
|
||||
# Integrations
|
||||
"ProviderName",
|
||||
"BaseWebhooksManager",
|
||||
"ManualWebhookManagerBase",
|
||||
"Webhook",
|
||||
# Provider-Specific (when available)
|
||||
"BaseOAuthHandler",
|
||||
# Utilities
|
||||
"json",
|
||||
"store_media_file",
|
||||
"MediaFileType",
|
||||
"convert",
|
||||
"TextFormatter",
|
||||
"TruncatedLogger",
|
||||
"BaseModel",
|
||||
"Field",
|
||||
"SecretStr",
|
||||
"Requests",
|
||||
# SDK Components
|
||||
"AutoRegistry",
|
||||
"BlockConfiguration",
|
||||
"Provider",
|
||||
"ProviderBuilder",
|
||||
"cost",
|
||||
]
|
||||
|
||||
# Remove None values from __all__
|
||||
__all__ = [name for name in __all__ if globals().get(name) is not None]
|
||||
161
autogpt_platform/backend/backend/sdk/builder.py
Normal file
161
autogpt_platform/backend/backend/sdk/builder.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
Builder class for creating provider configurations with a fluent API.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Callable, List, Optional, Type
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.cost import BlockCost, BlockCostType
|
||||
from backend.data.model import APIKeyCredentials, Credentials, UserPasswordCredentials
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
from backend.sdk.provider import OAuthConfig, Provider
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
class ProviderBuilder:
|
||||
"""Builder for creating provider configurations."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self._oauth_config: Optional[OAuthConfig] = None
|
||||
self._webhook_manager: Optional[Type[BaseWebhooksManager]] = None
|
||||
self._default_credentials: List[Credentials] = []
|
||||
self._base_costs: List[BlockCost] = []
|
||||
self._supported_auth_types: set = set()
|
||||
self._api_client_factory: Optional[Callable] = None
|
||||
self._error_handler: Optional[Callable[[Exception], str]] = None
|
||||
self._default_scopes: Optional[List[str]] = None
|
||||
self._client_id_env_var: Optional[str] = None
|
||||
self._client_secret_env_var: Optional[str] = None
|
||||
self._extra_config: dict = {}
|
||||
|
||||
def with_oauth(
|
||||
self,
|
||||
handler_class: Type[BaseOAuthHandler],
|
||||
scopes: Optional[List[str]] = None,
|
||||
client_id_env_var: Optional[str] = None,
|
||||
client_secret_env_var: Optional[str] = None,
|
||||
) -> "ProviderBuilder":
|
||||
"""Add OAuth support."""
|
||||
self._oauth_config = OAuthConfig(
|
||||
oauth_handler=handler_class,
|
||||
scopes=scopes,
|
||||
client_id_env_var=client_id_env_var,
|
||||
client_secret_env_var=client_secret_env_var,
|
||||
)
|
||||
self._supported_auth_types.add("oauth2")
|
||||
return self
|
||||
|
||||
def with_api_key(self, env_var_name: str, title: str) -> "ProviderBuilder":
|
||||
"""Add API key support with environment variable name."""
|
||||
self._supported_auth_types.add("api_key")
|
||||
|
||||
# Register the API key mapping
|
||||
AutoRegistry.register_api_key(self.name, env_var_name)
|
||||
|
||||
# Check if API key exists in environment
|
||||
api_key = os.getenv(env_var_name)
|
||||
if api_key:
|
||||
self._default_credentials.append(
|
||||
APIKeyCredentials(
|
||||
id=f"{self.name}-default",
|
||||
provider=self.name,
|
||||
api_key=SecretStr(api_key),
|
||||
title=title,
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def with_api_key_from_settings(
|
||||
self, settings_attr: str, title: str
|
||||
) -> "ProviderBuilder":
|
||||
"""Use existing API key from settings."""
|
||||
self._supported_auth_types.add("api_key")
|
||||
|
||||
# Try to get the API key from settings
|
||||
settings = Settings()
|
||||
api_key = getattr(settings.secrets, settings_attr, None)
|
||||
if api_key:
|
||||
self._default_credentials.append(
|
||||
APIKeyCredentials(
|
||||
id=f"{self.name}-default",
|
||||
provider=self.name,
|
||||
api_key=api_key,
|
||||
title=title,
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def with_user_password(
|
||||
self, username_env_var: str, password_env_var: str, title: str
|
||||
) -> "ProviderBuilder":
|
||||
"""Add username/password support with environment variable names."""
|
||||
self._supported_auth_types.add("user_password")
|
||||
|
||||
# Check if credentials exist in environment
|
||||
username = os.getenv(username_env_var)
|
||||
password = os.getenv(password_env_var)
|
||||
if username and password:
|
||||
self._default_credentials.append(
|
||||
UserPasswordCredentials(
|
||||
id=f"{self.name}-default",
|
||||
provider=self.name,
|
||||
username=SecretStr(username),
|
||||
password=SecretStr(password),
|
||||
title=title,
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def with_webhook_manager(
|
||||
self, manager_class: Type[BaseWebhooksManager]
|
||||
) -> "ProviderBuilder":
|
||||
"""Register webhook manager for this provider."""
|
||||
self._webhook_manager = manager_class
|
||||
return self
|
||||
|
||||
def with_base_cost(
|
||||
self, amount: int, cost_type: BlockCostType
|
||||
) -> "ProviderBuilder":
|
||||
"""Set base cost for all blocks using this provider."""
|
||||
self._base_costs.append(BlockCost(cost_amount=amount, cost_type=cost_type))
|
||||
return self
|
||||
|
||||
def with_api_client(self, factory: Callable) -> "ProviderBuilder":
|
||||
"""Register API client factory."""
|
||||
self._api_client_factory = factory
|
||||
return self
|
||||
|
||||
def with_error_handler(
|
||||
self, handler: Callable[[Exception], str]
|
||||
) -> "ProviderBuilder":
|
||||
"""Register error handler for provider-specific errors."""
|
||||
self._error_handler = handler
|
||||
return self
|
||||
|
||||
def with_config(self, **kwargs) -> "ProviderBuilder":
|
||||
"""Add additional configuration options."""
|
||||
self._extra_config.update(kwargs)
|
||||
return self
|
||||
|
||||
def build(self) -> Provider:
|
||||
"""Build and register the provider configuration."""
|
||||
provider = Provider(
|
||||
name=self.name,
|
||||
oauth_config=self._oauth_config,
|
||||
webhook_manager=self._webhook_manager,
|
||||
default_credentials=self._default_credentials,
|
||||
base_costs=self._base_costs,
|
||||
supported_auth_types=self._supported_auth_types,
|
||||
api_client_factory=self._api_client_factory,
|
||||
error_handler=self._error_handler,
|
||||
**self._extra_config,
|
||||
)
|
||||
|
||||
# Auto-registration happens here
|
||||
AutoRegistry.register_provider(provider)
|
||||
return provider
|
||||
163
autogpt_platform/backend/backend/sdk/cost_integration.py
Normal file
163
autogpt_platform/backend/backend/sdk/cost_integration.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Integration between SDK provider costs and the execution cost system.
|
||||
|
||||
This module provides the glue between provider-defined base costs and the
|
||||
BLOCK_COSTS configuration used by the execution system.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Type
|
||||
|
||||
from backend.data.block import Block
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCost
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_provider_costs_for_block(block_class: Type[Block]) -> None:
|
||||
"""
|
||||
Register provider base costs for a specific block in BLOCK_COSTS.
|
||||
|
||||
This function checks if the block uses credentials from a provider that has
|
||||
base costs defined, and automatically registers those costs for the block.
|
||||
|
||||
Args:
|
||||
block_class: The block class to register costs for
|
||||
"""
|
||||
# Skip if block already has custom costs defined
|
||||
if block_class in BLOCK_COSTS:
|
||||
logger.debug(
|
||||
f"Block {block_class.__name__} already has costs defined, skipping provider costs"
|
||||
)
|
||||
return
|
||||
|
||||
# Get the block's input schema
|
||||
# We need to instantiate the block to get its input schema
|
||||
try:
|
||||
block_instance = block_class()
|
||||
input_schema = block_instance.input_schema
|
||||
except Exception as e:
|
||||
logger.debug(f"Block {block_class.__name__} cannot be instantiated: {e}")
|
||||
return
|
||||
|
||||
# Look for credentials fields
|
||||
# The cost system works of filtering on credentials fields,
|
||||
# without credentials fields, we can not apply costs
|
||||
# TODO: Improve cost system to allow for costs witout a provider
|
||||
credentials_fields = input_schema.get_credentials_fields()
|
||||
if not credentials_fields:
|
||||
logger.debug(f"Block {block_class.__name__} has no credentials fields")
|
||||
return
|
||||
|
||||
# Get provider information from credentials fields
|
||||
for field_name, field_info in credentials_fields.items():
|
||||
# Get the field schema to extract provider information
|
||||
field_schema = input_schema.get_field_schema(field_name)
|
||||
|
||||
# Extract provider names from json_schema_extra
|
||||
providers = field_schema.get("credentials_provider", [])
|
||||
if not providers:
|
||||
continue
|
||||
|
||||
# For each provider, check if it has base costs
|
||||
block_costs: List[BlockCost] = []
|
||||
for provider_name in providers:
|
||||
provider = AutoRegistry.get_provider(provider_name)
|
||||
if not provider:
|
||||
logger.debug(f"Provider {provider_name} not found in registry")
|
||||
continue
|
||||
|
||||
# Add provider's base costs to the block
|
||||
if provider.base_costs:
|
||||
logger.info(
|
||||
f"Registering {len(provider.base_costs)} base costs from provider {provider_name} for block {block_class.__name__}"
|
||||
)
|
||||
block_costs.extend(provider.base_costs)
|
||||
|
||||
# Register costs if any were found
|
||||
if block_costs:
|
||||
BLOCK_COSTS[block_class] = block_costs
|
||||
logger.info(
|
||||
f"Registered {len(block_costs)} total costs for block {block_class.__name__}"
|
||||
)
|
||||
|
||||
|
||||
def sync_all_provider_costs() -> None:
|
||||
"""
|
||||
Sync all provider base costs to blocks that use them.
|
||||
|
||||
This should be called after all providers and blocks are registered,
|
||||
typically during application startup.
|
||||
"""
|
||||
from backend.blocks import load_all_blocks
|
||||
|
||||
logger.info("Syncing provider costs to blocks...")
|
||||
|
||||
blocks_with_costs = 0
|
||||
total_costs = 0
|
||||
|
||||
for block_id, block_class in load_all_blocks().items():
|
||||
initial_count = len(BLOCK_COSTS.get(block_class, []))
|
||||
register_provider_costs_for_block(block_class)
|
||||
final_count = len(BLOCK_COSTS.get(block_class, []))
|
||||
|
||||
if final_count > initial_count:
|
||||
blocks_with_costs += 1
|
||||
total_costs += final_count - initial_count
|
||||
|
||||
logger.info(f"Synced {total_costs} costs to {blocks_with_costs} blocks")
|
||||
|
||||
|
||||
def get_block_costs(block_class: Type[Block]) -> List[BlockCost]:
|
||||
"""
|
||||
Get all costs for a block, including both explicit and provider costs.
|
||||
|
||||
Args:
|
||||
block_class: The block class to get costs for
|
||||
|
||||
Returns:
|
||||
List of BlockCost objects for the block
|
||||
"""
|
||||
# First ensure provider costs are registered
|
||||
register_provider_costs_for_block(block_class)
|
||||
|
||||
# Return all costs for the block
|
||||
return BLOCK_COSTS.get(block_class, [])
|
||||
|
||||
|
||||
def cost(*costs: BlockCost):
|
||||
"""
|
||||
Decorator to set custom costs for a block.
|
||||
|
||||
This decorator allows blocks to define their own costs, which will override
|
||||
any provider base costs. Multiple costs can be specified with different
|
||||
filters for different pricing tiers (e.g., different models).
|
||||
|
||||
Example:
|
||||
@cost(
|
||||
BlockCost(cost_type=BlockCostType.RUN, cost_amount=10),
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_amount=20,
|
||||
cost_filter={"model": "premium"}
|
||||
)
|
||||
)
|
||||
class MyBlock(Block):
|
||||
...
|
||||
|
||||
Args:
|
||||
*costs: Variable number of BlockCost objects
|
||||
"""
|
||||
|
||||
def decorator(block_class: Type[Block]) -> Type[Block]:
|
||||
# Register the costs for this block
|
||||
if costs:
|
||||
BLOCK_COSTS[block_class] = list(costs)
|
||||
logger.info(
|
||||
f"Registered {len(costs)} custom costs for block {block_class.__name__}"
|
||||
)
|
||||
return block_class
|
||||
|
||||
return decorator
|
||||
114
autogpt_platform/backend/backend/sdk/provider.py
Normal file
114
autogpt_platform/backend/backend/sdk/provider.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
Provider configuration class that holds all provider-related settings.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, List, Optional, Set, Type
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.cost import BlockCost
|
||||
from backend.data.model import Credentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
|
||||
|
||||
class OAuthConfig(BaseModel):
|
||||
"""Configuration for OAuth authentication."""
|
||||
|
||||
oauth_handler: Type[BaseOAuthHandler]
|
||||
scopes: Optional[List[str]] = None
|
||||
client_id_env_var: Optional[str] = None
|
||||
client_secret_env_var: Optional[str] = None
|
||||
|
||||
|
||||
class Provider:
|
||||
"""A configured provider that blocks can use.
|
||||
|
||||
A Provider represents a service or platform that blocks can integrate with, like Linear, OpenAI, etc.
|
||||
It contains configuration for:
|
||||
- Authentication (OAuth, API keys)
|
||||
- Default credentials
|
||||
- Base costs for using the provider
|
||||
- Webhook handling
|
||||
- Error handling
|
||||
- API client factory
|
||||
|
||||
Blocks use Provider instances to handle authentication, make API calls, and manage service-specific logic.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
oauth_config: Optional[OAuthConfig] = None,
|
||||
webhook_manager: Optional[Type[BaseWebhooksManager]] = None,
|
||||
default_credentials: Optional[List[Credentials]] = None,
|
||||
base_costs: Optional[List[BlockCost]] = None,
|
||||
supported_auth_types: Optional[Set[str]] = None,
|
||||
api_client_factory: Optional[Callable] = None,
|
||||
error_handler: Optional[Callable[[Exception], str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.name = name
|
||||
self.oauth_config = oauth_config
|
||||
self.webhook_manager = webhook_manager
|
||||
self.default_credentials = default_credentials or []
|
||||
self.base_costs = base_costs or []
|
||||
self.supported_auth_types = supported_auth_types or set()
|
||||
self._api_client_factory = api_client_factory
|
||||
self._error_handler = error_handler
|
||||
|
||||
# Store any additional configuration
|
||||
self._extra_config = kwargs
|
||||
|
||||
def credentials_field(self, **kwargs) -> CredentialsMetaInput:
|
||||
"""Return a CredentialsField configured for this provider."""
|
||||
# Extract known CredentialsField parameters
|
||||
title = kwargs.pop("title", None)
|
||||
description = kwargs.pop("description", f"{self.name.title()} credentials")
|
||||
required_scopes = kwargs.pop("required_scopes", set())
|
||||
discriminator = kwargs.pop("discriminator", None)
|
||||
discriminator_mapping = kwargs.pop("discriminator_mapping", None)
|
||||
discriminator_values = kwargs.pop("discriminator_values", None)
|
||||
|
||||
# Create json_schema_extra with provider information
|
||||
json_schema_extra = {
|
||||
"credentials_provider": [self.name],
|
||||
"credentials_types": (
|
||||
list(self.supported_auth_types)
|
||||
if self.supported_auth_types
|
||||
else ["api_key"]
|
||||
),
|
||||
}
|
||||
|
||||
# Merge any existing json_schema_extra
|
||||
if "json_schema_extra" in kwargs:
|
||||
json_schema_extra.update(kwargs.pop("json_schema_extra"))
|
||||
|
||||
# Add json_schema_extra to kwargs
|
||||
kwargs["json_schema_extra"] = json_schema_extra
|
||||
|
||||
return CredentialsField(
|
||||
required_scopes=required_scopes,
|
||||
discriminator=discriminator,
|
||||
discriminator_mapping=discriminator_mapping,
|
||||
discriminator_values=discriminator_values,
|
||||
title=title,
|
||||
description=description,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_api(self, credentials: Credentials) -> Any:
|
||||
"""Get API client instance for the given credentials."""
|
||||
if self._api_client_factory:
|
||||
return self._api_client_factory(credentials)
|
||||
raise NotImplementedError(f"No API client factory registered for {self.name}")
|
||||
|
||||
def handle_error(self, error: Exception) -> str:
|
||||
"""Handle provider-specific errors."""
|
||||
if self._error_handler:
|
||||
return self._error_handler(error)
|
||||
return str(error)
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get additional configuration value."""
|
||||
return self._extra_config.get(key, default)
|
||||
220
autogpt_platform/backend/backend/sdk/registry.py
Normal file
220
autogpt_platform/backend/backend/sdk/registry.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Auto-registration system for blocks, providers, and their configurations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.blocks.basic import Block
|
||||
from backend.data.model import APIKeyCredentials, Credentials
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.sdk.provider import Provider
|
||||
|
||||
|
||||
class SDKOAuthCredentials(BaseModel):
|
||||
"""OAuth credentials configuration for SDK providers."""
|
||||
|
||||
use_secrets: bool = False
|
||||
client_id_env_var: Optional[str] = None
|
||||
client_secret_env_var: Optional[str] = None
|
||||
|
||||
|
||||
class BlockConfiguration:
|
||||
"""Configuration associated with a block."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str,
|
||||
costs: List[Any],
|
||||
default_credentials: List[Credentials],
|
||||
webhook_manager: Optional[Type[BaseWebhooksManager]] = None,
|
||||
oauth_handler: Optional[Type[BaseOAuthHandler]] = None,
|
||||
):
|
||||
self.provider = provider
|
||||
self.costs = costs
|
||||
self.default_credentials = default_credentials
|
||||
self.webhook_manager = webhook_manager
|
||||
self.oauth_handler = oauth_handler
|
||||
|
||||
|
||||
class AutoRegistry:
|
||||
"""Central registry for all block-related configurations."""
|
||||
|
||||
_lock = threading.Lock()
|
||||
_providers: Dict[str, "Provider"] = {}
|
||||
_default_credentials: List[Credentials] = []
|
||||
_oauth_handlers: Dict[str, Type[BaseOAuthHandler]] = {}
|
||||
_oauth_credentials: Dict[str, SDKOAuthCredentials] = {}
|
||||
_webhook_managers: Dict[str, Type[BaseWebhooksManager]] = {}
|
||||
_block_configurations: Dict[Type[Block], BlockConfiguration] = {}
|
||||
_api_key_mappings: Dict[str, str] = {} # provider -> env_var_name
|
||||
|
||||
@classmethod
|
||||
def register_provider(cls, provider: "Provider") -> None:
|
||||
"""Auto-register provider and all its configurations."""
|
||||
with cls._lock:
|
||||
cls._providers[provider.name] = provider
|
||||
|
||||
# Register OAuth handler if provided
|
||||
if provider.oauth_config:
|
||||
# Dynamically set PROVIDER_NAME if not already set
|
||||
if (
|
||||
not hasattr(provider.oauth_config.oauth_handler, "PROVIDER_NAME")
|
||||
or provider.oauth_config.oauth_handler.PROVIDER_NAME is None
|
||||
):
|
||||
# Import ProviderName to create dynamic enum value
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
# This works because ProviderName has _missing_ method
|
||||
provider.oauth_config.oauth_handler.PROVIDER_NAME = ProviderName(
|
||||
provider.name
|
||||
)
|
||||
cls._oauth_handlers[provider.name] = provider.oauth_config.oauth_handler
|
||||
|
||||
# Register OAuth credentials configuration
|
||||
oauth_creds = SDKOAuthCredentials(
|
||||
use_secrets=False, # SDK providers use custom env vars
|
||||
client_id_env_var=provider.oauth_config.client_id_env_var,
|
||||
client_secret_env_var=provider.oauth_config.client_secret_env_var,
|
||||
)
|
||||
cls._oauth_credentials[provider.name] = oauth_creds
|
||||
|
||||
# Register webhook manager if provided
|
||||
if provider.webhook_manager:
|
||||
# Dynamically set PROVIDER_NAME if not already set
|
||||
if (
|
||||
not hasattr(provider.webhook_manager, "PROVIDER_NAME")
|
||||
or provider.webhook_manager.PROVIDER_NAME is None
|
||||
):
|
||||
# Import ProviderName to create dynamic enum value
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
# This works because ProviderName has _missing_ method
|
||||
provider.webhook_manager.PROVIDER_NAME = ProviderName(provider.name)
|
||||
cls._webhook_managers[provider.name] = provider.webhook_manager
|
||||
|
||||
# Register default credentials
|
||||
cls._default_credentials.extend(provider.default_credentials)
|
||||
|
||||
@classmethod
|
||||
def register_api_key(cls, provider: str, env_var_name: str) -> None:
|
||||
"""Register an environment variable as an API key for a provider."""
|
||||
with cls._lock:
|
||||
cls._api_key_mappings[provider] = env_var_name
|
||||
|
||||
# Dynamically check if the env var exists and create credential
|
||||
import os
|
||||
|
||||
api_key = os.getenv(env_var_name)
|
||||
if api_key:
|
||||
credential = APIKeyCredentials(
|
||||
id=f"{provider}-default",
|
||||
provider=provider,
|
||||
api_key=SecretStr(api_key),
|
||||
title=f"Default {provider} credentials",
|
||||
)
|
||||
# Check if credential already exists to avoid duplicates
|
||||
if not any(c.id == credential.id for c in cls._default_credentials):
|
||||
cls._default_credentials.append(credential)
|
||||
|
||||
@classmethod
|
||||
def get_all_credentials(cls) -> List[Credentials]:
|
||||
"""Replace hardcoded get_all_creds() in credentials_store.py."""
|
||||
with cls._lock:
|
||||
return cls._default_credentials.copy()
|
||||
|
||||
@classmethod
|
||||
def get_oauth_handlers(cls) -> Dict[str, Type[BaseOAuthHandler]]:
|
||||
"""Replace HANDLERS_BY_NAME in oauth/__init__.py."""
|
||||
with cls._lock:
|
||||
return cls._oauth_handlers.copy()
|
||||
|
||||
@classmethod
|
||||
def get_oauth_credentials(cls) -> Dict[str, SDKOAuthCredentials]:
|
||||
"""Get OAuth credentials configuration for SDK providers."""
|
||||
with cls._lock:
|
||||
return cls._oauth_credentials.copy()
|
||||
|
||||
@classmethod
|
||||
def get_webhook_managers(cls) -> Dict[str, Type[BaseWebhooksManager]]:
|
||||
"""Replace load_webhook_managers() in webhooks/__init__.py."""
|
||||
with cls._lock:
|
||||
return cls._webhook_managers.copy()
|
||||
|
||||
@classmethod
|
||||
def register_block_configuration(
|
||||
cls, block_class: Type[Block], config: BlockConfiguration
|
||||
) -> None:
|
||||
"""Register configuration for a specific block class."""
|
||||
with cls._lock:
|
||||
cls._block_configurations[block_class] = config
|
||||
|
||||
@classmethod
|
||||
def get_provider(cls, name: str) -> Optional["Provider"]:
|
||||
"""Get a registered provider by name."""
|
||||
with cls._lock:
|
||||
return cls._providers.get(name)
|
||||
|
||||
@classmethod
|
||||
def get_all_provider_names(cls) -> List[str]:
|
||||
"""Get all registered provider names."""
|
||||
with cls._lock:
|
||||
return list(cls._providers.keys())
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
"""Clear all registrations (useful for testing)."""
|
||||
with cls._lock:
|
||||
cls._providers.clear()
|
||||
cls._default_credentials.clear()
|
||||
cls._oauth_handlers.clear()
|
||||
cls._webhook_managers.clear()
|
||||
cls._block_configurations.clear()
|
||||
cls._api_key_mappings.clear()
|
||||
|
||||
@classmethod
|
||||
def patch_integrations(cls) -> None:
|
||||
"""Patch existing integration points to use AutoRegistry."""
|
||||
# OAuth handlers are handled by SDKAwareHandlersDict in oauth/__init__.py
|
||||
# No patching needed for OAuth handlers
|
||||
|
||||
# Patch webhook managers
|
||||
try:
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
# Get the module from sys.modules to respect mocking
|
||||
if "backend.integrations.webhooks" in sys.modules:
|
||||
webhooks: Any = sys.modules["backend.integrations.webhooks"]
|
||||
else:
|
||||
import backend.integrations.webhooks
|
||||
|
||||
webhooks: Any = backend.integrations.webhooks
|
||||
|
||||
if hasattr(webhooks, "load_webhook_managers"):
|
||||
original_load = webhooks.load_webhook_managers
|
||||
|
||||
def patched_load():
|
||||
# Get original managers
|
||||
managers = original_load()
|
||||
# Add SDK-registered managers
|
||||
sdk_managers = cls.get_webhook_managers()
|
||||
if isinstance(sdk_managers, dict):
|
||||
# Import ProviderName for conversion
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
# Convert string keys to ProviderName for consistency
|
||||
for provider_str, manager in sdk_managers.items():
|
||||
provider_name = ProviderName(provider_str)
|
||||
managers[provider_name] = manager
|
||||
return managers
|
||||
|
||||
webhooks.load_webhook_managers = patched_load
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to patch webhook managers: {e}")
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, Any, Dict, List, Optional, Sequence
|
||||
from typing import Annotated, Any, Optional, Sequence
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||
@@ -11,7 +11,6 @@ from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.api_key import APIKey
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.data.execution import NodeExecutionResult
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.server.external.middleware import require_permission
|
||||
from backend.util.settings import Settings
|
||||
@@ -30,30 +29,19 @@ class NodeOutput(TypedDict):
|
||||
class ExecutionNode(TypedDict):
|
||||
node_id: str
|
||||
input: Any
|
||||
output: Dict[str, Any]
|
||||
output: dict[str, Any]
|
||||
|
||||
|
||||
class ExecutionNodeOutput(TypedDict):
|
||||
node_id: str
|
||||
outputs: List[NodeOutput]
|
||||
outputs: list[NodeOutput]
|
||||
|
||||
|
||||
class GraphExecutionResult(TypedDict):
|
||||
execution_id: str
|
||||
status: str
|
||||
nodes: List[ExecutionNode]
|
||||
output: Optional[List[Dict[str, str]]]
|
||||
|
||||
|
||||
def get_outputs_with_names(results: list[NodeExecutionResult]) -> list[dict[str, str]]:
|
||||
outputs = []
|
||||
for result in results:
|
||||
if "output" in result.output_data:
|
||||
output_value = result.output_data["output"][0]
|
||||
name = result.output_data.get("name", [None])[0]
|
||||
if output_value and name:
|
||||
outputs.append({name: output_value})
|
||||
return outputs
|
||||
nodes: list[ExecutionNode]
|
||||
output: Optional[list[dict[str, str]]]
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -122,23 +110,34 @@ async def get_graph_execution_results(
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
results = await execution_db.get_node_executions(graph_exec_id)
|
||||
last_result = results[-1] if results else None
|
||||
execution_status = (
|
||||
last_result.status if last_result else AgentExecutionStatus.INCOMPLETE
|
||||
graph_exec = await execution_db.get_graph_execution(
|
||||
user_id=api_key.user_id,
|
||||
execution_id=graph_exec_id,
|
||||
include_node_executions=True,
|
||||
)
|
||||
outputs = get_outputs_with_names(results)
|
||||
if not graph_exec:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Graph execution #{graph_exec_id} not found."
|
||||
)
|
||||
|
||||
return GraphExecutionResult(
|
||||
execution_id=graph_exec_id,
|
||||
status=execution_status,
|
||||
status=graph_exec.status.value,
|
||||
nodes=[
|
||||
ExecutionNode(
|
||||
node_id=result.node_id,
|
||||
input=result.input_data.get("value", result.input_data),
|
||||
output={k: v for k, v in result.output_data.items()},
|
||||
node_id=node_exec.node_id,
|
||||
input=node_exec.input_data.get("value", node_exec.input_data),
|
||||
output={k: v for k, v in node_exec.output_data.items()},
|
||||
)
|
||||
for result in results
|
||||
for node_exec in graph_exec.node_executions
|
||||
],
|
||||
output=outputs if execution_status == AgentExecutionStatus.COMPLETED else None,
|
||||
output=(
|
||||
[
|
||||
{name: value}
|
||||
for name, values in graph_exec.outputs.items()
|
||||
for value in values
|
||||
]
|
||||
if graph_exec.status == AgentExecutionStatus.COMPLETED
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
Models for integration-related data structures that need to be exposed in the OpenAPI schema.
|
||||
|
||||
This module provides models that will be included in the OpenAPI schema generation,
|
||||
allowing frontend code generators like Orval to create corresponding TypeScript types.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
|
||||
def get_all_provider_names() -> list[str]:
|
||||
"""
|
||||
Collect all provider names from both ProviderName enum and AutoRegistry.
|
||||
|
||||
This function should be called at runtime to ensure we get all
|
||||
dynamically registered providers.
|
||||
|
||||
Returns:
|
||||
A sorted list of unique provider names.
|
||||
"""
|
||||
# Get static providers from enum
|
||||
static_providers = [member.value for member in ProviderName]
|
||||
|
||||
# Get dynamic providers from registry
|
||||
dynamic_providers = AutoRegistry.get_all_provider_names()
|
||||
|
||||
# Combine and deduplicate
|
||||
all_providers = list(set(static_providers + dynamic_providers))
|
||||
all_providers.sort()
|
||||
|
||||
return all_providers
|
||||
|
||||
|
||||
# Note: We don't create a static enum here because providers are registered dynamically.
|
||||
# Instead, we expose provider names through API endpoints that can be fetched at runtime.
|
||||
|
||||
|
||||
class ProviderNamesResponse(BaseModel):
|
||||
"""Response containing list of all provider names."""
|
||||
|
||||
providers: list[str] = Field(
|
||||
description="List of all available provider names",
|
||||
default_factory=get_all_provider_names,
|
||||
)
|
||||
|
||||
|
||||
class ProviderConstants(BaseModel):
|
||||
"""
|
||||
Model that exposes all provider names as a constant in the OpenAPI schema.
|
||||
This is designed to be converted by Orval into a TypeScript constant.
|
||||
"""
|
||||
|
||||
PROVIDER_NAMES: dict[str, str] = Field(
|
||||
description="All available provider names as a constant mapping",
|
||||
default_factory=lambda: {
|
||||
name.upper().replace("-", "_"): name for name in get_all_provider_names()
|
||||
},
|
||||
)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"PROVIDER_NAMES": {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"EXA": "exa",
|
||||
"GEM": "gem",
|
||||
"EXAMPLE_SERVICE": "example-service",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Annotated, Awaitable, Literal
|
||||
from typing import TYPE_CHECKING, Annotated, Awaitable, List, Literal
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
@@ -22,12 +22,22 @@ from backend.data.integrations import (
|
||||
publish_webhook_event,
|
||||
wait_for_webhook_event,
|
||||
)
|
||||
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
|
||||
from backend.data.model import (
|
||||
Credentials,
|
||||
CredentialsType,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.server.integrations.models import (
|
||||
ProviderConstants,
|
||||
ProviderNamesResponse,
|
||||
get_all_provider_names,
|
||||
)
|
||||
from backend.server.v2.library.db import set_preset_webhook, update_preset
|
||||
from backend.util.exceptions import NeedConfirmation, NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
@@ -82,6 +92,9 @@ class CredentialsMetaResponse(BaseModel):
|
||||
title: str | None
|
||||
scopes: list[str] | None
|
||||
username: str | None
|
||||
host: str | None = Field(
|
||||
default=None, description="Host pattern for host-scoped credentials"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{provider}/callback")
|
||||
@@ -156,6 +169,9 @@ async def callback(
|
||||
title=credentials.title,
|
||||
scopes=credentials.scopes,
|
||||
username=credentials.username,
|
||||
host=(
|
||||
credentials.host if isinstance(credentials, HostScopedCredentials) else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -172,6 +188,7 @@ async def list_credentials(
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
@@ -193,6 +210,7 @@ async def list_credentials_by_provider(
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
@@ -459,14 +477,49 @@ async def remove_all_webhooks_for_credentials(
|
||||
def _get_provider_oauth_handler(
|
||||
req: Request, provider_name: ProviderName
|
||||
) -> "BaseOAuthHandler":
|
||||
if provider_name not in HANDLERS_BY_NAME:
|
||||
# Ensure blocks are loaded so SDK providers are available
|
||||
try:
|
||||
from backend.blocks import load_all_blocks
|
||||
|
||||
load_all_blocks() # This is cached, so it only runs once
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load blocks: {e}")
|
||||
|
||||
# Convert provider_name to string for lookup
|
||||
provider_key = (
|
||||
provider_name.value if hasattr(provider_name, "value") else str(provider_name)
|
||||
)
|
||||
|
||||
if provider_key not in HANDLERS_BY_NAME:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Provider '{provider_name.value}' does not support OAuth",
|
||||
detail=f"Provider '{provider_key}' does not support OAuth",
|
||||
)
|
||||
|
||||
# Check if this provider has custom OAuth credentials
|
||||
oauth_credentials = CREDENTIALS_BY_PROVIDER.get(provider_key)
|
||||
|
||||
if oauth_credentials and not oauth_credentials.use_secrets:
|
||||
# SDK provider with custom env vars
|
||||
import os
|
||||
|
||||
client_id = (
|
||||
os.getenv(oauth_credentials.client_id_env_var)
|
||||
if oauth_credentials.client_id_env_var
|
||||
else None
|
||||
)
|
||||
client_secret = (
|
||||
os.getenv(oauth_credentials.client_secret_env_var)
|
||||
if oauth_credentials.client_secret_env_var
|
||||
else None
|
||||
)
|
||||
else:
|
||||
# Original provider using settings.secrets
|
||||
client_id = getattr(settings.secrets, f"{provider_name.value}_client_id", None)
|
||||
client_secret = getattr(
|
||||
settings.secrets, f"{provider_name.value}_client_secret", None
|
||||
)
|
||||
|
||||
client_id = getattr(settings.secrets, f"{provider_name.value}_client_id")
|
||||
client_secret = getattr(settings.secrets, f"{provider_name.value}_client_secret")
|
||||
if not (client_id and client_secret):
|
||||
logger.error(
|
||||
f"Attempt to use unconfigured {provider_name.value} OAuth integration"
|
||||
@@ -479,14 +532,84 @@ def _get_provider_oauth_handler(
|
||||
},
|
||||
)
|
||||
|
||||
handler_class = HANDLERS_BY_NAME[provider_name]
|
||||
frontend_base_url = (
|
||||
settings.config.frontend_base_url
|
||||
or settings.config.platform_base_url
|
||||
or str(req.base_url)
|
||||
)
|
||||
handler_class = HANDLERS_BY_NAME[provider_key]
|
||||
frontend_base_url = settings.config.frontend_base_url
|
||||
|
||||
if not frontend_base_url:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Frontend base URL is not configured",
|
||||
)
|
||||
|
||||
return handler_class(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
|
||||
)
|
||||
|
||||
|
||||
# === PROVIDER DISCOVERY ENDPOINTS ===
|
||||
|
||||
|
||||
@router.get("/providers", response_model=List[str])
|
||||
async def list_providers() -> List[str]:
|
||||
"""
|
||||
Get a list of all available provider names.
|
||||
|
||||
Returns both statically defined providers (from ProviderName enum)
|
||||
and dynamically registered providers (from SDK decorators).
|
||||
|
||||
Note: The complete list of provider names is also available as a constant
|
||||
in the generated TypeScript client via PROVIDER_NAMES.
|
||||
"""
|
||||
# Get all providers at runtime
|
||||
all_providers = get_all_provider_names()
|
||||
return all_providers
|
||||
|
||||
|
||||
@router.get("/providers/names", response_model=ProviderNamesResponse)
|
||||
async def get_provider_names() -> ProviderNamesResponse:
|
||||
"""
|
||||
Get all provider names in a structured format.
|
||||
|
||||
This endpoint is specifically designed to expose the provider names
|
||||
in the OpenAPI schema so that code generators like Orval can create
|
||||
appropriate TypeScript constants.
|
||||
"""
|
||||
return ProviderNamesResponse()
|
||||
|
||||
|
||||
@router.get("/providers/constants", response_model=ProviderConstants)
|
||||
async def get_provider_constants() -> ProviderConstants:
|
||||
"""
|
||||
Get provider names as constants.
|
||||
|
||||
This endpoint returns a model with provider names as constants,
|
||||
specifically designed for OpenAPI code generation tools to create
|
||||
TypeScript constants.
|
||||
"""
|
||||
return ProviderConstants()
|
||||
|
||||
|
||||
class ProviderEnumResponse(BaseModel):
|
||||
"""Response containing a provider from the enum."""
|
||||
|
||||
provider: str = Field(
|
||||
description="A provider name from the complete list of providers"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/providers/enum-example", response_model=ProviderEnumResponse)
|
||||
async def get_provider_enum_example() -> ProviderEnumResponse:
|
||||
"""
|
||||
Example endpoint that uses the CompleteProviderNames enum.
|
||||
|
||||
This endpoint exists to ensure that the CompleteProviderNames enum is included
|
||||
in the OpenAPI schema, which will cause Orval to generate it as a
|
||||
TypeScript enum/constant.
|
||||
"""
|
||||
# Return the first provider as an example
|
||||
all_providers = get_all_provider_names()
|
||||
return ProviderEnumResponse(
|
||||
provider=all_providers[0] if all_providers else "openai"
|
||||
)
|
||||
|
||||
@@ -62,6 +62,10 @@ def launch_darkly_context():
|
||||
async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.db.connect()
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
# SDK auto-registration is now handled by AutoRegistry.patch_integrations()
|
||||
# which is called when the SDK module is imported
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
|
||||
@@ -215,13 +219,22 @@ async def health():
|
||||
|
||||
class AgentServer(backend.util.service.AppProcess):
|
||||
def run(self):
|
||||
server_app = starlette.middleware.cors.CORSMiddleware(
|
||||
app=app,
|
||||
allow_origins=settings.config.backend_cors_allow_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
cors_kwargs = {
|
||||
"app": app,
|
||||
"allow_origins": settings.config.backend_cors_allow_origins,
|
||||
"allow_credentials": True,
|
||||
"allow_methods": ["*"],
|
||||
"allow_headers": ["*"],
|
||||
}
|
||||
|
||||
# Use regex pattern if configured (for dynamic domains like Vercel previews)
|
||||
# Only enable in non-production environments for security
|
||||
if (settings.config.backend_cors_allow_origin_regex and
|
||||
settings.config.app_env.value != "prod"):
|
||||
cors_kwargs["allow_origin_regex"] = settings.config.backend_cors_allow_origin_regex
|
||||
logger.info(f"CORS regex enabled for {settings.config.app_env.value}: {settings.config.backend_cors_allow_origin_regex}")
|
||||
|
||||
server_app = starlette.middleware.cors.CORSMiddleware(**cors_kwargs)
|
||||
uvicorn.run(
|
||||
server_app,
|
||||
host=backend.util.settings.Config().agent_api_host,
|
||||
@@ -357,11 +370,22 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
provider: ProviderName,
|
||||
credentials: Credentials,
|
||||
) -> Credentials:
|
||||
from backend.server.integrations.router import create_credentials
|
||||
|
||||
return await create_credentials(
|
||||
user_id=user_id, provider=provider, credentials=credentials
|
||||
from backend.server.integrations.router import (
|
||||
create_credentials,
|
||||
get_credential,
|
||||
)
|
||||
|
||||
try:
|
||||
return await create_credentials(
|
||||
user_id=user_id, provider=provider, credentials=credentials
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating credentials: {e}")
|
||||
return await get_credential(
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
cred_id=credentials.id,
|
||||
)
|
||||
|
||||
def set_test_dependency_overrides(self, overrides: dict):
|
||||
app.dependency_overrides.update(overrides)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user