Compare commits

..

26 Commits

Author SHA1 Message Date
Lluis Agusti
8958357343 chore: changes ... 2025-06-22 19:58:59 +04:00
Lluis Agusti
0b1b29a9bb chore: changes... 2025-06-20 19:14:27 +04:00
Lluis Agusti
d5dfc40263 chore: wip 2025-06-20 19:11:09 +04:00
Lluis Agusti
d9f9f80346 chore: don't set cookie settings on tests... 2025-06-20 17:00:07 +04:00
Lluis Agusti
a47e1916fb chore: bypass cookie settings on tests 2025-06-20 16:38:57 +04:00
Ubbe
49b22576b5 Merge branch 'dev' into fix/cookie-config 2025-06-19 19:28:07 +04:00
Nicholas Tindle
db3d62eaa0 fix(frontend): specify path for lax 2025-06-18 15:01:56 -05:00
Nicholas Tindle
46da6a1c5f Merge branch 'dev' into fix/cookie-config 2025-06-18 11:42:39 -05:00
Nicholas Tindle
f1471377c3 Merge branch 'dev' into fix/cookie-config 2025-06-17 14:22:09 -05:00
Lluis Agusti
13e5f6bf8e Merge 'dev' into 'fix/cookie-config' 2025-06-17 18:45:41 +04:00
Lluis Agusti
add32b8449 chore: working cookie settings 2025-06-17 18:43:41 +04:00
Nicholas Tindle
2f11dade70 Merge branch 'dev' into fix/cookie-config 2025-06-17 09:11:00 -05:00
Nicholas Tindle
8f1ebfc696 Merge branch 'dev' into fix/cookie-config 2025-06-16 10:37:11 -05:00
Nicholas Tindle
fc975e9e17 Merge branch 'fix/untrusted-origins' into fix/cookie-config 2025-06-13 15:15:45 -05:00
Nicholas Tindle
7985da3e8e fix: correct header + don't trust anyone! 2025-06-13 14:15:44 -05:00
Nicholas Tindle
34184f7cc0 fix: don't let other poeple look at our cookies 2025-06-13 13:49:44 -05:00
Nicholas Tindle
ade66f3d27 fix: lockfile 2025-06-13 13:15:53 -05:00
Nicholas Tindle
9b221ff931 fix: remove logging of secret things 2025-06-13 13:06:35 -05:00
Nicholas Tindle
0955cfb869 fix: formatting + crypto comparisons 2025-06-13 11:49:32 -05:00
Nicholas Tindle
bf26b8f14a Merge branch 'dev' into fix/untrusted-origins 2025-06-13 11:02:27 -05:00
Swifty
82f6687646 Merge branch 'dev' into fix/untrusted-origins 2025-06-11 11:07:56 +02:00
Nicholas Tindle
10efb1772e fix(backend): don't trust external orgins 2025-06-06 15:10:25 -05:00
Nicholas Tindle
692c6defce fix: prevent invalid json uploads 2025-06-05 16:24:22 -05:00
Nicholas Tindle
08c56a337b fix: DoS attack prevention 2025-06-05 16:23:20 -05:00
Nicholas Tindle
41ebd5fe5d fix: don't allow open redirects 2025-06-05 16:13:18 -05:00
Nicholas Tindle
e8657ed711 feat: use expected trusted sources for each 2025-06-05 16:00:01 -05:00
617 changed files with 9566 additions and 46876 deletions

View File

@@ -50,23 +50,6 @@ jobs:
env:
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
clamav:
image: clamav/clamav-debian:latest
ports:
- 3310:3310
env:
CLAMAV_NO_FRESHCLAMD: false
CLAMD_CONF_StreamMaxLength: 50M
CLAMD_CONF_MaxFileSize: 100M
CLAMD_CONF_MaxScanSize: 100M
CLAMD_CONF_MaxThreads: 4
CLAMD_CONF_ReadTimeout: 300
options: >-
--health-cmd "clamdscan --version || exit 1"
--health-interval 30s
--health-timeout 10s
--health-retries 5
--health-start-period 180s
steps:
- name: Checkout repository
@@ -148,35 +131,6 @@ jobs:
# outputs:
# DB_URL, API_URL, GRAPHQL_URL, ANON_KEY, SERVICE_ROLE_KEY, JWT_SECRET
- name: Wait for ClamAV to be ready
run: |
echo "Waiting for ClamAV daemon to start..."
max_attempts=60
attempt=0
until nc -z localhost 3310 || [ $attempt -eq $max_attempts ]; do
echo "ClamAV is unavailable - sleeping (attempt $((attempt+1))/$max_attempts)"
sleep 5
attempt=$((attempt+1))
done
if [ $attempt -eq $max_attempts ]; then
echo "ClamAV failed to start after $((max_attempts*5)) seconds"
echo "Checking ClamAV service logs..."
docker logs $(docker ps -q --filter "ancestor=clamav/clamav-debian:latest") 2>&1 | tail -50 || echo "No ClamAV container found"
exit 1
fi
echo "ClamAV is ready!"
# Verify ClamAV is responsive
echo "Testing ClamAV connection..."
timeout 10 bash -c 'echo "PING" | nc localhost 3310' || {
echo "ClamAV is not responding to PING"
docker logs $(docker ps -q --filter "ancestor=clamav/clamav-debian:latest") 2>&1 | tail -50 || echo "No ClamAV container found"
exit 1
}
- name: Run Database Migrations
run: poetry run prisma migrate dev --name updates
env:
@@ -190,9 +144,9 @@ jobs:
- name: Run pytest with coverage
run: |
if [[ "${{ runner.debug }}" == "1" ]]; then
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG test
else
poetry run pytest -s -vv
poetry run pytest -s -vv test
fi
if: success() || (failure() && steps.lint.outcome == 'failure')
env:
@@ -205,7 +159,6 @@ jobs:
REDIS_HOST: "localhost"
REDIS_PORT: "6379"
REDIS_PASSWORD: "testpassword"
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
env:
CI: true

View File

@@ -18,45 +18,11 @@ defaults:
working-directory: autogpt_platform/frontend
jobs:
setup:
runs-on: ubuntu-latest
outputs:
cache-key: ${{ steps.cache-key.outputs.key }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Enable corepack
run: corepack enable
- name: Generate cache key
id: cache-key
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('**/pnpm-lock.yaml') }}" >> $GITHUB_OUTPUT
- name: Cache dependencies
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ steps.cache-key.outputs.key }}
restore-keys: |
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
lint:
runs-on: ubuntu-latest
needs: setup
steps:
- name: Checkout repository
uses: actions/checkout@v4
- uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
@@ -66,14 +32,6 @@ jobs:
- name: Enable corepack
run: corepack enable
- name: Restore dependencies cache
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
@@ -82,11 +40,9 @@ jobs:
type-check:
runs-on: ubuntu-latest
needs: setup
steps:
- name: Checkout repository
uses: actions/checkout@v4
- uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
@@ -96,62 +52,14 @@ jobs:
- name: Enable corepack
run: corepack enable
- name: Restore dependencies cache
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Run tsc check
run: pnpm type-check
chromatic:
runs-on: ubuntu-latest
needs: setup
# Only run on dev branch pushes or PRs targeting dev
if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "21"
- name: Enable corepack
run: corepack enable
- name: Restore dependencies cache
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Run Chromatic
uses: chromaui/action@latest
with:
projectToken: chpt_9e7c1a76478c9c8
onlyChanged: true
workingDir: autogpt_platform/frontend
token: ${{ secrets.GITHUB_TOKEN }}
test:
runs-on: ubuntu-latest
needs: setup
strategy:
fail-fast: false
matrix:
@@ -189,14 +97,6 @@ jobs:
run: |
docker compose -f ../docker-compose.yml up -d
- name: Restore dependencies cache
uses: actions/cache@v4
with:
path: ~/.pnpm-store
key: ${{ needs.setup.outputs.cache-key }}
restore-keys: |
${{ runner.os }}-pnpm-
- name: Install dependencies
run: pnpm install --frozen-lockfile
@@ -212,8 +112,6 @@ jobs:
- name: Run Playwright tests
run: pnpm test:no-build --project=${{ matrix.browser }}
env:
BROWSER_TYPE: ${{ matrix.browser }}
- name: Print Final Docker Compose logs
if: always()

2
.gitignore vendored
View File

@@ -165,7 +165,7 @@ package-lock.json
# Allow for locally private items
# private
pri*
pri*
# ignore
ig*
.github_access_token

View File

@@ -19,7 +19,7 @@ cd backend && poetry install
# Run database migrations
poetry run prisma migrate dev
# Start all services (database, redis, rabbitmq, clamav)
# Start all services (database, redis, rabbitmq)
docker compose up -d
# Run the backend server
@@ -32,7 +32,6 @@ poetry run test
poetry run pytest path/to/test_file.py::test_function_name
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
@@ -78,7 +77,6 @@ npm run type-check
- **Queue System**: RabbitMQ for async task processing
- **Execution Engine**: Separate executor service processes agent workflows
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
### Frontend Architecture
- **Framework**: Next.js App Router with React Server Components
@@ -92,7 +90,6 @@ npm run type-check
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
3. **Integrations**: OAuth and API connections stored per user
4. **Store**: Marketplace for sharing agent templates
5. **Virus Scanning**: ClamAV integration for file upload security
### Testing Approach
- Backend uses pytest with snapshot testing for API responses
@@ -121,7 +118,6 @@ Key models (defined in `/backend/schema.prisma`):
3. Define input/output schemas
4. Implement `run` method
5. Register in block registry
6. Generate the block uuid using `uuid.uuid4()`
**Modifying the API:**
1. Update route in `/backend/backend/server/routers/`
@@ -133,15 +129,4 @@ Key models (defined in `/backend/schema.prisma`):
1. Components go in `/frontend/src/components/`
2. Use existing UI components from `/frontend/src/components/ui/`
3. Add Storybook stories for new components
4. Test with Playwright if user-facing
### Security Implementation
**Cache Protection Middleware:**
- Located in `/backend/backend/server/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications
4. Test with Playwright if user-facing

View File

@@ -62,12 +62,6 @@ To run the AutoGPT Platform, follow these steps:
pnpm i
```
Generate the API client (this step is required before running the frontend):
```
pnpm generate:api-client
```
Then start the frontend application in development mode:
```
@@ -170,27 +164,3 @@ To persist data for PostgreSQL and Redis, you can modify the `docker-compose.yml
3. Save the file and run `docker compose up -d` to apply the changes.
This configuration will create named volumes for PostgreSQL and Redis, ensuring that your data persists across container restarts.
### API Client Generation
The platform includes scripts for generating and managing the API client:
- `pnpm fetch:openapi`: Fetches the OpenAPI specification from the backend service (requires backend to be running on port 8006)
- `pnpm generate:api-client`: Generates the TypeScript API client from the OpenAPI specification using Orval
- `pnpm generate:api-all`: Runs both fetch and generate commands in sequence
#### Manual API Client Updates
If you need to update the API client after making changes to the backend API:
1. Ensure the backend services are running:
```
docker compose up -d
```
2. Generate the updated API client:
```
pnpm generate:api-all
```
This will fetch the latest OpenAPI specification and regenerate the TypeScript client code.

View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
[[package]]
name = "aiohappyeyeballs"
@@ -177,7 +177,7 @@ files = [
{file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"},
{file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
]
markers = {main = "python_version < \"3.11\"", dev = "python_full_version < \"3.11.3\""}
markers = {main = "python_version == \"3.10\"", dev = "python_full_version < \"3.11.3\""}
[[package]]
name = "attrs"
@@ -390,7 +390,7 @@ description = "Backport of PEP 654 (exception groups)"
optional = false
python-versions = ">=3.7"
groups = ["main"]
markers = "python_version < \"3.11\""
markers = "python_version == \"3.10\""
files = [
{file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"},
{file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
@@ -1667,30 +1667,30 @@ pyasn1 = ">=0.1.3"
[[package]]
name = "ruff"
version = "0.12.2"
version = "0.11.10"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
groups = ["dev"]
files = [
{file = "ruff-0.12.2-py3-none-linux_armv6l.whl", hash = "sha256:093ea2b221df1d2b8e7ad92fc6ffdca40a2cb10d8564477a987b44fd4008a7be"},
{file = "ruff-0.12.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:09e4cf27cc10f96b1708100fa851e0daf21767e9709e1649175355280e0d950e"},
{file = "ruff-0.12.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:8ae64755b22f4ff85e9c52d1f82644abd0b6b6b6deedceb74bd71f35c24044cc"},
{file = "ruff-0.12.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3eb3a6b2db4d6e2c77e682f0b988d4d61aff06860158fdb413118ca133d57922"},
{file = "ruff-0.12.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:73448de992d05517170fc37169cbca857dfeaeaa8c2b9be494d7bcb0d36c8f4b"},
{file = "ruff-0.12.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b8b94317cbc2ae4a2771af641739f933934b03555e51515e6e021c64441532d"},
{file = "ruff-0.12.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:45fc42c3bf1d30d2008023a0a9a0cfb06bf9835b147f11fe0679f21ae86d34b1"},
{file = "ruff-0.12.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce48f675c394c37e958bf229fb5c1e843e20945a6d962cf3ea20b7a107dcd9f4"},
{file = "ruff-0.12.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:793d8859445ea47591272021a81391350205a4af65a9392401f418a95dfb75c9"},
{file = "ruff-0.12.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6932323db80484dda89153da3d8e58164d01d6da86857c79f1961934354992da"},
{file = "ruff-0.12.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6aa7e623a3a11538108f61e859ebf016c4f14a7e6e4eba1980190cacb57714ce"},
{file = "ruff-0.12.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2a4a20aeed74671b2def096bdf2eac610c7d8ffcbf4fb0e627c06947a1d7078d"},
{file = "ruff-0.12.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:71a4c550195612f486c9d1f2b045a600aeba851b298c667807ae933478fcef04"},
{file = "ruff-0.12.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:4987b8f4ceadf597c927beee65a5eaf994c6e2b631df963f86d8ad1bdea99342"},
{file = "ruff-0.12.2-py3-none-win32.whl", hash = "sha256:369ffb69b70cd55b6c3fc453b9492d98aed98062db9fec828cdfd069555f5f1a"},
{file = "ruff-0.12.2-py3-none-win_amd64.whl", hash = "sha256:dca8a3b6d6dc9810ed8f328d406516bf4d660c00caeaef36eb831cf4871b0639"},
{file = "ruff-0.12.2-py3-none-win_arm64.whl", hash = "sha256:48d6c6bfb4761df68bc05ae630e24f506755e702d4fb08f08460be778c7ccb12"},
{file = "ruff-0.12.2.tar.gz", hash = "sha256:d7b4f55cd6f325cb7621244f19c873c565a08aff5a4ba9c69aa7355f3f7afd3e"},
{file = "ruff-0.11.10-py3-none-linux_armv6l.whl", hash = "sha256:859a7bfa7bc8888abbea31ef8a2b411714e6a80f0d173c2a82f9041ed6b50f58"},
{file = "ruff-0.11.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:968220a57e09ea5e4fd48ed1c646419961a0570727c7e069842edd018ee8afed"},
{file = "ruff-0.11.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:1067245bad978e7aa7b22f67113ecc6eb241dca0d9b696144256c3a879663bca"},
{file = "ruff-0.11.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4854fd09c7aed5b1590e996a81aeff0c9ff51378b084eb5a0b9cd9518e6cff2"},
{file = "ruff-0.11.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8b4564e9f99168c0f9195a0fd5fa5928004b33b377137f978055e40008a082c5"},
{file = "ruff-0.11.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b6a9cc5b62c03cc1fea0044ed8576379dbaf751d5503d718c973d5418483641"},
{file = "ruff-0.11.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:607ecbb6f03e44c9e0a93aedacb17b4eb4f3563d00e8b474298a201622677947"},
{file = "ruff-0.11.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7b3a522fa389402cd2137df9ddefe848f727250535c70dafa840badffb56b7a4"},
{file = "ruff-0.11.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2f071b0deed7e9245d5820dac235cbdd4ef99d7b12ff04c330a241ad3534319f"},
{file = "ruff-0.11.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a60e3a0a617eafba1f2e4186d827759d65348fa53708ca547e384db28406a0b"},
{file = "ruff-0.11.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:da8ec977eaa4b7bf75470fb575bea2cb41a0e07c7ea9d5a0a97d13dbca697bf2"},
{file = "ruff-0.11.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:ddf8967e08227d1bd95cc0851ef80d2ad9c7c0c5aab1eba31db49cf0a7b99523"},
{file = "ruff-0.11.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5a94acf798a82db188f6f36575d80609072b032105d114b0f98661e1679c9125"},
{file = "ruff-0.11.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:3afead355f1d16d95630df28d4ba17fb2cb9c8dfac8d21ced14984121f639bad"},
{file = "ruff-0.11.10-py3-none-win32.whl", hash = "sha256:dc061a98d32a97211af7e7f3fa1d4ca2fcf919fb96c28f39551f35fc55bdbc19"},
{file = "ruff-0.11.10-py3-none-win_amd64.whl", hash = "sha256:5cc725fbb4d25b0f185cb42df07ab6b76c4489b4bfb740a175f3a59c70e8a224"},
{file = "ruff-0.11.10-py3-none-win_arm64.whl", hash = "sha256:ef69637b35fb8b210743926778d0e45e1bffa850a7c61e428c6b971549b5f5d1"},
{file = "ruff-0.11.10.tar.gz", hash = "sha256:d522fb204b4959909ecac47da02830daec102eeb100fb50ea9554818d47a5fa6"},
]
[[package]]
@@ -1823,7 +1823,7 @@ description = "A lil' TOML parser"
optional = false
python-versions = ">=3.8"
groups = ["main"]
markers = "python_version < \"3.11\""
markers = "python_version == \"3.10\""
files = [
{file = "tomli-2.1.0-py3-none-any.whl", hash = "sha256:a5c57c3d1c56f5ccdf89f6523458f60ef716e210fc47c4cfb188c5ba473e0391"},
{file = "tomli-2.1.0.tar.gz", hash = "sha256:3f646cae2aec94e17d04973e4249548320197cfabdf130015d023de4b74d8ab8"},
@@ -2176,4 +2176,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<4.0"
content-hash = "574057127b05f28c2ae39f7b11aa0d7c52f857655e9223e23a27c9989b2ac10f"
content-hash = "d92143928a88ca3a56ac200c335910eafac938940022fed8bd0d17c95040b54f"

View File

@@ -23,7 +23,7 @@ uvicorn = "^0.34.3"
[tool.poetry.group.dev.dependencies]
redis = "^5.2.1"
ruff = "^0.12.2"
ruff = "^0.11.10"
[build-system]
requires = ["poetry-core"]

View File

@@ -20,7 +20,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
modules = [
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
for f in current_dir.rglob("*.py")
if f.is_file() and f.name != "__init__.py" and not f.name.startswith("test_")
if f.is_file() and f.name != "__init__.py"
]
for module in modules:
if not re.match("^[a-z0-9_.]+$", module):

View File

@@ -2,8 +2,6 @@ import asyncio
import logging
from typing import Any, Optional
from pydantic import JsonValue
from backend.data.block import (
Block,
BlockCategory,
@@ -14,10 +12,10 @@ from backend.data.block import (
get_block,
)
from backend.data.execution import ExecutionStatus
from backend.data.model import SchemaField
from backend.util import json, retry
from backend.data.model import CredentialsMetaInput, SchemaField
from backend.util import json
_logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
class AgentExecutorBlock(Block):
@@ -30,9 +28,9 @@ class AgentExecutorBlock(Block):
input_schema: dict = SchemaField(description="Input schema for the graph")
output_schema: dict = SchemaField(description="Output schema for the graph")
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = SchemaField(
default=None, hidden=True
)
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = SchemaField(default=None, hidden=True)
@classmethod
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
@@ -73,46 +71,31 @@ class AgentExecutorBlock(Block):
graph_version=input_data.graph_version,
user_id=input_data.user_id,
inputs=input_data.inputs,
nodes_input_masks=input_data.nodes_input_masks,
node_credentials_input_map=input_data.node_credentials_input_map,
use_db_query=False,
)
logger = execution_utils.LogMetadata(
logger=_logger,
user_id=input_data.user_id,
graph_eid=graph_exec.id,
graph_id=input_data.graph_id,
node_eid="*",
node_id="*",
block_name=self.name,
)
try:
async for name, data in self._run(
graph_id=input_data.graph_id,
graph_version=input_data.graph_version,
graph_exec_id=graph_exec.id,
user_id=input_data.user_id,
logger=logger,
):
yield name, data
except asyncio.CancelledError:
await self._stop(
graph_exec_id=graph_exec.id,
user_id=input_data.user_id,
logger=logger,
)
logger.warning(
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} was cancelled."
f"Execution of graph {input_data.graph_id} version {input_data.graph_version} was cancelled."
)
await execution_utils.stop_graph_execution(
graph_exec.id, use_db_query=False
)
except Exception as e:
await self._stop(
graph_exec_id=graph_exec.id,
user_id=input_data.user_id,
logger=logger,
)
logger.error(
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} failed: {e}, execution is stopped."
f"Execution of graph {input_data.graph_id} version {input_data.graph_version} failed: {e}, stopping execution."
)
await execution_utils.stop_graph_execution(
graph_exec.id, use_db_query=False
)
raise
@@ -122,7 +105,6 @@ class AgentExecutorBlock(Block):
graph_version: int,
graph_exec_id: str,
user_id: str,
logger,
) -> BlockOutput:
from backend.data.execution import ExecutionEventType
@@ -175,25 +157,3 @@ class AgentExecutorBlock(Block):
f"Execution {log_id} produced {output_name}: {output_data}"
)
yield output_name, output_data
@retry.func_retry
async def _stop(
self,
graph_exec_id: str,
user_id: str,
logger,
) -> None:
from backend.executor import utils as execution_utils
log_id = f"Graph exec-id: {graph_exec_id}"
logger.info(f"Stopping execution of {log_id}")
try:
await execution_utils.stop_graph_execution(
graph_exec_id=graph_exec_id,
user_id=user_id,
use_db_query=False,
)
logger.info(f"Execution {log_id} stopped successfully.")
except Exception as e:
logger.error(f"Failed to stop execution {log_id}: {e}")

View File

@@ -53,7 +53,6 @@ class AudioTrack(str, Enum):
REFRESHER = ("Refresher",)
TOURIST = ("Tourist",)
TWIN_TYCHES = ("Twin Tyches",)
DONT_STOP_ME_ABSTRACT_FUTURE_BASS = ("Dont Stop Me Abstract Future Bass",)
@property
def audio_url(self):
@@ -79,7 +78,6 @@ class AudioTrack(str, Enum):
AudioTrack.REFRESHER: "https://cdn.tfrv.xyz/audio/refresher.mp3",
AudioTrack.TOURIST: "https://cdn.tfrv.xyz/audio/tourist.mp3",
AudioTrack.TWIN_TYCHES: "https://cdn.tfrv.xyz/audio/twin-tynches.mp3",
AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS: "https://cdn.revid.ai/audio/_dont-stop-me-abstract-future-bass.mp3",
}
return audio_urls[self]
@@ -107,7 +105,6 @@ class GenerationPreset(str, Enum):
MOVIE = ("Movie",)
STYLIZED_ILLUSTRATION = ("Stylized Illustration",)
MANGA = ("Manga",)
DEFAULT = ("DEFAULT",)
class Voice(str, Enum):
@@ -117,7 +114,6 @@ class Voice(str, Enum):
JESSICA = "Jessica"
CHARLOTTE = "Charlotte"
CALLUM = "Callum"
EVA = "Eva"
@property
def voice_id(self):
@@ -128,7 +124,6 @@ class Voice(str, Enum):
Voice.JESSICA: "cgSgspJ2msm6clMCkdW9",
Voice.CHARLOTTE: "XB0fDUnXU5powFXDhCwa",
Voice.CALLUM: "N2lVS1w4EtoT3dr4eOWO",
Voice.EVA: "FGY2WhTYpPnrIDTdsKH5",
}
return voice_id_map[self]
@@ -146,8 +141,6 @@ logger = logging.getLogger(__name__)
class AIShortformVideoCreatorBlock(Block):
"""Creates a shortform texttovideo clip using stock or AI imagery."""
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal[ProviderName.REVID], Literal["api_key"]
@@ -191,58 +184,6 @@ class AIShortformVideoCreatorBlock(Block):
video_url: str = SchemaField(description="The URL of the created video")
error: str = SchemaField(description="Error message if the request failed")
async def create_webhook(self) -> tuple[str, str]:
"""Create a new webhook URL for receiving notifications."""
url = "https://webhook.site/token"
headers = {"Accept": "application/json", "Content-Type": "application/json"}
response = await Requests().post(url, headers=headers)
webhook_data = response.json()
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
"""Create a video using the Revid API."""
url = "https://www.revid.ai/api/public/v2/render"
headers = {"key": api_key.get_secret_value()}
response = await Requests().post(url, json=payload, headers=headers)
logger.debug(
f"API Response Status Code: {response.status}, Content: {response.text}"
)
return response.json()
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
"""Check the status of a video creation job."""
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
headers = {"key": api_key.get_secret_value()}
response = await Requests().get(url, headers=headers)
return response.json()
async def wait_for_video(
self,
api_key: SecretStr,
pid: str,
max_wait_time: int = 1000,
) -> str:
"""Wait for video creation to complete and return the video URL."""
start_time = time.time()
while time.time() - start_time < max_wait_time:
status = await self.check_video_status(api_key, pid)
logger.debug(f"Video status: {status}")
if status.get("status") == "ready" and "videoUrl" in status:
return status["videoUrl"]
elif status.get("status") == "error":
error_message = status.get("error", "Unknown error occurred")
logger.error(f"Video creation failed: {error_message}")
raise ValueError(f"Video creation failed: {error_message}")
elif status.get("status") in ["FAILED", "CANCELED"]:
logger.error(f"Video creation failed: {status.get('message')}")
raise ValueError(f"Video creation failed: {status.get('message')}")
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(
id="361697fb-0c4f-4feb-aed3-8320c88c771b",
@@ -261,22 +202,70 @@ class AIShortformVideoCreatorBlock(Block):
"voice": Voice.LILY,
"video_style": VisualMediaType.STOCK_VIDEOS,
},
test_output=("video_url", "https://example.com/video.mp4"),
test_output=(
"video_url",
"https://example.com/video.mp4",
),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"create_webhook": lambda: (
"test_uuid",
"https://webhook.site/test_uuid",
),
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/video.mp4",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
"create_video": lambda api_key, payload: {"pid": "test_pid"},
"wait_for_video": lambda api_key, pid, webhook_token, max_wait_time=1000: "https://example.com/video.mp4",
},
test_credentials=TEST_CREDENTIALS,
)
async def create_webhook(self):
url = "https://webhook.site/token"
headers = {"Accept": "application/json", "Content-Type": "application/json"}
response = await Requests().post(url, headers=headers)
webhook_data = response.json()
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
url = "https://www.revid.ai/api/public/v2/render"
headers = {"key": api_key.get_secret_value()}
response = await Requests().post(url, json=payload, headers=headers)
logger.debug(
f"API Response Status Code: {response.status}, Content: {response.text}"
)
return response.json()
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
headers = {"key": api_key.get_secret_value()}
response = await Requests().get(url, headers=headers)
return response.json()
async def wait_for_video(
self,
api_key: SecretStr,
pid: str,
webhook_token: str,
max_wait_time: int = 1000,
) -> str:
start_time = time.time()
while time.time() - start_time < max_wait_time:
status = await self.check_video_status(api_key, pid)
logger.debug(f"Video status: {status}")
if status.get("status") == "ready" and "videoUrl" in status:
return status["videoUrl"]
elif status.get("status") == "error":
error_message = status.get("error", "Unknown error occurred")
logger.error(f"Video creation failed: {error_message}")
raise ValueError(f"Video creation failed: {error_message}")
elif status.get("status") in ["FAILED", "CANCELED"]:
logger.error(f"Video creation failed: {status.get('message')}")
raise ValueError(f"Video creation failed: {status.get('message')}")
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise TimeoutError("Video creation timed out")
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
@@ -284,18 +273,20 @@ class AIShortformVideoCreatorBlock(Block):
webhook_token, webhook_url = await self.create_webhook()
logger.debug(f"Webhook URL: {webhook_url}")
audio_url = input_data.background_music.audio_url
payload = {
"frameRate": input_data.frame_rate,
"resolution": input_data.resolution,
"frameDurationMultiplier": 18,
"webhook": None,
"webhook": webhook_url,
"creationParams": {
"mediaType": input_data.video_style,
"captionPresetName": "Wrap 1",
"selectedVoice": input_data.voice.voice_id,
"hasEnhancedGeneration": True,
"generationPreset": input_data.generation_preset.name,
"selectedAudio": input_data.background_music.value,
"selectedAudio": input_data.background_music,
"origin": "/create",
"inputText": input_data.script,
"flowType": "text-to-video",
@@ -311,7 +302,7 @@ class AIShortformVideoCreatorBlock(Block):
"selectedStoryStyle": {"value": "custom", "label": "Custom"},
"hasToGenerateVideos": input_data.video_style
!= VisualMediaType.STOCK_VIDEOS,
"audioUrl": input_data.background_music.audio_url,
"audioUrl": audio_url,
},
}
@@ -328,370 +319,8 @@ class AIShortformVideoCreatorBlock(Block):
logger.debug(
f"Video created with project ID: {pid}. Waiting for completion..."
)
video_url = await self.wait_for_video(credentials.api_key, pid)
video_url = await self.wait_for_video(
credentials.api_key, pid, webhook_token
)
logger.debug(f"Video ready: {video_url}")
yield "video_url", video_url
class AIAdMakerVideoCreatorBlock(Block):
"""Generates a 30second vertical AI advert using optional usersupplied imagery."""
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal[ProviderName.REVID], Literal["api_key"]
] = CredentialsField(
description="Credentials for Revid.ai API access.",
)
script: str = SchemaField(
description="Short advertising copy. Line breaks create new scenes.",
placeholder="Introducing Foobar [show product photo] the gadget that does it all.",
)
ratio: str = SchemaField(description="Aspect ratio", default="9 / 16")
target_duration: int = SchemaField(
description="Desired length of the ad in seconds.", default=30
)
voice: Voice = SchemaField(
description="Narration voice", default=Voice.EVA, placeholder=Voice.EVA
)
background_music: AudioTrack = SchemaField(
description="Background track",
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS,
)
input_media_urls: list[str] = SchemaField(
description="List of image URLs to feature in the advert.", default=[]
)
use_only_provided_media: bool = SchemaField(
description="Restrict visuals to supplied images only.", default=True
)
class Output(BlockSchema):
video_url: str = SchemaField(description="URL of the finished advert")
error: str = SchemaField(description="Error message on failure")
async def create_webhook(self) -> tuple[str, str]:
"""Create a new webhook URL for receiving notifications."""
url = "https://webhook.site/token"
headers = {"Accept": "application/json", "Content-Type": "application/json"}
response = await Requests().post(url, headers=headers)
webhook_data = response.json()
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
"""Create a video using the Revid API."""
url = "https://www.revid.ai/api/public/v2/render"
headers = {"key": api_key.get_secret_value()}
response = await Requests().post(url, json=payload, headers=headers)
logger.debug(
f"API Response Status Code: {response.status}, Content: {response.text}"
)
return response.json()
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
"""Check the status of a video creation job."""
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
headers = {"key": api_key.get_secret_value()}
response = await Requests().get(url, headers=headers)
return response.json()
async def wait_for_video(
self,
api_key: SecretStr,
pid: str,
max_wait_time: int = 1000,
) -> str:
"""Wait for video creation to complete and return the video URL."""
start_time = time.time()
while time.time() - start_time < max_wait_time:
status = await self.check_video_status(api_key, pid)
logger.debug(f"Video status: {status}")
if status.get("status") == "ready" and "videoUrl" in status:
return status["videoUrl"]
elif status.get("status") == "error":
error_message = status.get("error", "Unknown error occurred")
logger.error(f"Video creation failed: {error_message}")
raise ValueError(f"Video creation failed: {error_message}")
elif status.get("status") in ["FAILED", "CANCELED"]:
logger.error(f"Video creation failed: {status.get('message')}")
raise ValueError(f"Video creation failed: {status.get('message')}")
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(
id="58bd2a19-115d-4fd1-8ca4-13b9e37fa6a0",
description="Creates an AIgenerated 30second advert (text + images)",
categories={BlockCategory.MARKETING, BlockCategory.AI},
input_schema=AIAdMakerVideoCreatorBlock.Input,
output_schema=AIAdMakerVideoCreatorBlock.Output,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"script": "Test product launch!",
"input_media_urls": [
"https://cdn.revid.ai/uploads/1747076315114-image.png",
],
},
test_output=("video_url", "https://example.com/ad.mp4"),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
"https://webhook.site/test_uuid",
),
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/ad.mp4",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
webhook_token, webhook_url = await self.create_webhook()
payload = {
"webhook": webhook_url,
"creationParams": {
"targetDuration": input_data.target_duration,
"ratio": input_data.ratio,
"mediaType": "aiVideo",
"inputText": input_data.script,
"flowType": "text-to-video",
"slug": "ai-ad-generator",
"slugNew": "",
"isCopiedFrom": False,
"hasToGenerateVoice": True,
"hasToTranscript": False,
"hasToSearchMedia": True,
"hasAvatar": False,
"hasWebsiteRecorder": False,
"hasTextSmallAtBottom": False,
"selectedAudio": input_data.background_music.value,
"selectedVoice": input_data.voice.voice_id,
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
"selectedAvatarType": "video/mp4",
"websiteToRecord": "",
"hasToGenerateCover": True,
"nbGenerations": 1,
"disableCaptions": False,
"mediaMultiplier": "medium",
"characters": [],
"captionPresetName": "Revid",
"sourceType": "contentScraping",
"selectedStoryStyle": {"value": "custom", "label": "General"},
"generationPreset": "DEFAULT",
"hasToGenerateMusic": False,
"isOptimizedForChinese": False,
"generationUserPrompt": "",
"enableNsfwFilter": False,
"addStickers": False,
"typeMovingImageAnim": "dynamic",
"hasToGenerateSoundEffects": False,
"forceModelType": "gpt-image-1",
"selectedCharacters": [],
"lang": "",
"voiceSpeed": 1,
"disableAudio": False,
"disableVoice": False,
"useOnlyProvidedMedia": input_data.use_only_provided_media,
"imageGenerationModel": "ultra",
"videoGenerationModel": "pro",
"hasEnhancedGeneration": True,
"hasEnhancedGenerationPro": True,
"inputMedias": [
{"url": url, "title": "", "type": "image"}
for url in input_data.input_media_urls
],
"hasToGenerateVideos": True,
"audioUrl": input_data.background_music.audio_url,
"watermark": None,
},
}
response = await self.create_video(credentials.api_key, payload)
pid = response.get("pid")
if not pid:
raise RuntimeError("Failed to create video: No project ID returned")
video_url = await self.wait_for_video(credentials.api_key, pid)
yield "video_url", video_url
class AIScreenshotToVideoAdBlock(Block):
"""Creates an advert where the supplied screenshot is narrated by an AI avatar."""
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal[ProviderName.REVID], Literal["api_key"]
] = CredentialsField(description="Revid.ai API key")
script: str = SchemaField(
description="Narration that will accompany the screenshot.",
placeholder="Check out these amazing stats!",
)
screenshot_url: str = SchemaField(
description="Screenshot or image URL to showcase."
)
ratio: str = SchemaField(default="9 / 16")
target_duration: int = SchemaField(default=30)
voice: Voice = SchemaField(default=Voice.EVA)
background_music: AudioTrack = SchemaField(
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS
)
class Output(BlockSchema):
video_url: str = SchemaField(description="Rendered video URL")
error: str = SchemaField(description="Error, if encountered")
async def create_webhook(self) -> tuple[str, str]:
"""Create a new webhook URL for receiving notifications."""
url = "https://webhook.site/token"
headers = {"Accept": "application/json", "Content-Type": "application/json"}
response = await Requests().post(url, headers=headers)
webhook_data = response.json()
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
"""Create a video using the Revid API."""
url = "https://www.revid.ai/api/public/v2/render"
headers = {"key": api_key.get_secret_value()}
response = await Requests().post(url, json=payload, headers=headers)
logger.debug(
f"API Response Status Code: {response.status}, Content: {response.text}"
)
return response.json()
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
"""Check the status of a video creation job."""
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
headers = {"key": api_key.get_secret_value()}
response = await Requests().get(url, headers=headers)
return response.json()
async def wait_for_video(
self,
api_key: SecretStr,
pid: str,
max_wait_time: int = 1000,
) -> str:
"""Wait for video creation to complete and return the video URL."""
start_time = time.time()
while time.time() - start_time < max_wait_time:
status = await self.check_video_status(api_key, pid)
logger.debug(f"Video status: {status}")
if status.get("status") == "ready" and "videoUrl" in status:
return status["videoUrl"]
elif status.get("status") == "error":
error_message = status.get("error", "Unknown error occurred")
logger.error(f"Video creation failed: {error_message}")
raise ValueError(f"Video creation failed: {error_message}")
elif status.get("status") in ["FAILED", "CANCELED"]:
logger.error(f"Video creation failed: {status.get('message')}")
raise ValueError(f"Video creation failed: {status.get('message')}")
await asyncio.sleep(10)
logger.error("Video creation timed out")
raise TimeoutError("Video creation timed out")
def __init__(self):
super().__init__(
id="0f3e4635-e810-43d9-9e81-49e6f4e83b7c",
description="Turns a screenshot into an engaging, avatarnarrated video advert.",
categories={BlockCategory.AI, BlockCategory.MARKETING},
input_schema=AIScreenshotToVideoAdBlock.Input,
output_schema=AIScreenshotToVideoAdBlock.Output,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"script": "Amazing numbers!",
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
},
test_output=("video_url", "https://example.com/screenshot.mp4"),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
"https://webhook.site/test_uuid",
),
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/screenshot.mp4",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
webhook_token, webhook_url = await self.create_webhook()
payload = {
"webhook": webhook_url,
"creationParams": {
"targetDuration": input_data.target_duration,
"ratio": input_data.ratio,
"mediaType": "aiVideo",
"hasAvatar": True,
"removeAvatarBackground": True,
"inputText": input_data.script,
"flowType": "text-to-video",
"slug": "ai-ad-generator",
"slugNew": "screenshot-to-video-ad",
"isCopiedFrom": "ai-ad-generator",
"hasToGenerateVoice": True,
"hasToTranscript": False,
"hasToSearchMedia": True,
"hasWebsiteRecorder": False,
"hasTextSmallAtBottom": False,
"selectedAudio": input_data.background_music.value,
"selectedVoice": input_data.voice.voice_id,
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
"selectedAvatarType": "video/mp4",
"websiteToRecord": "",
"hasToGenerateCover": True,
"nbGenerations": 1,
"disableCaptions": False,
"mediaMultiplier": "medium",
"characters": [],
"captionPresetName": "Revid",
"sourceType": "contentScraping",
"selectedStoryStyle": {"value": "custom", "label": "General"},
"generationPreset": "DEFAULT",
"hasToGenerateMusic": False,
"isOptimizedForChinese": False,
"generationUserPrompt": "",
"enableNsfwFilter": False,
"addStickers": False,
"typeMovingImageAnim": "dynamic",
"hasToGenerateSoundEffects": False,
"forceModelType": "gpt-image-1",
"selectedCharacters": [],
"lang": "",
"voiceSpeed": 1,
"disableAudio": False,
"disableVoice": False,
"useOnlyProvidedMedia": True,
"imageGenerationModel": "ultra",
"videoGenerationModel": "ultra",
"hasEnhancedGeneration": True,
"hasEnhancedGenerationPro": True,
"inputMedias": [
{"url": input_data.screenshot_url, "title": "", "type": "image"}
],
"hasToGenerateVideos": True,
"audioUrl": input_data.background_music.audio_url,
"watermark": None,
},
}
response = await self.create_video(credentials.api_key, payload)
pid = response.get("pid")
if not pid:
raise RuntimeError("Failed to create video: No project ID returned")
video_url = await self.wait_for_video(credentials.api_key, pid)
yield "video_url", video_url

View File

@@ -4,7 +4,6 @@ from typing import List
from backend.blocks.apollo._auth import ApolloCredentials
from backend.blocks.apollo.models import (
Contact,
EnrichPersonRequest,
Organization,
SearchOrganizationsRequest,
SearchOrganizationsResponse,
@@ -30,10 +29,10 @@ class ApolloClient:
async def search_people(self, query: SearchPeopleRequest) -> List[Contact]:
"""Search for people in Apollo"""
response = await self.requests.post(
response = await self.requests.get(
f"{self.API_URL}/mixed_people/search",
headers=self._get_headers(),
json=query.model_dump(exclude={"max_results"}),
params=query.model_dump(exclude={"credentials", "max_results"}),
)
data = response.json()
parsed_response = SearchPeopleResponse(**data)
@@ -54,10 +53,10 @@ class ApolloClient:
and len(parsed_response.people) > 0
):
query.page += 1
response = await self.requests.post(
response = await self.requests.get(
f"{self.API_URL}/mixed_people/search",
headers=self._get_headers(),
json=query.model_dump(exclude={"max_results"}),
params=query.model_dump(exclude={"credentials", "max_results"}),
)
data = response.json()
parsed_response = SearchPeopleResponse(**data)
@@ -70,10 +69,10 @@ class ApolloClient:
self, query: SearchOrganizationsRequest
) -> List[Organization]:
"""Search for organizations in Apollo"""
response = await self.requests.post(
response = await self.requests.get(
f"{self.API_URL}/mixed_companies/search",
headers=self._get_headers(),
json=query.model_dump(exclude={"max_results"}),
params=query.model_dump(exclude={"credentials", "max_results"}),
)
data = response.json()
parsed_response = SearchOrganizationsResponse(**data)
@@ -94,10 +93,10 @@ class ApolloClient:
and len(parsed_response.organizations) > 0
):
query.page += 1
response = await self.requests.post(
response = await self.requests.get(
f"{self.API_URL}/mixed_companies/search",
headers=self._get_headers(),
json=query.model_dump(exclude={"max_results"}),
params=query.model_dump(exclude={"credentials", "max_results"}),
)
data = response.json()
parsed_response = SearchOrganizationsResponse(**data)
@@ -111,21 +110,3 @@ class ApolloClient:
return (
organizations[: query.max_results] if query.max_results else organizations
)
async def enrich_person(self, query: EnrichPersonRequest) -> Contact:
"""Enrich a person's data including email & phone reveal"""
response = await self.requests.post(
f"{self.API_URL}/people/match",
headers=self._get_headers(),
json=query.model_dump(),
params={
"reveal_personal_emails": "true",
},
)
data = response.json()
if "person" not in data:
raise ValueError(f"Person not found or enrichment failed: {data}")
contact = Contact(**data["person"])
contact.email = contact.email or "-"
return contact

View File

@@ -1,31 +1,17 @@
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel as OriginalBaseModel
from pydantic import ConfigDict
from pydantic import BaseModel, ConfigDict
from backend.data.model import SchemaField
class BaseModel(OriginalBaseModel):
def model_dump(self, *args, exclude: set[str] | None = None, **kwargs):
if exclude is None:
exclude = set("credentials")
else:
exclude.add("credentials")
kwargs.setdefault("exclude_none", True)
kwargs.setdefault("exclude_unset", True)
kwargs.setdefault("exclude_defaults", True)
return super().model_dump(*args, exclude=exclude, **kwargs)
class PrimaryPhone(BaseModel):
"""A primary phone in Apollo"""
number: Optional[str] = ""
source: Optional[str] = ""
sanitized_number: Optional[str] = ""
number: str
source: str
sanitized_number: str
class SenorityLevels(str, Enum):
@@ -56,102 +42,102 @@ class ContactEmailStatuses(str, Enum):
class RuleConfigStatus(BaseModel):
"""A rule config status in Apollo"""
_id: Optional[str] = ""
created_at: Optional[str] = ""
rule_action_config_id: Optional[str] = ""
rule_config_id: Optional[str] = ""
status_cd: Optional[str] = ""
updated_at: Optional[str] = ""
id: Optional[str] = ""
key: Optional[str] = ""
_id: str
created_at: str
rule_action_config_id: str
rule_config_id: str
status_cd: str
updated_at: str
id: str
key: str
class ContactCampaignStatus(BaseModel):
"""A contact campaign status in Apollo"""
id: Optional[str] = ""
emailer_campaign_id: Optional[str] = ""
send_email_from_user_id: Optional[str] = ""
inactive_reason: Optional[str] = ""
status: Optional[str] = ""
added_at: Optional[str] = ""
added_by_user_id: Optional[str] = ""
finished_at: Optional[str] = ""
paused_at: Optional[str] = ""
auto_unpause_at: Optional[str] = ""
send_email_from_email_address: Optional[str] = ""
send_email_from_email_account_id: Optional[str] = ""
manually_set_unpause: Optional[str] = ""
failure_reason: Optional[str] = ""
current_step_id: Optional[str] = ""
in_response_to_emailer_message_id: Optional[str] = ""
cc_emails: Optional[str] = ""
bcc_emails: Optional[str] = ""
to_emails: Optional[str] = ""
id: str
emailer_campaign_id: str
send_email_from_user_id: str
inactive_reason: str
status: str
added_at: str
added_by_user_id: str
finished_at: str
paused_at: str
auto_unpause_at: str
send_email_from_email_address: str
send_email_from_email_account_id: str
manually_set_unpause: str
failure_reason: str
current_step_id: str
in_response_to_emailer_message_id: str
cc_emails: str
bcc_emails: str
to_emails: str
class Account(BaseModel):
"""An account in Apollo"""
id: Optional[str] = ""
name: Optional[str] = ""
website_url: Optional[str] = ""
blog_url: Optional[str] = ""
angellist_url: Optional[str] = ""
linkedin_url: Optional[str] = ""
twitter_url: Optional[str] = ""
facebook_url: Optional[str] = ""
primary_phone: Optional[PrimaryPhone] = PrimaryPhone()
languages: Optional[list[str]] = []
alexa_ranking: Optional[int] = 0
phone: Optional[str] = ""
linkedin_uid: Optional[str] = ""
founded_year: Optional[int] = 0
publicly_traded_symbol: Optional[str] = ""
publicly_traded_exchange: Optional[str] = ""
logo_url: Optional[str] = ""
chrunchbase_url: Optional[str] = ""
primary_domain: Optional[str] = ""
domain: Optional[str] = ""
team_id: Optional[str] = ""
organization_id: Optional[str] = ""
account_stage_id: Optional[str] = ""
source: Optional[str] = ""
original_source: Optional[str] = ""
creator_id: Optional[str] = ""
owner_id: Optional[str] = ""
created_at: Optional[str] = ""
phone_status: Optional[str] = ""
hubspot_id: Optional[str] = ""
salesforce_id: Optional[str] = ""
crm_owner_id: Optional[str] = ""
parent_account_id: Optional[str] = ""
sanitized_phone: Optional[str] = ""
id: str
name: str
website_url: str
blog_url: str
angellist_url: str
linkedin_url: str
twitter_url: str
facebook_url: str
primary_phone: PrimaryPhone
languages: list[str]
alexa_ranking: int
phone: str
linkedin_uid: str
founded_year: int
publicly_traded_symbol: str
publicly_traded_exchange: str
logo_url: str
chrunchbase_url: str
primary_domain: str
domain: str
team_id: str
organization_id: str
account_stage_id: str
source: str
original_source: str
creator_id: str
owner_id: str
created_at: str
phone_status: str
hubspot_id: str
salesforce_id: str
crm_owner_id: str
parent_account_id: str
sanitized_phone: str
# no listed type on the API docs
account_playbook_statues: Optional[list[Any]] = []
account_rule_config_statuses: Optional[list[RuleConfigStatus]] = []
existence_level: Optional[str] = ""
label_ids: Optional[list[str]] = []
typed_custom_fields: Optional[Any] = {}
custom_field_errors: Optional[Any] = {}
modality: Optional[str] = ""
source_display_name: Optional[str] = ""
salesforce_record_id: Optional[str] = ""
crm_record_url: Optional[str] = ""
account_playbook_statues: list[Any]
account_rule_config_statuses: list[RuleConfigStatus]
existence_level: str
label_ids: list[str]
typed_custom_fields: Any
custom_field_errors: Any
modality: str
source_display_name: str
salesforce_record_id: str
crm_record_url: str
class ContactEmail(BaseModel):
"""A contact email in Apollo"""
email: Optional[str] = ""
email_md5: Optional[str] = ""
email_sha256: Optional[str] = ""
email_status: Optional[str] = ""
email_source: Optional[str] = ""
extrapolated_email_confidence: Optional[str] = ""
position: Optional[int] = 0
email_from_customer: Optional[str] = ""
free_domain: Optional[bool] = True
email: str = ""
email_md5: str = ""
email_sha256: str = ""
email_status: str = ""
email_source: str = ""
extrapolated_email_confidence: str = ""
position: int = 0
email_from_customer: str = ""
free_domain: bool = True
class EmploymentHistory(BaseModel):
@@ -164,40 +150,40 @@ class EmploymentHistory(BaseModel):
populate_by_name=True,
)
_id: Optional[str] = ""
created_at: Optional[str] = ""
current: Optional[bool] = False
degree: Optional[str] = ""
description: Optional[str] = ""
emails: Optional[str] = ""
end_date: Optional[str] = ""
grade_level: Optional[str] = ""
kind: Optional[str] = ""
major: Optional[str] = ""
organization_id: Optional[str] = ""
organization_name: Optional[str] = ""
raw_address: Optional[str] = ""
start_date: Optional[str] = ""
title: Optional[str] = ""
updated_at: Optional[str] = ""
id: Optional[str] = ""
key: Optional[str] = ""
_id: Optional[str] = None
created_at: Optional[str] = None
current: Optional[bool] = None
degree: Optional[str] = None
description: Optional[str] = None
emails: Optional[str] = None
end_date: Optional[str] = None
grade_level: Optional[str] = None
kind: Optional[str] = None
major: Optional[str] = None
organization_id: Optional[str] = None
organization_name: Optional[str] = None
raw_address: Optional[str] = None
start_date: Optional[str] = None
title: Optional[str] = None
updated_at: Optional[str] = None
id: Optional[str] = None
key: Optional[str] = None
class Breadcrumb(BaseModel):
"""A breadcrumb in Apollo"""
label: Optional[str] = ""
signal_field_name: Optional[str] = ""
value: str | list | None = ""
display_name: Optional[str] = ""
label: Optional[str] = "N/A"
signal_field_name: Optional[str] = "N/A"
value: str | list | None = "N/A"
display_name: Optional[str] = "N/A"
class TypedCustomField(BaseModel):
"""A typed custom field in Apollo"""
id: Optional[str] = ""
value: Optional[str] = ""
id: Optional[str] = "N/A"
value: Optional[str] = "N/A"
class Pagination(BaseModel):
@@ -219,23 +205,23 @@ class Pagination(BaseModel):
class DialerFlags(BaseModel):
"""A dialer flags in Apollo"""
country_name: Optional[str] = ""
country_enabled: Optional[bool] = True
high_risk_calling_enabled: Optional[bool] = True
potential_high_risk_number: Optional[bool] = True
country_name: str
country_enabled: bool
high_risk_calling_enabled: bool
potential_high_risk_number: bool
class PhoneNumber(BaseModel):
"""A phone number in Apollo"""
raw_number: Optional[str] = ""
sanitized_number: Optional[str] = ""
type: Optional[str] = ""
position: Optional[int] = 0
status: Optional[str] = ""
dnc_status: Optional[str] = ""
dnc_other_info: Optional[str] = ""
dailer_flags: Optional[DialerFlags] = DialerFlags(
raw_number: str = ""
sanitized_number: str = ""
type: str = ""
position: int = 0
status: str = ""
dnc_status: str = ""
dnc_other_info: str = ""
dailer_flags: DialerFlags = DialerFlags(
country_name="",
country_enabled=True,
high_risk_calling_enabled=True,
@@ -253,31 +239,33 @@ class Organization(BaseModel):
populate_by_name=True,
)
id: Optional[str] = ""
name: Optional[str] = ""
website_url: Optional[str] = ""
blog_url: Optional[str] = ""
angellist_url: Optional[str] = ""
linkedin_url: Optional[str] = ""
twitter_url: Optional[str] = ""
facebook_url: Optional[str] = ""
primary_phone: Optional[PrimaryPhone] = PrimaryPhone()
languages: Optional[list[str]] = []
id: Optional[str] = "N/A"
name: Optional[str] = "N/A"
website_url: Optional[str] = "N/A"
blog_url: Optional[str] = "N/A"
angellist_url: Optional[str] = "N/A"
linkedin_url: Optional[str] = "N/A"
twitter_url: Optional[str] = "N/A"
facebook_url: Optional[str] = "N/A"
primary_phone: Optional[PrimaryPhone] = PrimaryPhone(
number="N/A", source="N/A", sanitized_number="N/A"
)
languages: list[str] = []
alexa_ranking: Optional[int] = 0
phone: Optional[str] = ""
linkedin_uid: Optional[str] = ""
phone: Optional[str] = "N/A"
linkedin_uid: Optional[str] = "N/A"
founded_year: Optional[int] = 0
publicly_traded_symbol: Optional[str] = ""
publicly_traded_exchange: Optional[str] = ""
logo_url: Optional[str] = ""
chrunchbase_url: Optional[str] = ""
primary_domain: Optional[str] = ""
sanitized_phone: Optional[str] = ""
owned_by_organization_id: Optional[str] = ""
intent_strength: Optional[str] = ""
show_intent: Optional[bool] = True
publicly_traded_symbol: Optional[str] = "N/A"
publicly_traded_exchange: Optional[str] = "N/A"
logo_url: Optional[str] = "N/A"
chrunchbase_url: Optional[str] = "N/A"
primary_domain: Optional[str] = "N/A"
sanitized_phone: Optional[str] = "N/A"
owned_by_organization_id: Optional[str] = "N/A"
intent_strength: Optional[str] = "N/A"
show_intent: bool = True
has_intent_signal_account: Optional[bool] = True
intent_signal_account: Optional[str] = ""
intent_signal_account: Optional[str] = "N/A"
class Contact(BaseModel):
@@ -290,95 +278,95 @@ class Contact(BaseModel):
populate_by_name=True,
)
contact_roles: Optional[list[Any]] = []
id: Optional[str] = ""
first_name: Optional[str] = ""
last_name: Optional[str] = ""
name: Optional[str] = ""
linkedin_url: Optional[str] = ""
title: Optional[str] = ""
contact_stage_id: Optional[str] = ""
owner_id: Optional[str] = ""
creator_id: Optional[str] = ""
person_id: Optional[str] = ""
email_needs_tickling: Optional[bool] = True
organization_name: Optional[str] = ""
source: Optional[str] = ""
original_source: Optional[str] = ""
organization_id: Optional[str] = ""
headline: Optional[str] = ""
photo_url: Optional[str] = ""
present_raw_address: Optional[str] = ""
linkededin_uid: Optional[str] = ""
extrapolated_email_confidence: Optional[float] = 0.0
salesforce_id: Optional[str] = ""
salesforce_lead_id: Optional[str] = ""
salesforce_contact_id: Optional[str] = ""
saleforce_account_id: Optional[str] = ""
crm_owner_id: Optional[str] = ""
created_at: Optional[str] = ""
emailer_campaign_ids: Optional[list[str]] = []
direct_dial_status: Optional[str] = ""
direct_dial_enrichment_failed_at: Optional[str] = ""
email_status: Optional[str] = ""
email_source: Optional[str] = ""
account_id: Optional[str] = ""
last_activity_date: Optional[str] = ""
hubspot_vid: Optional[str] = ""
hubspot_company_id: Optional[str] = ""
crm_id: Optional[str] = ""
sanitized_phone: Optional[str] = ""
merged_crm_ids: Optional[str] = ""
updated_at: Optional[str] = ""
queued_for_crm_push: Optional[bool] = True
suggested_from_rule_engine_config_id: Optional[str] = ""
email_unsubscribed: Optional[str] = ""
label_ids: Optional[list[Any]] = []
has_pending_email_arcgate_request: Optional[bool] = True
has_email_arcgate_request: Optional[bool] = True
existence_level: Optional[str] = ""
email: Optional[str] = ""
email_from_customer: Optional[str] = ""
typed_custom_fields: Optional[list[TypedCustomField]] = []
custom_field_errors: Optional[Any] = {}
salesforce_record_id: Optional[str] = ""
crm_record_url: Optional[str] = ""
email_status_unavailable_reason: Optional[str] = ""
email_true_status: Optional[str] = ""
updated_email_true_status: Optional[bool] = True
contact_rule_config_statuses: Optional[list[RuleConfigStatus]] = []
source_display_name: Optional[str] = ""
twitter_url: Optional[str] = ""
contact_campaign_statuses: Optional[list[ContactCampaignStatus]] = []
state: Optional[str] = ""
city: Optional[str] = ""
country: Optional[str] = ""
account: Optional[Account] = Account()
contact_emails: Optional[list[ContactEmail]] = []
organization: Optional[Organization] = Organization()
employment_history: Optional[list[EmploymentHistory]] = []
time_zone: Optional[str] = ""
intent_strength: Optional[str] = ""
show_intent: Optional[bool] = True
phone_numbers: Optional[list[PhoneNumber]] = []
account_phone_note: Optional[str] = ""
free_domain: Optional[bool] = True
is_likely_to_engage: Optional[bool] = True
email_domain_catchall: Optional[bool] = True
contact_job_change_event: Optional[str] = ""
contact_roles: list[Any] = []
id: Optional[str] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
name: Optional[str] = None
linkedin_url: Optional[str] = None
title: Optional[str] = None
contact_stage_id: Optional[str] = None
owner_id: Optional[str] = None
creator_id: Optional[str] = None
person_id: Optional[str] = None
email_needs_tickling: bool = True
organization_name: Optional[str] = None
source: Optional[str] = None
original_source: Optional[str] = None
organization_id: Optional[str] = None
headline: Optional[str] = None
photo_url: Optional[str] = None
present_raw_address: Optional[str] = None
linkededin_uid: Optional[str] = None
extrapolated_email_confidence: Optional[float] = None
salesforce_id: Optional[str] = None
salesforce_lead_id: Optional[str] = None
salesforce_contact_id: Optional[str] = None
saleforce_account_id: Optional[str] = None
crm_owner_id: Optional[str] = None
created_at: Optional[str] = None
emailer_campaign_ids: list[str] = []
direct_dial_status: Optional[str] = None
direct_dial_enrichment_failed_at: Optional[str] = None
email_status: Optional[str] = None
email_source: Optional[str] = None
account_id: Optional[str] = None
last_activity_date: Optional[str] = None
hubspot_vid: Optional[str] = None
hubspot_company_id: Optional[str] = None
crm_id: Optional[str] = None
sanitized_phone: Optional[str] = None
merged_crm_ids: Optional[str] = None
updated_at: Optional[str] = None
queued_for_crm_push: bool = True
suggested_from_rule_engine_config_id: Optional[str] = None
email_unsubscribed: Optional[str] = None
label_ids: list[Any] = []
has_pending_email_arcgate_request: bool = True
has_email_arcgate_request: bool = True
existence_level: Optional[str] = None
email: Optional[str] = None
email_from_customer: Optional[str] = None
typed_custom_fields: list[TypedCustomField] = []
custom_field_errors: Any = None
salesforce_record_id: Optional[str] = None
crm_record_url: Optional[str] = None
email_status_unavailable_reason: Optional[str] = None
email_true_status: Optional[str] = None
updated_email_true_status: bool = True
contact_rule_config_statuses: list[RuleConfigStatus] = []
source_display_name: Optional[str] = None
twitter_url: Optional[str] = None
contact_campaign_statuses: list[ContactCampaignStatus] = []
state: Optional[str] = None
city: Optional[str] = None
country: Optional[str] = None
account: Optional[Account] = None
contact_emails: list[ContactEmail] = []
organization: Optional[Organization] = None
employment_history: list[EmploymentHistory] = []
time_zone: Optional[str] = None
intent_strength: Optional[str] = None
show_intent: bool = True
phone_numbers: list[PhoneNumber] = []
account_phone_note: Optional[str] = None
free_domain: bool = True
is_likely_to_engage: bool = True
email_domain_catchall: bool = True
contact_job_change_event: Optional[str] = None
class SearchOrganizationsRequest(BaseModel):
"""Request for Apollo's search organizations API"""
organization_num_employees_range: Optional[list[int]] = SchemaField(
organization_num_empoloyees_range: list[int] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
default=[0, 1000000],
)
organization_locations: Optional[list[str]] = SchemaField(
organization_locations: list[str] = SchemaField(
description="""The location of the company headquarters. You can search across cities, US states, and countries.
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, any Boston-based companies will not appearch in your search results, even if they match other parameters.
@@ -387,30 +375,28 @@ To exclude companies based on location, use the organization_not_locations param
""",
default_factory=list,
)
organizations_not_locations: Optional[list[str]] = SchemaField(
organizations_not_locations: list[str] = SchemaField(
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
""",
default_factory=list,
)
q_organization_keyword_tags: Optional[list[str]] = SchemaField(
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry.""",
default_factory=list,
q_organization_keyword_tags: list[str] = SchemaField(
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry."""
)
q_organization_name: Optional[str] = SchemaField(
q_organization_name: str = SchemaField(
description="""Filter search results to include a specific company name.
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible.""",
default="",
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible."""
)
organization_ids: Optional[list[str]] = SchemaField(
organization_ids: list[str] = SchemaField(
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
To find IDs, identify the values for organization_id when you call this endpoint.""",
default_factory=list,
)
max_results: Optional[int] = SchemaField(
max_results: int = SchemaField(
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
default=100,
ge=1,
@@ -435,11 +421,11 @@ Use the page parameter to search the different pages of data.""",
class SearchOrganizationsResponse(BaseModel):
"""Response from Apollo's search organizations API"""
breadcrumbs: Optional[list[Breadcrumb]] = []
partial_results_only: Optional[bool] = True
has_join: Optional[bool] = True
disable_eu_prospecting: Optional[bool] = True
partial_results_limit: Optional[int] = 0
breadcrumbs: list[Breadcrumb] = []
partial_results_only: bool = True
has_join: bool = True
disable_eu_prospecting: bool = True
partial_results_limit: int = 0
pagination: Pagination = Pagination(
page=0, per_page=0, total_entries=0, total_pages=0
)
@@ -447,14 +433,14 @@ class SearchOrganizationsResponse(BaseModel):
accounts: list[Any] = []
organizations: list[Organization] = []
models_ids: list[str] = []
num_fetch_result: Optional[str] = ""
derived_params: Optional[str] = ""
num_fetch_result: Optional[str] = "N/A"
derived_params: Optional[str] = "N/A"
class SearchPeopleRequest(BaseModel):
"""Request for Apollo's search people API"""
person_titles: Optional[list[str]] = SchemaField(
person_titles: list[str] = SchemaField(
description="""Job titles held by the people you want to find. For a person to be included in search results, they only need to match 1 of the job titles you add. Adding more job titles expands your search results.
Results also include job titles with the same terms, even if they are not exact matches. For example, searching for marketing manager might return people with the job title content marketing manager.
@@ -464,13 +450,13 @@ Use this parameter in combination with the person_seniorities[] parameter to fin
default_factory=list,
placeholder="marketing manager",
)
person_locations: Optional[list[str]] = SchemaField(
person_locations: list[str] = SchemaField(
description="""The location where people live. You can search across cities, US states, and countries.
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
default_factory=list,
)
person_seniorities: Optional[list[SenorityLevels]] = SchemaField(
person_seniorities: list[SenorityLevels] = SchemaField(
description="""The job seniority that people hold within their current employer. This enables you to find people that currently hold positions at certain reporting levels, such as Director level or senior IC level.
For a person to be included in search results, they only need to match 1 of the seniorities you add. Adding more seniorities expands your search results.
@@ -480,7 +466,7 @@ Searches only return results based on their current job title, so searching for
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
default_factory=list,
)
organization_locations: Optional[list[str]] = SchemaField(
organization_locations: list[str] = SchemaField(
description="""The location of the company headquarters for a person's current employer. You can search across cities, US states, and countries.
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
@@ -488,7 +474,7 @@ If a company has several office locations, results are still based on the headqu
To find people based on their personal location, use the person_locations parameter.""",
default_factory=list,
)
q_organization_domains: Optional[list[str]] = SchemaField(
q_organization_domains: list[str] = SchemaField(
description="""The domain name for the person's employer. This can be the current employer or a previous employer. Do not include www., the @ symbol, or similar.
You can add multiple domains to search across companies.
@@ -496,23 +482,23 @@ You can add multiple domains to search across companies.
Examples: apollo.io and microsoft.com""",
default_factory=list,
)
contact_email_statuses: Optional[list[ContactEmailStatuses]] = SchemaField(
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
default_factory=list,
)
organization_ids: Optional[list[str]] = SchemaField(
organization_ids: list[str] = SchemaField(
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
default_factory=list,
)
organization_num_employees_range: Optional[list[int]] = SchemaField(
organization_num_empoloyees_range: list[int] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
default_factory=list,
)
q_keywords: Optional[str] = SchemaField(
q_keywords: str = SchemaField(
description="""A string of words over which we want to filter the results""",
default="",
)
@@ -528,7 +514,7 @@ Use this parameter in combination with the per_page parameter to make search res
Use the page parameter to search the different pages of data.""",
default=100,
)
max_results: Optional[int] = SchemaField(
max_results: int = SchemaField(
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
default=100,
ge=1,
@@ -547,61 +533,16 @@ class SearchPeopleResponse(BaseModel):
populate_by_name=True,
)
breadcrumbs: Optional[list[Breadcrumb]] = []
partial_results_only: Optional[bool] = True
has_join: Optional[bool] = True
disable_eu_prospecting: Optional[bool] = True
partial_results_limit: Optional[int] = 0
breadcrumbs: list[Breadcrumb] = []
partial_results_only: bool = True
has_join: bool = True
disable_eu_prospecting: bool = True
partial_results_limit: int = 0
pagination: Pagination = Pagination(
page=0, per_page=0, total_entries=0, total_pages=0
)
contacts: list[Contact] = []
people: list[Contact] = []
model_ids: list[str] = []
num_fetch_result: Optional[str] = ""
derived_params: Optional[str] = ""
class EnrichPersonRequest(BaseModel):
"""Request for Apollo's person enrichment API"""
person_id: Optional[str] = SchemaField(
description="Apollo person ID to enrich (most accurate method)",
default="",
)
first_name: Optional[str] = SchemaField(
description="First name of the person to enrich",
default="",
)
last_name: Optional[str] = SchemaField(
description="Last name of the person to enrich",
default="",
)
name: Optional[str] = SchemaField(
description="Full name of the person to enrich",
default="",
)
email: Optional[str] = SchemaField(
description="Email address of the person to enrich",
default="",
)
domain: Optional[str] = SchemaField(
description="Company domain of the person to enrich",
default="",
)
company: Optional[str] = SchemaField(
description="Company name of the person to enrich",
default="",
)
linkedin_url: Optional[str] = SchemaField(
description="LinkedIn URL of the person to enrich",
default="",
)
organization_id: Optional[str] = SchemaField(
description="Apollo organization ID of the person's company",
default="",
)
title: Optional[str] = SchemaField(
description="Job title of the person to enrich",
default="",
)
num_fetch_result: Optional[str] = "N/A"
derived_params: Optional[str] = "N/A"

View File

@@ -11,14 +11,14 @@ from backend.blocks.apollo.models import (
SearchOrganizationsRequest,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import CredentialsField, SchemaField
from backend.data.model import SchemaField
class SearchOrganizationsBlock(Block):
"""Search for organizations in Apollo"""
class Input(BlockSchema):
organization_num_employees_range: list[int] = SchemaField(
organization_num_empoloyees_range: list[int] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
@@ -65,7 +65,7 @@ To find IDs, identify the values for organization_id when you call this endpoint
le=50000,
advanced=True,
)
credentials: ApolloCredentialsInput = CredentialsField(
credentials: ApolloCredentialsInput = SchemaField(
description="Apollo credentials",
)
@@ -210,7 +210,9 @@ To find IDs, identify the values for organization_id when you call this endpoint
async def run(
self, input_data: Input, *, credentials: ApolloCredentials, **kwargs
) -> BlockOutput:
query = SearchOrganizationsRequest(**input_data.model_dump())
query = SearchOrganizationsRequest(
**input_data.model_dump(exclude={"credentials"})
)
organizations = await self.search_organizations(query, credentials)
for organization in organizations:
yield "organization", organization

View File

@@ -1,5 +1,3 @@
import asyncio
from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS,
@@ -10,12 +8,11 @@ from backend.blocks.apollo._auth import (
from backend.blocks.apollo.models import (
Contact,
ContactEmailStatuses,
EnrichPersonRequest,
SearchPeopleRequest,
SenorityLevels,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import CredentialsField, SchemaField
from backend.data.model import SchemaField
class SearchPeopleBlock(Block):
@@ -80,7 +77,7 @@ class SearchPeopleBlock(Block):
default_factory=list,
advanced=False,
)
organization_num_employees_range: list[int] = SchemaField(
organization_num_empoloyees_range: list[int] = SchemaField(
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
@@ -93,19 +90,14 @@ class SearchPeopleBlock(Block):
advanced=False,
)
max_results: int = SchemaField(
description="""The maximum number of results to return. If you don't specify this parameter, the default is 25. Limited to 500 to prevent overspending.""",
default=25,
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
default=100,
ge=1,
le=500,
advanced=True,
)
enrich_info: bool = SchemaField(
description="""Whether to enrich contacts with detailed information including real email addresses. This will double the search cost.""",
default=False,
le=50000,
advanced=True,
)
credentials: ApolloCredentialsInput = CredentialsField(
credentials: ApolloCredentialsInput = SchemaField(
description="Apollo credentials",
)
@@ -114,6 +106,9 @@ class SearchPeopleBlock(Block):
description="List of people found",
default_factory=list,
)
person: Contact = SchemaField(
description="Each found person, one at a time",
)
error: str = SchemaField(
description="Error message if the search failed",
default="",
@@ -129,6 +124,87 @@ class SearchPeopleBlock(Block):
test_credentials=TEST_CREDENTIALS,
test_input={"credentials": TEST_CREDENTIALS_INPUT},
test_output=[
(
"person",
Contact(
contact_roles=[],
id="1",
name="John Doe",
first_name="John",
last_name="Doe",
linkedin_url="https://www.linkedin.com/in/johndoe",
title="Software Engineer",
organization_name="Google",
organization_id="123456",
contact_stage_id="1",
owner_id="1",
creator_id="1",
person_id="1",
email_needs_tickling=True,
source="apollo",
original_source="apollo",
headline="Software Engineer",
photo_url="https://www.linkedin.com/in/johndoe",
present_raw_address="123 Main St, Anytown, USA",
linkededin_uid="123456",
extrapolated_email_confidence=0.8,
salesforce_id="123456",
salesforce_lead_id="123456",
salesforce_contact_id="123456",
saleforce_account_id="123456",
crm_owner_id="123456",
created_at="2021-01-01",
emailer_campaign_ids=[],
direct_dial_status="active",
direct_dial_enrichment_failed_at="2021-01-01",
email_status="active",
email_source="apollo",
account_id="123456",
last_activity_date="2021-01-01",
hubspot_vid="123456",
hubspot_company_id="123456",
crm_id="123456",
sanitized_phone="123456",
merged_crm_ids="123456",
updated_at="2021-01-01",
queued_for_crm_push=True,
suggested_from_rule_engine_config_id="123456",
email_unsubscribed=None,
label_ids=[],
has_pending_email_arcgate_request=True,
has_email_arcgate_request=True,
existence_level=None,
email=None,
email_from_customer=None,
typed_custom_fields=[],
custom_field_errors=None,
salesforce_record_id=None,
crm_record_url=None,
email_status_unavailable_reason=None,
email_true_status=None,
updated_email_true_status=True,
contact_rule_config_statuses=[],
source_display_name=None,
twitter_url=None,
contact_campaign_statuses=[],
state=None,
city=None,
country=None,
account=None,
contact_emails=[],
organization=None,
employment_history=[],
time_zone=None,
intent_strength=None,
show_intent=True,
phone_numbers=[],
account_phone_note=None,
free_domain=True,
is_likely_to_engage=True,
email_domain_catchall=True,
contact_job_change_event=None,
),
),
(
"people",
[
@@ -303,34 +379,6 @@ class SearchPeopleBlock(Block):
client = ApolloClient(credentials)
return await client.search_people(query)
@staticmethod
async def enrich_person(
query: EnrichPersonRequest, credentials: ApolloCredentials
) -> Contact:
client = ApolloClient(credentials)
return await client.enrich_person(query)
@staticmethod
def merge_contact_data(original: Contact, enriched: Contact) -> Contact:
"""
Merge contact data from original search with enriched data.
Enriched data complements original data, only filling in missing values.
"""
merged_data = original.model_dump()
enriched_data = enriched.model_dump()
# Only update fields that are None, empty string, empty list, or default values in original
for key, enriched_value in enriched_data.items():
# Skip if enriched value is None, empty string, or empty list
if enriched_value is None or enriched_value == "" or enriched_value == []:
continue
# Update if original value is None, empty string, empty list, or zero
if enriched_value:
merged_data[key] = enriched_value
return Contact(**merged_data)
async def run(
self,
input_data: Input,
@@ -339,25 +387,8 @@ class SearchPeopleBlock(Block):
**kwargs,
) -> BlockOutput:
query = SearchPeopleRequest(**input_data.model_dump())
query = SearchPeopleRequest(**input_data.model_dump(exclude={"credentials"}))
people = await self.search_people(query, credentials)
# Enrich with detailed info if requested
if input_data.enrich_info:
async def enrich_or_fallback(person: Contact):
try:
enrich_query = EnrichPersonRequest(person_id=person.id)
enriched_person = await self.enrich_person(
enrich_query, credentials
)
# Merge enriched data with original data, complementing instead of replacing
return self.merge_contact_data(person, enriched_person)
except Exception:
return person # If enrichment fails, use original person data
people = await asyncio.gather(
*(enrich_or_fallback(person) for person in people)
)
for person in people:
yield "person", person
yield "people", people

View File

@@ -1,138 +0,0 @@
from backend.blocks.apollo._api import ApolloClient
from backend.blocks.apollo._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
ApolloCredentials,
ApolloCredentialsInput,
)
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import CredentialsField, SchemaField
class GetPersonDetailBlock(Block):
"""Get detailed person data with Apollo API, including email reveal"""
class Input(BlockSchema):
person_id: str = SchemaField(
description="Apollo person ID to enrich (most accurate method)",
default="",
advanced=False,
)
first_name: str = SchemaField(
description="First name of the person to enrich",
default="",
advanced=False,
)
last_name: str = SchemaField(
description="Last name of the person to enrich",
default="",
advanced=False,
)
name: str = SchemaField(
description="Full name of the person to enrich (alternative to first_name + last_name)",
default="",
advanced=False,
)
email: str = SchemaField(
description="Known email address of the person (helps with matching)",
default="",
advanced=False,
)
domain: str = SchemaField(
description="Company domain of the person (e.g., 'google.com')",
default="",
advanced=False,
)
company: str = SchemaField(
description="Company name of the person",
default="",
advanced=False,
)
linkedin_url: str = SchemaField(
description="LinkedIn URL of the person",
default="",
advanced=False,
)
organization_id: str = SchemaField(
description="Apollo organization ID of the person's company",
default="",
advanced=True,
)
title: str = SchemaField(
description="Job title of the person to enrich",
default="",
advanced=True,
)
credentials: ApolloCredentialsInput = CredentialsField(
description="Apollo credentials",
)
class Output(BlockSchema):
contact: Contact = SchemaField(
description="Enriched contact information",
)
error: str = SchemaField(
description="Error message if enrichment failed",
default="",
)
def __init__(self):
super().__init__(
id="3b18d46c-3db6-42ae-a228-0ba441bdd176",
description="Get detailed person data with Apollo API, including email reveal",
categories={BlockCategory.SEARCH},
input_schema=GetPersonDetailBlock.Input,
output_schema=GetPersonDetailBlock.Output,
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"first_name": "John",
"last_name": "Doe",
"company": "Google",
},
test_output=[
(
"contact",
Contact(
id="1",
name="John Doe",
first_name="John",
last_name="Doe",
email="john.doe@gmail.com",
title="Software Engineer",
organization_name="Google",
linkedin_url="https://www.linkedin.com/in/johndoe",
),
),
],
test_mock={
"enrich_person": lambda query, credentials: Contact(
id="1",
name="John Doe",
first_name="John",
last_name="Doe",
email="john.doe@gmail.com",
title="Software Engineer",
organization_name="Google",
linkedin_url="https://www.linkedin.com/in/johndoe",
)
},
)
@staticmethod
async def enrich_person(
query: EnrichPersonRequest, credentials: ApolloCredentials
) -> Contact:
client = ApolloClient(credentials)
return await client.enrich_person(query)
async def run(
self,
input_data: Input,
*,
credentials: ApolloCredentials,
**kwargs,
) -> BlockOutput:
query = EnrichPersonRequest(**input_data.model_dump())
yield "contact", await self.enrich_person(query, credentials)

View File

@@ -1,9 +1,11 @@
import enum
from typing import Any
from typing import Any, List
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
from backend.data.model import SchemaField
from backend.util import json
from backend.util.file import store_media_file
from backend.util.mock import MockObject
from backend.util.type import MediaFileType, convert
@@ -12,12 +14,6 @@ class FileStoreBlock(Block):
file_in: MediaFileType = SchemaField(
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
)
base_64: bool = SchemaField(
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
default=False,
advanced=True,
title="Produce Base64 Output",
)
class Output(BlockSchema):
file_out: MediaFileType = SchemaField(
@@ -41,11 +37,12 @@ class FileStoreBlock(Block):
graph_exec_id: str,
**kwargs,
) -> BlockOutput:
yield "file_out", await store_media_file(
file_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.file_in,
return_content=input_data.base_64,
return_content=False,
)
yield "file_out", file_path
class StoreValueBlock(Block):
@@ -118,6 +115,266 @@ class PrintToConsoleBlock(Block):
yield "status", "printed"
class FindInDictionaryBlock(Block):
class Input(BlockSchema):
input: Any = SchemaField(description="Dictionary to lookup from")
key: str | int = SchemaField(description="Key to lookup in the dictionary")
class Output(BlockSchema):
output: Any = SchemaField(description="Value found for the given key")
missing: Any = SchemaField(
description="Value of the input that missing the key"
)
def __init__(self):
super().__init__(
id="0e50422c-6dee-4145-83d6-3a5a392f65de",
description="Lookup the given key in the input dictionary/object/list and return the value.",
input_schema=FindInDictionaryBlock.Input,
output_schema=FindInDictionaryBlock.Output,
test_input=[
{"input": {"apple": 1, "banana": 2, "cherry": 3}, "key": "banana"},
{"input": {"x": 10, "y": 20, "z": 30}, "key": "w"},
{"input": [1, 2, 3], "key": 1},
{"input": [1, 2, 3], "key": 3},
{"input": MockObject(value="!!", key="key"), "key": "key"},
{"input": [{"k1": "v1"}, {"k2": "v2"}, {"k1": "v3"}], "key": "k1"},
],
test_output=[
("output", 2),
("missing", {"x": 10, "y": 20, "z": 30}),
("output", 2),
("missing", [1, 2, 3]),
("output", "key"),
("output", ["v1", "v3"]),
],
categories={BlockCategory.BASIC},
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
obj = input_data.input
key = input_data.key
if isinstance(obj, str):
obj = json.loads(obj)
if isinstance(obj, dict) and key in obj:
yield "output", obj[key]
elif isinstance(obj, list) and isinstance(key, int) and 0 <= key < len(obj):
yield "output", obj[key]
elif isinstance(obj, list) and isinstance(key, str):
if len(obj) == 0:
yield "output", []
elif isinstance(obj[0], dict) and key in obj[0]:
yield "output", [item[key] for item in obj if key in item]
else:
yield "output", [getattr(val, key) for val in obj if hasattr(val, key)]
elif isinstance(obj, object) and isinstance(key, str) and hasattr(obj, key):
yield "output", getattr(obj, key)
else:
yield "missing", input_data.input
class AddToDictionaryBlock(Block):
class Input(BlockSchema):
dictionary: dict[Any, Any] = SchemaField(
default_factory=dict,
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
)
key: str = SchemaField(
default="",
description="The key for the new entry.",
placeholder="new_key",
advanced=False,
)
value: Any = SchemaField(
default=None,
description="The value for the new entry.",
placeholder="new_value",
advanced=False,
)
entries: dict[Any, Any] = SchemaField(
default_factory=dict,
description="The entries to add to the dictionary. This is the batch version of the `key` and `value` fields.",
advanced=True,
)
class Output(BlockSchema):
updated_dictionary: dict = SchemaField(
description="The dictionary with the new entry added."
)
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="31d1064e-7446-4693-a7d4-65e5ca1180d1",
description="Adds a new key-value pair to a dictionary. If no dictionary is provided, a new one is created.",
categories={BlockCategory.BASIC},
input_schema=AddToDictionaryBlock.Input,
output_schema=AddToDictionaryBlock.Output,
test_input=[
{
"dictionary": {"existing_key": "existing_value"},
"key": "new_key",
"value": "new_value",
},
{"key": "first_key", "value": "first_value"},
{
"dictionary": {"existing_key": "existing_value"},
"entries": {"new_key": "new_value", "first_key": "first_value"},
},
],
test_output=[
(
"updated_dictionary",
{"existing_key": "existing_value", "new_key": "new_value"},
),
("updated_dictionary", {"first_key": "first_value"}),
(
"updated_dictionary",
{
"existing_key": "existing_value",
"new_key": "new_value",
"first_key": "first_value",
},
),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
updated_dict = input_data.dictionary.copy()
if input_data.value is not None and input_data.key:
updated_dict[input_data.key] = input_data.value
for key, value in input_data.entries.items():
updated_dict[key] = value
yield "updated_dictionary", updated_dict
class AddToListBlock(Block):
class Input(BlockSchema):
list: List[Any] = SchemaField(
default_factory=list,
advanced=False,
description="The list to add the entry to. If not provided, a new list will be created.",
)
entry: Any = SchemaField(
description="The entry to add to the list. Can be of any type (string, int, dict, etc.).",
advanced=False,
default=None,
)
entries: List[Any] = SchemaField(
default_factory=lambda: list(),
description="The entries to add to the list. This is the batch version of the `entry` field.",
advanced=True,
)
position: int | None = SchemaField(
default=None,
description="The position to insert the new entry. If not provided, the entry will be appended to the end of the list.",
)
class Output(BlockSchema):
updated_list: List[Any] = SchemaField(
description="The list with the new entry added."
)
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="aeb08fc1-2fc1-4141-bc8e-f758f183a822",
description="Adds a new entry to a list. The entry can be of any type. If no list is provided, a new one is created.",
categories={BlockCategory.BASIC},
input_schema=AddToListBlock.Input,
output_schema=AddToListBlock.Output,
test_input=[
{
"list": [1, "string", {"existing_key": "existing_value"}],
"entry": {"new_key": "new_value"},
"position": 1,
},
{"entry": "first_entry"},
{"list": ["a", "b", "c"], "entry": "d"},
{
"entry": "e",
"entries": ["f", "g"],
"list": ["a", "b"],
"position": 1,
},
],
test_output=[
(
"updated_list",
[
1,
{"new_key": "new_value"},
"string",
{"existing_key": "existing_value"},
],
),
("updated_list", ["first_entry"]),
("updated_list", ["a", "b", "c", "d"]),
("updated_list", ["a", "f", "g", "e", "b"]),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
entries_added = input_data.entries.copy()
if input_data.entry:
entries_added.append(input_data.entry)
updated_list = input_data.list.copy()
if (pos := input_data.position) is not None:
updated_list = updated_list[:pos] + entries_added + updated_list[pos:]
else:
updated_list += entries_added
yield "updated_list", updated_list
class FindInListBlock(Block):
class Input(BlockSchema):
list: List[Any] = SchemaField(description="The list to search in.")
value: Any = SchemaField(description="The value to search for.")
class Output(BlockSchema):
index: int = SchemaField(description="The index of the value in the list.")
found: bool = SchemaField(
description="Whether the value was found in the list."
)
not_found_value: Any = SchemaField(
description="The value that was not found in the list."
)
def __init__(self):
super().__init__(
id="5e2c6d0a-1e37-489f-b1d0-8e1812b23333",
description="Finds the index of the value in the list.",
categories={BlockCategory.BASIC},
input_schema=FindInListBlock.Input,
output_schema=FindInListBlock.Output,
test_input=[
{"list": [1, 2, 3, 4, 5], "value": 3},
{"list": [1, 2, 3, 4, 5], "value": 6},
],
test_output=[
("index", 2),
("found", True),
("found", False),
("not_found_value", 6),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
yield "index", input_data.list.index(input_data.value)
yield "found", True
except ValueError:
yield "found", False
yield "not_found_value", input_data.value
class NoteBlock(Block):
class Input(BlockSchema):
text: str = SchemaField(description="The text to display in the sticky note.")
@@ -143,6 +400,104 @@ class NoteBlock(Block):
yield "output", input_data.text
class CreateDictionaryBlock(Block):
class Input(BlockSchema):
values: dict[str, Any] = SchemaField(
description="Key-value pairs to create the dictionary with",
placeholder="e.g., {'name': 'Alice', 'age': 25}",
)
class Output(BlockSchema):
dictionary: dict[str, Any] = SchemaField(
description="The created dictionary containing the specified key-value pairs"
)
error: str = SchemaField(
description="Error message if dictionary creation failed"
)
def __init__(self):
super().__init__(
id="b924ddf4-de4f-4b56-9a85-358930dcbc91",
description="Creates a dictionary with the specified key-value pairs. Use this when you know all the values you want to add upfront.",
categories={BlockCategory.DATA},
input_schema=CreateDictionaryBlock.Input,
output_schema=CreateDictionaryBlock.Output,
test_input=[
{
"values": {"name": "Alice", "age": 25, "city": "New York"},
},
{
"values": {"numbers": [1, 2, 3], "active": True, "score": 95.5},
},
],
test_output=[
(
"dictionary",
{"name": "Alice", "age": 25, "city": "New York"},
),
(
"dictionary",
{"numbers": [1, 2, 3], "active": True, "score": 95.5},
),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
# The values are already validated by Pydantic schema
yield "dictionary", input_data.values
except Exception as e:
yield "error", f"Failed to create dictionary: {str(e)}"
class CreateListBlock(Block):
class Input(BlockSchema):
values: List[Any] = SchemaField(
description="A list of values to be combined into a new list.",
placeholder="e.g., ['Alice', 25, True]",
)
class Output(BlockSchema):
list: List[Any] = SchemaField(
description="The created list containing the specified values."
)
error: str = SchemaField(description="Error message if list creation failed.")
def __init__(self):
super().__init__(
id="a912d5c7-6e00-4542-b2a9-8034136930e4",
description="Creates a list with the specified values. Use this when you know all the values you want to add upfront.",
categories={BlockCategory.DATA},
input_schema=CreateListBlock.Input,
output_schema=CreateListBlock.Output,
test_input=[
{
"values": ["Alice", 25, True],
},
{
"values": [1, 2, 3, "four", {"key": "value"}],
},
],
test_output=[
(
"list",
["Alice", 25, True],
),
(
"list",
[1, 2, 3, "four", {"key": "value"}],
),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
# The values are already validated by Pydantic schema
yield "list", input_data.values
except Exception as e:
yield "error", f"Failed to create list: {str(e)}"
class TypeOptions(enum.Enum):
STRING = "string"
NUMBER = "number"
@@ -160,7 +515,6 @@ class UniversalTypeConverterBlock(Block):
class Output(BlockSchema):
value: Any = SchemaField(description="The converted value.")
error: str = SchemaField(description="Error message if conversion failed.")
def __init__(self):
super().__init__(

View File

@@ -3,7 +3,6 @@ from typing import Any
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.type import convert
class ComparisonOperator(Enum):
@@ -164,7 +163,7 @@ class IfInputMatchesBlock(Block):
},
{
"input": 10,
"value": "None",
"value": None,
"yes_value": "Yes",
"no_value": "No",
},
@@ -182,23 +181,7 @@ class IfInputMatchesBlock(Block):
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
# If input_data.value is not matching input_data.input, convert value to type of input
if (
input_data.input != input_data.value
and input_data.input is not input_data.value
):
try:
# Only attempt conversion if input is not None and value is not None
if input_data.input is not None and input_data.value is not None:
input_type = type(input_data.input)
# Avoid converting if input_type is Any or object
if input_type not in (Any, object):
input_data.value = convert(input_data.value, input_type)
except Exception:
pass # If conversion fails, just leave value as is
if input_data.input == input_data.value:
if input_data.input == input_data.value or input_data.input is input_data.value:
yield "result", True
yield "yes_output", input_data.yes_value
else:

View File

@@ -1,683 +0,0 @@
from typing import Any, List
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.json import loads
from backend.util.mock import MockObject
from backend.util.prompt import estimate_token_count_str
# =============================================================================
# Dictionary Manipulation Blocks
# =============================================================================
class CreateDictionaryBlock(Block):
class Input(BlockSchema):
values: dict[str, Any] = SchemaField(
description="Key-value pairs to create the dictionary with",
placeholder="e.g., {'name': 'Alice', 'age': 25}",
)
class Output(BlockSchema):
dictionary: dict[str, Any] = SchemaField(
description="The created dictionary containing the specified key-value pairs"
)
error: str = SchemaField(
description="Error message if dictionary creation failed"
)
def __init__(self):
super().__init__(
id="b924ddf4-de4f-4b56-9a85-358930dcbc91",
description="Creates a dictionary with the specified key-value pairs. Use this when you know all the values you want to add upfront.",
categories={BlockCategory.DATA},
input_schema=CreateDictionaryBlock.Input,
output_schema=CreateDictionaryBlock.Output,
test_input=[
{
"values": {"name": "Alice", "age": 25, "city": "New York"},
},
{
"values": {"numbers": [1, 2, 3], "active": True, "score": 95.5},
},
],
test_output=[
(
"dictionary",
{"name": "Alice", "age": 25, "city": "New York"},
),
(
"dictionary",
{"numbers": [1, 2, 3], "active": True, "score": 95.5},
),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
# The values are already validated by Pydantic schema
yield "dictionary", input_data.values
except Exception as e:
yield "error", f"Failed to create dictionary: {str(e)}"
class AddToDictionaryBlock(Block):
class Input(BlockSchema):
dictionary: dict[Any, Any] = SchemaField(
default_factory=dict,
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
)
key: str = SchemaField(
default="",
description="The key for the new entry.",
placeholder="new_key",
advanced=False,
)
value: Any = SchemaField(
default=None,
description="The value for the new entry.",
placeholder="new_value",
advanced=False,
)
entries: dict[Any, Any] = SchemaField(
default_factory=dict,
description="The entries to add to the dictionary. This is the batch version of the `key` and `value` fields.",
advanced=True,
)
class Output(BlockSchema):
updated_dictionary: dict = SchemaField(
description="The dictionary with the new entry added."
)
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="31d1064e-7446-4693-a7d4-65e5ca1180d1",
description="Adds a new key-value pair to a dictionary. If no dictionary is provided, a new one is created.",
categories={BlockCategory.BASIC},
input_schema=AddToDictionaryBlock.Input,
output_schema=AddToDictionaryBlock.Output,
test_input=[
{
"dictionary": {"existing_key": "existing_value"},
"key": "new_key",
"value": "new_value",
},
{"key": "first_key", "value": "first_value"},
{
"dictionary": {"existing_key": "existing_value"},
"entries": {"new_key": "new_value", "first_key": "first_value"},
},
],
test_output=[
(
"updated_dictionary",
{"existing_key": "existing_value", "new_key": "new_value"},
),
("updated_dictionary", {"first_key": "first_value"}),
(
"updated_dictionary",
{
"existing_key": "existing_value",
"new_key": "new_value",
"first_key": "first_value",
},
),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
updated_dict = input_data.dictionary.copy()
if input_data.value is not None and input_data.key:
updated_dict[input_data.key] = input_data.value
for key, value in input_data.entries.items():
updated_dict[key] = value
yield "updated_dictionary", updated_dict
class FindInDictionaryBlock(Block):
class Input(BlockSchema):
input: Any = SchemaField(description="Dictionary to lookup from")
key: str | int = SchemaField(description="Key to lookup in the dictionary")
class Output(BlockSchema):
output: Any = SchemaField(description="Value found for the given key")
missing: Any = SchemaField(
description="Value of the input that missing the key"
)
def __init__(self):
super().__init__(
id="0e50422c-6dee-4145-83d6-3a5a392f65de",
description="Lookup the given key in the input dictionary/object/list and return the value.",
input_schema=FindInDictionaryBlock.Input,
output_schema=FindInDictionaryBlock.Output,
test_input=[
{"input": {"apple": 1, "banana": 2, "cherry": 3}, "key": "banana"},
{"input": {"x": 10, "y": 20, "z": 30}, "key": "w"},
{"input": [1, 2, 3], "key": 1},
{"input": [1, 2, 3], "key": 3},
{"input": MockObject(value="!!", key="key"), "key": "key"},
{"input": [{"k1": "v1"}, {"k2": "v2"}, {"k1": "v3"}], "key": "k1"},
],
test_output=[
("output", 2),
("missing", {"x": 10, "y": 20, "z": 30}),
("output", 2),
("missing", [1, 2, 3]),
("output", "key"),
("output", ["v1", "v3"]),
],
categories={BlockCategory.BASIC},
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
obj = input_data.input
key = input_data.key
if isinstance(obj, str):
obj = loads(obj)
if isinstance(obj, dict) and key in obj:
yield "output", obj[key]
elif isinstance(obj, list) and isinstance(key, int) and 0 <= key < len(obj):
yield "output", obj[key]
elif isinstance(obj, list) and isinstance(key, str):
if len(obj) == 0:
yield "output", []
elif isinstance(obj[0], dict) and key in obj[0]:
yield "output", [item[key] for item in obj if key in item]
else:
yield "output", [getattr(val, key) for val in obj if hasattr(val, key)]
elif isinstance(obj, object) and isinstance(key, str) and hasattr(obj, key):
yield "output", getattr(obj, key)
else:
yield "missing", input_data.input
class RemoveFromDictionaryBlock(Block):
class Input(BlockSchema):
dictionary: dict[Any, Any] = SchemaField(
description="The dictionary to modify."
)
key: str | int = SchemaField(description="Key to remove from the dictionary.")
return_value: bool = SchemaField(
default=False, description="Whether to return the removed value."
)
class Output(BlockSchema):
updated_dictionary: dict[Any, Any] = SchemaField(
description="The dictionary after removal."
)
removed_value: Any = SchemaField(description="The removed value if requested.")
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="46afe2ea-c613-43f8-95ff-6692c3ef6876",
description="Removes a key-value pair from a dictionary.",
categories={BlockCategory.BASIC},
input_schema=RemoveFromDictionaryBlock.Input,
output_schema=RemoveFromDictionaryBlock.Output,
test_input=[
{
"dictionary": {"a": 1, "b": 2, "c": 3},
"key": "b",
"return_value": True,
},
{"dictionary": {"x": "hello", "y": "world"}, "key": "x"},
],
test_output=[
("updated_dictionary", {"a": 1, "c": 3}),
("removed_value", 2),
("updated_dictionary", {"y": "world"}),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
updated_dict = input_data.dictionary.copy()
try:
removed_value = updated_dict.pop(input_data.key)
yield "updated_dictionary", updated_dict
if input_data.return_value:
yield "removed_value", removed_value
except KeyError:
yield "error", f"Key '{input_data.key}' not found in dictionary"
class ReplaceDictionaryValueBlock(Block):
class Input(BlockSchema):
dictionary: dict[Any, Any] = SchemaField(
description="The dictionary to modify."
)
key: str | int = SchemaField(description="Key to replace the value for.")
value: Any = SchemaField(description="The new value for the given key.")
class Output(BlockSchema):
updated_dictionary: dict[Any, Any] = SchemaField(
description="The dictionary after replacement."
)
old_value: Any = SchemaField(description="The value that was replaced.")
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="27e31876-18b6-44f3-ab97-f6226d8b3889",
description="Replaces the value for a specified key in a dictionary.",
categories={BlockCategory.BASIC},
input_schema=ReplaceDictionaryValueBlock.Input,
output_schema=ReplaceDictionaryValueBlock.Output,
test_input=[
{"dictionary": {"a": 1, "b": 2, "c": 3}, "key": "b", "value": 99},
{
"dictionary": {"x": "hello", "y": "world"},
"key": "y",
"value": "universe",
},
],
test_output=[
("updated_dictionary", {"a": 1, "b": 99, "c": 3}),
("old_value", 2),
("updated_dictionary", {"x": "hello", "y": "universe"}),
("old_value", "world"),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
updated_dict = input_data.dictionary.copy()
try:
old_value = updated_dict[input_data.key]
updated_dict[input_data.key] = input_data.value
yield "updated_dictionary", updated_dict
yield "old_value", old_value
except KeyError:
yield "error", f"Key '{input_data.key}' not found in dictionary"
class DictionaryIsEmptyBlock(Block):
class Input(BlockSchema):
dictionary: dict[Any, Any] = SchemaField(description="The dictionary to check.")
class Output(BlockSchema):
is_empty: bool = SchemaField(description="True if the dictionary is empty.")
def __init__(self):
super().__init__(
id="a3cf3f64-6bb9-4cc6-9900-608a0b3359b0",
description="Checks if a dictionary is empty.",
categories={BlockCategory.BASIC},
input_schema=DictionaryIsEmptyBlock.Input,
output_schema=DictionaryIsEmptyBlock.Output,
test_input=[{"dictionary": {}}, {"dictionary": {"a": 1}}],
test_output=[("is_empty", True), ("is_empty", False)],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
yield "is_empty", len(input_data.dictionary) == 0
# =============================================================================
# List Manipulation Blocks
# =============================================================================
class CreateListBlock(Block):
class Input(BlockSchema):
values: List[Any] = SchemaField(
description="A list of values to be combined into a new list.",
placeholder="e.g., ['Alice', 25, True]",
)
max_size: int | None = SchemaField(
default=None,
description="Maximum size of the list. If provided, the list will be yielded in chunks of this size.",
advanced=True,
)
max_tokens: int | None = SchemaField(
default=None,
description="Maximum tokens for the list. If provided, the list will be yielded in chunks that fit within this token limit.",
advanced=True,
)
class Output(BlockSchema):
list: List[Any] = SchemaField(
description="The created list containing the specified values."
)
error: str = SchemaField(description="Error message if list creation failed.")
def __init__(self):
super().__init__(
id="a912d5c7-6e00-4542-b2a9-8034136930e4",
description="Creates a list with the specified values. Use this when you know all the values you want to add upfront. This block can also yield the list in batches based on a maximum size or token limit.",
categories={BlockCategory.DATA},
input_schema=CreateListBlock.Input,
output_schema=CreateListBlock.Output,
test_input=[
{
"values": ["Alice", 25, True],
},
{
"values": [1, 2, 3, "four", {"key": "value"}],
},
],
test_output=[
(
"list",
["Alice", 25, True],
),
(
"list",
[1, 2, 3, "four", {"key": "value"}],
),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
chunk = []
cur_tokens, max_tokens = 0, input_data.max_tokens
cur_size, max_size = 0, input_data.max_size
for value in input_data.values:
if max_tokens:
tokens = estimate_token_count_str(value)
else:
tokens = 0
# Check if adding this value would exceed either limit
if (max_tokens and (cur_tokens + tokens > max_tokens)) or (
max_size and (cur_size + 1 > max_size)
):
yield "list", chunk
chunk = [value]
cur_size, cur_tokens = 1, tokens
else:
chunk.append(value)
cur_size, cur_tokens = cur_size + 1, cur_tokens + tokens
# Yield final chunk if any
if chunk or not input_data.values:
yield "list", chunk
class AddToListBlock(Block):
class Input(BlockSchema):
list: List[Any] = SchemaField(
default_factory=list,
advanced=False,
description="The list to add the entry to. If not provided, a new list will be created.",
)
entry: Any = SchemaField(
description="The entry to add to the list. Can be of any type (string, int, dict, etc.).",
advanced=False,
default=None,
)
entries: List[Any] = SchemaField(
default_factory=lambda: list(),
description="The entries to add to the list. This is the batch version of the `entry` field.",
advanced=True,
)
position: int | None = SchemaField(
default=None,
description="The position to insert the new entry. If not provided, the entry will be appended to the end of the list.",
)
class Output(BlockSchema):
updated_list: List[Any] = SchemaField(
description="The list with the new entry added."
)
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="aeb08fc1-2fc1-4141-bc8e-f758f183a822",
description="Adds a new entry to a list. The entry can be of any type. If no list is provided, a new one is created.",
categories={BlockCategory.BASIC},
input_schema=AddToListBlock.Input,
output_schema=AddToListBlock.Output,
test_input=[
{
"list": [1, "string", {"existing_key": "existing_value"}],
"entry": {"new_key": "new_value"},
"position": 1,
},
{"entry": "first_entry"},
{"list": ["a", "b", "c"], "entry": "d"},
{
"entry": "e",
"entries": ["f", "g"],
"list": ["a", "b"],
"position": 1,
},
],
test_output=[
(
"updated_list",
[
1,
{"new_key": "new_value"},
"string",
{"existing_key": "existing_value"},
],
),
("updated_list", ["first_entry"]),
("updated_list", ["a", "b", "c", "d"]),
("updated_list", ["a", "f", "g", "e", "b"]),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
entries_added = input_data.entries.copy()
if input_data.entry:
entries_added.append(input_data.entry)
updated_list = input_data.list.copy()
if (pos := input_data.position) is not None:
updated_list = updated_list[:pos] + entries_added + updated_list[pos:]
else:
updated_list += entries_added
yield "updated_list", updated_list
class FindInListBlock(Block):
class Input(BlockSchema):
list: List[Any] = SchemaField(description="The list to search in.")
value: Any = SchemaField(description="The value to search for.")
class Output(BlockSchema):
index: int = SchemaField(description="The index of the value in the list.")
found: bool = SchemaField(
description="Whether the value was found in the list."
)
not_found_value: Any = SchemaField(
description="The value that was not found in the list."
)
def __init__(self):
super().__init__(
id="5e2c6d0a-1e37-489f-b1d0-8e1812b23333",
description="Finds the index of the value in the list.",
categories={BlockCategory.BASIC},
input_schema=FindInListBlock.Input,
output_schema=FindInListBlock.Output,
test_input=[
{"list": [1, 2, 3, 4, 5], "value": 3},
{"list": [1, 2, 3, 4, 5], "value": 6},
],
test_output=[
("index", 2),
("found", True),
("found", False),
("not_found_value", 6),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
yield "index", input_data.list.index(input_data.value)
yield "found", True
except ValueError:
yield "found", False
yield "not_found_value", input_data.value
class GetListItemBlock(Block):
class Input(BlockSchema):
list: List[Any] = SchemaField(description="The list to get the item from.")
index: int = SchemaField(
description="The 0-based index of the item (supports negative indices)."
)
class Output(BlockSchema):
item: Any = SchemaField(description="The item at the specified index.")
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="262ca24c-1025-43cf-a578-534e23234e97",
description="Returns the element at the given index.",
categories={BlockCategory.BASIC},
input_schema=GetListItemBlock.Input,
output_schema=GetListItemBlock.Output,
test_input=[
{"list": [1, 2, 3], "index": 1},
{"list": [1, 2, 3], "index": -1},
],
test_output=[
("item", 2),
("item", 3),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
try:
yield "item", input_data.list[input_data.index]
except IndexError:
yield "error", "Index out of range"
class RemoveFromListBlock(Block):
class Input(BlockSchema):
list: List[Any] = SchemaField(description="The list to modify.")
value: Any = SchemaField(
default=None, description="Value to remove from the list."
)
index: int | None = SchemaField(
default=None,
description="Index of the item to pop (supports negative indices).",
)
return_item: bool = SchemaField(
default=False, description="Whether to return the removed item."
)
class Output(BlockSchema):
updated_list: List[Any] = SchemaField(description="The list after removal.")
removed_item: Any = SchemaField(description="The removed item if requested.")
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="d93c5a93-ac7e-41c1-ae5c-ef67e6e9b826",
description="Removes an item from a list by value or index.",
categories={BlockCategory.BASIC},
input_schema=RemoveFromListBlock.Input,
output_schema=RemoveFromListBlock.Output,
test_input=[
{"list": [1, 2, 3], "index": 1, "return_item": True},
{"list": ["a", "b", "c"], "value": "b"},
],
test_output=[
("updated_list", [1, 3]),
("removed_item", 2),
("updated_list", ["a", "c"]),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
lst = input_data.list.copy()
removed = None
try:
if input_data.index is not None:
removed = lst.pop(input_data.index)
elif input_data.value is not None:
lst.remove(input_data.value)
removed = input_data.value
else:
raise ValueError("No index or value provided for removal")
except (IndexError, ValueError):
yield "error", "Index or value not found"
return
yield "updated_list", lst
if input_data.return_item:
yield "removed_item", removed
class ReplaceListItemBlock(Block):
class Input(BlockSchema):
list: List[Any] = SchemaField(description="The list to modify.")
index: int = SchemaField(
description="Index of the item to replace (supports negative indices)."
)
value: Any = SchemaField(description="The new value for the given index.")
class Output(BlockSchema):
updated_list: List[Any] = SchemaField(description="The list after replacement.")
old_item: Any = SchemaField(description="The item that was replaced.")
error: str = SchemaField(description="Error message if the operation failed.")
def __init__(self):
super().__init__(
id="fbf62922-bea1-4a3d-8bac-23587f810b38",
description="Replaces an item at the specified index.",
categories={BlockCategory.BASIC},
input_schema=ReplaceListItemBlock.Input,
output_schema=ReplaceListItemBlock.Output,
test_input=[
{"list": [1, 2, 3], "index": 1, "value": 99},
{"list": ["a", "b"], "index": -1, "value": "c"},
],
test_output=[
("updated_list", [1, 99, 3]),
("old_item", 2),
("updated_list", ["a", "c"]),
("old_item", "b"),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
lst = input_data.list.copy()
try:
old = lst[input_data.index]
lst[input_data.index] = input_data.value
except IndexError:
yield "error", "Index out of range"
return
yield "updated_list", lst
yield "old_item", old
class ListIsEmptyBlock(Block):
class Input(BlockSchema):
list: List[Any] = SchemaField(description="The list to check.")
class Output(BlockSchema):
is_empty: bool = SchemaField(description="True if the list is empty.")
def __init__(self):
super().__init__(
id="896ed73b-27d0-41be-813c-c1c1dc856c03",
description="Checks if a list is empty.",
categories={BlockCategory.BASIC},
input_schema=ListIsEmptyBlock.Input,
output_schema=ListIsEmptyBlock.Output,
test_input=[{"list": []}, {"list": [1]}],
test_output=[("is_empty", True), ("is_empty", False)],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
yield "is_empty", len(input_data.list) == 0

View File

@@ -13,7 +13,7 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import MediaFileType, store_media_file
from backend.util.file import MediaFileType
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -108,7 +108,7 @@ class AIImageEditorBlock(Block):
output_schema=AIImageEditorBlock.Output,
test_input={
"prompt": "Add a hat to the cat",
"input_image": "",
"input_image": "https://example.com/cat.png",
"aspect_ratio": AspectRatio.MATCH_INPUT_IMAGE,
"seed": None,
"model": FluxKontextModelName.PRO,
@@ -128,22 +128,13 @@ class AIImageEditorBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
**kwargs,
) -> BlockOutput:
result = await self.run_model(
api_key=credentials.api_key,
model_name=input_data.model.api_name,
prompt=input_data.prompt,
input_image_b64=(
await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.input_image,
return_content=True,
)
if input_data.input_image
else None
),
input_image=input_data.input_image,
aspect_ratio=input_data.aspect_ratio.value,
seed=input_data.seed,
)
@@ -154,14 +145,14 @@ class AIImageEditorBlock(Block):
api_key: SecretStr,
model_name: str,
prompt: str,
input_image_b64: Optional[str],
input_image: Optional[MediaFileType],
aspect_ratio: str,
seed: Optional[int],
) -> MediaFileType:
client = ReplicateClient(api_token=api_key.get_secret_value())
input_params = {
"prompt": prompt,
"input_image": input_image_b64,
"input_image": input_image,
"aspect_ratio": aspect_ratio,
**({"seed": seed} if seed is not None else {}),
}

View File

@@ -498,9 +498,6 @@ class GithubListIssuesBlock(Block):
issue: IssueItem = SchemaField(
title="Issue", description="Issues with their title and URL"
)
issues: list[IssueItem] = SchemaField(
description="List of issues with their title and URL"
)
error: str = SchemaField(description="Error message if listing issues failed")
def __init__(self):
@@ -516,22 +513,13 @@ class GithubListIssuesBlock(Block):
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"issues",
[
{
"title": "Issue 1",
"url": "https://github.com/owner/repo/issues/1",
}
],
),
(
"issue",
{
"title": "Issue 1",
"url": "https://github.com/owner/repo/issues/1",
},
),
)
],
test_mock={
"list_issues": lambda *args, **kwargs: [
@@ -563,12 +551,10 @@ class GithubListIssuesBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
issues = await self.list_issues(
for issue in await self.list_issues(
credentials,
input_data.repo_url,
)
yield "issues", issues
for issue in issues:
):
yield "issue", issue

View File

@@ -31,12 +31,7 @@ class GithubListPullRequestsBlock(Block):
pull_request: PRItem = SchemaField(
title="Pull Request", description="PRs with their title and URL"
)
pull_requests: list[PRItem] = SchemaField(
description="List of pull requests with their title and URL"
)
error: str = SchemaField(
description="Error message if listing pull requests failed"
)
error: str = SchemaField(description="Error message if listing issues failed")
def __init__(self):
super().__init__(
@@ -51,22 +46,13 @@ class GithubListPullRequestsBlock(Block):
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"pull_requests",
[
{
"title": "Pull request 1",
"url": "https://github.com/owner/repo/pull/1",
}
],
),
(
"pull_request",
{
"title": "Pull request 1",
"url": "https://github.com/owner/repo/pull/1",
},
),
)
],
test_mock={
"list_prs": lambda *args, **kwargs: [
@@ -102,7 +88,6 @@ class GithubListPullRequestsBlock(Block):
credentials,
input_data.repo_url,
)
yield "pull_requests", pull_requests
for pr in pull_requests:
yield "pull_request", pr
@@ -280,26 +265,10 @@ class GithubReadPullRequestBlock(Block):
files = response.json()
changes = []
for file in files:
status: str = file.get("status", "")
diff: str = file.get("patch", "")
if status != "removed":
is_filename: str = file.get("filename", "")
was_filename: str = (
file.get("previous_filename", is_filename)
if status != "added"
else ""
)
else:
is_filename = ""
was_filename: str = file.get("filename", "")
patch_header = ""
if was_filename:
patch_header += f"--- {was_filename}\n"
if is_filename:
patch_header += f"+++ {is_filename}\n"
changes.append(patch_header + diff)
return "\n\n".join(changes)
filename = file.get("filename", "")
status = file.get("status", "")
changes.append(f"{filename}: {status}")
return "\n".join(changes)
async def run(
self,
@@ -475,9 +444,6 @@ class GithubListPRReviewersBlock(Block):
title="Reviewer",
description="Reviewers with their username and profile URL",
)
reviewers: list[ReviewerItem] = SchemaField(
description="List of reviewers with their username and profile URL"
)
error: str = SchemaField(
description="Error message if listing reviewers failed"
)
@@ -495,22 +461,13 @@ class GithubListPRReviewersBlock(Block):
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"reviewers",
[
{
"username": "reviewer1",
"url": "https://github.com/reviewer1",
}
],
),
(
"reviewer",
{
"username": "reviewer1",
"url": "https://github.com/reviewer1",
},
),
)
],
test_mock={
"list_reviewers": lambda *args, **kwargs: [
@@ -543,12 +500,10 @@ class GithubListPRReviewersBlock(Block):
credentials: GithubCredentials,
**kwargs,
) -> BlockOutput:
reviewers = await self.list_reviewers(
for reviewer in await self.list_reviewers(
credentials,
input_data.pr_url,
)
yield "reviewers", reviewers
for reviewer in reviewers:
):
yield "reviewer", reviewer

View File

@@ -31,9 +31,6 @@ class GithubListTagsBlock(Block):
tag: TagItem = SchemaField(
title="Tag", description="Tags with their name and file tree browser URL"
)
tags: list[TagItem] = SchemaField(
description="List of tags with their name and file tree browser URL"
)
error: str = SchemaField(description="Error message if listing tags failed")
def __init__(self):
@@ -49,22 +46,13 @@ class GithubListTagsBlock(Block):
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"tags",
[
{
"name": "v1.0.0",
"url": "https://github.com/owner/repo/tree/v1.0.0",
}
],
),
(
"tag",
{
"name": "v1.0.0",
"url": "https://github.com/owner/repo/tree/v1.0.0",
},
),
)
],
test_mock={
"list_tags": lambda *args, **kwargs: [
@@ -105,7 +93,6 @@ class GithubListTagsBlock(Block):
credentials,
input_data.repo_url,
)
yield "tags", tags
for tag in tags:
yield "tag", tag
@@ -127,9 +114,6 @@ class GithubListBranchesBlock(Block):
title="Branch",
description="Branches with their name and file tree browser URL",
)
branches: list[BranchItem] = SchemaField(
description="List of branches with their name and file tree browser URL"
)
error: str = SchemaField(description="Error message if listing branches failed")
def __init__(self):
@@ -145,22 +129,13 @@ class GithubListBranchesBlock(Block):
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"branches",
[
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
}
],
),
(
"branch",
{
"name": "main",
"url": "https://github.com/owner/repo/tree/main",
},
),
)
],
test_mock={
"list_branches": lambda *args, **kwargs: [
@@ -201,7 +176,6 @@ class GithubListBranchesBlock(Block):
credentials,
input_data.repo_url,
)
yield "branches", branches
for branch in branches:
yield "branch", branch
@@ -225,9 +199,6 @@ class GithubListDiscussionsBlock(Block):
discussion: DiscussionItem = SchemaField(
title="Discussion", description="Discussions with their title and URL"
)
discussions: list[DiscussionItem] = SchemaField(
description="List of discussions with their title and URL"
)
error: str = SchemaField(
description="Error message if listing discussions failed"
)
@@ -246,22 +217,13 @@ class GithubListDiscussionsBlock(Block):
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"discussions",
[
{
"title": "Discussion 1",
"url": "https://github.com/owner/repo/discussions/1",
}
],
),
(
"discussion",
{
"title": "Discussion 1",
"url": "https://github.com/owner/repo/discussions/1",
},
),
)
],
test_mock={
"list_discussions": lambda *args, **kwargs: [
@@ -317,7 +279,6 @@ class GithubListDiscussionsBlock(Block):
input_data.repo_url,
input_data.num_discussions,
)
yield "discussions", discussions
for discussion in discussions:
yield "discussion", discussion
@@ -339,9 +300,6 @@ class GithubListReleasesBlock(Block):
title="Release",
description="Releases with their name and file tree browser URL",
)
releases: list[ReleaseItem] = SchemaField(
description="List of releases with their name and file tree browser URL"
)
error: str = SchemaField(description="Error message if listing releases failed")
def __init__(self):
@@ -357,22 +315,13 @@ class GithubListReleasesBlock(Block):
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"releases",
[
{
"name": "v1.0.0",
"url": "https://github.com/owner/repo/releases/tag/v1.0.0",
}
],
),
(
"release",
{
"name": "v1.0.0",
"url": "https://github.com/owner/repo/releases/tag/v1.0.0",
},
),
)
],
test_mock={
"list_releases": lambda *args, **kwargs: [
@@ -408,7 +357,6 @@ class GithubListReleasesBlock(Block):
credentials,
input_data.repo_url,
)
yield "releases", releases
for release in releases:
yield "release", release
@@ -1093,9 +1041,6 @@ class GithubListStargazersBlock(Block):
title="Stargazer",
description="Stargazers with their username and profile URL",
)
stargazers: list[StargazerItem] = SchemaField(
description="List of stargazers with their username and profile URL"
)
error: str = SchemaField(
description="Error message if listing stargazers failed"
)
@@ -1113,22 +1058,13 @@ class GithubListStargazersBlock(Block):
},
test_credentials=TEST_CREDENTIALS,
test_output=[
(
"stargazers",
[
{
"username": "octocat",
"url": "https://github.com/octocat",
}
],
),
(
"stargazer",
{
"username": "octocat",
"url": "https://github.com/octocat",
},
),
)
],
test_mock={
"list_stargazers": lambda *args, **kwargs: [
@@ -1168,6 +1104,5 @@ class GithubListStargazersBlock(Block):
credentials,
input_data.repo_url,
)
yield "stargazers", stargazers
for stargazer in stargazers:
yield "stargazer", stargazer

File diff suppressed because it is too large Load Diff

View File

@@ -3,19 +3,11 @@ import logging
from enum import Enum
from io import BytesIO
from pathlib import Path
from typing import Literal
import aiofiles
from pydantic import SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
HostScopedCredentials,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.data.model import SchemaField
from backend.util.file import (
MediaFileType,
get_exec_file_path,
@@ -27,30 +19,6 @@ from backend.util.request import Requests
logger = logging.getLogger(name=__name__)
# Host-scoped credentials for HTTP requests
HttpCredentials = CredentialsMetaInput[
Literal[ProviderName.HTTP], Literal["host_scoped"]
]
TEST_CREDENTIALS = HostScopedCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="http",
host="api.example.com",
headers={
"Authorization": SecretStr("Bearer test-token"),
"X-API-Key": SecretStr("test-api-key"),
},
title="Mock HTTP Host-Scoped Credentials",
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
class HttpMethod(Enum):
GET = "GET"
POST = "POST"
@@ -201,62 +169,3 @@ class SendWebRequestBlock(Block):
yield "client_error", result
else:
yield "server_error", result
class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
class Input(SendWebRequestBlock.Input):
credentials: HttpCredentials = CredentialsField(
description="HTTP host-scoped credentials for automatic header injection",
discriminator="url",
)
def __init__(self):
Block.__init__(
self,
id="fff86bcd-e001-4bad-a7f6-2eae4720c8dc",
description="Make an authenticated HTTP request with host-scoped credentials (JSON / form / multipart).",
categories={BlockCategory.OUTPUT},
input_schema=SendAuthenticatedWebRequestBlock.Input,
output_schema=SendWebRequestBlock.Output,
test_credentials=TEST_CREDENTIALS,
)
async def run( # type: ignore[override]
self,
input_data: Input,
*,
graph_exec_id: str,
credentials: HostScopedCredentials,
**kwargs,
) -> BlockOutput:
# Create SendWebRequestBlock.Input from our input (removing credentials field)
base_input = SendWebRequestBlock.Input(
url=input_data.url,
method=input_data.method,
headers=input_data.headers,
json_format=input_data.json_format,
body=input_data.body,
files_name=input_data.files_name,
files=input_data.files,
)
# Apply host-scoped credentials to headers
extra_headers = {}
if credentials.matches_url(input_data.url):
logger.debug(
f"Applying host-scoped credentials {credentials.id} for URL {input_data.url}"
)
extra_headers.update(credentials.get_headers_dict())
else:
logger.warning(
f"Host-scoped credentials {credentials.id} do not match URL {input_data.url}"
)
# Merge with user-provided headers (user headers take precedence)
base_input.headers = {**extra_headers, **input_data.headers}
# Use parent class run method
async for output_name, output_data in super().run(
base_input, graph_exec_id=graph_exec_id, **kwargs
):
yield output_name, output_data

View File

@@ -413,12 +413,6 @@ class AgentFileInputBlock(AgentInputBlock):
advanced=False,
title="Default Value",
)
base_64: bool = SchemaField(
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
default=False,
advanced=True,
title="Produce Base64 Output",
)
class Output(AgentInputBlock.Output):
result: str = SchemaField(description="File reference/path result.")
@@ -452,11 +446,12 @@ class AgentFileInputBlock(AgentInputBlock):
if not input_data.value:
return
yield "result", await store_media_file(
file_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.value,
return_content=input_data.base_64,
return_content=False,
)
yield "result", file_path
class AgentDropdownInputBlock(AgentInputBlock):

View File

@@ -23,7 +23,6 @@ from backend.data.model import (
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.logging import TruncatedLogger
from backend.util.prompt import compress_prompt, estimate_token_count
from backend.util.text import TextFormatter
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
@@ -41,7 +40,7 @@ LLMProviderName = Literal[
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
TEST_CREDENTIALS = APIKeyCredentials(
id="769f6af7-820b-4d5d-9b7a-ab82bbc165f",
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
provider="openai",
api_key=SecretStr("mock-openai-api-key"),
title="Mock OpenAI API key",
@@ -127,9 +126,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE = (
"perplexity/llama-3.1-sonar-large-128k-online"
)
PERPLEXITY_SONAR = "perplexity/sonar"
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
QWEN_QWQ_32B_PREVIEW = "qwen/qwq-32b-preview"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
@@ -232,13 +228,6 @@ MODEL_METADATA = {
LlmModel.PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE: ModelMetadata(
"open_router", 127072, 127072
),
LlmModel.PERPLEXITY_SONAR: ModelMetadata("open_router", 127000, 127000),
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata("open_router", 200000, 8000),
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
"open_router",
128000,
128000,
),
LlmModel.QWEN_QWQ_32B_PREVIEW: ModelMetadata("open_router", 32768, 32768),
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata(
"open_router", 131000, 4096
@@ -283,7 +272,6 @@ class LLMResponse(BaseModel):
tool_calls: Optional[List[ToolContentBlock]] | None
prompt_tokens: int
completion_tokens: int
reasoning: Optional[str] = None
def convert_openai_tool_fmt_to_anthropic(
@@ -318,44 +306,11 @@ def convert_openai_tool_fmt_to_anthropic(
return anthropic_tools
def extract_openai_reasoning(response) -> str | None:
"""Extract reasoning from OpenAI-compatible response if available."""
"""Note: This will likely not working since the reasoning is not present in another Response API"""
reasoning = None
choice = response.choices[0]
if hasattr(choice, "reasoning") and getattr(choice, "reasoning", None):
reasoning = str(getattr(choice, "reasoning"))
elif hasattr(response, "reasoning") and getattr(response, "reasoning", None):
reasoning = str(getattr(response, "reasoning"))
elif hasattr(choice.message, "reasoning") and getattr(
choice.message, "reasoning", None
):
reasoning = str(getattr(choice.message, "reasoning"))
return reasoning
def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
"""Extract tool calls from OpenAI-compatible response."""
if response.choices[0].message.tool_calls:
return [
ToolContentBlock(
id=tool.id,
type=tool.type,
function=ToolCall(
name=tool.function.name,
arguments=tool.function.arguments,
),
)
for tool in response.choices[0].message.tool_calls
]
return None
def get_parallel_tool_calls_param(llm_model: LlmModel, parallel_tool_calls):
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
if llm_model.startswith("o") or parallel_tool_calls is None:
return openai.NOT_GIVEN
return parallel_tool_calls
def estimate_token_count(prompt_messages: list[dict]) -> int:
char_count = sum(len(str(msg.get("content", ""))) for msg in prompt_messages)
message_overhead = len(prompt_messages) * 4
estimated_tokens = (char_count // 4) + message_overhead
return int(estimated_tokens * 1.2)
async def llm_call(
@@ -366,8 +321,7 @@ async def llm_call(
max_tokens: int | None,
tools: list[dict] | None = None,
ollama_host: str = "localhost:11434",
parallel_tool_calls=None,
compress_prompt_to_fit: bool = True,
parallel_tool_calls: bool | None = None,
) -> LLMResponse:
"""
Make a call to a language model.
@@ -390,32 +344,28 @@ async def llm_call(
- completion_tokens: The number of tokens used in the completion.
"""
provider = llm_model.metadata.provider
context_window = llm_model.context_window
if compress_prompt_to_fit:
prompt = compress_prompt(
messages=prompt,
target_tokens=llm_model.context_window // 2,
lossy_ok=True,
)
# Calculate available tokens based on context window and input length
estimated_input_tokens = estimate_token_count(prompt)
model_max_output = llm_model.max_output_tokens or int(2**15)
context_window = llm_model.context_window
model_max_output = llm_model.max_output_tokens or 4096
user_max = max_tokens or model_max_output
available_tokens = max(context_window - estimated_input_tokens, 0)
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
max_tokens = max(min(available_tokens, model_max_output, user_max), 0)
if provider == "openai":
tools_param = tools if tools else openai.NOT_GIVEN
oai_client = openai.AsyncOpenAI(api_key=credentials.api_key.get_secret_value())
response_format = None
parallel_tool_calls = get_parallel_tool_calls_param(
llm_model, parallel_tool_calls
)
if json_format:
if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
prompt = [
{"role": "user", "content": "\n".join(sys_messages)},
{"role": "user", "content": "\n".join(usr_messages)},
]
elif json_format:
response_format = {"type": "json_object"}
response = await oai_client.chat.completions.create(
@@ -424,11 +374,25 @@ async def llm_call(
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
tools=tools_param, # type: ignore
parallel_tool_calls=parallel_tool_calls,
parallel_tool_calls=(
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
),
)
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
if response.choices[0].message.tool_calls:
tool_calls = [
ToolContentBlock(
id=tool.id,
type=tool.type,
function=ToolCall(
name=tool.function.name,
arguments=tool.function.arguments,
),
)
for tool in response.choices[0].message.tool_calls
]
else:
tool_calls = None
return LLMResponse(
raw_response=response.choices[0].message,
@@ -437,7 +401,6 @@ async def llm_call(
tool_calls=tool_calls,
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
reasoning=reasoning,
)
elif provider == "anthropic":
@@ -499,12 +462,6 @@ async def llm_call(
f"Tool use stop reason but no tool calls found in content. {resp}"
)
reasoning = None
for content_block in resp.content:
if hasattr(content_block, "type") and content_block.type == "thinking":
reasoning = content_block.thinking
break
return LLMResponse(
raw_response=resp,
prompt=prompt,
@@ -516,7 +473,6 @@ async def llm_call(
tool_calls=tool_calls,
prompt_tokens=resp.usage.input_tokens,
completion_tokens=resp.usage.output_tokens,
reasoning=reasoning,
)
except anthropic.APIError as e:
error_message = f"Anthropic API error: {str(e)}"
@@ -541,7 +497,6 @@ async def llm_call(
tool_calls=None,
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
reasoning=None,
)
elif provider == "ollama":
if tools:
@@ -563,7 +518,6 @@ async def llm_call(
tool_calls=None,
prompt_tokens=response.get("prompt_eval_count") or 0,
completion_tokens=response.get("eval_count") or 0,
reasoning=None,
)
elif provider == "open_router":
tools_param = tools if tools else openai.NOT_GIVEN
@@ -572,10 +526,6 @@ async def llm_call(
api_key=credentials.api_key.get_secret_value(),
)
parallel_tool_calls_param = get_parallel_tool_calls_param(
llm_model, parallel_tool_calls
)
response = await client.chat.completions.create(
extra_headers={
"HTTP-Referer": "https://agpt.co",
@@ -585,7 +535,6 @@ async def llm_call(
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
parallel_tool_calls=parallel_tool_calls_param,
)
# If there's no response, raise an error
@@ -595,8 +544,19 @@ async def llm_call(
else:
raise ValueError("No response from OpenRouter.")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
if response.choices[0].message.tool_calls:
tool_calls = [
ToolContentBlock(
id=tool.id,
type=tool.type,
function=ToolCall(
name=tool.function.name, arguments=tool.function.arguments
),
)
for tool in response.choices[0].message.tool_calls
]
else:
tool_calls = None
return LLMResponse(
raw_response=response.choices[0].message,
@@ -605,7 +565,6 @@ async def llm_call(
tool_calls=tool_calls,
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
reasoning=reasoning,
)
elif provider == "llama_api":
tools_param = tools if tools else openai.NOT_GIVEN
@@ -614,10 +573,6 @@ async def llm_call(
api_key=credentials.api_key.get_secret_value(),
)
parallel_tool_calls_param = get_parallel_tool_calls_param(
llm_model, parallel_tool_calls
)
response = await client.chat.completions.create(
extra_headers={
"HTTP-Referer": "https://agpt.co",
@@ -627,7 +582,9 @@ async def llm_call(
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
parallel_tool_calls=parallel_tool_calls_param,
parallel_tool_calls=(
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
),
)
# If there's no response, raise an error
@@ -637,8 +594,19 @@ async def llm_call(
else:
raise ValueError("No response from Llama API.")
tool_calls = extract_openai_tool_calls(response)
reasoning = extract_openai_reasoning(response)
if response.choices[0].message.tool_calls:
tool_calls = [
ToolContentBlock(
id=tool.id,
type=tool.type,
function=ToolCall(
name=tool.function.name, arguments=tool.function.arguments
),
)
for tool in response.choices[0].message.tool_calls
]
else:
tool_calls = None
return LLMResponse(
raw_response=response.choices[0].message,
@@ -647,7 +615,6 @@ async def llm_call(
tool_calls=tool_calls,
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
completion_tokens=response.usage.completion_tokens if response.usage else 0,
reasoning=reasoning,
)
elif provider == "aiml_api":
client = openai.OpenAI(
@@ -671,7 +638,6 @@ async def llm_call(
completion_tokens=(
completion.usage.completion_tokens if completion.usage else 0
),
reasoning=None,
)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
@@ -697,11 +663,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
description="Expected format of the response. If provided, the response will be validated against this format. "
"The keys should be the expected fields in the response, and the values should be the description of the field.",
)
list_result: bool = SchemaField(
title="List Result",
default=False,
description="Whether the response should be a list of objects in the expected format.",
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
@@ -733,11 +694,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
default=None,
description="The maximum number of tokens to generate in the chat completion.",
)
compress_prompt_to_fit: bool = SchemaField(
advanced=True,
default=True,
description="Whether to compress the prompt to fit within the model's context window.",
)
ollama_host: str = SchemaField(
advanced=True,
default="localhost:11434",
@@ -745,7 +702,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
)
class Output(BlockSchema):
response: dict[str, Any] | list[dict[str, Any]] = SchemaField(
response: dict[str, Any] = SchemaField(
description="The response object generated by the language model."
)
prompt: list = SchemaField(description="The prompt sent to the language model.")
@@ -785,7 +742,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
tool_calls=None,
prompt_tokens=0,
completion_tokens=0,
reasoning=None,
)
},
)
@@ -796,7 +752,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
llm_model: LlmModel,
prompt: list[dict],
json_format: bool,
compress_prompt_to_fit: bool,
max_tokens: int | None,
tools: list[dict] | None = None,
ollama_host: str = "localhost:11434",
@@ -814,7 +769,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
max_tokens=max_tokens,
tools=tools,
ollama_host=ollama_host,
compress_prompt_to_fit=compress_prompt_to_fit,
)
async def run(
@@ -839,22 +793,13 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
expected_format = [
f'"{k}": "{v}"' for k, v in input_data.expected_format.items()
]
if input_data.list_result:
format_prompt = (
f'"results": [\n {{\n {", ".join(expected_format)}\n }}\n]'
)
else:
format_prompt = "\n ".join(expected_format)
format_prompt = ",\n ".join(expected_format)
sys_prompt = trim_prompt(
f"""
|Reply strictly only in the following JSON format:
|{{
| {format_prompt}
|}}
|
|Ensure the response is valid JSON. Do not include any additional text outside of the JSON.
|If you cannot provide all the keys, provide an empty string for the values you cannot answer.
"""
)
prompt.append({"role": "system", "content": sys_prompt})
@@ -862,18 +807,19 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
if input_data.prompt:
prompt.append({"role": "user", "content": input_data.prompt})
def validate_response(parsed: object) -> str | None:
def parse_response(resp: str) -> tuple[dict[str, Any], str | None]:
try:
parsed = json.loads(resp)
if not isinstance(parsed, dict):
return f"Expected a dictionary, but got {type(parsed)}"
return {}, f"Expected a dictionary, but got {type(parsed)}"
miss_keys = set(input_data.expected_format.keys()) - set(parsed.keys())
if miss_keys:
return f"Missing keys: {miss_keys}"
return None
return parsed, f"Missing keys: {miss_keys}"
return parsed, None
except JSONDecodeError as e:
return f"JSON decode error: {e}"
return {}, f"JSON decode error: {e}"
logger.debug(f"LLM request: {prompt}")
logger.info(f"LLM request: {prompt}")
retry_prompt = ""
llm_model = input_data.model
@@ -883,7 +829,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
credentials=credentials,
llm_model=llm_model,
prompt=prompt,
compress_prompt_to_fit=input_data.compress_prompt_to_fit,
json_format=bool(input_data.expected_format),
ollama_host=input_data.ollama_host,
max_tokens=input_data.max_tokens,
@@ -895,32 +840,21 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
output_token_count=llm_response.completion_tokens,
)
)
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
logger.info(f"LLM attempt-{retry_count} response: {response_text}")
if input_data.expected_format:
response_obj = json.loads(response_text)
if input_data.list_result and isinstance(response_obj, dict):
if "results" in response_obj:
response_obj = response_obj.get("results", [])
elif len(response_obj) == 1:
response_obj = list(response_obj.values())
response_error = "\n".join(
[
validation_error
for response_item in (
response_obj
if isinstance(response_obj, list)
else [response_obj]
parsed_dict, parsed_error = parse_response(response_text)
if not parsed_error:
yield "response", {
k: (
json.loads(v)
if isinstance(v, str)
and v.startswith("[")
and v.endswith("]")
else (", ".join(v) if isinstance(v, list) else v)
)
if (validation_error := validate_response(response_item))
]
)
if not response_error:
yield "response", response_obj
for k, v in parsed_dict.items()
}
yield "prompt", self.prompt
return
else:
@@ -937,7 +871,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
|And this is the error:
|--
|{response_error}
|{parsed_error}
|--
"""
)

View File

@@ -13,7 +13,7 @@ from backend.data.model import (
from backend.integrations.providers import ProviderName
TEST_CREDENTIALS = APIKeyCredentials(
id="8cc8b2c5-d3e4-4b1c-84ad-e1e9fe2a0122",
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
provider="mem0",
api_key=SecretStr("mock-mem0-api-key"),
title="Mock Mem0 API key",
@@ -67,19 +67,17 @@ class AddMemoryBlock(Block, Mem0Base):
metadata: dict[str, Any] = SchemaField(
description="Optional metadata for the memory", default_factory=dict
)
limit_memory_to_run: bool = SchemaField(
description="Limit the memory to the run", default=False
)
limit_memory_to_agent: bool = SchemaField(
description="Limit the memory to the agent", default=True
description="Limit the memory to the agent", default=False
)
class Output(BlockSchema):
action: str = SchemaField(description="Action of the operation")
memory: str = SchemaField(description="Memory created")
results: list[dict[str, str]] = SchemaField(
description="List of all results from the operation"
)
error: str = SchemaField(description="Error message if operation fails")
def __init__(self):
@@ -106,14 +104,7 @@ class AddMemoryBlock(Block, Mem0Base):
"credentials": TEST_CREDENTIALS_INPUT,
},
],
test_output=[
("results", [{"event": "CREATED", "memory": "test memory"}]),
("action", "CREATED"),
("memory", "test memory"),
("results", [{"event": "CREATED", "memory": "test memory"}]),
("action", "CREATED"),
("memory", "test memory"),
],
test_output=[("action", "NO_CHANGE"), ("action", "NO_CHANGE")],
test_credentials=TEST_CREDENTIALS,
test_mock={"_get_client": lambda credentials: MockMemoryClient()},
)
@@ -126,7 +117,7 @@ class AddMemoryBlock(Block, Mem0Base):
user_id: str,
graph_id: str,
graph_exec_id: str,
**kwargs,
**kwargs
) -> BlockOutput:
try:
client = self._get_client(credentials)
@@ -155,11 +146,8 @@ class AddMemoryBlock(Block, Mem0Base):
**params,
)
results = result.get("results", [])
yield "results", results
if len(results) > 0:
for result in results:
if len(result.get("results", [])) > 0:
for result in result.get("results", []):
yield "action", result["event"]
yield "memory", result["memory"]
else:
@@ -190,10 +178,6 @@ class SearchMemoryBlock(Block, Mem0Base):
default_factory=list,
advanced=True,
)
metadata_filter: Optional[dict[str, Any]] = SchemaField(
description="Optional metadata filters to apply",
default=None,
)
limit_memory_to_run: bool = SchemaField(
description="Limit the memory to the run", default=False
)
@@ -232,7 +216,7 @@ class SearchMemoryBlock(Block, Mem0Base):
user_id: str,
graph_id: str,
graph_exec_id: str,
**kwargs,
**kwargs
) -> BlockOutput:
try:
client = self._get_client(credentials)
@@ -251,8 +235,6 @@ class SearchMemoryBlock(Block, Mem0Base):
filters["AND"].append({"run_id": graph_exec_id})
if input_data.limit_memory_to_agent:
filters["AND"].append({"agent_id": graph_id})
if input_data.metadata_filter:
filters["AND"].append({"metadata": input_data.metadata_filter})
result: list[dict[str, Any]] = client.search(
input_data.query, version="v2", filters=filters
@@ -278,15 +260,11 @@ class GetAllMemoriesBlock(Block, Mem0Base):
categories: Optional[list[str]] = SchemaField(
description="Filter by categories", default=None
)
metadata_filter: Optional[dict[str, Any]] = SchemaField(
description="Optional metadata filters to apply",
default=None,
)
limit_memory_to_run: bool = SchemaField(
description="Limit the memory to the run", default=False
)
limit_memory_to_agent: bool = SchemaField(
description="Limit the memory to the agent", default=True
description="Limit the memory to the agent", default=False
)
class Output(BlockSchema):
@@ -296,11 +274,11 @@ class GetAllMemoriesBlock(Block, Mem0Base):
def __init__(self):
super().__init__(
id="45aee5bf-4767-45d1-a28b-e01c5aae9fc1",
description="Retrieve all memories from Mem0 with optional conversation filtering",
description="Retrieve all memories from Mem0 with pagination",
input_schema=GetAllMemoriesBlock.Input,
output_schema=GetAllMemoriesBlock.Output,
test_input={
"metadata_filter": {"type": "test"},
"user_id": "test_user",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
@@ -318,7 +296,7 @@ class GetAllMemoriesBlock(Block, Mem0Base):
user_id: str,
graph_id: str,
graph_exec_id: str,
**kwargs,
**kwargs
) -> BlockOutput:
try:
client = self._get_client(credentials)
@@ -336,8 +314,6 @@ class GetAllMemoriesBlock(Block, Mem0Base):
filters["AND"].append(
{"categories": {"contains": input_data.categories}}
)
if input_data.metadata_filter:
filters["AND"].append({"metadata": input_data.metadata_filter})
memories: list[dict[str, Any]] = client.get_all(
filters=filters,
@@ -350,116 +326,14 @@ class GetAllMemoriesBlock(Block, Mem0Base):
yield "error", str(e)
class GetLatestMemoryBlock(Block, Mem0Base):
"""Block for retrieving the latest memory from Mem0"""
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal[ProviderName.MEM0], Literal["api_key"]
] = CredentialsField(description="Mem0 API key credentials")
trigger: bool = SchemaField(
description="An unused field that is used to trigger the block when you have no other inputs",
default=False,
advanced=False,
)
categories: Optional[list[str]] = SchemaField(
description="Filter by categories", default=None
)
conversation_id: Optional[str] = SchemaField(
description="Optional conversation ID to retrieve the latest memory from (uses run_id)",
default=None,
)
metadata_filter: Optional[dict[str, Any]] = SchemaField(
description="Optional metadata filters to apply",
default=None,
)
limit_memory_to_run: bool = SchemaField(
description="Limit the memory to the run", default=False
)
limit_memory_to_agent: bool = SchemaField(
description="Limit the memory to the agent", default=True
)
class Output(BlockSchema):
memory: Optional[dict[str, Any]] = SchemaField(
description="Latest memory if found"
)
found: bool = SchemaField(description="Whether a memory was found")
error: str = SchemaField(description="Error message if operation fails")
def __init__(self):
super().__init__(
id="0f9d81b5-a145-4c23-b87f-01d6bf37b677",
description="Retrieve the latest memory from Mem0 with optional key filtering",
input_schema=GetLatestMemoryBlock.Input,
output_schema=GetLatestMemoryBlock.Output,
test_input={
"metadata_filter": {"type": "test"},
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
("memory", {"id": "test-memory", "content": "test content"}),
("found", True),
],
test_credentials=TEST_CREDENTIALS,
test_mock={"_get_client": lambda credentials: MockMemoryClient()},
)
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
user_id: str,
graph_id: str,
graph_exec_id: str,
**kwargs,
) -> BlockOutput:
try:
client = self._get_client(credentials)
filters: Filter = {
"AND": [
{"user_id": user_id},
]
}
if input_data.limit_memory_to_run:
filters["AND"].append({"run_id": graph_exec_id})
if input_data.limit_memory_to_agent:
filters["AND"].append({"agent_id": graph_id})
if input_data.categories:
filters["AND"].append(
{"categories": {"contains": input_data.categories}}
)
if input_data.metadata_filter:
filters["AND"].append({"metadata": input_data.metadata_filter})
memories: list[dict[str, Any]] = client.get_all(
filters=filters,
version="v2",
)
if memories:
# Return the latest memory (first in the list as they're sorted by recency)
latest_memory = memories[0]
yield "memory", latest_memory
yield "found", True
else:
yield "memory", None
yield "found", False
except Exception as e:
yield "error", str(e)
# Mock client for testing
class MockMemoryClient:
"""Mock Mem0 client for testing"""
def add(self, *args, **kwargs):
return {"results": [{"event": "CREATED", "memory": "test memory"}]}
return {"memory_id": "test-memory-id", "status": "success"}
def search(self, *args, **kwargs) -> list[dict[str, Any]]:
def search(self, *args, **kwargs) -> list[dict[str, str]]:
return [{"id": "test-memory", "content": "test content"}]
def get_all(self, *args, **kwargs) -> list[dict[str, str]]:

View File

@@ -1,155 +0,0 @@
import logging
from typing import Any, Literal
from autogpt_libs.utils.cache import thread_cached
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
logger = logging.getLogger(__name__)
@thread_cached
def get_database_manager_client():
from backend.executor import DatabaseManagerAsyncClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
StorageScope = Literal["within_agent", "across_agents"]
def get_storage_key(key: str, scope: StorageScope, graph_id: str) -> str:
"""Generate the storage key based on scope"""
if scope == "across_agents":
return f"global#{key}"
else:
return f"agent#{graph_id}#{key}"
class PersistInformationBlock(Block):
"""Block for persisting key-value data for the current user with configurable scope"""
class Input(BlockSchema):
key: str = SchemaField(description="Key to store the information under")
value: Any = SchemaField(description="Value to store")
scope: StorageScope = SchemaField(
description="Scope of persistence: within_agent (shared across all runs of this agent) or across_agents (shared across all agents for this user)",
default="within_agent",
)
class Output(BlockSchema):
value: Any = SchemaField(description="Value that was stored")
def __init__(self):
super().__init__(
id="1d055e55-a2b9-4547-8311-907d05b0304d",
description="Persist key-value information for the current user",
categories={BlockCategory.DATA},
input_schema=PersistInformationBlock.Input,
output_schema=PersistInformationBlock.Output,
test_input={
"key": "user_preference",
"value": {"theme": "dark", "language": "en"},
"scope": "within_agent",
},
test_output=[
("value", {"theme": "dark", "language": "en"}),
],
test_mock={
"_store_data": lambda *args, **kwargs: {
"theme": "dark",
"language": "en",
}
},
)
async def run(
self,
input_data: Input,
*,
user_id: str,
graph_id: str,
node_exec_id: str,
**kwargs,
) -> BlockOutput:
# Determine the storage key based on scope
storage_key = get_storage_key(input_data.key, input_data.scope, graph_id)
# Store the data
yield "value", await self._store_data(
user_id=user_id,
node_exec_id=node_exec_id,
key=storage_key,
data=input_data.value,
)
async def _store_data(
self, user_id: str, node_exec_id: str, key: str, data: Any
) -> Any | None:
return await get_database_manager_client().set_execution_kv_data(
user_id=user_id,
node_exec_id=node_exec_id,
key=key,
data=data,
)
class RetrieveInformationBlock(Block):
"""Block for retrieving key-value data for the current user with configurable scope"""
class Input(BlockSchema):
key: str = SchemaField(description="Key to retrieve the information for")
scope: StorageScope = SchemaField(
description="Scope of persistence: within_agent (shared across all runs of this agent) or across_agents (shared across all agents for this user)",
default="within_agent",
)
default_value: Any = SchemaField(
description="Default value to return if key is not found", default=None
)
class Output(BlockSchema):
value: Any = SchemaField(description="Retrieved value or default value")
def __init__(self):
super().__init__(
id="d8710fc9-6e29-481e-a7d5-165eb16f8471",
description="Retrieve key-value information for the current user",
categories={BlockCategory.DATA},
input_schema=RetrieveInformationBlock.Input,
output_schema=RetrieveInformationBlock.Output,
test_input={
"key": "user_preference",
"scope": "within_agent",
"default_value": {"theme": "light", "language": "en"},
},
test_output=[
("value", {"theme": "light", "language": "en"}),
],
test_mock={"_retrieve_data": lambda *args, **kwargs: None},
static_output=True,
)
async def run(
self, input_data: Input, *, user_id: str, graph_id: str, **kwargs
) -> BlockOutput:
# Determine the storage key based on scope
storage_key = get_storage_key(input_data.key, input_data.scope, graph_id)
# Retrieve the data
stored_value = await self._retrieve_data(
user_id=user_id,
key=storage_key,
)
if stored_value is not None:
yield "value", stored_value
else:
yield "value", input_data.default_value
async def _retrieve_data(self, user_id: str, key: str) -> Any | None:
return await get_database_manager_client().get_execution_kv_data(
user_id=user_id,
key=key,
)

View File

@@ -96,7 +96,6 @@ class GetRedditPostsBlock(Block):
class Output(BlockSchema):
post: RedditPost = SchemaField(description="Reddit post")
posts: list[RedditPost] = SchemaField(description="List of all Reddit posts")
def __init__(self):
super().__init__(
@@ -129,23 +128,6 @@ class GetRedditPostsBlock(Block):
id="id2", subreddit="subreddit", title="title2", body="body2"
),
),
(
"posts",
[
RedditPost(
id="id1",
subreddit="subreddit",
title="title1",
body="body1",
),
RedditPost(
id="id2",
subreddit="subreddit",
title="title2",
body="body2",
),
],
),
],
test_mock={
"get_posts": lambda input_data, credentials: [
@@ -168,7 +150,6 @@ class GetRedditPostsBlock(Block):
self, input_data: Input, *, credentials: RedditCredentials, **kwargs
) -> BlockOutput:
current_time = datetime.now(tz=timezone.utc)
all_posts = []
for post in self.get_posts(input_data=input_data, credentials=credentials):
if input_data.last_minutes:
post_datetime = datetime.fromtimestamp(
@@ -181,16 +162,12 @@ class GetRedditPostsBlock(Block):
if input_data.last_post and post.id == input_data.last_post:
break
reddit_post = RedditPost(
yield "post", RedditPost(
id=post.id,
subreddit=input_data.subreddit,
title=post.title,
body=post.selftext,
)
all_posts.append(reddit_post)
yield "post", reddit_post
yield "posts", all_posts
class PostRedditCommentBlock(Block):

View File

@@ -40,7 +40,6 @@ class ReadRSSFeedBlock(Block):
class Output(BlockSchema):
entry: RSSEntry = SchemaField(description="The RSS item")
entries: list[RSSEntry] = SchemaField(description="List of all RSS entries")
def __init__(self):
super().__init__(
@@ -67,21 +66,6 @@ class ReadRSSFeedBlock(Block):
categories=["Technology", "News"],
),
),
(
"entries",
[
RSSEntry(
title="Example RSS Item",
link="https://example.com/article",
description="This is an example RSS item description.",
pub_date=datetime(
2023, 6, 23, 12, 30, 0, tzinfo=timezone.utc
),
author="John Doe",
categories=["Technology", "News"],
),
],
),
],
test_mock={
"parse_feed": lambda *args, **kwargs: {
@@ -112,22 +96,21 @@ class ReadRSSFeedBlock(Block):
keep_going = input_data.run_continuously
feed = self.parse_feed(input_data.rss_url)
all_entries = []
for entry in feed["entries"]:
pub_date = datetime(*entry["published_parsed"][:6], tzinfo=timezone.utc)
if pub_date > start_time:
rss_entry = RSSEntry(
title=entry["title"],
link=entry["link"],
description=entry.get("summary", ""),
pub_date=pub_date,
author=entry.get("author", ""),
categories=[tag["term"] for tag in entry.get("tags", [])],
yield (
"entry",
RSSEntry(
title=entry["title"],
link=entry["link"],
description=entry.get("summary", ""),
pub_date=pub_date,
author=entry.get("author", ""),
categories=[tag["term"] for tag in entry.get("tags", [])],
),
)
all_entries.append(rss_entry)
yield "entry", rss_entry
yield "entries", all_entries
await asyncio.sleep(input_data.polling_rate)

View File

@@ -26,10 +26,10 @@ logger = logging.getLogger(__name__)
@thread_cached
def get_database_manager_client():
from backend.executor import DatabaseManagerAsyncClient
from backend.executor import DatabaseManagerClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
return get_service_client(DatabaseManagerClient)
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
@@ -85,7 +85,7 @@ def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
return tool_call_ids
def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
def _create_tool_response(call_id: str, output: dict[str, Any]) -> dict[str, Any]:
"""
Create a tool response message for either OpenAI or Anthropics,
based on the tool_id format.
@@ -142,12 +142,6 @@ class SmartDecisionMakerBlock(Block):
advanced=False,
)
credentials: llm.AICredentials = llm.AICredentialsField()
multiple_tool_calls: bool = SchemaField(
title="Multiple Tool Calls",
default=False,
description="Whether to allow multiple tool calls in a single response.",
advanced=True,
)
sys_prompt: str = SchemaField(
title="System Prompt",
default="Thinking carefully step by step decide which function to call. "
@@ -156,7 +150,7 @@ class SmartDecisionMakerBlock(Block):
"matching the required jsonschema signature, no missing argument is allowed. "
"If you have already completed the task objective, you can end the task "
"by providing the end result of your work as a finish message. "
"Function parameters that has no default value and not optional typed has to be provided. ",
"Only provide EXACTLY one function call, multiple tool calls is strictly prohibited.",
description="The system prompt to provide additional context to the model.",
)
conversation_history: list[dict] = SchemaField(
@@ -212,15 +206,6 @@ class SmartDecisionMakerBlock(Block):
"link like the output of `StoreValue` or `AgentInput` block"
)
# Check that both conversation_history and last_tool_output are connected together
if any(link.sink_name == "conversation_history" for link in links) != any(
link.sink_name == "last_tool_output" for link in links
):
raise ValueError(
"Last Tool Output is needed when Conversation History is used, "
"and vice versa. Please connect both inputs together."
)
return missing_links
@classmethod
@@ -231,15 +216,8 @@ class SmartDecisionMakerBlock(Block):
conversation_history = data.get("conversation_history", [])
pending_tool_calls = get_pending_tool_calls(conversation_history)
last_tool_output = data.get("last_tool_output")
# Tool call is pending, wait for the tool output to be provided.
if last_tool_output is None and pending_tool_calls:
if not last_tool_output and pending_tool_calls:
return {"last_tool_output"}
# No tool call is pending, wait for the conversation history to be updated.
if last_tool_output is not None and not pending_tool_calls:
return {"conversation_history"}
return set()
class Output(BlockSchema):
@@ -273,7 +251,7 @@ class SmartDecisionMakerBlock(Block):
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
@staticmethod
async def _create_block_function_signature(
def _create_block_function_signature(
sink_node: "Node", links: list["Link"]
) -> dict[str, Any]:
"""
@@ -295,24 +273,35 @@ class SmartDecisionMakerBlock(Block):
"name": SmartDecisionMakerBlock.cleanup(block.name),
"description": block.description,
}
sink_block_input_schema = block.input_schema
properties = {}
required = []
for link in links:
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
properties[sink_name] = sink_block_input_schema.get_field_schema(
link.sink_name
sink_block_input_schema = block.input_schema
description = (
sink_block_input_schema.model_fields[link.sink_name].description
if link.sink_name in sink_block_input_schema.model_fields
and sink_block_input_schema.model_fields[link.sink_name].description
else f"The {link.sink_name} of the tool"
)
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
"type": "string",
"description": description,
}
tool_function["parameters"] = {
**block.input_schema.jsonschema(),
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": False,
"strict": True,
}
return {"type": "function", "function": tool_function}
@staticmethod
async def _create_agent_function_signature(
def _create_agent_function_signature(
sink_node: "Node", links: list["Link"]
) -> dict[str, Any]:
"""
@@ -334,7 +323,7 @@ class SmartDecisionMakerBlock(Block):
raise ValueError("Graph ID or Graph Version not found in sink node.")
db_client = get_database_manager_client()
sink_graph_meta = await db_client.get_graph_metadata(graph_id, graph_version)
sink_graph_meta = db_client.get_graph_metadata(graph_id, graph_version)
if not sink_graph_meta:
raise ValueError(
f"Sink graph metadata not found: {graph_id} {graph_version}"
@@ -346,27 +335,25 @@ class SmartDecisionMakerBlock(Block):
}
properties = {}
required = []
for link in links:
sink_block_input_schema = sink_node.input_default["input_schema"]
sink_block_properties = sink_block_input_schema.get("properties", {}).get(
link.sink_name, {}
)
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
description = (
sink_block_properties["description"]
if "description" in sink_block_properties
sink_block_input_schema["properties"][link.sink_name]["description"]
if "description"
in sink_block_input_schema["properties"][link.sink_name]
else f"The {link.sink_name} of the tool"
)
properties[sink_name] = {
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
"type": "string",
"description": description,
"default": json.dumps(sink_block_properties.get("default", None)),
}
tool_function["parameters"] = {
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": False,
"strict": True,
}
@@ -374,7 +361,7 @@ class SmartDecisionMakerBlock(Block):
return {"type": "function", "function": tool_function}
@staticmethod
async def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
"""
Creates function signatures for tools linked to a specified node within a graph.
@@ -396,13 +383,13 @@ class SmartDecisionMakerBlock(Block):
db_client = get_database_manager_client()
tools = [
(link, node)
for link, node in await db_client.get_connected_output_nodes(node_id)
for link, node in db_client.get_connected_output_nodes(node_id)
if link.source_name.startswith("tools_^_") and link.source_id == node_id
]
if not tools:
raise ValueError("There is no next node to execute.")
return_tool_functions: list[dict[str, Any]] = []
return_tool_functions = []
grouped_tool_links: dict[str, tuple["Node", list["Link"]]] = {}
for link, node in tools:
@@ -417,13 +404,13 @@ class SmartDecisionMakerBlock(Block):
if sink_node.block_id == AgentExecutorBlock().id:
return_tool_functions.append(
await SmartDecisionMakerBlock._create_agent_function_signature(
SmartDecisionMakerBlock._create_agent_function_signature(
sink_node, links
)
)
else:
return_tool_functions.append(
await SmartDecisionMakerBlock._create_block_function_signature(
SmartDecisionMakerBlock._create_block_function_signature(
sink_node, links
)
)
@@ -442,43 +429,37 @@ class SmartDecisionMakerBlock(Block):
user_id: str,
**kwargs,
) -> BlockOutput:
tool_functions = await self._create_function_signature(node_id)
yield "tool_functions", json.dumps(tool_functions)
tool_functions = self._create_function_signature(node_id)
input_data.conversation_history = input_data.conversation_history or []
prompt = [json.to_dict(p) for p in input_data.conversation_history if p]
pending_tool_calls = get_pending_tool_calls(input_data.conversation_history)
if pending_tool_calls and input_data.last_tool_output is None:
if pending_tool_calls and not input_data.last_tool_output:
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
# Only assign the last tool output to the first pending tool call
tool_output = []
if pending_tool_calls and input_data.last_tool_output is not None:
# Get the first pending tool call ID
first_call_id = next(iter(pending_tool_calls.keys()))
tool_output.append(
_create_tool_response(first_call_id, input_data.last_tool_output)
# Prefill all missing tool calls with the last tool output/
# TODO: we need a better way to handle this.
tool_output = [
_create_tool_response(pending_call_id, input_data.last_tool_output)
for pending_call_id, count in pending_tool_calls.items()
for _ in range(count)
]
# If the SDM block only calls 1 tool at a time, this should not happen.
if len(tool_output) > 1:
logger.warning(
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
f"Multiple pending tool calls are prefilled using a single output. "
f"Execution may not be accurate."
)
# Add tool output to prompt right away
prompt.extend(tool_output)
# Check if there are still pending tool calls after handling the first one
remaining_pending_calls = get_pending_tool_calls(prompt)
# If there are still pending tool calls, yield the conversation and return early
if remaining_pending_calls:
yield "conversations", prompt
return
# Fallback on adding tool output in the conversation history as user prompt.
elif input_data.last_tool_output:
logger.error(
if len(tool_output) == 0 and input_data.last_tool_output:
logger.warning(
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
f"No pending tool calls found. This may indicate an issue with the "
f"conversation history, or the tool giving response more than once."
f"This should not happen! Please check the conversation history for any inconsistencies."
f"conversation history, or an LLM calling two tools at the same time."
)
tool_output.append(
{
@@ -486,11 +467,8 @@ class SmartDecisionMakerBlock(Block):
"content": f"Last tool output: {json.dumps(input_data.last_tool_output)}",
}
)
prompt.extend(tool_output)
if input_data.multiple_tool_calls:
input_data.sys_prompt += "\nYou can call a tool (different tools) multiple times in a single response."
else:
input_data.sys_prompt += "\nOnly provide EXACTLY one function call, multiple tool calls is strictly prohibited."
prompt.extend(tool_output)
values = input_data.prompt_values
if values:
@@ -517,7 +495,7 @@ class SmartDecisionMakerBlock(Block):
max_tokens=input_data.max_tokens,
tools=tool_functions,
ollama_host=input_data.ollama_host,
parallel_tool_calls=input_data.multiple_tool_calls,
parallel_tool_calls=False,
)
if not response.tool_calls:
@@ -528,37 +506,8 @@ class SmartDecisionMakerBlock(Block):
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
# Find the tool definition to get the expected arguments
tool_def = next(
(
tool
for tool in tool_functions
if tool["function"]["name"] == tool_name
),
None,
)
for arg_name, arg_value in tool_args.items():
yield f"tools_^_{tool_name}_~_{arg_name}", arg_value
if (
tool_def
and "function" in tool_def
and "parameters" in tool_def["function"]
):
expected_args = tool_def["function"]["parameters"].get("properties", {})
else:
expected_args = tool_args.keys()
# Yield provided arguments and None for missing ones
for arg_name in expected_args:
if arg_name in tool_args:
yield f"tools_^_{tool_name}_~_{arg_name}", tool_args[arg_name]
else:
yield f"tools_^_{tool_name}_~_{arg_name}", None
# Add reasoning to conversation history if available
if response.reasoning:
prompt.append(
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
)
prompt.append(response.raw_response)
yield "conversations", prompt
response.prompt.append(response.raw_response)
yield "conversations", response.prompt

View File

@@ -17,7 +17,7 @@ from backend.blocks.smartlead.models import (
Sequence,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import CredentialsField, SchemaField
from backend.data.model import SchemaField
class CreateCampaignBlock(Block):
@@ -27,7 +27,7 @@ class CreateCampaignBlock(Block):
name: str = SchemaField(
description="The name of the campaign",
)
credentials: SmartLeadCredentialsInput = CredentialsField(
credentials: SmartLeadCredentialsInput = SchemaField(
description="SmartLead credentials",
)
@@ -119,7 +119,7 @@ class AddLeadToCampaignBlock(Block):
description="Settings for lead upload",
default=LeadUploadSettings(),
)
credentials: SmartLeadCredentialsInput = CredentialsField(
credentials: SmartLeadCredentialsInput = SchemaField(
description="SmartLead credentials",
)
@@ -251,7 +251,7 @@ class SaveCampaignSequencesBlock(Block):
default_factory=list,
advanced=False,
)
credentials: SmartLeadCredentialsInput = CredentialsField(
credentials: SmartLeadCredentialsInput = SchemaField(
description="SmartLead credentials",
)

View File

@@ -1,485 +0,0 @@
"""Comprehensive tests for HTTP block with HostScopedCredentials functionality."""
from typing import cast
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
from backend.blocks.http import (
HttpCredentials,
HttpMethod,
SendAuthenticatedWebRequestBlock,
)
from backend.data.model import HostScopedCredentials
from backend.util.request import Response
class TestHttpBlockWithHostScopedCredentials:
"""Test suite for HTTP block integration with HostScopedCredentials."""
@pytest.fixture
def http_block(self):
"""Create an HTTP block instance."""
return SendAuthenticatedWebRequestBlock()
@pytest.fixture
def mock_response(self):
"""Mock a successful HTTP response."""
response = MagicMock(spec=Response)
response.status = 200
response.headers = {"content-type": "application/json"}
response.json.return_value = {"success": True, "data": "test"}
return response
@pytest.fixture
def exact_match_credentials(self):
"""Create host-scoped credentials for exact domain matching."""
return HostScopedCredentials(
provider="http",
host="api.example.com",
headers={
"Authorization": SecretStr("Bearer exact-match-token"),
"X-API-Key": SecretStr("api-key-123"),
},
title="Exact Match API Credentials",
)
@pytest.fixture
def wildcard_credentials(self):
"""Create host-scoped credentials with wildcard pattern."""
return HostScopedCredentials(
provider="http",
host="*.github.com",
headers={
"Authorization": SecretStr("token ghp_wildcard123"),
},
title="GitHub Wildcard Credentials",
)
@pytest.fixture
def non_matching_credentials(self):
"""Create credentials that don't match test URLs."""
return HostScopedCredentials(
provider="http",
host="different.api.com",
headers={
"Authorization": SecretStr("Bearer non-matching-token"),
},
title="Non-matching Credentials",
)
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_http_block_with_exact_host_match(
self,
mock_requests_class,
http_block,
exact_match_credentials,
mock_response,
):
"""Test HTTP block with exact host matching credentials."""
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
# Prepare input data
input_data = SendAuthenticatedWebRequestBlock.Input(
url="https://api.example.com/data",
method=HttpMethod.GET,
headers={"User-Agent": "test-agent"},
credentials=cast(
HttpCredentials,
{
"id": exact_match_credentials.id,
"provider": "http",
"type": "host_scoped",
"title": exact_match_credentials.title,
},
),
)
# Execute with credentials provided by execution manager
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=exact_match_credentials,
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify request headers include both credential and user headers
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
expected_headers = {
"Authorization": "Bearer exact-match-token",
"X-API-Key": "api-key-123",
"User-Agent": "test-agent",
}
assert call_args.kwargs["headers"] == expected_headers
# Verify response handling
assert len(result) == 1
assert result[0][0] == "response"
assert result[0][1] == {"success": True, "data": "test"}
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_http_block_with_wildcard_host_match(
self,
mock_requests_class,
http_block,
wildcard_credentials,
mock_response,
):
"""Test HTTP block with wildcard host pattern matching."""
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
# Test with subdomain that should match *.github.com
input_data = SendAuthenticatedWebRequestBlock.Input(
url="https://api.github.com/user",
method=HttpMethod.GET,
headers={},
credentials=cast(
HttpCredentials,
{
"id": wildcard_credentials.id,
"provider": "http",
"type": "host_scoped",
"title": wildcard_credentials.title,
},
),
)
# Execute with wildcard credentials
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=wildcard_credentials,
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify wildcard matching works
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
expected_headers = {"Authorization": "token ghp_wildcard123"}
assert call_args.kwargs["headers"] == expected_headers
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_http_block_with_non_matching_credentials(
self,
mock_requests_class,
http_block,
non_matching_credentials,
mock_response,
):
"""Test HTTP block when credentials don't match the target URL."""
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
# Test with URL that doesn't match the credentials
input_data = SendAuthenticatedWebRequestBlock.Input(
url="https://api.example.com/data",
method=HttpMethod.GET,
headers={"User-Agent": "test-agent"},
credentials=cast(
HttpCredentials,
{
"id": non_matching_credentials.id,
"provider": "http",
"type": "host_scoped",
"title": non_matching_credentials.title,
},
),
)
# Execute with non-matching credentials
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=non_matching_credentials,
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify only user headers are included (no credential headers)
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
expected_headers = {"User-Agent": "test-agent"}
assert call_args.kwargs["headers"] == expected_headers
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_user_headers_override_credential_headers(
self,
mock_requests_class,
http_block,
exact_match_credentials,
mock_response,
):
"""Test that user-provided headers take precedence over credential headers."""
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
# Test with user header that conflicts with credential header
input_data = SendAuthenticatedWebRequestBlock.Input(
url="https://api.example.com/data",
method=HttpMethod.POST,
headers={
"Authorization": "Bearer user-override-token", # Should override
"Content-Type": "application/json", # Additional user header
},
credentials=cast(
HttpCredentials,
{
"id": exact_match_credentials.id,
"provider": "http",
"type": "host_scoped",
"title": exact_match_credentials.title,
},
),
)
# Execute with conflicting headers
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=exact_match_credentials,
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify user headers take precedence
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
expected_headers = {
"X-API-Key": "api-key-123", # From credentials
"Authorization": "Bearer user-override-token", # User override
"Content-Type": "application/json", # User header
}
assert call_args.kwargs["headers"] == expected_headers
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_auto_discovered_credentials_flow(
self,
mock_requests_class,
http_block,
mock_response,
):
"""Test the auto-discovery flow where execution manager provides matching credentials."""
# Create auto-discovered credentials
auto_discovered_creds = HostScopedCredentials(
provider="http",
host="*.example.com",
headers={
"Authorization": SecretStr("Bearer auto-discovered-token"),
},
title="Auto-discovered Credentials",
)
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
# Test with empty credentials field (triggers auto-discovery)
input_data = SendAuthenticatedWebRequestBlock.Input(
url="https://api.example.com/data",
method=HttpMethod.GET,
headers={},
credentials=cast(
HttpCredentials,
{
"id": "", # Empty ID triggers auto-discovery in execution manager
"provider": "http",
"type": "host_scoped",
"title": "",
},
),
)
# Execute with auto-discovered credentials provided by execution manager
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=auto_discovered_creds, # Execution manager found these
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify auto-discovered credentials were applied
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
expected_headers = {"Authorization": "Bearer auto-discovered-token"}
assert call_args.kwargs["headers"] == expected_headers
# Verify response handling
assert len(result) == 1
assert result[0][0] == "response"
assert result[0][1] == {"success": True, "data": "test"}
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_multiple_header_credentials(
self,
mock_requests_class,
http_block,
mock_response,
):
"""Test credentials with multiple headers are all applied."""
# Create credentials with multiple headers
multi_header_creds = HostScopedCredentials(
provider="http",
host="api.example.com",
headers={
"Authorization": SecretStr("Bearer multi-token"),
"X-API-Key": SecretStr("api-key-456"),
"X-Client-ID": SecretStr("client-789"),
"X-Custom-Header": SecretStr("custom-value"),
},
title="Multi-Header Credentials",
)
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
# Test with credentials containing multiple headers
input_data = SendAuthenticatedWebRequestBlock.Input(
url="https://api.example.com/data",
method=HttpMethod.GET,
headers={"User-Agent": "test-agent"},
credentials=cast(
HttpCredentials,
{
"id": multi_header_creds.id,
"provider": "http",
"type": "host_scoped",
"title": multi_header_creds.title,
},
),
)
# Execute with multi-header credentials
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=multi_header_creds,
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify all headers are included
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
expected_headers = {
"Authorization": "Bearer multi-token",
"X-API-Key": "api-key-456",
"X-Client-ID": "client-789",
"X-Custom-Header": "custom-value",
"User-Agent": "test-agent",
}
assert call_args.kwargs["headers"] == expected_headers
@pytest.mark.asyncio
@patch("backend.blocks.http.Requests")
async def test_credentials_with_complex_url_patterns(
self,
mock_requests_class,
http_block,
mock_response,
):
"""Test credentials matching various URL patterns."""
# Test cases for different URL patterns
test_cases = [
{
"host_pattern": "api.example.com",
"test_url": "https://api.example.com/v1/users",
"should_match": True,
},
{
"host_pattern": "*.example.com",
"test_url": "https://api.example.com/v1/users",
"should_match": True,
},
{
"host_pattern": "*.example.com",
"test_url": "https://subdomain.example.com/data",
"should_match": True,
},
{
"host_pattern": "api.example.com",
"test_url": "https://api.different.com/data",
"should_match": False,
},
]
# Setup mocks
mock_requests = AsyncMock()
mock_requests.request.return_value = mock_response
mock_requests_class.return_value = mock_requests
for case in test_cases:
# Reset mock for each test case
mock_requests.reset_mock()
# Create credentials for this test case
test_creds = HostScopedCredentials(
provider="http",
host=case["host_pattern"],
headers={
"Authorization": SecretStr(f"Bearer {case['host_pattern']}-token"),
},
title=f"Credentials for {case['host_pattern']}",
)
input_data = SendAuthenticatedWebRequestBlock.Input(
url=case["test_url"],
method=HttpMethod.GET,
headers={"User-Agent": "test-agent"},
credentials=cast(
HttpCredentials,
{
"id": test_creds.id,
"provider": "http",
"type": "host_scoped",
"title": test_creds.title,
},
),
)
# Execute with test credentials
result = []
async for output_name, output_data in http_block.run(
input_data,
credentials=test_creds,
graph_exec_id="test-exec-id",
):
result.append((output_name, output_data))
# Verify headers based on whether pattern should match
mock_requests.request.assert_called_once()
call_args = mock_requests.request.call_args
headers = call_args.kwargs["headers"]
if case["should_match"]:
# Should include both user and credential headers
expected_auth = f"Bearer {case['host_pattern']}-token"
assert headers["Authorization"] == expected_auth
assert headers["User-Agent"] == "test-agent"
else:
# Should only include user headers
assert "Authorization" not in headers
assert headers["User-Agent"] == "test-agent"

View File

@@ -15,7 +15,7 @@ from backend.blocks.zerobounce._auth import (
ZeroBounceCredentialsInput,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import CredentialsField, SchemaField
from backend.data.model import SchemaField
class Response(BaseModel):
@@ -90,7 +90,7 @@ class ValidateEmailsBlock(Block):
description="IP address to validate",
default="",
)
credentials: ZeroBounceCredentialsInput = CredentialsField(
credentials: ZeroBounceCredentialsInput = SchemaField(
description="ZeroBounce credentials",
)

View File

@@ -1,5 +0,0 @@
from .graph import NodeModel
from .integrations import Webhook # noqa: F401
# Resolve Webhook <- NodeModel forward reference
NodeModel.model_rebuild()

View File

@@ -78,7 +78,6 @@ class BlockCategory(Enum):
PRODUCTIVITY = "Block that helps with productivity"
ISSUE_TRACKING = "Block that helps with issue tracking"
MULTIMEDIA = "Block that interacts with multimedia content"
MARKETING = "Block that helps with marketing"
def dict(self) -> dict[str, str]:
return {"category": self.name, "description": self.value}
@@ -119,10 +118,7 @@ class BlockSchema(BaseModel):
@classmethod
def validate_data(cls, data: BlockInput) -> str | None:
return json.validate_with_jsonschema(
schema=cls.jsonschema(),
data={k: v for k, v in data.items() if v is not None},
)
return json.validate_with_jsonschema(schema=cls.jsonschema(), data=data)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
@@ -475,8 +471,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
)
async for output_name, output_data in self.run(
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
**kwargs,
self.input_schema(**input_data), **kwargs
):
if output_name == "error":
raise RuntimeError(output_data)
@@ -486,22 +481,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
raise ValueError(f"Block produced an invalid output data: {error}")
yield output_name, output_data
def is_triggered_by_event_type(
self, trigger_config: dict[str, Any], event_type: str
) -> bool:
if not self.webhook_config:
raise TypeError("This method can't be used on non-trigger blocks")
if not self.webhook_config.event_filter_input:
return True
event_filter = trigger_config.get(self.webhook_config.event_filter_input)
if not event_filter:
raise ValueError("Event filter is not configured on trigger")
return event_type in [
self.webhook_config.event_format.format(event=k)
for k in event_filter
if event_filter[k] is True
]
# ======================= Block Helper Functions ======================= #

View File

@@ -2,9 +2,6 @@ from typing import Type
from backend.blocks.ai_music_generator import AIMusicGeneratorBlock
from backend.blocks.ai_shortform_video_block import AIShortformVideoCreatorBlock
from backend.blocks.apollo.organization import SearchOrganizationsBlock
from backend.blocks.apollo.people import SearchPeopleBlock
from backend.blocks.apollo.person import GetPersonDetailBlock
from backend.blocks.flux_kontext import AIImageEditorBlock, FluxKontextModelName
from backend.blocks.ideogram import IdeogramModelBlock
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
@@ -27,7 +24,6 @@ from backend.data.cost import BlockCost, BlockCostType
from backend.integrations.credentials_store import (
aiml_api_credentials,
anthropic_credentials,
apollo_credentials,
did_credentials,
groq_credentials,
ideogram_credentials,
@@ -85,9 +81,6 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.EVA_QWEN_2_5_32B: 1,
LlmModel.DEEPSEEK_CHAT: 2,
LlmModel.PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE: 1,
LlmModel.PERPLEXITY_SONAR: 1,
LlmModel.PERPLEXITY_SONAR_PRO: 5,
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10,
LlmModel.QWEN_QWQ_32B_PREVIEW: 2,
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
@@ -352,52 +345,4 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
)
],
SmartDecisionMakerBlock: LLM_COST,
SearchOrganizationsBlock: [
BlockCost(
cost_amount=2,
cost_filter={
"credentials": {
"id": apollo_credentials.id,
"provider": apollo_credentials.provider,
"type": apollo_credentials.type,
}
},
)
],
SearchPeopleBlock: [
BlockCost(
cost_amount=10,
cost_filter={
"enrich_info": False,
"credentials": {
"id": apollo_credentials.id,
"provider": apollo_credentials.provider,
"type": apollo_credentials.type,
},
},
),
BlockCost(
cost_amount=20,
cost_filter={
"enrich_info": True,
"credentials": {
"id": apollo_credentials.id,
"provider": apollo_credentials.provider,
"type": apollo_credentials.type,
},
},
),
],
GetPersonDetailBlock: [
BlockCost(
cost_amount=1,
cost_filter={
"credentials": {
"id": apollo_credentials.id,
"provider": apollo_credentials.provider,
"type": apollo_credentials.type,
}
},
)
],
}

View File

@@ -22,7 +22,6 @@ from prisma.models import (
AgentGraphExecution,
AgentNodeExecution,
AgentNodeExecutionInputOutput,
AgentNodeExecutionKeyValueData,
)
from prisma.types import (
AgentGraphExecutionCreateInput,
@@ -30,11 +29,10 @@ from prisma.types import (
AgentGraphExecutionWhereInput,
AgentNodeExecutionCreateInput,
AgentNodeExecutionInputOutputCreateInput,
AgentNodeExecutionKeyValueDataCreateInput,
AgentNodeExecutionUpdateInput,
AgentNodeExecutionWhereInput,
)
from pydantic import BaseModel, ConfigDict, JsonValue
from pydantic import BaseModel, ConfigDict
from pydantic.fields import Field
from backend.server.v2.store.exceptions import DatabaseError
@@ -50,14 +48,14 @@ from .block import (
get_webhook_block_ids,
)
from .db import BaseDbModel
from .event_bus import AsyncRedisEventBus, RedisEventBus
from .includes import (
EXECUTION_RESULT_INCLUDE,
EXECUTION_RESULT_ORDER,
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
graph_execution_include,
)
from .model import GraphExecutionStats, NodeExecutionStats
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
from .queue import AsyncRedisEventBus, RedisEventBus
T = TypeVar("T")
@@ -273,7 +271,7 @@ class GraphExecutionWithNodes(GraphExecution):
graph_id=self.graph_id,
graph_version=self.graph_version or 0,
graph_exec_id=self.id,
nodes_input_masks={}, # FIXME: store credentials on AgentGraphExecution
node_credentials_input_map={}, # FIXME
)
@@ -349,7 +347,6 @@ class NodeExecutionResult(BaseModel):
async def get_graph_executions(
graph_exec_id: str | None = None,
graph_id: str | None = None,
user_id: str | None = None,
statuses: list[ExecutionStatus] | None = None,
@@ -360,8 +357,6 @@ async def get_graph_executions(
where_filter: AgentGraphExecutionWhereInput = {
"isDeleted": False,
}
if graph_exec_id:
where_filter["id"] = graph_exec_id
if user_id:
where_filter["userId"] = user_id
if graph_id:
@@ -561,18 +556,18 @@ async def upsert_execution_input(
async def upsert_execution_output(
node_exec_id: str,
output_name: str,
output_data: Any | None,
output_data: Any,
) -> None:
"""
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
"""
data = AgentNodeExecutionInputOutputCreateInput(
name=output_name,
referencedByOutputExecId=node_exec_id,
await AgentNodeExecutionInputOutput.prisma().create(
data=AgentNodeExecutionInputOutputCreateInput(
name=output_name,
data=Json(output_data),
referencedByOutputExecId=node_exec_id,
)
)
if output_data is not None:
data["data"] = Json(output_data)
await AgentNodeExecutionInputOutput.prisma().create(data=data)
async def update_graph_execution_start_time(
@@ -593,10 +588,12 @@ async def update_graph_execution_start_time(
async def update_graph_execution_stats(
graph_exec_id: str,
status: ExecutionStatus | None = None,
status: ExecutionStatus,
stats: GraphExecutionStats | None = None,
) -> GraphExecution | None:
update_data: AgentGraphExecutionUpdateManyMutationInput = {}
update_data: AgentGraphExecutionUpdateManyMutationInput = {
"executionStatus": status
}
if stats:
stats_dict = stats.model_dump()
@@ -604,9 +601,6 @@ async def update_graph_execution_stats(
stats_dict["error"] = str(stats_dict["error"])
update_data["stats"] = Json(stats_dict)
if status:
update_data["executionStatus"] = status
updated_count = await AgentGraphExecution.prisma().update_many(
where={
"id": graph_exec_id,
@@ -789,7 +783,7 @@ class GraphExecutionEntry(BaseModel):
graph_exec_id: str
graph_id: str
graph_version: int
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]]
class NodeExecutionEntry(BaseModel):
@@ -909,57 +903,3 @@ class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
) -> AsyncGenerator[ExecutionEvent, None]:
async for event in self.listen_events(f"{user_id}/{graph_id}/{graph_exec_id}"):
yield event
# --------------------- KV Data Functions --------------------- #
async def get_execution_kv_data(user_id: str, key: str) -> Any | None:
"""
Get key-value data for a user and key.
Args:
user_id: The id of the User.
key: The key to retrieve data for.
Returns:
The data associated with the key, or None if not found.
"""
kv_data = await AgentNodeExecutionKeyValueData.prisma().find_unique(
where={"userId_key": {"userId": user_id, "key": key}}
)
return (
type_utils.convert(kv_data.data, type[Any])
if kv_data and kv_data.data
else None
)
async def set_execution_kv_data(
user_id: str, node_exec_id: str, key: str, data: Any
) -> Any | None:
"""
Set key-value data for a user and key.
Args:
user_id: The id of the User.
node_exec_id: The id of the AgentNodeExecution.
key: The key to store data under.
data: The data to store.
"""
resp = await AgentNodeExecutionKeyValueData.prisma().upsert(
where={"userId_key": {"userId": user_id, "key": key}},
data={
"create": AgentNodeExecutionKeyValueDataCreateInput(
userId=user_id,
agentNodeExecutionId=node_exec_id,
key=key,
data=Json(data) if data is not None else None,
),
"update": {
"agentNodeExecutionId": node_exec_id,
"data": Json(data) if data is not None else None,
},
},
)
return type_utils.convert(resp.data, type[Any]) if resp and resp.data else None

View File

@@ -1,7 +1,7 @@
import logging
import uuid
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from typing import Any, Literal, Optional, cast
import prisma
from prisma import Json
@@ -14,7 +14,7 @@ from prisma.types import (
AgentNodeLinkCreateInput,
StoreListingVersionWhereInput,
)
from pydantic import JsonValue, create_model
from pydantic import create_model
from pydantic.fields import computed_field
from backend.blocks.agent import AgentExecutorBlock
@@ -27,15 +27,12 @@ from backend.data.model import (
CredentialsMetaInput,
is_credentials_field_name,
)
from backend.integrations.providers import ProviderName
from backend.util import type as type_utils
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
from .db import BaseDbModel, transaction
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
if TYPE_CHECKING:
from .integrations import Webhook
from .integrations import Webhook
logger = logging.getLogger(__name__)
@@ -84,12 +81,10 @@ class NodeModel(Node):
graph_version: int
webhook_id: Optional[str] = None
webhook: Optional["Webhook"] = None
webhook: Optional[Webhook] = None
@staticmethod
def from_db(node: AgentNode, for_export: bool = False) -> "NodeModel":
from .integrations import Webhook
obj = NodeModel(
id=node.id,
block_id=node.agentBlockId,
@@ -107,7 +102,19 @@ class NodeModel(Node):
return obj
def is_triggered_by_event_type(self, event_type: str) -> bool:
return self.block.is_triggered_by_event_type(self.input_default, event_type)
block = self.block
if not block.webhook_config:
raise TypeError("This method can't be used on non-webhook blocks")
if not block.webhook_config.event_filter_input:
return True
event_filter = self.input_default.get(block.webhook_config.event_filter_input)
if not event_filter:
raise ValueError(f"Event filter is not configured on node #{self.id}")
return event_type in [
block.webhook_config.event_format.format(event=k)
for k in event_filter
if event_filter[k] is True
]
def stripped_for_export(self) -> "NodeModel":
"""
@@ -155,6 +162,10 @@ class NodeModel(Node):
return result
# Fix 2-way reference Node <-> Webhook
Webhook.model_rebuild()
class BaseGraph(BaseDbModel):
version: int = 1
is_active: bool = True
@@ -244,8 +255,6 @@ class Graph(BaseGraph):
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
if field.provider != other_field.provider:
continue
if ProviderName.HTTP in field.provider:
continue
# If this happens, that means a block implementation probably needs
# to be updated.
@@ -267,7 +276,6 @@ class Graph(BaseGraph):
required_scopes=set(field_info.required_scopes or []),
discriminator=field_info.discriminator,
discriminator_mapping=field_info.discriminator_mapping,
discriminator_values=field_info.discriminator_values,
),
)
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
@@ -286,40 +294,37 @@ class Graph(BaseGraph):
Returns:
dict[aggregated_field_key, tuple(
CredentialsFieldInfo: A spec for one aggregated credentials field
(now includes discriminator_values from matching nodes)
set[(node_id, field_name)]: Node credentials fields that are
compatible with this aggregated field spec
)]
"""
# First collect all credential field data with input defaults
node_credential_data = []
for graph in [self] + self.sub_graphs:
for node in graph.nodes:
for (
field_name,
field_info,
) in node.block.input_schema.get_credentials_fields_info().items():
discriminator = field_info.discriminator
if not discriminator:
node_credential_data.append((field_info, (node.id, field_name)))
continue
discriminator_value = node.input_default.get(discriminator)
if discriminator_value is None:
node_credential_data.append((field_info, (node.id, field_name)))
continue
discriminated_info = field_info.discriminate(discriminator_value)
discriminated_info.discriminator_values.add(discriminator_value)
node_credential_data.append(
(discriminated_info, (node.id, field_name))
return {
"_".join(sorted(agg_field_info.provider))
+ "_"
+ "_".join(sorted(agg_field_info.supported_types))
+ "_credentials": (agg_field_info, node_fields)
for agg_field_info, node_fields in CredentialsFieldInfo.combine(
*(
(
# Apply discrimination before aggregating credentials inputs
(
field_info.discriminate(
node.input_default[field_info.discriminator]
)
if (
field_info.discriminator
and node.input_default.get(field_info.discriminator)
)
else field_info
),
(node.id, field_name),
)
# Combine credential field info (this will merge discriminator_values automatically)
return CredentialsFieldInfo.combine(*node_credential_data)
for graph in [self] + self.sub_graphs
for node in graph.nodes
for field_name, field_info in node.block.input_schema.get_credentials_fields_info().items()
)
)
}
class GraphModel(Graph):
@@ -389,10 +394,8 @@ class GraphModel(Graph):
# Reassign Link IDs
for link in graph.links:
if link.source_id in id_map:
link.source_id = id_map[link.source_id]
if link.sink_id in id_map:
link.sink_id = id_map[link.sink_id]
link.source_id = id_map[link.source_id]
link.sink_id = id_map[link.sink_id]
# Reassign User IDs for agent blocks
for node in graph.nodes:
@@ -400,26 +403,16 @@ class GraphModel(Graph):
continue
node.input_default["user_id"] = user_id
node.input_default.setdefault("inputs", {})
if (
graph_id := node.input_default.get("graph_id")
) and graph_id in graph_id_map:
if (graph_id := node.input_default.get("graph_id")) in graph_id_map:
node.input_default["graph_id"] = graph_id_map[graph_id]
def validate_graph(
self,
for_run: bool = False,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
):
self._validate_graph(self, for_run, nodes_input_masks)
def validate_graph(self, for_run: bool = False):
self._validate_graph(self, for_run)
for sub_graph in self.sub_graphs:
self._validate_graph(sub_graph, for_run, nodes_input_masks)
self._validate_graph(sub_graph, for_run)
@staticmethod
def _validate_graph(
graph: BaseGraph,
for_run: bool = False,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
):
def _validate_graph(graph: BaseGraph, for_run: bool = False):
def is_tool_pin(name: str) -> bool:
return name.startswith("tools_^_")
@@ -446,18 +439,20 @@ class GraphModel(Graph):
if (block := nodes_block.get(node.id)) is None:
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
node_input_mask = (
nodes_input_masks.get(node.id, {}) if nodes_input_masks else {}
)
provided_inputs = set(
[sanitize(name) for name in node.input_default]
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
+ ([name for name in node_input_mask] if node_input_mask else [])
)
InputSchema = block.input_schema
for name in (required_fields := InputSchema.get_required_fields()):
if (
name not in provided_inputs
# Webhook payload is passed in by ExecutionManager
and not (
name == "payload"
and block.block_type
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
)
# Checking availability of credentials is done by ExecutionManager
and name not in InputSchema.get_credentials_fields()
# Validate only I/O nodes, or validate everything when executing
@@ -490,18 +485,10 @@ class GraphModel(Graph):
def has_value(node: Node, name: str):
return (
(
name in node.input_default
and node.input_default[name] is not None
and str(node.input_default[name]).strip() != ""
)
or (name in input_fields and input_fields[name].default is not None)
or (
name in node_input_mask
and node_input_mask[name] is not None
and str(node_input_mask[name]).strip() != ""
)
)
name in node.input_default
and node.input_default[name] is not None
and str(node.input_default[name]).strip() != ""
) or (name in input_fields and input_fields[name].default is not None)
# Validate dependencies between fields
for field_name in input_fields.keys():
@@ -587,7 +574,7 @@ class GraphModel(Graph):
graph: AgentGraph,
for_export: bool = False,
sub_graphs: list[AgentGraph] | None = None,
) -> "GraphModel":
):
return GraphModel(
id=graph.id,
user_id=graph.userId if not for_export else "",
@@ -616,7 +603,6 @@ class GraphModel(Graph):
async def get_node(node_id: str) -> NodeModel:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
node = await AgentNode.prisma().find_unique_or_raise(
where={"id": node_id},
include=AGENT_NODE_INCLUDE,
@@ -625,7 +611,6 @@ async def get_node(node_id: str) -> NodeModel:
async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
node = await AgentNode.prisma().update(
where={"id": node_id},
data=(

View File

@@ -60,8 +60,7 @@ def graph_execution_include(
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
"AgentNodes": {"include": AGENT_NODE_INCLUDE},
"AgentPresets": {"include": {"InputPresets": True}},
"AgentNodes": {"include": AGENT_NODE_INCLUDE}
}

View File

@@ -1,25 +1,21 @@
import logging
from typing import AsyncGenerator, Literal, Optional, overload
from typing import TYPE_CHECKING, AsyncGenerator, Optional
from prisma import Json
from prisma.models import IntegrationWebhook
from prisma.types import (
IntegrationWebhookCreateInput,
IntegrationWebhookUpdateInput,
IntegrationWebhookWhereInput,
Serializable,
)
from prisma.types import IntegrationWebhookCreateInput
from pydantic import Field, computed_field
from backend.data.event_bus import AsyncRedisEventBus
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
from backend.data.queue import AsyncRedisEventBus
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.utils import webhook_ingress_url
from backend.server.v2.library.model import LibraryAgentPreset
from backend.util.exceptions import NotFoundError
from .db import BaseDbModel
from .graph import NodeModel
if TYPE_CHECKING:
from .graph import NodeModel
logger = logging.getLogger(__name__)
@@ -36,6 +32,8 @@ class Webhook(BaseDbModel):
provider_webhook_id: str
attached_nodes: Optional[list["NodeModel"]] = None
@computed_field
@property
def url(self) -> str:
@@ -43,6 +41,8 @@ class Webhook(BaseDbModel):
@staticmethod
def from_db(webhook: IntegrationWebhook):
from .graph import NodeModel
return Webhook(
id=webhook.id,
user_id=webhook.userId,
@@ -54,26 +54,11 @@ class Webhook(BaseDbModel):
config=dict(webhook.config),
secret=webhook.secret,
provider_webhook_id=webhook.providerWebhookId,
)
class WebhookWithRelations(Webhook):
triggered_nodes: list[NodeModel]
triggered_presets: list[LibraryAgentPreset]
@staticmethod
def from_db(webhook: IntegrationWebhook):
if webhook.AgentNodes is None or webhook.AgentPresets is None:
raise ValueError(
"AgentNodes and AgentPresets must be included in "
"IntegrationWebhook query with relations"
)
return WebhookWithRelations(
**Webhook.from_db(webhook).model_dump(),
triggered_nodes=[NodeModel.from_db(node) for node in webhook.AgentNodes],
triggered_presets=[
LibraryAgentPreset.from_db(preset) for preset in webhook.AgentPresets
],
attached_nodes=(
[NodeModel.from_db(node) for node in webhook.AgentNodes]
if webhook.AgentNodes is not None
else None
),
)
@@ -98,19 +83,7 @@ async def create_webhook(webhook: Webhook) -> Webhook:
return Webhook.from_db(created_webhook)
@overload
async def get_webhook(
webhook_id: str, *, include_relations: Literal[True]
) -> WebhookWithRelations: ...
@overload
async def get_webhook(
webhook_id: str, *, include_relations: Literal[False] = False
) -> Webhook: ...
async def get_webhook(
webhook_id: str, *, include_relations: bool = False
) -> Webhook | WebhookWithRelations:
async def get_webhook(webhook_id: str) -> Webhook:
"""
⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints.
@@ -119,113 +92,73 @@ async def get_webhook(
"""
webhook = await IntegrationWebhook.prisma().find_unique(
where={"id": webhook_id},
include=INTEGRATION_WEBHOOK_INCLUDE if include_relations else None,
include=INTEGRATION_WEBHOOK_INCLUDE,
)
if not webhook:
raise NotFoundError(f"Webhook #{webhook_id} not found")
return (WebhookWithRelations if include_relations else Webhook).from_db(webhook)
return Webhook.from_db(webhook)
@overload
async def get_all_webhooks_by_creds(
user_id: str, credentials_id: str, *, include_relations: Literal[True]
) -> list[WebhookWithRelations]: ...
@overload
async def get_all_webhooks_by_creds(
user_id: str, credentials_id: str, *, include_relations: Literal[False] = False
) -> list[Webhook]: ...
async def get_all_webhooks_by_creds(
user_id: str, credentials_id: str, *, include_relations: bool = False
) -> list[Webhook] | list[WebhookWithRelations]:
async def get_all_webhooks_by_creds(credentials_id: str) -> list[Webhook]:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
if not credentials_id:
raise ValueError("credentials_id must not be empty")
webhooks = await IntegrationWebhook.prisma().find_many(
where={"userId": user_id, "credentialsId": credentials_id},
include=INTEGRATION_WEBHOOK_INCLUDE if include_relations else None,
where={"credentialsId": credentials_id},
include=INTEGRATION_WEBHOOK_INCLUDE,
)
return [
(WebhookWithRelations if include_relations else Webhook).from_db(webhook)
for webhook in webhooks
]
return [Webhook.from_db(webhook) for webhook in webhooks]
async def find_webhook_by_credentials_and_props(
user_id: str,
credentials_id: str,
webhook_type: str,
resource: str,
events: list[str],
credentials_id: str, webhook_type: str, resource: str, events: list[str]
) -> Webhook | None:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
webhook = await IntegrationWebhook.prisma().find_first(
where={
"userId": user_id,
"credentialsId": credentials_id,
"webhookType": webhook_type,
"resource": resource,
"events": {"has_every": events},
},
include=INTEGRATION_WEBHOOK_INCLUDE,
)
return Webhook.from_db(webhook) if webhook else None
async def find_webhook_by_graph_and_props(
user_id: str,
provider: str,
webhook_type: str,
graph_id: Optional[str] = None,
preset_id: Optional[str] = None,
graph_id: str, provider: str, webhook_type: str, events: list[str]
) -> Webhook | None:
"""Either `graph_id` or `preset_id` must be provided."""
where_clause: IntegrationWebhookWhereInput = {
"userId": user_id,
"provider": provider,
"webhookType": webhook_type,
}
if preset_id:
where_clause["AgentPresets"] = {"some": {"id": preset_id}}
elif graph_id:
where_clause["AgentNodes"] = {"some": {"agentGraphId": graph_id}}
else:
raise ValueError("Either graph_id or preset_id must be provided")
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
webhook = await IntegrationWebhook.prisma().find_first(
where=where_clause,
where={
"provider": provider,
"webhookType": webhook_type,
"events": {"has_every": events},
"AgentNodes": {"some": {"agentGraphId": graph_id}},
},
include=INTEGRATION_WEBHOOK_INCLUDE,
)
return Webhook.from_db(webhook) if webhook else None
async def update_webhook(
webhook_id: str,
config: Optional[dict[str, Serializable]] = None,
events: Optional[list[str]] = None,
) -> Webhook:
async def update_webhook_config(webhook_id: str, updated_config: dict) -> Webhook:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
data: IntegrationWebhookUpdateInput = {}
if config is not None:
data["config"] = Json(config)
if events is not None:
data["events"] = events
if not data:
raise ValueError("Empty update query")
_updated_webhook = await IntegrationWebhook.prisma().update(
where={"id": webhook_id},
data=data,
data={"config": Json(updated_config)},
include=INTEGRATION_WEBHOOK_INCLUDE,
)
if _updated_webhook is None:
raise NotFoundError(f"Webhook #{webhook_id} not found")
raise ValueError(f"Webhook #{webhook_id} not found")
return Webhook.from_db(_updated_webhook)
async def delete_webhook(user_id: str, webhook_id: str) -> None:
deleted = await IntegrationWebhook.prisma().delete_many(
where={"id": webhook_id, "userId": user_id}
)
if deleted < 1:
raise NotFoundError(f"Webhook #{webhook_id} not found")
async def delete_webhook(webhook_id: str) -> None:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
deleted = await IntegrationWebhook.prisma().delete(where={"id": webhook_id})
if not deleted:
raise ValueError(f"Webhook #{webhook_id} not found")
# --------------------- WEBHOOK EVENTS --------------------- #

View File

@@ -14,12 +14,11 @@ from typing import (
Generic,
Literal,
Optional,
Sequence,
TypedDict,
TypeVar,
cast,
get_args,
)
from urllib.parse import urlparse
from uuid import uuid4
from prisma.enums import CreditTransactionType
@@ -241,65 +240,13 @@ class UserPasswordCredentials(_BaseCredentials):
return f"Basic {base64.b64encode(f'{self.username.get_secret_value()}:{self.password.get_secret_value()}'.encode()).decode()}"
class HostScopedCredentials(_BaseCredentials):
type: Literal["host_scoped"] = "host_scoped"
host: str = Field(description="The host/URI pattern to match against request URLs")
headers: dict[str, SecretStr] = Field(
description="Key-value header map to add to matching requests",
default_factory=dict,
)
def _extract_headers(self, headers: dict[str, SecretStr]) -> dict[str, str]:
"""Helper to extract secret values from headers."""
return {key: value.get_secret_value() for key, value in headers.items()}
@field_serializer("headers")
def serialize_headers(self, headers: dict[str, SecretStr]) -> dict[str, str]:
"""Serialize headers by extracting secret values."""
return self._extract_headers(headers)
def get_headers_dict(self) -> dict[str, str]:
"""Get headers with secret values extracted."""
return self._extract_headers(self.headers)
def auth_header(self) -> str:
"""Get authorization header for backward compatibility."""
auth_headers = self.get_headers_dict()
if "Authorization" in auth_headers:
return auth_headers["Authorization"]
return ""
def matches_url(self, url: str) -> bool:
"""Check if this credential should be applied to the given URL."""
parsed_url = urlparse(url)
# Extract hostname without port
request_host = parsed_url.hostname
if not request_host:
return False
# Simple host matching - exact match or wildcard subdomain match
if self.host == request_host:
return True
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
if self.host.startswith("*."):
domain = self.host[2:] # Remove "*."
return request_host.endswith(f".{domain}") or request_host == domain
return False
Credentials = Annotated[
OAuth2Credentials
| APIKeyCredentials
| UserPasswordCredentials
| HostScopedCredentials,
OAuth2Credentials | APIKeyCredentials | UserPasswordCredentials,
Field(discriminator="type"),
]
CredentialsType = Literal["api_key", "oauth2", "user_password", "host_scoped"]
CredentialsType = Literal["api_key", "oauth2", "user_password"]
class OAuthState(BaseModel):
@@ -373,29 +320,15 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
)
@staticmethod
def _add_json_schema_extra(schema: dict, model_class: type):
# Use model_class for allowed_providers/cred_types
if hasattr(model_class, "allowed_providers") and hasattr(
model_class, "allowed_cred_types"
):
schema["credentials_provider"] = model_class.allowed_providers()
schema["credentials_types"] = model_class.allowed_cred_types()
# Do not return anything, just mutate schema in place
def _add_json_schema_extra(schema, cls: CredentialsMetaInput):
schema["credentials_provider"] = cls.allowed_providers()
schema["credentials_types"] = cls.allowed_cred_types()
model_config = ConfigDict(
json_schema_extra=_add_json_schema_extra, # type: ignore
)
def _extract_host_from_url(url: str) -> str:
"""Extract host from URL for grouping host-scoped credentials."""
try:
parsed = urlparse(url)
return parsed.hostname or url
except Exception:
return ""
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
provider: frozenset[CP] = Field(..., alias="credentials_provider")
@@ -403,12 +336,11 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
required_scopes: Optional[frozenset[str]] = Field(None, alias="credentials_scopes")
discriminator: Optional[str] = None
discriminator_mapping: Optional[dict[str, CP]] = None
discriminator_values: set[Any] = Field(default_factory=set)
@classmethod
def combine(
cls, *fields: tuple[CredentialsFieldInfo[CP, CT], T]
) -> dict[str, tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
) -> Sequence[tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
"""
Combines multiple CredentialsFieldInfo objects into as few as possible.
@@ -426,36 +358,22 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
the set of keys of the respective original items that were grouped together.
"""
if not fields:
return {}
return []
# Group fields by their provider and supported_types
# For HTTP host-scoped credentials, also group by host
grouped_fields: defaultdict[
tuple[frozenset[CP], frozenset[CT]],
list[tuple[T, CredentialsFieldInfo[CP, CT]]],
] = defaultdict(list)
for field, key in fields:
if field.provider == frozenset([ProviderName.HTTP]):
# HTTP host-scoped credentials can have different hosts that reqires different credential sets.
# Group by host extracted from the URL
providers = frozenset(
[cast(CP, "http")]
+ [
cast(CP, _extract_host_from_url(str(value)))
for value in field.discriminator_values
]
)
else:
providers = frozenset(field.provider)
group_key = (providers, frozenset(field.supported_types))
group_key = (frozenset(field.provider), frozenset(field.supported_types))
grouped_fields[group_key].append((key, field))
# Combine fields within each group
result: dict[str, tuple[CredentialsFieldInfo[CP, CT], set[T]]] = {}
result: list[tuple[CredentialsFieldInfo[CP, CT], set[T]]] = []
for key, group in grouped_fields.items():
for group in grouped_fields.values():
# Start with the first field in the group
_, combined = group[0]
@@ -468,32 +386,18 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
if field.required_scopes:
all_scopes.update(field.required_scopes)
# Combine discriminator_values from all fields in the group (removing duplicates)
all_discriminator_values = []
for _, field in group:
for value in field.discriminator_values:
if value not in all_discriminator_values:
all_discriminator_values.append(value)
# Generate the key for the combined result
providers_key, supported_types_key = key
group_key = (
"-".join(sorted(providers_key))
+ "_"
+ "-".join(sorted(supported_types_key))
+ "_credentials"
)
result[group_key] = (
CredentialsFieldInfo[CP, CT](
credentials_provider=combined.provider,
credentials_types=combined.supported_types,
credentials_scopes=frozenset(all_scopes) or None,
discriminator=combined.discriminator,
discriminator_mapping=combined.discriminator_mapping,
discriminator_values=set(all_discriminator_values),
),
combined_keys,
# Create a new combined field
result.append(
(
CredentialsFieldInfo[CP, CT](
credentials_provider=combined.provider,
credentials_types=combined.supported_types,
credentials_scopes=frozenset(all_scopes) or None,
discriminator=combined.discriminator,
discriminator_mapping=combined.discriminator_mapping,
),
combined_keys,
)
)
return result
@@ -502,15 +406,11 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
if not (self.discriminator and self.discriminator_mapping):
return self
discriminator_value = self.discriminator_mapping[discriminator_value]
return CredentialsFieldInfo(
credentials_provider=frozenset(
[self.discriminator_mapping[discriminator_value]]
),
credentials_provider=frozenset([discriminator_value]),
credentials_types=self.supported_types,
credentials_scopes=self.required_scopes,
discriminator=self.discriminator,
discriminator_mapping=self.discriminator_mapping,
discriminator_values=self.discriminator_values,
)
@@ -519,7 +419,6 @@ def CredentialsField(
*,
discriminator: Optional[str] = None,
discriminator_mapping: Optional[dict[str, Any]] = None,
discriminator_values: Optional[set[Any]] = None,
title: Optional[str] = None,
description: Optional[str] = None,
**kwargs,
@@ -535,7 +434,6 @@ def CredentialsField(
"credentials_scopes": list(required_scopes) or None,
"discriminator": discriminator,
"discriminator_mapping": discriminator_mapping,
"discriminator_values": discriminator_values,
}.items()
if v is not None
}

View File

@@ -1,143 +0,0 @@
import pytest
from pydantic import SecretStr
from backend.data.model import HostScopedCredentials
class TestHostScopedCredentials:
def test_host_scoped_credentials_creation(self):
"""Test creating HostScopedCredentials with required fields."""
creds = HostScopedCredentials(
provider="custom",
host="api.example.com",
headers={
"Authorization": SecretStr("Bearer secret-token"),
"X-API-Key": SecretStr("api-key-123"),
},
title="Example API Credentials",
)
assert creds.type == "host_scoped"
assert creds.provider == "custom"
assert creds.host == "api.example.com"
assert creds.title == "Example API Credentials"
assert len(creds.headers) == 2
assert "Authorization" in creds.headers
assert "X-API-Key" in creds.headers
def test_get_headers_dict(self):
"""Test getting headers with secret values extracted."""
creds = HostScopedCredentials(
provider="custom",
host="api.example.com",
headers={
"Authorization": SecretStr("Bearer secret-token"),
"X-Custom-Header": SecretStr("custom-value"),
},
)
headers_dict = creds.get_headers_dict()
assert headers_dict == {
"Authorization": "Bearer secret-token",
"X-Custom-Header": "custom-value",
}
def test_matches_url_exact_host(self):
"""Test URL matching with exact host match."""
creds = HostScopedCredentials(
provider="custom",
host="api.example.com",
headers={"Authorization": SecretStr("Bearer token")},
)
assert creds.matches_url("https://api.example.com/v1/data")
assert creds.matches_url("http://api.example.com/endpoint")
assert not creds.matches_url("https://other.example.com/v1/data")
assert not creds.matches_url("https://subdomain.api.example.com/v1/data")
def test_matches_url_wildcard_subdomain(self):
"""Test URL matching with wildcard subdomain pattern."""
creds = HostScopedCredentials(
provider="custom",
host="*.example.com",
headers={"Authorization": SecretStr("Bearer token")},
)
assert creds.matches_url("https://api.example.com/v1/data")
assert creds.matches_url("https://subdomain.example.com/endpoint")
assert creds.matches_url("https://deep.nested.example.com/path")
assert creds.matches_url("https://example.com/path") # Base domain should match
assert not creds.matches_url("https://example.org/v1/data")
assert not creds.matches_url("https://notexample.com/v1/data")
def test_matches_url_with_port_and_path(self):
"""Test URL matching with ports and paths."""
creds = HostScopedCredentials(
provider="custom",
host="localhost",
headers={"Authorization": SecretStr("Bearer token")},
)
assert creds.matches_url("http://localhost:8080/api/v1")
assert creds.matches_url("https://localhost:443/secure/endpoint")
assert creds.matches_url("http://localhost/simple")
def test_empty_headers_dict(self):
"""Test HostScopedCredentials with empty headers."""
creds = HostScopedCredentials(
provider="custom", host="api.example.com", headers={}
)
assert creds.get_headers_dict() == {}
assert creds.matches_url("https://api.example.com/test")
def test_credential_serialization(self):
"""Test that credentials can be serialized/deserialized properly."""
original_creds = HostScopedCredentials(
provider="custom",
host="api.example.com",
headers={
"Authorization": SecretStr("Bearer secret-token"),
"X-API-Key": SecretStr("api-key-123"),
},
title="Test Credentials",
)
# Serialize to dict (simulating storage)
serialized = original_creds.model_dump()
# Deserialize back
restored_creds = HostScopedCredentials.model_validate(serialized)
assert restored_creds.id == original_creds.id
assert restored_creds.provider == original_creds.provider
assert restored_creds.host == original_creds.host
assert restored_creds.title == original_creds.title
assert restored_creds.type == "host_scoped"
# Check that headers are properly restored
assert restored_creds.get_headers_dict() == original_creds.get_headers_dict()
@pytest.mark.parametrize(
"host,test_url,expected",
[
("api.example.com", "https://api.example.com/test", True),
("api.example.com", "https://different.example.com/test", False),
("*.example.com", "https://api.example.com/test", True),
("*.example.com", "https://sub.api.example.com/test", True),
("*.example.com", "https://example.com/test", True),
("*.example.com", "https://example.org/test", False),
("localhost", "http://localhost:3000/test", True),
("localhost", "http://127.0.0.1:3000/test", False),
],
)
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
"""Parametrized test for various URL matching scenarios."""
creds = HostScopedCredentials(
provider="test",
host=host,
headers={"Authorization": SecretStr("Bearer token")},
)
assert creds.matches_url(test_url) == expected

View File

@@ -7,7 +7,7 @@ from pydantic import BaseModel
from redis.asyncio.client import PubSub as AsyncPubSub
from redis.client import PubSub
from backend.data import redis_client as redis
from backend.data import redis
logger = logging.getLogger(__name__)

View File

@@ -5,14 +5,12 @@ from backend.data import db
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import (
create_graph_execution,
get_execution_kv_data,
get_graph_execution,
get_graph_execution_meta,
get_graph_executions,
get_latest_node_execution,
get_node_execution,
get_node_executions,
set_execution_kv_data,
update_graph_execution_start_time,
update_graph_execution_stats,
update_node_execution_stats,
@@ -103,8 +101,6 @@ class DatabaseManager(AppService):
update_node_execution_stats = _(update_node_execution_stats)
upsert_execution_input = _(upsert_execution_input)
upsert_execution_output = _(upsert_execution_output)
get_execution_kv_data = _(get_execution_kv_data)
set_execution_kv_data = _(set_execution_kv_data)
# Graphs
get_node = _(get_node)
@@ -163,8 +159,6 @@ class DatabaseManagerClient(AppServiceClient):
update_node_execution_stats = _(d.update_node_execution_stats)
upsert_execution_input = _(d.upsert_execution_input)
upsert_execution_output = _(d.upsert_execution_output)
get_execution_kv_data = _(d.get_execution_kv_data)
set_execution_kv_data = _(d.set_execution_kv_data)
# Graphs
get_node = _(d.get_node)
@@ -208,11 +202,8 @@ class DatabaseManagerAsyncClient(AppServiceClient):
return DatabaseManager
create_graph_execution = d.create_graph_execution
get_connected_output_nodes = d.get_connected_output_nodes
get_latest_node_execution = d.get_latest_node_execution
get_graph = d.get_graph
get_graph_metadata = d.get_graph_metadata
get_graph_execution_meta = d.get_graph_execution_meta
get_node = d.get_node
get_node_execution = d.get_node_execution
get_node_executions = d.get_node_executions
@@ -224,5 +215,3 @@ class DatabaseManagerAsyncClient(AppServiceClient):
update_node_execution_status = d.update_node_execution_status
update_node_execution_status_batch = d.update_node_execution_status_batch
update_user_integrations = d.update_user_integrations
get_execution_kv_data = d.get_execution_kv_data
set_execution_kv_data = d.set_execution_kv_data

View File

@@ -12,11 +12,14 @@ from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
from pika.adapters.blocking_connection import BlockingChannel
from pika.spec import Basic, BasicProperties
from pydantic import JsonValue
from redis.asyncio.lock import Lock as RedisLock
from backend.blocks.io import AgentOutputBlock
from backend.data.model import GraphExecutionStats, NodeExecutionStats
from backend.data.model import (
CredentialsMetaInput,
GraphExecutionStats,
NodeExecutionStats,
)
from backend.data.notifications import (
AgentRunData,
LowBalanceData,
@@ -24,7 +27,7 @@ from backend.data.notifications import (
NotificationType,
)
from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.utils import LogMetadata, create_execution_queue_config
from backend.executor.utils import create_execution_queue_config
from backend.notifications.notifications import queue_notification
from backend.util.exceptions import InsufficientBalanceError
@@ -35,7 +38,7 @@ from autogpt_libs.utils.cache import thread_cached
from prometheus_client import Gauge, start_http_server
from backend.blocks.agent import AgentExecutorBlock
from backend.data import redis_client as redis
from backend.data import redis
from backend.data.block import (
BlockData,
BlockInput,
@@ -98,6 +101,35 @@ utilization_gauge = Gauge(
)
class LogMetadata(TruncatedLogger):
def __init__(
self,
user_id: str,
graph_eid: str,
graph_id: str,
node_eid: str,
node_id: str,
block_name: str,
max_length: int = 1000,
):
metadata = {
"component": "ExecutionManager",
"user_id": user_id,
"graph_eid": graph_eid,
"graph_id": graph_id,
"node_eid": node_eid,
"node_id": node_id,
"block_name": block_name,
}
prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|neid:{node_eid}|{block_name}]"
super().__init__(
_logger,
max_length=max_length,
prefix=prefix,
metadata=metadata,
)
T = TypeVar("T")
@@ -106,7 +138,9 @@ async def execute_node(
creds_manager: IntegrationCredentialsManager,
data: NodeExecutionEntry,
execution_stats: NodeExecutionStats | None = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
) -> BlockOutput:
"""
Execute a node in the graph. This will trigger a block execution on a node,
@@ -129,7 +163,6 @@ async def execute_node(
node_block = node.block
log_metadata = LogMetadata(
logger=_logger,
user_id=user_id,
graph_eid=graph_exec_id,
graph_id=graph_id,
@@ -150,8 +183,8 @@ async def execute_node(
if isinstance(node_block, AgentExecutorBlock):
_input_data = AgentExecutorBlock.Input(**node.input_default)
_input_data.inputs = input_data
if nodes_input_masks:
_input_data.nodes_input_masks = nodes_input_masks
if node_credentials_input_map:
_input_data.node_credentials_input_map = node_credentials_input_map
input_data = _input_data.model_dump()
data.inputs = input_data
@@ -222,7 +255,7 @@ async def _enqueue_next_nodes(
graph_exec_id: str,
graph_id: str,
log_metadata: LogMetadata,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]],
) -> list[NodeExecutionEntry]:
async def add_enqueued_execution(
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
@@ -256,9 +289,8 @@ async def _enqueue_next_nodes(
next_input_name = node_link.sink_name
next_node_id = node_link.sink_id
output_name, _ = output
next_data = parse_execution_output(output, next_output_name)
if next_data is None and output_name != next_output_name:
if next_data is None:
return enqueued_executions
next_node = await db_client.get_node(next_node_id)
@@ -293,12 +325,14 @@ async def _enqueue_next_nodes(
for name in static_link_names:
next_node_input[name] = latest_execution.input_data.get(name)
# Apply node input overrides
node_input_mask = None
if nodes_input_masks and (
node_input_mask := nodes_input_masks.get(next_node.id)
# Apply node credentials overrides
node_credentials = None
if node_credentials_input_map and (
node_credentials := node_credentials_input_map.get(next_node.id)
):
next_node_input.update(node_input_mask)
next_node_input.update(
{k: v.model_dump() for k, v in node_credentials.items()}
)
# Validate the input data for the next node.
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
@@ -342,9 +376,11 @@ async def _enqueue_next_nodes(
for input_name in static_link_names:
idata[input_name] = next_node_input[input_name]
# Apply node input overrides
if node_input_mask:
idata.update(node_input_mask)
# Apply node credentials overrides
if node_credentials:
idata.update(
{k: v.model_dump() for k, v in node_credentials.items()}
)
idata, msg = validate_exec(next_node, idata)
suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}"
@@ -393,15 +429,16 @@ class Executor:
"""
@classmethod
@async_error_logged(swallow=True)
@async_error_logged
async def on_node_execution(
cls,
node_exec: NodeExecutionEntry,
node_exec_progress: NodeExecutionProgress,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
) -> NodeExecutionStats:
log_metadata = LogMetadata(
logger=_logger,
user_id=node_exec.user_id,
graph_eid=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
@@ -420,7 +457,7 @@ class Executor:
db_client=db_client,
log_metadata=log_metadata,
stats=execution_stats,
nodes_input_masks=nodes_input_masks,
node_credentials_input_map=node_credentials_input_map,
)
execution_stats.walltime = timing_info.wall_time
execution_stats.cputime = timing_info.cpu_time
@@ -443,7 +480,9 @@ class Executor:
db_client: "DatabaseManagerAsyncClient",
log_metadata: LogMetadata,
stats: NodeExecutionStats | None = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
):
try:
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
@@ -458,7 +497,7 @@ class Executor:
creds_manager=cls.creds_manager,
data=node_exec,
execution_stats=stats,
nodes_input_masks=nodes_input_masks,
node_credentials_input_map=node_credentials_input_map,
):
node_exec_progress.add_output(
ExecutionOutputEntry(
@@ -502,12 +541,11 @@ class Executor:
logger.info(f"[GraphExecutor] {cls.pid} started")
@classmethod
@error_logged(swallow=False)
@error_logged
def on_graph_execution(
cls, graph_exec: GraphExecutionEntry, cancel: threading.Event
):
log_metadata = LogMetadata(
logger=_logger,
user_id=graph_exec.user_id,
graph_eid=graph_exec.graph_exec_id,
graph_id=graph_exec.graph_id,
@@ -555,15 +593,6 @@ class Executor:
exec_stats.cputime += timing_info.cpu_time
exec_stats.error = str(error) if error else exec_stats.error
if status not in {
ExecutionStatus.COMPLETED,
ExecutionStatus.TERMINATED,
ExecutionStatus.FAILED,
}:
raise RuntimeError(
f"Graph Execution #{graph_exec.graph_exec_id} ended with unexpected status {status}"
)
if graph_exec_result := db_client.update_graph_execution_stats(
graph_exec_id=graph_exec.graph_exec_id,
status=status,
@@ -667,6 +696,7 @@ class Executor:
if _graph_exec := db_client.update_graph_execution_stats(
graph_exec_id=graph_exec.graph_exec_id,
status=execution_status,
stats=execution_stats,
):
send_execution_update(_graph_exec)
@@ -748,19 +778,24 @@ class Executor:
)
raise
# Add input overrides -----------------------------
# Add credential overrides -----------------------------
node_id = queued_node_exec.node_id
if (nodes_input_masks := graph_exec.nodes_input_masks) and (
node_input_mask := nodes_input_masks.get(node_id)
if (node_creds_map := graph_exec.node_credentials_input_map) and (
node_field_creds_map := node_creds_map.get(node_id)
):
queued_node_exec.inputs.update(node_input_mask)
queued_node_exec.inputs.update(
{
field_name: creds_meta.model_dump()
for field_name, creds_meta in node_field_creds_map.items()
}
)
# Kick off async node execution -------------------------
node_execution_task = asyncio.run_coroutine_threadsafe(
cls.on_node_execution(
node_exec=queued_node_exec,
node_exec_progress=running_node_execution[node_id],
nodes_input_masks=nodes_input_masks,
node_credentials_input_map=node_creds_map,
),
cls.node_execution_loop,
)
@@ -804,7 +839,7 @@ class Executor:
node_id=node_id,
graph_exec=graph_exec,
log_metadata=log_metadata,
nodes_input_masks=nodes_input_masks,
node_creds_map=node_creds_map,
execution_queue=execution_queue,
),
cls.node_evaluation_loop,
@@ -835,7 +870,6 @@ class Executor:
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
)
finally:
# Cancel and wait for all node executions to complete
for node_id, inflight_exec in running_node_execution.items():
if inflight_exec.is_done():
continue
@@ -848,28 +882,6 @@ class Executor:
log_metadata.info(f"Stopping node evaluation {node_id}")
inflight_eval.cancel()
for node_id, inflight_exec in running_node_execution.items():
if inflight_exec.is_done():
continue
try:
inflight_exec.wait_for_cancellation(timeout=60.0)
except TimeoutError:
log_metadata.exception(
f"Node execution #{node_id} did not stop in time, "
"it may be stuck or taking too long."
)
for node_id, inflight_eval in running_node_evaluation.items():
if inflight_eval.done():
continue
try:
inflight_eval.result(timeout=60.0)
except TimeoutError:
log_metadata.exception(
f"Node evaluation #{node_id} did not stop in time, "
"it may be stuck or taking too long."
)
if execution_status in [ExecutionStatus.TERMINATED, ExecutionStatus.FAILED]:
inflight_executions = db_client.get_node_executions(
graph_exec.graph_exec_id,
@@ -897,7 +909,7 @@ class Executor:
node_id: str,
graph_exec: GraphExecutionEntry,
log_metadata: LogMetadata,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
node_creds_map: Optional[dict[str, dict[str, CredentialsMetaInput]]],
execution_queue: ExecutionQueue[NodeExecutionEntry],
) -> None:
"""Process a node's output, update its status, and enqueue next nodes.
@@ -907,7 +919,7 @@ class Executor:
node_id: The ID of the node that produced the output
graph_exec: The graph execution entry
log_metadata: Logger metadata for consistent logging
nodes_input_masks: Optional map of node input overrides
node_creds_map: Optional map of node credentials
execution_queue: Queue to add next executions to
"""
db_client = get_db_async_client()
@@ -931,7 +943,7 @@ class Executor:
graph_exec_id=graph_exec.graph_exec_id,
graph_id=graph_exec.graph_id,
log_metadata=log_metadata,
nodes_input_masks=nodes_input_masks,
node_credentials_input_map=node_creds_map,
):
execution_queue.add(next_execution)
except Exception as e:

View File

@@ -3,7 +3,6 @@ import logging
import os
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Optional
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
@@ -15,16 +14,13 @@ from apscheduler.triggers.cron import CronTrigger
from autogpt_libs.utils.cache import thread_cached
from dotenv import load_dotenv
from prisma.enums import NotificationType
from pydantic import BaseModel, Field, ValidationError
from pydantic import BaseModel, ValidationError
from sqlalchemy import MetaData, create_engine
from backend.data.block import BlockInput
from backend.data.execution import ExecutionStatus
from backend.data.model import CredentialsMetaInput
from backend.executor import utils as execution_utils
from backend.notifications.notifications import NotificationManagerClient
from backend.util.exceptions import NotAuthorizedError, NotFoundError
from backend.util.logging import PrefixFilter
from backend.util.metrics import sentry_capture_error
from backend.util.service import (
AppService,
@@ -56,19 +52,19 @@ def _extract_schema_from_url(database_url) -> tuple[str, str]:
logger = logging.getLogger(__name__)
logger.addFilter(PrefixFilter("[Scheduler]"))
apscheduler_logger = logger.getChild("apscheduler")
apscheduler_logger.addFilter(PrefixFilter("[Scheduler] [APScheduler]"))
config = Config()
def log(msg, **kwargs):
logger.info("[Scheduler] " + msg, **kwargs)
def job_listener(event):
"""Logs job execution outcomes for better monitoring."""
if event.exception:
logger.error(f"Job {event.job_id} failed.")
log(f"Job {event.job_id} failed.")
else:
logger.info(f"Job {event.job_id} completed successfully.")
log(f"Job {event.job_id} completed successfully.")
@thread_cached
@@ -88,17 +84,16 @@ def execute_graph(**kwargs):
async def _execute_graph(**kwargs):
args = GraphExecutionJobArgs(**kwargs)
try:
logger.info(f"Executing recurring job for graph #{args.graph_id}")
log(f"Executing recurring job for graph #{args.graph_id}")
await execution_utils.add_graph_execution(
user_id=args.user_id,
graph_id=args.graph_id,
graph_version=args.graph_version,
inputs=args.input_data,
graph_credentials_inputs=args.input_credentials,
user_id=args.user_id,
graph_version=args.graph_version,
use_db_query=False,
)
except Exception as e:
logger.error(f"Error executing graph {args.graph_id}: {e}")
logger.exception(f"Error executing graph {args.graph_id}: {e}")
class LateExecutionException(Exception):
@@ -142,20 +137,20 @@ def report_late_executions() -> str:
def process_existing_batches(**kwargs):
args = NotificationJobArgs(**kwargs)
try:
logger.info(
log(
f"Processing existing batches for notification type {args.notification_types}"
)
get_notification_client().process_existing_batches(args.notification_types)
except Exception as e:
logger.error(f"Error processing existing batches: {e}")
logger.exception(f"Error processing existing batches: {e}")
def process_weekly_summary(**kwargs):
try:
logger.info("Processing weekly summary")
log("Processing weekly summary")
get_notification_client().queue_weekly_summary()
except Exception as e:
logger.error(f"Error processing weekly summary: {e}")
logger.exception(f"Error processing weekly summary: {e}")
class Jobstores(Enum):
@@ -165,12 +160,11 @@ class Jobstores(Enum):
class GraphExecutionJobArgs(BaseModel):
user_id: str
graph_id: str
input_data: BlockInput
user_id: str
graph_version: int
cron: str
input_data: BlockInput
input_credentials: dict[str, CredentialsMetaInput] = Field(default_factory=dict)
class GraphExecutionJobInfo(GraphExecutionJobArgs):
@@ -253,8 +247,7 @@ class Scheduler(AppService):
),
# These don't really need persistence
Jobstores.WEEKLY_NOTIFICATIONS.value: MemoryJobStore(),
},
logger=apscheduler_logger,
}
)
if self.register_system_tasks:
@@ -292,40 +285,34 @@ class Scheduler(AppService):
def cleanup(self):
super().cleanup()
logger.info("⏳ Shutting down scheduler...")
logger.info(f"[{self.service_name}] ⏳ Shutting down scheduler...")
if self.scheduler:
self.scheduler.shutdown(wait=False)
@expose
def add_graph_execution_schedule(
self,
user_id: str,
graph_id: str,
graph_version: int,
cron: str,
input_data: BlockInput,
input_credentials: dict[str, CredentialsMetaInput],
name: Optional[str] = None,
user_id: str,
) -> GraphExecutionJobInfo:
job_args = GraphExecutionJobArgs(
user_id=user_id,
graph_id=graph_id,
input_data=input_data,
user_id=user_id,
graph_version=graph_version,
cron=cron,
input_data=input_data,
input_credentials=input_credentials,
)
job = self.scheduler.add_job(
execute_graph,
CronTrigger.from_crontab(cron),
kwargs=job_args.model_dump(),
name=name,
trigger=CronTrigger.from_crontab(cron),
jobstore=Jobstores.EXECUTION.value,
replace_existing=True,
jobstore=Jobstores.EXECUTION.value,
)
logger.info(
f"Added job {job.id} with cron schedule '{cron}' input data: {input_data}"
)
log(f"Added job {job.id} with cron schedule '{cron}' input data: {input_data}")
return GraphExecutionJobInfo.from_db(job_args, job)
@expose
@@ -334,13 +321,14 @@ class Scheduler(AppService):
) -> GraphExecutionJobInfo:
job = self.scheduler.get_job(schedule_id, jobstore=Jobstores.EXECUTION.value)
if not job:
raise NotFoundError(f"Job #{schedule_id} not found.")
log(f"Job {schedule_id} not found.")
raise ValueError(f"Job #{schedule_id} not found.")
job_args = GraphExecutionJobArgs(**job.kwargs)
if job_args.user_id != user_id:
raise NotAuthorizedError("User ID does not match the job's user ID")
raise ValueError("User ID does not match the job's user ID.")
logger.info(f"Deleting job {schedule_id}")
log(f"Deleting job {schedule_id}")
job.remove()
return GraphExecutionJobInfo.from_db(job_args, job)

View File

@@ -1,15 +1,12 @@
import asyncio
import logging
import time
from collections import defaultdict
from concurrent.futures import Future
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
from autogpt_libs.utils.cache import thread_cached
from pydantic import BaseModel, JsonValue
from pydantic import BaseModel
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.block import (
Block,
BlockData,
@@ -26,8 +23,12 @@ from backend.data.execution import (
GraphExecutionStats,
GraphExecutionWithNodes,
RedisExecutionEventBus,
create_graph_execution,
get_node_executions,
update_graph_execution_stats,
update_node_execution_status_batch,
)
from backend.data.graph import GraphModel, Node
from backend.data.graph import GraphModel, Node, get_graph
from backend.data.model import CredentialsMetaInput
from backend.data.rabbitmq import (
AsyncRabbitMQ,
@@ -54,36 +55,6 @@ logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil
# ============ Resource Helpers ============ #
class LogMetadata(TruncatedLogger):
def __init__(
self,
logger: logging.Logger,
user_id: str,
graph_eid: str,
graph_id: str,
node_eid: str,
node_id: str,
block_name: str,
max_length: int = 1000,
):
metadata = {
"component": "ExecutionManager",
"user_id": user_id,
"graph_eid": graph_eid,
"graph_id": graph_id,
"node_eid": node_eid,
"node_id": node_id,
"block_name": block_name,
}
prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|neid:{node_eid}|{block_name}]"
super().__init__(
logger,
max_length=max_length,
prefix=prefix,
metadata=metadata,
)
@thread_cached
def get_execution_event_bus() -> RedisExecutionEventBus:
return RedisExecutionEventBus()
@@ -431,6 +402,12 @@ def validate_exec(
return None, f"Block for {node.block_id} not found."
schema = node_block.input_schema
# Convert non-matching data types to the expected input schema.
for name, data_type in schema.__annotations__.items():
value = data.get(name)
if (value is not None) and (type(value) is not data_type):
data[name] = convert(value, data_type)
# Input data (without default values) should contain all required fields.
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
if missing_links := schema.get_missing_links(data, node.input_links):
@@ -442,12 +419,6 @@ def validate_exec(
if resolve_input:
data = merge_execution_input(data)
# Convert non-matching data types to the expected input schema.
for name, data_type in schema.__annotations__.items():
value = data.get(name)
if (value is not None) and (type(value) is not data_type):
data[name] = convert(value, data_type)
# Input data post-merge should contain all required fields from the schema.
if missing_input := schema.get_missing_input(data):
return None, f"{error_prefix} missing input {missing_input}"
@@ -464,7 +435,9 @@ def validate_exec(
async def _validate_node_input_credentials(
graph: GraphModel,
user_id: str,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
):
"""Checks all credentials for all nodes of the graph"""
@@ -480,13 +453,11 @@ async def _validate_node_input_credentials(
for field_name, credentials_meta_type in credentials_fields.items():
if (
nodes_input_masks
and (node_input_mask := nodes_input_masks.get(node.id))
and field_name in node_input_mask
node_credentials_input_map
and (node_credentials_inputs := node_credentials_input_map.get(node.id))
and field_name in node_credentials_inputs
):
credentials_meta = credentials_meta_type.model_validate(
node_input_mask[field_name]
)
credentials_meta = node_credentials_input_map[node.id][field_name]
elif field_name in node.input_default:
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
@@ -525,7 +496,7 @@ async def _validate_node_input_credentials(
def make_node_credentials_input_map(
graph: GraphModel,
graph_credentials_input: dict[str, CredentialsMetaInput],
) -> dict[str, dict[str, JsonValue]]:
) -> dict[str, dict[str, CredentialsMetaInput]]:
"""
Maps credentials for an execution to the correct nodes.
@@ -534,9 +505,9 @@ def make_node_credentials_input_map(
graph_credentials_input: A (graph_input_name, credentials_meta) map.
Returns:
dict[node_id, dict[field_name, CredentialsMetaRaw]]: Node credentials input map.
dict[node_id, dict[field_name, CredentialsMetaInput]]: Node credentials input map.
"""
result: dict[str, dict[str, JsonValue]] = {}
result: dict[str, dict[str, CredentialsMetaInput]] = {}
# Get aggregated credentials fields for the graph
graph_cred_inputs = graph.aggregate_credentials_inputs()
@@ -550,9 +521,7 @@ def make_node_credentials_input_map(
for node_id, node_field_name in compatible_node_fields:
if node_id not in result:
result[node_id] = {}
result[node_id][node_field_name] = graph_credentials_input[
graph_input_name
].model_dump(exclude_none=True)
result[node_id][node_field_name] = graph_credentials_input[graph_input_name]
return result
@@ -561,7 +530,9 @@ async def construct_node_execution_input(
graph: GraphModel,
user_id: str,
graph_inputs: BlockInput,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
) -> list[tuple[str, BlockInput]]:
"""
Validates and prepares the input data for executing a graph.
@@ -579,8 +550,8 @@ async def construct_node_execution_input(
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
the corresponding input data for that node.
"""
graph.validate_graph(for_run=True, nodes_input_masks=nodes_input_masks)
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
graph.validate_graph(for_run=True)
await _validate_node_input_credentials(graph, user_id, node_credentials_input_map)
nodes_input = []
for node in graph.starting_nodes:
@@ -597,9 +568,23 @@ async def construct_node_execution_input(
if input_name and input_name in graph_inputs:
input_data = {"value": graph_inputs[input_name]}
# Apply node input overrides
if nodes_input_masks and (node_input_mask := nodes_input_masks.get(node.id)):
input_data.update(node_input_mask)
# Extract webhook payload, and assign it to the input pin
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
if (
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
and node.webhook_id
):
if webhook_payload_key not in graph_inputs:
raise ValueError(
f"Node {block.name} #{node.id} webhook payload is missing"
)
input_data = {"payload": graph_inputs[webhook_payload_key]}
# Apply node credentials overrides
if node_credentials_input_map and (
node_credentials := node_credentials_input_map.get(node.id)
):
input_data.update({k: v.model_dump() for k, v in node_credentials.items()})
input_data, error = validate_exec(node, input_data)
if input_data is None:
@@ -615,20 +600,6 @@ async def construct_node_execution_input(
return nodes_input
def _merge_nodes_input_masks(
overrides_map_1: dict[str, dict[str, JsonValue]],
overrides_map_2: dict[str, dict[str, JsonValue]],
) -> dict[str, dict[str, JsonValue]]:
"""Perform a per-node merge of input overrides"""
result = overrides_map_1.copy()
for node_id, overrides2 in overrides_map_2.items():
if node_id in result:
result[node_id] = {**result[node_id], **overrides2}
else:
result[node_id] = overrides2
return result
# ============ Execution Queue Helpers ============ #
@@ -682,10 +653,8 @@ def create_execution_queue_config() -> RabbitMQConfig:
async def stop_graph_execution(
user_id: str,
graph_exec_id: str,
use_db_query: bool = True,
wait_timeout: float = 60.0,
):
"""
Mechanism:
@@ -695,67 +664,79 @@ async def stop_graph_execution(
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
"""
queue_client = await get_async_execution_queue()
db = execution_db if use_db_query else get_db_async_client()
await queue_client.publish_message(
routing_key="",
message=CancelExecutionEvent(graph_exec_id=graph_exec_id).model_dump_json(),
exchange=GRAPH_EXECUTION_CANCEL_EXCHANGE,
)
if not wait_timeout:
return
start_time = time.time()
while time.time() - start_time < wait_timeout:
graph_exec = await db.get_graph_execution_meta(
execution_id=graph_exec_id, user_id=user_id
# Update the status of the graph execution
if use_db_query:
graph_execution = await update_graph_execution_stats(
graph_exec_id,
ExecutionStatus.TERMINATED,
)
else:
graph_execution = await get_db_async_client().update_graph_execution_stats(
graph_exec_id,
ExecutionStatus.TERMINATED,
)
if not graph_exec:
raise NotFoundError(f"Graph execution #{graph_exec_id} not found.")
if graph_execution:
await get_async_execution_event_bus().publish(graph_execution)
else:
raise NotFoundError(
f"Graph execution #{graph_exec_id} not found for termination."
)
if graph_exec.status in [
# Update the status of the node executions
if use_db_query:
node_executions = await get_node_executions(
graph_exec_id=graph_exec_id,
statuses=[
ExecutionStatus.QUEUED,
ExecutionStatus.RUNNING,
ExecutionStatus.INCOMPLETE,
],
)
await update_node_execution_status_batch(
[v.node_exec_id for v in node_executions],
ExecutionStatus.TERMINATED,
ExecutionStatus.COMPLETED,
ExecutionStatus.FAILED,
]:
# If graph execution is terminated/completed/failed, cancellation is complete
return
)
else:
node_executions = await get_db_async_client().get_node_executions(
graph_exec_id=graph_exec_id,
statuses=[
ExecutionStatus.QUEUED,
ExecutionStatus.RUNNING,
ExecutionStatus.INCOMPLETE,
],
)
await get_db_async_client().update_node_execution_status_batch(
[v.node_exec_id for v in node_executions],
ExecutionStatus.TERMINATED,
)
elif graph_exec.status in [
ExecutionStatus.QUEUED,
ExecutionStatus.INCOMPLETE,
]:
# If the graph is still on the queue, we can prevent them from being executed
# by setting the status to TERMINATED.
node_execs = await db.get_node_executions(
graph_exec_id=graph_exec_id,
statuses=[ExecutionStatus.QUEUED, ExecutionStatus.INCOMPLETE],
await asyncio.gather(
*[
get_async_execution_event_bus().publish(
v.model_copy(update={"status": ExecutionStatus.TERMINATED})
)
await db.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in node_execs],
ExecutionStatus.TERMINATED,
)
await db.update_graph_execution_stats(
graph_exec_id=graph_exec_id,
status=ExecutionStatus.TERMINATED,
)
await asyncio.sleep(1.0)
raise TimeoutError(
f"Timed out waiting for graph execution #{graph_exec_id} to terminate."
for v in node_executions
]
)
async def add_graph_execution(
graph_id: str,
user_id: str,
inputs: Optional[BlockInput] = None,
inputs: BlockInput,
preset_id: Optional[str] = None,
graph_version: Optional[int] = None,
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
use_db_query: bool = True,
) -> GraphExecutionWithNodes:
"""
@@ -769,52 +750,68 @@ async def add_graph_execution(
graph_version: The version of the graph to execute.
graph_credentials_inputs: Credentials inputs to use in the execution.
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
nodes_input_masks: Node inputs to use in the execution.
node_credentials_input_map: Credentials inputs to use in the execution, mapped to specific nodes.
Returns:
GraphExecutionEntry: The entry for the graph execution.
Raises:
ValueError: If the graph is not found or if there are validation errors.
"""
gdb = graph_db if use_db_query else get_db_async_client()
edb = execution_db if use_db_query else get_db_async_client()
""" # noqa
if use_db_query:
graph: GraphModel | None = await get_graph(
graph_id=graph_id,
user_id=user_id,
version=graph_version,
include_subgraphs=True,
)
else:
graph: GraphModel | None = await get_db_async_client().get_graph(
graph_id=graph_id,
user_id=user_id,
version=graph_version,
include_subgraphs=True,
)
graph: GraphModel | None = await gdb.get_graph(
graph_id=graph_id,
user_id=user_id,
version=graph_version,
include_subgraphs=True,
)
if not graph:
raise NotFoundError(f"Graph #{graph_id} not found.")
nodes_input_masks = _merge_nodes_input_masks(
(
make_node_credentials_input_map(graph, graph_credentials_inputs)
if graph_credentials_inputs
else {}
),
nodes_input_masks or {},
)
starting_nodes_input = await construct_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=inputs or {},
nodes_input_masks=nodes_input_masks,
node_credentials_input_map = node_credentials_input_map or (
make_node_credentials_input_map(graph, graph_credentials_inputs)
if graph_credentials_inputs
else None
)
graph_exec = await edb.create_graph_execution(
user_id=user_id,
graph_id=graph_id,
graph_version=graph.version,
starting_nodes_input=starting_nodes_input,
preset_id=preset_id,
)
if use_db_query:
graph_exec = await create_graph_execution(
user_id=user_id,
graph_id=graph_id,
graph_version=graph.version,
starting_nodes_input=await construct_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=inputs,
node_credentials_input_map=node_credentials_input_map,
),
preset_id=preset_id,
)
else:
graph_exec = await get_db_async_client().create_graph_execution(
user_id=user_id,
graph_id=graph_id,
graph_version=graph.version,
starting_nodes_input=await construct_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=inputs,
node_credentials_input_map=node_credentials_input_map,
),
preset_id=preset_id,
)
try:
queue = await get_async_execution_queue()
graph_exec_entry = graph_exec.to_graph_execution_entry()
if nodes_input_masks:
graph_exec_entry.nodes_input_masks = nodes_input_masks
if node_credentials_input_map:
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
await queue.publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
message=graph_exec_entry.model_dump_json(),
@@ -827,15 +824,28 @@ async def add_graph_execution(
return graph_exec
except Exception as e:
logger.error(f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {e}")
await edb.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
ExecutionStatus.FAILED,
)
await edb.update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=ExecutionStatus.FAILED,
stats=GraphExecutionStats(error=str(e)),
)
if use_db_query:
await update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
ExecutionStatus.FAILED,
)
await update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=ExecutionStatus.FAILED,
stats=GraphExecutionStats(error=str(e)),
)
else:
await get_db_async_client().update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
ExecutionStatus.FAILED,
)
await get_db_async_client().update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=ExecutionStatus.FAILED,
stats=GraphExecutionStats(error=str(e)),
)
raise
@@ -890,10 +900,14 @@ class NodeExecutionProgress:
try:
self.tasks[exec_id].result(wait_time)
except TimeoutError:
print(
">>>>>>> -- Timeout, after waiting for",
wait_time,
"seconds for node_id",
exec_id,
)
pass
except Exception as e:
logger.error(f"Task for exec ID {exec_id} failed with error: {str(e)}")
pass
return self.is_done(0)
def stop(self) -> list[str]:
@@ -910,25 +924,6 @@ class NodeExecutionProgress:
cancelled_ids.append(task_id)
return cancelled_ids
def wait_for_cancellation(self, timeout: float = 5.0):
"""
Wait for all cancelled tasks to complete cancellation.
Args:
timeout: Maximum time to wait for cancellation in seconds
"""
start_time = time.time()
while time.time() - start_time < timeout:
# Check if all tasks are done (either completed or cancelled)
if all(task.done() for task in self.tasks.values()):
return True
time.sleep(0.1) # Small delay to avoid busy waiting
raise TimeoutError(
f"Timeout waiting for cancellation of tasks: {list(self.tasks.keys())}"
)
def _pop_done_task(self, exec_id: str) -> bool:
task = self.tasks.get(exec_id)
if not task:
@@ -941,10 +936,8 @@ class NodeExecutionProgress:
return False
if task := self.tasks.pop(exec_id):
try:
self.on_done_task(exec_id, task.result())
except Exception as e:
logger.error(f"Task for exec ID {exec_id} failed with error: {str(e)}")
self.on_done_task(exec_id, task.result())
return True
def _next_exec(self) -> str | None:

View File

@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
from pydantic import SecretStr
from backend.data.redis_client import get_redis_async
from backend.data.redis import get_redis_async
if TYPE_CHECKING:
from backend.executor.database import DatabaseManagerAsyncClient

View File

@@ -7,7 +7,7 @@ from autogpt_libs.utils.synchronize import AsyncRedisKeyedMutex
from redis.asyncio.lock import Lock as AsyncRedisLock
from backend.data.model import Credentials, OAuth2Credentials
from backend.data.redis_client import get_redis_async
from backend.data.redis import get_redis_async
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName

View File

@@ -17,7 +17,6 @@ class ProviderName(str, Enum):
GOOGLE = "google"
GOOGLE_MAPS = "google_maps"
GROQ = "groq"
HTTP = "http"
HUBSPOT = "hubspot"
IDEOGRAM = "ideogram"
JINA = "jina"

View File

@@ -1,22 +1,23 @@
import functools
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ..providers import ProviderName
from ._base import BaseWebhooksManager
_WEBHOOK_MANAGERS: dict["ProviderName", type["BaseWebhooksManager"]] = {}
# --8<-- [start:load_webhook_managers]
@functools.cache
def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]:
webhook_managers = {}
if _WEBHOOK_MANAGERS:
return _WEBHOOK_MANAGERS
from .compass import CompassWebhookManager
from .generic import GenericWebhooksManager
from .github import GithubWebhooksManager
from .slant3d import Slant3DWebhooksManager
webhook_managers.update(
_WEBHOOK_MANAGERS.update(
{
handler.PROVIDER_NAME: handler
for handler in [
@@ -27,7 +28,7 @@ def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]
]
}
)
return webhook_managers
return _WEBHOOK_MANAGERS
# --8<-- [end:load_webhook_managers]

View File

@@ -7,14 +7,13 @@ from uuid import uuid4
from fastapi import Request
from strenum import StrEnum
import backend.data.integrations as integrations
from backend.data import integrations
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.utils import webhook_ingress_url
from backend.util.exceptions import MissingConfigError
from backend.util.settings import Config
from .utils import webhook_ingress_url
logger = logging.getLogger(__name__)
app_config = Config()
@@ -42,74 +41,44 @@ class BaseWebhooksManager(ABC, Generic[WT]):
)
if webhook := await integrations.find_webhook_by_credentials_and_props(
user_id=user_id,
credentials_id=credentials.id,
webhook_type=webhook_type,
resource=resource,
events=events,
credentials.id, webhook_type, resource, events
):
return webhook
return await self._create_webhook(
user_id=user_id,
webhook_type=webhook_type,
events=events,
resource=resource,
credentials=credentials,
user_id, webhook_type, events, resource, credentials
)
async def get_manual_webhook(
self,
user_id: str,
graph_id: str,
webhook_type: WT,
events: list[str],
graph_id: Optional[str] = None,
preset_id: Optional[str] = None,
) -> integrations.Webhook:
"""
Tries to find an existing webhook tied to `graph_id`/`preset_id`,
or creates a new webhook if none exists.
Existing webhooks are matched by `user_id`, `webhook_type`,
and `graph_id`/`preset_id`.
If an existing webhook is found, we check if the events match and update them
if necessary. We do this rather than creating a new webhook
to avoid changing the webhook URL for existing manual webhooks.
"""
if (graph_id or preset_id) and (
current_webhook := await integrations.find_webhook_by_graph_and_props(
user_id=user_id,
provider=self.PROVIDER_NAME.value,
webhook_type=webhook_type,
graph_id=graph_id,
preset_id=preset_id,
)
):
if current_webhook := await integrations.find_webhook_by_graph_and_props(
graph_id, self.PROVIDER_NAME, webhook_type, events
):
if set(current_webhook.events) != set(events):
current_webhook = await integrations.update_webhook(
current_webhook.id, events=events
)
return current_webhook
return await self._create_webhook(
user_id=user_id,
webhook_type=webhook_type,
events=events,
user_id,
webhook_type,
events,
register=False,
)
async def prune_webhook_if_dangling(
self, user_id: str, webhook_id: str, credentials: Optional[Credentials]
self, webhook_id: str, credentials: Optional[Credentials]
) -> bool:
webhook = await integrations.get_webhook(webhook_id, include_relations=True)
if webhook.triggered_nodes or webhook.triggered_presets:
webhook = await integrations.get_webhook(webhook_id)
if webhook.attached_nodes is None:
raise ValueError("Error retrieving webhook including attached nodes")
if webhook.attached_nodes:
# Don't prune webhook if in use
return False
if credentials:
await self._deregister_webhook(webhook, credentials)
await integrations.delete_webhook(user_id, webhook.id)
await integrations.delete_webhook(webhook.id)
return True
# --8<-- [start:BaseWebhooksManager3]

View File

@@ -1,12 +1,10 @@
import logging
from typing import TYPE_CHECKING, Optional, cast
from backend.data.block import BlockSchema
from backend.data.block import BlockSchema, BlockWebhookConfig
from backend.data.graph import set_node_webhook
from backend.integrations.creds_manager import IntegrationCredentialsManager
from . import get_webhook_manager, supports_webhooks
from .utils import setup_webhook_for_block
from backend.integrations.webhooks import get_webhook_manager, supports_webhooks
if TYPE_CHECKING:
from backend.data.graph import GraphModel, NodeModel
@@ -83,9 +81,7 @@ async def on_graph_deactivate(graph: "GraphModel", user_id: str):
f"credentials #{creds_meta['id']}"
)
updated_node = await on_node_deactivate(
user_id, node, credentials=node_credentials
)
updated_node = await on_node_deactivate(node, credentials=node_credentials)
updated_nodes.append(updated_node)
graph.nodes = updated_nodes
@@ -100,25 +96,105 @@ async def on_node_activate(
) -> "NodeModel":
"""Hook to be called when the node is activated/created"""
if node.block.webhook_config:
new_webhook, feedback = await setup_webhook_for_block(
user_id=user_id,
trigger_block=node.block,
trigger_config=node.input_default,
for_graph_id=node.graph_id,
block = node.block
if not block.webhook_config:
return node
provider = block.webhook_config.provider
if not supports_webhooks(provider):
raise ValueError(
f"Block #{block.id} has webhook_config for provider {provider} "
"which does not support webhooks"
)
if new_webhook:
node = await set_node_webhook(node.id, new_webhook.id)
else:
logger.debug(
f"Node #{node.id} does not have everything for a webhook: {feedback}"
logger.debug(
f"Activating webhook node #{node.id} with config {block.webhook_config}"
)
webhooks_manager = get_webhook_manager(provider)
if auto_setup_webhook := isinstance(block.webhook_config, BlockWebhookConfig):
try:
resource = block.webhook_config.resource_format.format(**node.input_default)
except KeyError:
resource = None
logger.debug(
f"Constructed resource string {resource} from input {node.input_default}"
)
else:
resource = "" # not relevant for manual webhooks
block_input_schema = cast(BlockSchema, block.input_schema)
credentials_field_name = next(iter(block_input_schema.get_credentials_fields()), "")
credentials_meta = (
node.input_default.get(credentials_field_name)
if credentials_field_name
else None
)
event_filter_input_name = block.webhook_config.event_filter_input
has_everything_for_webhook = (
resource is not None
and (credentials_meta or not credentials_field_name)
and (
not event_filter_input_name
or (
event_filter_input_name in node.input_default
and any(
is_on
for is_on in node.input_default[event_filter_input_name].values()
)
)
)
)
if has_everything_for_webhook and resource is not None:
logger.debug(f"Node #{node} has everything for a webhook!")
if credentials_meta and not credentials:
raise ValueError(
f"Cannot set up webhook for node #{node.id}: "
f"credentials #{credentials_meta['id']} not available"
)
if event_filter_input_name:
# Shape of the event filter is enforced in Block.__init__
event_filter = cast(dict, node.input_default[event_filter_input_name])
events = [
block.webhook_config.event_format.format(event=event)
for event, enabled in event_filter.items()
if enabled is True
]
logger.debug(f"Webhook events to subscribe to: {', '.join(events)}")
else:
events = []
# Find/make and attach a suitable webhook to the node
if auto_setup_webhook:
assert credentials is not None
new_webhook = await webhooks_manager.get_suitable_auto_webhook(
user_id,
credentials,
block.webhook_config.webhook_type,
resource,
events,
)
else:
# Manual webhook -> no credentials -> don't register but do create
new_webhook = await webhooks_manager.get_manual_webhook(
user_id,
node.graph_id,
block.webhook_config.webhook_type,
events,
)
logger.debug(f"Acquired webhook: {new_webhook}")
return await set_node_webhook(node.id, new_webhook.id)
else:
logger.debug(f"Node #{node.id} does not have everything for a webhook")
return node
async def on_node_deactivate(
user_id: str,
node: "NodeModel",
*,
credentials: Optional["Credentials"] = None,
@@ -157,9 +233,7 @@ async def on_node_deactivate(
f"Pruning{' and deregistering' if credentials else ''} "
f"webhook #{webhook.id}"
)
await webhooks_manager.prune_webhook_if_dangling(
user_id, webhook.id, credentials
)
await webhooks_manager.prune_webhook_if_dangling(webhook.id, credentials)
if (
cast(BlockSchema, block.input_schema).get_credentials_fields()
and not credentials

View File

@@ -1,22 +1,7 @@
import logging
from typing import TYPE_CHECKING, Optional, cast
from pydantic import JsonValue
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.settings import Config
from . import get_webhook_manager, supports_webhooks
if TYPE_CHECKING:
from backend.data.block import Block, BlockSchema
from backend.data.integrations import Webhook
from backend.data.model import Credentials
logger = logging.getLogger(__name__)
app_config = Config()
credentials_manager = IntegrationCredentialsManager()
# TODO: add test to assert this matches the actual API route
@@ -25,122 +10,3 @@ def webhook_ingress_url(provider_name: ProviderName, webhook_id: str) -> str:
f"{app_config.platform_base_url}/api/integrations/{provider_name.value}"
f"/webhooks/{webhook_id}/ingress"
)
async def setup_webhook_for_block(
user_id: str,
trigger_block: "Block[BlockSchema, BlockSchema]",
trigger_config: dict[str, JsonValue], # = Trigger block inputs
for_graph_id: Optional[str] = None,
for_preset_id: Optional[str] = None,
credentials: Optional["Credentials"] = None,
) -> tuple["Webhook", None] | tuple[None, str]:
"""
Utility function to create (and auto-setup if possible) a webhook for a given provider.
Returns:
Webhook: The created or found webhook object, if successful.
str: A feedback message, if any required inputs are missing.
"""
from backend.data.block import BlockWebhookConfig
if not (trigger_base_config := trigger_block.webhook_config):
raise ValueError(f"Block #{trigger_block.id} does not have a webhook_config")
provider = trigger_base_config.provider
if not supports_webhooks(provider):
raise NotImplementedError(
f"Block #{trigger_block.id} has webhook_config for provider {provider} "
"for which we do not have a WebhooksManager"
)
logger.debug(
f"Setting up webhook for block #{trigger_block.id} with config {trigger_config}"
)
# Check & parse the event filter input, if any
events: list[str] = []
if event_filter_input_name := trigger_base_config.event_filter_input:
if not (event_filter := trigger_config.get(event_filter_input_name)):
return None, (
f"Cannot set up {provider.value} webhook without event filter input: "
f"missing input for '{event_filter_input_name}'"
)
elif not (
# Shape of the event filter is enforced in Block.__init__
any((event_filter := cast(dict[str, bool], event_filter)).values())
):
return None, (
f"Cannot set up {provider.value} webhook without any enabled events "
f"in event filter input '{event_filter_input_name}'"
)
events = [
trigger_base_config.event_format.format(event=event)
for event, enabled in event_filter.items()
if enabled is True
]
logger.debug(f"Webhook events to subscribe to: {', '.join(events)}")
# Check & process prerequisites for auto-setup webhooks
if auto_setup_webhook := isinstance(trigger_base_config, BlockWebhookConfig):
try:
resource = trigger_base_config.resource_format.format(**trigger_config)
except KeyError as missing_key:
return None, (
f"Cannot auto-setup {provider.value} webhook without resource: "
f"missing input for '{missing_key}'"
)
logger.debug(
f"Constructed resource string {resource} from input {trigger_config}"
)
creds_field_name = next(
# presence of this field is enforced in Block.__init__
iter(trigger_block.input_schema.get_credentials_fields())
)
if not (
credentials_meta := cast(dict, trigger_config.get(creds_field_name, None))
):
return None, f"Cannot set up {provider.value} webhook without credentials"
elif not (
credentials := credentials
or await credentials_manager.get(user_id, credentials_meta["id"])
):
raise ValueError(
f"Cannot set up {provider.value} webhook without credentials: "
f"credentials #{credentials_meta['id']} not found for user #{user_id}"
)
elif credentials.provider != provider:
raise ValueError(
f"Credentials #{credentials.id} do not match provider {provider.value}"
)
else:
# not relevant for manual webhooks:
resource = ""
credentials = None
webhooks_manager = get_webhook_manager(provider)
# Find/make and attach a suitable webhook to the node
if auto_setup_webhook:
assert credentials is not None
webhook = await webhooks_manager.get_suitable_auto_webhook(
user_id=user_id,
credentials=credentials,
webhook_type=trigger_base_config.webhook_type,
resource=resource,
events=events,
)
else:
# Manual webhook -> no credentials -> don't register but do create
webhook = await webhooks_manager.get_manual_webhook(
user_id=user_id,
webhook_type=trigger_base_config.webhook_type,
events=events,
graph_id=for_graph_id,
preset_id=for_preset_id,
)
logger.debug(f"Acquired webhook: {webhook}")
return webhook, None

View File

@@ -1,7 +1,5 @@
from fastapi import FastAPI
from backend.server.middleware.security import SecurityHeadersMiddleware
from .routes.v1 import v1_router
external_app = FastAPI(
@@ -10,6 +8,4 @@ external_app = FastAPI(
docs_url="/docs",
version="1.0",
)
external_app.add_middleware(SecurityHeadersMiddleware)
external_app.include_router(v1_router, prefix="/v1")

View File

@@ -2,19 +2,11 @@ import asyncio
import logging
from typing import TYPE_CHECKING, Annotated, Awaitable, Literal
from fastapi import (
APIRouter,
Body,
Depends,
HTTPException,
Path,
Query,
Request,
status,
)
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
from pydantic import BaseModel, Field
from starlette.status import HTTP_404_NOT_FOUND
from backend.data.graph import get_graph, set_node_webhook
from backend.data.graph import set_node_webhook
from backend.data.integrations import (
WebhookEvent,
get_all_webhooks_by_creds,
@@ -22,18 +14,12 @@ from backend.data.integrations import (
publish_webhook_event,
wait_for_webhook_event,
)
from backend.data.model import (
Credentials,
CredentialsType,
HostScopedCredentials,
OAuth2Credentials,
)
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
from backend.executor.utils import add_graph_execution
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
from backend.server.v2.library.db import set_preset_webhook, update_preset
from backend.util.exceptions import NeedConfirmation, NotFoundError
from backend.util.settings import Settings
@@ -87,9 +73,6 @@ class CredentialsMetaResponse(BaseModel):
title: str | None
scopes: list[str] | None
username: str | None
host: str | None = Field(
default=None, description="Host pattern for host-scoped credentials"
)
@router.post("/{provider}/callback")
@@ -112,10 +95,7 @@ async def callback(
if not valid_state:
logger.warning(f"Invalid or expired state token for user {user_id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid or expired state token",
)
raise HTTPException(status_code=400, detail="Invalid or expired state token")
try:
scopes = valid_state.scopes
logger.debug(f"Retrieved scopes from state token: {scopes}")
@@ -142,12 +122,17 @@ async def callback(
)
except Exception as e:
logger.error(
f"OAuth2 Code->Token exchange failed for provider {provider.value}: {e}"
logger.exception(
"OAuth callback for provider %s failed during code exchange: %s. Confirm provider credentials.",
provider.value,
e,
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"OAuth2 callback failed to exchange code for tokens: {str(e)}",
status_code=400,
detail={
"message": str(e),
"hint": "Verify OAuth configuration and try again.",
},
)
# TODO: Allow specifying `title` to set on `credentials`
@@ -164,9 +149,6 @@ async def callback(
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=(
credentials.host if isinstance(credentials, HostScopedCredentials) else None
),
)
@@ -183,7 +165,6 @@ async def list_credentials(
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
]
@@ -205,7 +186,6 @@ async def list_credentials_by_provider(
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
]
@@ -221,13 +201,10 @@ async def get_credential(
) -> Credentials:
credential = await creds_manager.get(user_id, cred_id)
if not credential:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
raise HTTPException(status_code=404, detail="Credentials not found")
if credential.provider != provider:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Credentials do not match the specified provider",
status_code=404, detail="Credentials do not match the specified provider"
)
return credential
@@ -245,8 +222,7 @@ async def create_credentials(
await creds_manager.create(user_id, credentials)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to store credentials: {str(e)}",
status_code=500, detail=f"Failed to store credentials: {str(e)}"
)
return credentials
@@ -280,17 +256,14 @@ async def delete_credentials(
) -> CredentialsDeletionResponse | CredentialsDeletionNeedsConfirmationResponse:
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
if not creds:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
raise HTTPException(status_code=404, detail="Credentials not found")
if creds.provider != provider:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Credentials do not match the specified provider",
status_code=404, detail="Credentials do not match the specified provider"
)
try:
await remove_all_webhooks_for_credentials(user_id, creds, force)
await remove_all_webhooks_for_credentials(creds, force)
except NeedConfirmation as e:
return CredentialsDeletionNeedsConfirmationResponse(message=str(e))
@@ -321,10 +294,16 @@ async def webhook_ingress_generic(
logger.debug(f"Received {provider.value} webhook ingress for ID {webhook_id}")
webhook_manager = get_webhook_manager(provider)
try:
webhook = await get_webhook(webhook_id, include_relations=True)
webhook = await get_webhook(webhook_id)
except NotFoundError as e:
logger.warning(f"Webhook payload received for unknown webhook #{webhook_id}")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
logger.warning(
"Webhook payload received for unknown webhook %s. Confirm the webhook ID.",
webhook_id,
)
raise HTTPException(
status_code=HTTP_404_NOT_FOUND,
detail={"message": str(e), "hint": "Check if the webhook ID is correct."},
) from e
logger.debug(f"Webhook #{webhook_id}: {webhook}")
payload, event_type = await webhook_manager.validate_payload(webhook, request)
logger.debug(
@@ -341,11 +320,11 @@ async def webhook_ingress_generic(
await publish_webhook_event(webhook_event)
logger.debug(f"Webhook event published: {webhook_event}")
if not (webhook.triggered_nodes or webhook.triggered_presets):
if not webhook.attached_nodes:
return
executions: list[Awaitable] = []
for node in webhook.triggered_nodes:
for node in webhook.attached_nodes:
logger.debug(f"Webhook-attached node: {node}")
if not node.is_triggered_by_event_type(event_type):
logger.debug(f"Node #{node.id} doesn't trigger on event {event_type}")
@@ -356,48 +335,7 @@ async def webhook_ingress_generic(
user_id=webhook.user_id,
graph_id=node.graph_id,
graph_version=node.graph_version,
nodes_input_masks={node.id: {"payload": payload}},
)
)
for preset in webhook.triggered_presets:
logger.debug(f"Webhook-attached preset: {preset}")
if not preset.is_active:
logger.debug(f"Preset #{preset.id} is inactive")
continue
graph = await get_graph(preset.graph_id, preset.graph_version, webhook.user_id)
if not graph:
logger.error(
f"User #{webhook.user_id} has preset #{preset.id} for graph "
f"#{preset.graph_id} v{preset.graph_version}, "
"but no access to the graph itself."
)
logger.info(f"Automatically deactivating broken preset #{preset.id}")
await update_preset(preset.user_id, preset.id, is_active=False)
continue
if not (trigger_node := graph.webhook_input_node):
# NOTE: this should NEVER happen, but we log and handle it gracefully
logger.error(
f"Preset #{preset.id} is triggered by webhook #{webhook.id}, but graph "
f"#{preset.graph_id} v{preset.graph_version} has no webhook input node"
)
await set_preset_webhook(preset.user_id, preset.id, None)
continue
if not trigger_node.block.is_triggered_by_event_type(preset.inputs, event_type):
logger.debug(f"Preset #{preset.id} doesn't trigger on event {event_type}")
continue
logger.debug(f"Executing preset #{preset.id} for webhook #{webhook.id}")
executions.append(
add_graph_execution(
user_id=webhook.user_id,
graph_id=preset.graph_id,
preset_id=preset.id,
graph_version=preset.graph_version,
graph_credentials_inputs=preset.credentials,
nodes_input_masks={
trigger_node.id: {**preset.inputs, "payload": payload}
},
inputs={f"webhook_{webhook_id}_payload": payload},
)
)
asyncio.gather(*executions)
@@ -422,9 +360,7 @@ async def webhook_ping(
return False
if not await wait_for_webhook_event(webhook_id, event_type="ping", timeout=10):
raise HTTPException(
status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="Webhook ping timed out"
)
raise HTTPException(status_code=504, detail="Webhook ping timed out")
return True
@@ -433,37 +369,32 @@ async def webhook_ping(
async def remove_all_webhooks_for_credentials(
user_id: str, credentials: Credentials, force: bool = False
credentials: Credentials, force: bool = False
) -> None:
"""
Remove and deregister all webhooks that were registered using the given credentials.
Params:
user_id: The ID of the user who owns the credentials and webhooks.
credentials: The credentials for which to remove the associated webhooks.
force: Whether to proceed if any of the webhooks are still in use.
Raises:
NeedConfirmation: If any of the webhooks are still in use and `force` is `False`
"""
webhooks = await get_all_webhooks_by_creds(
user_id, credentials.id, include_relations=True
)
if any(w.triggered_nodes or w.triggered_presets for w in webhooks) and not force:
webhooks = await get_all_webhooks_by_creds(credentials.id)
if any(w.attached_nodes for w in webhooks) and not force:
raise NeedConfirmation(
"Some webhooks linked to these credentials are still in use by an agent"
)
for webhook in webhooks:
# Unlink all nodes & presets
for node in webhook.triggered_nodes:
# Unlink all nodes
for node in webhook.attached_nodes or []:
await set_node_webhook(node.id, None)
for preset in webhook.triggered_presets:
await set_preset_webhook(user_id, preset.id, None)
# Prune the webhook
webhook_manager = get_webhook_manager(ProviderName(credentials.provider))
success = await webhook_manager.prune_webhook_if_dangling(
user_id, webhook.id, credentials
webhook.id, credentials
)
if not success:
logger.warning(f"Webhook #{webhook.id} failed to prune")
@@ -474,7 +405,7 @@ def _get_provider_oauth_handler(
) -> "BaseOAuthHandler":
if provider_name not in HANDLERS_BY_NAME:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
status_code=404,
detail=f"Provider '{provider_name.value}' does not support OAuth",
)
@@ -482,13 +413,14 @@ def _get_provider_oauth_handler(
client_secret = getattr(settings.secrets, f"{provider_name.value}_client_secret")
if not (client_id and client_secret):
logger.error(
f"Attempt to use unconfigured {provider_name.value} OAuth integration"
"OAuth credentials for provider %s are missing. Check environment configuration.",
provider_name.value,
)
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
status_code=501,
detail={
"message": f"Integration with provider '{provider_name.value}' is not configured.",
"hint": "Set client ID and secret in the application's deployment environment",
"message": f"Integration with provider '{provider_name.value}' is not configured",
"hint": "Set client ID and secret in the environment.",
},
)

View File

@@ -1,93 +0,0 @@
import re
from typing import Set
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""
Middleware to add security headers to responses, with cache control
disabled by default for all endpoints except those explicitly allowed.
"""
CACHEABLE_PATHS: Set[str] = {
# Static assets
"/static",
"/_next/static",
"/assets",
"/images",
"/css",
"/js",
"/fonts",
# Public API endpoints
"/api/health",
"/api/v1/health",
"/api/status",
# Public store/marketplace pages (read-only)
"/api/store/agents",
"/api/v1/store/agents",
"/api/store/categories",
"/api/v1/store/categories",
"/api/store/featured",
"/api/v1/store/featured",
# Public graph templates (read-only, no user data)
"/api/graphs/templates",
"/api/v1/graphs/templates",
# Documentation endpoints
"/api/docs",
"/api/v1/docs",
"/docs",
"/swagger",
"/openapi.json",
# Favicon and manifest
"/favicon.ico",
"/manifest.json",
"/robots.txt",
"/sitemap.xml",
}
def __init__(self, app: ASGIApp):
super().__init__(app)
# Compile regex patterns for wildcard matching
self.cacheable_patterns = [
re.compile(pattern.replace("*", "[^/]+"))
for pattern in self.CACHEABLE_PATHS
if "*" in pattern
]
self.exact_paths = {path for path in self.CACHEABLE_PATHS if "*" not in path}
def is_cacheable_path(self, path: str) -> bool:
"""Check if the given path is allowed to be cached."""
# Check exact matches first
for cacheable_path in self.exact_paths:
if path.startswith(cacheable_path):
return True
# Check pattern matches
for pattern in self.cacheable_patterns:
if pattern.match(path):
return True
return False
async def dispatch(self, request: Request, call_next):
response: Response = await call_next(request)
# Add general security headers
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# Default: Disable caching for all endpoints
# Only allow caching for explicitly permitted paths
if not self.is_cacheable_path(request.url.path):
response.headers["Cache-Control"] = (
"no-store, no-cache, must-revalidate, private"
)
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
return response

View File

@@ -1,143 +0,0 @@
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from starlette.applications import Starlette
from backend.server.middleware.security import SecurityHeadersMiddleware
@pytest.fixture
def app():
"""Create a test FastAPI app with security middleware."""
app = FastAPI()
app.add_middleware(SecurityHeadersMiddleware)
@app.get("/api/auth/user")
def get_user():
return {"user": "test"}
@app.get("/api/v1/integrations/oauth/google")
def oauth_endpoint():
return {"oauth": "data"}
@app.get("/api/graphs/123/execute")
def execute_graph():
return {"execution": "data"}
@app.get("/api/integrations/credentials")
def get_credentials():
return {"credentials": "sensitive"}
@app.get("/api/store/agents")
def store_agents():
return {"agents": "public list"}
@app.get("/api/health")
def health_check():
return {"status": "ok"}
@app.get("/static/logo.png")
def static_file():
return {"static": "content"}
return app
@pytest.fixture
def client(app):
"""Create a test client."""
return TestClient(app)
def test_non_cacheable_endpoints_have_cache_control_headers(client):
"""Test that non-cacheable endpoints (most endpoints) have proper cache control headers."""
non_cacheable_endpoints = [
"/api/auth/user",
"/api/v1/integrations/oauth/google",
"/api/graphs/123/execute",
"/api/integrations/credentials",
]
for endpoint in non_cacheable_endpoints:
response = client.get(endpoint)
# Check cache control headers are present (default behavior)
assert (
response.headers["Cache-Control"]
== "no-store, no-cache, must-revalidate, private"
)
assert response.headers["Pragma"] == "no-cache"
assert response.headers["Expires"] == "0"
# Check general security headers
assert response.headers["X-Content-Type-Options"] == "nosniff"
assert response.headers["X-Frame-Options"] == "DENY"
assert response.headers["X-XSS-Protection"] == "1; mode=block"
assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin"
def test_cacheable_endpoints_dont_have_cache_control_headers(client):
"""Test that explicitly cacheable endpoints don't have restrictive cache control headers."""
cacheable_endpoints = [
"/api/store/agents",
"/api/health",
"/static/logo.png",
]
for endpoint in cacheable_endpoints:
response = client.get(endpoint)
# Should NOT have restrictive cache control headers
assert (
"Cache-Control" not in response.headers
or "no-store" not in response.headers.get("Cache-Control", "")
)
assert (
"Pragma" not in response.headers
or response.headers.get("Pragma") != "no-cache"
)
assert (
"Expires" not in response.headers or response.headers.get("Expires") != "0"
)
# Should still have general security headers
assert response.headers["X-Content-Type-Options"] == "nosniff"
assert response.headers["X-Frame-Options"] == "DENY"
assert response.headers["X-XSS-Protection"] == "1; mode=block"
assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin"
def test_is_cacheable_path_detection():
"""Test the path detection logic."""
middleware = SecurityHeadersMiddleware(Starlette())
# Test cacheable paths (allow list)
assert middleware.is_cacheable_path("/api/health")
assert middleware.is_cacheable_path("/api/v1/health")
assert middleware.is_cacheable_path("/static/image.png")
assert middleware.is_cacheable_path("/api/store/agents")
assert middleware.is_cacheable_path("/docs")
assert middleware.is_cacheable_path("/favicon.ico")
# Test non-cacheable paths (everything else)
assert not middleware.is_cacheable_path("/api/auth/user")
assert not middleware.is_cacheable_path("/api/v1/integrations/oauth/callback")
assert not middleware.is_cacheable_path("/api/integrations/credentials/123")
assert not middleware.is_cacheable_path("/api/graphs/abc123/execute")
assert not middleware.is_cacheable_path("/api/store/xyz/submissions")
def test_path_prefix_matching():
"""Test that path prefix matching works correctly."""
middleware = SecurityHeadersMiddleware(Starlette())
# Test that paths starting with cacheable prefixes are cacheable
assert middleware.is_cacheable_path("/static/css/style.css")
assert middleware.is_cacheable_path("/static/js/app.js")
assert middleware.is_cacheable_path("/assets/images/logo.png")
assert middleware.is_cacheable_path("/_next/static/chunks/main.js")
# Test that other API paths are not cacheable by default
assert not middleware.is_cacheable_path("/api/users/profile")
assert not middleware.is_cacheable_path("/api/v1/private/data")
assert not middleware.is_cacheable_path("/api/billing/subscription")

View File

@@ -1,6 +1,5 @@
import contextlib
import logging
from enum import Enum
from typing import Any, Optional
import autogpt_libs.auth.models
@@ -15,7 +14,6 @@ from autogpt_libs.feature_flag.client import (
)
from autogpt_libs.logging.utils import generate_uvicorn_config
from fastapi.exceptions import RequestValidationError
from fastapi.routing import APIRoute
import backend.data.block
import backend.data.db
@@ -38,7 +36,6 @@ from backend.blocks.llm import LlmModel
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.server.external.api import external_app
from backend.server.middleware.security import SecurityHeadersMiddleware
settings = backend.util.settings.Settings()
logger = logging.getLogger(__name__)
@@ -70,33 +67,6 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.db.disconnect()
def custom_generate_unique_id(route: APIRoute):
"""Generate clean operation IDs for OpenAPI spec following the format:
{method}{tag}{summary}
"""
if not route.tags or not route.methods:
return f"{route.name}"
method = list(route.methods)[0].lower()
first_tag = route.tags[0]
if isinstance(first_tag, Enum):
tag_str = first_tag.name
else:
tag_str = str(first_tag)
tag = "".join(word.capitalize() for word in tag_str.split("_")) # v1/v2
summary = (
route.summary if route.summary else route.name
) # need to be unique, a different version could have the same summary
summary = "".join(word.capitalize() for word in str(summary).split("_"))
if tag:
return f"{method}{tag}{summary}"
else:
return f"{method}{summary}"
docs_url = (
"/docs"
if settings.config.app_env == backend.util.settings.AppEnvironment.LOCAL
@@ -112,11 +82,8 @@ app = fastapi.FastAPI(
version="0.1",
lifespan=lifespan_context,
docs_url=docs_url,
generate_unique_id_function=custom_generate_unique_id,
)
app.add_middleware(SecurityHeadersMiddleware)
def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
def handler(request: fastapi.Request, exc: Exception):
@@ -191,12 +158,10 @@ app.include_router(
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
)
app.include_router(
backend.server.v2.otto.routes.router, tags=["v2", "otto"], prefix="/api/otto"
backend.server.v2.otto.routes.router, tags=["v2"], prefix="/api/otto"
)
app.include_router(
backend.server.v2.turnstile.routes.router,
tags=["v2", "turnstile"],
prefix="/api/turnstile",
backend.server.v2.turnstile.routes.router, tags=["v2"], prefix="/api/turnstile"
)
app.include_router(
@@ -279,7 +244,6 @@ class AgentServer(backend.util.service.AppProcess):
@staticmethod
async def test_delete_graph(graph_id: str, user_id: str):
"""Used for clean-up after a test run"""
await backend.server.v2.library.db.delete_library_agent_by_graph_id(
graph_id=graph_id, user_id=user_id
)
@@ -324,14 +288,18 @@ class AgentServer(backend.util.service.AppProcess):
@staticmethod
async def test_execute_preset(
graph_id: str,
graph_version: int,
preset_id: str,
user_id: str,
inputs: Optional[dict[str, Any]] = None,
node_input: Optional[dict[str, Any]] = None,
):
return await backend.server.v2.library.routes.presets.execute_preset(
graph_id=graph_id,
graph_version=graph_version,
preset_id=preset_id,
node_input=node_input or {},
user_id=user_id,
inputs=inputs or {},
)
@staticmethod
@@ -357,22 +325,11 @@ class AgentServer(backend.util.service.AppProcess):
provider: ProviderName,
credentials: Credentials,
) -> Credentials:
from backend.server.integrations.router import (
create_credentials,
get_credential,
)
from backend.server.integrations.router import create_credentials
try:
return await create_credentials(
user_id=user_id, provider=provider, credentials=credentials
)
except Exception as e:
logger.error(f"Error creating credentials: {e}")
return await get_credential(
provider=provider,
user_id=user_id,
cred_id=credentials.id,
)
return await create_credentials(
user_id=user_id, provider=provider, credentials=credentials
)
def set_test_dependency_overrides(self, overrides: dict):
app.dependency_overrides.update(overrides)

View File

@@ -4,7 +4,6 @@ import logging
from typing import Annotated
import fastapi
import pydantic
import backend.data.analytics
from backend.server.utils import get_user_id
@@ -13,28 +12,24 @@ router = fastapi.APIRouter()
logger = logging.getLogger(__name__)
class LogRawMetricRequest(pydantic.BaseModel):
metric_name: str = pydantic.Field(..., min_length=1)
metric_value: float = pydantic.Field(..., allow_inf_nan=False)
data_string: str = pydantic.Field(..., min_length=1)
@router.post(path="/log_raw_metric")
async def log_raw_metric(
user_id: Annotated[str, fastapi.Depends(get_user_id)],
request: LogRawMetricRequest,
metric_name: Annotated[str, fastapi.Body(..., embed=True)],
metric_value: Annotated[float, fastapi.Body(..., embed=True)],
data_string: Annotated[str, fastapi.Body(..., embed=True)],
):
try:
result = await backend.data.analytics.log_raw_metric(
user_id=user_id,
metric_name=request.metric_name,
metric_value=request.metric_value,
data_string=request.data_string,
metric_name=metric_name,
metric_value=metric_value,
data_string=data_string,
)
return result.id
except Exception as e:
logger.exception(
"Failed to log metric %s for user %s: %s", request.metric_name, user_id, e
"Failed to log metric %s for user %s: %s", metric_name, user_id, e
)
raise fastapi.HTTPException(
status_code=500,

View File

@@ -97,17 +97,8 @@ def test_log_raw_metric_invalid_request_improved() -> None:
assert "data_string" in error_fields, "Should report missing data_string"
def test_log_raw_metric_type_validation_improved(
mocker: pytest_mock.MockFixture,
) -> None:
def test_log_raw_metric_type_validation_improved() -> None:
"""Test metric type validation with improved assertions."""
# Mock the analytics function to avoid event loop issues
mocker.patch(
"backend.data.analytics.log_raw_metric",
new_callable=AsyncMock,
return_value=Mock(id="test-id"),
)
invalid_requests = [
{
"data": {
@@ -128,10 +119,10 @@ def test_log_raw_metric_type_validation_improved(
{
"data": {
"metric_name": "test",
"metric_value": 123, # Valid number
"data_string": "", # Empty data_string
"metric_value": float("inf"), # Infinity
"data_string": "test",
},
"expected_error": "String should have at least 1 character",
"expected_error": "ensure this value is finite",
},
]

View File

@@ -93,18 +93,10 @@ def test_log_raw_metric_values_parametrized(
],
)
def test_log_raw_metric_invalid_requests_parametrized(
mocker: pytest_mock.MockFixture,
invalid_data: dict,
expected_error: str,
) -> None:
"""Test invalid metric requests with parametrize."""
# Mock the analytics function to avoid event loop issues
mocker.patch(
"backend.data.analytics.log_raw_metric",
new_callable=AsyncMock,
return_value=Mock(id="test-id"),
)
response = client.post("/log_raw_metric", json=invalid_data)
assert response.status_code == 422

View File

@@ -34,7 +34,7 @@ router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/unsubscribe", summary="One Click Email Unsubscribe")
@router.post("/unsubscribe")
async def unsubscribe_via_one_click(token: Annotated[str, Query()]):
logger.info("Received unsubscribe request from One Click Unsubscribe")
try:
@@ -48,11 +48,7 @@ async def unsubscribe_via_one_click(token: Annotated[str, Query()]):
return JSONResponse(status_code=200, content={"status": "ok"})
@router.post(
"/",
dependencies=[Depends(postmark_validator.get_dependency())],
summary="Handle Postmark Email Webhooks",
)
@router.post("/", dependencies=[Depends(postmark_validator.get_dependency())])
async def postmark_webhook_handler(
webhook: Annotated[
PostmarkWebhook,

View File

@@ -9,7 +9,7 @@ import stripe
from autogpt_libs.auth.middleware import auth_middleware
from autogpt_libs.feature_flag.client import feature_flag
from autogpt_libs.utils.cache import thread_cached
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Request, Response
from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
from typing_extensions import Optional, TypedDict
@@ -72,7 +72,6 @@ from backend.server.model import (
UpdatePermissionsRequest,
)
from backend.server.utils import get_user_id
from backend.util.exceptions import NotFoundError
from backend.util.service import get_service_client
from backend.util.settings import Settings
@@ -114,22 +113,14 @@ v1_router.include_router(
########################################################
@v1_router.post(
"/auth/user",
summary="Get or create user",
tags=["auth"],
dependencies=[Depends(auth_middleware)],
)
@v1_router.post("/auth/user", tags=["auth"], dependencies=[Depends(auth_middleware)])
async def get_or_create_user_route(user_data: dict = Depends(auth_middleware)):
user = await get_or_create_user(user_data)
return user.model_dump()
@v1_router.post(
"/auth/user/email",
summary="Update user email",
tags=["auth"],
dependencies=[Depends(auth_middleware)],
"/auth/user/email", tags=["auth"], dependencies=[Depends(auth_middleware)]
)
async def update_user_email_route(
user_id: Annotated[str, Depends(get_user_id)], email: str = Body(...)
@@ -141,7 +132,6 @@ async def update_user_email_route(
@v1_router.get(
"/auth/user/preferences",
summary="Get notification preferences",
tags=["auth"],
dependencies=[Depends(auth_middleware)],
)
@@ -154,7 +144,6 @@ async def get_preferences(
@v1_router.post(
"/auth/user/preferences",
summary="Update notification preferences",
tags=["auth"],
dependencies=[Depends(auth_middleware)],
)
@@ -172,20 +161,14 @@ async def update_preferences(
@v1_router.get(
"/onboarding",
summary="Get onboarding status",
tags=["onboarding"],
dependencies=[Depends(auth_middleware)],
"/onboarding", tags=["onboarding"], dependencies=[Depends(auth_middleware)]
)
async def get_onboarding(user_id: Annotated[str, Depends(get_user_id)]):
return await get_user_onboarding(user_id)
@v1_router.patch(
"/onboarding",
summary="Update onboarding progress",
tags=["onboarding"],
dependencies=[Depends(auth_middleware)],
"/onboarding", tags=["onboarding"], dependencies=[Depends(auth_middleware)]
)
async def update_onboarding(
user_id: Annotated[str, Depends(get_user_id)], data: UserOnboardingUpdate
@@ -195,7 +178,6 @@ async def update_onboarding(
@v1_router.get(
"/onboarding/agents",
summary="Get recommended agents",
tags=["onboarding"],
dependencies=[Depends(auth_middleware)],
)
@@ -207,7 +189,6 @@ async def get_onboarding_agents(
@v1_router.get(
"/onboarding/enabled",
summary="Check onboarding enabled",
tags=["onboarding", "public"],
dependencies=[Depends(auth_middleware)],
)
@@ -220,12 +201,7 @@ async def is_onboarding_enabled():
########################################################
@v1_router.get(
path="/blocks",
summary="List available blocks",
tags=["blocks"],
dependencies=[Depends(auth_middleware)],
)
@v1_router.get(path="/blocks", tags=["blocks"], dependencies=[Depends(auth_middleware)])
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
blocks = [block() for block in get_blocks().values()]
costs = get_block_costs()
@@ -236,7 +212,6 @@ def get_graph_blocks() -> Sequence[dict[Any, Any]]:
@v1_router.post(
path="/blocks/{block_id}/execute",
summary="Execute graph block",
tags=["blocks"],
dependencies=[Depends(auth_middleware)],
)
@@ -256,12 +231,7 @@ async def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlock
########################################################
@v1_router.get(
path="/credits",
tags=["credits"],
summary="Get user credits",
dependencies=[Depends(auth_middleware)],
)
@v1_router.get(path="/credits", dependencies=[Depends(auth_middleware)])
async def get_user_credits(
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[str, int]:
@@ -269,10 +239,7 @@ async def get_user_credits(
@v1_router.post(
path="/credits",
summary="Request credit top up",
tags=["credits"],
dependencies=[Depends(auth_middleware)],
path="/credits", tags=["credits"], dependencies=[Depends(auth_middleware)]
)
async def request_top_up(
request: RequestTopUp, user_id: Annotated[str, Depends(get_user_id)]
@@ -285,7 +252,6 @@ async def request_top_up(
@v1_router.post(
path="/credits/{transaction_key}/refund",
summary="Refund credit transaction",
tags=["credits"],
dependencies=[Depends(auth_middleware)],
)
@@ -298,10 +264,7 @@ async def refund_top_up(
@v1_router.patch(
path="/credits",
summary="Fulfill checkout session",
tags=["credits"],
dependencies=[Depends(auth_middleware)],
path="/credits", tags=["credits"], dependencies=[Depends(auth_middleware)]
)
async def fulfill_checkout(user_id: Annotated[str, Depends(get_user_id)]):
await _user_credit_model.fulfill_checkout(user_id=user_id)
@@ -310,7 +273,6 @@ async def fulfill_checkout(user_id: Annotated[str, Depends(get_user_id)]):
@v1_router.post(
path="/credits/auto-top-up",
summary="Configure auto top up",
tags=["credits"],
dependencies=[Depends(auth_middleware)],
)
@@ -339,7 +301,6 @@ async def configure_user_auto_top_up(
@v1_router.get(
path="/credits/auto-top-up",
summary="Get auto top up",
tags=["credits"],
dependencies=[Depends(auth_middleware)],
)
@@ -349,9 +310,7 @@ async def get_user_auto_top_up(
return await get_auto_top_up(user_id)
@v1_router.post(
path="/credits/stripe_webhook", summary="Handle Stripe webhooks", tags=["credits"]
)
@v1_router.post(path="/credits/stripe_webhook", tags=["credits"])
async def stripe_webhook(request: Request):
# Get the raw request body
payload = await request.body()
@@ -386,24 +345,14 @@ async def stripe_webhook(request: Request):
return Response(status_code=200)
@v1_router.get(
path="/credits/manage",
tags=["credits"],
summary="Manage payment methods",
dependencies=[Depends(auth_middleware)],
)
@v1_router.get(path="/credits/manage", dependencies=[Depends(auth_middleware)])
async def manage_payment_method(
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[str, str]:
return {"url": await _user_credit_model.create_billing_portal_session(user_id)}
@v1_router.get(
path="/credits/transactions",
tags=["credits"],
summary="Get credit history",
dependencies=[Depends(auth_middleware)],
)
@v1_router.get(path="/credits/transactions", dependencies=[Depends(auth_middleware)])
async def get_credit_history(
user_id: Annotated[str, Depends(get_user_id)],
transaction_time: datetime | None = None,
@@ -421,12 +370,7 @@ async def get_credit_history(
)
@v1_router.get(
path="/credits/refunds",
tags=["credits"],
summary="Get refund requests",
dependencies=[Depends(auth_middleware)],
)
@v1_router.get(path="/credits/refunds", dependencies=[Depends(auth_middleware)])
async def get_refund_requests(
user_id: Annotated[str, Depends(get_user_id)],
) -> list[RefundRequest]:
@@ -442,12 +386,7 @@ class DeleteGraphResponse(TypedDict):
version_counts: int
@v1_router.get(
path="/graphs",
summary="List user graphs",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
@v1_router.get(path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)])
async def get_graphs(
user_id: Annotated[str, Depends(get_user_id)],
) -> Sequence[graph_db.GraphModel]:
@@ -455,14 +394,10 @@ async def get_graphs(
@v1_router.get(
path="/graphs/{graph_id}",
summary="Get specific graph",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
)
@v1_router.get(
path="/graphs/{graph_id}/versions/{version}",
summary="Get graph version",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
@@ -486,7 +421,6 @@ async def get_graph(
@v1_router.get(
path="/graphs/{graph_id}/versions",
summary="Get all graph versions",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
@@ -500,10 +434,7 @@ async def get_graph_all_versions(
@v1_router.post(
path="/graphs",
summary="Create new graph",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)]
)
async def create_new_graph(
create_graph: CreateGraph,
@@ -526,10 +457,7 @@ async def create_new_graph(
@v1_router.delete(
path="/graphs/{graph_id}",
summary="Delete graph permanently",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
)
async def delete_graph(
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
@@ -541,10 +469,7 @@ async def delete_graph(
@v1_router.put(
path="/graphs/{graph_id}",
summary="Update graph version",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
)
async def update_graph(
graph_id: str,
@@ -590,7 +515,6 @@ async def update_graph(
@v1_router.put(
path="/graphs/{graph_id}/versions/active",
summary="Set active graph version",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
@@ -629,7 +553,6 @@ async def set_graph_active_version(
@v1_router.post(
path="/graphs/{graph_id}/execute/{graph_version}",
summary="Execute graph agent",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
@@ -663,65 +586,33 @@ async def execute_graph(
@v1_router.post(
path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
summary="Stop graph execution",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
async def stop_graph_run(
graph_id: str, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> execution_db.GraphExecutionMeta | None:
res = await _stop_graph_run(
user_id=user_id,
graph_id=graph_id,
graph_exec_id=graph_exec_id,
) -> execution_db.GraphExecution:
if not await execution_db.get_graph_execution_meta(
user_id=user_id, execution_id=graph_exec_id
):
raise HTTPException(404, detail=f"Agent execution #{graph_exec_id} not found")
await execution_utils.stop_graph_execution(graph_exec_id)
# Retrieve & return canceled graph execution in its final state
result = await execution_db.get_graph_execution(
execution_id=graph_exec_id, user_id=user_id
)
if not res:
return None
return res[0]
@v1_router.post(
path="/executions",
summary="Stop graph executions",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
async def stop_graph_runs(
graph_id: str, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> list[execution_db.GraphExecutionMeta]:
return await _stop_graph_run(
user_id=user_id,
graph_id=graph_id,
graph_exec_id=graph_exec_id,
)
async def _stop_graph_run(
user_id: str,
graph_id: Optional[str] = None,
graph_exec_id: Optional[str] = None,
) -> list[execution_db.GraphExecutionMeta]:
graph_execs = await execution_db.get_graph_executions(
user_id=user_id,
graph_id=graph_id,
graph_exec_id=graph_exec_id,
statuses=[
execution_db.ExecutionStatus.INCOMPLETE,
execution_db.ExecutionStatus.QUEUED,
execution_db.ExecutionStatus.RUNNING,
],
)
stopped_execs = [
execution_utils.stop_graph_execution(graph_exec_id=exec.id, user_id=user_id)
for exec in graph_execs
]
await asyncio.gather(*stopped_execs)
return graph_execs
if not result:
raise HTTPException(
500,
detail=f"Could not fetch graph execution #{graph_exec_id} after stopping",
)
return result
@v1_router.get(
path="/executions",
summary="Get all executions",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
@@ -733,7 +624,6 @@ async def get_graphs_executions(
@v1_router.get(
path="/graphs/{graph_id}/executions",
summary="Get graph executions",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
@@ -746,7 +636,6 @@ async def get_graph_executions(
@v1_router.get(
path="/graphs/{graph_id}/executions/{graph_exec_id}",
summary="Get execution details",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
@@ -776,7 +665,6 @@ async def get_graph_execution(
@v1_router.delete(
path="/executions/{graph_exec_id}",
summary="Delete graph execution",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
status_code=HTTP_204_NO_CONTENT,
@@ -796,55 +684,60 @@ async def delete_graph_execution(
class ScheduleCreationRequest(pydantic.BaseModel):
graph_version: Optional[int] = None
name: str
cron: str
inputs: dict[str, Any]
credentials: dict[str, CredentialsMetaInput] = pydantic.Field(default_factory=dict)
input_data: dict[Any, Any]
graph_id: str
graph_version: int
@v1_router.post(
path="/graphs/{graph_id}/schedules",
summary="Create execution schedule",
path="/schedules",
tags=["schedules"],
dependencies=[Depends(auth_middleware)],
)
async def create_graph_execution_schedule(
async def create_schedule(
user_id: Annotated[str, Depends(get_user_id)],
graph_id: str = Path(..., description="ID of the graph to schedule"),
schedule_params: ScheduleCreationRequest = Body(),
schedule: ScheduleCreationRequest,
) -> scheduler.GraphExecutionJobInfo:
graph = await graph_db.get_graph(
graph_id=graph_id,
version=schedule_params.graph_version,
user_id=user_id,
schedule.graph_id, schedule.graph_version, user_id=user_id
)
if not graph:
raise HTTPException(
status_code=404,
detail=f"Graph #{graph_id} v{schedule_params.graph_version} not found.",
detail=f"Graph #{schedule.graph_id} v.{schedule.graph_version} not found.",
)
return await execution_scheduler_client().add_execution_schedule(
user_id=user_id,
graph_id=graph_id,
graph_id=schedule.graph_id,
graph_version=graph.version,
name=schedule_params.name,
cron=schedule_params.cron,
input_data=schedule_params.inputs,
input_credentials=schedule_params.credentials,
cron=schedule.cron,
input_data=schedule.input_data,
user_id=user_id,
)
@v1_router.get(
path="/graphs/{graph_id}/schedules",
summary="List execution schedules for a graph",
@v1_router.delete(
path="/schedules/{schedule_id}",
tags=["schedules"],
dependencies=[Depends(auth_middleware)],
)
async def list_graph_execution_schedules(
async def delete_schedule(
schedule_id: str,
user_id: Annotated[str, Depends(get_user_id)],
graph_id: str = Path(),
) -> dict[Any, Any]:
await execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
return {"id": schedule_id}
@v1_router.get(
path="/schedules",
tags=["schedules"],
dependencies=[Depends(auth_middleware)],
)
async def get_execution_schedules(
user_id: Annotated[str, Depends(get_user_id)],
graph_id: str | None = None,
) -> list[scheduler.GraphExecutionJobInfo]:
return await execution_scheduler_client().get_execution_schedules(
user_id=user_id,
@@ -852,38 +745,6 @@ async def list_graph_execution_schedules(
)
@v1_router.get(
path="/schedules",
summary="List execution schedules for a user",
tags=["schedules"],
dependencies=[Depends(auth_middleware)],
)
async def list_all_graphs_execution_schedules(
user_id: Annotated[str, Depends(get_user_id)],
) -> list[scheduler.GraphExecutionJobInfo]:
return await execution_scheduler_client().get_execution_schedules(user_id=user_id)
@v1_router.delete(
path="/schedules/{schedule_id}",
summary="Delete execution schedule",
tags=["schedules"],
dependencies=[Depends(auth_middleware)],
)
async def delete_graph_execution_schedule(
user_id: Annotated[str, Depends(get_user_id)],
schedule_id: str = Path(..., description="ID of the schedule to delete"),
) -> dict[str, Any]:
try:
await execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
except NotFoundError:
raise HTTPException(
status_code=HTTP_404_NOT_FOUND,
detail=f"Schedule #{schedule_id} not found",
)
return {"id": schedule_id}
########################################################
##################### API KEY ##############################
########################################################
@@ -891,7 +752,6 @@ async def delete_graph_execution_schedule(
@v1_router.post(
"/api-keys",
summary="Create new API key",
response_model=CreateAPIKeyResponse,
tags=["api-keys"],
dependencies=[Depends(auth_middleware)],
@@ -922,7 +782,6 @@ async def create_api_key(
@v1_router.get(
"/api-keys",
summary="List user API keys",
response_model=list[APIKeyWithoutHash] | dict[str, str],
tags=["api-keys"],
dependencies=[Depends(auth_middleware)],
@@ -943,7 +802,6 @@ async def get_api_keys(
@v1_router.get(
"/api-keys/{key_id}",
summary="Get specific API key",
response_model=APIKeyWithoutHash,
tags=["api-keys"],
dependencies=[Depends(auth_middleware)],
@@ -967,7 +825,6 @@ async def get_api_key(
@v1_router.delete(
"/api-keys/{key_id}",
summary="Revoke API key",
response_model=APIKeyWithoutHash,
tags=["api-keys"],
dependencies=[Depends(auth_middleware)],
@@ -996,7 +853,6 @@ async def delete_api_key(
@v1_router.post(
"/api-keys/{key_id}/suspend",
summary="Suspend API key",
response_model=APIKeyWithoutHash,
tags=["api-keys"],
dependencies=[Depends(auth_middleware)],
@@ -1022,7 +878,6 @@ async def suspend_key(
@v1_router.put(
"/api-keys/{key_id}/permissions",
summary="Update key permissions",
response_model=APIKeyWithoutHash,
tags=["api-keys"],
dependencies=[Depends(auth_middleware)],

View File

@@ -108,16 +108,11 @@ class TestDatabaseIsolation:
where={"email": {"contains": "@test.example"}}
)
@pytest.fixture(scope="session")
async def test_create_user(self, test_db_connection):
"""Test that demonstrates proper isolation."""
# This test has access to a clean database
user = await test_db_connection.user.create(
data={
"id": "test-user-id",
"email": "test@test.example",
"name": "Test User",
}
data={"email": "test@test.example", "name": "Test User"}
)
assert user.email == "test@test.example"
# User will be cleaned up automatically

View File

@@ -22,9 +22,7 @@ router = APIRouter(
)
@router.post(
"/add_credits", response_model=AddUserCreditsResponse, summary="Add Credits to User"
)
@router.post("/add_credits", response_model=AddUserCreditsResponse)
async def add_user_credits(
user_id: typing.Annotated[str, Body()],
amount: typing.Annotated[int, Body()],
@@ -51,7 +49,6 @@ async def add_user_credits(
@router.get(
"/users_history",
response_model=UserHistoryResponse,
summary="Get All Users History",
)
async def admin_get_all_user_history(
admin_user: typing.Annotated[

View File

@@ -19,7 +19,6 @@ router = fastapi.APIRouter(prefix="/admin", tags=["store", "admin"])
@router.get(
"/listings",
summary="Get Admin Listings History",
response_model=backend.server.v2.store.model.StoreListingsWithVersionsResponse,
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
)
@@ -64,7 +63,6 @@ async def get_admin_listings_with_versions(
@router.post(
"/submissions/{store_listing_version_id}/review",
summary="Review Store Submission",
response_model=backend.server.v2.store.model.StoreSubmission,
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
)
@@ -106,7 +104,6 @@ async def review_submission(
@router.get(
"/submissions/download/{store_listing_version_id}",
summary="Admin Download Agent File",
tags=["store", "admin"],
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
)

View File

@@ -1,5 +1,5 @@
import logging
from typing import Literal, Optional
from typing import Optional
import fastapi
import prisma.errors
@@ -7,17 +7,17 @@ import prisma.fields
import prisma.models
import prisma.types
import backend.data.graph as graph_db
import backend.data.graph
import backend.server.model
import backend.server.v2.library.model as library_model
import backend.server.v2.store.exceptions as store_exceptions
import backend.server.v2.store.image_gen as store_image_gen
import backend.server.v2.store.media as store_media
from backend.data.block import BlockInput
from backend.data.db import locked_transaction, transaction
from backend.data import db
from backend.data import graph as graph_db
from backend.data.db import locked_transaction
from backend.data.execution import get_graph_execution
from backend.data.includes import library_agent_include
from backend.data.model import CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
from backend.util.exceptions import NotFoundError
@@ -122,7 +122,7 @@ async def list_library_agents(
except Exception as e:
# Skip this agent if there was an error
logger.error(
f"Error parsing LibraryAgent #{agent.id} from DB item: {e}"
f"Error parsing LibraryAgent when getting library agents from db: {e}"
)
continue
@@ -168,16 +168,9 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
)
if not library_agent:
raise NotFoundError(f"Library agent #{id} not found")
raise store_exceptions.AgentNotFoundError(f"Library agent #{id} not found")
return library_model.LibraryAgent.from_db(
library_agent,
sub_graphs=(
await graph_db.get_sub_graphs(library_agent.AgentGraph)
if library_agent.AgentGraph
else None
),
)
return library_model.LibraryAgent.from_db(library_agent)
except prisma.errors.PrismaError as e:
logger.error(f"Database error fetching library agent: {e}")
@@ -222,34 +215,8 @@ async def get_library_agent_by_store_version_id(
return None
async def get_library_agent_by_graph_id(
user_id: str,
graph_id: str,
graph_version: Optional[int] = None,
) -> library_model.LibraryAgent | None:
try:
filter: prisma.types.LibraryAgentWhereInput = {
"agentGraphId": graph_id,
"userId": user_id,
"isDeleted": False,
}
if graph_version is not None:
filter["agentGraphVersion"] = graph_version
agent = await prisma.models.LibraryAgent.prisma().find_first(
where=filter,
include=library_agent_include(user_id),
)
if not agent:
return None
return library_model.LibraryAgent.from_db(agent)
except prisma.errors.PrismaError as e:
logger.error(f"Database error fetching library agent by graph ID: {e}")
raise store_exceptions.DatabaseError("Failed to fetch library agent") from e
async def add_generated_agent_image(
graph: graph_db.GraphModel,
graph: backend.data.graph.GraphModel,
library_agent_id: str,
) -> Optional[prisma.models.LibraryAgent]:
"""
@@ -282,7 +249,7 @@ async def add_generated_agent_image(
async def create_library_agent(
graph: graph_db.GraphModel,
graph: backend.data.graph.GraphModel,
user_id: str,
) -> library_model.LibraryAgent:
"""
@@ -379,8 +346,8 @@ async def update_library_agent(
auto_update_version: Optional[bool] = None,
is_favorite: Optional[bool] = None,
is_archived: Optional[bool] = None,
is_deleted: Optional[Literal[False]] = None,
) -> library_model.LibraryAgent:
is_deleted: Optional[bool] = None,
) -> None:
"""
Updates the specified LibraryAgent record.
@@ -390,18 +357,15 @@ async def update_library_agent(
auto_update_version: Whether the agent should auto-update to active version.
is_favorite: Whether this agent is marked as a favorite.
is_archived: Whether this agent is archived.
Returns:
The updated LibraryAgent.
is_deleted: Whether this agent is deleted.
Raises:
NotFoundError: If the specified LibraryAgent does not exist.
DatabaseError: If there's an error in the update operation.
"""
logger.debug(
f"Updating library agent {library_agent_id} for user {user_id} with "
f"auto_update_version={auto_update_version}, is_favorite={is_favorite}, "
f"is_archived={is_archived}"
f"is_archived={is_archived}, is_deleted={is_deleted}"
)
update_fields: prisma.types.LibraryAgentUpdateManyMutationInput = {}
if auto_update_version is not None:
@@ -411,46 +375,17 @@ async def update_library_agent(
if is_archived is not None:
update_fields["isArchived"] = is_archived
if is_deleted is not None:
if is_deleted is True:
raise RuntimeError(
"Use delete_library_agent() to (soft-)delete library agents"
)
update_fields["isDeleted"] = is_deleted
if not update_fields:
raise ValueError("No values were passed to update")
try:
n_updated = await prisma.models.LibraryAgent.prisma().update_many(
where={"id": library_agent_id, "userId": user_id},
data=update_fields,
)
if n_updated < 1:
raise NotFoundError(f"Library agent {library_agent_id} not found")
return await get_library_agent(
id=library_agent_id,
user_id=user_id,
await prisma.models.LibraryAgent.prisma().update_many(
where={"id": library_agent_id, "userId": user_id}, data=update_fields
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating library agent: {str(e)}")
raise store_exceptions.DatabaseError("Failed to update library agent") from e
async def delete_library_agent(
library_agent_id: str, user_id: str, soft_delete: bool = True
) -> None:
if soft_delete:
deleted_count = await prisma.models.LibraryAgent.prisma().update_many(
where={"id": library_agent_id, "userId": user_id}, data={"isDeleted": True}
)
else:
deleted_count = await prisma.models.LibraryAgent.prisma().delete_many(
where={"id": library_agent_id, "userId": user_id}
)
if deleted_count < 1:
raise NotFoundError(f"Library agent #{library_agent_id} not found")
async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
"""
Deletes a library agent for the given user
@@ -590,10 +525,7 @@ async def list_presets(
)
raise store_exceptions.DatabaseError("Invalid pagination parameters")
query_filter: prisma.types.AgentPresetWhereInput = {
"userId": user_id,
"isDeleted": False,
}
query_filter: prisma.types.AgentPresetWhereInput = {"userId": user_id}
if graph_id:
query_filter["agentGraphId"] = graph_id
@@ -649,7 +581,7 @@ async def get_preset(
where={"id": preset_id},
include={"InputPresets": True},
)
if not preset or preset.userId != user_id or preset.isDeleted:
if not preset or preset.userId != user_id:
return None
return library_model.LibraryAgentPreset.from_db(preset)
except prisma.errors.PrismaError as e:
@@ -686,19 +618,12 @@ async def create_preset(
agentGraphId=preset.graph_id,
agentGraphVersion=preset.graph_version,
isActive=preset.is_active,
webhookId=preset.webhook_id,
InputPresets={
"create": [
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput( # noqa
name=name, data=prisma.fields.Json(data)
)
for name, data in {
**preset.inputs,
**{
key: creds_meta.model_dump(exclude_none=True)
for key, creds_meta in preset.credentials.items()
},
}.items()
for name, data in preset.inputs.items()
]
},
),
@@ -739,7 +664,6 @@ async def create_preset_from_graph_execution(
user_id=user_id,
preset=library_model.LibraryAgentPresetCreatable(
inputs=graph_execution.inputs,
credentials={}, # FIXME
graph_id=graph_execution.graph_id,
graph_version=graph_execution.graph_version,
name=create_request.name,
@@ -752,11 +676,7 @@ async def create_preset_from_graph_execution(
async def update_preset(
user_id: str,
preset_id: str,
inputs: Optional[BlockInput] = None,
credentials: Optional[dict[str, CredentialsMetaInput]] = None,
name: Optional[str] = None,
description: Optional[str] = None,
is_active: Optional[bool] = None,
preset: library_model.LibraryAgentPresetUpdatable,
) -> library_model.LibraryAgentPreset:
"""
Updates an existing AgentPreset for a user.
@@ -764,95 +684,49 @@ async def update_preset(
Args:
user_id: The ID of the user updating the preset.
preset_id: The ID of the preset to update.
inputs: New inputs object to set on the preset.
credentials: New credentials to set on the preset.
name: New name for the preset.
description: New description for the preset.
is_active: New active status for the preset.
preset: The preset data used for the update.
Returns:
The updated LibraryAgentPreset.
Raises:
DatabaseError: If there's a database error in updating the preset.
NotFoundError: If attempting to update a non-existent preset.
ValueError: If attempting to update a non-existent preset.
"""
current = await get_preset(user_id, preset_id) # assert ownership
if not current:
raise NotFoundError(f"Preset #{preset_id} not found for user #{user_id}")
logger.debug(
f"Updating preset #{preset_id} ({repr(current.name)}) for user #{user_id}",
f"Updating preset #{preset_id} ({repr(preset.name)}) for user #{user_id}",
)
try:
async with transaction() as tx:
update_data: prisma.types.AgentPresetUpdateInput = {}
if name:
update_data["name"] = name
if description:
update_data["description"] = description
if is_active is not None:
update_data["isActive"] = is_active
if inputs or credentials:
if not (inputs and credentials):
raise ValueError(
"Preset inputs and credentials must be provided together"
update_data: prisma.types.AgentPresetUpdateInput = {}
if preset.name:
update_data["name"] = preset.name
if preset.description:
update_data["description"] = preset.description
if preset.inputs:
update_data["InputPresets"] = {
"create": [
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput( # noqa
name=name, data=prisma.fields.Json(data)
)
update_data["InputPresets"] = {
"create": [
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput( # noqa
name=name, data=prisma.fields.Json(data)
)
for name, data in {
**inputs,
**{
key: creds_meta.model_dump(exclude_none=True)
for key, creds_meta in credentials.items()
},
}.items()
],
}
# Existing InputPresets must be deleted, in a separate query
await prisma.models.AgentNodeExecutionInputOutput.prisma(
tx
).delete_many(where={"agentPresetId": preset_id})
for name, data in preset.inputs.items()
]
}
if preset.is_active:
update_data["isActive"] = preset.is_active
updated = await prisma.models.AgentPreset.prisma(tx).update(
where={"id": preset_id},
data=update_data,
include={"InputPresets": True},
)
updated = await prisma.models.AgentPreset.prisma().update(
where={"id": preset_id},
data=update_data,
include={"InputPresets": True},
)
if not updated:
raise RuntimeError(f"AgentPreset #{preset_id} vanished while updating")
raise ValueError(f"AgentPreset #{preset_id} not found")
return library_model.LibraryAgentPreset.from_db(updated)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating preset: {e}")
raise store_exceptions.DatabaseError("Failed to update preset") from e
async def set_preset_webhook(
user_id: str, preset_id: str, webhook_id: str | None
) -> library_model.LibraryAgentPreset:
current = await prisma.models.AgentPreset.prisma().find_unique(
where={"id": preset_id},
include={"InputPresets": True},
)
if not current or current.userId != user_id:
raise NotFoundError(f"Preset #{preset_id} not found")
updated = await prisma.models.AgentPreset.prisma().update(
where={"id": preset_id},
data=(
{"Webhook": {"connect": {"id": webhook_id}}}
if webhook_id
else {"Webhook": {"disconnect": True}}
),
include={"InputPresets": True},
)
if not updated:
raise RuntimeError(f"AgentPreset #{preset_id} vanished while updating")
return library_model.LibraryAgentPreset.from_db(updated)
async def delete_preset(user_id: str, preset_id: str) -> None:
"""
Soft-deletes a preset by marking it as isDeleted = True.
@@ -864,7 +738,7 @@ async def delete_preset(user_id: str, preset_id: str) -> None:
Raises:
DatabaseError: If there's a database error during deletion.
"""
logger.debug(f"Setting preset #{preset_id} for user #{user_id} to deleted")
logger.info(f"Deleting preset {preset_id} for user {user_id}")
try:
await prisma.models.AgentPreset.prisma().update_many(
where={"id": preset_id, "userId": user_id},
@@ -891,7 +765,7 @@ async def fork_library_agent(library_agent_id: str, user_id: str):
"""
logger.debug(f"Forking library agent {library_agent_id} for user {user_id}")
try:
async with locked_transaction(f"usr_trx_{user_id}-fork_agent"):
async with db.locked_transaction(f"usr_trx_{user_id}-fork_agent"):
# Fetch the original agent
original_agent = await get_library_agent(library_agent_id, user_id)

View File

@@ -143,7 +143,7 @@ async def test_add_agent_to_library(mocker):
)
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.create = mocker.AsyncMock(
return_value=mock_library_agent_data
)
@@ -159,24 +159,21 @@ async def test_add_agent_to_library(mocker):
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
where={"id": "version123"}, include={"AgentGraph": True}
)
mock_library_agent.return_value.find_unique.assert_called_once_with(
mock_library_agent.return_value.find_first.assert_called_once_with(
where={
"userId_agentGraphId_agentGraphVersion": {
"userId": "test-user",
"agentGraphId": "agent1",
"agentGraphVersion": 1,
}
"userId": "test-user",
"agentGraphId": "agent1",
"agentGraphVersion": 1,
},
include={"AgentGraph": True},
include=library_agent_include("test-user"),
)
mock_library_agent.return_value.create.assert_called_once_with(
data={
"User": {"connect": {"id": "test-user"}},
"AgentGraph": {
"connect": {"graphVersionId": {"id": "agent1", "version": 1}}
},
"isCreatedByUser": False,
},
data=prisma.types.LibraryAgentCreateInput(
userId="test-user",
agentGraphId="agent1",
agentGraphVersion=1,
isCreatedByUser=False,
),
include=library_agent_include("test-user"),
)

View File

@@ -9,8 +9,6 @@ import pydantic
import backend.data.block as block_model
import backend.data.graph as graph_model
import backend.server.model as server_model
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
from backend.integrations.providers import ProviderName
class LibraryAgentStatus(str, Enum):
@@ -20,14 +18,6 @@ class LibraryAgentStatus(str, Enum):
ERROR = "ERROR" # Agent is in an error state
class LibraryAgentTriggerInfo(pydantic.BaseModel):
provider: ProviderName
config_schema: dict[str, Any] = pydantic.Field(
description="Input schema for the trigger block"
)
credentials_input_name: Optional[str]
class LibraryAgent(pydantic.BaseModel):
"""
Represents an agent in the library, including metadata for display and
@@ -50,15 +40,8 @@ class LibraryAgent(pydantic.BaseModel):
name: str
description: str
# Made input_schema and output_schema match GraphMeta's type
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
credentials_input_schema: dict[str, Any] | None = pydantic.Field(
description="Input schema for credentials required by the agent",
)
has_external_trigger: bool = pydantic.Field(
description="Whether the agent has an external trigger (e.g. webhook) node"
)
trigger_setup_info: Optional[LibraryAgentTriggerInfo] = None
# Indicates whether there's a new output (based on recent runs)
new_output: bool
@@ -70,10 +53,7 @@ class LibraryAgent(pydantic.BaseModel):
is_latest_version: bool
@staticmethod
def from_db(
agent: prisma.models.LibraryAgent,
sub_graphs: Optional[list[prisma.models.AgentGraph]] = None,
) -> "LibraryAgent":
def from_db(agent: prisma.models.LibraryAgent) -> "LibraryAgent":
"""
Factory method that constructs a LibraryAgent from a Prisma LibraryAgent
model instance.
@@ -81,7 +61,7 @@ class LibraryAgent(pydantic.BaseModel):
if not agent.AgentGraph:
raise ValueError("Associated Agent record is required.")
graph = graph_model.GraphModel.from_db(agent.AgentGraph, sub_graphs=sub_graphs)
graph = graph_model.GraphModel.from_db(agent.AgentGraph)
agent_updated_at = agent.AgentGraph.updatedAt
lib_agent_updated_at = agent.updatedAt
@@ -126,34 +106,6 @@ class LibraryAgent(pydantic.BaseModel):
name=graph.name,
description=graph.description,
input_schema=graph.input_schema,
credentials_input_schema=(
graph.credentials_input_schema if sub_graphs else None
),
has_external_trigger=graph.has_webhook_trigger,
trigger_setup_info=(
LibraryAgentTriggerInfo(
provider=trigger_block.webhook_config.provider,
config_schema={
**(json_schema := trigger_block.input_schema.jsonschema()),
"properties": {
pn: sub_schema
for pn, sub_schema in json_schema["properties"].items()
if not is_credentials_field_name(pn)
},
"required": [
pn
for pn in json_schema.get("required", [])
if not is_credentials_field_name(pn)
],
},
credentials_input_name=next(
iter(trigger_block.input_schema.get_credentials_fields()), None
),
)
if graph.webhook_input_node
and (trigger_block := graph.webhook_input_node.block).webhook_config
else None
),
new_output=new_output,
can_access_graph=can_access_graph,
is_latest_version=is_latest_version,
@@ -225,15 +177,12 @@ class LibraryAgentPresetCreatable(pydantic.BaseModel):
graph_version: int
inputs: block_model.BlockInput
credentials: dict[str, CredentialsMetaInput]
name: str
description: str
is_active: bool = True
webhook_id: Optional[str] = None
class LibraryAgentPresetCreatableFromGraphExecution(pydantic.BaseModel):
"""
@@ -254,7 +203,6 @@ class LibraryAgentPresetUpdatable(pydantic.BaseModel):
"""
inputs: Optional[block_model.BlockInput] = None
credentials: Optional[dict[str, CredentialsMetaInput]] = None
name: Optional[str] = None
description: Optional[str] = None
@@ -266,28 +214,20 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
"""Represents a preset configuration for a library agent."""
id: str
user_id: str
updated_at: datetime.datetime
@classmethod
def from_db(cls, preset: prisma.models.AgentPreset) -> "LibraryAgentPreset":
if preset.InputPresets is None:
raise ValueError("InputPresets must be included in AgentPreset query")
raise ValueError("Input values must be included in object")
input_data: block_model.BlockInput = {}
input_credentials: dict[str, CredentialsMetaInput] = {}
for preset_input in preset.InputPresets:
if not is_credentials_field_name(preset_input.name):
input_data[preset_input.name] = preset_input.data
else:
input_credentials[preset_input.name] = (
CredentialsMetaInput.model_validate(preset_input.data)
)
input_data[preset_input.name] = preset_input.data
return cls(
id=preset.id,
user_id=preset.userId,
updated_at=preset.updatedAt,
graph_id=preset.agentGraphId,
graph_version=preset.agentGraphVersion,
@@ -295,8 +235,6 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
description=preset.description,
is_active=preset.isActive,
inputs=input_data,
credentials=input_credentials,
webhook_id=preset.webhookId,
)
@@ -338,3 +276,6 @@ class LibraryAgentUpdateRequest(pydantic.BaseModel):
is_archived: Optional[bool] = pydantic.Field(
default=None, description="Archive the agent"
)
is_deleted: Optional[bool] = pydantic.Field(
default=None, description="Delete the agent"
)

View File

@@ -1,19 +1,13 @@
import logging
from typing import Any, Optional
from typing import Optional
import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, status
from fastapi.responses import Response
from pydantic import BaseModel, Field
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
from fastapi.responses import JSONResponse
import backend.server.v2.library.db as library_db
import backend.server.v2.library.model as library_model
import backend.server.v2.store.exceptions as store_exceptions
from backend.data.graph import get_graph
from backend.data.model import CredentialsMetaInput
from backend.executor.utils import make_node_credentials_input_map
from backend.integrations.webhooks.utils import setup_webhook_for_block
from backend.util.exceptions import NotFoundError
logger = logging.getLogger(__name__)
@@ -26,7 +20,6 @@ router = APIRouter(
@router.get(
"",
summary="List Library Agents",
responses={
500: {"description": "Server error", "content": {"application/json": {}}},
},
@@ -77,14 +70,14 @@ async def list_library_agents(
page_size=page_size,
)
except Exception as e:
logger.error(f"Could not list library agents for user #{user_id}: {e}")
logger.exception("Listing library agents failed for user %s: %s", user_id, e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
detail={"message": str(e), "hint": "Inspect database connectivity."},
) from e
@router.get("/{library_agent_id}", summary="Get Library Agent")
@router.get("/{library_agent_id}")
async def get_library_agent(
library_agent_id: str,
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
@@ -92,26 +85,8 @@ async def get_library_agent(
return await library_db.get_library_agent(id=library_agent_id, user_id=user_id)
@router.get("/by-graph/{graph_id}")
async def get_library_agent_by_graph_id(
graph_id: str,
version: Optional[int] = Query(default=None),
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> library_model.LibraryAgent:
library_agent = await library_db.get_library_agent_by_graph_id(
user_id, graph_id, version
)
if not library_agent:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Library agent for graph #{graph_id} and user #{user_id} not found",
)
return library_agent
@router.get(
"/marketplace/{store_listing_version_id}",
summary="Get Agent By Store ID",
tags=["store, library"],
response_model=library_model.LibraryAgent | None,
)
@@ -126,22 +101,23 @@ async def get_library_agent_by_store_listing_version_id(
return await library_db.get_library_agent_by_store_version_id(
store_listing_version_id, user_id
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
)
except Exception as e:
logger.error(f"Could not fetch library agent from store version ID: {e}")
logger.exception(
"Retrieving library agent by store version failed for user %s: %s",
user_id,
e,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
detail={
"message": str(e),
"hint": "Check if the store listing ID is valid.",
},
) from e
@router.post(
"",
summary="Add Marketplace Agent",
status_code=status.HTTP_201_CREATED,
responses={
201: {"description": "Agent added successfully"},
@@ -173,20 +149,26 @@ async def add_marketplace_agent_to_library(
user_id=user_id,
)
except store_exceptions.AgentNotFoundError as e:
except store_exceptions.AgentNotFoundError:
logger.warning(
f"Could not find store listing version {store_listing_version_id} "
"to add to library"
"Store listing version %s not found when adding to library",
store_listing_version_id,
)
raise HTTPException(
status_code=404,
detail={
"message": f"Store listing version {store_listing_version_id} not found",
"hint": "Confirm the ID provided.",
},
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except store_exceptions.DatabaseError as e:
logger.error(f"Database error while adding agent to library: {e}", e)
logger.exception("Database error whilst adding agent to library: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Inspect DB logs for details."},
) from e
except Exception as e:
logger.error(f"Unexpected error while adding agent to library: {e}")
logger.exception("Unexpected error while adding agent to library: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
@@ -196,11 +178,11 @@ async def add_marketplace_agent_to_library(
) from e
@router.patch(
@router.put(
"/{library_agent_id}",
summary="Update Library Agent",
status_code=status.HTTP_204_NO_CONTENT,
responses={
200: {"description": "Agent updated successfully"},
204: {"description": "Agent updated successfully"},
500: {"description": "Server error"},
},
)
@@ -208,7 +190,7 @@ async def update_library_agent(
library_agent_id: str,
payload: library_model.LibraryAgentUpdateRequest,
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> library_model.LibraryAgent:
) -> JSONResponse:
"""
Update the library agent with the given fields.
@@ -217,76 +199,40 @@ async def update_library_agent(
payload: Fields to update (auto_update_version, is_favorite, etc.).
user_id: ID of the authenticated user.
Returns:
204 (No Content) on success.
Raises:
HTTPException(500): If a server/database error occurs.
"""
try:
return await library_db.update_library_agent(
await library_db.update_library_agent(
library_agent_id=library_agent_id,
user_id=user_id,
auto_update_version=payload.auto_update_version,
is_favorite=payload.is_favorite,
is_archived=payload.is_archived,
is_deleted=payload.is_deleted,
)
return JSONResponse(
status_code=status.HTTP_204_NO_CONTENT,
content={"message": "Agent updated successfully"},
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
except store_exceptions.DatabaseError as e:
logger.error(f"Database error while updating library agent: {e}")
logger.exception("Database error while updating library agent: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Verify DB connection."},
) from e
except Exception as e:
logger.error(f"Unexpected error while updating library agent: {e}")
logger.exception("Unexpected error while updating library agent: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Check server logs."},
) from e
@router.delete(
"/{library_agent_id}",
summary="Delete Library Agent",
responses={
204: {"description": "Agent deleted successfully"},
404: {"description": "Agent not found"},
500: {"description": "Server error"},
},
)
async def delete_library_agent(
library_agent_id: str,
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> Response:
"""
Soft-delete the specified library agent.
Args:
library_agent_id: ID of the library agent to delete.
user_id: ID of the authenticated user.
Returns:
204 No Content if successful.
Raises:
HTTPException(404): If the agent does not exist.
HTTPException(500): If a server/database error occurs.
"""
try:
await library_db.delete_library_agent(
library_agent_id=library_agent_id, user_id=user_id
)
return Response(status_code=status.HTTP_204_NO_CONTENT)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
@router.post("/{library_agent_id}/fork", summary="Fork Library Agent")
@router.post("/{library_agent_id}/fork")
async def fork_library_agent(
library_agent_id: str,
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
@@ -295,81 +241,3 @@ async def fork_library_agent(
library_agent_id=library_agent_id,
user_id=user_id,
)
class TriggeredPresetSetupParams(BaseModel):
name: str
description: str = ""
trigger_config: dict[str, Any]
agent_credentials: dict[str, CredentialsMetaInput] = Field(default_factory=dict)
@router.post("/{library_agent_id}/setup-trigger")
async def setup_trigger(
library_agent_id: str = Path(..., description="ID of the library agent"),
params: TriggeredPresetSetupParams = Body(),
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> library_model.LibraryAgentPreset:
"""
Sets up a webhook-triggered `LibraryAgentPreset` for a `LibraryAgent`.
Returns the correspondingly created `LibraryAgentPreset` with `webhook_id` set.
"""
library_agent = await library_db.get_library_agent(
id=library_agent_id, user_id=user_id
)
if not library_agent:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Library agent #{library_agent_id} not found",
)
graph = await get_graph(
library_agent.graph_id, version=library_agent.graph_version, user_id=user_id
)
if not graph:
raise HTTPException(
status.HTTP_410_GONE,
f"Graph #{library_agent.graph_id} not accessible (anymore)",
)
if not (trigger_node := graph.webhook_input_node):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Graph #{library_agent.graph_id} does not have a webhook node",
)
trigger_config_with_credentials = {
**params.trigger_config,
**(
make_node_credentials_input_map(graph, params.agent_credentials).get(
trigger_node.id
)
or {}
),
}
new_webhook, feedback = await setup_webhook_for_block(
user_id=user_id,
trigger_block=trigger_node.block,
trigger_config=trigger_config_with_credentials,
)
if not new_webhook:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Could not set up webhook: {feedback}",
)
new_preset = await library_db.create_preset(
user_id=user_id,
preset=library_model.LibraryAgentPresetCreatable(
graph_id=library_agent.graph_id,
graph_version=library_agent.graph_version,
name=params.name,
description=params.description,
inputs=trigger_config_with_credentials,
credentials=params.agent_credentials,
webhook_id=new_webhook.id,
is_active=True,
),
)
return new_preset

View File

@@ -1,23 +1,17 @@
import logging
from typing import Any, Optional
from typing import Annotated, Any, Optional
import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
import backend.server.v2.library.db as db
import backend.server.v2.library.model as models
from backend.data.graph import get_graph
from backend.data.integrations import get_webhook
from backend.executor.utils import add_graph_execution, make_node_credentials_input_map
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks import get_webhook_manager
from backend.integrations.webhooks.utils import setup_webhook_for_block
from backend.executor.utils import add_graph_execution
from backend.util.exceptions import NotFoundError
logger = logging.getLogger(__name__)
credentials_manager = IntegrationCredentialsManager()
router = APIRouter(tags=["presets"])
router = APIRouter()
@router.get(
@@ -55,7 +49,11 @@ async def list_presets(
except Exception as e:
logger.exception("Failed to list presets for user %s: %s", user_id, e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"message": str(e),
"hint": "Ensure the presets DB table is accessible.",
},
)
@@ -83,21 +81,21 @@ async def get_preset(
"""
try:
preset = await db.get_preset(user_id, preset_id)
if not preset:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Preset {preset_id} not found",
)
return preset
except Exception as e:
logger.exception(
"Error retrieving preset %s for user %s: %s", preset_id, user_id, e
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Validate preset ID and retry."},
)
if not preset:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Preset #{preset_id} not found",
)
return preset
@router.post(
"/presets",
@@ -134,7 +132,8 @@ async def create_preset(
except Exception as e:
logger.exception("Preset creation failed for user %s: %s", user_id, e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Check preset payload format."},
)
@@ -162,85 +161,17 @@ async def update_preset(
Raises:
HTTPException: If an error occurs while updating the preset.
"""
current = await get_preset(preset_id, user_id=user_id)
if not current:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Preset #{preset_id} not found")
graph = await get_graph(
current.graph_id,
current.graph_version,
user_id=user_id,
)
if not graph:
raise HTTPException(
status.HTTP_410_GONE,
f"Graph #{current.graph_id} not accessible (anymore)",
)
trigger_inputs_updated, new_webhook, feedback = False, None, None
if (trigger_node := graph.webhook_input_node) and (
preset.inputs is not None and preset.credentials is not None
):
trigger_config_with_credentials = {
**preset.inputs,
**(
make_node_credentials_input_map(graph, preset.credentials).get(
trigger_node.id
)
or {}
),
}
new_webhook, feedback = await setup_webhook_for_block(
user_id=user_id,
trigger_block=graph.webhook_input_node.block,
trigger_config=trigger_config_with_credentials,
for_preset_id=preset_id,
)
trigger_inputs_updated = True
if not new_webhook:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Could not update trigger configuration: {feedback}",
)
try:
updated = await db.update_preset(
user_id=user_id,
preset_id=preset_id,
inputs=preset.inputs,
credentials=preset.credentials,
name=preset.name,
description=preset.description,
is_active=preset.is_active,
return await db.update_preset(
user_id=user_id, preset_id=preset_id, preset=preset
)
except Exception as e:
logger.exception("Preset update failed for user %s: %s", user_id, e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Check preset data and try again."},
)
# Update the webhook as well, if necessary
if trigger_inputs_updated:
updated = await db.set_preset_webhook(
user_id, preset_id, new_webhook.id if new_webhook else None
)
# Clean up webhook if it is now unused
if current.webhook_id and (
current.webhook_id != (new_webhook.id if new_webhook else None)
):
current_webhook = await get_webhook(current.webhook_id)
credentials = (
await credentials_manager.get(user_id, current_webhook.credentials_id)
if current_webhook.credentials_id
else None
)
await get_webhook_manager(
current_webhook.provider
).prune_webhook_if_dangling(user_id, current_webhook.id, credentials)
return updated
@router.delete(
"/presets/{preset_id}",
@@ -262,28 +193,6 @@ async def delete_preset(
Raises:
HTTPException: If an error occurs while deleting the preset.
"""
preset = await db.get_preset(user_id, preset_id)
if not preset:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Preset #{preset_id} not found for user #{user_id}",
)
# Detach and clean up the attached webhook, if any
if preset.webhook_id:
webhook = await get_webhook(preset.webhook_id)
await db.set_preset_webhook(user_id, preset_id, None)
# Clean up webhook if it is now unused
credentials = (
await credentials_manager.get(user_id, webhook.credentials_id)
if webhook.credentials_id
else None
)
await get_webhook_manager(webhook.provider).prune_webhook_if_dangling(
user_id, webhook.id, credentials
)
try:
await db.delete_preset(user_id, preset_id)
except Exception as e:
@@ -292,7 +201,7 @@ async def delete_preset(
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
detail={"message": str(e), "hint": "Ensure preset exists before deleting."},
)
@@ -303,20 +212,24 @@ async def delete_preset(
description="Execute a preset with the given graph and node input for the current user.",
)
async def execute_preset(
graph_id: str,
graph_version: int,
preset_id: str,
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
inputs: dict[str, Any] = Body(..., embed=True, default_factory=dict),
) -> dict[str, Any]: # FIXME: add proper return type
"""
Execute a preset given graph parameters, returning the execution ID on success.
Args:
graph_id (str): ID of the graph to execute.
graph_version (int): Version of the graph to execute.
preset_id (str): ID of the preset to execute.
node_input (Dict[Any, Any]): Input data for the node.
user_id (str): ID of the authenticated user.
inputs (dict[str, Any]): Optionally, additional input data for the graph execution.
Returns:
{id: graph_exec_id}: A response containing the execution ID.
Dict[str, Any]: A response containing the execution ID.
Raises:
HTTPException: If the preset is not found or an error occurs while executing the preset.
@@ -326,18 +239,18 @@ async def execute_preset(
if not preset:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Preset #{preset_id} not found",
detail="Preset not found",
)
# Merge input overrides with preset inputs
merged_node_input = preset.inputs | inputs
merged_node_input = preset.inputs | node_input
execution = await add_graph_execution(
graph_id=graph_id,
user_id=user_id,
graph_id=preset.graph_id,
graph_version=preset.graph_version,
preset_id=preset_id,
inputs=merged_node_input,
preset_id=preset_id,
graph_version=graph_version,
)
logger.debug(f"Execution added: {execution} with input: {merged_node_input}")
@@ -348,6 +261,9 @@ async def execute_preset(
except Exception as e:
logger.exception("Preset execution failed for user %s: %s", user_id, e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"message": str(e),
"hint": "Review preset configuration and graph ID.",
},
)

View File

@@ -50,8 +50,6 @@ async def test_get_library_agents_success(
creator_name="Test Creator",
creator_image_url="",
input_schema={"type": "object", "properties": {}},
credentials_input_schema={"type": "object", "properties": {}},
has_external_trigger=False,
status=library_model.LibraryAgentStatus.COMPLETED,
new_output=False,
can_access_graph=True,
@@ -68,8 +66,6 @@ async def test_get_library_agents_success(
creator_name="Test Creator",
creator_image_url="",
input_schema={"type": "object", "properties": {}},
credentials_input_schema={"type": "object", "properties": {}},
has_external_trigger=False,
status=library_model.LibraryAgentStatus.COMPLETED,
new_output=False,
can_access_graph=False,
@@ -121,57 +117,26 @@ def test_get_library_agents_error(mocker: pytest_mock.MockFixture):
)
@pytest.mark.skip(reason="Mocker Not implemented")
def test_add_agent_to_library_success(mocker: pytest_mock.MockFixture):
mock_library_agent = library_model.LibraryAgent(
id="test-library-agent-id",
graph_id="test-agent-1",
graph_version=1,
name="Test Agent 1",
description="Test Description 1",
image_url=None,
creator_name="Test Creator",
creator_image_url="",
input_schema={"type": "object", "properties": {}},
credentials_input_schema={"type": "object", "properties": {}},
has_external_trigger=False,
status=library_model.LibraryAgentStatus.COMPLETED,
new_output=False,
can_access_graph=True,
is_latest_version=True,
updated_at=FIXED_NOW,
)
mock_db_call = mocker.patch("backend.server.v2.library.db.add_agent_to_library")
mock_db_call.return_value = None
mock_db_call = mocker.patch(
"backend.server.v2.library.db.add_store_agent_to_library"
)
mock_db_call.return_value = mock_library_agent
response = client.post(
"/agents", json={"store_listing_version_id": "test-version-id"}
)
response = client.post("/agents/test-version-id")
assert response.status_code == 201
# Verify the response contains the library agent data
data = library_model.LibraryAgent.model_validate(response.json())
assert data.id == "test-library-agent-id"
assert data.graph_id == "test-agent-1"
mock_db_call.assert_called_once_with(
store_listing_version_id="test-version-id", user_id="test-user-id"
)
@pytest.mark.skip(reason="Mocker Not implemented")
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture):
mock_db_call = mocker.patch(
"backend.server.v2.library.db.add_store_agent_to_library"
)
mock_db_call = mocker.patch("backend.server.v2.library.db.add_agent_to_library")
mock_db_call.side_effect = Exception("Test error")
response = client.post(
"/agents", json={"store_listing_version_id": "test-version-id"}
)
response = client.post("/agents/test-version-id")
assert response.status_code == 500
assert "detail" in response.json() # Verify error response structure
assert response.json()["detail"] == "Failed to add agent to library"
mock_db_call.assert_called_once_with(
store_listing_version_id="test-version-id", user_id="test-user-id"
)

View File

@@ -14,10 +14,7 @@ router = APIRouter()
@router.post(
"/ask",
response_model=ApiResponse,
dependencies=[Depends(auth_middleware)],
summary="Proxy Otto Chat Request",
"/ask", response_model=ApiResponse, dependencies=[Depends(auth_middleware)]
)
async def proxy_otto_request(
request: ChatRequest, user_id: str = Depends(get_user_id)

View File

@@ -259,8 +259,8 @@ def test_ask_otto_unauthenticated(mocker: pytest_mock.MockFixture) -> None:
}
response = client.post("/ask", json=request_data)
# When auth is disabled and Otto API URL is not configured, we get 502 (wrapped from 503)
assert response.status_code == 502
# When auth is disabled and Otto API URL is not configured, we get 503
assert response.status_code == 503
# Restore the override
app.dependency_overrides[autogpt_libs.auth.middleware.auth_middleware] = (

View File

@@ -93,14 +93,6 @@ async def test_get_store_agent_details(mocker):
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
# Mock Profile prisma call
mock_profile = mocker.MagicMock()
mock_profile.userId = "user-id-123"
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
mock_profile_db.return_value.find_first = mocker.AsyncMock(
return_value=mock_profile
)
# Mock StoreListing prisma call - this is what was missing
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(

View File

@@ -34,20 +34,6 @@ class StorageUploadError(MediaUploadError):
pass
class VirusDetectedError(MediaUploadError):
"""Raised when a virus is detected in uploaded file"""
def __init__(self, threat_name: str, message: str | None = None):
self.threat_name = threat_name
super().__init__(message or f"Virus detected: {threat_name}")
class VirusScanError(MediaUploadError):
"""Raised when virus scanning fails"""
pass
class StoreError(Exception):
"""Base exception for store-related errors"""

View File

@@ -8,7 +8,6 @@ from google.cloud import storage
import backend.server.v2.store.exceptions
from backend.util.exceptions import MissingConfigError
from backend.util.settings import Settings
from backend.util.virus_scanner import scan_content_safe
logger = logging.getLogger(__name__)
@@ -68,7 +67,7 @@ async def upload_media(
# Validate file signature/magic bytes
if file.content_type in ALLOWED_IMAGE_TYPES:
# Check image file signatures
if content.startswith(b"\xff\xd8\xff"): # JPEG
if content.startswith(b"\xFF\xD8\xFF"): # JPEG
if file.content_type != "image/jpeg":
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
"File signature does not match content type"
@@ -176,7 +175,6 @@ async def upload_media(
blob.content_type = content_type
file_bytes = await file.read()
await scan_content_safe(file_bytes, filename=unique_filename)
blob.upload_from_string(file_bytes, content_type=content_type)
public_url = blob.public_url

View File

@@ -12,7 +12,6 @@ from autogpt_libs.auth.depends import auth_middleware, get_user_id
import backend.data.block
import backend.data.graph
import backend.server.v2.store.db
import backend.server.v2.store.exceptions
import backend.server.v2.store.image_gen
import backend.server.v2.store.media
import backend.server.v2.store.model
@@ -30,7 +29,6 @@ router = fastapi.APIRouter()
@router.get(
"/profile",
summary="Get user profile",
tags=["store", "private"],
response_model=backend.server.v2.store.model.ProfileDetails,
)
@@ -63,7 +61,6 @@ async def get_profile(
@router.post(
"/profile",
summary="Update user profile",
tags=["store", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
response_model=backend.server.v2.store.model.CreatorDetails,
@@ -110,7 +107,6 @@ async def update_or_create_profile(
@router.get(
"/agents",
summary="List store agents",
tags=["store", "public"],
response_model=backend.server.v2.store.model.StoreAgentsResponse,
)
@@ -183,7 +179,6 @@ async def get_agents(
@router.get(
"/agents/{username}/{agent_name}",
summary="Get specific agent",
tags=["store", "public"],
response_model=backend.server.v2.store.model.StoreAgentDetails,
)
@@ -213,7 +208,6 @@ async def get_agent(username: str, agent_name: str):
@router.get(
"/graph/{store_listing_version_id}",
summary="Get agent graph",
tags=["store"],
)
async def get_graph_meta_by_store_listing_version_id(
@@ -238,7 +232,6 @@ async def get_graph_meta_by_store_listing_version_id(
@router.get(
"/agents/{store_listing_version_id}",
summary="Get agent by version",
tags=["store"],
response_model=backend.server.v2.store.model.StoreAgentDetails,
)
@@ -264,7 +257,6 @@ async def get_store_agent(
@router.post(
"/agents/{username}/{agent_name}/review",
summary="Create agent review",
tags=["store"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
response_model=backend.server.v2.store.model.StoreReview,
@@ -316,7 +308,6 @@ async def create_review(
@router.get(
"/creators",
summary="List store creators",
tags=["store", "public"],
response_model=backend.server.v2.store.model.CreatorsResponse,
)
@@ -368,7 +359,6 @@ async def get_creators(
@router.get(
"/creator/{username}",
summary="Get creator details",
tags=["store", "public"],
response_model=backend.server.v2.store.model.CreatorDetails,
)
@@ -400,7 +390,6 @@ async def get_creator(
############################################
@router.get(
"/myagents",
summary="Get my agents",
tags=["store", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
response_model=backend.server.v2.store.model.MyAgentsResponse,
@@ -423,7 +412,6 @@ async def get_my_agents(
@router.delete(
"/submissions/{submission_id}",
summary="Delete store submission",
tags=["store", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
response_model=bool,
@@ -460,7 +448,6 @@ async def delete_submission(
@router.get(
"/submissions",
summary="List my submissions",
tags=["store", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
response_model=backend.server.v2.store.model.StoreSubmissionsResponse,
@@ -514,7 +501,6 @@ async def get_submissions(
@router.post(
"/submissions",
summary="Create store submission",
tags=["store", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
response_model=backend.server.v2.store.model.StoreSubmission,
@@ -562,7 +548,6 @@ async def create_submission(
@router.post(
"/submissions/media",
summary="Upload submission media",
tags=["store", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
)
@@ -590,25 +575,6 @@ async def upload_submission_media(
user_id=user_id, file=file
)
return media_url
except backend.server.v2.store.exceptions.VirusDetectedError as e:
logger.warning(f"Virus detected in uploaded file: {e.threat_name}")
return fastapi.responses.JSONResponse(
status_code=400,
content={
"detail": f"File rejected due to virus detection: {e.threat_name}",
"error_type": "virus_detected",
"threat_name": e.threat_name,
},
)
except backend.server.v2.store.exceptions.VirusScanError as e:
logger.error(f"Virus scanning failed: {str(e)}")
return fastapi.responses.JSONResponse(
status_code=503,
content={
"detail": "Virus scanning service unavailable. Please try again later.",
"error_type": "virus_scan_failed",
},
)
except Exception:
logger.exception("Exception occurred whilst uploading submission media")
return fastapi.responses.JSONResponse(
@@ -619,7 +585,6 @@ async def upload_submission_media(
@router.post(
"/submissions/generate_image",
summary="Generate submission image",
tags=["store", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
)
@@ -681,7 +646,6 @@ async def generate_image(
@router.get(
"/download/agents/{store_listing_version_id}",
summary="Download agent file",
tags=["store", "public"],
)
async def download_agent_file(

View File

@@ -13,9 +13,7 @@ router = APIRouter()
settings = Settings()
@router.post(
"/verify", response_model=TurnstileVerifyResponse, summary="Verify Turnstile Token"
)
@router.post("/verify", response_model=TurnstileVerifyResponse)
async def verify_turnstile_token(
request: TurnstileVerifyRequest,
) -> TurnstileVerifyResponse:

View File

@@ -2,17 +2,7 @@ import functools
import logging
import os
import time
from typing import (
Any,
Awaitable,
Callable,
Coroutine,
Literal,
ParamSpec,
Tuple,
TypeVar,
overload,
)
from typing import Any, Awaitable, Callable, Coroutine, ParamSpec, Tuple, TypeVar
from pydantic import BaseModel
@@ -82,115 +72,37 @@ def async_time_measured(
return async_wrapper
@overload
def error_logged(
*, swallow: Literal[True]
) -> Callable[[Callable[P, T]], Callable[P, T | None]]: ...
@overload
def error_logged(
*, swallow: Literal[False]
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
@overload
def error_logged() -> Callable[[Callable[P, T]], Callable[P, T | None]]: ...
def error_logged(
*, swallow: bool = True
) -> (
Callable[[Callable[P, T]], Callable[P, T | None]]
| Callable[[Callable[P, T]], Callable[P, T]]
):
def error_logged(func: Callable[P, T]) -> Callable[P, T | None]:
"""
Decorator to log any exceptions raised by a function, with optional suppression.
Args:
swallow: Whether to suppress the exception (True) or re-raise it (False). Default is True.
Usage:
@error_logged() # Default behavior (swallow errors)
@error_logged(swallow=False) # Log and re-raise errors
Decorator to suppress and log any exceptions raised by a function.
"""
def decorator(f: Callable[P, T]) -> Callable[P, T | None]:
@functools.wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
try:
return f(*args, **kwargs)
except Exception as e:
logger.exception(
f"Error when calling function {f.__name__} with arguments {args} {kwargs}: {e}"
)
if not swallow:
raise
return None
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
try:
return func(*args, **kwargs)
except Exception as e:
logger.exception(
f"Error when calling function {func.__name__} with arguments {args} {kwargs}: {e}"
)
return wrapper
return decorator
return wrapper
@overload
def async_error_logged(
*, swallow: Literal[True]
) -> Callable[
[Callable[P, Coroutine[Any, Any, T]]], Callable[P, Coroutine[Any, Any, T | None]]
]: ...
@overload
def async_error_logged(
*, swallow: Literal[False]
) -> Callable[
[Callable[P, Coroutine[Any, Any, T]]], Callable[P, Coroutine[Any, Any, T]]
]: ...
@overload
def async_error_logged() -> Callable[
[Callable[P, Coroutine[Any, Any, T]]],
Callable[P, Coroutine[Any, Any, T | None]],
]: ...
def async_error_logged(*, swallow: bool = True) -> (
Callable[
[Callable[P, Coroutine[Any, Any, T]]],
Callable[P, Coroutine[Any, Any, T | None]],
]
| Callable[
[Callable[P, Coroutine[Any, Any, T]]], Callable[P, Coroutine[Any, Any, T]]
]
):
func: Callable[P, Coroutine[Any, Any, T]],
) -> Callable[P, Coroutine[Any, Any, T | None]]:
"""
Decorator to log any exceptions raised by an async function, with optional suppression.
Args:
swallow: Whether to suppress the exception (True) or re-raise it (False). Default is True.
Usage:
@async_error_logged() # Default behavior (swallow errors)
@async_error_logged(swallow=False) # Log and re-raise errors
Decorator to suppress and log any exceptions raised by an async function.
"""
def decorator(
f: Callable[P, Coroutine[Any, Any, T]]
) -> Callable[P, Coroutine[Any, Any, T | None]]:
@functools.wraps(f)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
try:
return await f(*args, **kwargs)
except Exception as e:
logger.exception(
f"Error when calling async function {f.__name__} with arguments {args} {kwargs}: {e}"
)
if not swallow:
raise
return None
@functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
try:
return await func(*args, **kwargs)
except Exception as e:
logger.exception(
f"Error when calling async function {func.__name__} with arguments {args} {kwargs}: {e}"
)
return wrapper
return decorator
return wrapper

View File

@@ -1,74 +0,0 @@
import time
import pytest
from backend.util.decorator import async_error_logged, error_logged, time_measured
@time_measured
def example_function(a: int, b: int, c: int) -> int:
time.sleep(0.5)
return a + b + c
@error_logged(swallow=True)
def example_function_with_error_swallowed(a: int, b: int, c: int) -> int:
raise ValueError("This error should be swallowed")
@error_logged(swallow=False)
def example_function_with_error_not_swallowed(a: int, b: int, c: int) -> int:
raise ValueError("This error should NOT be swallowed")
@async_error_logged(swallow=True)
async def async_function_with_error_swallowed() -> int:
raise ValueError("This async error should be swallowed")
@async_error_logged(swallow=False)
async def async_function_with_error_not_swallowed() -> int:
raise ValueError("This async error should NOT be swallowed")
def test_timer_decorator():
"""Test that the time_measured decorator correctly measures execution time."""
info, res = example_function(1, 2, 3)
assert info.cpu_time >= 0
assert info.wall_time >= 0.4
assert res == 6
def test_error_decorator_swallow_true():
"""Test that error_logged(swallow=True) logs and swallows errors."""
res = example_function_with_error_swallowed(1, 2, 3)
assert res is None
def test_error_decorator_swallow_false():
"""Test that error_logged(swallow=False) logs errors but re-raises them."""
with pytest.raises(ValueError, match="This error should NOT be swallowed"):
example_function_with_error_not_swallowed(1, 2, 3)
def test_async_error_decorator_swallow_true():
"""Test that async_error_logged(swallow=True) logs and swallows errors."""
import asyncio
async def run_test():
res = await async_function_with_error_swallowed()
return res
res = asyncio.run(run_test())
assert res is None
def test_async_error_decorator_swallow_false():
"""Test that async_error_logged(swallow=False) logs errors but re-raises them."""
import asyncio
async def run_test():
await async_function_with_error_not_swallowed()
with pytest.raises(ValueError, match="This async error should NOT be swallowed"):
asyncio.run(run_test())

View File

@@ -10,10 +10,6 @@ class NeedConfirmation(Exception):
"""The user must explicitly confirm that they want to proceed"""
class NotAuthorizedError(ValueError):
"""The user is not authorized to perform the requested operation"""
class InsufficientBalanceError(ValueError):
user_id: str
message: str

View File

@@ -9,7 +9,6 @@ from urllib.parse import urlparse
from backend.util.request import Requests
from backend.util.type import MediaFileType
from backend.util.virus_scanner import scan_content_safe
TEMP_DIR = Path(tempfile.gettempdir()).resolve()
@@ -106,11 +105,7 @@ async def store_media_file(
extension = _extension_from_mime(mime_type)
filename = f"{uuid.uuid4()}{extension}"
target_path = _ensure_inside_base(base_path / filename, base_path)
content = base64.b64decode(b64_content)
# Virus scan the base64 content before writing
await scan_content_safe(content, filename=filename)
target_path.write_bytes(content)
target_path.write_bytes(base64.b64decode(b64_content))
elif file.startswith(("http://", "https://")):
# URL
@@ -120,9 +115,6 @@ async def store_media_file(
# Download and save
resp = await Requests().get(file)
# Virus scan the downloaded content before writing
await scan_content_safe(resp.content, filename=filename)
target_path.write_bytes(resp.content)
else:

View File

@@ -14,37 +14,8 @@ def to_dict(data) -> dict:
return jsonable_encoder(data)
def dumps(data: Any, *args: Any, **kwargs: Any) -> str:
"""
Serialize data to JSON string with automatic conversion of Pydantic models and complex types.
This function converts the input data to a JSON-serializable format using FastAPI's
jsonable_encoder before dumping to JSON. It handles Pydantic models, complex types,
and ensures proper serialization.
Parameters
----------
data : Any
The data to serialize. Can be any type including Pydantic models, dicts, lists, etc.
*args : Any
Additional positional arguments passed to json.dumps()
**kwargs : Any
Additional keyword arguments passed to json.dumps() (e.g., indent, separators)
Returns
-------
str
JSON string representation of the data
Examples
--------
>>> dumps({"name": "Alice", "age": 30})
'{"name": "Alice", "age": 30}'
>>> dumps(pydantic_model_instance, indent=2)
'{\n "field1": "value1",\n "field2": "value2"\n}'
"""
return json.dumps(to_dict(data), *args, **kwargs)
def dumps(data) -> str:
return json.dumps(to_dict(data))
T = TypeVar("T")

View File

@@ -1,4 +1,4 @@
import logging
from logging import Logger
from backend.util.settings import AppEnvironment, BehaveAs, Settings
@@ -6,6 +6,8 @@ settings = Settings()
def configure_logging():
import logging
import autogpt_libs.logging.config
if (
@@ -23,7 +25,7 @@ def configure_logging():
class TruncatedLogger:
def __init__(
self,
logger: logging.Logger,
logger: Logger,
prefix: str = "",
metadata: dict | None = None,
max_length: int = 1000,
@@ -63,13 +65,3 @@ class TruncatedLogger:
if len(text) > self.max_length:
text = text[: self.max_length] + "..."
return text
class PrefixFilter(logging.Filter):
def __init__(self, prefix: str):
super().__init__()
self.prefix = prefix
def filter(self, record):
record.msg = f"{self.prefix} {record.msg}"
return True

View File

@@ -1,206 +0,0 @@
from copy import deepcopy
from typing import Any
from tiktoken import encoding_for_model
from backend.util import json
# ---------------------------------------------------------------------------#
# INTERNAL UTILITIES #
# ---------------------------------------------------------------------------#
def _tok_len(text: str, enc) -> int:
"""True token length of *text* in tokenizer *enc* (no wrapper cost)."""
return len(enc.encode(str(text)))
def _msg_tokens(msg: dict, enc) -> int:
"""
OpenAI counts ≈3 wrapper tokens per chat message, plus 1 if "name"
is present, plus the tokenised content length.
"""
WRAPPER = 3 + (1 if "name" in msg else 0)
return WRAPPER + _tok_len(msg.get("content") or "", enc)
def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
"""
Return *text* shortened to ≈max_tok tokens by keeping the head & tail
and inserting an ellipsis token in the middle.
"""
ids = enc.encode(str(text))
if len(ids) <= max_tok:
return text # nothing to do
# Split the allowance between the two ends:
head = max_tok // 2 - 1 # -1 for the ellipsis
tail = max_tok - head - 1
mid = enc.encode("")
return enc.decode(ids[:head] + mid + ids[-tail:])
# ---------------------------------------------------------------------------#
# PUBLIC API #
# ---------------------------------------------------------------------------#
def compress_prompt(
messages: list[dict],
target_tokens: int,
*,
model: str = "gpt-4o",
reserve: int = 2_048,
start_cap: int = 8_192,
floor_cap: int = 128,
lossy_ok: bool = True,
) -> list[dict]:
"""
Shrink *messages* so that::
token_count(prompt) + reserve ≤ target_tokens
Strategy
--------
1. **Token-aware truncation** progressively halve a per-message cap
(`start_cap`, `start_cap/2`, … `floor_cap`) and apply it to the
*content* of every message except the first and last. Tool shells
are included: we keep the envelope but shorten huge payloads.
2. **Middle-out deletion** if still over the limit, delete whole
messages working outward from the centre, **skipping** any message
that contains ``tool_calls`` or has ``role == "tool"``.
3. **Last-chance trim** if still too big, truncate the *first* and
*last* message bodies down to `floor_cap` tokens.
4. If the prompt is *still* too large:
• raise ``ValueError`` when ``lossy_ok == False`` (default)
• return the partially-trimmed prompt when ``lossy_ok == True``
Parameters
----------
messages Complete chat history (will be deep-copied).
model Model name; passed to tiktoken to pick the right
tokenizer (gpt-4o → 'o200k_base', others fallback).
target_tokens Hard ceiling for prompt size **excluding** the model's
forthcoming answer.
reserve How many tokens you want to leave available for that
answer (`max_tokens` in your subsequent completion call).
start_cap Initial per-message truncation ceiling (tokens).
floor_cap Lowest cap we'll accept before moving to deletions.
lossy_ok If *True* return best-effort prompt instead of raising
after all trim passes have been exhausted.
Returns
-------
list[dict] A *new* messages list that abides by the rules above.
"""
enc = encoding_for_model(model) # best-match tokenizer
msgs = deepcopy(messages) # never mutate caller
def total_tokens() -> int:
"""Current size of *msgs* in tokens."""
return sum(_msg_tokens(m, enc) for m in msgs)
original_token_count = total_tokens()
if original_token_count + reserve <= target_tokens:
return msgs
# ---- STEP 0 : normalise content --------------------------------------
# Convert non-string payloads to strings so token counting is coherent.
for m in msgs[1:-1]: # keep the first & last intact
if not isinstance(m.get("content"), str) and m.get("content") is not None:
# Reasonable 20k-char ceiling prevents pathological blobs
content_str = json.dumps(m["content"], separators=(",", ":"))
if len(content_str) > 20_000:
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
m["content"] = content_str
# ---- STEP 1 : token-aware truncation ---------------------------------
cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for m in msgs[1:-1]: # keep first & last intact
if _tok_len(m.get("content") or "", enc) > cap:
m["content"] = _truncate_middle_tokens(m["content"], enc, cap)
cap //= 2 # tighten the screw
# ---- STEP 2 : middle-out deletion -----------------------------------
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
centre = len(msgs) // 2
# Build a symmetrical centre-out index walk: centre, centre+1, centre-1, ...
order = [centre] + [
i
for pair in zip(range(centre + 1, len(msgs) - 1), range(centre - 1, 0, -1))
for i in pair
]
removed = False
for i in order:
msg = msgs[i]
if "tool_calls" in msg or msg.get("role") == "tool":
continue # protect tool shells
del msgs[i]
removed = True
break
if not removed: # nothing more we can drop
break
# ---- STEP 3 : final safety-net trim on first & last ------------------
cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for idx in (0, -1): # first and last
text = msgs[idx].get("content") or ""
if _tok_len(text, enc) > cap:
msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap)
cap //= 2 # tighten the screw
# ---- STEP 4 : success or fail-gracefully -----------------------------
if total_tokens() + reserve > target_tokens and not lossy_ok:
raise ValueError(
"compress_prompt: prompt still exceeds budget "
f"({total_tokens() + reserve} > {target_tokens})."
)
return msgs
def estimate_token_count(
messages: list[dict],
*,
model: str = "gpt-4o",
) -> int:
"""
Return the true token count of *messages* when encoded for *model*.
Parameters
----------
messages Complete chat history.
model Model name; passed to tiktoken to pick the right
tokenizer (gpt-4o → 'o200k_base', others fallback).
Returns
-------
int Token count.
"""
enc = encoding_for_model(model) # best-match tokenizer
return sum(_msg_tokens(m, enc) for m in messages)
def estimate_token_count_str(
text: Any,
*,
model: str = "gpt-4o",
) -> int:
"""
Return the true token count of *text* when encoded for *model*.
Parameters
----------
text Input text.
model Model name; passed to tiktoken to pick the right
tokenizer (gpt-4o → 'o200k_base', others fallback).
Returns
-------
int Token count.
"""
enc = encoding_for_model(model) # best-match tokenizer
text = json.dumps(text) if not isinstance(text, str) else text
return _tok_len(text, enc)

View File

@@ -353,10 +353,6 @@ class Requests:
max_redirects: int = 10,
**kwargs,
) -> Response:
# Convert auth tuple to aiohttp.BasicAuth if necessary
if "auth" in kwargs and isinstance(kwargs["auth"], tuple):
kwargs["auth"] = aiohttp.BasicAuth(*kwargs["auth"])
if files is not None:
if json is not None:
raise ValueError(
@@ -434,13 +430,7 @@ class Requests:
) as response:
if self.raise_for_status:
try:
response.raise_for_status()
except ClientResponseError as e:
body = await response.read()
raise Exception(
f"HTTP {response.status} Error: {response.reason}, Body: {body.decode(errors='replace')}"
) from e
response.raise_for_status()
# If allowed and a redirect is received, follow the redirect manually
if allow_redirects and response.status in (301, 302, 303, 307, 308):

View File

@@ -31,7 +31,7 @@ from tenacity import (
wait_exponential_jitter,
)
import backend.util.exceptions as exceptions
from backend.util.exceptions import InsufficientBalanceError
from backend.util.json import to_dict
from backend.util.metrics import sentry_init
from backend.util.process import AppProcess, get_service_name
@@ -106,13 +106,7 @@ EXCEPTION_MAPPING = {
ValueError,
TimeoutError,
ConnectionError,
*[
ErrorType
for _, ErrorType in inspect.getmembers(exceptions)
if inspect.isclass(ErrorType)
and issubclass(ErrorType, Exception)
and ErrorType.__module__ == exceptions.__name__
],
InsufficientBalanceError,
]
}

View File

@@ -238,31 +238,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="The Discord channel for the platform",
)
clamav_service_host: str = Field(
default="localhost",
description="The host for the ClamAV daemon",
)
clamav_service_port: int = Field(
default=3310,
description="The port for the ClamAV daemon",
)
clamav_service_timeout: int = Field(
default=60,
description="The timeout in seconds for the ClamAV daemon",
)
clamav_service_enabled: bool = Field(
default=True,
description="Whether virus scanning is enabled or not",
)
clamav_max_concurrency: int = Field(
default=10,
description="The maximum number of concurrent scans to perform",
)
clamav_mark_failed_scans_as_clean: bool = Field(
default=False,
description="Whether to mark failed scans as clean or not",
)
@field_validator("platform_base_url", "frontend_base_url")
@classmethod
def validate_platform_base_url(cls, v: str, info: ValidationInfo) -> str:

View File

@@ -1,6 +1,5 @@
import json
import types
from typing import Any, Type, TypeVar, Union, cast, get_args, get_origin, overload
from typing import Any, Type, TypeVar, cast, get_args, get_origin
from prisma import Json as PrismaJson
@@ -105,37 +104,9 @@ def __convert_bool(value: Any) -> bool:
return bool(value)
def _try_convert(value: Any, target_type: Any, raise_on_mismatch: bool) -> Any:
def _try_convert(value: Any, target_type: Type, raise_on_mismatch: bool) -> Any:
origin = get_origin(target_type)
args = get_args(target_type)
# Handle Union types (including Optional which is Union[T, None])
if origin is Union or origin is types.UnionType:
# Handle None values for Optional types
if value is None:
if type(None) in args:
return None
elif raise_on_mismatch:
raise TypeError(f"Value {value} is not of expected type {target_type}")
else:
return value
# Try to convert to each type in the union, excluding None
non_none_types = [arg for arg in args if arg is not type(None)]
# Try each type in the union, using the original raise_on_mismatch behavior
for arg_type in non_none_types:
try:
return _try_convert(value, arg_type, raise_on_mismatch)
except (TypeError, ValueError, ConversionError):
continue
# If no conversion succeeded
if raise_on_mismatch:
raise TypeError(f"Value {value} is not of expected type {target_type}")
else:
return value
if origin is None:
origin = target_type
if origin not in [list, dict, tuple, str, set, int, float, bool]:
@@ -218,19 +189,11 @@ def type_match(value: Any, target_type: Type[T]) -> T:
return cast(T, _try_convert(value, target_type, raise_on_mismatch=True))
@overload
def convert(value: Any, target_type: Type[T]) -> T: ...
@overload
def convert(value: Any, target_type: Any) -> Any: ...
def convert(value: Any, target_type: Any) -> Any:
def convert(value: Any, target_type: Type[T]) -> T:
try:
if isinstance(value, PrismaJson):
value = value.data
return _try_convert(value, target_type, raise_on_mismatch=False)
return cast(T, _try_convert(value, target_type, raise_on_mismatch=False))
except Exception as e:
raise ConversionError(f"Failed to convert {value} to {target_type}") from e
@@ -240,7 +203,6 @@ class FormattedStringType(str):
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
_ = source_type # unused parameter required by pydantic
return handler(str)
@classmethod

View File

@@ -1,209 +0,0 @@
import asyncio
import io
import logging
import time
from typing import Optional, Tuple
import aioclamd
from pydantic import BaseModel
from pydantic_settings import BaseSettings
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
class VirusScanResult(BaseModel):
is_clean: bool
scan_time_ms: int
file_size: int
threat_name: Optional[str] = None
class VirusScannerSettings(BaseSettings):
# Tunables for the scanner layer (NOT the ClamAV daemon).
clamav_service_host: str = "localhost"
clamav_service_port: int = 3310
clamav_service_timeout: int = 60
clamav_service_enabled: bool = True
# If the service is disabled, all files are considered clean.
mark_failed_scans_as_clean: bool = False
# Client-side protective limits
max_scan_size: int = 2 * 1024 * 1024 * 1024 # 2 GB guard-rail in memory
min_chunk_size: int = 128 * 1024 # 128 KB hard floor
max_retries: int = 8 # halve ≤ max_retries times
# Concurrency throttle toward the ClamAV daemon. Do *NOT* simply turn this
# up to the number of CPU cores; keep it ≤ (MaxThreads / pods) 1.
max_concurrency: int = 5
class VirusScannerService:
"""Fully-async ClamAV wrapper using **aioclamd**.
• Reuses a single `ClamdAsyncClient` connection (aioclamd keeps the socket open).
• Throttles concurrent `INSTREAM` calls with an `asyncio.Semaphore` so we don't exhaust daemon worker threads or file descriptors.
• Falls back to progressively smaller chunk sizes when the daemon rejects a stream as too large.
"""
def __init__(self, settings: VirusScannerSettings) -> None:
self.settings = settings
self._client = aioclamd.ClamdAsyncClient(
host=settings.clamav_service_host,
port=settings.clamav_service_port,
timeout=settings.clamav_service_timeout,
)
self._sem = asyncio.Semaphore(settings.max_concurrency)
# ------------------------------------------------------------------ #
# Helpers
# ------------------------------------------------------------------ #
@staticmethod
def _parse_raw(raw: Optional[dict]) -> Tuple[bool, Optional[str]]:
"""
Convert aioclamd output to (infected?, threat_name).
Returns (False, None) for clean.
"""
if not raw:
return False, None
status, threat = next(iter(raw.values()))
return status == "FOUND", threat
async def _instream(self, chunk: bytes) -> Tuple[bool, Optional[str]]:
"""Scan **one** chunk with concurrency control."""
async with self._sem:
try:
raw = await self._client.instream(io.BytesIO(chunk))
return self._parse_raw(raw)
except (BrokenPipeError, ConnectionResetError) as exc:
raise RuntimeError("size-limit") from exc
except Exception as exc:
if "INSTREAM size limit exceeded" in str(exc):
raise RuntimeError("size-limit") from exc
raise
# ------------------------------------------------------------------ #
# Public API
# ------------------------------------------------------------------ #
async def scan_file(
self, content: bytes, *, filename: str = "unknown"
) -> VirusScanResult:
"""
Scan `content`. Returns a result object or raises on infrastructure
failure (unreachable daemon, etc.).
The algorithm always tries whole-file first. If the daemon refuses
on size grounds, it falls back to chunked parallel scanning.
"""
if not self.settings.clamav_service_enabled:
logger.warning(f"Virus scanning disabled accepting {filename}")
return VirusScanResult(
is_clean=True, scan_time_ms=0, file_size=len(content)
)
if len(content) > self.settings.max_scan_size:
logger.warning(
f"File {filename} ({len(content)} bytes) exceeds client max scan size ({self.settings.max_scan_size}); Stopping virus scan"
)
return VirusScanResult(
is_clean=self.settings.mark_failed_scans_as_clean,
file_size=len(content),
scan_time_ms=0,
)
# Ensure daemon is reachable (small RTT check)
if not await self._client.ping():
raise RuntimeError("ClamAV service is unreachable")
start = time.monotonic()
chunk_size = len(content) # Start with full content length
for retry in range(self.settings.max_retries):
# For small files, don't check min_chunk_size limit
if chunk_size < self.settings.min_chunk_size and chunk_size < len(content):
break
logger.debug(
f"Scanning {filename} with chunk size: {chunk_size // 1_048_576} MB (retry {retry + 1}/{self.settings.max_retries})"
)
try:
tasks = [
asyncio.create_task(self._instream(content[o : o + chunk_size]))
for o in range(0, len(content), chunk_size)
]
for coro in asyncio.as_completed(tasks):
infected, threat = await coro
if infected:
for t in tasks:
if not t.done():
t.cancel()
return VirusScanResult(
is_clean=False,
threat_name=threat,
file_size=len(content),
scan_time_ms=int((time.monotonic() - start) * 1000),
)
# All chunks clean
return VirusScanResult(
is_clean=True,
file_size=len(content),
scan_time_ms=int((time.monotonic() - start) * 1000),
)
except RuntimeError as exc:
if str(exc) == "size-limit":
chunk_size //= 2
continue
logger.error(f"Cannot scan {filename}: {exc}")
raise
# Phase 3 give up but warn
logger.warning(
f"Unable to virus scan {filename} ({len(content)} bytes) even with minimum chunk size ({self.settings.min_chunk_size} bytes). Recommend manual review."
)
return VirusScanResult(
is_clean=self.settings.mark_failed_scans_as_clean,
file_size=len(content),
scan_time_ms=int((time.monotonic() - start) * 1000),
)
_scanner: Optional[VirusScannerService] = None
def get_virus_scanner() -> VirusScannerService:
global _scanner
if _scanner is None:
_settings = VirusScannerSettings(
clamav_service_host=settings.config.clamav_service_host,
clamav_service_port=settings.config.clamav_service_port,
clamav_service_enabled=settings.config.clamav_service_enabled,
max_concurrency=settings.config.clamav_max_concurrency,
mark_failed_scans_as_clean=settings.config.clamav_mark_failed_scans_as_clean,
)
_scanner = VirusScannerService(_settings)
return _scanner
async def scan_content_safe(content: bytes, *, filename: str = "unknown") -> None:
"""
Helper function to scan content and raise appropriate exceptions.
Raises:
VirusDetectedError: If virus is found
VirusScanError: If scanning fails
"""
from backend.server.v2.store.exceptions import VirusDetectedError, VirusScanError
try:
result = await get_virus_scanner().scan_file(content, filename=filename)
if not result.is_clean:
threat_name = result.threat_name or "Unknown threat"
logger.warning(f"Virus detected in file {filename}: {threat_name}")
raise VirusDetectedError(
threat_name, f"File rejected due to virus detection: {threat_name}"
)
logger.info(f"File {filename} passed virus scan in {result.scan_time_ms}ms")
except VirusDetectedError:
raise
except Exception as e:
logger.error(f"Virus scanning failed for {filename}: {str(e)}")
raise VirusScanError(f"Virus scanning failed: {str(e)}") from e

View File

@@ -1,253 +0,0 @@
import asyncio
from unittest.mock import AsyncMock, Mock, patch
import pytest
from backend.server.v2.store.exceptions import VirusDetectedError, VirusScanError
from backend.util.virus_scanner import (
VirusScannerService,
VirusScannerSettings,
VirusScanResult,
get_virus_scanner,
scan_content_safe,
)
class TestVirusScannerService:
@pytest.fixture
def scanner_settings(self):
return VirusScannerSettings(
clamav_service_host="localhost",
clamav_service_port=3310,
clamav_service_enabled=True,
max_scan_size=10 * 1024 * 1024, # 10MB for testing
mark_failed_scans_as_clean=False, # For testing, failed scans should be clean
)
@pytest.fixture
def scanner(self, scanner_settings):
return VirusScannerService(scanner_settings)
@pytest.fixture
def disabled_scanner(self):
settings = VirusScannerSettings(clamav_service_enabled=False)
return VirusScannerService(settings)
def test_scanner_initialization(self, scanner_settings):
scanner = VirusScannerService(scanner_settings)
assert scanner.settings.clamav_service_host == "localhost"
assert scanner.settings.clamav_service_port == 3310
assert scanner.settings.clamav_service_enabled is True
@pytest.mark.asyncio
async def test_scan_disabled_returns_clean(self, disabled_scanner):
content = b"test file content"
result = await disabled_scanner.scan_file(content, filename="test.txt")
assert result.is_clean is True
assert result.threat_name is None
assert result.file_size == len(content)
assert result.scan_time_ms == 0
@pytest.mark.asyncio
async def test_scan_file_too_large(self, scanner):
# Create content larger than max_scan_size
large_content = b"x" * (scanner.settings.max_scan_size + 1)
# Large files behavior depends on mark_failed_scans_as_clean setting
result = await scanner.scan_file(large_content, filename="large_file.txt")
assert result.is_clean == scanner.settings.mark_failed_scans_as_clean
assert result.file_size == len(large_content)
assert result.scan_time_ms == 0
@pytest.mark.asyncio
async def test_scan_file_too_large_both_configurations(self):
"""Test large file handling with both mark_failed_scans_as_clean configurations"""
large_content = b"x" * (10 * 1024 * 1024 + 1) # Larger than 10MB
# Test with mark_failed_scans_as_clean=True
settings_clean = VirusScannerSettings(
max_scan_size=10 * 1024 * 1024, mark_failed_scans_as_clean=True
)
scanner_clean = VirusScannerService(settings_clean)
result_clean = await scanner_clean.scan_file(
large_content, filename="large_file.txt"
)
assert result_clean.is_clean is True
# Test with mark_failed_scans_as_clean=False
settings_dirty = VirusScannerSettings(
max_scan_size=10 * 1024 * 1024, mark_failed_scans_as_clean=False
)
scanner_dirty = VirusScannerService(settings_dirty)
result_dirty = await scanner_dirty.scan_file(
large_content, filename="large_file.txt"
)
assert result_dirty.is_clean is False
# Note: ping method was removed from current implementation
@pytest.mark.asyncio
async def test_scan_clean_file(self, scanner):
async def mock_instream(_):
await asyncio.sleep(0.001) # Small delay to ensure timing > 0
return None # No virus detected
mock_client = Mock()
mock_client.ping = AsyncMock(return_value=True)
mock_client.instream = AsyncMock(side_effect=mock_instream)
# Replace the client instance that was created in the constructor
scanner._client = mock_client
content = b"clean file content"
result = await scanner.scan_file(content, filename="clean.txt")
assert result.is_clean is True
assert result.threat_name is None
assert result.file_size == len(content)
assert result.scan_time_ms > 0
@pytest.mark.asyncio
async def test_scan_infected_file(self, scanner):
async def mock_instream(_):
await asyncio.sleep(0.001) # Small delay to ensure timing > 0
return {"stream": ("FOUND", "Win.Test.EICAR_HDB-1")}
mock_client = Mock()
mock_client.ping = AsyncMock(return_value=True)
mock_client.instream = AsyncMock(side_effect=mock_instream)
# Replace the client instance that was created in the constructor
scanner._client = mock_client
content = b"infected file content"
result = await scanner.scan_file(content, filename="infected.txt")
assert result.is_clean is False
assert result.threat_name == "Win.Test.EICAR_HDB-1"
assert result.file_size == len(content)
assert result.scan_time_ms > 0
@pytest.mark.asyncio
async def test_scan_clamav_unavailable_fail_safe(self, scanner):
mock_client = Mock()
mock_client.ping = AsyncMock(return_value=False)
# Replace the client instance that was created in the constructor
scanner._client = mock_client
content = b"test content"
with pytest.raises(RuntimeError, match="ClamAV service is unreachable"):
await scanner.scan_file(content, filename="test.txt")
@pytest.mark.asyncio
async def test_scan_error_fail_safe(self, scanner):
mock_client = Mock()
mock_client.ping = AsyncMock(return_value=True)
mock_client.instream = AsyncMock(side_effect=Exception("Scanning error"))
# Replace the client instance that was created in the constructor
scanner._client = mock_client
content = b"test content"
with pytest.raises(Exception, match="Scanning error"):
await scanner.scan_file(content, filename="test.txt")
# Note: scan_file_method and scan_upload_file tests removed as these APIs don't exist in current implementation
def test_get_virus_scanner_singleton(self):
scanner1 = get_virus_scanner()
scanner2 = get_virus_scanner()
# Should return the same instance
assert scanner1 is scanner2
# Note: client_reuse test removed as _get_client method doesn't exist in current implementation
def test_scan_result_model(self):
# Test VirusScanResult model
result = VirusScanResult(
is_clean=False, threat_name="Test.Virus", scan_time_ms=150, file_size=1024
)
assert result.is_clean is False
assert result.threat_name == "Test.Virus"
assert result.scan_time_ms == 150
assert result.file_size == 1024
@pytest.mark.asyncio
async def test_concurrent_scans(self, scanner):
async def mock_instream(_):
await asyncio.sleep(0.001) # Small delay to ensure timing > 0
return None
mock_client = Mock()
mock_client.ping = AsyncMock(return_value=True)
mock_client.instream = AsyncMock(side_effect=mock_instream)
# Replace the client instance that was created in the constructor
scanner._client = mock_client
content1 = b"file1 content"
content2 = b"file2 content"
# Run concurrent scans
results = await asyncio.gather(
scanner.scan_file(content1, filename="file1.txt"),
scanner.scan_file(content2, filename="file2.txt"),
)
assert len(results) == 2
assert all(result.is_clean for result in results)
assert results[0].file_size == len(content1)
assert results[1].file_size == len(content2)
assert all(result.scan_time_ms > 0 for result in results)
class TestHelperFunctions:
"""Test the helper functions scan_content_safe"""
@pytest.mark.asyncio
async def test_scan_content_safe_clean(self):
"""Test scan_content_safe with clean content"""
with patch("backend.util.virus_scanner.get_virus_scanner") as mock_get_scanner:
mock_scanner = Mock()
mock_scanner.scan_file = AsyncMock()
mock_scanner.scan_file.return_value = Mock(
is_clean=True, threat_name=None, scan_time_ms=50, file_size=100
)
mock_get_scanner.return_value = mock_scanner
# Should not raise any exception
await scan_content_safe(b"clean content", filename="test.txt")
@pytest.mark.asyncio
async def test_scan_content_safe_infected(self):
"""Test scan_content_safe with infected content"""
with patch("backend.util.virus_scanner.get_virus_scanner") as mock_get_scanner:
mock_scanner = Mock()
mock_scanner.scan_file = AsyncMock()
mock_scanner.scan_file.return_value = Mock(
is_clean=False, threat_name="Test.Virus", scan_time_ms=50, file_size=100
)
mock_get_scanner.return_value = mock_scanner
with pytest.raises(VirusDetectedError) as exc_info:
await scan_content_safe(b"infected content", filename="virus.txt")
assert exc_info.value.threat_name == "Test.Virus"
@pytest.mark.asyncio
async def test_scan_content_safe_scan_error(self):
"""Test scan_content_safe when scanning fails"""
with patch("backend.util.virus_scanner.get_virus_scanner") as mock_get_scanner:
mock_scanner = Mock()
mock_scanner.scan_file = AsyncMock()
mock_scanner.scan_file.side_effect = Exception("Scan failed")
mock_get_scanner.return_value = mock_scanner
with pytest.raises(VirusScanError, match="Virus scanning failed"):
await scan_content_safe(b"test content", filename="test.txt")

View File

@@ -1,2 +0,0 @@
-- AlterTable
ALTER TABLE "AgentNodeExecutionInputOutput" ALTER COLUMN "data" DROP NOT NULL;

View File

@@ -1,5 +0,0 @@
-- Add webhookId column
ALTER TABLE "AgentPreset" ADD COLUMN "webhookId" TEXT;
-- Add AgentPreset<->IntegrationWebhook relation
ALTER TABLE "AgentPreset" ADD CONSTRAINT "AgentPreset_webhookId_fkey" FOREIGN KEY ("webhookId") REFERENCES "IntegrationWebhook"("id") ON DELETE SET NULL ON UPDATE CASCADE;

View File

@@ -1,11 +0,0 @@
-- CreateTable
CREATE TABLE "AgentNodeExecutionKeyValueData" (
"userId" TEXT NOT NULL,
"key" TEXT NOT NULL,
"agentNodeExecutionId" TEXT NOT NULL,
"data" JSONB,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3),
CONSTRAINT "AgentNodeExecutionKeyValueData_pkey" PRIMARY KEY ("userId","key")
);

View File

@@ -17,32 +17,20 @@ aiormq = ">=6.8,<6.9"
exceptiongroup = ">=1,<2"
yarl = "*"
[[package]]
name = "aioclamd"
version = "1.0.0"
description = "Asynchronous client for virus scanning with ClamAV"
optional = false
python-versions = ">=3.7,<4.0"
groups = ["main"]
files = [
{file = "aioclamd-1.0.0-py3-none-any.whl", hash = "sha256:4727da3953a4b38be4c2de1acb6b3bb3c94c1c171dcac780b80234ee6253f3d9"},
{file = "aioclamd-1.0.0.tar.gz", hash = "sha256:7b14e94e3a2285cc89e2f4d434e2a01f322d3cb95476ce2dda015a7980876047"},
]
[[package]]
name = "aiodns"
version = "3.5.0"
version = "3.4.0"
description = "Simple DNS resolver for asyncio"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "aiodns-3.5.0-py3-none-any.whl", hash = "sha256:6d0404f7d5215849233f6ee44854f2bb2481adf71b336b2279016ea5990ca5c5"},
{file = "aiodns-3.5.0.tar.gz", hash = "sha256:11264edbab51896ecf546c18eb0dd56dff0428c6aa6d2cd87e643e07300eb310"},
{file = "aiodns-3.4.0-py3-none-any.whl", hash = "sha256:4da2b25f7475343f3afbb363a2bfe46afa544f2b318acb9a945065e622f4ed24"},
{file = "aiodns-3.4.0.tar.gz", hash = "sha256:24b0ae58410530367f21234d0c848e4de52c1f16fbddc111726a4ab536ec1b2f"},
]
[package.dependencies]
pycares = ">=4.9.0"
pycares = ">=4.0.0"
[[package]]
name = "aiofiles"
@@ -222,14 +210,14 @@ files = [
[[package]]
name = "anthropic"
version = "0.57.1"
version = "0.51.0"
description = "The official Python library for the anthropic API"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "anthropic-0.57.1-py3-none-any.whl", hash = "sha256:33afc1f395af207d07ff1bffc0a3d1caac53c371793792569c5d2f09283ea306"},
{file = "anthropic-0.57.1.tar.gz", hash = "sha256:7815dd92245a70d21f65f356f33fc80c5072eada87fb49437767ea2918b2c4b0"},
{file = "anthropic-0.51.0-py3-none-any.whl", hash = "sha256:b8b47d482c9aa1f81b923555cebb687c2730309a20d01be554730c8302e0f62a"},
{file = "anthropic-0.51.0.tar.gz", hash = "sha256:6f824451277992af079554430d5b2c8ff5bc059cc2c968cdc3f06824437da201"},
]
[package.dependencies]
@@ -242,7 +230,6 @@ sniffio = "*"
typing-extensions = ">=4.10,<5"
[package.extras]
aiohttp = ["aiohttp", "httpx-aiohttp (>=0.1.6)"]
bedrock = ["boto3 (>=1.28.57)", "botocore (>=1.31.57)"]
vertex = ["google-auth[requests] (>=2,<3)"]
@@ -1006,14 +993,14 @@ pgp = ["gpg"]
[[package]]
name = "e2b"
version = "1.5.4"
version = "1.5.0"
description = "E2B SDK that give agents cloud environments"
optional = false
python-versions = "<4.0,>=3.9"
groups = ["main"]
files = [
{file = "e2b-1.5.4-py3-none-any.whl", hash = "sha256:9c8d22f9203311dff890e037823596daaba3d793300238117f2efc5426888f2c"},
{file = "e2b-1.5.4.tar.gz", hash = "sha256:49f1c115d0198244beef5854d19cc857fda9382e205f137b98d3dae0e7e0b2d2"},
{file = "e2b-1.5.0-py3-none-any.whl", hash = "sha256:875a843d1d314a9945e24bfb78c9b1b5cac7e2ecb1e799664d827a26a0b2276a"},
{file = "e2b-1.5.0.tar.gz", hash = "sha256:905730eea5c07f271d073d4b5d2a9ef44c8ac04b9b146a99fa0235db77bf6854"},
]
[package.dependencies]
@@ -1027,19 +1014,19 @@ typing-extensions = ">=4.1.0"
[[package]]
name = "e2b-code-interpreter"
version = "1.5.2"
version = "1.5.0"
description = "E2B Code Interpreter - Stateful code execution"
optional = false
python-versions = "<4.0,>=3.9"
groups = ["main"]
files = [
{file = "e2b_code_interpreter-1.5.2-py3-none-any.whl", hash = "sha256:5c3188d8f25226b28fef4b255447cc6a4c36afb748bdd5180b45be486d5169f3"},
{file = "e2b_code_interpreter-1.5.2.tar.gz", hash = "sha256:3bd6ea70596290e85aaf0a2f19f28bf37a5e73d13086f5e6a0080bb591c5a547"},
{file = "e2b_code_interpreter-1.5.0-py3-none-any.whl", hash = "sha256:299f5641a3754264a07f8edc3cccb744d6b009f10dc9285789a9352e24989a9b"},
{file = "e2b_code_interpreter-1.5.0.tar.gz", hash = "sha256:cd6028b6f20c4231e88a002de86484b9d4a99ea588b5be183b9ec7189a0f3cf6"},
]
[package.dependencies]
attrs = ">=21.3.0"
e2b = ">=1.5.4,<2.0.0"
e2b = ">=1.4.0,<2.0.0"
httpx = ">=0.20.0,<1.0.0"
[[package]]
@@ -1110,14 +1097,14 @@ typing-extensions = "*"
[[package]]
name = "fastapi"
version = "0.115.14"
version = "0.115.12"
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "fastapi-0.115.14-py3-none-any.whl", hash = "sha256:6c0c8bf9420bd58f565e585036d971872472b4f7d3f6c73b698e10cffdefb3ca"},
{file = "fastapi-0.115.14.tar.gz", hash = "sha256:b1de15cdc1c499a4da47914db35d0e4ef8f1ce62b624e94e0e5824421df99739"},
{file = "fastapi-0.115.12-py3-none-any.whl", hash = "sha256:e94613d6c05e27be7ffebdd6ea5f388112e5e430c8f7d6494a9d1d88d43e814d"},
{file = "fastapi-0.115.12.tar.gz", hash = "sha256:1e2c2a2646905f9e83d32f04a3f86aff4a286669c6c950ca95b5fd68c2602681"},
]
[package.dependencies]
@@ -1193,20 +1180,20 @@ packaging = ">=20"
[[package]]
name = "flake8"
version = "7.3.0"
version = "7.2.0"
description = "the modular source code checker: pep8 pyflakes and co"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "flake8-7.3.0-py2.py3-none-any.whl", hash = "sha256:b9696257b9ce8beb888cdbe31cf885c90d31928fe202be0889a7cdafad32f01e"},
{file = "flake8-7.3.0.tar.gz", hash = "sha256:fe044858146b9fc69b551a4b490d69cf960fcb78ad1edcb84e7fbb1b4a8e3872"},
{file = "flake8-7.2.0-py2.py3-none-any.whl", hash = "sha256:93b92ba5bdb60754a6da14fa3b93a9361fd00a59632ada61fd7b130436c40343"},
{file = "flake8-7.2.0.tar.gz", hash = "sha256:fa558ae3f6f7dbf2b4f22663e5343b6b6023620461f8d4ff2019ef4b5ee70426"},
]
[package.dependencies]
mccabe = ">=0.7.0,<0.8.0"
pycodestyle = ">=2.14.0,<2.15.0"
pyflakes = ">=3.4.0,<3.5.0"
pycodestyle = ">=2.13.0,<2.14.0"
pyflakes = ">=3.3.0,<3.4.0"
[[package]]
name = "frozenlist"
@@ -1357,14 +1344,14 @@ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.0)"]
[[package]]
name = "google-api-python-client"
version = "2.176.0"
version = "2.170.0"
description = "Google API Client Library for Python"
optional = false
python-versions = ">=3.7"
groups = ["main"]
files = [
{file = "google_api_python_client-2.176.0-py3-none-any.whl", hash = "sha256:e22239797f1d085341e12cd924591fc65c56d08e0af02549d7606092e6296510"},
{file = "google_api_python_client-2.176.0.tar.gz", hash = "sha256:2b451cdd7fd10faeb5dd20f7d992f185e1e8f4124c35f2cdcc77c843139a4cf1"},
{file = "google_api_python_client-2.170.0-py3-none-any.whl", hash = "sha256:7bf518a0527ad23322f070fa69f4f24053170d5c766821dc970ff0571ec22748"},
{file = "google_api_python_client-2.170.0.tar.gz", hash = "sha256:75f3a1856f11418ea3723214e0abc59d9b217fd7ed43dcf743aab7f06ab9e2b1"},
]
[package.dependencies]
@@ -1517,27 +1504,27 @@ protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4
[[package]]
name = "google-cloud-storage"
version = "3.2.0"
version = "3.1.0"
description = "Google Cloud Storage API client library"
optional = false
python-versions = ">=3.7"
groups = ["main"]
files = [
{file = "google_cloud_storage-3.2.0-py3-none-any.whl", hash = "sha256:ff7a9a49666954a7c3d1598291220c72d3b9e49d9dfcf9dfaecb301fc4fb0b24"},
{file = "google_cloud_storage-3.2.0.tar.gz", hash = "sha256:decca843076036f45633198c125d1861ffbf47ebf5c0e3b98dcb9b2db155896c"},
{file = "google_cloud_storage-3.1.0-py2.py3-none-any.whl", hash = "sha256:eaf36966b68660a9633f03b067e4a10ce09f1377cae3ff9f2c699f69a81c66c6"},
{file = "google_cloud_storage-3.1.0.tar.gz", hash = "sha256:944273179897c7c8a07ee15f2e6466a02da0c7c4b9ecceac2a26017cb2972049"},
]
[package.dependencies]
google-api-core = ">=2.15.0,<3.0.0"
google-auth = ">=2.26.1,<3.0.0"
google-cloud-core = ">=2.4.2,<3.0.0"
google-crc32c = ">=1.1.3,<2.0.0"
google-resumable-media = ">=2.7.2,<3.0.0"
requests = ">=2.22.0,<3.0.0"
google-api-core = ">=2.15.0,<3.0.0dev"
google-auth = ">=2.26.1,<3.0dev"
google-cloud-core = ">=2.4.2,<3.0dev"
google-crc32c = ">=1.0,<2.0dev"
google-resumable-media = ">=2.7.2"
requests = ">=2.18.0,<3.0.0dev"
[package.extras]
protobuf = ["protobuf (>=3.20.2,<7.0.0)"]
tracing = ["opentelemetry-api (>=1.1.0,<2.0.0)"]
protobuf = ["protobuf (<6.0.0dev)"]
tracing = ["opentelemetry-api (>=1.1.0)"]
[[package]]
name = "google-crc32c"
@@ -1745,14 +1732,14 @@ test = ["objgraph", "psutil"]
[[package]]
name = "groq"
version = "0.29.0"
version = "0.24.0"
description = "The official Python library for the groq API"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "groq-0.29.0-py3-none-any.whl", hash = "sha256:03515ec46be1ef1feef0cd9d876b6f30a39ee2742e76516153d84acd7c97f23a"},
{file = "groq-0.29.0.tar.gz", hash = "sha256:109dc4d696c05d44e4c2cd157652c4c6600c3e96f093f6e158facb5691e37847"},
{file = "groq-0.24.0-py3-none-any.whl", hash = "sha256:0020e6b0b2b267263c9eb7c318deef13c12f399c6525734200b11d777b00088e"},
{file = "groq-0.24.0.tar.gz", hash = "sha256:e821559de8a77fb81d2585b3faec80ff923d6d64fd52339b33f6c94997d6f7f5"},
]
[package.dependencies]
@@ -1763,9 +1750,6 @@ pydantic = ">=1.9.0,<3"
sniffio = "*"
typing-extensions = ">=4.10,<5"
[package.extras]
aiohttp = ["aiohttp", "httpx-aiohttp (>=0.1.6)"]
[[package]]
name = "grpc-google-iam-v1"
version = "0.14.2"
@@ -2552,14 +2536,14 @@ files = [
[[package]]
name = "mem0ai"
version = "0.1.114"
version = "0.1.102"
description = "Long-term memory for AI Agents"
optional = false
python-versions = "<4.0,>=3.9"
groups = ["main"]
files = [
{file = "mem0ai-0.1.114-py3-none-any.whl", hash = "sha256:dfb7f0079ee282f5d9782e220f6f09707bcf5e107925d1901dbca30d8dd83f9b"},
{file = "mem0ai-0.1.114.tar.gz", hash = "sha256:b27886132eaec78544e8b8b54f0b14a36728f3c99da54cb7cb417150e2fad7e1"},
{file = "mem0ai-0.1.102-py3-none-any.whl", hash = "sha256:1401ccfd2369e2182ce78abb61b817e739fe49508b5a8ad98abcd4f8ad4db0b4"},
{file = "mem0ai-0.1.102.tar.gz", hash = "sha256:7358dba4fbe954b9c3f33204c14df7babaf9067e2eb48241d89a32e6bc774988"},
]
[package.dependencies]
@@ -2572,11 +2556,8 @@ sqlalchemy = ">=2.0.31"
[package.extras]
dev = ["isort (>=5.13.2)", "pytest (>=8.2.2)", "ruff (>=0.6.5)"]
extras = ["boto3 (>=1.34.0)", "elasticsearch (>=8.0.0)", "langchain-community (>=0.0.0)", "langchain-memgraph (>=0.1.0)", "opensearch-py (>=2.0.0)", "sentence-transformers (>=5.0.0)"]
graph = ["langchain-aws (>=0.2.23)", "langchain-neo4j (>=0.4.0)", "neo4j (>=5.23.1)", "rank-bm25 (>=0.2.2)"]
llms = ["google-genai (>=1.0.0)", "google-generativeai (>=0.3.0)", "groq (>=0.3.0)", "litellm (>=0.1.0)", "ollama (>=0.1.0)", "together (>=0.2.10)", "vertexai (>=0.1.0)"]
graph = ["langchain-neo4j (>=0.4.0)", "neo4j (>=5.23.1)", "rank-bm25 (>=0.2.2)"]
test = ["pytest (>=8.2.2)", "pytest-asyncio (>=0.23.7)", "pytest-mock (>=3.14.0)"]
vector-stores = ["azure-search-documents (>=11.4.0b8)", "chromadb (>=0.4.24)", "faiss-cpu (>=1.7.4)", "pinecone (<=7.3.0)", "pinecone-text (>=0.10.0)", "pymochow (>=2.2.9)", "pymongo (>=4.13.2)", "upstash-vector (>=0.1.0)", "vecs (>=0.4.0)", "weaviate-client (>=4.4.0)"]
[[package]]
name = "more-itertools"
@@ -2915,14 +2896,14 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
[[package]]
name = "ollama"
version = "0.5.1"
version = "0.4.9"
description = "The official Python client for Ollama."
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "ollama-0.5.1-py3-none-any.whl", hash = "sha256:4c8839f35bc173c7057b1eb2cbe7f498c1a7e134eafc9192824c8aecb3617506"},
{file = "ollama-0.5.1.tar.gz", hash = "sha256:5a799e4dc4e7af638b11e3ae588ab17623ee019e496caaf4323efbaa8feeff93"},
{file = "ollama-0.4.9-py3-none-any.whl", hash = "sha256:18c8c85358c54d7f73d6a66cda495b0e3ba99fdb88f824ae470d740fbb211a50"},
{file = "ollama-0.4.9.tar.gz", hash = "sha256:5266d4d29b5089a01489872b8e8f980f018bccbdd1082b3903448af1d5615ce7"},
]
[package.dependencies]
@@ -2931,14 +2912,14 @@ pydantic = ">=2.9"
[[package]]
name = "openai"
version = "1.93.2"
version = "1.82.1"
description = "The official Python library for the openai API"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "openai-1.93.2-py3-none-any.whl", hash = "sha256:5adbbebd48eae160e6d68efc4c0a4f7cb1318a44c62d9fc626cec229f418eab4"},
{file = "openai-1.93.2.tar.gz", hash = "sha256:4a7312b426b5e4c98b78dfa1148b5683371882de3ad3d5f7c8e0c74f3cc90778"},
{file = "openai-1.82.1-py3-none-any.whl", hash = "sha256:334eb5006edf59aa464c9e932b9d137468d810b2659e5daea9b3a8c39d052395"},
{file = "openai-1.82.1.tar.gz", hash = "sha256:ffc529680018e0417acac85f926f92aa0bbcbc26e82e2621087303c66bc7f95d"},
]
[package.dependencies]
@@ -2952,7 +2933,6 @@ tqdm = ">4"
typing-extensions = ">=4.11,<5"
[package.extras]
aiohttp = ["aiohttp", "httpx-aiohttp (>=0.1.6)"]
datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
realtime = ["websockets (>=13,<16)"]
voice-helpers = ["numpy (>=2.0.2)", "sounddevice (>=0.5.1)"]
@@ -3267,14 +3247,14 @@ testing = ["coverage", "pytest", "pytest-benchmark"]
[[package]]
name = "poethepoet"
version = "0.36.0"
description = "A task runner that works well with poetry and uv."
version = "0.34.0"
description = "A task runner that works well with poetry."
optional = false
python-versions = ">=3.9"
groups = ["dev"]
files = [
{file = "poethepoet-0.36.0-py3-none-any.whl", hash = "sha256:693e3c1eae9f6731d3613c3c0c40f747d3c5c68a375beda42e590a63c5623308"},
{file = "poethepoet-0.36.0.tar.gz", hash = "sha256:2217b49cb4e4c64af0b42ff8c4814b17f02e107d38bc461542517348ede25663"},
{file = "poethepoet-0.34.0-py3-none-any.whl", hash = "sha256:c472d6f0fdb341b48d346f4ccd49779840c15b30dfd6bc6347a80d6274b5e34e"},
{file = "poethepoet-0.34.0.tar.gz", hash = "sha256:86203acce555bbfe45cb6ccac61ba8b16a5784264484195874da457ddabf5850"},
]
[package.dependencies]
@@ -3500,14 +3480,14 @@ tqdm = "*"
[[package]]
name = "prometheus-client"
version = "0.22.1"
version = "0.21.1"
description = "Python client for the Prometheus monitoring system."
optional = false
python-versions = ">=3.9"
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "prometheus_client-0.22.1-py3-none-any.whl", hash = "sha256:cca895342e308174341b2cbf99a56bef291fbc0ef7b9e5412a0f26d653ba7094"},
{file = "prometheus_client-0.22.1.tar.gz", hash = "sha256:190f1331e783cf21eb60bca559354e0a4d4378facecf78f5428c39b675d20d28"},
{file = "prometheus_client-0.21.1-py3-none-any.whl", hash = "sha256:594b45c410d6f4f8888940fe80b5cc2521b305a1fafe1c58609ef715a001f301"},
{file = "prometheus_client-0.21.1.tar.gz", hash = "sha256:252505a722ac04b0456be05c05f75f45d760c2911ffc45f2a06bcaed9f3ae3fb"},
]
[package.extras]
@@ -3791,88 +3771,83 @@ pyasn1 = ">=0.6.1,<0.7.0"
[[package]]
name = "pycares"
version = "4.9.0"
version = "4.8.0"
description = "Python interface for c-ares"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "pycares-4.9.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0b8bd9a3ee6e9bc990e1933dc7e7e2f44d4184f49a90fa444297ac12ab6c0c84"},
{file = "pycares-4.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:417a5c20861f35977240ad4961479a6778125bcac21eb2ad1c3aad47e2ff7fab"},
{file = "pycares-4.9.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ab290faa4ea53ce53e3ceea1b3a42822daffce2d260005533293a52525076750"},
{file = "pycares-4.9.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b1df81193084c9717734e4615e8c5074b9852478c9007d1a8bb242f7f580e67"},
{file = "pycares-4.9.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:20c7a6af0c2ccd17cc5a70d76e299a90e7ebd6c4d8a3d7fff5ae533339f61431"},
{file = "pycares-4.9.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:370f41442a5b034aebdb2719b04ee04d3e805454a20d3f64f688c1c49f9137c3"},
{file = "pycares-4.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:340e4a3bbfd14d73c01ec0793a321b8a4a93f64c508225883291078b7ee17ac8"},
{file = "pycares-4.9.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f0ec94785856ea4f5556aa18f4c027361ba4b26cb36c4ad97d2105ef4eec68ba"},
{file = "pycares-4.9.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:dd6b7e23a4a9e2039b5d67dfa0499d2d5f114667dc13fb5d7d03eed230c7ac4f"},
{file = "pycares-4.9.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:490c978b0be9d35a253a5e31dd598f6d66b453625f0eb7dc2d81b22b8c3bb3f4"},
{file = "pycares-4.9.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:e433faaf07f44e44f1a1b839fee847480fe3db9431509dafc9f16d618d491d0f"},
{file = "pycares-4.9.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:cf6d8851a06b79d10089962c9dadcb34dad00bf027af000f7102297a54aaff2e"},
{file = "pycares-4.9.0-cp310-cp310-win32.whl", hash = "sha256:4f803e7d66ac7d8342998b8b07393788991353a46b05bbaad0b253d6f3484ea8"},
{file = "pycares-4.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:8e17bd32267e3870855de3baed7d0efa6337344d68f44853fd9195c919f39400"},
{file = "pycares-4.9.0-cp310-cp310-win_arm64.whl", hash = "sha256:6b74f75d8e430f9bb11a1cc99b2e328eed74b17d8d4b476de09126f38d419eb9"},
{file = "pycares-4.9.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:16a97ee83ec60d35c7f716f117719932c27d428b1bb56b242ba1c4aa55521747"},
{file = "pycares-4.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:78748521423a211ce699a50c27cc5c19e98b7db610ccea98daad652ace373990"},
{file = "pycares-4.9.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8818b2c7a57d9d6d41e8b64d9ff87992b8ea2522fc0799686725228bc3cff6c5"},
{file = "pycares-4.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:96df8990f16013ca5194d6ece19dddb4ef9cd7c3efaab9f196ec3ccd44b40f8d"},
{file = "pycares-4.9.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:61af86fd58b8326e723b0d20fb96b56acaec2261c3a7c9a1c29d0a79659d613a"},
{file = "pycares-4.9.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ec72edb276bda559813cc807bc47b423d409ffab2402417a5381077e9c2c6be1"},
{file = "pycares-4.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:832fb122c7376c76cab62f8862fa5e398b9575fb7c9ff6bc9811086441ee64ca"},
{file = "pycares-4.9.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:cdcfaef24f771a471671470ccfd676c0366ab6b0616fd8217b8f356c40a02b83"},
{file = "pycares-4.9.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:52cb056d06ff55d78a8665b97ae948abaaba2ca200ca59b10346d4526bce1e7d"},
{file = "pycares-4.9.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:54985ed3f2e8a87315269f24cb73441622857a7830adfc3a27c675a94c3261c1"},
{file = "pycares-4.9.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:08048e223615d4aef3dac81fe0ea18fb18d6fc97881f1eb5be95bb1379969b8d"},
{file = "pycares-4.9.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:cc60037421ce05a409484287b2cd428e1363cca73c999b5f119936bb8f255208"},
{file = "pycares-4.9.0-cp311-cp311-win32.whl", hash = "sha256:62b86895b60cfb91befb3086caa0792b53f949231c6c0c3053c7dfee3f1386ab"},
{file = "pycares-4.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:7046b3c80954beaabf2db52b09c3d6fe85f6c4646af973e61be79d1c51589932"},
{file = "pycares-4.9.0-cp311-cp311-win_arm64.whl", hash = "sha256:fcbda3fdf44e94d3962ca74e6ba3dc18c0d7029106f030d61c04c0876f319403"},
{file = "pycares-4.9.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d68ca2da1001aeccdc81c4a2fb1f1f6cfdafd3d00e44e7c1ed93e3e05437f666"},
{file = "pycares-4.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4f0c8fa5a384d79551a27eafa39eed29529e66ba8fa795ee432ab88d050432a3"},
{file = "pycares-4.9.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0eb8c428cf3b9c6ff9c641ba50ab6357b4480cd737498733e6169b0ac8a1a89b"},
{file = "pycares-4.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6845bd4a43abf6dab7fedbf024ef458ac3750a25b25076ea9913e5ac5fec4548"},
{file = "pycares-4.9.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5e28f4acc3b97e46610cf164665ebf914f709daea6ced0ca4358ce55bc1c3d6b"},
{file = "pycares-4.9.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9464a39861840ce35a79352c34d653a9db44f9333af7c9feddb97998d3e00c07"},
{file = "pycares-4.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0611c1bd46d1fc6bdd9305b8850eb84c77df485769f72c574ed7b8389dfbee2"},
{file = "pycares-4.9.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d4fb5a38a51d03b75ac4320357e632c2e72e03fdeb13263ee333a40621415fdc"},
{file = "pycares-4.9.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:df5edae05fb3e1370ab7639e67e8891fdaa9026cb10f05dbd57893713f7a9cfe"},
{file = "pycares-4.9.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:397123ea53d261007bb0aa7e767ef238778f45026db40bed8196436da2cc73de"},
{file = "pycares-4.9.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:bb0d874d0b131b29894fd8a0f842be91ac21d50f90ec04cff4bb3f598464b523"},
{file = "pycares-4.9.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:497cc03a61ec1585eb17d2cb086a29a6a67d24babf1e9be519b47222916a3b06"},
{file = "pycares-4.9.0-cp312-cp312-win32.whl", hash = "sha256:b46e46313fdb5e82da15478652aac0fd15e1c9f33e08153bad845aa4007d6f84"},
{file = "pycares-4.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:12547a06445777091605a7581da15a0da158058beb8a05a3ebbf7301fd1f58d4"},
{file = "pycares-4.9.0-cp312-cp312-win_arm64.whl", hash = "sha256:f1e10bf1e8eb80b08e5c828627dba1ebc4acd54803bd0a27d92b9063b6aa99d8"},
{file = "pycares-4.9.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:574d815112a95ab09d75d0a9dc7dea737c06985e3125cf31c32ba6a3ed6ca006"},
{file = "pycares-4.9.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50e5ab06361d59625a27a7ad93d27e067dc7c9f6aa529a07d691eb17f3b43605"},
{file = "pycares-4.9.0-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:785f5fd11ff40237d9bc8afa441551bb449e2812c74334d1d10859569e07515c"},
{file = "pycares-4.9.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e194a500e403eba89b91fb863c917495c5b3dfcd1ce0ee8dc3a6f99a1360e2fc"},
{file = "pycares-4.9.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:112dd49cdec4e6150a8d95b197e8b6b7b4468a3170b30738ed9b248cb2240c04"},
{file = "pycares-4.9.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:94aa3c2f3eb0aa69160137134775501f06c901188e722aac63d2a210d4084f99"},
{file = "pycares-4.9.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b510d71255cf5a92ccc2643a553548fcb0623d6ed11c8c633b421d99d7fa4167"},
{file = "pycares-4.9.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5c6aa30b1492b8130f7832bf95178642c710ce6b7ba610c2b17377f77177e3cd"},
{file = "pycares-4.9.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:e5767988e044faffe2aff6a76aa08df99a8b6ef2641be8b00ea16334ce5dea93"},
{file = "pycares-4.9.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:b9928a942820a82daa3207509eaba9e0fa9660756ac56667ec2e062815331fcb"},
{file = "pycares-4.9.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:556c854174da76d544714cdfab10745ed5d4b99eec5899f7b13988cd26ff4763"},
{file = "pycares-4.9.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d42e2202ca9aa9a0a9a6e43a4a4408bbe0311aaa44800fa27b8fd7f82b20152a"},
{file = "pycares-4.9.0-cp313-cp313-win32.whl", hash = "sha256:cce8ef72c9ed4982c84114e6148a4e42e989d745de7862a0ad8b3f1cdc05def2"},
{file = "pycares-4.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:318cdf24f826f1d2f0c5a988730bd597e1683296628c8f1be1a5b96643c284fe"},
{file = "pycares-4.9.0-cp313-cp313-win_arm64.whl", hash = "sha256:faa9de8e647ed06757a2c117b70a7645a755561def814da6aca0d766cf71a402"},
{file = "pycares-4.9.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8310d27d68fa25be9781ce04d330f4860634a2ac34dd9265774b5f404679b41f"},
{file = "pycares-4.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99cf98452d3285307eec123049f2c9c50b109e06751b0727c6acefb6da30c6a0"},
{file = "pycares-4.9.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ffd6e8c8250655504602b076f106653e085e6b1e15318013442558101aa4777"},
{file = "pycares-4.9.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4065858d8c812159c9a55601fda73760d9e5e3300f7868d9e546eab1084f36c"},
{file = "pycares-4.9.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91ee6818113faf9013945c2b54bcd6b123d0ac192ae3099cf4288cedaf2dbb25"},
{file = "pycares-4.9.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:21f0602059ec11857ab7ad608c7ec8bc6f7a302c04559ec06d33e82f040585f8"},
{file = "pycares-4.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e22e5b46ed9b12183091da56e4a5a20813b5436c4f13135d7a1c20a84027ca8a"},
{file = "pycares-4.9.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:9eded8649867bfd7aea7589c5755eae4d37686272f6ed7a995da40890d02de71"},
{file = "pycares-4.9.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:f71d31cbbe066657a2536c98aad850724a9ab7b1cd2624f491832ae9667ea8e7"},
{file = "pycares-4.9.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:2b30945982ab4741f097efc5b0853051afc3c11df26996ed53a700c7575175af"},
{file = "pycares-4.9.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:54a8f1f067d64810426491d33033f5353b54f35e5339126440ad4e6afbf3f149"},
{file = "pycares-4.9.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:41556a269a192349e92eee953f62eddd867e9eddb27f444b261e2c1c4a4a9eff"},
{file = "pycares-4.9.0-cp39-cp39-win32.whl", hash = "sha256:524d6c14eaa167ed098a4fe54856d1248fa20c296cdd6976f9c1b838ba32d014"},
{file = "pycares-4.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:15f930c733d36aa487b4ad60413013bd811281b5ea4ca620070fa38505d84df4"},
{file = "pycares-4.9.0-cp39-cp39-win_arm64.whl", hash = "sha256:79b7addb2a41267d46650ac0d9c4f3b3233b036f186b85606f7586881dfb4b69"},
{file = "pycares-4.9.0.tar.gz", hash = "sha256:8ee484ddb23dbec4d88d14ed5b6d592c1960d2e93c385d5e52b6fad564d82395"},
{file = "pycares-4.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f40d9f4a8de398b110fdf226cdfadd86e8c7eb71d5298120ec41cf8d94b0012f"},
{file = "pycares-4.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:339de06fc849a51015968038d2bbed68fc24047522404af9533f32395ca80d25"},
{file = "pycares-4.8.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:372a236c1502b9056b0bea195c64c329603b4efa70b593a33b7ae37fbb7fad00"},
{file = "pycares-4.8.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03f66a5e143d102ccc204bd4e29edd70bed28420f707efd2116748241e30cb73"},
{file = "pycares-4.8.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ef50504296cd5fc58cfd6318f82e20af24fbe2c83004f6ff16259adb13afdf14"},
{file = "pycares-4.8.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d1bc541b627c7951dd36136b18bd185c5244a0fb2af5b1492ffb8acaceec1c5b"},
{file = "pycares-4.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:938d188ed6bed696099be67ebdcdf121827b9432b17a9ea9e40dc35fd9d85363"},
{file = "pycares-4.8.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:327837ffdc0c7adda09c98e1263c64b2aff814eea51a423f66733c75ccd9a642"},
{file = "pycares-4.8.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:a6b9b8d08c4508c45bd39e0c74e9e7052736f18ca1d25a289365bb9ac36e5849"},
{file = "pycares-4.8.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:feac07d5e6d2d8f031c71237c21c21b8c995b41a1eba64560e8cf1e42ac11bc6"},
{file = "pycares-4.8.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:5bcdbf37012fd2323ca9f2a1074421a9ccf277d772632f8f0ce8c46ec7564250"},
{file = "pycares-4.8.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e3ebb692cb43fcf34fe0d26f2cf9a0ea53fdfb136463845b81fad651277922db"},
{file = "pycares-4.8.0-cp310-cp310-win32.whl", hash = "sha256:d98447ec0efff3fa868ccc54dcc56e71faff498f8848ecec2004c3108efb4da2"},
{file = "pycares-4.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:1abb8f40917960ead3c2771277f0bdee1967393b0fdf68743c225b606787da68"},
{file = "pycares-4.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5e25db89005ddd8d9c5720293afe6d6dd92e682fc6bc7a632535b84511e2060d"},
{file = "pycares-4.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6f9665ef116e6ee216c396f5f927756c2164f9f3316aec7ff1a9a1e1e7ec9b2a"},
{file = "pycares-4.8.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54a96893133471f6889b577147adcc21a480dbe316f56730871028379c8313f3"},
{file = "pycares-4.8.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51024b3a69762bd3100d94986a29922be15e13f56f991aaefb41f5bcd3d7f0bb"},
{file = "pycares-4.8.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:47ff9db50c599e4d965ae3bec99cc30941c1d2b0f078ec816680b70d052dd54a"},
{file = "pycares-4.8.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:27ef8ff4e0f60ea6769a60d1c3d1d2aefed1d832e7bb83fc3934884e2dba5cdd"},
{file = "pycares-4.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63511af7a3f9663f562fbb6bfa3591a259505d976e2aba1fa2da13dde43c6ca7"},
{file = "pycares-4.8.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:73c3219b47616e6a5ad1810de96ed59721c7751f19b70ae7bf24997a8365408f"},
{file = "pycares-4.8.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:da42a45207c18f37be5e491c14b6d1063cfe1e46620eb661735d0cedc2b59099"},
{file = "pycares-4.8.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:8a068e898bb5dd09cd654e19cd2abf20f93d0cc59d5d955135ed48ea0f806aa1"},
{file = "pycares-4.8.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:962aed95675bb66c0b785a2fbbd1bb58ce7f009e283e4ef5aaa4a1f2dc00d217"},
{file = "pycares-4.8.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ce8b1a16c1e4517a82a0ebd7664783a327166a3764d844cf96b1fb7b9dd1e493"},
{file = "pycares-4.8.0-cp311-cp311-win32.whl", hash = "sha256:b3749ddbcbd216376c3b53d42d8b640b457133f1a12b0e003f3838f953037ae7"},
{file = "pycares-4.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:5ce8a4e1b485b2360ab666c4ea1db97f57ede345a3b566d80bfa52b17e616610"},
{file = "pycares-4.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3273e01a75308ed06d2492d83c7ba476e579a60a24d9f20fe178ce5e9d8d028b"},
{file = "pycares-4.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fcedaadea1f452911fd29935749f98d144dae758d6003b7e9b6c5d5bd47d1dff"},
{file = "pycares-4.8.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aae6cb33e287e06a4aabcbc57626df682c9a4fa8026207f5b498697f1c2fb562"},
{file = "pycares-4.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25038b930e5be82839503fb171385b2aefd6d541bc5b7da0938bdb67780467d2"},
{file = "pycares-4.8.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cc8499b6e7dfbe4af65f6938db710ce9acd1debf34af2cbb93b898b1e5da6a5a"},
{file = "pycares-4.8.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c4e1c6a68ef56a7622f6176d9946d4e51f3c853327a0123ef35a5380230c84cd"},
{file = "pycares-4.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7cc8c3c9114b9c84e4062d25ca9b4bddc80a65d0b074c7cb059275273382f89"},
{file = "pycares-4.8.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4404014069d3e362abf404c9932d4335bb9c07ba834cfe7d683c725b92e0f9da"},
{file = "pycares-4.8.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:ee0a58c32ec2a352cef0e1d20335a7caf9871cd79b73be2ca2896fe70f09c9d7"},
{file = "pycares-4.8.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:35f32f52b486b8fede3cbebf088f30b01242d0321b5216887c28e80490595302"},
{file = "pycares-4.8.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:ecbb506e27a3b3a2abc001c77beeccf265475c84b98629a6b3e61bd9f2987eaa"},
{file = "pycares-4.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:9392b2a34adbf60cb9e38f4a0d363413ecea8d835b5a475122f50f76676d59dd"},
{file = "pycares-4.8.0-cp312-cp312-win32.whl", hash = "sha256:f0fbefe68403ffcff19c869b8d621c88a6d2cef18d53cf0dab0fa9458a6ca712"},
{file = "pycares-4.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:fa8aab6085a2ddfb1b43a06ddf1b498347117bb47cd620d9b12c43383c9c2737"},
{file = "pycares-4.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:358a9a2c6fed59f62788e63d88669224955443048a1602016d4358e92aedb365"},
{file = "pycares-4.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0e3e1278967fa8d4a0056be3fcc8fc551b8bad1fc7d0e5172196dccb8ddb036a"},
{file = "pycares-4.8.0-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:79befb773e370a8f97de9f16f5ea2c7e7fa0e3c6c74fbea6d332bf58164d7d06"},
{file = "pycares-4.8.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2b00d3695db64ce98a34e632e1d53f5a1cdb25451489f227bec2a6c03ff87ee8"},
{file = "pycares-4.8.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:37bdc4f2ff0612d60fc4f7547e12ff02cdcaa9a9e42e827bb64d4748994719f1"},
{file = "pycares-4.8.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd92c44498ec7a6139888b464b28c49f7ba975933689bd67ea8d572b94188404"},
{file = "pycares-4.8.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2665a0d810e2bbc41e97f3c3e5ea7950f666b3aa19c5f6c99d6b018ccd2e0052"},
{file = "pycares-4.8.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:45a629a6470a33478514c566bce50c63f1b17d1c5f2f964c9a6790330dc105fb"},
{file = "pycares-4.8.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:47bb378f1773f41cca8e31dcdf009ce4a9b8aff8a30c7267aaff9a099c407ba5"},
{file = "pycares-4.8.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:fb3feae38458005cc101956e38f16eb3145fff8cd793e35cd4bdef6bf1aa2623"},
{file = "pycares-4.8.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:14bc28aeaa66b0f4331ac94455e8043c8a06b3faafd78cc49d4b677bae0d0b08"},
{file = "pycares-4.8.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:62c82b871470f2864a1febf7b96bb1d108ce9063e6d3d43727e8a46f0028a456"},
{file = "pycares-4.8.0-cp313-cp313-win32.whl", hash = "sha256:01afa8964c698c8f548b46d726f766aa7817b2d4386735af1f7996903d724920"},
{file = "pycares-4.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:22f86f81b12ab17b0a7bd0da1e27938caaed11715225c1168763af97f8bb51a7"},
{file = "pycares-4.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:61325d13a95255e858f42a7a1a9e482ff47ef2233f95ad9a4f308a3bd8ecf903"},
{file = "pycares-4.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:dfec3a7d42336fa46a1e7e07f67000fd4b97860598c59a894c08f81378629e4e"},
{file = "pycares-4.8.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b65067e4b4f5345688817fff6be06b9b1f4ec3619b0b9ecc639bc681b73f646b"},
{file = "pycares-4.8.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0322ad94bbaa7016139b5bbdcd0de6f6feb9d146d69e03a82aaca342e06830a6"},
{file = "pycares-4.8.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:456c60f170c997f9a43c7afa1085fced8efb7e13ae49dd5656f998ae13c4bdb4"},
{file = "pycares-4.8.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57a2c4c9ce423a85b0e0227409dbaf0d478f5e0c31d9e626768e77e1e887d32f"},
{file = "pycares-4.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:478d9c479108b7527266864c0affe3d6e863492c9bc269217e36100c8fd89b91"},
{file = "pycares-4.8.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:aed56bca096990ca0aa9bbf95761fc87e02880e04b0845922b5c12ea9abe523f"},
{file = "pycares-4.8.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ef265a390928ee2f77f8901c2273c53293157860451ad453ce7f45dd268b72f9"},
{file = "pycares-4.8.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:a5f17d7a76d8335f1c90a8530c8f1e8bb22e9a1d70a96f686efaed946de1c908"},
{file = "pycares-4.8.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:891f981feb2ef34367378f813fc17b3d706ce95b6548eeea0c9fe7705d7e54b1"},
{file = "pycares-4.8.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:4102f6d9117466cc0a1f527907a1454d109cc9e8551b8074888071ef16050fe3"},
{file = "pycares-4.8.0-cp39-cp39-win32.whl", hash = "sha256:d6775308659652adc88c82c53eda59b5e86a154aaba5ad1e287bbb3e0be77076"},
{file = "pycares-4.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:8bc05462aa44788d48544cca3d2532466fed2cdc5a2f24a43a92b620a61c9d19"},
{file = "pycares-4.8.0.tar.gz", hash = "sha256:2fc2ebfab960f654b3e3cf08a732486950da99393a657f8b44618ad3ed2d39c1"},
]
[package.dependencies]
@@ -3883,14 +3858,14 @@ idna = ["idna (>=2.1)"]
[[package]]
name = "pycodestyle"
version = "2.14.0"
version = "2.13.0"
description = "Python style guide checker"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "pycodestyle-2.14.0-py2.py3-none-any.whl", hash = "sha256:dd6bf7cb4ee77f8e016f9c8e74a35ddd9f67e1d5fd4184d86c3b98e07099f42d"},
{file = "pycodestyle-2.14.0.tar.gz", hash = "sha256:c4b5b517d278089ff9d0abdec919cd97262a3367449ea1c8b49b91529167b783"},
{file = "pycodestyle-2.13.0-py2.py3-none-any.whl", hash = "sha256:35863c5974a271c7a726ed228a14a4f6daf49df369d8c50cd9a6f58a5e143ba9"},
{file = "pycodestyle-2.13.0.tar.gz", hash = "sha256:c8415bf09abe81d9c7f872502a6eee881fbe85d8763dd5b9924bb0a01d67efae"},
]
[[package]]
@@ -3907,14 +3882,14 @@ files = [
[[package]]
name = "pydantic"
version = "2.11.7"
version = "2.11.5"
description = "Data validation using Python type hints"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "pydantic-2.11.7-py3-none-any.whl", hash = "sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b"},
{file = "pydantic-2.11.7.tar.gz", hash = "sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db"},
{file = "pydantic-2.11.5-py3-none-any.whl", hash = "sha256:f9c26ba06f9747749ca1e5c94d6a85cb84254577553c8785576fd38fa64dc0f7"},
{file = "pydantic-2.11.5.tar.gz", hash = "sha256:7f853db3d0ce78ce8bbb148c401c2cdd6431b3473c0cdff2755c7690952a7b7a"},
]
[package.dependencies]
@@ -4042,14 +4017,14 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
[[package]]
name = "pydantic-settings"
version = "2.10.1"
version = "2.9.1"
description = "Settings management using Pydantic"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "pydantic_settings-2.10.1-py3-none-any.whl", hash = "sha256:a60952460b99cf661dc25c29c0ef171721f98bfcb52ef8d9ea4c943d7c8cc796"},
{file = "pydantic_settings-2.10.1.tar.gz", hash = "sha256:06f0062169818d0f5524420a360d632d5857b83cffd4d42fe29597807a1614ee"},
{file = "pydantic_settings-2.9.1-py3-none-any.whl", hash = "sha256:59b4f431b1defb26fe620c71a7d3968a710d719f5f4cdbbdb7926edeb770f6ef"},
{file = "pydantic_settings-2.9.1.tar.gz", hash = "sha256:c509bf79d27563add44e8446233359004ed85066cd096d8b510f715e6ef5d268"},
]
[package.dependencies]
@@ -4066,31 +4041,16 @@ yaml = ["pyyaml (>=6.0.1)"]
[[package]]
name = "pyflakes"
version = "3.4.0"
version = "3.3.2"
description = "passive checker of Python programs"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "pyflakes-3.4.0-py2.py3-none-any.whl", hash = "sha256:f742a7dbd0d9cb9ea41e9a24a918996e8170c799fa528688d40dd582c8265f4f"},
{file = "pyflakes-3.4.0.tar.gz", hash = "sha256:b24f96fafb7d2ab0ec5075b7350b3d2d2218eab42003821c06344973d3ea2f58"},
{file = "pyflakes-3.3.2-py2.py3-none-any.whl", hash = "sha256:5039c8339cbb1944045f4ee5466908906180f13cc99cc9949348d10f82a5c32a"},
{file = "pyflakes-3.3.2.tar.gz", hash = "sha256:6dfd61d87b97fba5dcfaaf781171ac16be16453be6d816147989e7f6e6a9576b"},
]
[[package]]
name = "pygments"
version = "2.19.2"
description = "Pygments is a syntax highlighting package written in Python."
optional = false
python-versions = ">=3.8"
groups = ["main", "dev"]
files = [
{file = "pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b"},
{file = "pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887"},
]
[package.extras]
windows-terminal = ["colorama (>=0.4.6)"]
[[package]]
name = "pyjwt"
version = "2.10.1"
@@ -4150,14 +4110,14 @@ files = [
[[package]]
name = "pyright"
version = "1.1.402"
version = "1.1.401"
description = "Command line wrapper for pyright"
optional = false
python-versions = ">=3.7"
groups = ["dev"]
files = [
{file = "pyright-1.1.402-py3-none-any.whl", hash = "sha256:2c721f11869baac1884e846232800fe021c33f1b4acb3929cff321f7ea4e2982"},
{file = "pyright-1.1.402.tar.gz", hash = "sha256:85a33c2d40cd4439c66aa946fd4ce71ab2f3f5b8c22ce36a623f59ac22937683"},
{file = "pyright-1.1.401-py3-none-any.whl", hash = "sha256:6fde30492ba5b0d7667c16ecaf6c699fab8d7a1263f6a18549e0b00bf7724c06"},
{file = "pyright-1.1.401.tar.gz", hash = "sha256:788a82b6611fa5e34a326a921d86d898768cddf59edde8e93e56087d277cc6f1"},
]
[package.dependencies]
@@ -4171,27 +4131,26 @@ nodejs = ["nodejs-wheel-binaries"]
[[package]]
name = "pytest"
version = "8.4.1"
version = "8.3.5"
description = "pytest: simple powerful testing with Python"
optional = false
python-versions = ">=3.9"
python-versions = ">=3.8"
groups = ["main", "dev"]
files = [
{file = "pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7"},
{file = "pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c"},
{file = "pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820"},
{file = "pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845"},
]
[package.dependencies]
colorama = {version = ">=0.4", markers = "sys_platform == \"win32\""}
exceptiongroup = {version = ">=1", markers = "python_version < \"3.11\""}
iniconfig = ">=1"
packaging = ">=20"
colorama = {version = "*", markers = "sys_platform == \"win32\""}
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
iniconfig = "*"
packaging = "*"
pluggy = ">=1.5,<2"
pygments = ">=2.7.2"
tomli = {version = ">=1", markers = "python_version < \"3.11\""}
[package.extras]
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests", "setuptools", "xmlschema"]
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-asyncio"
@@ -4278,14 +4237,14 @@ six = ">=1.5"
[[package]]
name = "python-dotenv"
version = "1.1.1"
version = "1.1.0"
description = "Read key-value pairs from a .env file and set them as environment variables"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "python_dotenv-1.1.1-py3-none-any.whl", hash = "sha256:31f23644fe2602f88ff55e1f5c79ba497e01224ee7737937930c448e4d0e24dc"},
{file = "python_dotenv-1.1.1.tar.gz", hash = "sha256:a8a6399716257f45be6a007360200409fce5cda2661e3dec71d23dc15f6189ab"},
{file = "python_dotenv-1.1.0-py3-none-any.whl", hash = "sha256:d7c01d9e2293916c18baf562d95698754b0dbbb5e74d457c45d4f6561fb9d55d"},
{file = "python_dotenv-1.1.0.tar.gz", hash = "sha256:41f90bc6f5f177fb41f53e87666db362025010eb28f60a01c9143bfa33a2b2d5"},
]
[package.extras]
@@ -4731,19 +4690,19 @@ typing_extensions = ">=4.5.0"
[[package]]
name = "requests"
version = "2.32.4"
version = "2.32.3"
description = "Python HTTP for Humans."
optional = false
python-versions = ">=3.8"
groups = ["main", "dev"]
files = [
{file = "requests-2.32.4-py3-none-any.whl", hash = "sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c"},
{file = "requests-2.32.4.tar.gz", hash = "sha256:27d0316682c8a29834d3264820024b62a36942083d52caf2f14c0591336d3422"},
{file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
{file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
]
[package.dependencies]
certifi = ">=2017.4.17"
charset_normalizer = ">=2,<4"
charset-normalizer = ">=2,<4"
idna = ">=2.5,<4"
urllib3 = ">=1.21.1,<3"
@@ -4929,30 +4888,30 @@ pyasn1 = ">=0.1.3"
[[package]]
name = "ruff"
version = "0.12.2"
version = "0.11.12"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
groups = ["dev"]
files = [
{file = "ruff-0.12.2-py3-none-linux_armv6l.whl", hash = "sha256:093ea2b221df1d2b8e7ad92fc6ffdca40a2cb10d8564477a987b44fd4008a7be"},
{file = "ruff-0.12.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:09e4cf27cc10f96b1708100fa851e0daf21767e9709e1649175355280e0d950e"},
{file = "ruff-0.12.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:8ae64755b22f4ff85e9c52d1f82644abd0b6b6b6deedceb74bd71f35c24044cc"},
{file = "ruff-0.12.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3eb3a6b2db4d6e2c77e682f0b988d4d61aff06860158fdb413118ca133d57922"},
{file = "ruff-0.12.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:73448de992d05517170fc37169cbca857dfeaeaa8c2b9be494d7bcb0d36c8f4b"},
{file = "ruff-0.12.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b8b94317cbc2ae4a2771af641739f933934b03555e51515e6e021c64441532d"},
{file = "ruff-0.12.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:45fc42c3bf1d30d2008023a0a9a0cfb06bf9835b147f11fe0679f21ae86d34b1"},
{file = "ruff-0.12.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce48f675c394c37e958bf229fb5c1e843e20945a6d962cf3ea20b7a107dcd9f4"},
{file = "ruff-0.12.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:793d8859445ea47591272021a81391350205a4af65a9392401f418a95dfb75c9"},
{file = "ruff-0.12.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6932323db80484dda89153da3d8e58164d01d6da86857c79f1961934354992da"},
{file = "ruff-0.12.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6aa7e623a3a11538108f61e859ebf016c4f14a7e6e4eba1980190cacb57714ce"},
{file = "ruff-0.12.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2a4a20aeed74671b2def096bdf2eac610c7d8ffcbf4fb0e627c06947a1d7078d"},
{file = "ruff-0.12.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:71a4c550195612f486c9d1f2b045a600aeba851b298c667807ae933478fcef04"},
{file = "ruff-0.12.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:4987b8f4ceadf597c927beee65a5eaf994c6e2b631df963f86d8ad1bdea99342"},
{file = "ruff-0.12.2-py3-none-win32.whl", hash = "sha256:369ffb69b70cd55b6c3fc453b9492d98aed98062db9fec828cdfd069555f5f1a"},
{file = "ruff-0.12.2-py3-none-win_amd64.whl", hash = "sha256:dca8a3b6d6dc9810ed8f328d406516bf4d660c00caeaef36eb831cf4871b0639"},
{file = "ruff-0.12.2-py3-none-win_arm64.whl", hash = "sha256:48d6c6bfb4761df68bc05ae630e24f506755e702d4fb08f08460be778c7ccb12"},
{file = "ruff-0.12.2.tar.gz", hash = "sha256:d7b4f55cd6f325cb7621244f19c873c565a08aff5a4ba9c69aa7355f3f7afd3e"},
{file = "ruff-0.11.12-py3-none-linux_armv6l.whl", hash = "sha256:c7680aa2f0d4c4f43353d1e72123955c7a2159b8646cd43402de6d4a3a25d7cc"},
{file = "ruff-0.11.12-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:2cad64843da9f134565c20bcc430642de897b8ea02e2e79e6e02a76b8dcad7c3"},
{file = "ruff-0.11.12-py3-none-macosx_11_0_arm64.whl", hash = "sha256:9b6886b524a1c659cee1758140138455d3c029783d1b9e643f3624a5ee0cb0aa"},
{file = "ruff-0.11.12-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3cc3a3690aad6e86c1958d3ec3c38c4594b6ecec75c1f531e84160bd827b2012"},
{file = "ruff-0.11.12-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f97fdbc2549f456c65b3b0048560d44ddd540db1f27c778a938371424b49fe4a"},
{file = "ruff-0.11.12-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:74adf84960236961090e2d1348c1a67d940fd12e811a33fb3d107df61eef8fc7"},
{file = "ruff-0.11.12-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:b56697e5b8bcf1d61293ccfe63873aba08fdbcbbba839fc046ec5926bdb25a3a"},
{file = "ruff-0.11.12-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4d47afa45e7b0eaf5e5969c6b39cbd108be83910b5c74626247e366fd7a36a13"},
{file = "ruff-0.11.12-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:692bf9603fe1bf949de8b09a2da896f05c01ed7a187f4a386cdba6760e7f61be"},
{file = "ruff-0.11.12-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08033320e979df3b20dba567c62f69c45e01df708b0f9c83912d7abd3e0801cd"},
{file = "ruff-0.11.12-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:929b7706584f5bfd61d67d5070f399057d07c70585fa8c4491d78ada452d3bef"},
{file = "ruff-0.11.12-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:7de4a73205dc5756b8e09ee3ed67c38312dce1aa28972b93150f5751199981b5"},
{file = "ruff-0.11.12-py3-none-musllinux_1_2_i686.whl", hash = "sha256:2635c2a90ac1b8ca9e93b70af59dfd1dd2026a40e2d6eebaa3efb0465dd9cf02"},
{file = "ruff-0.11.12-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:d05d6a78a89166f03f03a198ecc9d18779076ad0eec476819467acb401028c0c"},
{file = "ruff-0.11.12-py3-none-win32.whl", hash = "sha256:f5a07f49767c4be4772d161bfc049c1f242db0cfe1bd976e0f0886732a4765d6"},
{file = "ruff-0.11.12-py3-none-win_amd64.whl", hash = "sha256:5a4d9f8030d8c3a45df201d7fb3ed38d0219bccd7955268e863ee4a115fa0832"},
{file = "ruff-0.11.12-py3-none-win_arm64.whl", hash = "sha256:65194e37853158d368e333ba282217941029a28ea90913c67e558c611d04daa5"},
{file = "ruff-0.11.12.tar.gz", hash = "sha256:43cf7f69c7d7c7d7513b9d59c5d8cafd704e05944f978614aa9faff6ac202603"},
]
[[package]]
@@ -4986,14 +4945,14 @@ files = [
[[package]]
name = "sentry-sdk"
version = "2.32.0"
version = "2.29.1"
description = "Python client for Sentry (https://sentry.io)"
optional = false
python-versions = ">=3.6"
groups = ["main"]
files = [
{file = "sentry_sdk-2.32.0-py2.py3-none-any.whl", hash = "sha256:6cf51521b099562d7ce3606da928c473643abe99b00ce4cb5626ea735f4ec345"},
{file = "sentry_sdk-2.32.0.tar.gz", hash = "sha256:9016c75d9316b0f6921ac14c8cd4fb938f26002430ac5be9945ab280f78bec6b"},
{file = "sentry_sdk-2.29.1-py2.py3-none-any.whl", hash = "sha256:90862fe0616ded4572da6c9dadb363121a1ae49a49e21c418f0634e9d10b4c19"},
{file = "sentry_sdk-2.29.1.tar.gz", hash = "sha256:8d4a0206b95fa5fe85e5e7517ed662e3888374bdc342c00e435e10e6d831aa6d"},
]
[package.dependencies]
@@ -5047,27 +5006,6 @@ statsig = ["statsig (>=0.55.3)"]
tornado = ["tornado (>=6)"]
unleash = ["UnleashClient (>=6.0.1)"]
[[package]]
name = "setuptools"
version = "80.9.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"},
{file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"},
]
[package.extras]
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""]
core = ["importlib_metadata (>=6) ; python_version < \"3.10\"", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1) ; python_version < \"3.11\"", "wheel (>=0.43.0)"]
cover = ["pytest-cov"]
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
enabler = ["pytest-enabler (>=2.2)"]
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
type = ["importlib_metadata (>=7.0.2) ; python_version < \"3.10\"", "jaraco.develop (>=7.21) ; sys_platform != \"cygwin\"", "mypy (==1.14.*)", "pytest-mypy"]
[[package]]
name = "sgmllib3k"
version = "1.0.0"
@@ -5280,23 +5218,23 @@ typing-extensions = {version = ">=4.5.0", markers = "python_version >= \"3.7\""}
[[package]]
name = "supabase"
version = "2.16.0"
version = "2.15.1"
description = "Supabase client for Python."
optional = false
python-versions = "<4.0,>=3.9"
groups = ["main"]
files = [
{file = "supabase-2.16.0-py3-none-any.whl", hash = "sha256:99065caab3d90a56650bf39fbd0e49740995da3738ab28706c61bd7f2401db55"},
{file = "supabase-2.16.0.tar.gz", hash = "sha256:98f3810158012d4ec0e3083f2e5515f5e10b32bd71e7d458662140e963c1d164"},
{file = "supabase-2.15.1-py3-none-any.whl", hash = "sha256:749299cdd74ecf528f52045c1e60d9dba81cc2054656f754c0ca7fba0dd34827"},
{file = "supabase-2.15.1.tar.gz", hash = "sha256:66e847dab9346062aa6a25b4e81ac786b972c5d4299827c57d1d5bd6a0346070"},
]
[package.dependencies]
gotrue = ">=2.11.0,<3.0.0"
httpx = ">=0.26,<0.29"
postgrest = ">0.19,<1.2"
realtime = ">=2.4.0,<2.6.0"
storage3 = ">=0.10,<0.13"
supafunc = ">=0.9,<0.11"
postgrest = ">0.19,<1.1"
realtime = ">=2.4.0,<2.5.0"
storage3 = ">=0.10,<0.12"
supafunc = ">=0.9,<0.10"
[[package]]
name = "supafunc"
@@ -5503,14 +5441,14 @@ files = [
[[package]]
name = "tweepy"
version = "4.16.0"
description = "Library for accessing the X API (Twitter)"
version = "4.15.0"
description = "Twitter library for Python"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "tweepy-4.16.0-py3-none-any.whl", hash = "sha256:48d1a1eb311d2c4b8990abcfa6f9fa2b2ad61be05c723b1a9b4f242656badae2"},
{file = "tweepy-4.16.0.tar.gz", hash = "sha256:1d95cbdc50bf6353a387f881f2584eaf60d14e00dbbdd8872a73de79c66878e3"},
{file = "tweepy-4.15.0-py3-none-any.whl", hash = "sha256:64adcea317158937059e4e2897b3ceb750b0c2dd5df58938c2da8f7eb3b88e6a"},
{file = "tweepy-4.15.0.tar.gz", hash = "sha256:1345cbcdf0a75e2d89f424c559fd49fda4d8cd7be25cd5131e3b57bad8a21d76"},
]
[package.dependencies]
@@ -5521,6 +5459,8 @@ requests-oauthlib = ">=1.2.0,<3"
[package.extras]
async = ["aiohttp (>=3.7.3,<4)", "async-lru (>=1.0.3,<3)"]
dev = ["coverage (>=4.4.2)", "coveralls (>=2.1.0)", "tox (>=3.21.0)"]
docs = ["myst-parser (==0.15.2)", "readthedocs-sphinx-search (==0.1.1)", "sphinx (==4.2.0)", "sphinx-hoverxref (==0.7b1)", "sphinx-tabs (==3.2.0)", "sphinx_rtd_theme (==1.0.0)"]
socks = ["requests[socks] (>=2.27.0,<3)"]
test = ["urllib3 (<2)", "vcrpy (>=1.10.3)"]
[[package]]
@@ -6280,14 +6220,14 @@ requests = "*"
[[package]]
name = "zerobouncesdk"
version = "1.1.2"
version = "1.1.1"
description = "ZeroBounce Python API - https://www.zerobounce.net."
optional = false
python-versions = ">=3.7"
groups = ["main"]
files = [
{file = "zerobouncesdk-1.1.2-py3-none-any.whl", hash = "sha256:a89febfb3adade01c314e6bad2113ad093f1e1cca6ddf9fcf445a8b2a9a458b4"},
{file = "zerobouncesdk-1.1.2.tar.gz", hash = "sha256:24810a2e39c963bc75b4732356b0fc8b10091f2c892f0c8b08fbb32640fdccaf"},
{file = "zerobouncesdk-1.1.1-py3-none-any.whl", hash = "sha256:9fb9dfa44fe4ce35d6f2e43d5144c31ca03544a3317d75643cb9f86b0c028675"},
{file = "zerobouncesdk-1.1.1.tar.gz", hash = "sha256:00aa537263d5bc21534c0007dd9f94ce8e0986caa530c5a0bbe0bd917451f236"},
]
[package.dependencies]
@@ -6429,4 +6369,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.13"
content-hash = "476228d2bf59b90edc5425c462c1263cbc1f2d346f79a826ac5e7efe7823aaa6"
content-hash = "6c93e51cf22c06548aa6d0e23ca8ceb4450f5e02d4142715e941aabc1a2cbd6a"

View File

@@ -10,67 +10,64 @@ packages = [{ include = "backend", format = "sdist" }]
[tool.poetry.dependencies]
python = ">=3.10,<3.13"
aio-pika = "^9.5.5"
aiodns = "^3.5.0"
anthropic = "^0.57.1"
aiodns = "^3.1.1"
anthropic = "^0.51.0"
apscheduler = "^3.11.0"
autogpt-libs = { path = "../autogpt_libs", develop = true }
bleach = { extras = ["css"], version = "^6.2.0" }
click = "^8.2.0"
cryptography = "^43.0"
discord-py = "^2.5.2"
e2b-code-interpreter = "^1.5.2"
fastapi = "^0.115.14"
e2b-code-interpreter = "^1.5.0"
fastapi = "^0.115.12"
feedparser = "^6.0.11"
flake8 = "^7.3.0"
google-api-python-client = "^2.176.0"
flake8 = "^7.2.0"
google-api-python-client = "^2.169.0"
google-auth-oauthlib = "^1.2.2"
google-cloud-storage = "^3.2.0"
google-cloud-storage = "^3.1.0"
googlemaps = "^4.10.0"
gravitasml = "^0.1.3"
groq = "^0.29.0"
groq = "^0.24.0"
jinja2 = "^3.1.6"
jsonref = "^1.1.0"
jsonschema = "^4.22.0"
launchdarkly-server-sdk = "^9.11.0"
mem0ai = "^0.1.114"
mem0ai = "^0.1.98"
moviepy = "^2.1.2"
ollama = "^0.5.1"
openai = "^1.93.2"
ollama = "^0.4.8"
openai = "^1.78.1"
pika = "^1.3.2"
pinecone = "^5.3.1"
poetry = "2.1.1" # CHECK DEPENDABOT SUPPORT BEFORE UPGRADING
postmarker = "^1.0"
praw = "~7.8.1"
prisma = "^0.15.0"
prometheus-client = "^0.22.1"
prometheus-client = "^0.21.1"
psutil = "^7.0.0"
psycopg2-binary = "^2.9.10"
pydantic = { extras = ["email"], version = "^2.11.7" }
pydantic-settings = "^2.10.1"
pytest = "^8.4.1"
pydantic = { extras = ["email"], version = "^2.11.4" }
pydantic-settings = "^2.9.1"
pytest = "^8.3.5"
pytest-asyncio = "^0.26.0"
python-dotenv = "^1.1.1"
python-dotenv = "^1.1.0"
python-multipart = "^0.0.20"
redis = "^5.2.0"
replicate = "^1.0.6"
sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlalchemy"], version = "^2.32.0"}
sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlalchemy"], version = "^2.28.0"}
sqlalchemy = "^2.0.40"
strenum = "^0.4.9"
stripe = "^11.5.0"
supabase = "2.16.0"
supabase = "2.15.1"
tenacity = "^9.1.2"
todoist-api-python = "^2.1.7"
tweepy = "^4.16.0"
tweepy = "^4.14.0"
uvicorn = { extras = ["standard"], version = "^0.34.2" }
websockets = "^14.2"
youtube-transcript-api = "^0.6.2"
zerobouncesdk = "^1.1.2"
zerobouncesdk = "^1.1.1"
# NOTE: please insert new dependencies in their alphabetical location
pytest-snapshot = "^0.9.0"
aiofiles = "^24.1.0"
tiktoken = "^0.9.0"
aioclamd = "^1.0.0"
setuptools = "^80.9.0"
[tool.poetry.group.dev.dependencies]
aiohappyeyeballs = "^2.6.1"
@@ -78,12 +75,12 @@ black = "^24.10.0"
faker = "^33.3.1"
httpx = "^0.28.1"
isort = "^5.13.2"
poethepoet = "^0.36.0"
pyright = "^1.1.402"
poethepoet = "^0.34.0"
pyright = "^1.1.400"
pytest-mock = "^3.14.0"
pytest-watcher = "^0.4.2"
requests = "^2.32.4"
ruff = "^0.12.2"
requests = "^2.32.3"
ruff = "^0.11.10"
# NOTE: please insert new dependencies in their alphabetical location
[build-system]
@@ -115,11 +112,6 @@ ignore_patterns = []
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session"
filterwarnings = [
"ignore:'audioop' is deprecated:DeprecationWarning:discord.player",
"ignore:invalid escape sequence:DeprecationWarning:tweepy.api",
]
[tool.ruff]
target-version = "py310"

Some files were not shown because too many files have changed in this diff Show More