mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-13 00:58:16 -05:00
Compare commits
26 Commits
fix/blocks
...
fix/cookie
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8958357343 | ||
|
|
0b1b29a9bb | ||
|
|
d5dfc40263 | ||
|
|
d9f9f80346 | ||
|
|
a47e1916fb | ||
|
|
49b22576b5 | ||
|
|
db3d62eaa0 | ||
|
|
46da6a1c5f | ||
|
|
f1471377c3 | ||
|
|
13e5f6bf8e | ||
|
|
add32b8449 | ||
|
|
2f11dade70 | ||
|
|
8f1ebfc696 | ||
|
|
fc975e9e17 | ||
|
|
7985da3e8e | ||
|
|
34184f7cc0 | ||
|
|
ade66f3d27 | ||
|
|
9b221ff931 | ||
|
|
0955cfb869 | ||
|
|
bf26b8f14a | ||
|
|
82f6687646 | ||
|
|
10efb1772e | ||
|
|
692c6defce | ||
|
|
08c56a337b | ||
|
|
41ebd5fe5d | ||
|
|
e8657ed711 |
51
.github/workflows/platform-backend-ci.yml
vendored
51
.github/workflows/platform-backend-ci.yml
vendored
@@ -50,23 +50,6 @@ jobs:
|
||||
env:
|
||||
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
|
||||
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
|
||||
clamav:
|
||||
image: clamav/clamav-debian:latest
|
||||
ports:
|
||||
- 3310:3310
|
||||
env:
|
||||
CLAMAV_NO_FRESHCLAMD: false
|
||||
CLAMD_CONF_StreamMaxLength: 50M
|
||||
CLAMD_CONF_MaxFileSize: 100M
|
||||
CLAMD_CONF_MaxScanSize: 100M
|
||||
CLAMD_CONF_MaxThreads: 4
|
||||
CLAMD_CONF_ReadTimeout: 300
|
||||
options: >-
|
||||
--health-cmd "clamdscan --version || exit 1"
|
||||
--health-interval 30s
|
||||
--health-timeout 10s
|
||||
--health-retries 5
|
||||
--health-start-period 180s
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -148,35 +131,6 @@ jobs:
|
||||
# outputs:
|
||||
# DB_URL, API_URL, GRAPHQL_URL, ANON_KEY, SERVICE_ROLE_KEY, JWT_SECRET
|
||||
|
||||
- name: Wait for ClamAV to be ready
|
||||
run: |
|
||||
echo "Waiting for ClamAV daemon to start..."
|
||||
max_attempts=60
|
||||
attempt=0
|
||||
|
||||
until nc -z localhost 3310 || [ $attempt -eq $max_attempts ]; do
|
||||
echo "ClamAV is unavailable - sleeping (attempt $((attempt+1))/$max_attempts)"
|
||||
sleep 5
|
||||
attempt=$((attempt+1))
|
||||
done
|
||||
|
||||
if [ $attempt -eq $max_attempts ]; then
|
||||
echo "ClamAV failed to start after $((max_attempts*5)) seconds"
|
||||
echo "Checking ClamAV service logs..."
|
||||
docker logs $(docker ps -q --filter "ancestor=clamav/clamav-debian:latest") 2>&1 | tail -50 || echo "No ClamAV container found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "ClamAV is ready!"
|
||||
|
||||
# Verify ClamAV is responsive
|
||||
echo "Testing ClamAV connection..."
|
||||
timeout 10 bash -c 'echo "PING" | nc localhost 3310' || {
|
||||
echo "ClamAV is not responding to PING"
|
||||
docker logs $(docker ps -q --filter "ancestor=clamav/clamav-debian:latest") 2>&1 | tail -50 || echo "No ClamAV container found"
|
||||
exit 1
|
||||
}
|
||||
|
||||
- name: Run Database Migrations
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
env:
|
||||
@@ -190,9 +144,9 @@ jobs:
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG
|
||||
poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG test
|
||||
else
|
||||
poetry run pytest -s -vv
|
||||
poetry run pytest -s -vv test
|
||||
fi
|
||||
if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
env:
|
||||
@@ -205,7 +159,6 @@ jobs:
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
env:
|
||||
CI: true
|
||||
|
||||
106
.github/workflows/platform-frontend-ci.yml
vendored
106
.github/workflows/platform-frontend-ci.yml
vendored
@@ -18,45 +18,11 @@ defaults:
|
||||
working-directory: autogpt_platform/frontend
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Generate cache key
|
||||
id: cache-key
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('**/pnpm-lock.yaml') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
@@ -66,14 +32,6 @@ jobs:
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
@@ -82,11 +40,9 @@ jobs:
|
||||
|
||||
type-check:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
@@ -96,62 +52,14 @@ jobs:
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run tsc check
|
||||
run: pnpm type-check
|
||||
|
||||
chromatic:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
# Only run on dev branch pushes or PRs targeting dev
|
||||
if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run Chromatic
|
||||
uses: chromaui/action@latest
|
||||
with:
|
||||
projectToken: chpt_9e7c1a76478c9c8
|
||||
onlyChanged: true
|
||||
workingDir: autogpt_platform/frontend
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -189,14 +97,6 @@ jobs:
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml up -d
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
@@ -212,8 +112,6 @@ jobs:
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build --project=${{ matrix.browser }}
|
||||
env:
|
||||
BROWSER_TYPE: ${{ matrix.browser }}
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -165,7 +165,7 @@ package-lock.json
|
||||
|
||||
# Allow for locally private items
|
||||
# private
|
||||
pri*
|
||||
pri*
|
||||
# ignore
|
||||
ig*
|
||||
.github_access_token
|
||||
|
||||
@@ -19,7 +19,7 @@ cd backend && poetry install
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
# Start all services (database, redis, rabbitmq)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend server
|
||||
@@ -32,7 +32,6 @@ poetry run test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
@@ -78,7 +77,6 @@ npm run type-check
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
### Frontend Architecture
|
||||
- **Framework**: Next.js App Router with React Server Components
|
||||
@@ -92,7 +90,6 @@ npm run type-check
|
||||
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
4. **Store**: Marketplace for sharing agent templates
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Testing Approach
|
||||
- Backend uses pytest with snapshot testing for API responses
|
||||
@@ -121,7 +118,6 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
3. Define input/output schemas
|
||||
4. Implement `run` method
|
||||
5. Register in block registry
|
||||
6. Generate the block uuid using `uuid.uuid4()`
|
||||
|
||||
**Modifying the API:**
|
||||
1. Update route in `/backend/backend/server/routers/`
|
||||
@@ -133,15 +129,4 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
1. Components go in `/frontend/src/components/`
|
||||
2. Use existing UI components from `/frontend/src/components/ui/`
|
||||
3. Add Storybook stories for new components
|
||||
4. Test with Playwright if user-facing
|
||||
|
||||
### Security Implementation
|
||||
|
||||
**Cache Protection Middleware:**
|
||||
- Located in `/backend/backend/server/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
4. Test with Playwright if user-facing
|
||||
@@ -62,12 +62,6 @@ To run the AutoGPT Platform, follow these steps:
|
||||
pnpm i
|
||||
```
|
||||
|
||||
Generate the API client (this step is required before running the frontend):
|
||||
|
||||
```
|
||||
pnpm generate:api-client
|
||||
```
|
||||
|
||||
Then start the frontend application in development mode:
|
||||
|
||||
```
|
||||
@@ -170,27 +164,3 @@ To persist data for PostgreSQL and Redis, you can modify the `docker-compose.yml
|
||||
3. Save the file and run `docker compose up -d` to apply the changes.
|
||||
|
||||
This configuration will create named volumes for PostgreSQL and Redis, ensuring that your data persists across container restarts.
|
||||
|
||||
### API Client Generation
|
||||
|
||||
The platform includes scripts for generating and managing the API client:
|
||||
|
||||
- `pnpm fetch:openapi`: Fetches the OpenAPI specification from the backend service (requires backend to be running on port 8006)
|
||||
- `pnpm generate:api-client`: Generates the TypeScript API client from the OpenAPI specification using Orval
|
||||
- `pnpm generate:api-all`: Runs both fetch and generate commands in sequence
|
||||
|
||||
#### Manual API Client Updates
|
||||
|
||||
If you need to update the API client after making changes to the backend API:
|
||||
|
||||
1. Ensure the backend services are running:
|
||||
```
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
2. Generate the updated API client:
|
||||
```
|
||||
pnpm generate:api-all
|
||||
```
|
||||
|
||||
This will fetch the latest OpenAPI specification and regenerate the TypeScript client code.
|
||||
|
||||
48
autogpt_platform/autogpt_libs/poetry.lock
generated
48
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
@@ -177,7 +177,7 @@ files = [
|
||||
{file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"},
|
||||
{file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
|
||||
]
|
||||
markers = {main = "python_version < \"3.11\"", dev = "python_full_version < \"3.11.3\""}
|
||||
markers = {main = "python_version == \"3.10\"", dev = "python_full_version < \"3.11.3\""}
|
||||
|
||||
[[package]]
|
||||
name = "attrs"
|
||||
@@ -390,7 +390,7 @@ description = "Backport of PEP 654 (exception groups)"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"},
|
||||
{file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
|
||||
@@ -1667,30 +1667,30 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.12.2"
|
||||
version = "0.11.10"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "ruff-0.12.2-py3-none-linux_armv6l.whl", hash = "sha256:093ea2b221df1d2b8e7ad92fc6ffdca40a2cb10d8564477a987b44fd4008a7be"},
|
||||
{file = "ruff-0.12.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:09e4cf27cc10f96b1708100fa851e0daf21767e9709e1649175355280e0d950e"},
|
||||
{file = "ruff-0.12.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:8ae64755b22f4ff85e9c52d1f82644abd0b6b6b6deedceb74bd71f35c24044cc"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3eb3a6b2db4d6e2c77e682f0b988d4d61aff06860158fdb413118ca133d57922"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:73448de992d05517170fc37169cbca857dfeaeaa8c2b9be494d7bcb0d36c8f4b"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b8b94317cbc2ae4a2771af641739f933934b03555e51515e6e021c64441532d"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:45fc42c3bf1d30d2008023a0a9a0cfb06bf9835b147f11fe0679f21ae86d34b1"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce48f675c394c37e958bf229fb5c1e843e20945a6d962cf3ea20b7a107dcd9f4"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:793d8859445ea47591272021a81391350205a4af65a9392401f418a95dfb75c9"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6932323db80484dda89153da3d8e58164d01d6da86857c79f1961934354992da"},
|
||||
{file = "ruff-0.12.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6aa7e623a3a11538108f61e859ebf016c4f14a7e6e4eba1980190cacb57714ce"},
|
||||
{file = "ruff-0.12.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2a4a20aeed74671b2def096bdf2eac610c7d8ffcbf4fb0e627c06947a1d7078d"},
|
||||
{file = "ruff-0.12.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:71a4c550195612f486c9d1f2b045a600aeba851b298c667807ae933478fcef04"},
|
||||
{file = "ruff-0.12.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:4987b8f4ceadf597c927beee65a5eaf994c6e2b631df963f86d8ad1bdea99342"},
|
||||
{file = "ruff-0.12.2-py3-none-win32.whl", hash = "sha256:369ffb69b70cd55b6c3fc453b9492d98aed98062db9fec828cdfd069555f5f1a"},
|
||||
{file = "ruff-0.12.2-py3-none-win_amd64.whl", hash = "sha256:dca8a3b6d6dc9810ed8f328d406516bf4d660c00caeaef36eb831cf4871b0639"},
|
||||
{file = "ruff-0.12.2-py3-none-win_arm64.whl", hash = "sha256:48d6c6bfb4761df68bc05ae630e24f506755e702d4fb08f08460be778c7ccb12"},
|
||||
{file = "ruff-0.12.2.tar.gz", hash = "sha256:d7b4f55cd6f325cb7621244f19c873c565a08aff5a4ba9c69aa7355f3f7afd3e"},
|
||||
{file = "ruff-0.11.10-py3-none-linux_armv6l.whl", hash = "sha256:859a7bfa7bc8888abbea31ef8a2b411714e6a80f0d173c2a82f9041ed6b50f58"},
|
||||
{file = "ruff-0.11.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:968220a57e09ea5e4fd48ed1c646419961a0570727c7e069842edd018ee8afed"},
|
||||
{file = "ruff-0.11.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:1067245bad978e7aa7b22f67113ecc6eb241dca0d9b696144256c3a879663bca"},
|
||||
{file = "ruff-0.11.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4854fd09c7aed5b1590e996a81aeff0c9ff51378b084eb5a0b9cd9518e6cff2"},
|
||||
{file = "ruff-0.11.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8b4564e9f99168c0f9195a0fd5fa5928004b33b377137f978055e40008a082c5"},
|
||||
{file = "ruff-0.11.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b6a9cc5b62c03cc1fea0044ed8576379dbaf751d5503d718c973d5418483641"},
|
||||
{file = "ruff-0.11.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:607ecbb6f03e44c9e0a93aedacb17b4eb4f3563d00e8b474298a201622677947"},
|
||||
{file = "ruff-0.11.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7b3a522fa389402cd2137df9ddefe848f727250535c70dafa840badffb56b7a4"},
|
||||
{file = "ruff-0.11.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2f071b0deed7e9245d5820dac235cbdd4ef99d7b12ff04c330a241ad3534319f"},
|
||||
{file = "ruff-0.11.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a60e3a0a617eafba1f2e4186d827759d65348fa53708ca547e384db28406a0b"},
|
||||
{file = "ruff-0.11.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:da8ec977eaa4b7bf75470fb575bea2cb41a0e07c7ea9d5a0a97d13dbca697bf2"},
|
||||
{file = "ruff-0.11.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:ddf8967e08227d1bd95cc0851ef80d2ad9c7c0c5aab1eba31db49cf0a7b99523"},
|
||||
{file = "ruff-0.11.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5a94acf798a82db188f6f36575d80609072b032105d114b0f98661e1679c9125"},
|
||||
{file = "ruff-0.11.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:3afead355f1d16d95630df28d4ba17fb2cb9c8dfac8d21ced14984121f639bad"},
|
||||
{file = "ruff-0.11.10-py3-none-win32.whl", hash = "sha256:dc061a98d32a97211af7e7f3fa1d4ca2fcf919fb96c28f39551f35fc55bdbc19"},
|
||||
{file = "ruff-0.11.10-py3-none-win_amd64.whl", hash = "sha256:5cc725fbb4d25b0f185cb42df07ab6b76c4489b4bfb740a175f3a59c70e8a224"},
|
||||
{file = "ruff-0.11.10-py3-none-win_arm64.whl", hash = "sha256:ef69637b35fb8b210743926778d0e45e1bffa850a7c61e428c6b971549b5f5d1"},
|
||||
{file = "ruff-0.11.10.tar.gz", hash = "sha256:d522fb204b4959909ecac47da02830daec102eeb100fb50ea9554818d47a5fa6"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1823,7 +1823,7 @@ description = "A lil' TOML parser"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
markers = "python_version < \"3.11\""
|
||||
markers = "python_version == \"3.10\""
|
||||
files = [
|
||||
{file = "tomli-2.1.0-py3-none-any.whl", hash = "sha256:a5c57c3d1c56f5ccdf89f6523458f60ef716e210fc47c4cfb188c5ba473e0391"},
|
||||
{file = "tomli-2.1.0.tar.gz", hash = "sha256:3f646cae2aec94e17d04973e4249548320197cfabdf130015d023de4b74d8ab8"},
|
||||
@@ -2176,4 +2176,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "574057127b05f28c2ae39f7b11aa0d7c52f857655e9223e23a27c9989b2ac10f"
|
||||
content-hash = "d92143928a88ca3a56ac200c335910eafac938940022fed8bd0d17c95040b54f"
|
||||
|
||||
@@ -23,7 +23,7 @@ uvicorn = "^0.34.3"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
redis = "^5.2.1"
|
||||
ruff = "^0.12.2"
|
||||
ruff = "^0.11.10"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -20,7 +20,7 @@ def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
modules = [
|
||||
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
|
||||
for f in current_dir.rglob("*.py")
|
||||
if f.is_file() and f.name != "__init__.py" and not f.name.startswith("test_")
|
||||
if f.is_file() and f.name != "__init__.py"
|
||||
]
|
||||
for module in modules:
|
||||
if not re.match("^[a-z0-9_.]+$", module):
|
||||
|
||||
@@ -2,8 +2,6 @@ import asyncio
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -14,10 +12,10 @@ from backend.data.block import (
|
||||
get_block,
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json, retry
|
||||
from backend.data.model import CredentialsMetaInput, SchemaField
|
||||
from backend.util import json
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentExecutorBlock(Block):
|
||||
@@ -30,9 +28,9 @@ class AgentExecutorBlock(Block):
|
||||
input_schema: dict = SchemaField(description="Input schema for the graph")
|
||||
output_schema: dict = SchemaField(description="Output schema for the graph")
|
||||
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = SchemaField(
|
||||
default=None, hidden=True
|
||||
)
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = SchemaField(default=None, hidden=True)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
|
||||
@@ -73,46 +71,31 @@ class AgentExecutorBlock(Block):
|
||||
graph_version=input_data.graph_version,
|
||||
user_id=input_data.user_id,
|
||||
inputs=input_data.inputs,
|
||||
nodes_input_masks=input_data.nodes_input_masks,
|
||||
node_credentials_input_map=input_data.node_credentials_input_map,
|
||||
use_db_query=False,
|
||||
)
|
||||
|
||||
logger = execution_utils.LogMetadata(
|
||||
logger=_logger,
|
||||
user_id=input_data.user_id,
|
||||
graph_eid=graph_exec.id,
|
||||
graph_id=input_data.graph_id,
|
||||
node_eid="*",
|
||||
node_id="*",
|
||||
block_name=self.name,
|
||||
)
|
||||
|
||||
try:
|
||||
async for name, data in self._run(
|
||||
graph_id=input_data.graph_id,
|
||||
graph_version=input_data.graph_version,
|
||||
graph_exec_id=graph_exec.id,
|
||||
user_id=input_data.user_id,
|
||||
logger=logger,
|
||||
):
|
||||
yield name, data
|
||||
except asyncio.CancelledError:
|
||||
await self._stop(
|
||||
graph_exec_id=graph_exec.id,
|
||||
user_id=input_data.user_id,
|
||||
logger=logger,
|
||||
)
|
||||
logger.warning(
|
||||
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} was cancelled."
|
||||
f"Execution of graph {input_data.graph_id} version {input_data.graph_version} was cancelled."
|
||||
)
|
||||
await execution_utils.stop_graph_execution(
|
||||
graph_exec.id, use_db_query=False
|
||||
)
|
||||
except Exception as e:
|
||||
await self._stop(
|
||||
graph_exec_id=graph_exec.id,
|
||||
user_id=input_data.user_id,
|
||||
logger=logger,
|
||||
)
|
||||
logger.error(
|
||||
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} failed: {e}, execution is stopped."
|
||||
f"Execution of graph {input_data.graph_id} version {input_data.graph_version} failed: {e}, stopping execution."
|
||||
)
|
||||
await execution_utils.stop_graph_execution(
|
||||
graph_exec.id, use_db_query=False
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -122,7 +105,6 @@ class AgentExecutorBlock(Block):
|
||||
graph_version: int,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
logger,
|
||||
) -> BlockOutput:
|
||||
|
||||
from backend.data.execution import ExecutionEventType
|
||||
@@ -175,25 +157,3 @@ class AgentExecutorBlock(Block):
|
||||
f"Execution {log_id} produced {output_name}: {output_data}"
|
||||
)
|
||||
yield output_name, output_data
|
||||
|
||||
@retry.func_retry
|
||||
async def _stop(
|
||||
self,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
logger,
|
||||
) -> None:
|
||||
from backend.executor import utils as execution_utils
|
||||
|
||||
log_id = f"Graph exec-id: {graph_exec_id}"
|
||||
logger.info(f"Stopping execution of {log_id}")
|
||||
|
||||
try:
|
||||
await execution_utils.stop_graph_execution(
|
||||
graph_exec_id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
use_db_query=False,
|
||||
)
|
||||
logger.info(f"Execution {log_id} stopped successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop execution {log_id}: {e}")
|
||||
|
||||
@@ -53,7 +53,6 @@ class AudioTrack(str, Enum):
|
||||
REFRESHER = ("Refresher",)
|
||||
TOURIST = ("Tourist",)
|
||||
TWIN_TYCHES = ("Twin Tyches",)
|
||||
DONT_STOP_ME_ABSTRACT_FUTURE_BASS = ("Dont Stop Me Abstract Future Bass",)
|
||||
|
||||
@property
|
||||
def audio_url(self):
|
||||
@@ -79,7 +78,6 @@ class AudioTrack(str, Enum):
|
||||
AudioTrack.REFRESHER: "https://cdn.tfrv.xyz/audio/refresher.mp3",
|
||||
AudioTrack.TOURIST: "https://cdn.tfrv.xyz/audio/tourist.mp3",
|
||||
AudioTrack.TWIN_TYCHES: "https://cdn.tfrv.xyz/audio/twin-tynches.mp3",
|
||||
AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS: "https://cdn.revid.ai/audio/_dont-stop-me-abstract-future-bass.mp3",
|
||||
}
|
||||
return audio_urls[self]
|
||||
|
||||
@@ -107,7 +105,6 @@ class GenerationPreset(str, Enum):
|
||||
MOVIE = ("Movie",)
|
||||
STYLIZED_ILLUSTRATION = ("Stylized Illustration",)
|
||||
MANGA = ("Manga",)
|
||||
DEFAULT = ("DEFAULT",)
|
||||
|
||||
|
||||
class Voice(str, Enum):
|
||||
@@ -117,7 +114,6 @@ class Voice(str, Enum):
|
||||
JESSICA = "Jessica"
|
||||
CHARLOTTE = "Charlotte"
|
||||
CALLUM = "Callum"
|
||||
EVA = "Eva"
|
||||
|
||||
@property
|
||||
def voice_id(self):
|
||||
@@ -128,7 +124,6 @@ class Voice(str, Enum):
|
||||
Voice.JESSICA: "cgSgspJ2msm6clMCkdW9",
|
||||
Voice.CHARLOTTE: "XB0fDUnXU5powFXDhCwa",
|
||||
Voice.CALLUM: "N2lVS1w4EtoT3dr4eOWO",
|
||||
Voice.EVA: "FGY2WhTYpPnrIDTdsKH5",
|
||||
}
|
||||
return voice_id_map[self]
|
||||
|
||||
@@ -146,8 +141,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIShortformVideoCreatorBlock(Block):
|
||||
"""Creates a short‑form text‑to‑video clip using stock or AI imagery."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
@@ -191,58 +184,6 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
video_url: str = SchemaField(description="The URL of the created video")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = await Requests().post(url, headers=headers)
|
||||
webhook_data = response.json()
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
"""Create a video using the Revid API."""
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().post(url, json=payload, headers=headers)
|
||||
logger.debug(
|
||||
f"API Response Status Code: {response.status}, Content: {response.text}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
"""Check the status of a video creation job."""
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().get(url, headers=headers)
|
||||
return response.json()
|
||||
|
||||
async def wait_for_video(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
"""Wait for video creation to complete and return the video URL."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = await self.check_video_status(api_key, pid)
|
||||
logger.debug(f"Video status: {status}")
|
||||
|
||||
if status.get("status") == "ready" and "videoUrl" in status:
|
||||
return status["videoUrl"]
|
||||
elif status.get("status") == "error":
|
||||
error_message = status.get("error", "Unknown error occurred")
|
||||
logger.error(f"Video creation failed: {error_message}")
|
||||
raise ValueError(f"Video creation failed: {error_message}")
|
||||
elif status.get("status") in ["FAILED", "CANCELED"]:
|
||||
logger.error(f"Video creation failed: {status.get('message')}")
|
||||
raise ValueError(f"Video creation failed: {status.get('message')}")
|
||||
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="361697fb-0c4f-4feb-aed3-8320c88c771b",
|
||||
@@ -261,22 +202,70 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
"voice": Voice.LILY,
|
||||
"video_style": VisualMediaType.STOCK_VIDEOS,
|
||||
},
|
||||
test_output=("video_url", "https://example.com/video.mp4"),
|
||||
test_output=(
|
||||
"video_url",
|
||||
"https://example.com/video.mp4",
|
||||
),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"create_webhook": lambda: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/video.mp4",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
|
||||
"create_video": lambda api_key, payload: {"pid": "test_pid"},
|
||||
"wait_for_video": lambda api_key, pid, webhook_token, max_wait_time=1000: "https://example.com/video.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def create_webhook(self):
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = await Requests().post(url, headers=headers)
|
||||
webhook_data = response.json()
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().post(url, json=payload, headers=headers)
|
||||
logger.debug(
|
||||
f"API Response Status Code: {response.status}, Content: {response.text}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().get(url, headers=headers)
|
||||
return response.json()
|
||||
|
||||
async def wait_for_video(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
webhook_token: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = await self.check_video_status(api_key, pid)
|
||||
logger.debug(f"Video status: {status}")
|
||||
|
||||
if status.get("status") == "ready" and "videoUrl" in status:
|
||||
return status["videoUrl"]
|
||||
elif status.get("status") == "error":
|
||||
error_message = status.get("error", "Unknown error occurred")
|
||||
logger.error(f"Video creation failed: {error_message}")
|
||||
raise ValueError(f"Video creation failed: {error_message}")
|
||||
elif status.get("status") in ["FAILED", "CANCELED"]:
|
||||
logger.error(f"Video creation failed: {status.get('message')}")
|
||||
raise ValueError(f"Video creation failed: {status.get('message')}")
|
||||
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
@@ -284,18 +273,20 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
logger.debug(f"Webhook URL: {webhook_url}")
|
||||
|
||||
audio_url = input_data.background_music.audio_url
|
||||
|
||||
payload = {
|
||||
"frameRate": input_data.frame_rate,
|
||||
"resolution": input_data.resolution,
|
||||
"frameDurationMultiplier": 18,
|
||||
"webhook": None,
|
||||
"webhook": webhook_url,
|
||||
"creationParams": {
|
||||
"mediaType": input_data.video_style,
|
||||
"captionPresetName": "Wrap 1",
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"hasEnhancedGeneration": True,
|
||||
"generationPreset": input_data.generation_preset.name,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"selectedAudio": input_data.background_music,
|
||||
"origin": "/create",
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
@@ -311,7 +302,7 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
"selectedStoryStyle": {"value": "custom", "label": "Custom"},
|
||||
"hasToGenerateVideos": input_data.video_style
|
||||
!= VisualMediaType.STOCK_VIDEOS,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
"audioUrl": audio_url,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -328,370 +319,8 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
logger.debug(
|
||||
f"Video created with project ID: {pid}. Waiting for completion..."
|
||||
)
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
video_url = await self.wait_for_video(
|
||||
credentials.api_key, pid, webhook_token
|
||||
)
|
||||
logger.debug(f"Video ready: {video_url}")
|
||||
yield "video_url", video_url
|
||||
|
||||
|
||||
class AIAdMakerVideoCreatorBlock(Block):
|
||||
"""Generates a 30‑second vertical AI advert using optional user‑supplied imagery."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Credentials for Revid.ai API access.",
|
||||
)
|
||||
script: str = SchemaField(
|
||||
description="Short advertising copy. Line breaks create new scenes.",
|
||||
placeholder="Introducing Foobar – [show product photo] the gadget that does it all.",
|
||||
)
|
||||
ratio: str = SchemaField(description="Aspect ratio", default="9 / 16")
|
||||
target_duration: int = SchemaField(
|
||||
description="Desired length of the ad in seconds.", default=30
|
||||
)
|
||||
voice: Voice = SchemaField(
|
||||
description="Narration voice", default=Voice.EVA, placeholder=Voice.EVA
|
||||
)
|
||||
background_music: AudioTrack = SchemaField(
|
||||
description="Background track",
|
||||
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS,
|
||||
)
|
||||
input_media_urls: list[str] = SchemaField(
|
||||
description="List of image URLs to feature in the advert.", default=[]
|
||||
)
|
||||
use_only_provided_media: bool = SchemaField(
|
||||
description="Restrict visuals to supplied images only.", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="URL of the finished advert")
|
||||
error: str = SchemaField(description="Error message on failure")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = await Requests().post(url, headers=headers)
|
||||
webhook_data = response.json()
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
"""Create a video using the Revid API."""
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().post(url, json=payload, headers=headers)
|
||||
logger.debug(
|
||||
f"API Response Status Code: {response.status}, Content: {response.text}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
"""Check the status of a video creation job."""
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().get(url, headers=headers)
|
||||
return response.json()
|
||||
|
||||
async def wait_for_video(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
"""Wait for video creation to complete and return the video URL."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = await self.check_video_status(api_key, pid)
|
||||
logger.debug(f"Video status: {status}")
|
||||
|
||||
if status.get("status") == "ready" and "videoUrl" in status:
|
||||
return status["videoUrl"]
|
||||
elif status.get("status") == "error":
|
||||
error_message = status.get("error", "Unknown error occurred")
|
||||
logger.error(f"Video creation failed: {error_message}")
|
||||
raise ValueError(f"Video creation failed: {error_message}")
|
||||
elif status.get("status") in ["FAILED", "CANCELED"]:
|
||||
logger.error(f"Video creation failed: {status.get('message')}")
|
||||
raise ValueError(f"Video creation failed: {status.get('message')}")
|
||||
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="58bd2a19-115d-4fd1-8ca4-13b9e37fa6a0",
|
||||
description="Creates an AI‑generated 30‑second advert (text + images)",
|
||||
categories={BlockCategory.MARKETING, BlockCategory.AI},
|
||||
input_schema=AIAdMakerVideoCreatorBlock.Input,
|
||||
output_schema=AIAdMakerVideoCreatorBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "Test product launch!",
|
||||
"input_media_urls": [
|
||||
"https://cdn.revid.ai/uploads/1747076315114-image.png",
|
||||
],
|
||||
},
|
||||
test_output=("video_url", "https://example.com/ad.mp4"),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/ad.mp4",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
|
||||
payload = {
|
||||
"webhook": webhook_url,
|
||||
"creationParams": {
|
||||
"targetDuration": input_data.target_duration,
|
||||
"ratio": input_data.ratio,
|
||||
"mediaType": "aiVideo",
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
"slug": "ai-ad-generator",
|
||||
"slugNew": "",
|
||||
"isCopiedFrom": False,
|
||||
"hasToGenerateVoice": True,
|
||||
"hasToTranscript": False,
|
||||
"hasToSearchMedia": True,
|
||||
"hasAvatar": False,
|
||||
"hasWebsiteRecorder": False,
|
||||
"hasTextSmallAtBottom": False,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
|
||||
"selectedAvatarType": "video/mp4",
|
||||
"websiteToRecord": "",
|
||||
"hasToGenerateCover": True,
|
||||
"nbGenerations": 1,
|
||||
"disableCaptions": False,
|
||||
"mediaMultiplier": "medium",
|
||||
"characters": [],
|
||||
"captionPresetName": "Revid",
|
||||
"sourceType": "contentScraping",
|
||||
"selectedStoryStyle": {"value": "custom", "label": "General"},
|
||||
"generationPreset": "DEFAULT",
|
||||
"hasToGenerateMusic": False,
|
||||
"isOptimizedForChinese": False,
|
||||
"generationUserPrompt": "",
|
||||
"enableNsfwFilter": False,
|
||||
"addStickers": False,
|
||||
"typeMovingImageAnim": "dynamic",
|
||||
"hasToGenerateSoundEffects": False,
|
||||
"forceModelType": "gpt-image-1",
|
||||
"selectedCharacters": [],
|
||||
"lang": "",
|
||||
"voiceSpeed": 1,
|
||||
"disableAudio": False,
|
||||
"disableVoice": False,
|
||||
"useOnlyProvidedMedia": input_data.use_only_provided_media,
|
||||
"imageGenerationModel": "ultra",
|
||||
"videoGenerationModel": "pro",
|
||||
"hasEnhancedGeneration": True,
|
||||
"hasEnhancedGenerationPro": True,
|
||||
"inputMedias": [
|
||||
{"url": url, "title": "", "type": "image"}
|
||||
for url in input_data.input_media_urls
|
||||
],
|
||||
"hasToGenerateVideos": True,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
"watermark": None,
|
||||
},
|
||||
}
|
||||
|
||||
response = await self.create_video(credentials.api_key, payload)
|
||||
pid = response.get("pid")
|
||||
if not pid:
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
yield "video_url", video_url
|
||||
|
||||
|
||||
class AIScreenshotToVideoAdBlock(Block):
|
||||
"""Creates an advert where the supplied screenshot is narrated by an AI avatar."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(description="Revid.ai API key")
|
||||
script: str = SchemaField(
|
||||
description="Narration that will accompany the screenshot.",
|
||||
placeholder="Check out these amazing stats!",
|
||||
)
|
||||
screenshot_url: str = SchemaField(
|
||||
description="Screenshot or image URL to showcase."
|
||||
)
|
||||
ratio: str = SchemaField(default="9 / 16")
|
||||
target_duration: int = SchemaField(default=30)
|
||||
voice: Voice = SchemaField(default=Voice.EVA)
|
||||
background_music: AudioTrack = SchemaField(
|
||||
default=AudioTrack.DONT_STOP_ME_ABSTRACT_FUTURE_BASS
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
video_url: str = SchemaField(description="Rendered video URL")
|
||||
error: str = SchemaField(description="Error, if encountered")
|
||||
|
||||
async def create_webhook(self) -> tuple[str, str]:
|
||||
"""Create a new webhook URL for receiving notifications."""
|
||||
url = "https://webhook.site/token"
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
response = await Requests().post(url, headers=headers)
|
||||
webhook_data = response.json()
|
||||
return webhook_data["uuid"], f"https://webhook.site/{webhook_data['uuid']}"
|
||||
|
||||
async def create_video(self, api_key: SecretStr, payload: dict) -> dict:
|
||||
"""Create a video using the Revid API."""
|
||||
url = "https://www.revid.ai/api/public/v2/render"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().post(url, json=payload, headers=headers)
|
||||
logger.debug(
|
||||
f"API Response Status Code: {response.status}, Content: {response.text}"
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def check_video_status(self, api_key: SecretStr, pid: str) -> dict:
|
||||
"""Check the status of a video creation job."""
|
||||
url = f"https://www.revid.ai/api/public/v2/status?pid={pid}"
|
||||
headers = {"key": api_key.get_secret_value()}
|
||||
response = await Requests().get(url, headers=headers)
|
||||
return response.json()
|
||||
|
||||
async def wait_for_video(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
pid: str,
|
||||
max_wait_time: int = 1000,
|
||||
) -> str:
|
||||
"""Wait for video creation to complete and return the video URL."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status = await self.check_video_status(api_key, pid)
|
||||
logger.debug(f"Video status: {status}")
|
||||
|
||||
if status.get("status") == "ready" and "videoUrl" in status:
|
||||
return status["videoUrl"]
|
||||
elif status.get("status") == "error":
|
||||
error_message = status.get("error", "Unknown error occurred")
|
||||
logger.error(f"Video creation failed: {error_message}")
|
||||
raise ValueError(f"Video creation failed: {error_message}")
|
||||
elif status.get("status") in ["FAILED", "CANCELED"]:
|
||||
logger.error(f"Video creation failed: {status.get('message')}")
|
||||
raise ValueError(f"Video creation failed: {status.get('message')}")
|
||||
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0f3e4635-e810-43d9-9e81-49e6f4e83b7c",
|
||||
description="Turns a screenshot into an engaging, avatar‑narrated video advert.",
|
||||
categories={BlockCategory.AI, BlockCategory.MARKETING},
|
||||
input_schema=AIScreenshotToVideoAdBlock.Input,
|
||||
output_schema=AIScreenshotToVideoAdBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"script": "Amazing numbers!",
|
||||
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
|
||||
},
|
||||
test_output=("video_url", "https://example.com/screenshot.mp4"),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
"https://webhook.site/test_uuid",
|
||||
),
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/screenshot.mp4",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
|
||||
payload = {
|
||||
"webhook": webhook_url,
|
||||
"creationParams": {
|
||||
"targetDuration": input_data.target_duration,
|
||||
"ratio": input_data.ratio,
|
||||
"mediaType": "aiVideo",
|
||||
"hasAvatar": True,
|
||||
"removeAvatarBackground": True,
|
||||
"inputText": input_data.script,
|
||||
"flowType": "text-to-video",
|
||||
"slug": "ai-ad-generator",
|
||||
"slugNew": "screenshot-to-video-ad",
|
||||
"isCopiedFrom": "ai-ad-generator",
|
||||
"hasToGenerateVoice": True,
|
||||
"hasToTranscript": False,
|
||||
"hasToSearchMedia": True,
|
||||
"hasWebsiteRecorder": False,
|
||||
"hasTextSmallAtBottom": False,
|
||||
"selectedAudio": input_data.background_music.value,
|
||||
"selectedVoice": input_data.voice.voice_id,
|
||||
"selectedAvatar": "https://cdn.revid.ai/avatars/young-woman.mp4",
|
||||
"selectedAvatarType": "video/mp4",
|
||||
"websiteToRecord": "",
|
||||
"hasToGenerateCover": True,
|
||||
"nbGenerations": 1,
|
||||
"disableCaptions": False,
|
||||
"mediaMultiplier": "medium",
|
||||
"characters": [],
|
||||
"captionPresetName": "Revid",
|
||||
"sourceType": "contentScraping",
|
||||
"selectedStoryStyle": {"value": "custom", "label": "General"},
|
||||
"generationPreset": "DEFAULT",
|
||||
"hasToGenerateMusic": False,
|
||||
"isOptimizedForChinese": False,
|
||||
"generationUserPrompt": "",
|
||||
"enableNsfwFilter": False,
|
||||
"addStickers": False,
|
||||
"typeMovingImageAnim": "dynamic",
|
||||
"hasToGenerateSoundEffects": False,
|
||||
"forceModelType": "gpt-image-1",
|
||||
"selectedCharacters": [],
|
||||
"lang": "",
|
||||
"voiceSpeed": 1,
|
||||
"disableAudio": False,
|
||||
"disableVoice": False,
|
||||
"useOnlyProvidedMedia": True,
|
||||
"imageGenerationModel": "ultra",
|
||||
"videoGenerationModel": "ultra",
|
||||
"hasEnhancedGeneration": True,
|
||||
"hasEnhancedGenerationPro": True,
|
||||
"inputMedias": [
|
||||
{"url": input_data.screenshot_url, "title": "", "type": "image"}
|
||||
],
|
||||
"hasToGenerateVideos": True,
|
||||
"audioUrl": input_data.background_music.audio_url,
|
||||
"watermark": None,
|
||||
},
|
||||
}
|
||||
|
||||
response = await self.create_video(credentials.api_key, payload)
|
||||
pid = response.get("pid")
|
||||
if not pid:
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
yield "video_url", video_url
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import List
|
||||
from backend.blocks.apollo._auth import ApolloCredentials
|
||||
from backend.blocks.apollo.models import (
|
||||
Contact,
|
||||
EnrichPersonRequest,
|
||||
Organization,
|
||||
SearchOrganizationsRequest,
|
||||
SearchOrganizationsResponse,
|
||||
@@ -30,10 +29,10 @@ class ApolloClient:
|
||||
|
||||
async def search_people(self, query: SearchPeopleRequest) -> List[Contact]:
|
||||
"""Search for people in Apollo"""
|
||||
response = await self.requests.post(
|
||||
response = await self.requests.get(
|
||||
f"{self.API_URL}/mixed_people/search",
|
||||
headers=self._get_headers(),
|
||||
json=query.model_dump(exclude={"max_results"}),
|
||||
params=query.model_dump(exclude={"credentials", "max_results"}),
|
||||
)
|
||||
data = response.json()
|
||||
parsed_response = SearchPeopleResponse(**data)
|
||||
@@ -54,10 +53,10 @@ class ApolloClient:
|
||||
and len(parsed_response.people) > 0
|
||||
):
|
||||
query.page += 1
|
||||
response = await self.requests.post(
|
||||
response = await self.requests.get(
|
||||
f"{self.API_URL}/mixed_people/search",
|
||||
headers=self._get_headers(),
|
||||
json=query.model_dump(exclude={"max_results"}),
|
||||
params=query.model_dump(exclude={"credentials", "max_results"}),
|
||||
)
|
||||
data = response.json()
|
||||
parsed_response = SearchPeopleResponse(**data)
|
||||
@@ -70,10 +69,10 @@ class ApolloClient:
|
||||
self, query: SearchOrganizationsRequest
|
||||
) -> List[Organization]:
|
||||
"""Search for organizations in Apollo"""
|
||||
response = await self.requests.post(
|
||||
response = await self.requests.get(
|
||||
f"{self.API_URL}/mixed_companies/search",
|
||||
headers=self._get_headers(),
|
||||
json=query.model_dump(exclude={"max_results"}),
|
||||
params=query.model_dump(exclude={"credentials", "max_results"}),
|
||||
)
|
||||
data = response.json()
|
||||
parsed_response = SearchOrganizationsResponse(**data)
|
||||
@@ -94,10 +93,10 @@ class ApolloClient:
|
||||
and len(parsed_response.organizations) > 0
|
||||
):
|
||||
query.page += 1
|
||||
response = await self.requests.post(
|
||||
response = await self.requests.get(
|
||||
f"{self.API_URL}/mixed_companies/search",
|
||||
headers=self._get_headers(),
|
||||
json=query.model_dump(exclude={"max_results"}),
|
||||
params=query.model_dump(exclude={"credentials", "max_results"}),
|
||||
)
|
||||
data = response.json()
|
||||
parsed_response = SearchOrganizationsResponse(**data)
|
||||
@@ -111,21 +110,3 @@ class ApolloClient:
|
||||
return (
|
||||
organizations[: query.max_results] if query.max_results else organizations
|
||||
)
|
||||
|
||||
async def enrich_person(self, query: EnrichPersonRequest) -> Contact:
|
||||
"""Enrich a person's data including email & phone reveal"""
|
||||
response = await self.requests.post(
|
||||
f"{self.API_URL}/people/match",
|
||||
headers=self._get_headers(),
|
||||
json=query.model_dump(),
|
||||
params={
|
||||
"reveal_personal_emails": "true",
|
||||
},
|
||||
)
|
||||
data = response.json()
|
||||
if "person" not in data:
|
||||
raise ValueError(f"Person not found or enrichment failed: {data}")
|
||||
|
||||
contact = Contact(**data["person"])
|
||||
contact.email = contact.email or "-"
|
||||
return contact
|
||||
|
||||
@@ -1,31 +1,17 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel as OriginalBaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class BaseModel(OriginalBaseModel):
|
||||
def model_dump(self, *args, exclude: set[str] | None = None, **kwargs):
|
||||
if exclude is None:
|
||||
exclude = set("credentials")
|
||||
else:
|
||||
exclude.add("credentials")
|
||||
|
||||
kwargs.setdefault("exclude_none", True)
|
||||
kwargs.setdefault("exclude_unset", True)
|
||||
kwargs.setdefault("exclude_defaults", True)
|
||||
return super().model_dump(*args, exclude=exclude, **kwargs)
|
||||
|
||||
|
||||
class PrimaryPhone(BaseModel):
|
||||
"""A primary phone in Apollo"""
|
||||
|
||||
number: Optional[str] = ""
|
||||
source: Optional[str] = ""
|
||||
sanitized_number: Optional[str] = ""
|
||||
number: str
|
||||
source: str
|
||||
sanitized_number: str
|
||||
|
||||
|
||||
class SenorityLevels(str, Enum):
|
||||
@@ -56,102 +42,102 @@ class ContactEmailStatuses(str, Enum):
|
||||
class RuleConfigStatus(BaseModel):
|
||||
"""A rule config status in Apollo"""
|
||||
|
||||
_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
rule_action_config_id: Optional[str] = ""
|
||||
rule_config_id: Optional[str] = ""
|
||||
status_cd: Optional[str] = ""
|
||||
updated_at: Optional[str] = ""
|
||||
id: Optional[str] = ""
|
||||
key: Optional[str] = ""
|
||||
_id: str
|
||||
created_at: str
|
||||
rule_action_config_id: str
|
||||
rule_config_id: str
|
||||
status_cd: str
|
||||
updated_at: str
|
||||
id: str
|
||||
key: str
|
||||
|
||||
|
||||
class ContactCampaignStatus(BaseModel):
|
||||
"""A contact campaign status in Apollo"""
|
||||
|
||||
id: Optional[str] = ""
|
||||
emailer_campaign_id: Optional[str] = ""
|
||||
send_email_from_user_id: Optional[str] = ""
|
||||
inactive_reason: Optional[str] = ""
|
||||
status: Optional[str] = ""
|
||||
added_at: Optional[str] = ""
|
||||
added_by_user_id: Optional[str] = ""
|
||||
finished_at: Optional[str] = ""
|
||||
paused_at: Optional[str] = ""
|
||||
auto_unpause_at: Optional[str] = ""
|
||||
send_email_from_email_address: Optional[str] = ""
|
||||
send_email_from_email_account_id: Optional[str] = ""
|
||||
manually_set_unpause: Optional[str] = ""
|
||||
failure_reason: Optional[str] = ""
|
||||
current_step_id: Optional[str] = ""
|
||||
in_response_to_emailer_message_id: Optional[str] = ""
|
||||
cc_emails: Optional[str] = ""
|
||||
bcc_emails: Optional[str] = ""
|
||||
to_emails: Optional[str] = ""
|
||||
id: str
|
||||
emailer_campaign_id: str
|
||||
send_email_from_user_id: str
|
||||
inactive_reason: str
|
||||
status: str
|
||||
added_at: str
|
||||
added_by_user_id: str
|
||||
finished_at: str
|
||||
paused_at: str
|
||||
auto_unpause_at: str
|
||||
send_email_from_email_address: str
|
||||
send_email_from_email_account_id: str
|
||||
manually_set_unpause: str
|
||||
failure_reason: str
|
||||
current_step_id: str
|
||||
in_response_to_emailer_message_id: str
|
||||
cc_emails: str
|
||||
bcc_emails: str
|
||||
to_emails: str
|
||||
|
||||
|
||||
class Account(BaseModel):
|
||||
"""An account in Apollo"""
|
||||
|
||||
id: Optional[str] = ""
|
||||
name: Optional[str] = ""
|
||||
website_url: Optional[str] = ""
|
||||
blog_url: Optional[str] = ""
|
||||
angellist_url: Optional[str] = ""
|
||||
linkedin_url: Optional[str] = ""
|
||||
twitter_url: Optional[str] = ""
|
||||
facebook_url: Optional[str] = ""
|
||||
primary_phone: Optional[PrimaryPhone] = PrimaryPhone()
|
||||
languages: Optional[list[str]] = []
|
||||
alexa_ranking: Optional[int] = 0
|
||||
phone: Optional[str] = ""
|
||||
linkedin_uid: Optional[str] = ""
|
||||
founded_year: Optional[int] = 0
|
||||
publicly_traded_symbol: Optional[str] = ""
|
||||
publicly_traded_exchange: Optional[str] = ""
|
||||
logo_url: Optional[str] = ""
|
||||
chrunchbase_url: Optional[str] = ""
|
||||
primary_domain: Optional[str] = ""
|
||||
domain: Optional[str] = ""
|
||||
team_id: Optional[str] = ""
|
||||
organization_id: Optional[str] = ""
|
||||
account_stage_id: Optional[str] = ""
|
||||
source: Optional[str] = ""
|
||||
original_source: Optional[str] = ""
|
||||
creator_id: Optional[str] = ""
|
||||
owner_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
phone_status: Optional[str] = ""
|
||||
hubspot_id: Optional[str] = ""
|
||||
salesforce_id: Optional[str] = ""
|
||||
crm_owner_id: Optional[str] = ""
|
||||
parent_account_id: Optional[str] = ""
|
||||
sanitized_phone: Optional[str] = ""
|
||||
id: str
|
||||
name: str
|
||||
website_url: str
|
||||
blog_url: str
|
||||
angellist_url: str
|
||||
linkedin_url: str
|
||||
twitter_url: str
|
||||
facebook_url: str
|
||||
primary_phone: PrimaryPhone
|
||||
languages: list[str]
|
||||
alexa_ranking: int
|
||||
phone: str
|
||||
linkedin_uid: str
|
||||
founded_year: int
|
||||
publicly_traded_symbol: str
|
||||
publicly_traded_exchange: str
|
||||
logo_url: str
|
||||
chrunchbase_url: str
|
||||
primary_domain: str
|
||||
domain: str
|
||||
team_id: str
|
||||
organization_id: str
|
||||
account_stage_id: str
|
||||
source: str
|
||||
original_source: str
|
||||
creator_id: str
|
||||
owner_id: str
|
||||
created_at: str
|
||||
phone_status: str
|
||||
hubspot_id: str
|
||||
salesforce_id: str
|
||||
crm_owner_id: str
|
||||
parent_account_id: str
|
||||
sanitized_phone: str
|
||||
# no listed type on the API docs
|
||||
account_playbook_statues: Optional[list[Any]] = []
|
||||
account_rule_config_statuses: Optional[list[RuleConfigStatus]] = []
|
||||
existence_level: Optional[str] = ""
|
||||
label_ids: Optional[list[str]] = []
|
||||
typed_custom_fields: Optional[Any] = {}
|
||||
custom_field_errors: Optional[Any] = {}
|
||||
modality: Optional[str] = ""
|
||||
source_display_name: Optional[str] = ""
|
||||
salesforce_record_id: Optional[str] = ""
|
||||
crm_record_url: Optional[str] = ""
|
||||
account_playbook_statues: list[Any]
|
||||
account_rule_config_statuses: list[RuleConfigStatus]
|
||||
existence_level: str
|
||||
label_ids: list[str]
|
||||
typed_custom_fields: Any
|
||||
custom_field_errors: Any
|
||||
modality: str
|
||||
source_display_name: str
|
||||
salesforce_record_id: str
|
||||
crm_record_url: str
|
||||
|
||||
|
||||
class ContactEmail(BaseModel):
|
||||
"""A contact email in Apollo"""
|
||||
|
||||
email: Optional[str] = ""
|
||||
email_md5: Optional[str] = ""
|
||||
email_sha256: Optional[str] = ""
|
||||
email_status: Optional[str] = ""
|
||||
email_source: Optional[str] = ""
|
||||
extrapolated_email_confidence: Optional[str] = ""
|
||||
position: Optional[int] = 0
|
||||
email_from_customer: Optional[str] = ""
|
||||
free_domain: Optional[bool] = True
|
||||
email: str = ""
|
||||
email_md5: str = ""
|
||||
email_sha256: str = ""
|
||||
email_status: str = ""
|
||||
email_source: str = ""
|
||||
extrapolated_email_confidence: str = ""
|
||||
position: int = 0
|
||||
email_from_customer: str = ""
|
||||
free_domain: bool = True
|
||||
|
||||
|
||||
class EmploymentHistory(BaseModel):
|
||||
@@ -164,40 +150,40 @@ class EmploymentHistory(BaseModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
current: Optional[bool] = False
|
||||
degree: Optional[str] = ""
|
||||
description: Optional[str] = ""
|
||||
emails: Optional[str] = ""
|
||||
end_date: Optional[str] = ""
|
||||
grade_level: Optional[str] = ""
|
||||
kind: Optional[str] = ""
|
||||
major: Optional[str] = ""
|
||||
organization_id: Optional[str] = ""
|
||||
organization_name: Optional[str] = ""
|
||||
raw_address: Optional[str] = ""
|
||||
start_date: Optional[str] = ""
|
||||
title: Optional[str] = ""
|
||||
updated_at: Optional[str] = ""
|
||||
id: Optional[str] = ""
|
||||
key: Optional[str] = ""
|
||||
_id: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
current: Optional[bool] = None
|
||||
degree: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
emails: Optional[str] = None
|
||||
end_date: Optional[str] = None
|
||||
grade_level: Optional[str] = None
|
||||
kind: Optional[str] = None
|
||||
major: Optional[str] = None
|
||||
organization_id: Optional[str] = None
|
||||
organization_name: Optional[str] = None
|
||||
raw_address: Optional[str] = None
|
||||
start_date: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
id: Optional[str] = None
|
||||
key: Optional[str] = None
|
||||
|
||||
|
||||
class Breadcrumb(BaseModel):
|
||||
"""A breadcrumb in Apollo"""
|
||||
|
||||
label: Optional[str] = ""
|
||||
signal_field_name: Optional[str] = ""
|
||||
value: str | list | None = ""
|
||||
display_name: Optional[str] = ""
|
||||
label: Optional[str] = "N/A"
|
||||
signal_field_name: Optional[str] = "N/A"
|
||||
value: str | list | None = "N/A"
|
||||
display_name: Optional[str] = "N/A"
|
||||
|
||||
|
||||
class TypedCustomField(BaseModel):
|
||||
"""A typed custom field in Apollo"""
|
||||
|
||||
id: Optional[str] = ""
|
||||
value: Optional[str] = ""
|
||||
id: Optional[str] = "N/A"
|
||||
value: Optional[str] = "N/A"
|
||||
|
||||
|
||||
class Pagination(BaseModel):
|
||||
@@ -219,23 +205,23 @@ class Pagination(BaseModel):
|
||||
class DialerFlags(BaseModel):
|
||||
"""A dialer flags in Apollo"""
|
||||
|
||||
country_name: Optional[str] = ""
|
||||
country_enabled: Optional[bool] = True
|
||||
high_risk_calling_enabled: Optional[bool] = True
|
||||
potential_high_risk_number: Optional[bool] = True
|
||||
country_name: str
|
||||
country_enabled: bool
|
||||
high_risk_calling_enabled: bool
|
||||
potential_high_risk_number: bool
|
||||
|
||||
|
||||
class PhoneNumber(BaseModel):
|
||||
"""A phone number in Apollo"""
|
||||
|
||||
raw_number: Optional[str] = ""
|
||||
sanitized_number: Optional[str] = ""
|
||||
type: Optional[str] = ""
|
||||
position: Optional[int] = 0
|
||||
status: Optional[str] = ""
|
||||
dnc_status: Optional[str] = ""
|
||||
dnc_other_info: Optional[str] = ""
|
||||
dailer_flags: Optional[DialerFlags] = DialerFlags(
|
||||
raw_number: str = ""
|
||||
sanitized_number: str = ""
|
||||
type: str = ""
|
||||
position: int = 0
|
||||
status: str = ""
|
||||
dnc_status: str = ""
|
||||
dnc_other_info: str = ""
|
||||
dailer_flags: DialerFlags = DialerFlags(
|
||||
country_name="",
|
||||
country_enabled=True,
|
||||
high_risk_calling_enabled=True,
|
||||
@@ -253,31 +239,33 @@ class Organization(BaseModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
id: Optional[str] = ""
|
||||
name: Optional[str] = ""
|
||||
website_url: Optional[str] = ""
|
||||
blog_url: Optional[str] = ""
|
||||
angellist_url: Optional[str] = ""
|
||||
linkedin_url: Optional[str] = ""
|
||||
twitter_url: Optional[str] = ""
|
||||
facebook_url: Optional[str] = ""
|
||||
primary_phone: Optional[PrimaryPhone] = PrimaryPhone()
|
||||
languages: Optional[list[str]] = []
|
||||
id: Optional[str] = "N/A"
|
||||
name: Optional[str] = "N/A"
|
||||
website_url: Optional[str] = "N/A"
|
||||
blog_url: Optional[str] = "N/A"
|
||||
angellist_url: Optional[str] = "N/A"
|
||||
linkedin_url: Optional[str] = "N/A"
|
||||
twitter_url: Optional[str] = "N/A"
|
||||
facebook_url: Optional[str] = "N/A"
|
||||
primary_phone: Optional[PrimaryPhone] = PrimaryPhone(
|
||||
number="N/A", source="N/A", sanitized_number="N/A"
|
||||
)
|
||||
languages: list[str] = []
|
||||
alexa_ranking: Optional[int] = 0
|
||||
phone: Optional[str] = ""
|
||||
linkedin_uid: Optional[str] = ""
|
||||
phone: Optional[str] = "N/A"
|
||||
linkedin_uid: Optional[str] = "N/A"
|
||||
founded_year: Optional[int] = 0
|
||||
publicly_traded_symbol: Optional[str] = ""
|
||||
publicly_traded_exchange: Optional[str] = ""
|
||||
logo_url: Optional[str] = ""
|
||||
chrunchbase_url: Optional[str] = ""
|
||||
primary_domain: Optional[str] = ""
|
||||
sanitized_phone: Optional[str] = ""
|
||||
owned_by_organization_id: Optional[str] = ""
|
||||
intent_strength: Optional[str] = ""
|
||||
show_intent: Optional[bool] = True
|
||||
publicly_traded_symbol: Optional[str] = "N/A"
|
||||
publicly_traded_exchange: Optional[str] = "N/A"
|
||||
logo_url: Optional[str] = "N/A"
|
||||
chrunchbase_url: Optional[str] = "N/A"
|
||||
primary_domain: Optional[str] = "N/A"
|
||||
sanitized_phone: Optional[str] = "N/A"
|
||||
owned_by_organization_id: Optional[str] = "N/A"
|
||||
intent_strength: Optional[str] = "N/A"
|
||||
show_intent: bool = True
|
||||
has_intent_signal_account: Optional[bool] = True
|
||||
intent_signal_account: Optional[str] = ""
|
||||
intent_signal_account: Optional[str] = "N/A"
|
||||
|
||||
|
||||
class Contact(BaseModel):
|
||||
@@ -290,95 +278,95 @@ class Contact(BaseModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
contact_roles: Optional[list[Any]] = []
|
||||
id: Optional[str] = ""
|
||||
first_name: Optional[str] = ""
|
||||
last_name: Optional[str] = ""
|
||||
name: Optional[str] = ""
|
||||
linkedin_url: Optional[str] = ""
|
||||
title: Optional[str] = ""
|
||||
contact_stage_id: Optional[str] = ""
|
||||
owner_id: Optional[str] = ""
|
||||
creator_id: Optional[str] = ""
|
||||
person_id: Optional[str] = ""
|
||||
email_needs_tickling: Optional[bool] = True
|
||||
organization_name: Optional[str] = ""
|
||||
source: Optional[str] = ""
|
||||
original_source: Optional[str] = ""
|
||||
organization_id: Optional[str] = ""
|
||||
headline: Optional[str] = ""
|
||||
photo_url: Optional[str] = ""
|
||||
present_raw_address: Optional[str] = ""
|
||||
linkededin_uid: Optional[str] = ""
|
||||
extrapolated_email_confidence: Optional[float] = 0.0
|
||||
salesforce_id: Optional[str] = ""
|
||||
salesforce_lead_id: Optional[str] = ""
|
||||
salesforce_contact_id: Optional[str] = ""
|
||||
saleforce_account_id: Optional[str] = ""
|
||||
crm_owner_id: Optional[str] = ""
|
||||
created_at: Optional[str] = ""
|
||||
emailer_campaign_ids: Optional[list[str]] = []
|
||||
direct_dial_status: Optional[str] = ""
|
||||
direct_dial_enrichment_failed_at: Optional[str] = ""
|
||||
email_status: Optional[str] = ""
|
||||
email_source: Optional[str] = ""
|
||||
account_id: Optional[str] = ""
|
||||
last_activity_date: Optional[str] = ""
|
||||
hubspot_vid: Optional[str] = ""
|
||||
hubspot_company_id: Optional[str] = ""
|
||||
crm_id: Optional[str] = ""
|
||||
sanitized_phone: Optional[str] = ""
|
||||
merged_crm_ids: Optional[str] = ""
|
||||
updated_at: Optional[str] = ""
|
||||
queued_for_crm_push: Optional[bool] = True
|
||||
suggested_from_rule_engine_config_id: Optional[str] = ""
|
||||
email_unsubscribed: Optional[str] = ""
|
||||
label_ids: Optional[list[Any]] = []
|
||||
has_pending_email_arcgate_request: Optional[bool] = True
|
||||
has_email_arcgate_request: Optional[bool] = True
|
||||
existence_level: Optional[str] = ""
|
||||
email: Optional[str] = ""
|
||||
email_from_customer: Optional[str] = ""
|
||||
typed_custom_fields: Optional[list[TypedCustomField]] = []
|
||||
custom_field_errors: Optional[Any] = {}
|
||||
salesforce_record_id: Optional[str] = ""
|
||||
crm_record_url: Optional[str] = ""
|
||||
email_status_unavailable_reason: Optional[str] = ""
|
||||
email_true_status: Optional[str] = ""
|
||||
updated_email_true_status: Optional[bool] = True
|
||||
contact_rule_config_statuses: Optional[list[RuleConfigStatus]] = []
|
||||
source_display_name: Optional[str] = ""
|
||||
twitter_url: Optional[str] = ""
|
||||
contact_campaign_statuses: Optional[list[ContactCampaignStatus]] = []
|
||||
state: Optional[str] = ""
|
||||
city: Optional[str] = ""
|
||||
country: Optional[str] = ""
|
||||
account: Optional[Account] = Account()
|
||||
contact_emails: Optional[list[ContactEmail]] = []
|
||||
organization: Optional[Organization] = Organization()
|
||||
employment_history: Optional[list[EmploymentHistory]] = []
|
||||
time_zone: Optional[str] = ""
|
||||
intent_strength: Optional[str] = ""
|
||||
show_intent: Optional[bool] = True
|
||||
phone_numbers: Optional[list[PhoneNumber]] = []
|
||||
account_phone_note: Optional[str] = ""
|
||||
free_domain: Optional[bool] = True
|
||||
is_likely_to_engage: Optional[bool] = True
|
||||
email_domain_catchall: Optional[bool] = True
|
||||
contact_job_change_event: Optional[str] = ""
|
||||
contact_roles: list[Any] = []
|
||||
id: Optional[str] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
linkedin_url: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
contact_stage_id: Optional[str] = None
|
||||
owner_id: Optional[str] = None
|
||||
creator_id: Optional[str] = None
|
||||
person_id: Optional[str] = None
|
||||
email_needs_tickling: bool = True
|
||||
organization_name: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
original_source: Optional[str] = None
|
||||
organization_id: Optional[str] = None
|
||||
headline: Optional[str] = None
|
||||
photo_url: Optional[str] = None
|
||||
present_raw_address: Optional[str] = None
|
||||
linkededin_uid: Optional[str] = None
|
||||
extrapolated_email_confidence: Optional[float] = None
|
||||
salesforce_id: Optional[str] = None
|
||||
salesforce_lead_id: Optional[str] = None
|
||||
salesforce_contact_id: Optional[str] = None
|
||||
saleforce_account_id: Optional[str] = None
|
||||
crm_owner_id: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
emailer_campaign_ids: list[str] = []
|
||||
direct_dial_status: Optional[str] = None
|
||||
direct_dial_enrichment_failed_at: Optional[str] = None
|
||||
email_status: Optional[str] = None
|
||||
email_source: Optional[str] = None
|
||||
account_id: Optional[str] = None
|
||||
last_activity_date: Optional[str] = None
|
||||
hubspot_vid: Optional[str] = None
|
||||
hubspot_company_id: Optional[str] = None
|
||||
crm_id: Optional[str] = None
|
||||
sanitized_phone: Optional[str] = None
|
||||
merged_crm_ids: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
queued_for_crm_push: bool = True
|
||||
suggested_from_rule_engine_config_id: Optional[str] = None
|
||||
email_unsubscribed: Optional[str] = None
|
||||
label_ids: list[Any] = []
|
||||
has_pending_email_arcgate_request: bool = True
|
||||
has_email_arcgate_request: bool = True
|
||||
existence_level: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
email_from_customer: Optional[str] = None
|
||||
typed_custom_fields: list[TypedCustomField] = []
|
||||
custom_field_errors: Any = None
|
||||
salesforce_record_id: Optional[str] = None
|
||||
crm_record_url: Optional[str] = None
|
||||
email_status_unavailable_reason: Optional[str] = None
|
||||
email_true_status: Optional[str] = None
|
||||
updated_email_true_status: bool = True
|
||||
contact_rule_config_statuses: list[RuleConfigStatus] = []
|
||||
source_display_name: Optional[str] = None
|
||||
twitter_url: Optional[str] = None
|
||||
contact_campaign_statuses: list[ContactCampaignStatus] = []
|
||||
state: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
account: Optional[Account] = None
|
||||
contact_emails: list[ContactEmail] = []
|
||||
organization: Optional[Organization] = None
|
||||
employment_history: list[EmploymentHistory] = []
|
||||
time_zone: Optional[str] = None
|
||||
intent_strength: Optional[str] = None
|
||||
show_intent: bool = True
|
||||
phone_numbers: list[PhoneNumber] = []
|
||||
account_phone_note: Optional[str] = None
|
||||
free_domain: bool = True
|
||||
is_likely_to_engage: bool = True
|
||||
email_domain_catchall: bool = True
|
||||
contact_job_change_event: Optional[str] = None
|
||||
|
||||
|
||||
class SearchOrganizationsRequest(BaseModel):
|
||||
"""Request for Apollo's search organizations API"""
|
||||
|
||||
organization_num_employees_range: Optional[list[int]] = SchemaField(
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default=[0, 1000000],
|
||||
)
|
||||
|
||||
organization_locations: Optional[list[str]] = SchemaField(
|
||||
organization_locations: list[str] = SchemaField(
|
||||
description="""The location of the company headquarters. You can search across cities, US states, and countries.
|
||||
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, any Boston-based companies will not appearch in your search results, even if they match other parameters.
|
||||
@@ -387,30 +375,28 @@ To exclude companies based on location, use the organization_not_locations param
|
||||
""",
|
||||
default_factory=list,
|
||||
)
|
||||
organizations_not_locations: Optional[list[str]] = SchemaField(
|
||||
organizations_not_locations: list[str] = SchemaField(
|
||||
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
|
||||
|
||||
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
|
||||
""",
|
||||
default_factory=list,
|
||||
)
|
||||
q_organization_keyword_tags: Optional[list[str]] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry.""",
|
||||
default_factory=list,
|
||||
q_organization_keyword_tags: list[str] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry."""
|
||||
)
|
||||
q_organization_name: Optional[str] = SchemaField(
|
||||
q_organization_name: str = SchemaField(
|
||||
description="""Filter search results to include a specific company name.
|
||||
|
||||
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible.""",
|
||||
default="",
|
||||
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible."""
|
||||
)
|
||||
organization_ids: Optional[list[str]] = SchemaField(
|
||||
organization_ids: list[str] = SchemaField(
|
||||
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, identify the values for organization_id when you call this endpoint.""",
|
||||
default_factory=list,
|
||||
)
|
||||
max_results: Optional[int] = SchemaField(
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
ge=1,
|
||||
@@ -435,11 +421,11 @@ Use the page parameter to search the different pages of data.""",
|
||||
class SearchOrganizationsResponse(BaseModel):
|
||||
"""Response from Apollo's search organizations API"""
|
||||
|
||||
breadcrumbs: Optional[list[Breadcrumb]] = []
|
||||
partial_results_only: Optional[bool] = True
|
||||
has_join: Optional[bool] = True
|
||||
disable_eu_prospecting: Optional[bool] = True
|
||||
partial_results_limit: Optional[int] = 0
|
||||
breadcrumbs: list[Breadcrumb] = []
|
||||
partial_results_only: bool = True
|
||||
has_join: bool = True
|
||||
disable_eu_prospecting: bool = True
|
||||
partial_results_limit: int = 0
|
||||
pagination: Pagination = Pagination(
|
||||
page=0, per_page=0, total_entries=0, total_pages=0
|
||||
)
|
||||
@@ -447,14 +433,14 @@ class SearchOrganizationsResponse(BaseModel):
|
||||
accounts: list[Any] = []
|
||||
organizations: list[Organization] = []
|
||||
models_ids: list[str] = []
|
||||
num_fetch_result: Optional[str] = ""
|
||||
derived_params: Optional[str] = ""
|
||||
num_fetch_result: Optional[str] = "N/A"
|
||||
derived_params: Optional[str] = "N/A"
|
||||
|
||||
|
||||
class SearchPeopleRequest(BaseModel):
|
||||
"""Request for Apollo's search people API"""
|
||||
|
||||
person_titles: Optional[list[str]] = SchemaField(
|
||||
person_titles: list[str] = SchemaField(
|
||||
description="""Job titles held by the people you want to find. For a person to be included in search results, they only need to match 1 of the job titles you add. Adding more job titles expands your search results.
|
||||
|
||||
Results also include job titles with the same terms, even if they are not exact matches. For example, searching for marketing manager might return people with the job title content marketing manager.
|
||||
@@ -464,13 +450,13 @@ Use this parameter in combination with the person_seniorities[] parameter to fin
|
||||
default_factory=list,
|
||||
placeholder="marketing manager",
|
||||
)
|
||||
person_locations: Optional[list[str]] = SchemaField(
|
||||
person_locations: list[str] = SchemaField(
|
||||
description="""The location where people live. You can search across cities, US states, and countries.
|
||||
|
||||
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
|
||||
default_factory=list,
|
||||
)
|
||||
person_seniorities: Optional[list[SenorityLevels]] = SchemaField(
|
||||
person_seniorities: list[SenorityLevels] = SchemaField(
|
||||
description="""The job seniority that people hold within their current employer. This enables you to find people that currently hold positions at certain reporting levels, such as Director level or senior IC level.
|
||||
|
||||
For a person to be included in search results, they only need to match 1 of the seniorities you add. Adding more seniorities expands your search results.
|
||||
@@ -480,7 +466,7 @@ Searches only return results based on their current job title, so searching for
|
||||
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
|
||||
default_factory=list,
|
||||
)
|
||||
organization_locations: Optional[list[str]] = SchemaField(
|
||||
organization_locations: list[str] = SchemaField(
|
||||
description="""The location of the company headquarters for a person's current employer. You can search across cities, US states, and countries.
|
||||
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
|
||||
@@ -488,7 +474,7 @@ If a company has several office locations, results are still based on the headqu
|
||||
To find people based on their personal location, use the person_locations parameter.""",
|
||||
default_factory=list,
|
||||
)
|
||||
q_organization_domains: Optional[list[str]] = SchemaField(
|
||||
q_organization_domains: list[str] = SchemaField(
|
||||
description="""The domain name for the person's employer. This can be the current employer or a previous employer. Do not include www., the @ symbol, or similar.
|
||||
|
||||
You can add multiple domains to search across companies.
|
||||
@@ -496,23 +482,23 @@ You can add multiple domains to search across companies.
|
||||
Examples: apollo.io and microsoft.com""",
|
||||
default_factory=list,
|
||||
)
|
||||
contact_email_statuses: Optional[list[ContactEmailStatuses]] = SchemaField(
|
||||
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
|
||||
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
|
||||
default_factory=list,
|
||||
)
|
||||
organization_ids: Optional[list[str]] = SchemaField(
|
||||
organization_ids: list[str] = SchemaField(
|
||||
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
|
||||
default_factory=list,
|
||||
)
|
||||
organization_num_employees_range: Optional[list[int]] = SchemaField(
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default_factory=list,
|
||||
)
|
||||
q_keywords: Optional[str] = SchemaField(
|
||||
q_keywords: str = SchemaField(
|
||||
description="""A string of words over which we want to filter the results""",
|
||||
default="",
|
||||
)
|
||||
@@ -528,7 +514,7 @@ Use this parameter in combination with the per_page parameter to make search res
|
||||
Use the page parameter to search the different pages of data.""",
|
||||
default=100,
|
||||
)
|
||||
max_results: Optional[int] = SchemaField(
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
ge=1,
|
||||
@@ -547,61 +533,16 @@ class SearchPeopleResponse(BaseModel):
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
breadcrumbs: Optional[list[Breadcrumb]] = []
|
||||
partial_results_only: Optional[bool] = True
|
||||
has_join: Optional[bool] = True
|
||||
disable_eu_prospecting: Optional[bool] = True
|
||||
partial_results_limit: Optional[int] = 0
|
||||
breadcrumbs: list[Breadcrumb] = []
|
||||
partial_results_only: bool = True
|
||||
has_join: bool = True
|
||||
disable_eu_prospecting: bool = True
|
||||
partial_results_limit: int = 0
|
||||
pagination: Pagination = Pagination(
|
||||
page=0, per_page=0, total_entries=0, total_pages=0
|
||||
)
|
||||
contacts: list[Contact] = []
|
||||
people: list[Contact] = []
|
||||
model_ids: list[str] = []
|
||||
num_fetch_result: Optional[str] = ""
|
||||
derived_params: Optional[str] = ""
|
||||
|
||||
|
||||
class EnrichPersonRequest(BaseModel):
|
||||
"""Request for Apollo's person enrichment API"""
|
||||
|
||||
person_id: Optional[str] = SchemaField(
|
||||
description="Apollo person ID to enrich (most accurate method)",
|
||||
default="",
|
||||
)
|
||||
first_name: Optional[str] = SchemaField(
|
||||
description="First name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
last_name: Optional[str] = SchemaField(
|
||||
description="Last name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
name: Optional[str] = SchemaField(
|
||||
description="Full name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
email: Optional[str] = SchemaField(
|
||||
description="Email address of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
domain: Optional[str] = SchemaField(
|
||||
description="Company domain of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
company: Optional[str] = SchemaField(
|
||||
description="Company name of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
linkedin_url: Optional[str] = SchemaField(
|
||||
description="LinkedIn URL of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
organization_id: Optional[str] = SchemaField(
|
||||
description="Apollo organization ID of the person's company",
|
||||
default="",
|
||||
)
|
||||
title: Optional[str] = SchemaField(
|
||||
description="Job title of the person to enrich",
|
||||
default="",
|
||||
)
|
||||
num_fetch_result: Optional[str] = "N/A"
|
||||
derived_params: Optional[str] = "N/A"
|
||||
|
||||
@@ -11,14 +11,14 @@ from backend.blocks.apollo.models import (
|
||||
SearchOrganizationsRequest,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class SearchOrganizationsBlock(Block):
|
||||
"""Search for organizations in Apollo"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
organization_num_employees_range: list[int] = SchemaField(
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
@@ -65,7 +65,7 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
le=50000,
|
||||
advanced=True,
|
||||
)
|
||||
credentials: ApolloCredentialsInput = CredentialsField(
|
||||
credentials: ApolloCredentialsInput = SchemaField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
@@ -210,7 +210,9 @@ To find IDs, identify the values for organization_id when you call this endpoint
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: ApolloCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
query = SearchOrganizationsRequest(**input_data.model_dump())
|
||||
query = SearchOrganizationsRequest(
|
||||
**input_data.model_dump(exclude={"credentials"})
|
||||
)
|
||||
organizations = await self.search_organizations(query, credentials)
|
||||
for organization in organizations:
|
||||
yield "organization", organization
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import asyncio
|
||||
|
||||
from backend.blocks.apollo._api import ApolloClient
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
@@ -10,12 +8,11 @@ from backend.blocks.apollo._auth import (
|
||||
from backend.blocks.apollo.models import (
|
||||
Contact,
|
||||
ContactEmailStatuses,
|
||||
EnrichPersonRequest,
|
||||
SearchPeopleRequest,
|
||||
SenorityLevels,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class SearchPeopleBlock(Block):
|
||||
@@ -80,7 +77,7 @@ class SearchPeopleBlock(Block):
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
organization_num_employees_range: list[int] = SchemaField(
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
@@ -93,19 +90,14 @@ class SearchPeopleBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 25. Limited to 500 to prevent overspending.""",
|
||||
default=25,
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
ge=1,
|
||||
le=500,
|
||||
advanced=True,
|
||||
)
|
||||
enrich_info: bool = SchemaField(
|
||||
description="""Whether to enrich contacts with detailed information including real email addresses. This will double the search cost.""",
|
||||
default=False,
|
||||
le=50000,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
credentials: ApolloCredentialsInput = CredentialsField(
|
||||
credentials: ApolloCredentialsInput = SchemaField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
@@ -114,6 +106,9 @@ class SearchPeopleBlock(Block):
|
||||
description="List of people found",
|
||||
default_factory=list,
|
||||
)
|
||||
person: Contact = SchemaField(
|
||||
description="Each found person, one at a time",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the search failed",
|
||||
default="",
|
||||
@@ -129,6 +124,87 @@ class SearchPeopleBlock(Block):
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
(
|
||||
"person",
|
||||
Contact(
|
||||
contact_roles=[],
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
organization_id="123456",
|
||||
contact_stage_id="1",
|
||||
owner_id="1",
|
||||
creator_id="1",
|
||||
person_id="1",
|
||||
email_needs_tickling=True,
|
||||
source="apollo",
|
||||
original_source="apollo",
|
||||
headline="Software Engineer",
|
||||
photo_url="https://www.linkedin.com/in/johndoe",
|
||||
present_raw_address="123 Main St, Anytown, USA",
|
||||
linkededin_uid="123456",
|
||||
extrapolated_email_confidence=0.8,
|
||||
salesforce_id="123456",
|
||||
salesforce_lead_id="123456",
|
||||
salesforce_contact_id="123456",
|
||||
saleforce_account_id="123456",
|
||||
crm_owner_id="123456",
|
||||
created_at="2021-01-01",
|
||||
emailer_campaign_ids=[],
|
||||
direct_dial_status="active",
|
||||
direct_dial_enrichment_failed_at="2021-01-01",
|
||||
email_status="active",
|
||||
email_source="apollo",
|
||||
account_id="123456",
|
||||
last_activity_date="2021-01-01",
|
||||
hubspot_vid="123456",
|
||||
hubspot_company_id="123456",
|
||||
crm_id="123456",
|
||||
sanitized_phone="123456",
|
||||
merged_crm_ids="123456",
|
||||
updated_at="2021-01-01",
|
||||
queued_for_crm_push=True,
|
||||
suggested_from_rule_engine_config_id="123456",
|
||||
email_unsubscribed=None,
|
||||
label_ids=[],
|
||||
has_pending_email_arcgate_request=True,
|
||||
has_email_arcgate_request=True,
|
||||
existence_level=None,
|
||||
email=None,
|
||||
email_from_customer=None,
|
||||
typed_custom_fields=[],
|
||||
custom_field_errors=None,
|
||||
salesforce_record_id=None,
|
||||
crm_record_url=None,
|
||||
email_status_unavailable_reason=None,
|
||||
email_true_status=None,
|
||||
updated_email_true_status=True,
|
||||
contact_rule_config_statuses=[],
|
||||
source_display_name=None,
|
||||
twitter_url=None,
|
||||
contact_campaign_statuses=[],
|
||||
state=None,
|
||||
city=None,
|
||||
country=None,
|
||||
account=None,
|
||||
contact_emails=[],
|
||||
organization=None,
|
||||
employment_history=[],
|
||||
time_zone=None,
|
||||
intent_strength=None,
|
||||
show_intent=True,
|
||||
phone_numbers=[],
|
||||
account_phone_note=None,
|
||||
free_domain=True,
|
||||
is_likely_to_engage=True,
|
||||
email_domain_catchall=True,
|
||||
contact_job_change_event=None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"people",
|
||||
[
|
||||
@@ -303,34 +379,6 @@ class SearchPeopleBlock(Block):
|
||||
client = ApolloClient(credentials)
|
||||
return await client.search_people(query)
|
||||
|
||||
@staticmethod
|
||||
async def enrich_person(
|
||||
query: EnrichPersonRequest, credentials: ApolloCredentials
|
||||
) -> Contact:
|
||||
client = ApolloClient(credentials)
|
||||
return await client.enrich_person(query)
|
||||
|
||||
@staticmethod
|
||||
def merge_contact_data(original: Contact, enriched: Contact) -> Contact:
|
||||
"""
|
||||
Merge contact data from original search with enriched data.
|
||||
Enriched data complements original data, only filling in missing values.
|
||||
"""
|
||||
merged_data = original.model_dump()
|
||||
enriched_data = enriched.model_dump()
|
||||
|
||||
# Only update fields that are None, empty string, empty list, or default values in original
|
||||
for key, enriched_value in enriched_data.items():
|
||||
# Skip if enriched value is None, empty string, or empty list
|
||||
if enriched_value is None or enriched_value == "" or enriched_value == []:
|
||||
continue
|
||||
|
||||
# Update if original value is None, empty string, empty list, or zero
|
||||
if enriched_value:
|
||||
merged_data[key] = enriched_value
|
||||
|
||||
return Contact(**merged_data)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
@@ -339,25 +387,8 @@ class SearchPeopleBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
query = SearchPeopleRequest(**input_data.model_dump())
|
||||
query = SearchPeopleRequest(**input_data.model_dump(exclude={"credentials"}))
|
||||
people = await self.search_people(query, credentials)
|
||||
|
||||
# Enrich with detailed info if requested
|
||||
if input_data.enrich_info:
|
||||
|
||||
async def enrich_or_fallback(person: Contact):
|
||||
try:
|
||||
enrich_query = EnrichPersonRequest(person_id=person.id)
|
||||
enriched_person = await self.enrich_person(
|
||||
enrich_query, credentials
|
||||
)
|
||||
# Merge enriched data with original data, complementing instead of replacing
|
||||
return self.merge_contact_data(person, enriched_person)
|
||||
except Exception:
|
||||
return person # If enrichment fails, use original person data
|
||||
|
||||
people = await asyncio.gather(
|
||||
*(enrich_or_fallback(person) for person in people)
|
||||
)
|
||||
|
||||
for person in people:
|
||||
yield "person", person
|
||||
yield "people", people
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
from backend.blocks.apollo._api import ApolloClient
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
ApolloCredentials,
|
||||
ApolloCredentialsInput,
|
||||
)
|
||||
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
|
||||
|
||||
class GetPersonDetailBlock(Block):
|
||||
"""Get detailed person data with Apollo API, including email reveal"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
person_id: str = SchemaField(
|
||||
description="Apollo person ID to enrich (most accurate method)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
first_name: str = SchemaField(
|
||||
description="First name of the person to enrich",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
last_name: str = SchemaField(
|
||||
description="Last name of the person to enrich",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
name: str = SchemaField(
|
||||
description="Full name of the person to enrich (alternative to first_name + last_name)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
email: str = SchemaField(
|
||||
description="Known email address of the person (helps with matching)",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
domain: str = SchemaField(
|
||||
description="Company domain of the person (e.g., 'google.com')",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
company: str = SchemaField(
|
||||
description="Company name of the person",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
linkedin_url: str = SchemaField(
|
||||
description="LinkedIn URL of the person",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
organization_id: str = SchemaField(
|
||||
description="Apollo organization ID of the person's company",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
title: str = SchemaField(
|
||||
description="Job title of the person to enrich",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
credentials: ApolloCredentialsInput = CredentialsField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
contact: Contact = SchemaField(
|
||||
description="Enriched contact information",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if enrichment failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3b18d46c-3db6-42ae-a228-0ba441bdd176",
|
||||
description="Get detailed person data with Apollo API, including email reveal",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=GetPersonDetailBlock.Input,
|
||||
output_schema=GetPersonDetailBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"first_name": "John",
|
||||
"last_name": "Doe",
|
||||
"company": "Google",
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"contact",
|
||||
Contact(
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
email="john.doe@gmail.com",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
),
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"enrich_person": lambda query, credentials: Contact(
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
email="john.doe@gmail.com",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def enrich_person(
|
||||
query: EnrichPersonRequest, credentials: ApolloCredentials
|
||||
) -> Contact:
|
||||
client = ApolloClient(credentials)
|
||||
return await client.enrich_person(query)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: ApolloCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
query = EnrichPersonRequest(**input_data.model_dump())
|
||||
yield "contact", await self.enrich_person(query, credentials)
|
||||
@@ -1,9 +1,11 @@
|
||||
import enum
|
||||
from typing import Any
|
||||
from typing import Any, List
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.type import MediaFileType, convert
|
||||
|
||||
|
||||
@@ -12,12 +14,6 @@ class FileStoreBlock(Block):
|
||||
file_in: MediaFileType = SchemaField(
|
||||
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
|
||||
)
|
||||
base_64: bool = SchemaField(
|
||||
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
|
||||
default=False,
|
||||
advanced=True,
|
||||
title="Produce Base64 Output",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
file_out: MediaFileType = SchemaField(
|
||||
@@ -41,11 +37,12 @@ class FileStoreBlock(Block):
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
yield "file_out", await store_media_file(
|
||||
file_path = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.file_in,
|
||||
return_content=input_data.base_64,
|
||||
return_content=False,
|
||||
)
|
||||
yield "file_out", file_path
|
||||
|
||||
|
||||
class StoreValueBlock(Block):
|
||||
@@ -118,6 +115,266 @@ class PrintToConsoleBlock(Block):
|
||||
yield "status", "printed"
|
||||
|
||||
|
||||
class FindInDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: Any = SchemaField(description="Dictionary to lookup from")
|
||||
key: str | int = SchemaField(description="Key to lookup in the dictionary")
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="Value found for the given key")
|
||||
missing: Any = SchemaField(
|
||||
description="Value of the input that missing the key"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0e50422c-6dee-4145-83d6-3a5a392f65de",
|
||||
description="Lookup the given key in the input dictionary/object/list and return the value.",
|
||||
input_schema=FindInDictionaryBlock.Input,
|
||||
output_schema=FindInDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{"input": {"apple": 1, "banana": 2, "cherry": 3}, "key": "banana"},
|
||||
{"input": {"x": 10, "y": 20, "z": 30}, "key": "w"},
|
||||
{"input": [1, 2, 3], "key": 1},
|
||||
{"input": [1, 2, 3], "key": 3},
|
||||
{"input": MockObject(value="!!", key="key"), "key": "key"},
|
||||
{"input": [{"k1": "v1"}, {"k2": "v2"}, {"k1": "v3"}], "key": "k1"},
|
||||
],
|
||||
test_output=[
|
||||
("output", 2),
|
||||
("missing", {"x": 10, "y": 20, "z": 30}),
|
||||
("output", 2),
|
||||
("missing", [1, 2, 3]),
|
||||
("output", "key"),
|
||||
("output", ["v1", "v3"]),
|
||||
],
|
||||
categories={BlockCategory.BASIC},
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
obj = input_data.input
|
||||
key = input_data.key
|
||||
|
||||
if isinstance(obj, str):
|
||||
obj = json.loads(obj)
|
||||
|
||||
if isinstance(obj, dict) and key in obj:
|
||||
yield "output", obj[key]
|
||||
elif isinstance(obj, list) and isinstance(key, int) and 0 <= key < len(obj):
|
||||
yield "output", obj[key]
|
||||
elif isinstance(obj, list) and isinstance(key, str):
|
||||
if len(obj) == 0:
|
||||
yield "output", []
|
||||
elif isinstance(obj[0], dict) and key in obj[0]:
|
||||
yield "output", [item[key] for item in obj if key in item]
|
||||
else:
|
||||
yield "output", [getattr(val, key) for val in obj if hasattr(val, key)]
|
||||
elif isinstance(obj, object) and isinstance(key, str) and hasattr(obj, key):
|
||||
yield "output", getattr(obj, key)
|
||||
else:
|
||||
yield "missing", input_data.input
|
||||
|
||||
|
||||
class AddToDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
|
||||
)
|
||||
key: str = SchemaField(
|
||||
default="",
|
||||
description="The key for the new entry.",
|
||||
placeholder="new_key",
|
||||
advanced=False,
|
||||
)
|
||||
value: Any = SchemaField(
|
||||
default=None,
|
||||
description="The value for the new entry.",
|
||||
placeholder="new_value",
|
||||
advanced=False,
|
||||
)
|
||||
entries: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
description="The entries to add to the dictionary. This is the batch version of the `key` and `value` fields.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_dictionary: dict = SchemaField(
|
||||
description="The dictionary with the new entry added."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="31d1064e-7446-4693-a7d4-65e5ca1180d1",
|
||||
description="Adds a new key-value pair to a dictionary. If no dictionary is provided, a new one is created.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=AddToDictionaryBlock.Input,
|
||||
output_schema=AddToDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"dictionary": {"existing_key": "existing_value"},
|
||||
"key": "new_key",
|
||||
"value": "new_value",
|
||||
},
|
||||
{"key": "first_key", "value": "first_value"},
|
||||
{
|
||||
"dictionary": {"existing_key": "existing_value"},
|
||||
"entries": {"new_key": "new_value", "first_key": "first_value"},
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"updated_dictionary",
|
||||
{"existing_key": "existing_value", "new_key": "new_value"},
|
||||
),
|
||||
("updated_dictionary", {"first_key": "first_value"}),
|
||||
(
|
||||
"updated_dictionary",
|
||||
{
|
||||
"existing_key": "existing_value",
|
||||
"new_key": "new_value",
|
||||
"first_key": "first_value",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
updated_dict = input_data.dictionary.copy()
|
||||
|
||||
if input_data.value is not None and input_data.key:
|
||||
updated_dict[input_data.key] = input_data.value
|
||||
|
||||
for key, value in input_data.entries.items():
|
||||
updated_dict[key] = value
|
||||
|
||||
yield "updated_dictionary", updated_dict
|
||||
|
||||
|
||||
class AddToListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
description="The list to add the entry to. If not provided, a new list will be created.",
|
||||
)
|
||||
entry: Any = SchemaField(
|
||||
description="The entry to add to the list. Can be of any type (string, int, dict, etc.).",
|
||||
advanced=False,
|
||||
default=None,
|
||||
)
|
||||
entries: List[Any] = SchemaField(
|
||||
default_factory=lambda: list(),
|
||||
description="The entries to add to the list. This is the batch version of the `entry` field.",
|
||||
advanced=True,
|
||||
)
|
||||
position: int | None = SchemaField(
|
||||
default=None,
|
||||
description="The position to insert the new entry. If not provided, the entry will be appended to the end of the list.",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_list: List[Any] = SchemaField(
|
||||
description="The list with the new entry added."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="aeb08fc1-2fc1-4141-bc8e-f758f183a822",
|
||||
description="Adds a new entry to a list. The entry can be of any type. If no list is provided, a new one is created.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=AddToListBlock.Input,
|
||||
output_schema=AddToListBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"list": [1, "string", {"existing_key": "existing_value"}],
|
||||
"entry": {"new_key": "new_value"},
|
||||
"position": 1,
|
||||
},
|
||||
{"entry": "first_entry"},
|
||||
{"list": ["a", "b", "c"], "entry": "d"},
|
||||
{
|
||||
"entry": "e",
|
||||
"entries": ["f", "g"],
|
||||
"list": ["a", "b"],
|
||||
"position": 1,
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"updated_list",
|
||||
[
|
||||
1,
|
||||
{"new_key": "new_value"},
|
||||
"string",
|
||||
{"existing_key": "existing_value"},
|
||||
],
|
||||
),
|
||||
("updated_list", ["first_entry"]),
|
||||
("updated_list", ["a", "b", "c", "d"]),
|
||||
("updated_list", ["a", "f", "g", "e", "b"]),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
entries_added = input_data.entries.copy()
|
||||
if input_data.entry:
|
||||
entries_added.append(input_data.entry)
|
||||
|
||||
updated_list = input_data.list.copy()
|
||||
if (pos := input_data.position) is not None:
|
||||
updated_list = updated_list[:pos] + entries_added + updated_list[pos:]
|
||||
else:
|
||||
updated_list += entries_added
|
||||
|
||||
yield "updated_list", updated_list
|
||||
|
||||
|
||||
class FindInListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to search in.")
|
||||
value: Any = SchemaField(description="The value to search for.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
index: int = SchemaField(description="The index of the value in the list.")
|
||||
found: bool = SchemaField(
|
||||
description="Whether the value was found in the list."
|
||||
)
|
||||
not_found_value: Any = SchemaField(
|
||||
description="The value that was not found in the list."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5e2c6d0a-1e37-489f-b1d0-8e1812b23333",
|
||||
description="Finds the index of the value in the list.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=FindInListBlock.Input,
|
||||
output_schema=FindInListBlock.Output,
|
||||
test_input=[
|
||||
{"list": [1, 2, 3, 4, 5], "value": 3},
|
||||
{"list": [1, 2, 3, 4, 5], "value": 6},
|
||||
],
|
||||
test_output=[
|
||||
("index", 2),
|
||||
("found", True),
|
||||
("found", False),
|
||||
("not_found_value", 6),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
yield "index", input_data.list.index(input_data.value)
|
||||
yield "found", True
|
||||
except ValueError:
|
||||
yield "found", False
|
||||
yield "not_found_value", input_data.value
|
||||
|
||||
|
||||
class NoteBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(description="The text to display in the sticky note.")
|
||||
@@ -143,6 +400,104 @@ class NoteBlock(Block):
|
||||
yield "output", input_data.text
|
||||
|
||||
|
||||
class CreateDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: dict[str, Any] = SchemaField(
|
||||
description="Key-value pairs to create the dictionary with",
|
||||
placeholder="e.g., {'name': 'Alice', 'age': 25}",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
dictionary: dict[str, Any] = SchemaField(
|
||||
description="The created dictionary containing the specified key-value pairs"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if dictionary creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b924ddf4-de4f-4b56-9a85-358930dcbc91",
|
||||
description="Creates a dictionary with the specified key-value pairs. Use this when you know all the values you want to add upfront.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=CreateDictionaryBlock.Input,
|
||||
output_schema=CreateDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"values": {"name": "Alice", "age": 25, "city": "New York"},
|
||||
},
|
||||
{
|
||||
"values": {"numbers": [1, 2, 3], "active": True, "score": 95.5},
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"dictionary",
|
||||
{"name": "Alice", "age": 25, "city": "New York"},
|
||||
),
|
||||
(
|
||||
"dictionary",
|
||||
{"numbers": [1, 2, 3], "active": True, "score": 95.5},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# The values are already validated by Pydantic schema
|
||||
yield "dictionary", input_data.values
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create dictionary: {str(e)}"
|
||||
|
||||
|
||||
class CreateListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: List[Any] = SchemaField(
|
||||
description="A list of values to be combined into a new list.",
|
||||
placeholder="e.g., ['Alice', 25, True]",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
description="The created list containing the specified values."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if list creation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a912d5c7-6e00-4542-b2a9-8034136930e4",
|
||||
description="Creates a list with the specified values. Use this when you know all the values you want to add upfront.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=CreateListBlock.Input,
|
||||
output_schema=CreateListBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"values": ["Alice", 25, True],
|
||||
},
|
||||
{
|
||||
"values": [1, 2, 3, "four", {"key": "value"}],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"list",
|
||||
["Alice", 25, True],
|
||||
),
|
||||
(
|
||||
"list",
|
||||
[1, 2, 3, "four", {"key": "value"}],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# The values are already validated by Pydantic schema
|
||||
yield "list", input_data.values
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create list: {str(e)}"
|
||||
|
||||
|
||||
class TypeOptions(enum.Enum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
@@ -160,7 +515,6 @@ class UniversalTypeConverterBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
value: Any = SchemaField(description="The converted value.")
|
||||
error: str = SchemaField(description="Error message if conversion failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Any
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.type import convert
|
||||
|
||||
|
||||
class ComparisonOperator(Enum):
|
||||
@@ -164,7 +163,7 @@ class IfInputMatchesBlock(Block):
|
||||
},
|
||||
{
|
||||
"input": 10,
|
||||
"value": "None",
|
||||
"value": None,
|
||||
"yes_value": "Yes",
|
||||
"no_value": "No",
|
||||
},
|
||||
@@ -182,23 +181,7 @@ class IfInputMatchesBlock(Block):
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
|
||||
# If input_data.value is not matching input_data.input, convert value to type of input
|
||||
if (
|
||||
input_data.input != input_data.value
|
||||
and input_data.input is not input_data.value
|
||||
):
|
||||
try:
|
||||
# Only attempt conversion if input is not None and value is not None
|
||||
if input_data.input is not None and input_data.value is not None:
|
||||
input_type = type(input_data.input)
|
||||
# Avoid converting if input_type is Any or object
|
||||
if input_type not in (Any, object):
|
||||
input_data.value = convert(input_data.value, input_type)
|
||||
except Exception:
|
||||
pass # If conversion fails, just leave value as is
|
||||
|
||||
if input_data.input == input_data.value:
|
||||
if input_data.input == input_data.value or input_data.input is input_data.value:
|
||||
yield "result", True
|
||||
yield "yes_output", input_data.yes_value
|
||||
else:
|
||||
|
||||
@@ -1,683 +0,0 @@
|
||||
from typing import Any, List
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.json import loads
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.prompt import estimate_token_count_str
|
||||
|
||||
# =============================================================================
|
||||
# Dictionary Manipulation Blocks
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CreateDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: dict[str, Any] = SchemaField(
|
||||
description="Key-value pairs to create the dictionary with",
|
||||
placeholder="e.g., {'name': 'Alice', 'age': 25}",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
dictionary: dict[str, Any] = SchemaField(
|
||||
description="The created dictionary containing the specified key-value pairs"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if dictionary creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b924ddf4-de4f-4b56-9a85-358930dcbc91",
|
||||
description="Creates a dictionary with the specified key-value pairs. Use this when you know all the values you want to add upfront.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=CreateDictionaryBlock.Input,
|
||||
output_schema=CreateDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"values": {"name": "Alice", "age": 25, "city": "New York"},
|
||||
},
|
||||
{
|
||||
"values": {"numbers": [1, 2, 3], "active": True, "score": 95.5},
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"dictionary",
|
||||
{"name": "Alice", "age": 25, "city": "New York"},
|
||||
),
|
||||
(
|
||||
"dictionary",
|
||||
{"numbers": [1, 2, 3], "active": True, "score": 95.5},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# The values are already validated by Pydantic schema
|
||||
yield "dictionary", input_data.values
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create dictionary: {str(e)}"
|
||||
|
||||
|
||||
class AddToDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
description="The dictionary to add the entry to. If not provided, a new dictionary will be created.",
|
||||
)
|
||||
key: str = SchemaField(
|
||||
default="",
|
||||
description="The key for the new entry.",
|
||||
placeholder="new_key",
|
||||
advanced=False,
|
||||
)
|
||||
value: Any = SchemaField(
|
||||
default=None,
|
||||
description="The value for the new entry.",
|
||||
placeholder="new_value",
|
||||
advanced=False,
|
||||
)
|
||||
entries: dict[Any, Any] = SchemaField(
|
||||
default_factory=dict,
|
||||
description="The entries to add to the dictionary. This is the batch version of the `key` and `value` fields.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_dictionary: dict = SchemaField(
|
||||
description="The dictionary with the new entry added."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="31d1064e-7446-4693-a7d4-65e5ca1180d1",
|
||||
description="Adds a new key-value pair to a dictionary. If no dictionary is provided, a new one is created.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=AddToDictionaryBlock.Input,
|
||||
output_schema=AddToDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"dictionary": {"existing_key": "existing_value"},
|
||||
"key": "new_key",
|
||||
"value": "new_value",
|
||||
},
|
||||
{"key": "first_key", "value": "first_value"},
|
||||
{
|
||||
"dictionary": {"existing_key": "existing_value"},
|
||||
"entries": {"new_key": "new_value", "first_key": "first_value"},
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"updated_dictionary",
|
||||
{"existing_key": "existing_value", "new_key": "new_value"},
|
||||
),
|
||||
("updated_dictionary", {"first_key": "first_value"}),
|
||||
(
|
||||
"updated_dictionary",
|
||||
{
|
||||
"existing_key": "existing_value",
|
||||
"new_key": "new_value",
|
||||
"first_key": "first_value",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
updated_dict = input_data.dictionary.copy()
|
||||
|
||||
if input_data.value is not None and input_data.key:
|
||||
updated_dict[input_data.key] = input_data.value
|
||||
|
||||
for key, value in input_data.entries.items():
|
||||
updated_dict[key] = value
|
||||
|
||||
yield "updated_dictionary", updated_dict
|
||||
|
||||
|
||||
class FindInDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: Any = SchemaField(description="Dictionary to lookup from")
|
||||
key: str | int = SchemaField(description="Key to lookup in the dictionary")
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="Value found for the given key")
|
||||
missing: Any = SchemaField(
|
||||
description="Value of the input that missing the key"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0e50422c-6dee-4145-83d6-3a5a392f65de",
|
||||
description="Lookup the given key in the input dictionary/object/list and return the value.",
|
||||
input_schema=FindInDictionaryBlock.Input,
|
||||
output_schema=FindInDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{"input": {"apple": 1, "banana": 2, "cherry": 3}, "key": "banana"},
|
||||
{"input": {"x": 10, "y": 20, "z": 30}, "key": "w"},
|
||||
{"input": [1, 2, 3], "key": 1},
|
||||
{"input": [1, 2, 3], "key": 3},
|
||||
{"input": MockObject(value="!!", key="key"), "key": "key"},
|
||||
{"input": [{"k1": "v1"}, {"k2": "v2"}, {"k1": "v3"}], "key": "k1"},
|
||||
],
|
||||
test_output=[
|
||||
("output", 2),
|
||||
("missing", {"x": 10, "y": 20, "z": 30}),
|
||||
("output", 2),
|
||||
("missing", [1, 2, 3]),
|
||||
("output", "key"),
|
||||
("output", ["v1", "v3"]),
|
||||
],
|
||||
categories={BlockCategory.BASIC},
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
obj = input_data.input
|
||||
key = input_data.key
|
||||
|
||||
if isinstance(obj, str):
|
||||
obj = loads(obj)
|
||||
|
||||
if isinstance(obj, dict) and key in obj:
|
||||
yield "output", obj[key]
|
||||
elif isinstance(obj, list) and isinstance(key, int) and 0 <= key < len(obj):
|
||||
yield "output", obj[key]
|
||||
elif isinstance(obj, list) and isinstance(key, str):
|
||||
if len(obj) == 0:
|
||||
yield "output", []
|
||||
elif isinstance(obj[0], dict) and key in obj[0]:
|
||||
yield "output", [item[key] for item in obj if key in item]
|
||||
else:
|
||||
yield "output", [getattr(val, key) for val in obj if hasattr(val, key)]
|
||||
elif isinstance(obj, object) and isinstance(key, str) and hasattr(obj, key):
|
||||
yield "output", getattr(obj, key)
|
||||
else:
|
||||
yield "missing", input_data.input
|
||||
|
||||
|
||||
class RemoveFromDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary to modify."
|
||||
)
|
||||
key: str | int = SchemaField(description="Key to remove from the dictionary.")
|
||||
return_value: bool = SchemaField(
|
||||
default=False, description="Whether to return the removed value."
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary after removal."
|
||||
)
|
||||
removed_value: Any = SchemaField(description="The removed value if requested.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="46afe2ea-c613-43f8-95ff-6692c3ef6876",
|
||||
description="Removes a key-value pair from a dictionary.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=RemoveFromDictionaryBlock.Input,
|
||||
output_schema=RemoveFromDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"dictionary": {"a": 1, "b": 2, "c": 3},
|
||||
"key": "b",
|
||||
"return_value": True,
|
||||
},
|
||||
{"dictionary": {"x": "hello", "y": "world"}, "key": "x"},
|
||||
],
|
||||
test_output=[
|
||||
("updated_dictionary", {"a": 1, "c": 3}),
|
||||
("removed_value", 2),
|
||||
("updated_dictionary", {"y": "world"}),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
updated_dict = input_data.dictionary.copy()
|
||||
try:
|
||||
removed_value = updated_dict.pop(input_data.key)
|
||||
yield "updated_dictionary", updated_dict
|
||||
if input_data.return_value:
|
||||
yield "removed_value", removed_value
|
||||
except KeyError:
|
||||
yield "error", f"Key '{input_data.key}' not found in dictionary"
|
||||
|
||||
|
||||
class ReplaceDictionaryValueBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary to modify."
|
||||
)
|
||||
key: str | int = SchemaField(description="Key to replace the value for.")
|
||||
value: Any = SchemaField(description="The new value for the given key.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_dictionary: dict[Any, Any] = SchemaField(
|
||||
description="The dictionary after replacement."
|
||||
)
|
||||
old_value: Any = SchemaField(description="The value that was replaced.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="27e31876-18b6-44f3-ab97-f6226d8b3889",
|
||||
description="Replaces the value for a specified key in a dictionary.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=ReplaceDictionaryValueBlock.Input,
|
||||
output_schema=ReplaceDictionaryValueBlock.Output,
|
||||
test_input=[
|
||||
{"dictionary": {"a": 1, "b": 2, "c": 3}, "key": "b", "value": 99},
|
||||
{
|
||||
"dictionary": {"x": "hello", "y": "world"},
|
||||
"key": "y",
|
||||
"value": "universe",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("updated_dictionary", {"a": 1, "b": 99, "c": 3}),
|
||||
("old_value", 2),
|
||||
("updated_dictionary", {"x": "hello", "y": "universe"}),
|
||||
("old_value", "world"),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
updated_dict = input_data.dictionary.copy()
|
||||
try:
|
||||
old_value = updated_dict[input_data.key]
|
||||
updated_dict[input_data.key] = input_data.value
|
||||
yield "updated_dictionary", updated_dict
|
||||
yield "old_value", old_value
|
||||
except KeyError:
|
||||
yield "error", f"Key '{input_data.key}' not found in dictionary"
|
||||
|
||||
|
||||
class DictionaryIsEmptyBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(description="The dictionary to check.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
is_empty: bool = SchemaField(description="True if the dictionary is empty.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a3cf3f64-6bb9-4cc6-9900-608a0b3359b0",
|
||||
description="Checks if a dictionary is empty.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=DictionaryIsEmptyBlock.Input,
|
||||
output_schema=DictionaryIsEmptyBlock.Output,
|
||||
test_input=[{"dictionary": {}}, {"dictionary": {"a": 1}}],
|
||||
test_output=[("is_empty", True), ("is_empty", False)],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "is_empty", len(input_data.dictionary) == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# List Manipulation Blocks
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CreateListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: List[Any] = SchemaField(
|
||||
description="A list of values to be combined into a new list.",
|
||||
placeholder="e.g., ['Alice', 25, True]",
|
||||
)
|
||||
max_size: int | None = SchemaField(
|
||||
default=None,
|
||||
description="Maximum size of the list. If provided, the list will be yielded in chunks of this size.",
|
||||
advanced=True,
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
default=None,
|
||||
description="Maximum tokens for the list. If provided, the list will be yielded in chunks that fit within this token limit.",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
description="The created list containing the specified values."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if list creation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a912d5c7-6e00-4542-b2a9-8034136930e4",
|
||||
description="Creates a list with the specified values. Use this when you know all the values you want to add upfront. This block can also yield the list in batches based on a maximum size or token limit.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=CreateListBlock.Input,
|
||||
output_schema=CreateListBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"values": ["Alice", 25, True],
|
||||
},
|
||||
{
|
||||
"values": [1, 2, 3, "four", {"key": "value"}],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"list",
|
||||
["Alice", 25, True],
|
||||
),
|
||||
(
|
||||
"list",
|
||||
[1, 2, 3, "four", {"key": "value"}],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
chunk = []
|
||||
cur_tokens, max_tokens = 0, input_data.max_tokens
|
||||
cur_size, max_size = 0, input_data.max_size
|
||||
|
||||
for value in input_data.values:
|
||||
if max_tokens:
|
||||
tokens = estimate_token_count_str(value)
|
||||
else:
|
||||
tokens = 0
|
||||
|
||||
# Check if adding this value would exceed either limit
|
||||
if (max_tokens and (cur_tokens + tokens > max_tokens)) or (
|
||||
max_size and (cur_size + 1 > max_size)
|
||||
):
|
||||
yield "list", chunk
|
||||
chunk = [value]
|
||||
cur_size, cur_tokens = 1, tokens
|
||||
else:
|
||||
chunk.append(value)
|
||||
cur_size, cur_tokens = cur_size + 1, cur_tokens + tokens
|
||||
|
||||
# Yield final chunk if any
|
||||
if chunk or not input_data.values:
|
||||
yield "list", chunk
|
||||
|
||||
|
||||
class AddToListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
description="The list to add the entry to. If not provided, a new list will be created.",
|
||||
)
|
||||
entry: Any = SchemaField(
|
||||
description="The entry to add to the list. Can be of any type (string, int, dict, etc.).",
|
||||
advanced=False,
|
||||
default=None,
|
||||
)
|
||||
entries: List[Any] = SchemaField(
|
||||
default_factory=lambda: list(),
|
||||
description="The entries to add to the list. This is the batch version of the `entry` field.",
|
||||
advanced=True,
|
||||
)
|
||||
position: int | None = SchemaField(
|
||||
default=None,
|
||||
description="The position to insert the new entry. If not provided, the entry will be appended to the end of the list.",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_list: List[Any] = SchemaField(
|
||||
description="The list with the new entry added."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="aeb08fc1-2fc1-4141-bc8e-f758f183a822",
|
||||
description="Adds a new entry to a list. The entry can be of any type. If no list is provided, a new one is created.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=AddToListBlock.Input,
|
||||
output_schema=AddToListBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"list": [1, "string", {"existing_key": "existing_value"}],
|
||||
"entry": {"new_key": "new_value"},
|
||||
"position": 1,
|
||||
},
|
||||
{"entry": "first_entry"},
|
||||
{"list": ["a", "b", "c"], "entry": "d"},
|
||||
{
|
||||
"entry": "e",
|
||||
"entries": ["f", "g"],
|
||||
"list": ["a", "b"],
|
||||
"position": 1,
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"updated_list",
|
||||
[
|
||||
1,
|
||||
{"new_key": "new_value"},
|
||||
"string",
|
||||
{"existing_key": "existing_value"},
|
||||
],
|
||||
),
|
||||
("updated_list", ["first_entry"]),
|
||||
("updated_list", ["a", "b", "c", "d"]),
|
||||
("updated_list", ["a", "f", "g", "e", "b"]),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
entries_added = input_data.entries.copy()
|
||||
if input_data.entry:
|
||||
entries_added.append(input_data.entry)
|
||||
|
||||
updated_list = input_data.list.copy()
|
||||
if (pos := input_data.position) is not None:
|
||||
updated_list = updated_list[:pos] + entries_added + updated_list[pos:]
|
||||
else:
|
||||
updated_list += entries_added
|
||||
|
||||
yield "updated_list", updated_list
|
||||
|
||||
|
||||
class FindInListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to search in.")
|
||||
value: Any = SchemaField(description="The value to search for.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
index: int = SchemaField(description="The index of the value in the list.")
|
||||
found: bool = SchemaField(
|
||||
description="Whether the value was found in the list."
|
||||
)
|
||||
not_found_value: Any = SchemaField(
|
||||
description="The value that was not found in the list."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5e2c6d0a-1e37-489f-b1d0-8e1812b23333",
|
||||
description="Finds the index of the value in the list.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=FindInListBlock.Input,
|
||||
output_schema=FindInListBlock.Output,
|
||||
test_input=[
|
||||
{"list": [1, 2, 3, 4, 5], "value": 3},
|
||||
{"list": [1, 2, 3, 4, 5], "value": 6},
|
||||
],
|
||||
test_output=[
|
||||
("index", 2),
|
||||
("found", True),
|
||||
("found", False),
|
||||
("not_found_value", 6),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
yield "index", input_data.list.index(input_data.value)
|
||||
yield "found", True
|
||||
except ValueError:
|
||||
yield "found", False
|
||||
yield "not_found_value", input_data.value
|
||||
|
||||
|
||||
class GetListItemBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to get the item from.")
|
||||
index: int = SchemaField(
|
||||
description="The 0-based index of the item (supports negative indices)."
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
item: Any = SchemaField(description="The item at the specified index.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="262ca24c-1025-43cf-a578-534e23234e97",
|
||||
description="Returns the element at the given index.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=GetListItemBlock.Input,
|
||||
output_schema=GetListItemBlock.Output,
|
||||
test_input=[
|
||||
{"list": [1, 2, 3], "index": 1},
|
||||
{"list": [1, 2, 3], "index": -1},
|
||||
],
|
||||
test_output=[
|
||||
("item", 2),
|
||||
("item", 3),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
yield "item", input_data.list[input_data.index]
|
||||
except IndexError:
|
||||
yield "error", "Index out of range"
|
||||
|
||||
|
||||
class RemoveFromListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to modify.")
|
||||
value: Any = SchemaField(
|
||||
default=None, description="Value to remove from the list."
|
||||
)
|
||||
index: int | None = SchemaField(
|
||||
default=None,
|
||||
description="Index of the item to pop (supports negative indices).",
|
||||
)
|
||||
return_item: bool = SchemaField(
|
||||
default=False, description="Whether to return the removed item."
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_list: List[Any] = SchemaField(description="The list after removal.")
|
||||
removed_item: Any = SchemaField(description="The removed item if requested.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d93c5a93-ac7e-41c1-ae5c-ef67e6e9b826",
|
||||
description="Removes an item from a list by value or index.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=RemoveFromListBlock.Input,
|
||||
output_schema=RemoveFromListBlock.Output,
|
||||
test_input=[
|
||||
{"list": [1, 2, 3], "index": 1, "return_item": True},
|
||||
{"list": ["a", "b", "c"], "value": "b"},
|
||||
],
|
||||
test_output=[
|
||||
("updated_list", [1, 3]),
|
||||
("removed_item", 2),
|
||||
("updated_list", ["a", "c"]),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
lst = input_data.list.copy()
|
||||
removed = None
|
||||
try:
|
||||
if input_data.index is not None:
|
||||
removed = lst.pop(input_data.index)
|
||||
elif input_data.value is not None:
|
||||
lst.remove(input_data.value)
|
||||
removed = input_data.value
|
||||
else:
|
||||
raise ValueError("No index or value provided for removal")
|
||||
except (IndexError, ValueError):
|
||||
yield "error", "Index or value not found"
|
||||
return
|
||||
|
||||
yield "updated_list", lst
|
||||
if input_data.return_item:
|
||||
yield "removed_item", removed
|
||||
|
||||
|
||||
class ReplaceListItemBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to modify.")
|
||||
index: int = SchemaField(
|
||||
description="Index of the item to replace (supports negative indices)."
|
||||
)
|
||||
value: Any = SchemaField(description="The new value for the given index.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
updated_list: List[Any] = SchemaField(description="The list after replacement.")
|
||||
old_item: Any = SchemaField(description="The item that was replaced.")
|
||||
error: str = SchemaField(description="Error message if the operation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fbf62922-bea1-4a3d-8bac-23587f810b38",
|
||||
description="Replaces an item at the specified index.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=ReplaceListItemBlock.Input,
|
||||
output_schema=ReplaceListItemBlock.Output,
|
||||
test_input=[
|
||||
{"list": [1, 2, 3], "index": 1, "value": 99},
|
||||
{"list": ["a", "b"], "index": -1, "value": "c"},
|
||||
],
|
||||
test_output=[
|
||||
("updated_list", [1, 99, 3]),
|
||||
("old_item", 2),
|
||||
("updated_list", ["a", "c"]),
|
||||
("old_item", "b"),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
lst = input_data.list.copy()
|
||||
try:
|
||||
old = lst[input_data.index]
|
||||
lst[input_data.index] = input_data.value
|
||||
except IndexError:
|
||||
yield "error", "Index out of range"
|
||||
return
|
||||
|
||||
yield "updated_list", lst
|
||||
yield "old_item", old
|
||||
|
||||
|
||||
class ListIsEmptyBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to check.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
is_empty: bool = SchemaField(description="True if the list is empty.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="896ed73b-27d0-41be-813c-c1c1dc856c03",
|
||||
description="Checks if a list is empty.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=ListIsEmptyBlock.Input,
|
||||
output_schema=ListIsEmptyBlock.Output,
|
||||
test_input=[{"list": []}, {"list": [1]}],
|
||||
test_output=[("is_empty", True), ("is_empty", False)],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "is_empty", len(input_data.list) == 0
|
||||
@@ -13,7 +13,7 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.file import MediaFileType, store_media_file
|
||||
from backend.util.file import MediaFileType
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -108,7 +108,7 @@ class AIImageEditorBlock(Block):
|
||||
output_schema=AIImageEditorBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Add a hat to the cat",
|
||||
"input_image": "",
|
||||
"input_image": "https://example.com/cat.png",
|
||||
"aspect_ratio": AspectRatio.MATCH_INPUT_IMAGE,
|
||||
"seed": None,
|
||||
"model": FluxKontextModelName.PRO,
|
||||
@@ -128,22 +128,13 @@ class AIImageEditorBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
result = await self.run_model(
|
||||
api_key=credentials.api_key,
|
||||
model_name=input_data.model.api_name,
|
||||
prompt=input_data.prompt,
|
||||
input_image_b64=(
|
||||
await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.input_image,
|
||||
return_content=True,
|
||||
)
|
||||
if input_data.input_image
|
||||
else None
|
||||
),
|
||||
input_image=input_data.input_image,
|
||||
aspect_ratio=input_data.aspect_ratio.value,
|
||||
seed=input_data.seed,
|
||||
)
|
||||
@@ -154,14 +145,14 @@ class AIImageEditorBlock(Block):
|
||||
api_key: SecretStr,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
input_image_b64: Optional[str],
|
||||
input_image: Optional[MediaFileType],
|
||||
aspect_ratio: str,
|
||||
seed: Optional[int],
|
||||
) -> MediaFileType:
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"input_image": input_image_b64,
|
||||
"input_image": input_image,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
**({"seed": seed} if seed is not None else {}),
|
||||
}
|
||||
|
||||
@@ -498,9 +498,6 @@ class GithubListIssuesBlock(Block):
|
||||
issue: IssueItem = SchemaField(
|
||||
title="Issue", description="Issues with their title and URL"
|
||||
)
|
||||
issues: list[IssueItem] = SchemaField(
|
||||
description="List of issues with their title and URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing issues failed")
|
||||
|
||||
def __init__(self):
|
||||
@@ -516,22 +513,13 @@ class GithubListIssuesBlock(Block):
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"issues",
|
||||
[
|
||||
{
|
||||
"title": "Issue 1",
|
||||
"url": "https://github.com/owner/repo/issues/1",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"issue",
|
||||
{
|
||||
"title": "Issue 1",
|
||||
"url": "https://github.com/owner/repo/issues/1",
|
||||
},
|
||||
),
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"list_issues": lambda *args, **kwargs: [
|
||||
@@ -563,12 +551,10 @@ class GithubListIssuesBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
issues = await self.list_issues(
|
||||
for issue in await self.list_issues(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield "issues", issues
|
||||
for issue in issues:
|
||||
):
|
||||
yield "issue", issue
|
||||
|
||||
|
||||
|
||||
@@ -31,12 +31,7 @@ class GithubListPullRequestsBlock(Block):
|
||||
pull_request: PRItem = SchemaField(
|
||||
title="Pull Request", description="PRs with their title and URL"
|
||||
)
|
||||
pull_requests: list[PRItem] = SchemaField(
|
||||
description="List of pull requests with their title and URL"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if listing pull requests failed"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing issues failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -51,22 +46,13 @@ class GithubListPullRequestsBlock(Block):
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"pull_requests",
|
||||
[
|
||||
{
|
||||
"title": "Pull request 1",
|
||||
"url": "https://github.com/owner/repo/pull/1",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"pull_request",
|
||||
{
|
||||
"title": "Pull request 1",
|
||||
"url": "https://github.com/owner/repo/pull/1",
|
||||
},
|
||||
),
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"list_prs": lambda *args, **kwargs: [
|
||||
@@ -102,7 +88,6 @@ class GithubListPullRequestsBlock(Block):
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield "pull_requests", pull_requests
|
||||
for pr in pull_requests:
|
||||
yield "pull_request", pr
|
||||
|
||||
@@ -280,26 +265,10 @@ class GithubReadPullRequestBlock(Block):
|
||||
files = response.json()
|
||||
changes = []
|
||||
for file in files:
|
||||
status: str = file.get("status", "")
|
||||
diff: str = file.get("patch", "")
|
||||
if status != "removed":
|
||||
is_filename: str = file.get("filename", "")
|
||||
was_filename: str = (
|
||||
file.get("previous_filename", is_filename)
|
||||
if status != "added"
|
||||
else ""
|
||||
)
|
||||
else:
|
||||
is_filename = ""
|
||||
was_filename: str = file.get("filename", "")
|
||||
|
||||
patch_header = ""
|
||||
if was_filename:
|
||||
patch_header += f"--- {was_filename}\n"
|
||||
if is_filename:
|
||||
patch_header += f"+++ {is_filename}\n"
|
||||
changes.append(patch_header + diff)
|
||||
return "\n\n".join(changes)
|
||||
filename = file.get("filename", "")
|
||||
status = file.get("status", "")
|
||||
changes.append(f"{filename}: {status}")
|
||||
return "\n".join(changes)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
@@ -475,9 +444,6 @@ class GithubListPRReviewersBlock(Block):
|
||||
title="Reviewer",
|
||||
description="Reviewers with their username and profile URL",
|
||||
)
|
||||
reviewers: list[ReviewerItem] = SchemaField(
|
||||
description="List of reviewers with their username and profile URL"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if listing reviewers failed"
|
||||
)
|
||||
@@ -495,22 +461,13 @@ class GithubListPRReviewersBlock(Block):
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"reviewers",
|
||||
[
|
||||
{
|
||||
"username": "reviewer1",
|
||||
"url": "https://github.com/reviewer1",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"reviewer",
|
||||
{
|
||||
"username": "reviewer1",
|
||||
"url": "https://github.com/reviewer1",
|
||||
},
|
||||
),
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"list_reviewers": lambda *args, **kwargs: [
|
||||
@@ -543,12 +500,10 @@ class GithubListPRReviewersBlock(Block):
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
reviewers = await self.list_reviewers(
|
||||
for reviewer in await self.list_reviewers(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
)
|
||||
yield "reviewers", reviewers
|
||||
for reviewer in reviewers:
|
||||
):
|
||||
yield "reviewer", reviewer
|
||||
|
||||
|
||||
|
||||
@@ -31,9 +31,6 @@ class GithubListTagsBlock(Block):
|
||||
tag: TagItem = SchemaField(
|
||||
title="Tag", description="Tags with their name and file tree browser URL"
|
||||
)
|
||||
tags: list[TagItem] = SchemaField(
|
||||
description="List of tags with their name and file tree browser URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing tags failed")
|
||||
|
||||
def __init__(self):
|
||||
@@ -49,22 +46,13 @@ class GithubListTagsBlock(Block):
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"tags",
|
||||
[
|
||||
{
|
||||
"name": "v1.0.0",
|
||||
"url": "https://github.com/owner/repo/tree/v1.0.0",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"tag",
|
||||
{
|
||||
"name": "v1.0.0",
|
||||
"url": "https://github.com/owner/repo/tree/v1.0.0",
|
||||
},
|
||||
),
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"list_tags": lambda *args, **kwargs: [
|
||||
@@ -105,7 +93,6 @@ class GithubListTagsBlock(Block):
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield "tags", tags
|
||||
for tag in tags:
|
||||
yield "tag", tag
|
||||
|
||||
@@ -127,9 +114,6 @@ class GithubListBranchesBlock(Block):
|
||||
title="Branch",
|
||||
description="Branches with their name and file tree browser URL",
|
||||
)
|
||||
branches: list[BranchItem] = SchemaField(
|
||||
description="List of branches with their name and file tree browser URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing branches failed")
|
||||
|
||||
def __init__(self):
|
||||
@@ -145,22 +129,13 @@ class GithubListBranchesBlock(Block):
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"branches",
|
||||
[
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"branch",
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
},
|
||||
),
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"list_branches": lambda *args, **kwargs: [
|
||||
@@ -201,7 +176,6 @@ class GithubListBranchesBlock(Block):
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield "branches", branches
|
||||
for branch in branches:
|
||||
yield "branch", branch
|
||||
|
||||
@@ -225,9 +199,6 @@ class GithubListDiscussionsBlock(Block):
|
||||
discussion: DiscussionItem = SchemaField(
|
||||
title="Discussion", description="Discussions with their title and URL"
|
||||
)
|
||||
discussions: list[DiscussionItem] = SchemaField(
|
||||
description="List of discussions with their title and URL"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if listing discussions failed"
|
||||
)
|
||||
@@ -246,22 +217,13 @@ class GithubListDiscussionsBlock(Block):
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"discussions",
|
||||
[
|
||||
{
|
||||
"title": "Discussion 1",
|
||||
"url": "https://github.com/owner/repo/discussions/1",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"discussion",
|
||||
{
|
||||
"title": "Discussion 1",
|
||||
"url": "https://github.com/owner/repo/discussions/1",
|
||||
},
|
||||
),
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"list_discussions": lambda *args, **kwargs: [
|
||||
@@ -317,7 +279,6 @@ class GithubListDiscussionsBlock(Block):
|
||||
input_data.repo_url,
|
||||
input_data.num_discussions,
|
||||
)
|
||||
yield "discussions", discussions
|
||||
for discussion in discussions:
|
||||
yield "discussion", discussion
|
||||
|
||||
@@ -339,9 +300,6 @@ class GithubListReleasesBlock(Block):
|
||||
title="Release",
|
||||
description="Releases with their name and file tree browser URL",
|
||||
)
|
||||
releases: list[ReleaseItem] = SchemaField(
|
||||
description="List of releases with their name and file tree browser URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing releases failed")
|
||||
|
||||
def __init__(self):
|
||||
@@ -357,22 +315,13 @@ class GithubListReleasesBlock(Block):
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"releases",
|
||||
[
|
||||
{
|
||||
"name": "v1.0.0",
|
||||
"url": "https://github.com/owner/repo/releases/tag/v1.0.0",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"release",
|
||||
{
|
||||
"name": "v1.0.0",
|
||||
"url": "https://github.com/owner/repo/releases/tag/v1.0.0",
|
||||
},
|
||||
),
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"list_releases": lambda *args, **kwargs: [
|
||||
@@ -408,7 +357,6 @@ class GithubListReleasesBlock(Block):
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield "releases", releases
|
||||
for release in releases:
|
||||
yield "release", release
|
||||
|
||||
@@ -1093,9 +1041,6 @@ class GithubListStargazersBlock(Block):
|
||||
title="Stargazer",
|
||||
description="Stargazers with their username and profile URL",
|
||||
)
|
||||
stargazers: list[StargazerItem] = SchemaField(
|
||||
description="List of stargazers with their username and profile URL"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if listing stargazers failed"
|
||||
)
|
||||
@@ -1113,22 +1058,13 @@ class GithubListStargazersBlock(Block):
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"stargazers",
|
||||
[
|
||||
{
|
||||
"username": "octocat",
|
||||
"url": "https://github.com/octocat",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
"stargazer",
|
||||
{
|
||||
"username": "octocat",
|
||||
"url": "https://github.com/octocat",
|
||||
},
|
||||
),
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"list_stargazers": lambda *args, **kwargs: [
|
||||
@@ -1168,6 +1104,5 @@ class GithubListStargazersBlock(Block):
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield "stargazers", stargazers
|
||||
for stargazer in stargazers:
|
||||
yield "stargazer", stargazer
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,19 +3,11 @@ import logging
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import aiofiles
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
HostScopedCredentials,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import (
|
||||
MediaFileType,
|
||||
get_exec_file_path,
|
||||
@@ -27,30 +19,6 @@ from backend.util.request import Requests
|
||||
logger = logging.getLogger(name=__name__)
|
||||
|
||||
|
||||
# Host-scoped credentials for HTTP requests
|
||||
HttpCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.HTTP], Literal["host_scoped"]
|
||||
]
|
||||
|
||||
|
||||
TEST_CREDENTIALS = HostScopedCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer test-token"),
|
||||
"X-API-Key": SecretStr("test-api-key"),
|
||||
},
|
||||
title="Mock HTTP Host-Scoped Credentials",
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class HttpMethod(Enum):
|
||||
GET = "GET"
|
||||
POST = "POST"
|
||||
@@ -201,62 +169,3 @@ class SendWebRequestBlock(Block):
|
||||
yield "client_error", result
|
||||
else:
|
||||
yield "server_error", result
|
||||
|
||||
|
||||
class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
||||
class Input(SendWebRequestBlock.Input):
|
||||
credentials: HttpCredentials = CredentialsField(
|
||||
description="HTTP host-scoped credentials for automatic header injection",
|
||||
discriminator="url",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
Block.__init__(
|
||||
self,
|
||||
id="fff86bcd-e001-4bad-a7f6-2eae4720c8dc",
|
||||
description="Make an authenticated HTTP request with host-scoped credentials (JSON / form / multipart).",
|
||||
categories={BlockCategory.OUTPUT},
|
||||
input_schema=SendAuthenticatedWebRequestBlock.Input,
|
||||
output_schema=SendWebRequestBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run( # type: ignore[override]
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
credentials: HostScopedCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Create SendWebRequestBlock.Input from our input (removing credentials field)
|
||||
base_input = SendWebRequestBlock.Input(
|
||||
url=input_data.url,
|
||||
method=input_data.method,
|
||||
headers=input_data.headers,
|
||||
json_format=input_data.json_format,
|
||||
body=input_data.body,
|
||||
files_name=input_data.files_name,
|
||||
files=input_data.files,
|
||||
)
|
||||
|
||||
# Apply host-scoped credentials to headers
|
||||
extra_headers = {}
|
||||
if credentials.matches_url(input_data.url):
|
||||
logger.debug(
|
||||
f"Applying host-scoped credentials {credentials.id} for URL {input_data.url}"
|
||||
)
|
||||
extra_headers.update(credentials.get_headers_dict())
|
||||
else:
|
||||
logger.warning(
|
||||
f"Host-scoped credentials {credentials.id} do not match URL {input_data.url}"
|
||||
)
|
||||
|
||||
# Merge with user-provided headers (user headers take precedence)
|
||||
base_input.headers = {**extra_headers, **input_data.headers}
|
||||
|
||||
# Use parent class run method
|
||||
async for output_name, output_data in super().run(
|
||||
base_input, graph_exec_id=graph_exec_id, **kwargs
|
||||
):
|
||||
yield output_name, output_data
|
||||
|
||||
@@ -413,12 +413,6 @@ class AgentFileInputBlock(AgentInputBlock):
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
base_64: bool = SchemaField(
|
||||
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
|
||||
default=False,
|
||||
advanced=True,
|
||||
title="Produce Base64 Output",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: str = SchemaField(description="File reference/path result.")
|
||||
@@ -452,11 +446,12 @@ class AgentFileInputBlock(AgentInputBlock):
|
||||
if not input_data.value:
|
||||
return
|
||||
|
||||
yield "result", await store_media_file(
|
||||
file_path = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.value,
|
||||
return_content=input_data.base_64,
|
||||
return_content=False,
|
||||
)
|
||||
yield "result", file_path
|
||||
|
||||
|
||||
class AgentDropdownInputBlock(AgentInputBlock):
|
||||
|
||||
@@ -23,7 +23,6 @@ from backend.data.model import (
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.prompt import compress_prompt, estimate_token_count
|
||||
from backend.util.text import TextFormatter
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||
@@ -41,7 +40,7 @@ LLMProviderName = Literal[
|
||||
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="769f6af7-820b-4d5d-9b7a-ab82bbc165f",
|
||||
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
provider="openai",
|
||||
api_key=SecretStr("mock-openai-api-key"),
|
||||
title="Mock OpenAI API key",
|
||||
@@ -127,9 +126,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE = (
|
||||
"perplexity/llama-3.1-sonar-large-128k-online"
|
||||
)
|
||||
PERPLEXITY_SONAR = "perplexity/sonar"
|
||||
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
|
||||
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
QWEN_QWQ_32B_PREVIEW = "qwen/qwq-32b-preview"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
|
||||
@@ -232,13 +228,6 @@ MODEL_METADATA = {
|
||||
LlmModel.PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE: ModelMetadata(
|
||||
"open_router", 127072, 127072
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR: ModelMetadata("open_router", 127000, 127000),
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata("open_router", 200000, 8000),
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
128000,
|
||||
),
|
||||
LlmModel.QWEN_QWQ_32B_PREVIEW: ModelMetadata("open_router", 32768, 32768),
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata(
|
||||
"open_router", 131000, 4096
|
||||
@@ -283,7 +272,6 @@ class LLMResponse(BaseModel):
|
||||
tool_calls: Optional[List[ToolContentBlock]] | None
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
reasoning: Optional[str] = None
|
||||
|
||||
|
||||
def convert_openai_tool_fmt_to_anthropic(
|
||||
@@ -318,44 +306,11 @@ def convert_openai_tool_fmt_to_anthropic(
|
||||
return anthropic_tools
|
||||
|
||||
|
||||
def extract_openai_reasoning(response) -> str | None:
|
||||
"""Extract reasoning from OpenAI-compatible response if available."""
|
||||
"""Note: This will likely not working since the reasoning is not present in another Response API"""
|
||||
reasoning = None
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "reasoning") and getattr(choice, "reasoning", None):
|
||||
reasoning = str(getattr(choice, "reasoning"))
|
||||
elif hasattr(response, "reasoning") and getattr(response, "reasoning", None):
|
||||
reasoning = str(getattr(response, "reasoning"))
|
||||
elif hasattr(choice.message, "reasoning") and getattr(
|
||||
choice.message, "reasoning", None
|
||||
):
|
||||
reasoning = str(getattr(choice.message, "reasoning"))
|
||||
return reasoning
|
||||
|
||||
|
||||
def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
||||
"""Extract tool calls from OpenAI-compatible response."""
|
||||
if response.choices[0].message.tool_calls:
|
||||
return [
|
||||
ToolContentBlock(
|
||||
id=tool.id,
|
||||
type=tool.type,
|
||||
function=ToolCall(
|
||||
name=tool.function.name,
|
||||
arguments=tool.function.arguments,
|
||||
),
|
||||
)
|
||||
for tool in response.choices[0].message.tool_calls
|
||||
]
|
||||
return None
|
||||
|
||||
|
||||
def get_parallel_tool_calls_param(llm_model: LlmModel, parallel_tool_calls):
|
||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||
return openai.NOT_GIVEN
|
||||
return parallel_tool_calls
|
||||
def estimate_token_count(prompt_messages: list[dict]) -> int:
|
||||
char_count = sum(len(str(msg.get("content", ""))) for msg in prompt_messages)
|
||||
message_overhead = len(prompt_messages) * 4
|
||||
estimated_tokens = (char_count // 4) + message_overhead
|
||||
return int(estimated_tokens * 1.2)
|
||||
|
||||
|
||||
async def llm_call(
|
||||
@@ -366,8 +321,7 @@ async def llm_call(
|
||||
max_tokens: int | None,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
parallel_tool_calls=None,
|
||||
compress_prompt_to_fit: bool = True,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Make a call to a language model.
|
||||
@@ -390,32 +344,28 @@ async def llm_call(
|
||||
- completion_tokens: The number of tokens used in the completion.
|
||||
"""
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
|
||||
if compress_prompt_to_fit:
|
||||
prompt = compress_prompt(
|
||||
messages=prompt,
|
||||
target_tokens=llm_model.context_window // 2,
|
||||
lossy_ok=True,
|
||||
)
|
||||
|
||||
# Calculate available tokens based on context window and input length
|
||||
estimated_input_tokens = estimate_token_count(prompt)
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or 4096
|
||||
user_max = max_tokens or model_max_output
|
||||
available_tokens = max(context_window - estimated_input_tokens, 0)
|
||||
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
|
||||
max_tokens = max(min(available_tokens, model_max_output, user_max), 0)
|
||||
|
||||
if provider == "openai":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
oai_client = openai.AsyncOpenAI(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = None
|
||||
|
||||
parallel_tool_calls = get_parallel_tool_calls_param(
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
if json_format:
|
||||
if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
prompt = [
|
||||
{"role": "user", "content": "\n".join(sys_messages)},
|
||||
{"role": "user", "content": "\n".join(usr_messages)},
|
||||
]
|
||||
elif json_format:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
response = await oai_client.chat.completions.create(
|
||||
@@ -424,11 +374,25 @@ async def llm_call(
|
||||
response_format=response_format, # type: ignore
|
||||
max_completion_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
parallel_tool_calls=(
|
||||
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
|
||||
),
|
||||
)
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
if response.choices[0].message.tool_calls:
|
||||
tool_calls = [
|
||||
ToolContentBlock(
|
||||
id=tool.id,
|
||||
type=tool.type,
|
||||
function=ToolCall(
|
||||
name=tool.function.name,
|
||||
arguments=tool.function.arguments,
|
||||
),
|
||||
)
|
||||
for tool in response.choices[0].message.tool_calls
|
||||
]
|
||||
else:
|
||||
tool_calls = None
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
@@ -437,7 +401,6 @@ async def llm_call(
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
|
||||
@@ -499,12 +462,6 @@ async def llm_call(
|
||||
f"Tool use stop reason but no tool calls found in content. {resp}"
|
||||
)
|
||||
|
||||
reasoning = None
|
||||
for content_block in resp.content:
|
||||
if hasattr(content_block, "type") and content_block.type == "thinking":
|
||||
reasoning = content_block.thinking
|
||||
break
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=resp,
|
||||
prompt=prompt,
|
||||
@@ -516,7 +473,6 @@ async def llm_call(
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=resp.usage.input_tokens,
|
||||
completion_tokens=resp.usage.output_tokens,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
except anthropic.APIError as e:
|
||||
error_message = f"Anthropic API error: {str(e)}"
|
||||
@@ -541,7 +497,6 @@ async def llm_call(
|
||||
tool_calls=None,
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
reasoning=None,
|
||||
)
|
||||
elif provider == "ollama":
|
||||
if tools:
|
||||
@@ -563,7 +518,6 @@ async def llm_call(
|
||||
tool_calls=None,
|
||||
prompt_tokens=response.get("prompt_eval_count") or 0,
|
||||
completion_tokens=response.get("eval_count") or 0,
|
||||
reasoning=None,
|
||||
)
|
||||
elif provider == "open_router":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
@@ -572,10 +526,6 @@ async def llm_call(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
extra_headers={
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
@@ -585,7 +535,6 @@ async def llm_call(
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
@@ -595,8 +544,19 @@ async def llm_call(
|
||||
else:
|
||||
raise ValueError("No response from OpenRouter.")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
if response.choices[0].message.tool_calls:
|
||||
tool_calls = [
|
||||
ToolContentBlock(
|
||||
id=tool.id,
|
||||
type=tool.type,
|
||||
function=ToolCall(
|
||||
name=tool.function.name, arguments=tool.function.arguments
|
||||
),
|
||||
)
|
||||
for tool in response.choices[0].message.tool_calls
|
||||
]
|
||||
else:
|
||||
tool_calls = None
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
@@ -605,7 +565,6 @@ async def llm_call(
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "llama_api":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
@@ -614,10 +573,6 @@ async def llm_call(
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
llm_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
extra_headers={
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
@@ -627,7 +582,9 @@ async def llm_call(
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
parallel_tool_calls=parallel_tool_calls_param,
|
||||
parallel_tool_calls=(
|
||||
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
|
||||
),
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
@@ -637,8 +594,19 @@ async def llm_call(
|
||||
else:
|
||||
raise ValueError("No response from Llama API.")
|
||||
|
||||
tool_calls = extract_openai_tool_calls(response)
|
||||
reasoning = extract_openai_reasoning(response)
|
||||
if response.choices[0].message.tool_calls:
|
||||
tool_calls = [
|
||||
ToolContentBlock(
|
||||
id=tool.id,
|
||||
type=tool.type,
|
||||
function=ToolCall(
|
||||
name=tool.function.name, arguments=tool.function.arguments
|
||||
),
|
||||
)
|
||||
for tool in response.choices[0].message.tool_calls
|
||||
]
|
||||
else:
|
||||
tool_calls = None
|
||||
|
||||
return LLMResponse(
|
||||
raw_response=response.choices[0].message,
|
||||
@@ -647,7 +615,6 @@ async def llm_call(
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "aiml_api":
|
||||
client = openai.OpenAI(
|
||||
@@ -671,7 +638,6 @@ async def llm_call(
|
||||
completion_tokens=(
|
||||
completion.usage.completion_tokens if completion.usage else 0
|
||||
),
|
||||
reasoning=None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
@@ -697,11 +663,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
description="Expected format of the response. If provided, the response will be validated against this format. "
|
||||
"The keys should be the expected fields in the response, and the values should be the description of the field.",
|
||||
)
|
||||
list_result: bool = SchemaField(
|
||||
title="List Result",
|
||||
default=False,
|
||||
description="Whether the response should be a list of objects in the expected format.",
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
@@ -733,11 +694,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate in the chat completion.",
|
||||
)
|
||||
compress_prompt_to_fit: bool = SchemaField(
|
||||
advanced=True,
|
||||
default=True,
|
||||
description="Whether to compress the prompt to fit within the model's context window.",
|
||||
)
|
||||
|
||||
ollama_host: str = SchemaField(
|
||||
advanced=True,
|
||||
default="localhost:11434",
|
||||
@@ -745,7 +702,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: dict[str, Any] | list[dict[str, Any]] = SchemaField(
|
||||
response: dict[str, Any] = SchemaField(
|
||||
description="The response object generated by the language model."
|
||||
)
|
||||
prompt: list = SchemaField(description="The prompt sent to the language model.")
|
||||
@@ -785,7 +742,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
tool_calls=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
reasoning=None,
|
||||
)
|
||||
},
|
||||
)
|
||||
@@ -796,7 +752,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
llm_model: LlmModel,
|
||||
prompt: list[dict],
|
||||
json_format: bool,
|
||||
compress_prompt_to_fit: bool,
|
||||
max_tokens: int | None,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
@@ -814,7 +769,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
ollama_host=ollama_host,
|
||||
compress_prompt_to_fit=compress_prompt_to_fit,
|
||||
)
|
||||
|
||||
async def run(
|
||||
@@ -839,22 +793,13 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
expected_format = [
|
||||
f'"{k}": "{v}"' for k, v in input_data.expected_format.items()
|
||||
]
|
||||
if input_data.list_result:
|
||||
format_prompt = (
|
||||
f'"results": [\n {{\n {", ".join(expected_format)}\n }}\n]'
|
||||
)
|
||||
else:
|
||||
format_prompt = "\n ".join(expected_format)
|
||||
|
||||
format_prompt = ",\n ".join(expected_format)
|
||||
sys_prompt = trim_prompt(
|
||||
f"""
|
||||
|Reply strictly only in the following JSON format:
|
||||
|{{
|
||||
| {format_prompt}
|
||||
|}}
|
||||
|
|
||||
|Ensure the response is valid JSON. Do not include any additional text outside of the JSON.
|
||||
|If you cannot provide all the keys, provide an empty string for the values you cannot answer.
|
||||
"""
|
||||
)
|
||||
prompt.append({"role": "system", "content": sys_prompt})
|
||||
@@ -862,18 +807,19 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
if input_data.prompt:
|
||||
prompt.append({"role": "user", "content": input_data.prompt})
|
||||
|
||||
def validate_response(parsed: object) -> str | None:
|
||||
def parse_response(resp: str) -> tuple[dict[str, Any], str | None]:
|
||||
try:
|
||||
parsed = json.loads(resp)
|
||||
if not isinstance(parsed, dict):
|
||||
return f"Expected a dictionary, but got {type(parsed)}"
|
||||
return {}, f"Expected a dictionary, but got {type(parsed)}"
|
||||
miss_keys = set(input_data.expected_format.keys()) - set(parsed.keys())
|
||||
if miss_keys:
|
||||
return f"Missing keys: {miss_keys}"
|
||||
return None
|
||||
return parsed, f"Missing keys: {miss_keys}"
|
||||
return parsed, None
|
||||
except JSONDecodeError as e:
|
||||
return f"JSON decode error: {e}"
|
||||
return {}, f"JSON decode error: {e}"
|
||||
|
||||
logger.debug(f"LLM request: {prompt}")
|
||||
logger.info(f"LLM request: {prompt}")
|
||||
retry_prompt = ""
|
||||
llm_model = input_data.model
|
||||
|
||||
@@ -883,7 +829,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
compress_prompt_to_fit=input_data.compress_prompt_to_fit,
|
||||
json_format=bool(input_data.expected_format),
|
||||
ollama_host=input_data.ollama_host,
|
||||
max_tokens=input_data.max_tokens,
|
||||
@@ -895,32 +840,21 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
output_token_count=llm_response.completion_tokens,
|
||||
)
|
||||
)
|
||||
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
logger.info(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
if input_data.expected_format:
|
||||
|
||||
response_obj = json.loads(response_text)
|
||||
|
||||
if input_data.list_result and isinstance(response_obj, dict):
|
||||
if "results" in response_obj:
|
||||
response_obj = response_obj.get("results", [])
|
||||
elif len(response_obj) == 1:
|
||||
response_obj = list(response_obj.values())
|
||||
|
||||
response_error = "\n".join(
|
||||
[
|
||||
validation_error
|
||||
for response_item in (
|
||||
response_obj
|
||||
if isinstance(response_obj, list)
|
||||
else [response_obj]
|
||||
parsed_dict, parsed_error = parse_response(response_text)
|
||||
if not parsed_error:
|
||||
yield "response", {
|
||||
k: (
|
||||
json.loads(v)
|
||||
if isinstance(v, str)
|
||||
and v.startswith("[")
|
||||
and v.endswith("]")
|
||||
else (", ".join(v) if isinstance(v, list) else v)
|
||||
)
|
||||
if (validation_error := validate_response(response_item))
|
||||
]
|
||||
)
|
||||
|
||||
if not response_error:
|
||||
yield "response", response_obj
|
||||
for k, v in parsed_dict.items()
|
||||
}
|
||||
yield "prompt", self.prompt
|
||||
return
|
||||
else:
|
||||
@@ -937,7 +871,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
|
||||
|And this is the error:
|
||||
|--
|
||||
|{response_error}
|
||||
|{parsed_error}
|
||||
|--
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@ from backend.data.model import (
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="8cc8b2c5-d3e4-4b1c-84ad-e1e9fe2a0122",
|
||||
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
provider="mem0",
|
||||
api_key=SecretStr("mock-mem0-api-key"),
|
||||
title="Mock Mem0 API key",
|
||||
@@ -67,19 +67,17 @@ class AddMemoryBlock(Block, Mem0Base):
|
||||
metadata: dict[str, Any] = SchemaField(
|
||||
description="Optional metadata for the memory", default_factory=dict
|
||||
)
|
||||
|
||||
limit_memory_to_run: bool = SchemaField(
|
||||
description="Limit the memory to the run", default=False
|
||||
)
|
||||
limit_memory_to_agent: bool = SchemaField(
|
||||
description="Limit the memory to the agent", default=True
|
||||
description="Limit the memory to the agent", default=False
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
action: str = SchemaField(description="Action of the operation")
|
||||
memory: str = SchemaField(description="Memory created")
|
||||
results: list[dict[str, str]] = SchemaField(
|
||||
description="List of all results from the operation"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if operation fails")
|
||||
|
||||
def __init__(self):
|
||||
@@ -106,14 +104,7 @@ class AddMemoryBlock(Block, Mem0Base):
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("results", [{"event": "CREATED", "memory": "test memory"}]),
|
||||
("action", "CREATED"),
|
||||
("memory", "test memory"),
|
||||
("results", [{"event": "CREATED", "memory": "test memory"}]),
|
||||
("action", "CREATED"),
|
||||
("memory", "test memory"),
|
||||
],
|
||||
test_output=[("action", "NO_CHANGE"), ("action", "NO_CHANGE")],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={"_get_client": lambda credentials: MockMemoryClient()},
|
||||
)
|
||||
@@ -126,7 +117,7 @@ class AddMemoryBlock(Block, Mem0Base):
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
**kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
client = self._get_client(credentials)
|
||||
@@ -155,11 +146,8 @@ class AddMemoryBlock(Block, Mem0Base):
|
||||
**params,
|
||||
)
|
||||
|
||||
results = result.get("results", [])
|
||||
yield "results", results
|
||||
|
||||
if len(results) > 0:
|
||||
for result in results:
|
||||
if len(result.get("results", [])) > 0:
|
||||
for result in result.get("results", []):
|
||||
yield "action", result["event"]
|
||||
yield "memory", result["memory"]
|
||||
else:
|
||||
@@ -190,10 +178,6 @@ class SearchMemoryBlock(Block, Mem0Base):
|
||||
default_factory=list,
|
||||
advanced=True,
|
||||
)
|
||||
metadata_filter: Optional[dict[str, Any]] = SchemaField(
|
||||
description="Optional metadata filters to apply",
|
||||
default=None,
|
||||
)
|
||||
limit_memory_to_run: bool = SchemaField(
|
||||
description="Limit the memory to the run", default=False
|
||||
)
|
||||
@@ -232,7 +216,7 @@ class SearchMemoryBlock(Block, Mem0Base):
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
**kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
client = self._get_client(credentials)
|
||||
@@ -251,8 +235,6 @@ class SearchMemoryBlock(Block, Mem0Base):
|
||||
filters["AND"].append({"run_id": graph_exec_id})
|
||||
if input_data.limit_memory_to_agent:
|
||||
filters["AND"].append({"agent_id": graph_id})
|
||||
if input_data.metadata_filter:
|
||||
filters["AND"].append({"metadata": input_data.metadata_filter})
|
||||
|
||||
result: list[dict[str, Any]] = client.search(
|
||||
input_data.query, version="v2", filters=filters
|
||||
@@ -278,15 +260,11 @@ class GetAllMemoriesBlock(Block, Mem0Base):
|
||||
categories: Optional[list[str]] = SchemaField(
|
||||
description="Filter by categories", default=None
|
||||
)
|
||||
metadata_filter: Optional[dict[str, Any]] = SchemaField(
|
||||
description="Optional metadata filters to apply",
|
||||
default=None,
|
||||
)
|
||||
limit_memory_to_run: bool = SchemaField(
|
||||
description="Limit the memory to the run", default=False
|
||||
)
|
||||
limit_memory_to_agent: bool = SchemaField(
|
||||
description="Limit the memory to the agent", default=True
|
||||
description="Limit the memory to the agent", default=False
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -296,11 +274,11 @@ class GetAllMemoriesBlock(Block, Mem0Base):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="45aee5bf-4767-45d1-a28b-e01c5aae9fc1",
|
||||
description="Retrieve all memories from Mem0 with optional conversation filtering",
|
||||
description="Retrieve all memories from Mem0 with pagination",
|
||||
input_schema=GetAllMemoriesBlock.Input,
|
||||
output_schema=GetAllMemoriesBlock.Output,
|
||||
test_input={
|
||||
"metadata_filter": {"type": "test"},
|
||||
"user_id": "test_user",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
@@ -318,7 +296,7 @@ class GetAllMemoriesBlock(Block, Mem0Base):
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
**kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
client = self._get_client(credentials)
|
||||
@@ -336,8 +314,6 @@ class GetAllMemoriesBlock(Block, Mem0Base):
|
||||
filters["AND"].append(
|
||||
{"categories": {"contains": input_data.categories}}
|
||||
)
|
||||
if input_data.metadata_filter:
|
||||
filters["AND"].append({"metadata": input_data.metadata_filter})
|
||||
|
||||
memories: list[dict[str, Any]] = client.get_all(
|
||||
filters=filters,
|
||||
@@ -350,116 +326,14 @@ class GetAllMemoriesBlock(Block, Mem0Base):
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GetLatestMemoryBlock(Block, Mem0Base):
|
||||
"""Block for retrieving the latest memory from Mem0"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.MEM0], Literal["api_key"]
|
||||
] = CredentialsField(description="Mem0 API key credentials")
|
||||
trigger: bool = SchemaField(
|
||||
description="An unused field that is used to trigger the block when you have no other inputs",
|
||||
default=False,
|
||||
advanced=False,
|
||||
)
|
||||
categories: Optional[list[str]] = SchemaField(
|
||||
description="Filter by categories", default=None
|
||||
)
|
||||
conversation_id: Optional[str] = SchemaField(
|
||||
description="Optional conversation ID to retrieve the latest memory from (uses run_id)",
|
||||
default=None,
|
||||
)
|
||||
metadata_filter: Optional[dict[str, Any]] = SchemaField(
|
||||
description="Optional metadata filters to apply",
|
||||
default=None,
|
||||
)
|
||||
limit_memory_to_run: bool = SchemaField(
|
||||
description="Limit the memory to the run", default=False
|
||||
)
|
||||
limit_memory_to_agent: bool = SchemaField(
|
||||
description="Limit the memory to the agent", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
memory: Optional[dict[str, Any]] = SchemaField(
|
||||
description="Latest memory if found"
|
||||
)
|
||||
found: bool = SchemaField(description="Whether a memory was found")
|
||||
error: str = SchemaField(description="Error message if operation fails")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0f9d81b5-a145-4c23-b87f-01d6bf37b677",
|
||||
description="Retrieve the latest memory from Mem0 with optional key filtering",
|
||||
input_schema=GetLatestMemoryBlock.Input,
|
||||
output_schema=GetLatestMemoryBlock.Output,
|
||||
test_input={
|
||||
"metadata_filter": {"type": "test"},
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
("memory", {"id": "test-memory", "content": "test content"}),
|
||||
("found", True),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={"_get_client": lambda credentials: MockMemoryClient()},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
client = self._get_client(credentials)
|
||||
|
||||
filters: Filter = {
|
||||
"AND": [
|
||||
{"user_id": user_id},
|
||||
]
|
||||
}
|
||||
if input_data.limit_memory_to_run:
|
||||
filters["AND"].append({"run_id": graph_exec_id})
|
||||
if input_data.limit_memory_to_agent:
|
||||
filters["AND"].append({"agent_id": graph_id})
|
||||
if input_data.categories:
|
||||
filters["AND"].append(
|
||||
{"categories": {"contains": input_data.categories}}
|
||||
)
|
||||
if input_data.metadata_filter:
|
||||
filters["AND"].append({"metadata": input_data.metadata_filter})
|
||||
|
||||
memories: list[dict[str, Any]] = client.get_all(
|
||||
filters=filters,
|
||||
version="v2",
|
||||
)
|
||||
|
||||
if memories:
|
||||
# Return the latest memory (first in the list as they're sorted by recency)
|
||||
latest_memory = memories[0]
|
||||
yield "memory", latest_memory
|
||||
yield "found", True
|
||||
else:
|
||||
yield "memory", None
|
||||
yield "found", False
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
# Mock client for testing
|
||||
class MockMemoryClient:
|
||||
"""Mock Mem0 client for testing"""
|
||||
|
||||
def add(self, *args, **kwargs):
|
||||
return {"results": [{"event": "CREATED", "memory": "test memory"}]}
|
||||
return {"memory_id": "test-memory-id", "status": "success"}
|
||||
|
||||
def search(self, *args, **kwargs) -> list[dict[str, Any]]:
|
||||
def search(self, *args, **kwargs) -> list[dict[str, str]]:
|
||||
return [{"id": "test-memory", "content": "test content"}]
|
||||
|
||||
def get_all(self, *args, **kwargs) -> list[dict[str, str]]:
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client():
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
|
||||
|
||||
|
||||
StorageScope = Literal["within_agent", "across_agents"]
|
||||
|
||||
|
||||
def get_storage_key(key: str, scope: StorageScope, graph_id: str) -> str:
|
||||
"""Generate the storage key based on scope"""
|
||||
if scope == "across_agents":
|
||||
return f"global#{key}"
|
||||
else:
|
||||
return f"agent#{graph_id}#{key}"
|
||||
|
||||
|
||||
class PersistInformationBlock(Block):
|
||||
"""Block for persisting key-value data for the current user with configurable scope"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
key: str = SchemaField(description="Key to store the information under")
|
||||
value: Any = SchemaField(description="Value to store")
|
||||
scope: StorageScope = SchemaField(
|
||||
description="Scope of persistence: within_agent (shared across all runs of this agent) or across_agents (shared across all agents for this user)",
|
||||
default="within_agent",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
value: Any = SchemaField(description="Value that was stored")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1d055e55-a2b9-4547-8311-907d05b0304d",
|
||||
description="Persist key-value information for the current user",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=PersistInformationBlock.Input,
|
||||
output_schema=PersistInformationBlock.Output,
|
||||
test_input={
|
||||
"key": "user_preference",
|
||||
"value": {"theme": "dark", "language": "en"},
|
||||
"scope": "within_agent",
|
||||
},
|
||||
test_output=[
|
||||
("value", {"theme": "dark", "language": "en"}),
|
||||
],
|
||||
test_mock={
|
||||
"_store_data": lambda *args, **kwargs: {
|
||||
"theme": "dark",
|
||||
"language": "en",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
node_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Determine the storage key based on scope
|
||||
storage_key = get_storage_key(input_data.key, input_data.scope, graph_id)
|
||||
|
||||
# Store the data
|
||||
yield "value", await self._store_data(
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
key=storage_key,
|
||||
data=input_data.value,
|
||||
)
|
||||
|
||||
async def _store_data(
|
||||
self, user_id: str, node_exec_id: str, key: str, data: Any
|
||||
) -> Any | None:
|
||||
return await get_database_manager_client().set_execution_kv_data(
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
key=key,
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
class RetrieveInformationBlock(Block):
|
||||
"""Block for retrieving key-value data for the current user with configurable scope"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
key: str = SchemaField(description="Key to retrieve the information for")
|
||||
scope: StorageScope = SchemaField(
|
||||
description="Scope of persistence: within_agent (shared across all runs of this agent) or across_agents (shared across all agents for this user)",
|
||||
default="within_agent",
|
||||
)
|
||||
default_value: Any = SchemaField(
|
||||
description="Default value to return if key is not found", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
value: Any = SchemaField(description="Retrieved value or default value")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d8710fc9-6e29-481e-a7d5-165eb16f8471",
|
||||
description="Retrieve key-value information for the current user",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=RetrieveInformationBlock.Input,
|
||||
output_schema=RetrieveInformationBlock.Output,
|
||||
test_input={
|
||||
"key": "user_preference",
|
||||
"scope": "within_agent",
|
||||
"default_value": {"theme": "light", "language": "en"},
|
||||
},
|
||||
test_output=[
|
||||
("value", {"theme": "light", "language": "en"}),
|
||||
],
|
||||
test_mock={"_retrieve_data": lambda *args, **kwargs: None},
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, user_id: str, graph_id: str, **kwargs
|
||||
) -> BlockOutput:
|
||||
# Determine the storage key based on scope
|
||||
storage_key = get_storage_key(input_data.key, input_data.scope, graph_id)
|
||||
|
||||
# Retrieve the data
|
||||
stored_value = await self._retrieve_data(
|
||||
user_id=user_id,
|
||||
key=storage_key,
|
||||
)
|
||||
|
||||
if stored_value is not None:
|
||||
yield "value", stored_value
|
||||
else:
|
||||
yield "value", input_data.default_value
|
||||
|
||||
async def _retrieve_data(self, user_id: str, key: str) -> Any | None:
|
||||
return await get_database_manager_client().get_execution_kv_data(
|
||||
user_id=user_id,
|
||||
key=key,
|
||||
)
|
||||
@@ -96,7 +96,6 @@ class GetRedditPostsBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
post: RedditPost = SchemaField(description="Reddit post")
|
||||
posts: list[RedditPost] = SchemaField(description="List of all Reddit posts")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -129,23 +128,6 @@ class GetRedditPostsBlock(Block):
|
||||
id="id2", subreddit="subreddit", title="title2", body="body2"
|
||||
),
|
||||
),
|
||||
(
|
||||
"posts",
|
||||
[
|
||||
RedditPost(
|
||||
id="id1",
|
||||
subreddit="subreddit",
|
||||
title="title1",
|
||||
body="body1",
|
||||
),
|
||||
RedditPost(
|
||||
id="id2",
|
||||
subreddit="subreddit",
|
||||
title="title2",
|
||||
body="body2",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_posts": lambda input_data, credentials: [
|
||||
@@ -168,7 +150,6 @@ class GetRedditPostsBlock(Block):
|
||||
self, input_data: Input, *, credentials: RedditCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
current_time = datetime.now(tz=timezone.utc)
|
||||
all_posts = []
|
||||
for post in self.get_posts(input_data=input_data, credentials=credentials):
|
||||
if input_data.last_minutes:
|
||||
post_datetime = datetime.fromtimestamp(
|
||||
@@ -181,16 +162,12 @@ class GetRedditPostsBlock(Block):
|
||||
if input_data.last_post and post.id == input_data.last_post:
|
||||
break
|
||||
|
||||
reddit_post = RedditPost(
|
||||
yield "post", RedditPost(
|
||||
id=post.id,
|
||||
subreddit=input_data.subreddit,
|
||||
title=post.title,
|
||||
body=post.selftext,
|
||||
)
|
||||
all_posts.append(reddit_post)
|
||||
yield "post", reddit_post
|
||||
|
||||
yield "posts", all_posts
|
||||
|
||||
|
||||
class PostRedditCommentBlock(Block):
|
||||
|
||||
@@ -40,7 +40,6 @@ class ReadRSSFeedBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
entry: RSSEntry = SchemaField(description="The RSS item")
|
||||
entries: list[RSSEntry] = SchemaField(description="List of all RSS entries")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -67,21 +66,6 @@ class ReadRSSFeedBlock(Block):
|
||||
categories=["Technology", "News"],
|
||||
),
|
||||
),
|
||||
(
|
||||
"entries",
|
||||
[
|
||||
RSSEntry(
|
||||
title="Example RSS Item",
|
||||
link="https://example.com/article",
|
||||
description="This is an example RSS item description.",
|
||||
pub_date=datetime(
|
||||
2023, 6, 23, 12, 30, 0, tzinfo=timezone.utc
|
||||
),
|
||||
author="John Doe",
|
||||
categories=["Technology", "News"],
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"parse_feed": lambda *args, **kwargs: {
|
||||
@@ -112,22 +96,21 @@ class ReadRSSFeedBlock(Block):
|
||||
keep_going = input_data.run_continuously
|
||||
|
||||
feed = self.parse_feed(input_data.rss_url)
|
||||
all_entries = []
|
||||
|
||||
for entry in feed["entries"]:
|
||||
pub_date = datetime(*entry["published_parsed"][:6], tzinfo=timezone.utc)
|
||||
|
||||
if pub_date > start_time:
|
||||
rss_entry = RSSEntry(
|
||||
title=entry["title"],
|
||||
link=entry["link"],
|
||||
description=entry.get("summary", ""),
|
||||
pub_date=pub_date,
|
||||
author=entry.get("author", ""),
|
||||
categories=[tag["term"] for tag in entry.get("tags", [])],
|
||||
yield (
|
||||
"entry",
|
||||
RSSEntry(
|
||||
title=entry["title"],
|
||||
link=entry["link"],
|
||||
description=entry.get("summary", ""),
|
||||
pub_date=pub_date,
|
||||
author=entry.get("author", ""),
|
||||
categories=[tag["term"] for tag in entry.get("tags", [])],
|
||||
),
|
||||
)
|
||||
all_entries.append(rss_entry)
|
||||
yield "entry", rss_entry
|
||||
|
||||
yield "entries", all_entries
|
||||
await asyncio.sleep(input_data.polling_rate)
|
||||
|
||||
@@ -26,10 +26,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client():
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.executor import DatabaseManagerClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient, health_check=False)
|
||||
return get_service_client(DatabaseManagerClient)
|
||||
|
||||
|
||||
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||
@@ -85,7 +85,7 @@ def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
|
||||
return tool_call_ids
|
||||
|
||||
|
||||
def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
|
||||
def _create_tool_response(call_id: str, output: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Create a tool response message for either OpenAI or Anthropics,
|
||||
based on the tool_id format.
|
||||
@@ -142,12 +142,6 @@ class SmartDecisionMakerBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
credentials: llm.AICredentials = llm.AICredentialsField()
|
||||
multiple_tool_calls: bool = SchemaField(
|
||||
title="Multiple Tool Calls",
|
||||
default=False,
|
||||
description="Whether to allow multiple tool calls in a single response.",
|
||||
advanced=True,
|
||||
)
|
||||
sys_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default="Thinking carefully step by step decide which function to call. "
|
||||
@@ -156,7 +150,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
"matching the required jsonschema signature, no missing argument is allowed. "
|
||||
"If you have already completed the task objective, you can end the task "
|
||||
"by providing the end result of your work as a finish message. "
|
||||
"Function parameters that has no default value and not optional typed has to be provided. ",
|
||||
"Only provide EXACTLY one function call, multiple tool calls is strictly prohibited.",
|
||||
description="The system prompt to provide additional context to the model.",
|
||||
)
|
||||
conversation_history: list[dict] = SchemaField(
|
||||
@@ -212,15 +206,6 @@ class SmartDecisionMakerBlock(Block):
|
||||
"link like the output of `StoreValue` or `AgentInput` block"
|
||||
)
|
||||
|
||||
# Check that both conversation_history and last_tool_output are connected together
|
||||
if any(link.sink_name == "conversation_history" for link in links) != any(
|
||||
link.sink_name == "last_tool_output" for link in links
|
||||
):
|
||||
raise ValueError(
|
||||
"Last Tool Output is needed when Conversation History is used, "
|
||||
"and vice versa. Please connect both inputs together."
|
||||
)
|
||||
|
||||
return missing_links
|
||||
|
||||
@classmethod
|
||||
@@ -231,15 +216,8 @@ class SmartDecisionMakerBlock(Block):
|
||||
conversation_history = data.get("conversation_history", [])
|
||||
pending_tool_calls = get_pending_tool_calls(conversation_history)
|
||||
last_tool_output = data.get("last_tool_output")
|
||||
|
||||
# Tool call is pending, wait for the tool output to be provided.
|
||||
if last_tool_output is None and pending_tool_calls:
|
||||
if not last_tool_output and pending_tool_calls:
|
||||
return {"last_tool_output"}
|
||||
|
||||
# No tool call is pending, wait for the conversation history to be updated.
|
||||
if last_tool_output is not None and not pending_tool_calls:
|
||||
return {"conversation_history"}
|
||||
|
||||
return set()
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -273,7 +251,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
|
||||
|
||||
@staticmethod
|
||||
async def _create_block_function_signature(
|
||||
def _create_block_function_signature(
|
||||
sink_node: "Node", links: list["Link"]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
@@ -295,24 +273,35 @@ class SmartDecisionMakerBlock(Block):
|
||||
"name": SmartDecisionMakerBlock.cleanup(block.name),
|
||||
"description": block.description,
|
||||
}
|
||||
sink_block_input_schema = block.input_schema
|
||||
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for link in links:
|
||||
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
|
||||
properties[sink_name] = sink_block_input_schema.get_field_schema(
|
||||
link.sink_name
|
||||
sink_block_input_schema = block.input_schema
|
||||
description = (
|
||||
sink_block_input_schema.model_fields[link.sink_name].description
|
||||
if link.sink_name in sink_block_input_schema.model_fields
|
||||
and sink_block_input_schema.model_fields[link.sink_name].description
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
}
|
||||
|
||||
tool_function["parameters"] = {
|
||||
**block.input_schema.jsonschema(),
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
"additionalProperties": False,
|
||||
"strict": True,
|
||||
}
|
||||
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
async def _create_agent_function_signature(
|
||||
def _create_agent_function_signature(
|
||||
sink_node: "Node", links: list["Link"]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
@@ -334,7 +323,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
raise ValueError("Graph ID or Graph Version not found in sink node.")
|
||||
|
||||
db_client = get_database_manager_client()
|
||||
sink_graph_meta = await db_client.get_graph_metadata(graph_id, graph_version)
|
||||
sink_graph_meta = db_client.get_graph_metadata(graph_id, graph_version)
|
||||
if not sink_graph_meta:
|
||||
raise ValueError(
|
||||
f"Sink graph metadata not found: {graph_id} {graph_version}"
|
||||
@@ -346,27 +335,25 @@ class SmartDecisionMakerBlock(Block):
|
||||
}
|
||||
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for link in links:
|
||||
sink_block_input_schema = sink_node.input_default["input_schema"]
|
||||
sink_block_properties = sink_block_input_schema.get("properties", {}).get(
|
||||
link.sink_name, {}
|
||||
)
|
||||
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
|
||||
description = (
|
||||
sink_block_properties["description"]
|
||||
if "description" in sink_block_properties
|
||||
sink_block_input_schema["properties"][link.sink_name]["description"]
|
||||
if "description"
|
||||
in sink_block_input_schema["properties"][link.sink_name]
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[sink_name] = {
|
||||
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
"default": json.dumps(sink_block_properties.get("default", None)),
|
||||
}
|
||||
|
||||
tool_function["parameters"] = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
"additionalProperties": False,
|
||||
"strict": True,
|
||||
}
|
||||
@@ -374,7 +361,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
async def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
|
||||
def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Creates function signatures for tools linked to a specified node within a graph.
|
||||
|
||||
@@ -396,13 +383,13 @@ class SmartDecisionMakerBlock(Block):
|
||||
db_client = get_database_manager_client()
|
||||
tools = [
|
||||
(link, node)
|
||||
for link, node in await db_client.get_connected_output_nodes(node_id)
|
||||
for link, node in db_client.get_connected_output_nodes(node_id)
|
||||
if link.source_name.startswith("tools_^_") and link.source_id == node_id
|
||||
]
|
||||
if not tools:
|
||||
raise ValueError("There is no next node to execute.")
|
||||
|
||||
return_tool_functions: list[dict[str, Any]] = []
|
||||
return_tool_functions = []
|
||||
|
||||
grouped_tool_links: dict[str, tuple["Node", list["Link"]]] = {}
|
||||
for link, node in tools:
|
||||
@@ -417,13 +404,13 @@ class SmartDecisionMakerBlock(Block):
|
||||
|
||||
if sink_node.block_id == AgentExecutorBlock().id:
|
||||
return_tool_functions.append(
|
||||
await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
else:
|
||||
return_tool_functions.append(
|
||||
await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
SmartDecisionMakerBlock._create_block_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
@@ -442,43 +429,37 @@ class SmartDecisionMakerBlock(Block):
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
tool_functions = await self._create_function_signature(node_id)
|
||||
yield "tool_functions", json.dumps(tool_functions)
|
||||
tool_functions = self._create_function_signature(node_id)
|
||||
|
||||
input_data.conversation_history = input_data.conversation_history or []
|
||||
prompt = [json.to_dict(p) for p in input_data.conversation_history if p]
|
||||
|
||||
pending_tool_calls = get_pending_tool_calls(input_data.conversation_history)
|
||||
if pending_tool_calls and input_data.last_tool_output is None:
|
||||
if pending_tool_calls and not input_data.last_tool_output:
|
||||
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
|
||||
|
||||
# Only assign the last tool output to the first pending tool call
|
||||
tool_output = []
|
||||
if pending_tool_calls and input_data.last_tool_output is not None:
|
||||
# Get the first pending tool call ID
|
||||
first_call_id = next(iter(pending_tool_calls.keys()))
|
||||
tool_output.append(
|
||||
_create_tool_response(first_call_id, input_data.last_tool_output)
|
||||
# Prefill all missing tool calls with the last tool output/
|
||||
# TODO: we need a better way to handle this.
|
||||
tool_output = [
|
||||
_create_tool_response(pending_call_id, input_data.last_tool_output)
|
||||
for pending_call_id, count in pending_tool_calls.items()
|
||||
for _ in range(count)
|
||||
]
|
||||
|
||||
# If the SDM block only calls 1 tool at a time, this should not happen.
|
||||
if len(tool_output) > 1:
|
||||
logger.warning(
|
||||
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
|
||||
f"Multiple pending tool calls are prefilled using a single output. "
|
||||
f"Execution may not be accurate."
|
||||
)
|
||||
|
||||
# Add tool output to prompt right away
|
||||
prompt.extend(tool_output)
|
||||
|
||||
# Check if there are still pending tool calls after handling the first one
|
||||
remaining_pending_calls = get_pending_tool_calls(prompt)
|
||||
|
||||
# If there are still pending tool calls, yield the conversation and return early
|
||||
if remaining_pending_calls:
|
||||
yield "conversations", prompt
|
||||
return
|
||||
|
||||
# Fallback on adding tool output in the conversation history as user prompt.
|
||||
elif input_data.last_tool_output:
|
||||
logger.error(
|
||||
if len(tool_output) == 0 and input_data.last_tool_output:
|
||||
logger.warning(
|
||||
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
|
||||
f"No pending tool calls found. This may indicate an issue with the "
|
||||
f"conversation history, or the tool giving response more than once."
|
||||
f"This should not happen! Please check the conversation history for any inconsistencies."
|
||||
f"conversation history, or an LLM calling two tools at the same time."
|
||||
)
|
||||
tool_output.append(
|
||||
{
|
||||
@@ -486,11 +467,8 @@ class SmartDecisionMakerBlock(Block):
|
||||
"content": f"Last tool output: {json.dumps(input_data.last_tool_output)}",
|
||||
}
|
||||
)
|
||||
prompt.extend(tool_output)
|
||||
if input_data.multiple_tool_calls:
|
||||
input_data.sys_prompt += "\nYou can call a tool (different tools) multiple times in a single response."
|
||||
else:
|
||||
input_data.sys_prompt += "\nOnly provide EXACTLY one function call, multiple tool calls is strictly prohibited."
|
||||
|
||||
prompt.extend(tool_output)
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
@@ -517,7 +495,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=input_data.multiple_tool_calls,
|
||||
parallel_tool_calls=False,
|
||||
)
|
||||
|
||||
if not response.tool_calls:
|
||||
@@ -528,37 +506,8 @@ class SmartDecisionMakerBlock(Block):
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
# Find the tool definition to get the expected arguments
|
||||
tool_def = next(
|
||||
(
|
||||
tool
|
||||
for tool in tool_functions
|
||||
if tool["function"]["name"] == tool_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
for arg_name, arg_value in tool_args.items():
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", arg_value
|
||||
|
||||
if (
|
||||
tool_def
|
||||
and "function" in tool_def
|
||||
and "parameters" in tool_def["function"]
|
||||
):
|
||||
expected_args = tool_def["function"]["parameters"].get("properties", {})
|
||||
else:
|
||||
expected_args = tool_args.keys()
|
||||
|
||||
# Yield provided arguments and None for missing ones
|
||||
for arg_name in expected_args:
|
||||
if arg_name in tool_args:
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", tool_args[arg_name]
|
||||
else:
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", None
|
||||
|
||||
# Add reasoning to conversation history if available
|
||||
if response.reasoning:
|
||||
prompt.append(
|
||||
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
|
||||
)
|
||||
|
||||
prompt.append(response.raw_response)
|
||||
yield "conversations", prompt
|
||||
response.prompt.append(response.raw_response)
|
||||
yield "conversations", response.prompt
|
||||
|
||||
@@ -17,7 +17,7 @@ from backend.blocks.smartlead.models import (
|
||||
Sequence,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class CreateCampaignBlock(Block):
|
||||
@@ -27,7 +27,7 @@ class CreateCampaignBlock(Block):
|
||||
name: str = SchemaField(
|
||||
description="The name of the campaign",
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = CredentialsField(
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
description="SmartLead credentials",
|
||||
)
|
||||
|
||||
@@ -119,7 +119,7 @@ class AddLeadToCampaignBlock(Block):
|
||||
description="Settings for lead upload",
|
||||
default=LeadUploadSettings(),
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = CredentialsField(
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
description="SmartLead credentials",
|
||||
)
|
||||
|
||||
@@ -251,7 +251,7 @@ class SaveCampaignSequencesBlock(Block):
|
||||
default_factory=list,
|
||||
advanced=False,
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = CredentialsField(
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
description="SmartLead credentials",
|
||||
)
|
||||
|
||||
|
||||
@@ -1,485 +0,0 @@
|
||||
"""Comprehensive tests for HTTP block with HostScopedCredentials functionality."""
|
||||
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.http import (
|
||||
HttpCredentials,
|
||||
HttpMethod,
|
||||
SendAuthenticatedWebRequestBlock,
|
||||
)
|
||||
from backend.data.model import HostScopedCredentials
|
||||
from backend.util.request import Response
|
||||
|
||||
|
||||
class TestHttpBlockWithHostScopedCredentials:
|
||||
"""Test suite for HTTP block integration with HostScopedCredentials."""
|
||||
|
||||
@pytest.fixture
|
||||
def http_block(self):
|
||||
"""Create an HTTP block instance."""
|
||||
return SendAuthenticatedWebRequestBlock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response(self):
|
||||
"""Mock a successful HTTP response."""
|
||||
response = MagicMock(spec=Response)
|
||||
response.status = 200
|
||||
response.headers = {"content-type": "application/json"}
|
||||
response.json.return_value = {"success": True, "data": "test"}
|
||||
return response
|
||||
|
||||
@pytest.fixture
|
||||
def exact_match_credentials(self):
|
||||
"""Create host-scoped credentials for exact domain matching."""
|
||||
return HostScopedCredentials(
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer exact-match-token"),
|
||||
"X-API-Key": SecretStr("api-key-123"),
|
||||
},
|
||||
title="Exact Match API Credentials",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def wildcard_credentials(self):
|
||||
"""Create host-scoped credentials with wildcard pattern."""
|
||||
return HostScopedCredentials(
|
||||
provider="http",
|
||||
host="*.github.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("token ghp_wildcard123"),
|
||||
},
|
||||
title="GitHub Wildcard Credentials",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def non_matching_credentials(self):
|
||||
"""Create credentials that don't match test URLs."""
|
||||
return HostScopedCredentials(
|
||||
provider="http",
|
||||
host="different.api.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer non-matching-token"),
|
||||
},
|
||||
title="Non-matching Credentials",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_http_block_with_exact_host_match(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
exact_match_credentials,
|
||||
mock_response,
|
||||
):
|
||||
"""Test HTTP block with exact host matching credentials."""
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Prepare input data
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.GET,
|
||||
headers={"User-Agent": "test-agent"},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": exact_match_credentials.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": exact_match_credentials.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with credentials provided by execution manager
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=exact_match_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify request headers include both credential and user headers
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {
|
||||
"Authorization": "Bearer exact-match-token",
|
||||
"X-API-Key": "api-key-123",
|
||||
"User-Agent": "test-agent",
|
||||
}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
# Verify response handling
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "response"
|
||||
assert result[0][1] == {"success": True, "data": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_http_block_with_wildcard_host_match(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
wildcard_credentials,
|
||||
mock_response,
|
||||
):
|
||||
"""Test HTTP block with wildcard host pattern matching."""
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with subdomain that should match *.github.com
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.github.com/user",
|
||||
method=HttpMethod.GET,
|
||||
headers={},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": wildcard_credentials.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": wildcard_credentials.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with wildcard credentials
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=wildcard_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify wildcard matching works
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {"Authorization": "token ghp_wildcard123"}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_http_block_with_non_matching_credentials(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
non_matching_credentials,
|
||||
mock_response,
|
||||
):
|
||||
"""Test HTTP block when credentials don't match the target URL."""
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with URL that doesn't match the credentials
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.GET,
|
||||
headers={"User-Agent": "test-agent"},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": non_matching_credentials.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": non_matching_credentials.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with non-matching credentials
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=non_matching_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify only user headers are included (no credential headers)
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {"User-Agent": "test-agent"}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_user_headers_override_credential_headers(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
exact_match_credentials,
|
||||
mock_response,
|
||||
):
|
||||
"""Test that user-provided headers take precedence over credential headers."""
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with user header that conflicts with credential header
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.POST,
|
||||
headers={
|
||||
"Authorization": "Bearer user-override-token", # Should override
|
||||
"Content-Type": "application/json", # Additional user header
|
||||
},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": exact_match_credentials.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": exact_match_credentials.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with conflicting headers
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=exact_match_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify user headers take precedence
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {
|
||||
"X-API-Key": "api-key-123", # From credentials
|
||||
"Authorization": "Bearer user-override-token", # User override
|
||||
"Content-Type": "application/json", # User header
|
||||
}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_auto_discovered_credentials_flow(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
mock_response,
|
||||
):
|
||||
"""Test the auto-discovery flow where execution manager provides matching credentials."""
|
||||
# Create auto-discovered credentials
|
||||
auto_discovered_creds = HostScopedCredentials(
|
||||
provider="http",
|
||||
host="*.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer auto-discovered-token"),
|
||||
},
|
||||
title="Auto-discovered Credentials",
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with empty credentials field (triggers auto-discovery)
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.GET,
|
||||
headers={},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": "", # Empty ID triggers auto-discovery in execution manager
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": "",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with auto-discovered credentials provided by execution manager
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=auto_discovered_creds, # Execution manager found these
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify auto-discovered credentials were applied
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {"Authorization": "Bearer auto-discovered-token"}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
# Verify response handling
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "response"
|
||||
assert result[0][1] == {"success": True, "data": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_multiple_header_credentials(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
mock_response,
|
||||
):
|
||||
"""Test credentials with multiple headers are all applied."""
|
||||
# Create credentials with multiple headers
|
||||
multi_header_creds = HostScopedCredentials(
|
||||
provider="http",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer multi-token"),
|
||||
"X-API-Key": SecretStr("api-key-456"),
|
||||
"X-Client-ID": SecretStr("client-789"),
|
||||
"X-Custom-Header": SecretStr("custom-value"),
|
||||
},
|
||||
title="Multi-Header Credentials",
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
# Test with credentials containing multiple headers
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url="https://api.example.com/data",
|
||||
method=HttpMethod.GET,
|
||||
headers={"User-Agent": "test-agent"},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": multi_header_creds.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": multi_header_creds.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with multi-header credentials
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=multi_header_creds,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify all headers are included
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
expected_headers = {
|
||||
"Authorization": "Bearer multi-token",
|
||||
"X-API-Key": "api-key-456",
|
||||
"X-Client-ID": "client-789",
|
||||
"X-Custom-Header": "custom-value",
|
||||
"User-Agent": "test-agent",
|
||||
}
|
||||
assert call_args.kwargs["headers"] == expected_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("backend.blocks.http.Requests")
|
||||
async def test_credentials_with_complex_url_patterns(
|
||||
self,
|
||||
mock_requests_class,
|
||||
http_block,
|
||||
mock_response,
|
||||
):
|
||||
"""Test credentials matching various URL patterns."""
|
||||
# Test cases for different URL patterns
|
||||
test_cases = [
|
||||
{
|
||||
"host_pattern": "api.example.com",
|
||||
"test_url": "https://api.example.com/v1/users",
|
||||
"should_match": True,
|
||||
},
|
||||
{
|
||||
"host_pattern": "*.example.com",
|
||||
"test_url": "https://api.example.com/v1/users",
|
||||
"should_match": True,
|
||||
},
|
||||
{
|
||||
"host_pattern": "*.example.com",
|
||||
"test_url": "https://subdomain.example.com/data",
|
||||
"should_match": True,
|
||||
},
|
||||
{
|
||||
"host_pattern": "api.example.com",
|
||||
"test_url": "https://api.different.com/data",
|
||||
"should_match": False,
|
||||
},
|
||||
]
|
||||
|
||||
# Setup mocks
|
||||
mock_requests = AsyncMock()
|
||||
mock_requests.request.return_value = mock_response
|
||||
mock_requests_class.return_value = mock_requests
|
||||
|
||||
for case in test_cases:
|
||||
# Reset mock for each test case
|
||||
mock_requests.reset_mock()
|
||||
|
||||
# Create credentials for this test case
|
||||
test_creds = HostScopedCredentials(
|
||||
provider="http",
|
||||
host=case["host_pattern"],
|
||||
headers={
|
||||
"Authorization": SecretStr(f"Bearer {case['host_pattern']}-token"),
|
||||
},
|
||||
title=f"Credentials for {case['host_pattern']}",
|
||||
)
|
||||
|
||||
input_data = SendAuthenticatedWebRequestBlock.Input(
|
||||
url=case["test_url"],
|
||||
method=HttpMethod.GET,
|
||||
headers={"User-Agent": "test-agent"},
|
||||
credentials=cast(
|
||||
HttpCredentials,
|
||||
{
|
||||
"id": test_creds.id,
|
||||
"provider": "http",
|
||||
"type": "host_scoped",
|
||||
"title": test_creds.title,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Execute with test credentials
|
||||
result = []
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=test_creds,
|
||||
graph_exec_id="test-exec-id",
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
# Verify headers based on whether pattern should match
|
||||
mock_requests.request.assert_called_once()
|
||||
call_args = mock_requests.request.call_args
|
||||
headers = call_args.kwargs["headers"]
|
||||
|
||||
if case["should_match"]:
|
||||
# Should include both user and credential headers
|
||||
expected_auth = f"Bearer {case['host_pattern']}-token"
|
||||
assert headers["Authorization"] == expected_auth
|
||||
assert headers["User-Agent"] == "test-agent"
|
||||
else:
|
||||
# Should only include user headers
|
||||
assert "Authorization" not in headers
|
||||
assert headers["User-Agent"] == "test-agent"
|
||||
@@ -15,7 +15,7 @@ from backend.blocks.zerobounce._auth import (
|
||||
ZeroBounceCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
@@ -90,7 +90,7 @@ class ValidateEmailsBlock(Block):
|
||||
description="IP address to validate",
|
||||
default="",
|
||||
)
|
||||
credentials: ZeroBounceCredentialsInput = CredentialsField(
|
||||
credentials: ZeroBounceCredentialsInput = SchemaField(
|
||||
description="ZeroBounce credentials",
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
from .graph import NodeModel
|
||||
from .integrations import Webhook # noqa: F401
|
||||
|
||||
# Resolve Webhook <- NodeModel forward reference
|
||||
NodeModel.model_rebuild()
|
||||
@@ -78,7 +78,6 @@ class BlockCategory(Enum):
|
||||
PRODUCTIVITY = "Block that helps with productivity"
|
||||
ISSUE_TRACKING = "Block that helps with issue tracking"
|
||||
MULTIMEDIA = "Block that interacts with multimedia content"
|
||||
MARKETING = "Block that helps with marketing"
|
||||
|
||||
def dict(self) -> dict[str, str]:
|
||||
return {"category": self.name, "description": self.value}
|
||||
@@ -119,10 +118,7 @@ class BlockSchema(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def validate_data(cls, data: BlockInput) -> str | None:
|
||||
return json.validate_with_jsonschema(
|
||||
schema=cls.jsonschema(),
|
||||
data={k: v for k, v in data.items() if v is not None},
|
||||
)
|
||||
return json.validate_with_jsonschema(schema=cls.jsonschema(), data=data)
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
@@ -475,8 +471,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
)
|
||||
|
||||
async for output_name, output_data in self.run(
|
||||
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
|
||||
**kwargs,
|
||||
self.input_schema(**input_data), **kwargs
|
||||
):
|
||||
if output_name == "error":
|
||||
raise RuntimeError(output_data)
|
||||
@@ -486,22 +481,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
raise ValueError(f"Block produced an invalid output data: {error}")
|
||||
yield output_name, output_data
|
||||
|
||||
def is_triggered_by_event_type(
|
||||
self, trigger_config: dict[str, Any], event_type: str
|
||||
) -> bool:
|
||||
if not self.webhook_config:
|
||||
raise TypeError("This method can't be used on non-trigger blocks")
|
||||
if not self.webhook_config.event_filter_input:
|
||||
return True
|
||||
event_filter = trigger_config.get(self.webhook_config.event_filter_input)
|
||||
if not event_filter:
|
||||
raise ValueError("Event filter is not configured on trigger")
|
||||
return event_type in [
|
||||
self.webhook_config.event_format.format(event=k)
|
||||
for k in event_filter
|
||||
if event_filter[k] is True
|
||||
]
|
||||
|
||||
|
||||
# ======================= Block Helper Functions ======================= #
|
||||
|
||||
|
||||
@@ -2,9 +2,6 @@ from typing import Type
|
||||
|
||||
from backend.blocks.ai_music_generator import AIMusicGeneratorBlock
|
||||
from backend.blocks.ai_shortform_video_block import AIShortformVideoCreatorBlock
|
||||
from backend.blocks.apollo.organization import SearchOrganizationsBlock
|
||||
from backend.blocks.apollo.people import SearchPeopleBlock
|
||||
from backend.blocks.apollo.person import GetPersonDetailBlock
|
||||
from backend.blocks.flux_kontext import AIImageEditorBlock, FluxKontextModelName
|
||||
from backend.blocks.ideogram import IdeogramModelBlock
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
@@ -27,7 +24,6 @@ from backend.data.cost import BlockCost, BlockCostType
|
||||
from backend.integrations.credentials_store import (
|
||||
aiml_api_credentials,
|
||||
anthropic_credentials,
|
||||
apollo_credentials,
|
||||
did_credentials,
|
||||
groq_credentials,
|
||||
ideogram_credentials,
|
||||
@@ -85,9 +81,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.EVA_QWEN_2_5_32B: 1,
|
||||
LlmModel.DEEPSEEK_CHAT: 2,
|
||||
LlmModel.PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE: 1,
|
||||
LlmModel.PERPLEXITY_SONAR: 1,
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: 5,
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10,
|
||||
LlmModel.QWEN_QWQ_32B_PREVIEW: 2,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
|
||||
@@ -352,52 +345,4 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
)
|
||||
],
|
||||
SmartDecisionMakerBlock: LLM_COST,
|
||||
SearchOrganizationsBlock: [
|
||||
BlockCost(
|
||||
cost_amount=2,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": apollo_credentials.id,
|
||||
"provider": apollo_credentials.provider,
|
||||
"type": apollo_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
SearchPeopleBlock: [
|
||||
BlockCost(
|
||||
cost_amount=10,
|
||||
cost_filter={
|
||||
"enrich_info": False,
|
||||
"credentials": {
|
||||
"id": apollo_credentials.id,
|
||||
"provider": apollo_credentials.provider,
|
||||
"type": apollo_credentials.type,
|
||||
},
|
||||
},
|
||||
),
|
||||
BlockCost(
|
||||
cost_amount=20,
|
||||
cost_filter={
|
||||
"enrich_info": True,
|
||||
"credentials": {
|
||||
"id": apollo_credentials.id,
|
||||
"provider": apollo_credentials.provider,
|
||||
"type": apollo_credentials.type,
|
||||
},
|
||||
},
|
||||
),
|
||||
],
|
||||
GetPersonDetailBlock: [
|
||||
BlockCost(
|
||||
cost_amount=1,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": apollo_credentials.id,
|
||||
"provider": apollo_credentials.provider,
|
||||
"type": apollo_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ from prisma.models import (
|
||||
AgentGraphExecution,
|
||||
AgentNodeExecution,
|
||||
AgentNodeExecutionInputOutput,
|
||||
AgentNodeExecutionKeyValueData,
|
||||
)
|
||||
from prisma.types import (
|
||||
AgentGraphExecutionCreateInput,
|
||||
@@ -30,11 +29,10 @@ from prisma.types import (
|
||||
AgentGraphExecutionWhereInput,
|
||||
AgentNodeExecutionCreateInput,
|
||||
AgentNodeExecutionInputOutputCreateInput,
|
||||
AgentNodeExecutionKeyValueDataCreateInput,
|
||||
AgentNodeExecutionUpdateInput,
|
||||
AgentNodeExecutionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, JsonValue
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic.fields import Field
|
||||
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
@@ -50,14 +48,14 @@ from .block import (
|
||||
get_webhook_block_ids,
|
||||
)
|
||||
from .db import BaseDbModel
|
||||
from .event_bus import AsyncRedisEventBus, RedisEventBus
|
||||
from .includes import (
|
||||
EXECUTION_RESULT_INCLUDE,
|
||||
EXECUTION_RESULT_ORDER,
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
graph_execution_include,
|
||||
)
|
||||
from .model import GraphExecutionStats, NodeExecutionStats
|
||||
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
|
||||
from .queue import AsyncRedisEventBus, RedisEventBus
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -273,7 +271,7 @@ class GraphExecutionWithNodes(GraphExecution):
|
||||
graph_id=self.graph_id,
|
||||
graph_version=self.graph_version or 0,
|
||||
graph_exec_id=self.id,
|
||||
nodes_input_masks={}, # FIXME: store credentials on AgentGraphExecution
|
||||
node_credentials_input_map={}, # FIXME
|
||||
)
|
||||
|
||||
|
||||
@@ -349,7 +347,6 @@ class NodeExecutionResult(BaseModel):
|
||||
|
||||
|
||||
async def get_graph_executions(
|
||||
graph_exec_id: str | None = None,
|
||||
graph_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
@@ -360,8 +357,6 @@ async def get_graph_executions(
|
||||
where_filter: AgentGraphExecutionWhereInput = {
|
||||
"isDeleted": False,
|
||||
}
|
||||
if graph_exec_id:
|
||||
where_filter["id"] = graph_exec_id
|
||||
if user_id:
|
||||
where_filter["userId"] = user_id
|
||||
if graph_id:
|
||||
@@ -561,18 +556,18 @@ async def upsert_execution_input(
|
||||
async def upsert_execution_output(
|
||||
node_exec_id: str,
|
||||
output_name: str,
|
||||
output_data: Any | None,
|
||||
output_data: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
|
||||
"""
|
||||
data = AgentNodeExecutionInputOutputCreateInput(
|
||||
name=output_name,
|
||||
referencedByOutputExecId=node_exec_id,
|
||||
await AgentNodeExecutionInputOutput.prisma().create(
|
||||
data=AgentNodeExecutionInputOutputCreateInput(
|
||||
name=output_name,
|
||||
data=Json(output_data),
|
||||
referencedByOutputExecId=node_exec_id,
|
||||
)
|
||||
)
|
||||
if output_data is not None:
|
||||
data["data"] = Json(output_data)
|
||||
await AgentNodeExecutionInputOutput.prisma().create(data=data)
|
||||
|
||||
|
||||
async def update_graph_execution_start_time(
|
||||
@@ -593,10 +588,12 @@ async def update_graph_execution_start_time(
|
||||
|
||||
async def update_graph_execution_stats(
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus | None = None,
|
||||
status: ExecutionStatus,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecution | None:
|
||||
update_data: AgentGraphExecutionUpdateManyMutationInput = {}
|
||||
update_data: AgentGraphExecutionUpdateManyMutationInput = {
|
||||
"executionStatus": status
|
||||
}
|
||||
|
||||
if stats:
|
||||
stats_dict = stats.model_dump()
|
||||
@@ -604,9 +601,6 @@ async def update_graph_execution_stats(
|
||||
stats_dict["error"] = str(stats_dict["error"])
|
||||
update_data["stats"] = Json(stats_dict)
|
||||
|
||||
if status:
|
||||
update_data["executionStatus"] = status
|
||||
|
||||
updated_count = await AgentGraphExecution.prisma().update_many(
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
@@ -789,7 +783,7 @@ class GraphExecutionEntry(BaseModel):
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None
|
||||
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]]
|
||||
|
||||
|
||||
class NodeExecutionEntry(BaseModel):
|
||||
@@ -909,57 +903,3 @@ class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
|
||||
) -> AsyncGenerator[ExecutionEvent, None]:
|
||||
async for event in self.listen_events(f"{user_id}/{graph_id}/{graph_exec_id}"):
|
||||
yield event
|
||||
|
||||
|
||||
# --------------------- KV Data Functions --------------------- #
|
||||
|
||||
|
||||
async def get_execution_kv_data(user_id: str, key: str) -> Any | None:
|
||||
"""
|
||||
Get key-value data for a user and key.
|
||||
|
||||
Args:
|
||||
user_id: The id of the User.
|
||||
key: The key to retrieve data for.
|
||||
|
||||
Returns:
|
||||
The data associated with the key, or None if not found.
|
||||
"""
|
||||
kv_data = await AgentNodeExecutionKeyValueData.prisma().find_unique(
|
||||
where={"userId_key": {"userId": user_id, "key": key}}
|
||||
)
|
||||
return (
|
||||
type_utils.convert(kv_data.data, type[Any])
|
||||
if kv_data and kv_data.data
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
async def set_execution_kv_data(
|
||||
user_id: str, node_exec_id: str, key: str, data: Any
|
||||
) -> Any | None:
|
||||
"""
|
||||
Set key-value data for a user and key.
|
||||
|
||||
Args:
|
||||
user_id: The id of the User.
|
||||
node_exec_id: The id of the AgentNodeExecution.
|
||||
key: The key to store data under.
|
||||
data: The data to store.
|
||||
"""
|
||||
resp = await AgentNodeExecutionKeyValueData.prisma().upsert(
|
||||
where={"userId_key": {"userId": user_id, "key": key}},
|
||||
data={
|
||||
"create": AgentNodeExecutionKeyValueDataCreateInput(
|
||||
userId=user_id,
|
||||
agentNodeExecutionId=node_exec_id,
|
||||
key=key,
|
||||
data=Json(data) if data is not None else None,
|
||||
),
|
||||
"update": {
|
||||
"agentNodeExecutionId": node_exec_id,
|
||||
"data": Json(data) if data is not None else None,
|
||||
},
|
||||
},
|
||||
)
|
||||
return type_utils.convert(resp.data, type[Any]) if resp and resp.data else None
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
from typing import Any, Literal, Optional, cast
|
||||
|
||||
import prisma
|
||||
from prisma import Json
|
||||
@@ -14,7 +14,7 @@ from prisma.types import (
|
||||
AgentNodeLinkCreateInput,
|
||||
StoreListingVersionWhereInput,
|
||||
)
|
||||
from pydantic import JsonValue, create_model
|
||||
from pydantic import create_model
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
@@ -27,15 +27,12 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
is_credentials_field_name,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import type as type_utils
|
||||
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
from .db import BaseDbModel, transaction
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .integrations import Webhook
|
||||
from .integrations import Webhook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -84,12 +81,10 @@ class NodeModel(Node):
|
||||
graph_version: int
|
||||
|
||||
webhook_id: Optional[str] = None
|
||||
webhook: Optional["Webhook"] = None
|
||||
webhook: Optional[Webhook] = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(node: AgentNode, for_export: bool = False) -> "NodeModel":
|
||||
from .integrations import Webhook
|
||||
|
||||
obj = NodeModel(
|
||||
id=node.id,
|
||||
block_id=node.agentBlockId,
|
||||
@@ -107,7 +102,19 @@ class NodeModel(Node):
|
||||
return obj
|
||||
|
||||
def is_triggered_by_event_type(self, event_type: str) -> bool:
|
||||
return self.block.is_triggered_by_event_type(self.input_default, event_type)
|
||||
block = self.block
|
||||
if not block.webhook_config:
|
||||
raise TypeError("This method can't be used on non-webhook blocks")
|
||||
if not block.webhook_config.event_filter_input:
|
||||
return True
|
||||
event_filter = self.input_default.get(block.webhook_config.event_filter_input)
|
||||
if not event_filter:
|
||||
raise ValueError(f"Event filter is not configured on node #{self.id}")
|
||||
return event_type in [
|
||||
block.webhook_config.event_format.format(event=k)
|
||||
for k in event_filter
|
||||
if event_filter[k] is True
|
||||
]
|
||||
|
||||
def stripped_for_export(self) -> "NodeModel":
|
||||
"""
|
||||
@@ -155,6 +162,10 @@ class NodeModel(Node):
|
||||
return result
|
||||
|
||||
|
||||
# Fix 2-way reference Node <-> Webhook
|
||||
Webhook.model_rebuild()
|
||||
|
||||
|
||||
class BaseGraph(BaseDbModel):
|
||||
version: int = 1
|
||||
is_active: bool = True
|
||||
@@ -244,8 +255,6 @@ class Graph(BaseGraph):
|
||||
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
|
||||
if field.provider != other_field.provider:
|
||||
continue
|
||||
if ProviderName.HTTP in field.provider:
|
||||
continue
|
||||
|
||||
# If this happens, that means a block implementation probably needs
|
||||
# to be updated.
|
||||
@@ -267,7 +276,6 @@ class Graph(BaseGraph):
|
||||
required_scopes=set(field_info.required_scopes or []),
|
||||
discriminator=field_info.discriminator,
|
||||
discriminator_mapping=field_info.discriminator_mapping,
|
||||
discriminator_values=field_info.discriminator_values,
|
||||
),
|
||||
)
|
||||
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
|
||||
@@ -286,40 +294,37 @@ class Graph(BaseGraph):
|
||||
Returns:
|
||||
dict[aggregated_field_key, tuple(
|
||||
CredentialsFieldInfo: A spec for one aggregated credentials field
|
||||
(now includes discriminator_values from matching nodes)
|
||||
set[(node_id, field_name)]: Node credentials fields that are
|
||||
compatible with this aggregated field spec
|
||||
)]
|
||||
"""
|
||||
# First collect all credential field data with input defaults
|
||||
node_credential_data = []
|
||||
|
||||
for graph in [self] + self.sub_graphs:
|
||||
for node in graph.nodes:
|
||||
for (
|
||||
field_name,
|
||||
field_info,
|
||||
) in node.block.input_schema.get_credentials_fields_info().items():
|
||||
|
||||
discriminator = field_info.discriminator
|
||||
if not discriminator:
|
||||
node_credential_data.append((field_info, (node.id, field_name)))
|
||||
continue
|
||||
|
||||
discriminator_value = node.input_default.get(discriminator)
|
||||
if discriminator_value is None:
|
||||
node_credential_data.append((field_info, (node.id, field_name)))
|
||||
continue
|
||||
|
||||
discriminated_info = field_info.discriminate(discriminator_value)
|
||||
discriminated_info.discriminator_values.add(discriminator_value)
|
||||
|
||||
node_credential_data.append(
|
||||
(discriminated_info, (node.id, field_name))
|
||||
return {
|
||||
"_".join(sorted(agg_field_info.provider))
|
||||
+ "_"
|
||||
+ "_".join(sorted(agg_field_info.supported_types))
|
||||
+ "_credentials": (agg_field_info, node_fields)
|
||||
for agg_field_info, node_fields in CredentialsFieldInfo.combine(
|
||||
*(
|
||||
(
|
||||
# Apply discrimination before aggregating credentials inputs
|
||||
(
|
||||
field_info.discriminate(
|
||||
node.input_default[field_info.discriminator]
|
||||
)
|
||||
if (
|
||||
field_info.discriminator
|
||||
and node.input_default.get(field_info.discriminator)
|
||||
)
|
||||
else field_info
|
||||
),
|
||||
(node.id, field_name),
|
||||
)
|
||||
|
||||
# Combine credential field info (this will merge discriminator_values automatically)
|
||||
return CredentialsFieldInfo.combine(*node_credential_data)
|
||||
for graph in [self] + self.sub_graphs
|
||||
for node in graph.nodes
|
||||
for field_name, field_info in node.block.input_schema.get_credentials_fields_info().items()
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class GraphModel(Graph):
|
||||
@@ -389,10 +394,8 @@ class GraphModel(Graph):
|
||||
|
||||
# Reassign Link IDs
|
||||
for link in graph.links:
|
||||
if link.source_id in id_map:
|
||||
link.source_id = id_map[link.source_id]
|
||||
if link.sink_id in id_map:
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
link.source_id = id_map[link.source_id]
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
|
||||
# Reassign User IDs for agent blocks
|
||||
for node in graph.nodes:
|
||||
@@ -400,26 +403,16 @@ class GraphModel(Graph):
|
||||
continue
|
||||
node.input_default["user_id"] = user_id
|
||||
node.input_default.setdefault("inputs", {})
|
||||
if (
|
||||
graph_id := node.input_default.get("graph_id")
|
||||
) and graph_id in graph_id_map:
|
||||
if (graph_id := node.input_default.get("graph_id")) in graph_id_map:
|
||||
node.input_default["graph_id"] = graph_id_map[graph_id]
|
||||
|
||||
def validate_graph(
|
||||
self,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
):
|
||||
self._validate_graph(self, for_run, nodes_input_masks)
|
||||
def validate_graph(self, for_run: bool = False):
|
||||
self._validate_graph(self, for_run)
|
||||
for sub_graph in self.sub_graphs:
|
||||
self._validate_graph(sub_graph, for_run, nodes_input_masks)
|
||||
self._validate_graph(sub_graph, for_run)
|
||||
|
||||
@staticmethod
|
||||
def _validate_graph(
|
||||
graph: BaseGraph,
|
||||
for_run: bool = False,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
):
|
||||
def _validate_graph(graph: BaseGraph, for_run: bool = False):
|
||||
def is_tool_pin(name: str) -> bool:
|
||||
return name.startswith("tools_^_")
|
||||
|
||||
@@ -446,18 +439,20 @@ class GraphModel(Graph):
|
||||
if (block := nodes_block.get(node.id)) is None:
|
||||
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
||||
|
||||
node_input_mask = (
|
||||
nodes_input_masks.get(node.id, {}) if nodes_input_masks else {}
|
||||
)
|
||||
provided_inputs = set(
|
||||
[sanitize(name) for name in node.input_default]
|
||||
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
|
||||
+ ([name for name in node_input_mask] if node_input_mask else [])
|
||||
)
|
||||
InputSchema = block.input_schema
|
||||
for name in (required_fields := InputSchema.get_required_fields()):
|
||||
if (
|
||||
name not in provided_inputs
|
||||
# Webhook payload is passed in by ExecutionManager
|
||||
and not (
|
||||
name == "payload"
|
||||
and block.block_type
|
||||
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
)
|
||||
# Checking availability of credentials is done by ExecutionManager
|
||||
and name not in InputSchema.get_credentials_fields()
|
||||
# Validate only I/O nodes, or validate everything when executing
|
||||
@@ -490,18 +485,10 @@ class GraphModel(Graph):
|
||||
|
||||
def has_value(node: Node, name: str):
|
||||
return (
|
||||
(
|
||||
name in node.input_default
|
||||
and node.input_default[name] is not None
|
||||
and str(node.input_default[name]).strip() != ""
|
||||
)
|
||||
or (name in input_fields and input_fields[name].default is not None)
|
||||
or (
|
||||
name in node_input_mask
|
||||
and node_input_mask[name] is not None
|
||||
and str(node_input_mask[name]).strip() != ""
|
||||
)
|
||||
)
|
||||
name in node.input_default
|
||||
and node.input_default[name] is not None
|
||||
and str(node.input_default[name]).strip() != ""
|
||||
) or (name in input_fields and input_fields[name].default is not None)
|
||||
|
||||
# Validate dependencies between fields
|
||||
for field_name in input_fields.keys():
|
||||
@@ -587,7 +574,7 @@ class GraphModel(Graph):
|
||||
graph: AgentGraph,
|
||||
for_export: bool = False,
|
||||
sub_graphs: list[AgentGraph] | None = None,
|
||||
) -> "GraphModel":
|
||||
):
|
||||
return GraphModel(
|
||||
id=graph.id,
|
||||
user_id=graph.userId if not for_export else "",
|
||||
@@ -616,7 +603,6 @@ class GraphModel(Graph):
|
||||
|
||||
|
||||
async def get_node(node_id: str) -> NodeModel:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
node = await AgentNode.prisma().find_unique_or_raise(
|
||||
where={"id": node_id},
|
||||
include=AGENT_NODE_INCLUDE,
|
||||
@@ -625,7 +611,6 @@ async def get_node(node_id: str) -> NodeModel:
|
||||
|
||||
|
||||
async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
node = await AgentNode.prisma().update(
|
||||
where={"id": node_id},
|
||||
data=(
|
||||
|
||||
@@ -60,8 +60,7 @@ def graph_execution_include(
|
||||
|
||||
|
||||
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE},
|
||||
"AgentPresets": {"include": {"InputPresets": True}},
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,25 +1,21 @@
|
||||
import logging
|
||||
from typing import AsyncGenerator, Literal, Optional, overload
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Optional
|
||||
|
||||
from prisma import Json
|
||||
from prisma.models import IntegrationWebhook
|
||||
from prisma.types import (
|
||||
IntegrationWebhookCreateInput,
|
||||
IntegrationWebhookUpdateInput,
|
||||
IntegrationWebhookWhereInput,
|
||||
Serializable,
|
||||
)
|
||||
from prisma.types import IntegrationWebhookCreateInput
|
||||
from pydantic import Field, computed_field
|
||||
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
|
||||
from backend.data.queue import AsyncRedisEventBus
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.utils import webhook_ingress_url
|
||||
from backend.server.v2.library.model import LibraryAgentPreset
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .db import BaseDbModel
|
||||
from .graph import NodeModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .graph import NodeModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,6 +32,8 @@ class Webhook(BaseDbModel):
|
||||
|
||||
provider_webhook_id: str
|
||||
|
||||
attached_nodes: Optional[list["NodeModel"]] = None
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
@@ -43,6 +41,8 @@ class Webhook(BaseDbModel):
|
||||
|
||||
@staticmethod
|
||||
def from_db(webhook: IntegrationWebhook):
|
||||
from .graph import NodeModel
|
||||
|
||||
return Webhook(
|
||||
id=webhook.id,
|
||||
user_id=webhook.userId,
|
||||
@@ -54,26 +54,11 @@ class Webhook(BaseDbModel):
|
||||
config=dict(webhook.config),
|
||||
secret=webhook.secret,
|
||||
provider_webhook_id=webhook.providerWebhookId,
|
||||
)
|
||||
|
||||
|
||||
class WebhookWithRelations(Webhook):
|
||||
triggered_nodes: list[NodeModel]
|
||||
triggered_presets: list[LibraryAgentPreset]
|
||||
|
||||
@staticmethod
|
||||
def from_db(webhook: IntegrationWebhook):
|
||||
if webhook.AgentNodes is None or webhook.AgentPresets is None:
|
||||
raise ValueError(
|
||||
"AgentNodes and AgentPresets must be included in "
|
||||
"IntegrationWebhook query with relations"
|
||||
)
|
||||
return WebhookWithRelations(
|
||||
**Webhook.from_db(webhook).model_dump(),
|
||||
triggered_nodes=[NodeModel.from_db(node) for node in webhook.AgentNodes],
|
||||
triggered_presets=[
|
||||
LibraryAgentPreset.from_db(preset) for preset in webhook.AgentPresets
|
||||
],
|
||||
attached_nodes=(
|
||||
[NodeModel.from_db(node) for node in webhook.AgentNodes]
|
||||
if webhook.AgentNodes is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -98,19 +83,7 @@ async def create_webhook(webhook: Webhook) -> Webhook:
|
||||
return Webhook.from_db(created_webhook)
|
||||
|
||||
|
||||
@overload
|
||||
async def get_webhook(
|
||||
webhook_id: str, *, include_relations: Literal[True]
|
||||
) -> WebhookWithRelations: ...
|
||||
@overload
|
||||
async def get_webhook(
|
||||
webhook_id: str, *, include_relations: Literal[False] = False
|
||||
) -> Webhook: ...
|
||||
|
||||
|
||||
async def get_webhook(
|
||||
webhook_id: str, *, include_relations: bool = False
|
||||
) -> Webhook | WebhookWithRelations:
|
||||
async def get_webhook(webhook_id: str) -> Webhook:
|
||||
"""
|
||||
⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints.
|
||||
|
||||
@@ -119,113 +92,73 @@ async def get_webhook(
|
||||
"""
|
||||
webhook = await IntegrationWebhook.prisma().find_unique(
|
||||
where={"id": webhook_id},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE if include_relations else None,
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
)
|
||||
if not webhook:
|
||||
raise NotFoundError(f"Webhook #{webhook_id} not found")
|
||||
return (WebhookWithRelations if include_relations else Webhook).from_db(webhook)
|
||||
return Webhook.from_db(webhook)
|
||||
|
||||
|
||||
@overload
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: Literal[True]
|
||||
) -> list[WebhookWithRelations]: ...
|
||||
@overload
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: Literal[False] = False
|
||||
) -> list[Webhook]: ...
|
||||
|
||||
|
||||
async def get_all_webhooks_by_creds(
|
||||
user_id: str, credentials_id: str, *, include_relations: bool = False
|
||||
) -> list[Webhook] | list[WebhookWithRelations]:
|
||||
async def get_all_webhooks_by_creds(credentials_id: str) -> list[Webhook]:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
if not credentials_id:
|
||||
raise ValueError("credentials_id must not be empty")
|
||||
webhooks = await IntegrationWebhook.prisma().find_many(
|
||||
where={"userId": user_id, "credentialsId": credentials_id},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE if include_relations else None,
|
||||
where={"credentialsId": credentials_id},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
)
|
||||
return [
|
||||
(WebhookWithRelations if include_relations else Webhook).from_db(webhook)
|
||||
for webhook in webhooks
|
||||
]
|
||||
return [Webhook.from_db(webhook) for webhook in webhooks]
|
||||
|
||||
|
||||
async def find_webhook_by_credentials_and_props(
|
||||
user_id: str,
|
||||
credentials_id: str,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
credentials_id: str, webhook_type: str, resource: str, events: list[str]
|
||||
) -> Webhook | None:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
webhook = await IntegrationWebhook.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"credentialsId": credentials_id,
|
||||
"webhookType": webhook_type,
|
||||
"resource": resource,
|
||||
"events": {"has_every": events},
|
||||
},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
)
|
||||
return Webhook.from_db(webhook) if webhook else None
|
||||
|
||||
|
||||
async def find_webhook_by_graph_and_props(
|
||||
user_id: str,
|
||||
provider: str,
|
||||
webhook_type: str,
|
||||
graph_id: Optional[str] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
graph_id: str, provider: str, webhook_type: str, events: list[str]
|
||||
) -> Webhook | None:
|
||||
"""Either `graph_id` or `preset_id` must be provided."""
|
||||
where_clause: IntegrationWebhookWhereInput = {
|
||||
"userId": user_id,
|
||||
"provider": provider,
|
||||
"webhookType": webhook_type,
|
||||
}
|
||||
|
||||
if preset_id:
|
||||
where_clause["AgentPresets"] = {"some": {"id": preset_id}}
|
||||
elif graph_id:
|
||||
where_clause["AgentNodes"] = {"some": {"agentGraphId": graph_id}}
|
||||
else:
|
||||
raise ValueError("Either graph_id or preset_id must be provided")
|
||||
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
webhook = await IntegrationWebhook.prisma().find_first(
|
||||
where=where_clause,
|
||||
where={
|
||||
"provider": provider,
|
||||
"webhookType": webhook_type,
|
||||
"events": {"has_every": events},
|
||||
"AgentNodes": {"some": {"agentGraphId": graph_id}},
|
||||
},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
)
|
||||
return Webhook.from_db(webhook) if webhook else None
|
||||
|
||||
|
||||
async def update_webhook(
|
||||
webhook_id: str,
|
||||
config: Optional[dict[str, Serializable]] = None,
|
||||
events: Optional[list[str]] = None,
|
||||
) -> Webhook:
|
||||
async def update_webhook_config(webhook_id: str, updated_config: dict) -> Webhook:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
data: IntegrationWebhookUpdateInput = {}
|
||||
if config is not None:
|
||||
data["config"] = Json(config)
|
||||
if events is not None:
|
||||
data["events"] = events
|
||||
if not data:
|
||||
raise ValueError("Empty update query")
|
||||
|
||||
_updated_webhook = await IntegrationWebhook.prisma().update(
|
||||
where={"id": webhook_id},
|
||||
data=data,
|
||||
data={"config": Json(updated_config)},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
)
|
||||
if _updated_webhook is None:
|
||||
raise NotFoundError(f"Webhook #{webhook_id} not found")
|
||||
raise ValueError(f"Webhook #{webhook_id} not found")
|
||||
return Webhook.from_db(_updated_webhook)
|
||||
|
||||
|
||||
async def delete_webhook(user_id: str, webhook_id: str) -> None:
|
||||
deleted = await IntegrationWebhook.prisma().delete_many(
|
||||
where={"id": webhook_id, "userId": user_id}
|
||||
)
|
||||
if deleted < 1:
|
||||
raise NotFoundError(f"Webhook #{webhook_id} not found")
|
||||
async def delete_webhook(webhook_id: str) -> None:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
deleted = await IntegrationWebhook.prisma().delete(where={"id": webhook_id})
|
||||
if not deleted:
|
||||
raise ValueError(f"Webhook #{webhook_id} not found")
|
||||
|
||||
|
||||
# --------------------- WEBHOOK EVENTS --------------------- #
|
||||
|
||||
@@ -14,12 +14,11 @@ from typing import (
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
cast,
|
||||
get_args,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
from prisma.enums import CreditTransactionType
|
||||
@@ -241,65 +240,13 @@ class UserPasswordCredentials(_BaseCredentials):
|
||||
return f"Basic {base64.b64encode(f'{self.username.get_secret_value()}:{self.password.get_secret_value()}'.encode()).decode()}"
|
||||
|
||||
|
||||
class HostScopedCredentials(_BaseCredentials):
|
||||
type: Literal["host_scoped"] = "host_scoped"
|
||||
host: str = Field(description="The host/URI pattern to match against request URLs")
|
||||
headers: dict[str, SecretStr] = Field(
|
||||
description="Key-value header map to add to matching requests",
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
def _extract_headers(self, headers: dict[str, SecretStr]) -> dict[str, str]:
|
||||
"""Helper to extract secret values from headers."""
|
||||
return {key: value.get_secret_value() for key, value in headers.items()}
|
||||
|
||||
@field_serializer("headers")
|
||||
def serialize_headers(self, headers: dict[str, SecretStr]) -> dict[str, str]:
|
||||
"""Serialize headers by extracting secret values."""
|
||||
return self._extract_headers(headers)
|
||||
|
||||
def get_headers_dict(self) -> dict[str, str]:
|
||||
"""Get headers with secret values extracted."""
|
||||
return self._extract_headers(self.headers)
|
||||
|
||||
def auth_header(self) -> str:
|
||||
"""Get authorization header for backward compatibility."""
|
||||
auth_headers = self.get_headers_dict()
|
||||
if "Authorization" in auth_headers:
|
||||
return auth_headers["Authorization"]
|
||||
return ""
|
||||
|
||||
def matches_url(self, url: str) -> bool:
|
||||
"""Check if this credential should be applied to the given URL."""
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
# Extract hostname without port
|
||||
request_host = parsed_url.hostname
|
||||
if not request_host:
|
||||
return False
|
||||
|
||||
# Simple host matching - exact match or wildcard subdomain match
|
||||
if self.host == request_host:
|
||||
return True
|
||||
|
||||
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
|
||||
if self.host.startswith("*."):
|
||||
domain = self.host[2:] # Remove "*."
|
||||
return request_host.endswith(f".{domain}") or request_host == domain
|
||||
|
||||
return False
|
||||
|
||||
|
||||
Credentials = Annotated[
|
||||
OAuth2Credentials
|
||||
| APIKeyCredentials
|
||||
| UserPasswordCredentials
|
||||
| HostScopedCredentials,
|
||||
OAuth2Credentials | APIKeyCredentials | UserPasswordCredentials,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
CredentialsType = Literal["api_key", "oauth2", "user_password", "host_scoped"]
|
||||
CredentialsType = Literal["api_key", "oauth2", "user_password"]
|
||||
|
||||
|
||||
class OAuthState(BaseModel):
|
||||
@@ -373,29 +320,15 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _add_json_schema_extra(schema: dict, model_class: type):
|
||||
# Use model_class for allowed_providers/cred_types
|
||||
if hasattr(model_class, "allowed_providers") and hasattr(
|
||||
model_class, "allowed_cred_types"
|
||||
):
|
||||
schema["credentials_provider"] = model_class.allowed_providers()
|
||||
schema["credentials_types"] = model_class.allowed_cred_types()
|
||||
# Do not return anything, just mutate schema in place
|
||||
def _add_json_schema_extra(schema, cls: CredentialsMetaInput):
|
||||
schema["credentials_provider"] = cls.allowed_providers()
|
||||
schema["credentials_types"] = cls.allowed_cred_types()
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=_add_json_schema_extra, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def _extract_host_from_url(url: str) -> str:
|
||||
"""Extract host from URL for grouping host-scoped credentials."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return parsed.hostname or url
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
|
||||
provider: frozenset[CP] = Field(..., alias="credentials_provider")
|
||||
@@ -403,12 +336,11 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
required_scopes: Optional[frozenset[str]] = Field(None, alias="credentials_scopes")
|
||||
discriminator: Optional[str] = None
|
||||
discriminator_mapping: Optional[dict[str, CP]] = None
|
||||
discriminator_values: set[Any] = Field(default_factory=set)
|
||||
|
||||
@classmethod
|
||||
def combine(
|
||||
cls, *fields: tuple[CredentialsFieldInfo[CP, CT], T]
|
||||
) -> dict[str, tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
|
||||
) -> Sequence[tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
|
||||
"""
|
||||
Combines multiple CredentialsFieldInfo objects into as few as possible.
|
||||
|
||||
@@ -426,36 +358,22 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
the set of keys of the respective original items that were grouped together.
|
||||
"""
|
||||
if not fields:
|
||||
return {}
|
||||
return []
|
||||
|
||||
# Group fields by their provider and supported_types
|
||||
# For HTTP host-scoped credentials, also group by host
|
||||
grouped_fields: defaultdict[
|
||||
tuple[frozenset[CP], frozenset[CT]],
|
||||
list[tuple[T, CredentialsFieldInfo[CP, CT]]],
|
||||
] = defaultdict(list)
|
||||
|
||||
for field, key in fields:
|
||||
if field.provider == frozenset([ProviderName.HTTP]):
|
||||
# HTTP host-scoped credentials can have different hosts that reqires different credential sets.
|
||||
# Group by host extracted from the URL
|
||||
providers = frozenset(
|
||||
[cast(CP, "http")]
|
||||
+ [
|
||||
cast(CP, _extract_host_from_url(str(value)))
|
||||
for value in field.discriminator_values
|
||||
]
|
||||
)
|
||||
else:
|
||||
providers = frozenset(field.provider)
|
||||
|
||||
group_key = (providers, frozenset(field.supported_types))
|
||||
group_key = (frozenset(field.provider), frozenset(field.supported_types))
|
||||
grouped_fields[group_key].append((key, field))
|
||||
|
||||
# Combine fields within each group
|
||||
result: dict[str, tuple[CredentialsFieldInfo[CP, CT], set[T]]] = {}
|
||||
result: list[tuple[CredentialsFieldInfo[CP, CT], set[T]]] = []
|
||||
|
||||
for key, group in grouped_fields.items():
|
||||
for group in grouped_fields.values():
|
||||
# Start with the first field in the group
|
||||
_, combined = group[0]
|
||||
|
||||
@@ -468,32 +386,18 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
if field.required_scopes:
|
||||
all_scopes.update(field.required_scopes)
|
||||
|
||||
# Combine discriminator_values from all fields in the group (removing duplicates)
|
||||
all_discriminator_values = []
|
||||
for _, field in group:
|
||||
for value in field.discriminator_values:
|
||||
if value not in all_discriminator_values:
|
||||
all_discriminator_values.append(value)
|
||||
|
||||
# Generate the key for the combined result
|
||||
providers_key, supported_types_key = key
|
||||
group_key = (
|
||||
"-".join(sorted(providers_key))
|
||||
+ "_"
|
||||
+ "-".join(sorted(supported_types_key))
|
||||
+ "_credentials"
|
||||
)
|
||||
|
||||
result[group_key] = (
|
||||
CredentialsFieldInfo[CP, CT](
|
||||
credentials_provider=combined.provider,
|
||||
credentials_types=combined.supported_types,
|
||||
credentials_scopes=frozenset(all_scopes) or None,
|
||||
discriminator=combined.discriminator,
|
||||
discriminator_mapping=combined.discriminator_mapping,
|
||||
discriminator_values=set(all_discriminator_values),
|
||||
),
|
||||
combined_keys,
|
||||
# Create a new combined field
|
||||
result.append(
|
||||
(
|
||||
CredentialsFieldInfo[CP, CT](
|
||||
credentials_provider=combined.provider,
|
||||
credentials_types=combined.supported_types,
|
||||
credentials_scopes=frozenset(all_scopes) or None,
|
||||
discriminator=combined.discriminator,
|
||||
discriminator_mapping=combined.discriminator_mapping,
|
||||
),
|
||||
combined_keys,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -502,15 +406,11 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
if not (self.discriminator and self.discriminator_mapping):
|
||||
return self
|
||||
|
||||
discriminator_value = self.discriminator_mapping[discriminator_value]
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset(
|
||||
[self.discriminator_mapping[discriminator_value]]
|
||||
),
|
||||
credentials_provider=frozenset([discriminator_value]),
|
||||
credentials_types=self.supported_types,
|
||||
credentials_scopes=self.required_scopes,
|
||||
discriminator=self.discriminator,
|
||||
discriminator_mapping=self.discriminator_mapping,
|
||||
discriminator_values=self.discriminator_values,
|
||||
)
|
||||
|
||||
|
||||
@@ -519,7 +419,6 @@ def CredentialsField(
|
||||
*,
|
||||
discriminator: Optional[str] = None,
|
||||
discriminator_mapping: Optional[dict[str, Any]] = None,
|
||||
discriminator_values: Optional[set[Any]] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
**kwargs,
|
||||
@@ -535,7 +434,6 @@ def CredentialsField(
|
||||
"credentials_scopes": list(required_scopes) or None,
|
||||
"discriminator": discriminator,
|
||||
"discriminator_mapping": discriminator_mapping,
|
||||
"discriminator_values": discriminator_values,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
@@ -1,143 +0,0 @@
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import HostScopedCredentials
|
||||
|
||||
|
||||
class TestHostScopedCredentials:
|
||||
def test_host_scoped_credentials_creation(self):
|
||||
"""Test creating HostScopedCredentials with required fields."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer secret-token"),
|
||||
"X-API-Key": SecretStr("api-key-123"),
|
||||
},
|
||||
title="Example API Credentials",
|
||||
)
|
||||
|
||||
assert creds.type == "host_scoped"
|
||||
assert creds.provider == "custom"
|
||||
assert creds.host == "api.example.com"
|
||||
assert creds.title == "Example API Credentials"
|
||||
assert len(creds.headers) == 2
|
||||
assert "Authorization" in creds.headers
|
||||
assert "X-API-Key" in creds.headers
|
||||
|
||||
def test_get_headers_dict(self):
|
||||
"""Test getting headers with secret values extracted."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer secret-token"),
|
||||
"X-Custom-Header": SecretStr("custom-value"),
|
||||
},
|
||||
)
|
||||
|
||||
headers_dict = creds.get_headers_dict()
|
||||
|
||||
assert headers_dict == {
|
||||
"Authorization": "Bearer secret-token",
|
||||
"X-Custom-Header": "custom-value",
|
||||
}
|
||||
|
||||
def test_matches_url_exact_host(self):
|
||||
"""Test URL matching with exact host match."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="api.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url("https://api.example.com/v1/data")
|
||||
assert creds.matches_url("http://api.example.com/endpoint")
|
||||
assert not creds.matches_url("https://other.example.com/v1/data")
|
||||
assert not creds.matches_url("https://subdomain.api.example.com/v1/data")
|
||||
|
||||
def test_matches_url_wildcard_subdomain(self):
|
||||
"""Test URL matching with wildcard subdomain pattern."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="*.example.com",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url("https://api.example.com/v1/data")
|
||||
assert creds.matches_url("https://subdomain.example.com/endpoint")
|
||||
assert creds.matches_url("https://deep.nested.example.com/path")
|
||||
assert creds.matches_url("https://example.com/path") # Base domain should match
|
||||
assert not creds.matches_url("https://example.org/v1/data")
|
||||
assert not creds.matches_url("https://notexample.com/v1/data")
|
||||
|
||||
def test_matches_url_with_port_and_path(self):
|
||||
"""Test URL matching with ports and paths."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="localhost",
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url("http://localhost:8080/api/v1")
|
||||
assert creds.matches_url("https://localhost:443/secure/endpoint")
|
||||
assert creds.matches_url("http://localhost/simple")
|
||||
|
||||
def test_empty_headers_dict(self):
|
||||
"""Test HostScopedCredentials with empty headers."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="custom", host="api.example.com", headers={}
|
||||
)
|
||||
|
||||
assert creds.get_headers_dict() == {}
|
||||
assert creds.matches_url("https://api.example.com/test")
|
||||
|
||||
def test_credential_serialization(self):
|
||||
"""Test that credentials can be serialized/deserialized properly."""
|
||||
original_creds = HostScopedCredentials(
|
||||
provider="custom",
|
||||
host="api.example.com",
|
||||
headers={
|
||||
"Authorization": SecretStr("Bearer secret-token"),
|
||||
"X-API-Key": SecretStr("api-key-123"),
|
||||
},
|
||||
title="Test Credentials",
|
||||
)
|
||||
|
||||
# Serialize to dict (simulating storage)
|
||||
serialized = original_creds.model_dump()
|
||||
|
||||
# Deserialize back
|
||||
restored_creds = HostScopedCredentials.model_validate(serialized)
|
||||
|
||||
assert restored_creds.id == original_creds.id
|
||||
assert restored_creds.provider == original_creds.provider
|
||||
assert restored_creds.host == original_creds.host
|
||||
assert restored_creds.title == original_creds.title
|
||||
assert restored_creds.type == "host_scoped"
|
||||
|
||||
# Check that headers are properly restored
|
||||
assert restored_creds.get_headers_dict() == original_creds.get_headers_dict()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"host,test_url,expected",
|
||||
[
|
||||
("api.example.com", "https://api.example.com/test", True),
|
||||
("api.example.com", "https://different.example.com/test", False),
|
||||
("*.example.com", "https://api.example.com/test", True),
|
||||
("*.example.com", "https://sub.api.example.com/test", True),
|
||||
("*.example.com", "https://example.com/test", True),
|
||||
("*.example.com", "https://example.org/test", False),
|
||||
("localhost", "http://localhost:3000/test", True),
|
||||
("localhost", "http://127.0.0.1:3000/test", False),
|
||||
],
|
||||
)
|
||||
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
|
||||
"""Parametrized test for various URL matching scenarios."""
|
||||
creds = HostScopedCredentials(
|
||||
provider="test",
|
||||
host=host,
|
||||
headers={"Authorization": SecretStr("Bearer token")},
|
||||
)
|
||||
|
||||
assert creds.matches_url(test_url) == expected
|
||||
@@ -7,7 +7,7 @@ from pydantic import BaseModel
|
||||
from redis.asyncio.client import PubSub as AsyncPubSub
|
||||
from redis.client import PubSub
|
||||
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data import redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -5,14 +5,12 @@ from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
create_graph_execution,
|
||||
get_execution_kv_data,
|
||||
get_graph_execution,
|
||||
get_graph_execution_meta,
|
||||
get_graph_executions,
|
||||
get_latest_node_execution,
|
||||
get_node_execution,
|
||||
get_node_executions,
|
||||
set_execution_kv_data,
|
||||
update_graph_execution_start_time,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_stats,
|
||||
@@ -103,8 +101,6 @@ class DatabaseManager(AppService):
|
||||
update_node_execution_stats = _(update_node_execution_stats)
|
||||
upsert_execution_input = _(upsert_execution_input)
|
||||
upsert_execution_output = _(upsert_execution_output)
|
||||
get_execution_kv_data = _(get_execution_kv_data)
|
||||
set_execution_kv_data = _(set_execution_kv_data)
|
||||
|
||||
# Graphs
|
||||
get_node = _(get_node)
|
||||
@@ -163,8 +159,6 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
update_node_execution_stats = _(d.update_node_execution_stats)
|
||||
upsert_execution_input = _(d.upsert_execution_input)
|
||||
upsert_execution_output = _(d.upsert_execution_output)
|
||||
get_execution_kv_data = _(d.get_execution_kv_data)
|
||||
set_execution_kv_data = _(d.set_execution_kv_data)
|
||||
|
||||
# Graphs
|
||||
get_node = _(d.get_node)
|
||||
@@ -208,11 +202,8 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
return DatabaseManager
|
||||
|
||||
create_graph_execution = d.create_graph_execution
|
||||
get_connected_output_nodes = d.get_connected_output_nodes
|
||||
get_latest_node_execution = d.get_latest_node_execution
|
||||
get_graph = d.get_graph
|
||||
get_graph_metadata = d.get_graph_metadata
|
||||
get_graph_execution_meta = d.get_graph_execution_meta
|
||||
get_node = d.get_node
|
||||
get_node_execution = d.get_node_execution
|
||||
get_node_executions = d.get_node_executions
|
||||
@@ -224,5 +215,3 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
update_node_execution_status = d.update_node_execution_status
|
||||
update_node_execution_status_batch = d.update_node_execution_status_batch
|
||||
update_user_integrations = d.update_user_integrations
|
||||
get_execution_kv_data = d.get_execution_kv_data
|
||||
set_execution_kv_data = d.set_execution_kv_data
|
||||
|
||||
@@ -12,11 +12,14 @@ from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from pydantic import JsonValue
|
||||
from redis.asyncio.lock import Lock as RedisLock
|
||||
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
GraphExecutionStats,
|
||||
NodeExecutionStats,
|
||||
)
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
LowBalanceData,
|
||||
@@ -24,7 +27,7 @@ from backend.data.notifications import (
|
||||
NotificationType,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.utils import LogMetadata, create_execution_queue_config
|
||||
from backend.executor.utils import create_execution_queue_config
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
@@ -35,7 +38,7 @@ from autogpt_libs.utils.cache import thread_cached
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data import redis
|
||||
from backend.data.block import (
|
||||
BlockData,
|
||||
BlockInput,
|
||||
@@ -98,6 +101,35 @@ utilization_gauge = Gauge(
|
||||
)
|
||||
|
||||
|
||||
class LogMetadata(TruncatedLogger):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_eid: str,
|
||||
graph_id: str,
|
||||
node_eid: str,
|
||||
node_id: str,
|
||||
block_name: str,
|
||||
max_length: int = 1000,
|
||||
):
|
||||
metadata = {
|
||||
"component": "ExecutionManager",
|
||||
"user_id": user_id,
|
||||
"graph_eid": graph_eid,
|
||||
"graph_id": graph_id,
|
||||
"node_eid": node_eid,
|
||||
"node_id": node_id,
|
||||
"block_name": block_name,
|
||||
}
|
||||
prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|neid:{node_eid}|{block_name}]"
|
||||
super().__init__(
|
||||
_logger,
|
||||
max_length=max_length,
|
||||
prefix=prefix,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@@ -106,7 +138,9 @@ async def execute_node(
|
||||
creds_manager: IntegrationCredentialsManager,
|
||||
data: NodeExecutionEntry,
|
||||
execution_stats: NodeExecutionStats | None = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Execute a node in the graph. This will trigger a block execution on a node,
|
||||
@@ -129,7 +163,6 @@ async def execute_node(
|
||||
node_block = node.block
|
||||
|
||||
log_metadata = LogMetadata(
|
||||
logger=_logger,
|
||||
user_id=user_id,
|
||||
graph_eid=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
@@ -150,8 +183,8 @@ async def execute_node(
|
||||
if isinstance(node_block, AgentExecutorBlock):
|
||||
_input_data = AgentExecutorBlock.Input(**node.input_default)
|
||||
_input_data.inputs = input_data
|
||||
if nodes_input_masks:
|
||||
_input_data.nodes_input_masks = nodes_input_masks
|
||||
if node_credentials_input_map:
|
||||
_input_data.node_credentials_input_map = node_credentials_input_map
|
||||
input_data = _input_data.model_dump()
|
||||
data.inputs = input_data
|
||||
|
||||
@@ -222,7 +255,7 @@ async def _enqueue_next_nodes(
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]],
|
||||
) -> list[NodeExecutionEntry]:
|
||||
async def add_enqueued_execution(
|
||||
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
|
||||
@@ -256,9 +289,8 @@ async def _enqueue_next_nodes(
|
||||
next_input_name = node_link.sink_name
|
||||
next_node_id = node_link.sink_id
|
||||
|
||||
output_name, _ = output
|
||||
next_data = parse_execution_output(output, next_output_name)
|
||||
if next_data is None and output_name != next_output_name:
|
||||
if next_data is None:
|
||||
return enqueued_executions
|
||||
next_node = await db_client.get_node(next_node_id)
|
||||
|
||||
@@ -293,12 +325,14 @@ async def _enqueue_next_nodes(
|
||||
for name in static_link_names:
|
||||
next_node_input[name] = latest_execution.input_data.get(name)
|
||||
|
||||
# Apply node input overrides
|
||||
node_input_mask = None
|
||||
if nodes_input_masks and (
|
||||
node_input_mask := nodes_input_masks.get(next_node.id)
|
||||
# Apply node credentials overrides
|
||||
node_credentials = None
|
||||
if node_credentials_input_map and (
|
||||
node_credentials := node_credentials_input_map.get(next_node.id)
|
||||
):
|
||||
next_node_input.update(node_input_mask)
|
||||
next_node_input.update(
|
||||
{k: v.model_dump() for k, v in node_credentials.items()}
|
||||
)
|
||||
|
||||
# Validate the input data for the next node.
|
||||
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
|
||||
@@ -342,9 +376,11 @@ async def _enqueue_next_nodes(
|
||||
for input_name in static_link_names:
|
||||
idata[input_name] = next_node_input[input_name]
|
||||
|
||||
# Apply node input overrides
|
||||
if node_input_mask:
|
||||
idata.update(node_input_mask)
|
||||
# Apply node credentials overrides
|
||||
if node_credentials:
|
||||
idata.update(
|
||||
{k: v.model_dump() for k, v in node_credentials.items()}
|
||||
)
|
||||
|
||||
idata, msg = validate_exec(next_node, idata)
|
||||
suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}"
|
||||
@@ -393,15 +429,16 @@ class Executor:
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@async_error_logged(swallow=True)
|
||||
@async_error_logged
|
||||
async def on_node_execution(
|
||||
cls,
|
||||
node_exec: NodeExecutionEntry,
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
) -> NodeExecutionStats:
|
||||
log_metadata = LogMetadata(
|
||||
logger=_logger,
|
||||
user_id=node_exec.user_id,
|
||||
graph_eid=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
@@ -420,7 +457,7 @@ class Executor:
|
||||
db_client=db_client,
|
||||
log_metadata=log_metadata,
|
||||
stats=execution_stats,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
)
|
||||
execution_stats.walltime = timing_info.wall_time
|
||||
execution_stats.cputime = timing_info.cpu_time
|
||||
@@ -443,7 +480,9 @@ class Executor:
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
log_metadata: LogMetadata,
|
||||
stats: NodeExecutionStats | None = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
):
|
||||
try:
|
||||
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
|
||||
@@ -458,7 +497,7 @@ class Executor:
|
||||
creds_manager=cls.creds_manager,
|
||||
data=node_exec,
|
||||
execution_stats=stats,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
):
|
||||
node_exec_progress.add_output(
|
||||
ExecutionOutputEntry(
|
||||
@@ -502,12 +541,11 @@ class Executor:
|
||||
logger.info(f"[GraphExecutor] {cls.pid} started")
|
||||
|
||||
@classmethod
|
||||
@error_logged(swallow=False)
|
||||
@error_logged
|
||||
def on_graph_execution(
|
||||
cls, graph_exec: GraphExecutionEntry, cancel: threading.Event
|
||||
):
|
||||
log_metadata = LogMetadata(
|
||||
logger=_logger,
|
||||
user_id=graph_exec.user_id,
|
||||
graph_eid=graph_exec.graph_exec_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
@@ -555,15 +593,6 @@ class Executor:
|
||||
exec_stats.cputime += timing_info.cpu_time
|
||||
exec_stats.error = str(error) if error else exec_stats.error
|
||||
|
||||
if status not in {
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.TERMINATED,
|
||||
ExecutionStatus.FAILED,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Graph Execution #{graph_exec.graph_exec_id} ended with unexpected status {status}"
|
||||
)
|
||||
|
||||
if graph_exec_result := db_client.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
status=status,
|
||||
@@ -667,6 +696,7 @@ class Executor:
|
||||
|
||||
if _graph_exec := db_client.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
status=execution_status,
|
||||
stats=execution_stats,
|
||||
):
|
||||
send_execution_update(_graph_exec)
|
||||
@@ -748,19 +778,24 @@ class Executor:
|
||||
)
|
||||
raise
|
||||
|
||||
# Add input overrides -----------------------------
|
||||
# Add credential overrides -----------------------------
|
||||
node_id = queued_node_exec.node_id
|
||||
if (nodes_input_masks := graph_exec.nodes_input_masks) and (
|
||||
node_input_mask := nodes_input_masks.get(node_id)
|
||||
if (node_creds_map := graph_exec.node_credentials_input_map) and (
|
||||
node_field_creds_map := node_creds_map.get(node_id)
|
||||
):
|
||||
queued_node_exec.inputs.update(node_input_mask)
|
||||
queued_node_exec.inputs.update(
|
||||
{
|
||||
field_name: creds_meta.model_dump()
|
||||
for field_name, creds_meta in node_field_creds_map.items()
|
||||
}
|
||||
)
|
||||
|
||||
# Kick off async node execution -------------------------
|
||||
node_execution_task = asyncio.run_coroutine_threadsafe(
|
||||
cls.on_node_execution(
|
||||
node_exec=queued_node_exec,
|
||||
node_exec_progress=running_node_execution[node_id],
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
node_credentials_input_map=node_creds_map,
|
||||
),
|
||||
cls.node_execution_loop,
|
||||
)
|
||||
@@ -804,7 +839,7 @@ class Executor:
|
||||
node_id=node_id,
|
||||
graph_exec=graph_exec,
|
||||
log_metadata=log_metadata,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
node_creds_map=node_creds_map,
|
||||
execution_queue=execution_queue,
|
||||
),
|
||||
cls.node_evaluation_loop,
|
||||
@@ -835,7 +870,6 @@ class Executor:
|
||||
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
|
||||
)
|
||||
finally:
|
||||
# Cancel and wait for all node executions to complete
|
||||
for node_id, inflight_exec in running_node_execution.items():
|
||||
if inflight_exec.is_done():
|
||||
continue
|
||||
@@ -848,28 +882,6 @@ class Executor:
|
||||
log_metadata.info(f"Stopping node evaluation {node_id}")
|
||||
inflight_eval.cancel()
|
||||
|
||||
for node_id, inflight_exec in running_node_execution.items():
|
||||
if inflight_exec.is_done():
|
||||
continue
|
||||
try:
|
||||
inflight_exec.wait_for_cancellation(timeout=60.0)
|
||||
except TimeoutError:
|
||||
log_metadata.exception(
|
||||
f"Node execution #{node_id} did not stop in time, "
|
||||
"it may be stuck or taking too long."
|
||||
)
|
||||
|
||||
for node_id, inflight_eval in running_node_evaluation.items():
|
||||
if inflight_eval.done():
|
||||
continue
|
||||
try:
|
||||
inflight_eval.result(timeout=60.0)
|
||||
except TimeoutError:
|
||||
log_metadata.exception(
|
||||
f"Node evaluation #{node_id} did not stop in time, "
|
||||
"it may be stuck or taking too long."
|
||||
)
|
||||
|
||||
if execution_status in [ExecutionStatus.TERMINATED, ExecutionStatus.FAILED]:
|
||||
inflight_executions = db_client.get_node_executions(
|
||||
graph_exec.graph_exec_id,
|
||||
@@ -897,7 +909,7 @@ class Executor:
|
||||
node_id: str,
|
||||
graph_exec: GraphExecutionEntry,
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
|
||||
node_creds_map: Optional[dict[str, dict[str, CredentialsMetaInput]]],
|
||||
execution_queue: ExecutionQueue[NodeExecutionEntry],
|
||||
) -> None:
|
||||
"""Process a node's output, update its status, and enqueue next nodes.
|
||||
@@ -907,7 +919,7 @@ class Executor:
|
||||
node_id: The ID of the node that produced the output
|
||||
graph_exec: The graph execution entry
|
||||
log_metadata: Logger metadata for consistent logging
|
||||
nodes_input_masks: Optional map of node input overrides
|
||||
node_creds_map: Optional map of node credentials
|
||||
execution_queue: Queue to add next executions to
|
||||
"""
|
||||
db_client = get_db_async_client()
|
||||
@@ -931,7 +943,7 @@ class Executor:
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
log_metadata=log_metadata,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
node_credentials_input_map=node_creds_map,
|
||||
):
|
||||
execution_queue.add(next_execution)
|
||||
except Exception as e:
|
||||
|
||||
@@ -3,7 +3,6 @@ import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
|
||||
@@ -15,16 +14,13 @@ from apscheduler.triggers.cron import CronTrigger
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from dotenv import load_dotenv
|
||||
from prisma.enums import NotificationType
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
from backend.util.logging import PrefixFilter
|
||||
from backend.util.metrics import sentry_capture_error
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
@@ -56,19 +52,19 @@ def _extract_schema_from_url(database_url) -> tuple[str, str]:
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.addFilter(PrefixFilter("[Scheduler]"))
|
||||
apscheduler_logger = logger.getChild("apscheduler")
|
||||
apscheduler_logger.addFilter(PrefixFilter("[Scheduler] [APScheduler]"))
|
||||
|
||||
config = Config()
|
||||
|
||||
|
||||
def log(msg, **kwargs):
|
||||
logger.info("[Scheduler] " + msg, **kwargs)
|
||||
|
||||
|
||||
def job_listener(event):
|
||||
"""Logs job execution outcomes for better monitoring."""
|
||||
if event.exception:
|
||||
logger.error(f"Job {event.job_id} failed.")
|
||||
log(f"Job {event.job_id} failed.")
|
||||
else:
|
||||
logger.info(f"Job {event.job_id} completed successfully.")
|
||||
log(f"Job {event.job_id} completed successfully.")
|
||||
|
||||
|
||||
@thread_cached
|
||||
@@ -88,17 +84,16 @@ def execute_graph(**kwargs):
|
||||
async def _execute_graph(**kwargs):
|
||||
args = GraphExecutionJobArgs(**kwargs)
|
||||
try:
|
||||
logger.info(f"Executing recurring job for graph #{args.graph_id}")
|
||||
log(f"Executing recurring job for graph #{args.graph_id}")
|
||||
await execution_utils.add_graph_execution(
|
||||
user_id=args.user_id,
|
||||
graph_id=args.graph_id,
|
||||
graph_version=args.graph_version,
|
||||
inputs=args.input_data,
|
||||
graph_credentials_inputs=args.input_credentials,
|
||||
user_id=args.user_id,
|
||||
graph_version=args.graph_version,
|
||||
use_db_query=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing graph {args.graph_id}: {e}")
|
||||
logger.exception(f"Error executing graph {args.graph_id}: {e}")
|
||||
|
||||
|
||||
class LateExecutionException(Exception):
|
||||
@@ -142,20 +137,20 @@ def report_late_executions() -> str:
|
||||
def process_existing_batches(**kwargs):
|
||||
args = NotificationJobArgs(**kwargs)
|
||||
try:
|
||||
logger.info(
|
||||
log(
|
||||
f"Processing existing batches for notification type {args.notification_types}"
|
||||
)
|
||||
get_notification_client().process_existing_batches(args.notification_types)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing existing batches: {e}")
|
||||
logger.exception(f"Error processing existing batches: {e}")
|
||||
|
||||
|
||||
def process_weekly_summary(**kwargs):
|
||||
try:
|
||||
logger.info("Processing weekly summary")
|
||||
log("Processing weekly summary")
|
||||
get_notification_client().queue_weekly_summary()
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing weekly summary: {e}")
|
||||
logger.exception(f"Error processing weekly summary: {e}")
|
||||
|
||||
|
||||
class Jobstores(Enum):
|
||||
@@ -165,12 +160,11 @@ class Jobstores(Enum):
|
||||
|
||||
|
||||
class GraphExecutionJobArgs(BaseModel):
|
||||
user_id: str
|
||||
graph_id: str
|
||||
input_data: BlockInput
|
||||
user_id: str
|
||||
graph_version: int
|
||||
cron: str
|
||||
input_data: BlockInput
|
||||
input_credentials: dict[str, CredentialsMetaInput] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class GraphExecutionJobInfo(GraphExecutionJobArgs):
|
||||
@@ -253,8 +247,7 @@ class Scheduler(AppService):
|
||||
),
|
||||
# These don't really need persistence
|
||||
Jobstores.WEEKLY_NOTIFICATIONS.value: MemoryJobStore(),
|
||||
},
|
||||
logger=apscheduler_logger,
|
||||
}
|
||||
)
|
||||
|
||||
if self.register_system_tasks:
|
||||
@@ -292,40 +285,34 @@ class Scheduler(AppService):
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info("⏳ Shutting down scheduler...")
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down scheduler...")
|
||||
if self.scheduler:
|
||||
self.scheduler.shutdown(wait=False)
|
||||
|
||||
@expose
|
||||
def add_graph_execution_schedule(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
cron: str,
|
||||
input_data: BlockInput,
|
||||
input_credentials: dict[str, CredentialsMetaInput],
|
||||
name: Optional[str] = None,
|
||||
user_id: str,
|
||||
) -> GraphExecutionJobInfo:
|
||||
job_args = GraphExecutionJobArgs(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
graph_version=graph_version,
|
||||
cron=cron,
|
||||
input_data=input_data,
|
||||
input_credentials=input_credentials,
|
||||
)
|
||||
job = self.scheduler.add_job(
|
||||
execute_graph,
|
||||
CronTrigger.from_crontab(cron),
|
||||
kwargs=job_args.model_dump(),
|
||||
name=name,
|
||||
trigger=CronTrigger.from_crontab(cron),
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
replace_existing=True,
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
logger.info(
|
||||
f"Added job {job.id} with cron schedule '{cron}' input data: {input_data}"
|
||||
)
|
||||
log(f"Added job {job.id} with cron schedule '{cron}' input data: {input_data}")
|
||||
return GraphExecutionJobInfo.from_db(job_args, job)
|
||||
|
||||
@expose
|
||||
@@ -334,13 +321,14 @@ class Scheduler(AppService):
|
||||
) -> GraphExecutionJobInfo:
|
||||
job = self.scheduler.get_job(schedule_id, jobstore=Jobstores.EXECUTION.value)
|
||||
if not job:
|
||||
raise NotFoundError(f"Job #{schedule_id} not found.")
|
||||
log(f"Job {schedule_id} not found.")
|
||||
raise ValueError(f"Job #{schedule_id} not found.")
|
||||
|
||||
job_args = GraphExecutionJobArgs(**job.kwargs)
|
||||
if job_args.user_id != user_id:
|
||||
raise NotAuthorizedError("User ID does not match the job's user ID")
|
||||
raise ValueError("User ID does not match the job's user ID.")
|
||||
|
||||
logger.info(f"Deleting job {schedule_id}")
|
||||
log(f"Deleting job {schedule_id}")
|
||||
job.remove()
|
||||
|
||||
return GraphExecutionJobInfo.from_db(job_args, job)
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from pydantic import BaseModel, JsonValue
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockData,
|
||||
@@ -26,8 +23,12 @@ from backend.data.execution import (
|
||||
GraphExecutionStats,
|
||||
GraphExecutionWithNodes,
|
||||
RedisExecutionEventBus,
|
||||
create_graph_execution,
|
||||
get_node_executions,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_status_batch,
|
||||
)
|
||||
from backend.data.graph import GraphModel, Node
|
||||
from backend.data.graph import GraphModel, Node, get_graph
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.rabbitmq import (
|
||||
AsyncRabbitMQ,
|
||||
@@ -54,36 +55,6 @@ logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil
|
||||
# ============ Resource Helpers ============ #
|
||||
|
||||
|
||||
class LogMetadata(TruncatedLogger):
|
||||
def __init__(
|
||||
self,
|
||||
logger: logging.Logger,
|
||||
user_id: str,
|
||||
graph_eid: str,
|
||||
graph_id: str,
|
||||
node_eid: str,
|
||||
node_id: str,
|
||||
block_name: str,
|
||||
max_length: int = 1000,
|
||||
):
|
||||
metadata = {
|
||||
"component": "ExecutionManager",
|
||||
"user_id": user_id,
|
||||
"graph_eid": graph_eid,
|
||||
"graph_id": graph_id,
|
||||
"node_eid": node_eid,
|
||||
"node_id": node_id,
|
||||
"block_name": block_name,
|
||||
}
|
||||
prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|neid:{node_eid}|{block_name}]"
|
||||
super().__init__(
|
||||
logger,
|
||||
max_length=max_length,
|
||||
prefix=prefix,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_execution_event_bus() -> RedisExecutionEventBus:
|
||||
return RedisExecutionEventBus()
|
||||
@@ -431,6 +402,12 @@ def validate_exec(
|
||||
return None, f"Block for {node.block_id} not found."
|
||||
schema = node_block.input_schema
|
||||
|
||||
# Convert non-matching data types to the expected input schema.
|
||||
for name, data_type in schema.__annotations__.items():
|
||||
value = data.get(name)
|
||||
if (value is not None) and (type(value) is not data_type):
|
||||
data[name] = convert(value, data_type)
|
||||
|
||||
# Input data (without default values) should contain all required fields.
|
||||
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
|
||||
if missing_links := schema.get_missing_links(data, node.input_links):
|
||||
@@ -442,12 +419,6 @@ def validate_exec(
|
||||
if resolve_input:
|
||||
data = merge_execution_input(data)
|
||||
|
||||
# Convert non-matching data types to the expected input schema.
|
||||
for name, data_type in schema.__annotations__.items():
|
||||
value = data.get(name)
|
||||
if (value is not None) and (type(value) is not data_type):
|
||||
data[name] = convert(value, data_type)
|
||||
|
||||
# Input data post-merge should contain all required fields from the schema.
|
||||
if missing_input := schema.get_missing_input(data):
|
||||
return None, f"{error_prefix} missing input {missing_input}"
|
||||
@@ -464,7 +435,9 @@ def validate_exec(
|
||||
async def _validate_node_input_credentials(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
):
|
||||
"""Checks all credentials for all nodes of the graph"""
|
||||
|
||||
@@ -480,13 +453,11 @@ async def _validate_node_input_credentials(
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
if (
|
||||
nodes_input_masks
|
||||
and (node_input_mask := nodes_input_masks.get(node.id))
|
||||
and field_name in node_input_mask
|
||||
node_credentials_input_map
|
||||
and (node_credentials_inputs := node_credentials_input_map.get(node.id))
|
||||
and field_name in node_credentials_inputs
|
||||
):
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node_input_mask[field_name]
|
||||
)
|
||||
credentials_meta = node_credentials_input_map[node.id][field_name]
|
||||
elif field_name in node.input_default:
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
@@ -525,7 +496,7 @@ async def _validate_node_input_credentials(
|
||||
def make_node_credentials_input_map(
|
||||
graph: GraphModel,
|
||||
graph_credentials_input: dict[str, CredentialsMetaInput],
|
||||
) -> dict[str, dict[str, JsonValue]]:
|
||||
) -> dict[str, dict[str, CredentialsMetaInput]]:
|
||||
"""
|
||||
Maps credentials for an execution to the correct nodes.
|
||||
|
||||
@@ -534,9 +505,9 @@ def make_node_credentials_input_map(
|
||||
graph_credentials_input: A (graph_input_name, credentials_meta) map.
|
||||
|
||||
Returns:
|
||||
dict[node_id, dict[field_name, CredentialsMetaRaw]]: Node credentials input map.
|
||||
dict[node_id, dict[field_name, CredentialsMetaInput]]: Node credentials input map.
|
||||
"""
|
||||
result: dict[str, dict[str, JsonValue]] = {}
|
||||
result: dict[str, dict[str, CredentialsMetaInput]] = {}
|
||||
|
||||
# Get aggregated credentials fields for the graph
|
||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||
@@ -550,9 +521,7 @@ def make_node_credentials_input_map(
|
||||
for node_id, node_field_name in compatible_node_fields:
|
||||
if node_id not in result:
|
||||
result[node_id] = {}
|
||||
result[node_id][node_field_name] = graph_credentials_input[
|
||||
graph_input_name
|
||||
].model_dump(exclude_none=True)
|
||||
result[node_id][node_field_name] = graph_credentials_input[graph_input_name]
|
||||
|
||||
return result
|
||||
|
||||
@@ -561,7 +530,9 @@ async def construct_node_execution_input(
|
||||
graph: GraphModel,
|
||||
user_id: str,
|
||||
graph_inputs: BlockInput,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
) -> list[tuple[str, BlockInput]]:
|
||||
"""
|
||||
Validates and prepares the input data for executing a graph.
|
||||
@@ -579,8 +550,8 @@ async def construct_node_execution_input(
|
||||
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
|
||||
the corresponding input data for that node.
|
||||
"""
|
||||
graph.validate_graph(for_run=True, nodes_input_masks=nodes_input_masks)
|
||||
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
|
||||
graph.validate_graph(for_run=True)
|
||||
await _validate_node_input_credentials(graph, user_id, node_credentials_input_map)
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
@@ -597,9 +568,23 @@ async def construct_node_execution_input(
|
||||
if input_name and input_name in graph_inputs:
|
||||
input_data = {"value": graph_inputs[input_name]}
|
||||
|
||||
# Apply node input overrides
|
||||
if nodes_input_masks and (node_input_mask := nodes_input_masks.get(node.id)):
|
||||
input_data.update(node_input_mask)
|
||||
# Extract webhook payload, and assign it to the input pin
|
||||
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
|
||||
if (
|
||||
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
and node.webhook_id
|
||||
):
|
||||
if webhook_payload_key not in graph_inputs:
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} webhook payload is missing"
|
||||
)
|
||||
input_data = {"payload": graph_inputs[webhook_payload_key]}
|
||||
|
||||
# Apply node credentials overrides
|
||||
if node_credentials_input_map and (
|
||||
node_credentials := node_credentials_input_map.get(node.id)
|
||||
):
|
||||
input_data.update({k: v.model_dump() for k, v in node_credentials.items()})
|
||||
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
if input_data is None:
|
||||
@@ -615,20 +600,6 @@ async def construct_node_execution_input(
|
||||
return nodes_input
|
||||
|
||||
|
||||
def _merge_nodes_input_masks(
|
||||
overrides_map_1: dict[str, dict[str, JsonValue]],
|
||||
overrides_map_2: dict[str, dict[str, JsonValue]],
|
||||
) -> dict[str, dict[str, JsonValue]]:
|
||||
"""Perform a per-node merge of input overrides"""
|
||||
result = overrides_map_1.copy()
|
||||
for node_id, overrides2 in overrides_map_2.items():
|
||||
if node_id in result:
|
||||
result[node_id] = {**result[node_id], **overrides2}
|
||||
else:
|
||||
result[node_id] = overrides2
|
||||
return result
|
||||
|
||||
|
||||
# ============ Execution Queue Helpers ============ #
|
||||
|
||||
|
||||
@@ -682,10 +653,8 @@ def create_execution_queue_config() -> RabbitMQConfig:
|
||||
|
||||
|
||||
async def stop_graph_execution(
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
use_db_query: bool = True,
|
||||
wait_timeout: float = 60.0,
|
||||
):
|
||||
"""
|
||||
Mechanism:
|
||||
@@ -695,67 +664,79 @@ async def stop_graph_execution(
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
queue_client = await get_async_execution_queue()
|
||||
db = execution_db if use_db_query else get_db_async_client()
|
||||
await queue_client.publish_message(
|
||||
routing_key="",
|
||||
message=CancelExecutionEvent(graph_exec_id=graph_exec_id).model_dump_json(),
|
||||
exchange=GRAPH_EXECUTION_CANCEL_EXCHANGE,
|
||||
)
|
||||
|
||||
if not wait_timeout:
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < wait_timeout:
|
||||
graph_exec = await db.get_graph_execution_meta(
|
||||
execution_id=graph_exec_id, user_id=user_id
|
||||
# Update the status of the graph execution
|
||||
if use_db_query:
|
||||
graph_execution = await update_graph_execution_stats(
|
||||
graph_exec_id,
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
else:
|
||||
graph_execution = await get_db_async_client().update_graph_execution_stats(
|
||||
graph_exec_id,
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
|
||||
if not graph_exec:
|
||||
raise NotFoundError(f"Graph execution #{graph_exec_id} not found.")
|
||||
if graph_execution:
|
||||
await get_async_execution_event_bus().publish(graph_execution)
|
||||
else:
|
||||
raise NotFoundError(
|
||||
f"Graph execution #{graph_exec_id} not found for termination."
|
||||
)
|
||||
|
||||
if graph_exec.status in [
|
||||
# Update the status of the node executions
|
||||
if use_db_query:
|
||||
node_executions = await get_node_executions(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
)
|
||||
await update_node_execution_status_batch(
|
||||
[v.node_exec_id for v in node_executions],
|
||||
ExecutionStatus.TERMINATED,
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.FAILED,
|
||||
]:
|
||||
# If graph execution is terminated/completed/failed, cancellation is complete
|
||||
return
|
||||
)
|
||||
else:
|
||||
node_executions = await get_db_async_client().get_node_executions(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
)
|
||||
await get_db_async_client().update_node_execution_status_batch(
|
||||
[v.node_exec_id for v in node_executions],
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
|
||||
elif graph_exec.status in [
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
]:
|
||||
# If the graph is still on the queue, we can prevent them from being executed
|
||||
# by setting the status to TERMINATED.
|
||||
node_execs = await db.get_node_executions(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[ExecutionStatus.QUEUED, ExecutionStatus.INCOMPLETE],
|
||||
await asyncio.gather(
|
||||
*[
|
||||
get_async_execution_event_bus().publish(
|
||||
v.model_copy(update={"status": ExecutionStatus.TERMINATED})
|
||||
)
|
||||
await db.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in node_execs],
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
await db.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec_id,
|
||||
status=ExecutionStatus.TERMINATED,
|
||||
)
|
||||
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for graph execution #{graph_exec_id} to terminate."
|
||||
for v in node_executions
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def add_graph_execution(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
inputs: Optional[BlockInput] = None,
|
||||
inputs: BlockInput,
|
||||
preset_id: Optional[str] = None,
|
||||
graph_version: Optional[int] = None,
|
||||
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
node_credentials_input_map: Optional[
|
||||
dict[str, dict[str, CredentialsMetaInput]]
|
||||
] = None,
|
||||
use_db_query: bool = True,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
@@ -769,52 +750,68 @@ async def add_graph_execution(
|
||||
graph_version: The version of the graph to execute.
|
||||
graph_credentials_inputs: Credentials inputs to use in the execution.
|
||||
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
|
||||
nodes_input_masks: Node inputs to use in the execution.
|
||||
node_credentials_input_map: Credentials inputs to use in the execution, mapped to specific nodes.
|
||||
Returns:
|
||||
GraphExecutionEntry: The entry for the graph execution.
|
||||
Raises:
|
||||
ValueError: If the graph is not found or if there are validation errors.
|
||||
"""
|
||||
gdb = graph_db if use_db_query else get_db_async_client()
|
||||
edb = execution_db if use_db_query else get_db_async_client()
|
||||
""" # noqa
|
||||
if use_db_query:
|
||||
graph: GraphModel | None = await get_graph(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
version=graph_version,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
else:
|
||||
graph: GraphModel | None = await get_db_async_client().get_graph(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
version=graph_version,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
|
||||
graph: GraphModel | None = await gdb.get_graph(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
version=graph_version,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
if not graph:
|
||||
raise NotFoundError(f"Graph #{graph_id} not found.")
|
||||
|
||||
nodes_input_masks = _merge_nodes_input_masks(
|
||||
(
|
||||
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
||||
if graph_credentials_inputs
|
||||
else {}
|
||||
),
|
||||
nodes_input_masks or {},
|
||||
)
|
||||
starting_nodes_input = await construct_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=inputs or {},
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
node_credentials_input_map = node_credentials_input_map or (
|
||||
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
||||
if graph_credentials_inputs
|
||||
else None
|
||||
)
|
||||
|
||||
graph_exec = await edb.create_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
starting_nodes_input=starting_nodes_input,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
if use_db_query:
|
||||
graph_exec = await create_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
starting_nodes_input=await construct_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=inputs,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
),
|
||||
preset_id=preset_id,
|
||||
)
|
||||
else:
|
||||
graph_exec = await get_db_async_client().create_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
starting_nodes_input=await construct_node_execution_input(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
graph_inputs=inputs,
|
||||
node_credentials_input_map=node_credentials_input_map,
|
||||
),
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
try:
|
||||
queue = await get_async_execution_queue()
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry()
|
||||
if nodes_input_masks:
|
||||
graph_exec_entry.nodes_input_masks = nodes_input_masks
|
||||
if node_credentials_input_map:
|
||||
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
|
||||
await queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec_entry.model_dump_json(),
|
||||
@@ -827,15 +824,28 @@ async def add_graph_execution(
|
||||
return graph_exec
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {e}")
|
||||
await edb.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
|
||||
ExecutionStatus.FAILED,
|
||||
)
|
||||
await edb.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
stats=GraphExecutionStats(error=str(e)),
|
||||
)
|
||||
|
||||
if use_db_query:
|
||||
await update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
|
||||
ExecutionStatus.FAILED,
|
||||
)
|
||||
await update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
stats=GraphExecutionStats(error=str(e)),
|
||||
)
|
||||
else:
|
||||
await get_db_async_client().update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
|
||||
ExecutionStatus.FAILED,
|
||||
)
|
||||
await get_db_async_client().update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
stats=GraphExecutionStats(error=str(e)),
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
|
||||
@@ -890,10 +900,14 @@ class NodeExecutionProgress:
|
||||
try:
|
||||
self.tasks[exec_id].result(wait_time)
|
||||
except TimeoutError:
|
||||
print(
|
||||
">>>>>>> -- Timeout, after waiting for",
|
||||
wait_time,
|
||||
"seconds for node_id",
|
||||
exec_id,
|
||||
)
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Task for exec ID {exec_id} failed with error: {str(e)}")
|
||||
pass
|
||||
|
||||
return self.is_done(0)
|
||||
|
||||
def stop(self) -> list[str]:
|
||||
@@ -910,25 +924,6 @@ class NodeExecutionProgress:
|
||||
cancelled_ids.append(task_id)
|
||||
return cancelled_ids
|
||||
|
||||
def wait_for_cancellation(self, timeout: float = 5.0):
|
||||
"""
|
||||
Wait for all cancelled tasks to complete cancellation.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for cancellation in seconds
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
# Check if all tasks are done (either completed or cancelled)
|
||||
if all(task.done() for task in self.tasks.values()):
|
||||
return True
|
||||
time.sleep(0.1) # Small delay to avoid busy waiting
|
||||
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for cancellation of tasks: {list(self.tasks.keys())}"
|
||||
)
|
||||
|
||||
def _pop_done_task(self, exec_id: str) -> bool:
|
||||
task = self.tasks.get(exec_id)
|
||||
if not task:
|
||||
@@ -941,10 +936,8 @@ class NodeExecutionProgress:
|
||||
return False
|
||||
|
||||
if task := self.tasks.pop(exec_id):
|
||||
try:
|
||||
self.on_done_task(exec_id, task.result())
|
||||
except Exception as e:
|
||||
logger.error(f"Task for exec ID {exec_id} failed with error: {str(e)}")
|
||||
self.on_done_task(exec_id, task.result())
|
||||
|
||||
return True
|
||||
|
||||
def _next_exec(self) -> str | None:
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.redis import get_redis_async
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor.database import DatabaseManagerAsyncClient
|
||||
|
||||
@@ -7,7 +7,7 @@ from autogpt_libs.utils.synchronize import AsyncRedisKeyedMutex
|
||||
from redis.asyncio.lock import Lock as AsyncRedisLock
|
||||
|
||||
from backend.data.model import Credentials, OAuth2Credentials
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.redis import get_redis_async
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
@@ -17,7 +17,6 @@ class ProviderName(str, Enum):
|
||||
GOOGLE = "google"
|
||||
GOOGLE_MAPS = "google_maps"
|
||||
GROQ = "groq"
|
||||
HTTP = "http"
|
||||
HUBSPOT = "hubspot"
|
||||
IDEOGRAM = "ideogram"
|
||||
JINA = "jina"
|
||||
|
||||
@@ -1,22 +1,23 @@
|
||||
import functools
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..providers import ProviderName
|
||||
from ._base import BaseWebhooksManager
|
||||
|
||||
_WEBHOOK_MANAGERS: dict["ProviderName", type["BaseWebhooksManager"]] = {}
|
||||
|
||||
|
||||
# --8<-- [start:load_webhook_managers]
|
||||
@functools.cache
|
||||
def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]:
|
||||
webhook_managers = {}
|
||||
if _WEBHOOK_MANAGERS:
|
||||
return _WEBHOOK_MANAGERS
|
||||
|
||||
from .compass import CompassWebhookManager
|
||||
from .generic import GenericWebhooksManager
|
||||
from .github import GithubWebhooksManager
|
||||
from .slant3d import Slant3DWebhooksManager
|
||||
|
||||
webhook_managers.update(
|
||||
_WEBHOOK_MANAGERS.update(
|
||||
{
|
||||
handler.PROVIDER_NAME: handler
|
||||
for handler in [
|
||||
@@ -27,7 +28,7 @@ def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]
|
||||
]
|
||||
}
|
||||
)
|
||||
return webhook_managers
|
||||
return _WEBHOOK_MANAGERS
|
||||
|
||||
|
||||
# --8<-- [end:load_webhook_managers]
|
||||
|
||||
@@ -7,14 +7,13 @@ from uuid import uuid4
|
||||
from fastapi import Request
|
||||
from strenum import StrEnum
|
||||
|
||||
import backend.data.integrations as integrations
|
||||
from backend.data import integrations
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.utils import webhook_ingress_url
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
from backend.util.settings import Config
|
||||
|
||||
from .utils import webhook_ingress_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
app_config = Config()
|
||||
|
||||
@@ -42,74 +41,44 @@ class BaseWebhooksManager(ABC, Generic[WT]):
|
||||
)
|
||||
|
||||
if webhook := await integrations.find_webhook_by_credentials_and_props(
|
||||
user_id=user_id,
|
||||
credentials_id=credentials.id,
|
||||
webhook_type=webhook_type,
|
||||
resource=resource,
|
||||
events=events,
|
||||
credentials.id, webhook_type, resource, events
|
||||
):
|
||||
return webhook
|
||||
|
||||
return await self._create_webhook(
|
||||
user_id=user_id,
|
||||
webhook_type=webhook_type,
|
||||
events=events,
|
||||
resource=resource,
|
||||
credentials=credentials,
|
||||
user_id, webhook_type, events, resource, credentials
|
||||
)
|
||||
|
||||
async def get_manual_webhook(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
webhook_type: WT,
|
||||
events: list[str],
|
||||
graph_id: Optional[str] = None,
|
||||
preset_id: Optional[str] = None,
|
||||
) -> integrations.Webhook:
|
||||
"""
|
||||
Tries to find an existing webhook tied to `graph_id`/`preset_id`,
|
||||
or creates a new webhook if none exists.
|
||||
|
||||
Existing webhooks are matched by `user_id`, `webhook_type`,
|
||||
and `graph_id`/`preset_id`.
|
||||
|
||||
If an existing webhook is found, we check if the events match and update them
|
||||
if necessary. We do this rather than creating a new webhook
|
||||
to avoid changing the webhook URL for existing manual webhooks.
|
||||
"""
|
||||
if (graph_id or preset_id) and (
|
||||
current_webhook := await integrations.find_webhook_by_graph_and_props(
|
||||
user_id=user_id,
|
||||
provider=self.PROVIDER_NAME.value,
|
||||
webhook_type=webhook_type,
|
||||
graph_id=graph_id,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
):
|
||||
if current_webhook := await integrations.find_webhook_by_graph_and_props(
|
||||
graph_id, self.PROVIDER_NAME, webhook_type, events
|
||||
):
|
||||
if set(current_webhook.events) != set(events):
|
||||
current_webhook = await integrations.update_webhook(
|
||||
current_webhook.id, events=events
|
||||
)
|
||||
return current_webhook
|
||||
|
||||
return await self._create_webhook(
|
||||
user_id=user_id,
|
||||
webhook_type=webhook_type,
|
||||
events=events,
|
||||
user_id,
|
||||
webhook_type,
|
||||
events,
|
||||
register=False,
|
||||
)
|
||||
|
||||
async def prune_webhook_if_dangling(
|
||||
self, user_id: str, webhook_id: str, credentials: Optional[Credentials]
|
||||
self, webhook_id: str, credentials: Optional[Credentials]
|
||||
) -> bool:
|
||||
webhook = await integrations.get_webhook(webhook_id, include_relations=True)
|
||||
if webhook.triggered_nodes or webhook.triggered_presets:
|
||||
webhook = await integrations.get_webhook(webhook_id)
|
||||
if webhook.attached_nodes is None:
|
||||
raise ValueError("Error retrieving webhook including attached nodes")
|
||||
if webhook.attached_nodes:
|
||||
# Don't prune webhook if in use
|
||||
return False
|
||||
|
||||
if credentials:
|
||||
await self._deregister_webhook(webhook, credentials)
|
||||
await integrations.delete_webhook(user_id, webhook.id)
|
||||
await integrations.delete_webhook(webhook.id)
|
||||
return True
|
||||
|
||||
# --8<-- [start:BaseWebhooksManager3]
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
|
||||
from backend.data.block import BlockSchema
|
||||
from backend.data.block import BlockSchema, BlockWebhookConfig
|
||||
from backend.data.graph import set_node_webhook
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
from . import get_webhook_manager, supports_webhooks
|
||||
from .utils import setup_webhook_for_block
|
||||
from backend.integrations.webhooks import get_webhook_manager, supports_webhooks
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.graph import GraphModel, NodeModel
|
||||
@@ -83,9 +81,7 @@ async def on_graph_deactivate(graph: "GraphModel", user_id: str):
|
||||
f"credentials #{creds_meta['id']}"
|
||||
)
|
||||
|
||||
updated_node = await on_node_deactivate(
|
||||
user_id, node, credentials=node_credentials
|
||||
)
|
||||
updated_node = await on_node_deactivate(node, credentials=node_credentials)
|
||||
updated_nodes.append(updated_node)
|
||||
|
||||
graph.nodes = updated_nodes
|
||||
@@ -100,25 +96,105 @@ async def on_node_activate(
|
||||
) -> "NodeModel":
|
||||
"""Hook to be called when the node is activated/created"""
|
||||
|
||||
if node.block.webhook_config:
|
||||
new_webhook, feedback = await setup_webhook_for_block(
|
||||
user_id=user_id,
|
||||
trigger_block=node.block,
|
||||
trigger_config=node.input_default,
|
||||
for_graph_id=node.graph_id,
|
||||
block = node.block
|
||||
|
||||
if not block.webhook_config:
|
||||
return node
|
||||
|
||||
provider = block.webhook_config.provider
|
||||
if not supports_webhooks(provider):
|
||||
raise ValueError(
|
||||
f"Block #{block.id} has webhook_config for provider {provider} "
|
||||
"which does not support webhooks"
|
||||
)
|
||||
if new_webhook:
|
||||
node = await set_node_webhook(node.id, new_webhook.id)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Node #{node.id} does not have everything for a webhook: {feedback}"
|
||||
|
||||
logger.debug(
|
||||
f"Activating webhook node #{node.id} with config {block.webhook_config}"
|
||||
)
|
||||
|
||||
webhooks_manager = get_webhook_manager(provider)
|
||||
|
||||
if auto_setup_webhook := isinstance(block.webhook_config, BlockWebhookConfig):
|
||||
try:
|
||||
resource = block.webhook_config.resource_format.format(**node.input_default)
|
||||
except KeyError:
|
||||
resource = None
|
||||
logger.debug(
|
||||
f"Constructed resource string {resource} from input {node.input_default}"
|
||||
)
|
||||
else:
|
||||
resource = "" # not relevant for manual webhooks
|
||||
|
||||
block_input_schema = cast(BlockSchema, block.input_schema)
|
||||
credentials_field_name = next(iter(block_input_schema.get_credentials_fields()), "")
|
||||
credentials_meta = (
|
||||
node.input_default.get(credentials_field_name)
|
||||
if credentials_field_name
|
||||
else None
|
||||
)
|
||||
event_filter_input_name = block.webhook_config.event_filter_input
|
||||
has_everything_for_webhook = (
|
||||
resource is not None
|
||||
and (credentials_meta or not credentials_field_name)
|
||||
and (
|
||||
not event_filter_input_name
|
||||
or (
|
||||
event_filter_input_name in node.input_default
|
||||
and any(
|
||||
is_on
|
||||
for is_on in node.input_default[event_filter_input_name].values()
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if has_everything_for_webhook and resource is not None:
|
||||
logger.debug(f"Node #{node} has everything for a webhook!")
|
||||
if credentials_meta and not credentials:
|
||||
raise ValueError(
|
||||
f"Cannot set up webhook for node #{node.id}: "
|
||||
f"credentials #{credentials_meta['id']} not available"
|
||||
)
|
||||
|
||||
if event_filter_input_name:
|
||||
# Shape of the event filter is enforced in Block.__init__
|
||||
event_filter = cast(dict, node.input_default[event_filter_input_name])
|
||||
events = [
|
||||
block.webhook_config.event_format.format(event=event)
|
||||
for event, enabled in event_filter.items()
|
||||
if enabled is True
|
||||
]
|
||||
logger.debug(f"Webhook events to subscribe to: {', '.join(events)}")
|
||||
else:
|
||||
events = []
|
||||
|
||||
# Find/make and attach a suitable webhook to the node
|
||||
if auto_setup_webhook:
|
||||
assert credentials is not None
|
||||
new_webhook = await webhooks_manager.get_suitable_auto_webhook(
|
||||
user_id,
|
||||
credentials,
|
||||
block.webhook_config.webhook_type,
|
||||
resource,
|
||||
events,
|
||||
)
|
||||
else:
|
||||
# Manual webhook -> no credentials -> don't register but do create
|
||||
new_webhook = await webhooks_manager.get_manual_webhook(
|
||||
user_id,
|
||||
node.graph_id,
|
||||
block.webhook_config.webhook_type,
|
||||
events,
|
||||
)
|
||||
logger.debug(f"Acquired webhook: {new_webhook}")
|
||||
return await set_node_webhook(node.id, new_webhook.id)
|
||||
else:
|
||||
logger.debug(f"Node #{node.id} does not have everything for a webhook")
|
||||
|
||||
return node
|
||||
|
||||
|
||||
async def on_node_deactivate(
|
||||
user_id: str,
|
||||
node: "NodeModel",
|
||||
*,
|
||||
credentials: Optional["Credentials"] = None,
|
||||
@@ -157,9 +233,7 @@ async def on_node_deactivate(
|
||||
f"Pruning{' and deregistering' if credentials else ''} "
|
||||
f"webhook #{webhook.id}"
|
||||
)
|
||||
await webhooks_manager.prune_webhook_if_dangling(
|
||||
user_id, webhook.id, credentials
|
||||
)
|
||||
await webhooks_manager.prune_webhook_if_dangling(webhook.id, credentials)
|
||||
if (
|
||||
cast(BlockSchema, block.input_schema).get_credentials_fields()
|
||||
and not credentials
|
||||
|
||||
@@ -1,22 +1,7 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Config
|
||||
|
||||
from . import get_webhook_manager, supports_webhooks
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.block import Block, BlockSchema
|
||||
from backend.data.integrations import Webhook
|
||||
from backend.data.model import Credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
app_config = Config()
|
||||
credentials_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
# TODO: add test to assert this matches the actual API route
|
||||
@@ -25,122 +10,3 @@ def webhook_ingress_url(provider_name: ProviderName, webhook_id: str) -> str:
|
||||
f"{app_config.platform_base_url}/api/integrations/{provider_name.value}"
|
||||
f"/webhooks/{webhook_id}/ingress"
|
||||
)
|
||||
|
||||
|
||||
async def setup_webhook_for_block(
|
||||
user_id: str,
|
||||
trigger_block: "Block[BlockSchema, BlockSchema]",
|
||||
trigger_config: dict[str, JsonValue], # = Trigger block inputs
|
||||
for_graph_id: Optional[str] = None,
|
||||
for_preset_id: Optional[str] = None,
|
||||
credentials: Optional["Credentials"] = None,
|
||||
) -> tuple["Webhook", None] | tuple[None, str]:
|
||||
"""
|
||||
Utility function to create (and auto-setup if possible) a webhook for a given provider.
|
||||
|
||||
Returns:
|
||||
Webhook: The created or found webhook object, if successful.
|
||||
str: A feedback message, if any required inputs are missing.
|
||||
"""
|
||||
from backend.data.block import BlockWebhookConfig
|
||||
|
||||
if not (trigger_base_config := trigger_block.webhook_config):
|
||||
raise ValueError(f"Block #{trigger_block.id} does not have a webhook_config")
|
||||
|
||||
provider = trigger_base_config.provider
|
||||
if not supports_webhooks(provider):
|
||||
raise NotImplementedError(
|
||||
f"Block #{trigger_block.id} has webhook_config for provider {provider} "
|
||||
"for which we do not have a WebhooksManager"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Setting up webhook for block #{trigger_block.id} with config {trigger_config}"
|
||||
)
|
||||
|
||||
# Check & parse the event filter input, if any
|
||||
events: list[str] = []
|
||||
if event_filter_input_name := trigger_base_config.event_filter_input:
|
||||
if not (event_filter := trigger_config.get(event_filter_input_name)):
|
||||
return None, (
|
||||
f"Cannot set up {provider.value} webhook without event filter input: "
|
||||
f"missing input for '{event_filter_input_name}'"
|
||||
)
|
||||
elif not (
|
||||
# Shape of the event filter is enforced in Block.__init__
|
||||
any((event_filter := cast(dict[str, bool], event_filter)).values())
|
||||
):
|
||||
return None, (
|
||||
f"Cannot set up {provider.value} webhook without any enabled events "
|
||||
f"in event filter input '{event_filter_input_name}'"
|
||||
)
|
||||
|
||||
events = [
|
||||
trigger_base_config.event_format.format(event=event)
|
||||
for event, enabled in event_filter.items()
|
||||
if enabled is True
|
||||
]
|
||||
logger.debug(f"Webhook events to subscribe to: {', '.join(events)}")
|
||||
|
||||
# Check & process prerequisites for auto-setup webhooks
|
||||
if auto_setup_webhook := isinstance(trigger_base_config, BlockWebhookConfig):
|
||||
try:
|
||||
resource = trigger_base_config.resource_format.format(**trigger_config)
|
||||
except KeyError as missing_key:
|
||||
return None, (
|
||||
f"Cannot auto-setup {provider.value} webhook without resource: "
|
||||
f"missing input for '{missing_key}'"
|
||||
)
|
||||
logger.debug(
|
||||
f"Constructed resource string {resource} from input {trigger_config}"
|
||||
)
|
||||
|
||||
creds_field_name = next(
|
||||
# presence of this field is enforced in Block.__init__
|
||||
iter(trigger_block.input_schema.get_credentials_fields())
|
||||
)
|
||||
|
||||
if not (
|
||||
credentials_meta := cast(dict, trigger_config.get(creds_field_name, None))
|
||||
):
|
||||
return None, f"Cannot set up {provider.value} webhook without credentials"
|
||||
elif not (
|
||||
credentials := credentials
|
||||
or await credentials_manager.get(user_id, credentials_meta["id"])
|
||||
):
|
||||
raise ValueError(
|
||||
f"Cannot set up {provider.value} webhook without credentials: "
|
||||
f"credentials #{credentials_meta['id']} not found for user #{user_id}"
|
||||
)
|
||||
elif credentials.provider != provider:
|
||||
raise ValueError(
|
||||
f"Credentials #{credentials.id} do not match provider {provider.value}"
|
||||
)
|
||||
else:
|
||||
# not relevant for manual webhooks:
|
||||
resource = ""
|
||||
credentials = None
|
||||
|
||||
webhooks_manager = get_webhook_manager(provider)
|
||||
|
||||
# Find/make and attach a suitable webhook to the node
|
||||
if auto_setup_webhook:
|
||||
assert credentials is not None
|
||||
webhook = await webhooks_manager.get_suitable_auto_webhook(
|
||||
user_id=user_id,
|
||||
credentials=credentials,
|
||||
webhook_type=trigger_base_config.webhook_type,
|
||||
resource=resource,
|
||||
events=events,
|
||||
)
|
||||
else:
|
||||
# Manual webhook -> no credentials -> don't register but do create
|
||||
webhook = await webhooks_manager.get_manual_webhook(
|
||||
user_id=user_id,
|
||||
webhook_type=trigger_base_config.webhook_type,
|
||||
events=events,
|
||||
graph_id=for_graph_id,
|
||||
preset_id=for_preset_id,
|
||||
)
|
||||
logger.debug(f"Acquired webhook: {webhook}")
|
||||
return webhook, None
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
|
||||
from .routes.v1 import v1_router
|
||||
|
||||
external_app = FastAPI(
|
||||
@@ -10,6 +8,4 @@ external_app = FastAPI(
|
||||
docs_url="/docs",
|
||||
version="1.0",
|
||||
)
|
||||
|
||||
external_app.add_middleware(SecurityHeadersMiddleware)
|
||||
external_app.include_router(v1_router, prefix="/v1")
|
||||
|
||||
@@ -2,19 +2,11 @@ import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Annotated, Awaitable, Literal
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Path,
|
||||
Query,
|
||||
Request,
|
||||
status,
|
||||
)
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.status import HTTP_404_NOT_FOUND
|
||||
|
||||
from backend.data.graph import get_graph, set_node_webhook
|
||||
from backend.data.graph import set_node_webhook
|
||||
from backend.data.integrations import (
|
||||
WebhookEvent,
|
||||
get_all_webhooks_by_creds,
|
||||
@@ -22,18 +14,12 @@ from backend.data.integrations import (
|
||||
publish_webhook_event,
|
||||
wait_for_webhook_event,
|
||||
)
|
||||
from backend.data.model import (
|
||||
Credentials,
|
||||
CredentialsType,
|
||||
HostScopedCredentials,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.server.v2.library.db import set_preset_webhook, update_preset
|
||||
from backend.util.exceptions import NeedConfirmation, NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -87,9 +73,6 @@ class CredentialsMetaResponse(BaseModel):
|
||||
title: str | None
|
||||
scopes: list[str] | None
|
||||
username: str | None
|
||||
host: str | None = Field(
|
||||
default=None, description="Host pattern for host-scoped credentials"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{provider}/callback")
|
||||
@@ -112,10 +95,7 @@ async def callback(
|
||||
|
||||
if not valid_state:
|
||||
logger.warning(f"Invalid or expired state token for user {user_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid or expired state token",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired state token")
|
||||
try:
|
||||
scopes = valid_state.scopes
|
||||
logger.debug(f"Retrieved scopes from state token: {scopes}")
|
||||
@@ -142,12 +122,17 @@ async def callback(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"OAuth2 Code->Token exchange failed for provider {provider.value}: {e}"
|
||||
logger.exception(
|
||||
"OAuth callback for provider %s failed during code exchange: %s. Confirm provider credentials.",
|
||||
provider.value,
|
||||
e,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"OAuth2 callback failed to exchange code for tokens: {str(e)}",
|
||||
status_code=400,
|
||||
detail={
|
||||
"message": str(e),
|
||||
"hint": "Verify OAuth configuration and try again.",
|
||||
},
|
||||
)
|
||||
|
||||
# TODO: Allow specifying `title` to set on `credentials`
|
||||
@@ -164,9 +149,6 @@ async def callback(
|
||||
title=credentials.title,
|
||||
scopes=credentials.scopes,
|
||||
username=credentials.username,
|
||||
host=(
|
||||
credentials.host if isinstance(credentials, HostScopedCredentials) else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -183,7 +165,6 @@ async def list_credentials(
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
@@ -205,7 +186,6 @@ async def list_credentials_by_provider(
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
@@ -221,13 +201,10 @@ async def get_credential(
|
||||
) -> Credentials:
|
||||
credential = await creds_manager.get(user_id, cred_id)
|
||||
if not credential:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="Credentials not found")
|
||||
if credential.provider != provider:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
status_code=404, detail="Credentials do not match the specified provider"
|
||||
)
|
||||
return credential
|
||||
|
||||
@@ -245,8 +222,7 @@ async def create_credentials(
|
||||
await creds_manager.create(user_id, credentials)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to store credentials: {str(e)}",
|
||||
status_code=500, detail=f"Failed to store credentials: {str(e)}"
|
||||
)
|
||||
return credentials
|
||||
|
||||
@@ -280,17 +256,14 @@ async def delete_credentials(
|
||||
) -> CredentialsDeletionResponse | CredentialsDeletionNeedsConfirmationResponse:
|
||||
creds = await creds_manager.store.get_creds_by_id(user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="Credentials not found")
|
||||
if creds.provider != provider:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Credentials do not match the specified provider",
|
||||
status_code=404, detail="Credentials do not match the specified provider"
|
||||
)
|
||||
|
||||
try:
|
||||
await remove_all_webhooks_for_credentials(user_id, creds, force)
|
||||
await remove_all_webhooks_for_credentials(creds, force)
|
||||
except NeedConfirmation as e:
|
||||
return CredentialsDeletionNeedsConfirmationResponse(message=str(e))
|
||||
|
||||
@@ -321,10 +294,16 @@ async def webhook_ingress_generic(
|
||||
logger.debug(f"Received {provider.value} webhook ingress for ID {webhook_id}")
|
||||
webhook_manager = get_webhook_manager(provider)
|
||||
try:
|
||||
webhook = await get_webhook(webhook_id, include_relations=True)
|
||||
webhook = await get_webhook(webhook_id)
|
||||
except NotFoundError as e:
|
||||
logger.warning(f"Webhook payload received for unknown webhook #{webhook_id}")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
logger.warning(
|
||||
"Webhook payload received for unknown webhook %s. Confirm the webhook ID.",
|
||||
webhook_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_404_NOT_FOUND,
|
||||
detail={"message": str(e), "hint": "Check if the webhook ID is correct."},
|
||||
) from e
|
||||
logger.debug(f"Webhook #{webhook_id}: {webhook}")
|
||||
payload, event_type = await webhook_manager.validate_payload(webhook, request)
|
||||
logger.debug(
|
||||
@@ -341,11 +320,11 @@ async def webhook_ingress_generic(
|
||||
await publish_webhook_event(webhook_event)
|
||||
logger.debug(f"Webhook event published: {webhook_event}")
|
||||
|
||||
if not (webhook.triggered_nodes or webhook.triggered_presets):
|
||||
if not webhook.attached_nodes:
|
||||
return
|
||||
|
||||
executions: list[Awaitable] = []
|
||||
for node in webhook.triggered_nodes:
|
||||
for node in webhook.attached_nodes:
|
||||
logger.debug(f"Webhook-attached node: {node}")
|
||||
if not node.is_triggered_by_event_type(event_type):
|
||||
logger.debug(f"Node #{node.id} doesn't trigger on event {event_type}")
|
||||
@@ -356,48 +335,7 @@ async def webhook_ingress_generic(
|
||||
user_id=webhook.user_id,
|
||||
graph_id=node.graph_id,
|
||||
graph_version=node.graph_version,
|
||||
nodes_input_masks={node.id: {"payload": payload}},
|
||||
)
|
||||
)
|
||||
for preset in webhook.triggered_presets:
|
||||
logger.debug(f"Webhook-attached preset: {preset}")
|
||||
if not preset.is_active:
|
||||
logger.debug(f"Preset #{preset.id} is inactive")
|
||||
continue
|
||||
|
||||
graph = await get_graph(preset.graph_id, preset.graph_version, webhook.user_id)
|
||||
if not graph:
|
||||
logger.error(
|
||||
f"User #{webhook.user_id} has preset #{preset.id} for graph "
|
||||
f"#{preset.graph_id} v{preset.graph_version}, "
|
||||
"but no access to the graph itself."
|
||||
)
|
||||
logger.info(f"Automatically deactivating broken preset #{preset.id}")
|
||||
await update_preset(preset.user_id, preset.id, is_active=False)
|
||||
continue
|
||||
if not (trigger_node := graph.webhook_input_node):
|
||||
# NOTE: this should NEVER happen, but we log and handle it gracefully
|
||||
logger.error(
|
||||
f"Preset #{preset.id} is triggered by webhook #{webhook.id}, but graph "
|
||||
f"#{preset.graph_id} v{preset.graph_version} has no webhook input node"
|
||||
)
|
||||
await set_preset_webhook(preset.user_id, preset.id, None)
|
||||
continue
|
||||
if not trigger_node.block.is_triggered_by_event_type(preset.inputs, event_type):
|
||||
logger.debug(f"Preset #{preset.id} doesn't trigger on event {event_type}")
|
||||
continue
|
||||
logger.debug(f"Executing preset #{preset.id} for webhook #{webhook.id}")
|
||||
|
||||
executions.append(
|
||||
add_graph_execution(
|
||||
user_id=webhook.user_id,
|
||||
graph_id=preset.graph_id,
|
||||
preset_id=preset.id,
|
||||
graph_version=preset.graph_version,
|
||||
graph_credentials_inputs=preset.credentials,
|
||||
nodes_input_masks={
|
||||
trigger_node.id: {**preset.inputs, "payload": payload}
|
||||
},
|
||||
inputs={f"webhook_{webhook_id}_payload": payload},
|
||||
)
|
||||
)
|
||||
asyncio.gather(*executions)
|
||||
@@ -422,9 +360,7 @@ async def webhook_ping(
|
||||
return False
|
||||
|
||||
if not await wait_for_webhook_event(webhook_id, event_type="ping", timeout=10):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="Webhook ping timed out"
|
||||
)
|
||||
raise HTTPException(status_code=504, detail="Webhook ping timed out")
|
||||
|
||||
return True
|
||||
|
||||
@@ -433,37 +369,32 @@ async def webhook_ping(
|
||||
|
||||
|
||||
async def remove_all_webhooks_for_credentials(
|
||||
user_id: str, credentials: Credentials, force: bool = False
|
||||
credentials: Credentials, force: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Remove and deregister all webhooks that were registered using the given credentials.
|
||||
|
||||
Params:
|
||||
user_id: The ID of the user who owns the credentials and webhooks.
|
||||
credentials: The credentials for which to remove the associated webhooks.
|
||||
force: Whether to proceed if any of the webhooks are still in use.
|
||||
|
||||
Raises:
|
||||
NeedConfirmation: If any of the webhooks are still in use and `force` is `False`
|
||||
"""
|
||||
webhooks = await get_all_webhooks_by_creds(
|
||||
user_id, credentials.id, include_relations=True
|
||||
)
|
||||
if any(w.triggered_nodes or w.triggered_presets for w in webhooks) and not force:
|
||||
webhooks = await get_all_webhooks_by_creds(credentials.id)
|
||||
if any(w.attached_nodes for w in webhooks) and not force:
|
||||
raise NeedConfirmation(
|
||||
"Some webhooks linked to these credentials are still in use by an agent"
|
||||
)
|
||||
for webhook in webhooks:
|
||||
# Unlink all nodes & presets
|
||||
for node in webhook.triggered_nodes:
|
||||
# Unlink all nodes
|
||||
for node in webhook.attached_nodes or []:
|
||||
await set_node_webhook(node.id, None)
|
||||
for preset in webhook.triggered_presets:
|
||||
await set_preset_webhook(user_id, preset.id, None)
|
||||
|
||||
# Prune the webhook
|
||||
webhook_manager = get_webhook_manager(ProviderName(credentials.provider))
|
||||
success = await webhook_manager.prune_webhook_if_dangling(
|
||||
user_id, webhook.id, credentials
|
||||
webhook.id, credentials
|
||||
)
|
||||
if not success:
|
||||
logger.warning(f"Webhook #{webhook.id} failed to prune")
|
||||
@@ -474,7 +405,7 @@ def _get_provider_oauth_handler(
|
||||
) -> "BaseOAuthHandler":
|
||||
if provider_name not in HANDLERS_BY_NAME:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
status_code=404,
|
||||
detail=f"Provider '{provider_name.value}' does not support OAuth",
|
||||
)
|
||||
|
||||
@@ -482,13 +413,14 @@ def _get_provider_oauth_handler(
|
||||
client_secret = getattr(settings.secrets, f"{provider_name.value}_client_secret")
|
||||
if not (client_id and client_secret):
|
||||
logger.error(
|
||||
f"Attempt to use unconfigured {provider_name.value} OAuth integration"
|
||||
"OAuth credentials for provider %s are missing. Check environment configuration.",
|
||||
provider_name.value,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
status_code=501,
|
||||
detail={
|
||||
"message": f"Integration with provider '{provider_name.value}' is not configured.",
|
||||
"hint": "Set client ID and secret in the application's deployment environment",
|
||||
"message": f"Integration with provider '{provider_name.value}' is not configured",
|
||||
"hint": "Set client ID and secret in the environment.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
import re
|
||||
from typing import Set
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to add security headers to responses, with cache control
|
||||
disabled by default for all endpoints except those explicitly allowed.
|
||||
"""
|
||||
|
||||
CACHEABLE_PATHS: Set[str] = {
|
||||
# Static assets
|
||||
"/static",
|
||||
"/_next/static",
|
||||
"/assets",
|
||||
"/images",
|
||||
"/css",
|
||||
"/js",
|
||||
"/fonts",
|
||||
# Public API endpoints
|
||||
"/api/health",
|
||||
"/api/v1/health",
|
||||
"/api/status",
|
||||
# Public store/marketplace pages (read-only)
|
||||
"/api/store/agents",
|
||||
"/api/v1/store/agents",
|
||||
"/api/store/categories",
|
||||
"/api/v1/store/categories",
|
||||
"/api/store/featured",
|
||||
"/api/v1/store/featured",
|
||||
# Public graph templates (read-only, no user data)
|
||||
"/api/graphs/templates",
|
||||
"/api/v1/graphs/templates",
|
||||
# Documentation endpoints
|
||||
"/api/docs",
|
||||
"/api/v1/docs",
|
||||
"/docs",
|
||||
"/swagger",
|
||||
"/openapi.json",
|
||||
# Favicon and manifest
|
||||
"/favicon.ico",
|
||||
"/manifest.json",
|
||||
"/robots.txt",
|
||||
"/sitemap.xml",
|
||||
}
|
||||
|
||||
def __init__(self, app: ASGIApp):
|
||||
super().__init__(app)
|
||||
# Compile regex patterns for wildcard matching
|
||||
self.cacheable_patterns = [
|
||||
re.compile(pattern.replace("*", "[^/]+"))
|
||||
for pattern in self.CACHEABLE_PATHS
|
||||
if "*" in pattern
|
||||
]
|
||||
self.exact_paths = {path for path in self.CACHEABLE_PATHS if "*" not in path}
|
||||
|
||||
def is_cacheable_path(self, path: str) -> bool:
|
||||
"""Check if the given path is allowed to be cached."""
|
||||
# Check exact matches first
|
||||
for cacheable_path in self.exact_paths:
|
||||
if path.startswith(cacheable_path):
|
||||
return True
|
||||
|
||||
# Check pattern matches
|
||||
for pattern in self.cacheable_patterns:
|
||||
if pattern.match(path):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response: Response = await call_next(request)
|
||||
|
||||
# Add general security headers
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
# Default: Disable caching for all endpoints
|
||||
# Only allow caching for explicitly permitted paths
|
||||
if not self.is_cacheable_path(request.url.path):
|
||||
response.headers["Cache-Control"] = (
|
||||
"no-store, no-cache, must-revalidate, private"
|
||||
)
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
response.headers["Expires"] = "0"
|
||||
|
||||
return response
|
||||
@@ -1,143 +0,0 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.applications import Starlette
|
||||
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create a test FastAPI app with security middleware."""
|
||||
app = FastAPI()
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
@app.get("/api/auth/user")
|
||||
def get_user():
|
||||
return {"user": "test"}
|
||||
|
||||
@app.get("/api/v1/integrations/oauth/google")
|
||||
def oauth_endpoint():
|
||||
return {"oauth": "data"}
|
||||
|
||||
@app.get("/api/graphs/123/execute")
|
||||
def execute_graph():
|
||||
return {"execution": "data"}
|
||||
|
||||
@app.get("/api/integrations/credentials")
|
||||
def get_credentials():
|
||||
return {"credentials": "sensitive"}
|
||||
|
||||
@app.get("/api/store/agents")
|
||||
def store_agents():
|
||||
return {"agents": "public list"}
|
||||
|
||||
@app.get("/api/health")
|
||||
def health_check():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/static/logo.png")
|
||||
def static_file():
|
||||
return {"static": "content"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create a test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_non_cacheable_endpoints_have_cache_control_headers(client):
|
||||
"""Test that non-cacheable endpoints (most endpoints) have proper cache control headers."""
|
||||
non_cacheable_endpoints = [
|
||||
"/api/auth/user",
|
||||
"/api/v1/integrations/oauth/google",
|
||||
"/api/graphs/123/execute",
|
||||
"/api/integrations/credentials",
|
||||
]
|
||||
|
||||
for endpoint in non_cacheable_endpoints:
|
||||
response = client.get(endpoint)
|
||||
|
||||
# Check cache control headers are present (default behavior)
|
||||
assert (
|
||||
response.headers["Cache-Control"]
|
||||
== "no-store, no-cache, must-revalidate, private"
|
||||
)
|
||||
assert response.headers["Pragma"] == "no-cache"
|
||||
assert response.headers["Expires"] == "0"
|
||||
|
||||
# Check general security headers
|
||||
assert response.headers["X-Content-Type-Options"] == "nosniff"
|
||||
assert response.headers["X-Frame-Options"] == "DENY"
|
||||
assert response.headers["X-XSS-Protection"] == "1; mode=block"
|
||||
assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin"
|
||||
|
||||
|
||||
def test_cacheable_endpoints_dont_have_cache_control_headers(client):
|
||||
"""Test that explicitly cacheable endpoints don't have restrictive cache control headers."""
|
||||
cacheable_endpoints = [
|
||||
"/api/store/agents",
|
||||
"/api/health",
|
||||
"/static/logo.png",
|
||||
]
|
||||
|
||||
for endpoint in cacheable_endpoints:
|
||||
response = client.get(endpoint)
|
||||
|
||||
# Should NOT have restrictive cache control headers
|
||||
assert (
|
||||
"Cache-Control" not in response.headers
|
||||
or "no-store" not in response.headers.get("Cache-Control", "")
|
||||
)
|
||||
assert (
|
||||
"Pragma" not in response.headers
|
||||
or response.headers.get("Pragma") != "no-cache"
|
||||
)
|
||||
assert (
|
||||
"Expires" not in response.headers or response.headers.get("Expires") != "0"
|
||||
)
|
||||
|
||||
# Should still have general security headers
|
||||
assert response.headers["X-Content-Type-Options"] == "nosniff"
|
||||
assert response.headers["X-Frame-Options"] == "DENY"
|
||||
assert response.headers["X-XSS-Protection"] == "1; mode=block"
|
||||
assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin"
|
||||
|
||||
|
||||
def test_is_cacheable_path_detection():
|
||||
"""Test the path detection logic."""
|
||||
middleware = SecurityHeadersMiddleware(Starlette())
|
||||
|
||||
# Test cacheable paths (allow list)
|
||||
assert middleware.is_cacheable_path("/api/health")
|
||||
assert middleware.is_cacheable_path("/api/v1/health")
|
||||
assert middleware.is_cacheable_path("/static/image.png")
|
||||
assert middleware.is_cacheable_path("/api/store/agents")
|
||||
assert middleware.is_cacheable_path("/docs")
|
||||
assert middleware.is_cacheable_path("/favicon.ico")
|
||||
|
||||
# Test non-cacheable paths (everything else)
|
||||
assert not middleware.is_cacheable_path("/api/auth/user")
|
||||
assert not middleware.is_cacheable_path("/api/v1/integrations/oauth/callback")
|
||||
assert not middleware.is_cacheable_path("/api/integrations/credentials/123")
|
||||
assert not middleware.is_cacheable_path("/api/graphs/abc123/execute")
|
||||
assert not middleware.is_cacheable_path("/api/store/xyz/submissions")
|
||||
|
||||
|
||||
def test_path_prefix_matching():
|
||||
"""Test that path prefix matching works correctly."""
|
||||
middleware = SecurityHeadersMiddleware(Starlette())
|
||||
|
||||
# Test that paths starting with cacheable prefixes are cacheable
|
||||
assert middleware.is_cacheable_path("/static/css/style.css")
|
||||
assert middleware.is_cacheable_path("/static/js/app.js")
|
||||
assert middleware.is_cacheable_path("/assets/images/logo.png")
|
||||
assert middleware.is_cacheable_path("/_next/static/chunks/main.js")
|
||||
|
||||
# Test that other API paths are not cacheable by default
|
||||
assert not middleware.is_cacheable_path("/api/users/profile")
|
||||
assert not middleware.is_cacheable_path("/api/v1/private/data")
|
||||
assert not middleware.is_cacheable_path("/api/billing/subscription")
|
||||
@@ -1,6 +1,5 @@
|
||||
import contextlib
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
import autogpt_libs.auth.models
|
||||
@@ -15,7 +14,6 @@ from autogpt_libs.feature_flag.client import (
|
||||
)
|
||||
from autogpt_libs.logging.utils import generate_uvicorn_config
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
@@ -38,7 +36,6 @@ from backend.blocks.llm import LlmModel
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.external.api import external_app
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
|
||||
settings = backend.util.settings.Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -70,33 +67,6 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.db.disconnect()
|
||||
|
||||
|
||||
def custom_generate_unique_id(route: APIRoute):
|
||||
"""Generate clean operation IDs for OpenAPI spec following the format:
|
||||
{method}{tag}{summary}
|
||||
"""
|
||||
if not route.tags or not route.methods:
|
||||
return f"{route.name}"
|
||||
|
||||
method = list(route.methods)[0].lower()
|
||||
first_tag = route.tags[0]
|
||||
if isinstance(first_tag, Enum):
|
||||
tag_str = first_tag.name
|
||||
else:
|
||||
tag_str = str(first_tag)
|
||||
|
||||
tag = "".join(word.capitalize() for word in tag_str.split("_")) # v1/v2
|
||||
|
||||
summary = (
|
||||
route.summary if route.summary else route.name
|
||||
) # need to be unique, a different version could have the same summary
|
||||
summary = "".join(word.capitalize() for word in str(summary).split("_"))
|
||||
|
||||
if tag:
|
||||
return f"{method}{tag}{summary}"
|
||||
else:
|
||||
return f"{method}{summary}"
|
||||
|
||||
|
||||
docs_url = (
|
||||
"/docs"
|
||||
if settings.config.app_env == backend.util.settings.AppEnvironment.LOCAL
|
||||
@@ -112,11 +82,8 @@ app = fastapi.FastAPI(
|
||||
version="0.1",
|
||||
lifespan=lifespan_context,
|
||||
docs_url=docs_url,
|
||||
generate_unique_id_function=custom_generate_unique_id,
|
||||
)
|
||||
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
|
||||
def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
|
||||
def handler(request: fastapi.Request, exc: Exception):
|
||||
@@ -191,12 +158,10 @@ app.include_router(
|
||||
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.otto.routes.router, tags=["v2", "otto"], prefix="/api/otto"
|
||||
backend.server.v2.otto.routes.router, tags=["v2"], prefix="/api/otto"
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.turnstile.routes.router,
|
||||
tags=["v2", "turnstile"],
|
||||
prefix="/api/turnstile",
|
||||
backend.server.v2.turnstile.routes.router, tags=["v2"], prefix="/api/turnstile"
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
@@ -279,7 +244,6 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
|
||||
@staticmethod
|
||||
async def test_delete_graph(graph_id: str, user_id: str):
|
||||
"""Used for clean-up after a test run"""
|
||||
await backend.server.v2.library.db.delete_library_agent_by_graph_id(
|
||||
graph_id=graph_id, user_id=user_id
|
||||
)
|
||||
@@ -324,14 +288,18 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
|
||||
@staticmethod
|
||||
async def test_execute_preset(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
preset_id: str,
|
||||
user_id: str,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
node_input: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
return await backend.server.v2.library.routes.presets.execute_preset(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
preset_id=preset_id,
|
||||
node_input=node_input or {},
|
||||
user_id=user_id,
|
||||
inputs=inputs or {},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -357,22 +325,11 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
provider: ProviderName,
|
||||
credentials: Credentials,
|
||||
) -> Credentials:
|
||||
from backend.server.integrations.router import (
|
||||
create_credentials,
|
||||
get_credential,
|
||||
)
|
||||
from backend.server.integrations.router import create_credentials
|
||||
|
||||
try:
|
||||
return await create_credentials(
|
||||
user_id=user_id, provider=provider, credentials=credentials
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating credentials: {e}")
|
||||
return await get_credential(
|
||||
provider=provider,
|
||||
user_id=user_id,
|
||||
cred_id=credentials.id,
|
||||
)
|
||||
return await create_credentials(
|
||||
user_id=user_id, provider=provider, credentials=credentials
|
||||
)
|
||||
|
||||
def set_test_dependency_overrides(self, overrides: dict):
|
||||
app.dependency_overrides.update(overrides)
|
||||
|
||||
@@ -4,7 +4,6 @@ import logging
|
||||
from typing import Annotated
|
||||
|
||||
import fastapi
|
||||
import pydantic
|
||||
|
||||
import backend.data.analytics
|
||||
from backend.server.utils import get_user_id
|
||||
@@ -13,28 +12,24 @@ router = fastapi.APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LogRawMetricRequest(pydantic.BaseModel):
|
||||
metric_name: str = pydantic.Field(..., min_length=1)
|
||||
metric_value: float = pydantic.Field(..., allow_inf_nan=False)
|
||||
data_string: str = pydantic.Field(..., min_length=1)
|
||||
|
||||
|
||||
@router.post(path="/log_raw_metric")
|
||||
async def log_raw_metric(
|
||||
user_id: Annotated[str, fastapi.Depends(get_user_id)],
|
||||
request: LogRawMetricRequest,
|
||||
metric_name: Annotated[str, fastapi.Body(..., embed=True)],
|
||||
metric_value: Annotated[float, fastapi.Body(..., embed=True)],
|
||||
data_string: Annotated[str, fastapi.Body(..., embed=True)],
|
||||
):
|
||||
try:
|
||||
result = await backend.data.analytics.log_raw_metric(
|
||||
user_id=user_id,
|
||||
metric_name=request.metric_name,
|
||||
metric_value=request.metric_value,
|
||||
data_string=request.data_string,
|
||||
metric_name=metric_name,
|
||||
metric_value=metric_value,
|
||||
data_string=data_string,
|
||||
)
|
||||
return result.id
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to log metric %s for user %s: %s", request.metric_name, user_id, e
|
||||
"Failed to log metric %s for user %s: %s", metric_name, user_id, e
|
||||
)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
|
||||
@@ -97,17 +97,8 @@ def test_log_raw_metric_invalid_request_improved() -> None:
|
||||
assert "data_string" in error_fields, "Should report missing data_string"
|
||||
|
||||
|
||||
def test_log_raw_metric_type_validation_improved(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
def test_log_raw_metric_type_validation_improved() -> None:
|
||||
"""Test metric type validation with improved assertions."""
|
||||
# Mock the analytics function to avoid event loop issues
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=Mock(id="test-id"),
|
||||
)
|
||||
|
||||
invalid_requests = [
|
||||
{
|
||||
"data": {
|
||||
@@ -128,10 +119,10 @@ def test_log_raw_metric_type_validation_improved(
|
||||
{
|
||||
"data": {
|
||||
"metric_name": "test",
|
||||
"metric_value": 123, # Valid number
|
||||
"data_string": "", # Empty data_string
|
||||
"metric_value": float("inf"), # Infinity
|
||||
"data_string": "test",
|
||||
},
|
||||
"expected_error": "String should have at least 1 character",
|
||||
"expected_error": "ensure this value is finite",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@@ -93,18 +93,10 @@ def test_log_raw_metric_values_parametrized(
|
||||
],
|
||||
)
|
||||
def test_log_raw_metric_invalid_requests_parametrized(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
invalid_data: dict,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
"""Test invalid metric requests with parametrize."""
|
||||
# Mock the analytics function to avoid event loop issues
|
||||
mocker.patch(
|
||||
"backend.data.analytics.log_raw_metric",
|
||||
new_callable=AsyncMock,
|
||||
return_value=Mock(id="test-id"),
|
||||
)
|
||||
|
||||
response = client.post("/log_raw_metric", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
@@ -34,7 +34,7 @@ router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/unsubscribe", summary="One Click Email Unsubscribe")
|
||||
@router.post("/unsubscribe")
|
||||
async def unsubscribe_via_one_click(token: Annotated[str, Query()]):
|
||||
logger.info("Received unsubscribe request from One Click Unsubscribe")
|
||||
try:
|
||||
@@ -48,11 +48,7 @@ async def unsubscribe_via_one_click(token: Annotated[str, Query()]):
|
||||
return JSONResponse(status_code=200, content={"status": "ok"})
|
||||
|
||||
|
||||
@router.post(
|
||||
"/",
|
||||
dependencies=[Depends(postmark_validator.get_dependency())],
|
||||
summary="Handle Postmark Email Webhooks",
|
||||
)
|
||||
@router.post("/", dependencies=[Depends(postmark_validator.get_dependency())])
|
||||
async def postmark_webhook_handler(
|
||||
webhook: Annotated[
|
||||
PostmarkWebhook,
|
||||
|
||||
@@ -9,7 +9,7 @@ import stripe
|
||||
from autogpt_libs.auth.middleware import auth_middleware
|
||||
from autogpt_libs.feature_flag.client import feature_flag
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Request, Response
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
@@ -72,7 +72,6 @@ from backend.server.model import (
|
||||
UpdatePermissionsRequest,
|
||||
)
|
||||
from backend.server.utils import get_user_id
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -114,22 +113,14 @@ v1_router.include_router(
|
||||
########################################################
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/auth/user",
|
||||
summary="Get or create user",
|
||||
tags=["auth"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@v1_router.post("/auth/user", tags=["auth"], dependencies=[Depends(auth_middleware)])
|
||||
async def get_or_create_user_route(user_data: dict = Depends(auth_middleware)):
|
||||
user = await get_or_create_user(user_data)
|
||||
return user.model_dump()
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/auth/user/email",
|
||||
summary="Update user email",
|
||||
tags=["auth"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
"/auth/user/email", tags=["auth"], dependencies=[Depends(auth_middleware)]
|
||||
)
|
||||
async def update_user_email_route(
|
||||
user_id: Annotated[str, Depends(get_user_id)], email: str = Body(...)
|
||||
@@ -141,7 +132,6 @@ async def update_user_email_route(
|
||||
|
||||
@v1_router.get(
|
||||
"/auth/user/preferences",
|
||||
summary="Get notification preferences",
|
||||
tags=["auth"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -154,7 +144,6 @@ async def get_preferences(
|
||||
|
||||
@v1_router.post(
|
||||
"/auth/user/preferences",
|
||||
summary="Update notification preferences",
|
||||
tags=["auth"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -172,20 +161,14 @@ async def update_preferences(
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/onboarding",
|
||||
summary="Get onboarding status",
|
||||
tags=["onboarding"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
"/onboarding", tags=["onboarding"], dependencies=[Depends(auth_middleware)]
|
||||
)
|
||||
async def get_onboarding(user_id: Annotated[str, Depends(get_user_id)]):
|
||||
return await get_user_onboarding(user_id)
|
||||
|
||||
|
||||
@v1_router.patch(
|
||||
"/onboarding",
|
||||
summary="Update onboarding progress",
|
||||
tags=["onboarding"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
"/onboarding", tags=["onboarding"], dependencies=[Depends(auth_middleware)]
|
||||
)
|
||||
async def update_onboarding(
|
||||
user_id: Annotated[str, Depends(get_user_id)], data: UserOnboardingUpdate
|
||||
@@ -195,7 +178,6 @@ async def update_onboarding(
|
||||
|
||||
@v1_router.get(
|
||||
"/onboarding/agents",
|
||||
summary="Get recommended agents",
|
||||
tags=["onboarding"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -207,7 +189,6 @@ async def get_onboarding_agents(
|
||||
|
||||
@v1_router.get(
|
||||
"/onboarding/enabled",
|
||||
summary="Check onboarding enabled",
|
||||
tags=["onboarding", "public"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -220,12 +201,7 @@ async def is_onboarding_enabled():
|
||||
########################################################
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/blocks",
|
||||
summary="List available blocks",
|
||||
tags=["blocks"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@v1_router.get(path="/blocks", tags=["blocks"], dependencies=[Depends(auth_middleware)])
|
||||
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
blocks = [block() for block in get_blocks().values()]
|
||||
costs = get_block_costs()
|
||||
@@ -236,7 +212,6 @@ def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
|
||||
@v1_router.post(
|
||||
path="/blocks/{block_id}/execute",
|
||||
summary="Execute graph block",
|
||||
tags=["blocks"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -256,12 +231,7 @@ async def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlock
|
||||
########################################################
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/credits",
|
||||
tags=["credits"],
|
||||
summary="Get user credits",
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@v1_router.get(path="/credits", dependencies=[Depends(auth_middleware)])
|
||||
async def get_user_credits(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[str, int]:
|
||||
@@ -269,10 +239,7 @@ async def get_user_credits(
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/credits",
|
||||
summary="Request credit top up",
|
||||
tags=["credits"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
path="/credits", tags=["credits"], dependencies=[Depends(auth_middleware)]
|
||||
)
|
||||
async def request_top_up(
|
||||
request: RequestTopUp, user_id: Annotated[str, Depends(get_user_id)]
|
||||
@@ -285,7 +252,6 @@ async def request_top_up(
|
||||
|
||||
@v1_router.post(
|
||||
path="/credits/{transaction_key}/refund",
|
||||
summary="Refund credit transaction",
|
||||
tags=["credits"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -298,10 +264,7 @@ async def refund_top_up(
|
||||
|
||||
|
||||
@v1_router.patch(
|
||||
path="/credits",
|
||||
summary="Fulfill checkout session",
|
||||
tags=["credits"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
path="/credits", tags=["credits"], dependencies=[Depends(auth_middleware)]
|
||||
)
|
||||
async def fulfill_checkout(user_id: Annotated[str, Depends(get_user_id)]):
|
||||
await _user_credit_model.fulfill_checkout(user_id=user_id)
|
||||
@@ -310,7 +273,6 @@ async def fulfill_checkout(user_id: Annotated[str, Depends(get_user_id)]):
|
||||
|
||||
@v1_router.post(
|
||||
path="/credits/auto-top-up",
|
||||
summary="Configure auto top up",
|
||||
tags=["credits"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -339,7 +301,6 @@ async def configure_user_auto_top_up(
|
||||
|
||||
@v1_router.get(
|
||||
path="/credits/auto-top-up",
|
||||
summary="Get auto top up",
|
||||
tags=["credits"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -349,9 +310,7 @@ async def get_user_auto_top_up(
|
||||
return await get_auto_top_up(user_id)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/credits/stripe_webhook", summary="Handle Stripe webhooks", tags=["credits"]
|
||||
)
|
||||
@v1_router.post(path="/credits/stripe_webhook", tags=["credits"])
|
||||
async def stripe_webhook(request: Request):
|
||||
# Get the raw request body
|
||||
payload = await request.body()
|
||||
@@ -386,24 +345,14 @@ async def stripe_webhook(request: Request):
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/credits/manage",
|
||||
tags=["credits"],
|
||||
summary="Manage payment methods",
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@v1_router.get(path="/credits/manage", dependencies=[Depends(auth_middleware)])
|
||||
async def manage_payment_method(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[str, str]:
|
||||
return {"url": await _user_credit_model.create_billing_portal_session(user_id)}
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/credits/transactions",
|
||||
tags=["credits"],
|
||||
summary="Get credit history",
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@v1_router.get(path="/credits/transactions", dependencies=[Depends(auth_middleware)])
|
||||
async def get_credit_history(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
transaction_time: datetime | None = None,
|
||||
@@ -421,12 +370,7 @@ async def get_credit_history(
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/credits/refunds",
|
||||
tags=["credits"],
|
||||
summary="Get refund requests",
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@v1_router.get(path="/credits/refunds", dependencies=[Depends(auth_middleware)])
|
||||
async def get_refund_requests(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[RefundRequest]:
|
||||
@@ -442,12 +386,7 @@ class DeleteGraphResponse(TypedDict):
|
||||
version_counts: int
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs",
|
||||
summary="List user graphs",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@v1_router.get(path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)])
|
||||
async def get_graphs(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
@@ -455,14 +394,10 @@ async def get_graphs(
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}",
|
||||
summary="Get specific graph",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
|
||||
)
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/versions/{version}",
|
||||
summary="Get graph version",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -486,7 +421,6 @@ async def get_graph(
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/versions",
|
||||
summary="Get all graph versions",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -500,10 +434,7 @@ async def get_graph_all_versions(
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/graphs",
|
||||
summary="Create new graph",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)]
|
||||
)
|
||||
async def create_new_graph(
|
||||
create_graph: CreateGraph,
|
||||
@@ -526,10 +457,7 @@ async def create_new_graph(
|
||||
|
||||
|
||||
@v1_router.delete(
|
||||
path="/graphs/{graph_id}",
|
||||
summary="Delete graph permanently",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
|
||||
)
|
||||
async def delete_graph(
|
||||
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
@@ -541,10 +469,7 @@ async def delete_graph(
|
||||
|
||||
|
||||
@v1_router.put(
|
||||
path="/graphs/{graph_id}",
|
||||
summary="Update graph version",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
|
||||
)
|
||||
async def update_graph(
|
||||
graph_id: str,
|
||||
@@ -590,7 +515,6 @@ async def update_graph(
|
||||
|
||||
@v1_router.put(
|
||||
path="/graphs/{graph_id}/versions/active",
|
||||
summary="Set active graph version",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -629,7 +553,6 @@ async def set_graph_active_version(
|
||||
|
||||
@v1_router.post(
|
||||
path="/graphs/{graph_id}/execute/{graph_version}",
|
||||
summary="Execute graph agent",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -663,65 +586,33 @@ async def execute_graph(
|
||||
|
||||
@v1_router.post(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
|
||||
summary="Stop graph execution",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def stop_graph_run(
|
||||
graph_id: str, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> execution_db.GraphExecutionMeta | None:
|
||||
res = await _stop_graph_run(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
) -> execution_db.GraphExecution:
|
||||
if not await execution_db.get_graph_execution_meta(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
):
|
||||
raise HTTPException(404, detail=f"Agent execution #{graph_exec_id} not found")
|
||||
|
||||
await execution_utils.stop_graph_execution(graph_exec_id)
|
||||
|
||||
# Retrieve & return canceled graph execution in its final state
|
||||
result = await execution_db.get_graph_execution(
|
||||
execution_id=graph_exec_id, user_id=user_id
|
||||
)
|
||||
if not res:
|
||||
return None
|
||||
return res[0]
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/executions",
|
||||
summary="Stop graph executions",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def stop_graph_runs(
|
||||
graph_id: str, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[execution_db.GraphExecutionMeta]:
|
||||
return await _stop_graph_run(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
|
||||
|
||||
async def _stop_graph_run(
|
||||
user_id: str,
|
||||
graph_id: Optional[str] = None,
|
||||
graph_exec_id: Optional[str] = None,
|
||||
) -> list[execution_db.GraphExecutionMeta]:
|
||||
graph_execs = await execution_db.get_graph_executions(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
execution_db.ExecutionStatus.INCOMPLETE,
|
||||
execution_db.ExecutionStatus.QUEUED,
|
||||
execution_db.ExecutionStatus.RUNNING,
|
||||
],
|
||||
)
|
||||
stopped_execs = [
|
||||
execution_utils.stop_graph_execution(graph_exec_id=exec.id, user_id=user_id)
|
||||
for exec in graph_execs
|
||||
]
|
||||
await asyncio.gather(*stopped_execs)
|
||||
return graph_execs
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
500,
|
||||
detail=f"Could not fetch graph execution #{graph_exec_id} after stopping",
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/executions",
|
||||
summary="Get all executions",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -733,7 +624,6 @@ async def get_graphs_executions(
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/executions",
|
||||
summary="Get graph executions",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -746,7 +636,6 @@ async def get_graph_executions(
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}",
|
||||
summary="Get execution details",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -776,7 +665,6 @@ async def get_graph_execution(
|
||||
|
||||
@v1_router.delete(
|
||||
path="/executions/{graph_exec_id}",
|
||||
summary="Delete graph execution",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
@@ -796,55 +684,60 @@ async def delete_graph_execution(
|
||||
|
||||
|
||||
class ScheduleCreationRequest(pydantic.BaseModel):
|
||||
graph_version: Optional[int] = None
|
||||
name: str
|
||||
cron: str
|
||||
inputs: dict[str, Any]
|
||||
credentials: dict[str, CredentialsMetaInput] = pydantic.Field(default_factory=dict)
|
||||
input_data: dict[Any, Any]
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/graphs/{graph_id}/schedules",
|
||||
summary="Create execution schedule",
|
||||
path="/schedules",
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def create_graph_execution_schedule(
|
||||
async def create_schedule(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
graph_id: str = Path(..., description="ID of the graph to schedule"),
|
||||
schedule_params: ScheduleCreationRequest = Body(),
|
||||
schedule: ScheduleCreationRequest,
|
||||
) -> scheduler.GraphExecutionJobInfo:
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=graph_id,
|
||||
version=schedule_params.graph_version,
|
||||
user_id=user_id,
|
||||
schedule.graph_id, schedule.graph_version, user_id=user_id
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Graph #{graph_id} v{schedule_params.graph_version} not found.",
|
||||
detail=f"Graph #{schedule.graph_id} v.{schedule.graph_version} not found.",
|
||||
)
|
||||
|
||||
return await execution_scheduler_client().add_execution_schedule(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_id=schedule.graph_id,
|
||||
graph_version=graph.version,
|
||||
name=schedule_params.name,
|
||||
cron=schedule_params.cron,
|
||||
input_data=schedule_params.inputs,
|
||||
input_credentials=schedule_params.credentials,
|
||||
cron=schedule.cron,
|
||||
input_data=schedule.input_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/schedules",
|
||||
summary="List execution schedules for a graph",
|
||||
@v1_router.delete(
|
||||
path="/schedules/{schedule_id}",
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def list_graph_execution_schedules(
|
||||
async def delete_schedule(
|
||||
schedule_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
graph_id: str = Path(),
|
||||
) -> dict[Any, Any]:
|
||||
await execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
|
||||
return {"id": schedule_id}
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/schedules",
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_execution_schedules(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
graph_id: str | None = None,
|
||||
) -> list[scheduler.GraphExecutionJobInfo]:
|
||||
return await execution_scheduler_client().get_execution_schedules(
|
||||
user_id=user_id,
|
||||
@@ -852,38 +745,6 @@ async def list_graph_execution_schedules(
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/schedules",
|
||||
summary="List execution schedules for a user",
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def list_all_graphs_execution_schedules(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[scheduler.GraphExecutionJobInfo]:
|
||||
return await execution_scheduler_client().get_execution_schedules(user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.delete(
|
||||
path="/schedules/{schedule_id}",
|
||||
summary="Delete execution schedule",
|
||||
tags=["schedules"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def delete_graph_execution_schedule(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
schedule_id: str = Path(..., description="ID of the schedule to delete"),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
await execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
|
||||
except NotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_404_NOT_FOUND,
|
||||
detail=f"Schedule #{schedule_id} not found",
|
||||
)
|
||||
return {"id": schedule_id}
|
||||
|
||||
|
||||
########################################################
|
||||
##################### API KEY ##############################
|
||||
########################################################
|
||||
@@ -891,7 +752,6 @@ async def delete_graph_execution_schedule(
|
||||
|
||||
@v1_router.post(
|
||||
"/api-keys",
|
||||
summary="Create new API key",
|
||||
response_model=CreateAPIKeyResponse,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
@@ -922,7 +782,6 @@ async def create_api_key(
|
||||
|
||||
@v1_router.get(
|
||||
"/api-keys",
|
||||
summary="List user API keys",
|
||||
response_model=list[APIKeyWithoutHash] | dict[str, str],
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
@@ -943,7 +802,6 @@ async def get_api_keys(
|
||||
|
||||
@v1_router.get(
|
||||
"/api-keys/{key_id}",
|
||||
summary="Get specific API key",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
@@ -967,7 +825,6 @@ async def get_api_key(
|
||||
|
||||
@v1_router.delete(
|
||||
"/api-keys/{key_id}",
|
||||
summary="Revoke API key",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
@@ -996,7 +853,6 @@ async def delete_api_key(
|
||||
|
||||
@v1_router.post(
|
||||
"/api-keys/{key_id}/suspend",
|
||||
summary="Suspend API key",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
@@ -1022,7 +878,6 @@ async def suspend_key(
|
||||
|
||||
@v1_router.put(
|
||||
"/api-keys/{key_id}/permissions",
|
||||
summary="Update key permissions",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
|
||||
@@ -108,16 +108,11 @@ class TestDatabaseIsolation:
|
||||
where={"email": {"contains": "@test.example"}}
|
||||
)
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def test_create_user(self, test_db_connection):
|
||||
"""Test that demonstrates proper isolation."""
|
||||
# This test has access to a clean database
|
||||
user = await test_db_connection.user.create(
|
||||
data={
|
||||
"id": "test-user-id",
|
||||
"email": "test@test.example",
|
||||
"name": "Test User",
|
||||
}
|
||||
data={"email": "test@test.example", "name": "Test User"}
|
||||
)
|
||||
assert user.email == "test@test.example"
|
||||
# User will be cleaned up automatically
|
||||
|
||||
@@ -22,9 +22,7 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/add_credits", response_model=AddUserCreditsResponse, summary="Add Credits to User"
|
||||
)
|
||||
@router.post("/add_credits", response_model=AddUserCreditsResponse)
|
||||
async def add_user_credits(
|
||||
user_id: typing.Annotated[str, Body()],
|
||||
amount: typing.Annotated[int, Body()],
|
||||
@@ -51,7 +49,6 @@ async def add_user_credits(
|
||||
@router.get(
|
||||
"/users_history",
|
||||
response_model=UserHistoryResponse,
|
||||
summary="Get All Users History",
|
||||
)
|
||||
async def admin_get_all_user_history(
|
||||
admin_user: typing.Annotated[
|
||||
|
||||
@@ -19,7 +19,6 @@ router = fastapi.APIRouter(prefix="/admin", tags=["store", "admin"])
|
||||
|
||||
@router.get(
|
||||
"/listings",
|
||||
summary="Get Admin Listings History",
|
||||
response_model=backend.server.v2.store.model.StoreListingsWithVersionsResponse,
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
|
||||
)
|
||||
@@ -64,7 +63,6 @@ async def get_admin_listings_with_versions(
|
||||
|
||||
@router.post(
|
||||
"/submissions/{store_listing_version_id}/review",
|
||||
summary="Review Store Submission",
|
||||
response_model=backend.server.v2.store.model.StoreSubmission,
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
|
||||
)
|
||||
@@ -106,7 +104,6 @@ async def review_submission(
|
||||
|
||||
@router.get(
|
||||
"/submissions/download/{store_listing_version_id}",
|
||||
summary="Admin Download Agent File",
|
||||
tags=["store", "admin"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
import fastapi
|
||||
import prisma.errors
|
||||
@@ -7,17 +7,17 @@ import prisma.fields
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
import backend.data.graph as graph_db
|
||||
import backend.data.graph
|
||||
import backend.server.model
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
import backend.server.v2.store.image_gen as store_image_gen
|
||||
import backend.server.v2.store.media as store_media
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.db import locked_transaction, transaction
|
||||
from backend.data import db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.db import locked_transaction
|
||||
from backend.data.execution import get_graph_execution
|
||||
from backend.data.includes import library_agent_include
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||
from backend.util.exceptions import NotFoundError
|
||||
@@ -122,7 +122,7 @@ async def list_library_agents(
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
logger.error(
|
||||
f"Error parsing LibraryAgent #{agent.id} from DB item: {e}"
|
||||
f"Error parsing LibraryAgent when getting library agents from db: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -168,16 +168,9 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
)
|
||||
|
||||
if not library_agent:
|
||||
raise NotFoundError(f"Library agent #{id} not found")
|
||||
raise store_exceptions.AgentNotFoundError(f"Library agent #{id} not found")
|
||||
|
||||
return library_model.LibraryAgent.from_db(
|
||||
library_agent,
|
||||
sub_graphs=(
|
||||
await graph_db.get_sub_graphs(library_agent.AgentGraph)
|
||||
if library_agent.AgentGraph
|
||||
else None
|
||||
),
|
||||
)
|
||||
return library_model.LibraryAgent.from_db(library_agent)
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error fetching library agent: {e}")
|
||||
@@ -222,34 +215,8 @@ async def get_library_agent_by_store_version_id(
|
||||
return None
|
||||
|
||||
|
||||
async def get_library_agent_by_graph_id(
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_version: Optional[int] = None,
|
||||
) -> library_model.LibraryAgent | None:
|
||||
try:
|
||||
filter: prisma.types.LibraryAgentWhereInput = {
|
||||
"agentGraphId": graph_id,
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
if graph_version is not None:
|
||||
filter["agentGraphVersion"] = graph_version
|
||||
|
||||
agent = await prisma.models.LibraryAgent.prisma().find_first(
|
||||
where=filter,
|
||||
include=library_agent_include(user_id),
|
||||
)
|
||||
if not agent:
|
||||
return None
|
||||
return library_model.LibraryAgent.from_db(agent)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error fetching library agent by graph ID: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to fetch library agent") from e
|
||||
|
||||
|
||||
async def add_generated_agent_image(
|
||||
graph: graph_db.GraphModel,
|
||||
graph: backend.data.graph.GraphModel,
|
||||
library_agent_id: str,
|
||||
) -> Optional[prisma.models.LibraryAgent]:
|
||||
"""
|
||||
@@ -282,7 +249,7 @@ async def add_generated_agent_image(
|
||||
|
||||
|
||||
async def create_library_agent(
|
||||
graph: graph_db.GraphModel,
|
||||
graph: backend.data.graph.GraphModel,
|
||||
user_id: str,
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
@@ -379,8 +346,8 @@ async def update_library_agent(
|
||||
auto_update_version: Optional[bool] = None,
|
||||
is_favorite: Optional[bool] = None,
|
||||
is_archived: Optional[bool] = None,
|
||||
is_deleted: Optional[Literal[False]] = None,
|
||||
) -> library_model.LibraryAgent:
|
||||
is_deleted: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Updates the specified LibraryAgent record.
|
||||
|
||||
@@ -390,18 +357,15 @@ async def update_library_agent(
|
||||
auto_update_version: Whether the agent should auto-update to active version.
|
||||
is_favorite: Whether this agent is marked as a favorite.
|
||||
is_archived: Whether this agent is archived.
|
||||
|
||||
Returns:
|
||||
The updated LibraryAgent.
|
||||
is_deleted: Whether this agent is deleted.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the specified LibraryAgent does not exist.
|
||||
DatabaseError: If there's an error in the update operation.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Updating library agent {library_agent_id} for user {user_id} with "
|
||||
f"auto_update_version={auto_update_version}, is_favorite={is_favorite}, "
|
||||
f"is_archived={is_archived}"
|
||||
f"is_archived={is_archived}, is_deleted={is_deleted}"
|
||||
)
|
||||
update_fields: prisma.types.LibraryAgentUpdateManyMutationInput = {}
|
||||
if auto_update_version is not None:
|
||||
@@ -411,46 +375,17 @@ async def update_library_agent(
|
||||
if is_archived is not None:
|
||||
update_fields["isArchived"] = is_archived
|
||||
if is_deleted is not None:
|
||||
if is_deleted is True:
|
||||
raise RuntimeError(
|
||||
"Use delete_library_agent() to (soft-)delete library agents"
|
||||
)
|
||||
update_fields["isDeleted"] = is_deleted
|
||||
if not update_fields:
|
||||
raise ValueError("No values were passed to update")
|
||||
|
||||
try:
|
||||
n_updated = await prisma.models.LibraryAgent.prisma().update_many(
|
||||
where={"id": library_agent_id, "userId": user_id},
|
||||
data=update_fields,
|
||||
)
|
||||
if n_updated < 1:
|
||||
raise NotFoundError(f"Library agent {library_agent_id} not found")
|
||||
|
||||
return await get_library_agent(
|
||||
id=library_agent_id,
|
||||
user_id=user_id,
|
||||
await prisma.models.LibraryAgent.prisma().update_many(
|
||||
where={"id": library_agent_id, "userId": user_id}, data=update_fields
|
||||
)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating library agent: {str(e)}")
|
||||
raise store_exceptions.DatabaseError("Failed to update library agent") from e
|
||||
|
||||
|
||||
async def delete_library_agent(
|
||||
library_agent_id: str, user_id: str, soft_delete: bool = True
|
||||
) -> None:
|
||||
if soft_delete:
|
||||
deleted_count = await prisma.models.LibraryAgent.prisma().update_many(
|
||||
where={"id": library_agent_id, "userId": user_id}, data={"isDeleted": True}
|
||||
)
|
||||
else:
|
||||
deleted_count = await prisma.models.LibraryAgent.prisma().delete_many(
|
||||
where={"id": library_agent_id, "userId": user_id}
|
||||
)
|
||||
if deleted_count < 1:
|
||||
raise NotFoundError(f"Library agent #{library_agent_id} not found")
|
||||
|
||||
|
||||
async def delete_library_agent_by_graph_id(graph_id: str, user_id: str) -> None:
|
||||
"""
|
||||
Deletes a library agent for the given user
|
||||
@@ -590,10 +525,7 @@ async def list_presets(
|
||||
)
|
||||
raise store_exceptions.DatabaseError("Invalid pagination parameters")
|
||||
|
||||
query_filter: prisma.types.AgentPresetWhereInput = {
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
}
|
||||
query_filter: prisma.types.AgentPresetWhereInput = {"userId": user_id}
|
||||
if graph_id:
|
||||
query_filter["agentGraphId"] = graph_id
|
||||
|
||||
@@ -649,7 +581,7 @@ async def get_preset(
|
||||
where={"id": preset_id},
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not preset or preset.userId != user_id or preset.isDeleted:
|
||||
if not preset or preset.userId != user_id:
|
||||
return None
|
||||
return library_model.LibraryAgentPreset.from_db(preset)
|
||||
except prisma.errors.PrismaError as e:
|
||||
@@ -686,19 +618,12 @@ async def create_preset(
|
||||
agentGraphId=preset.graph_id,
|
||||
agentGraphVersion=preset.graph_version,
|
||||
isActive=preset.is_active,
|
||||
webhookId=preset.webhook_id,
|
||||
InputPresets={
|
||||
"create": [
|
||||
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput( # noqa
|
||||
name=name, data=prisma.fields.Json(data)
|
||||
)
|
||||
for name, data in {
|
||||
**preset.inputs,
|
||||
**{
|
||||
key: creds_meta.model_dump(exclude_none=True)
|
||||
for key, creds_meta in preset.credentials.items()
|
||||
},
|
||||
}.items()
|
||||
for name, data in preset.inputs.items()
|
||||
]
|
||||
},
|
||||
),
|
||||
@@ -739,7 +664,6 @@ async def create_preset_from_graph_execution(
|
||||
user_id=user_id,
|
||||
preset=library_model.LibraryAgentPresetCreatable(
|
||||
inputs=graph_execution.inputs,
|
||||
credentials={}, # FIXME
|
||||
graph_id=graph_execution.graph_id,
|
||||
graph_version=graph_execution.graph_version,
|
||||
name=create_request.name,
|
||||
@@ -752,11 +676,7 @@ async def create_preset_from_graph_execution(
|
||||
async def update_preset(
|
||||
user_id: str,
|
||||
preset_id: str,
|
||||
inputs: Optional[BlockInput] = None,
|
||||
credentials: Optional[dict[str, CredentialsMetaInput]] = None,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
preset: library_model.LibraryAgentPresetUpdatable,
|
||||
) -> library_model.LibraryAgentPreset:
|
||||
"""
|
||||
Updates an existing AgentPreset for a user.
|
||||
@@ -764,95 +684,49 @@ async def update_preset(
|
||||
Args:
|
||||
user_id: The ID of the user updating the preset.
|
||||
preset_id: The ID of the preset to update.
|
||||
inputs: New inputs object to set on the preset.
|
||||
credentials: New credentials to set on the preset.
|
||||
name: New name for the preset.
|
||||
description: New description for the preset.
|
||||
is_active: New active status for the preset.
|
||||
preset: The preset data used for the update.
|
||||
|
||||
Returns:
|
||||
The updated LibraryAgentPreset.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there's a database error in updating the preset.
|
||||
NotFoundError: If attempting to update a non-existent preset.
|
||||
ValueError: If attempting to update a non-existent preset.
|
||||
"""
|
||||
current = await get_preset(user_id, preset_id) # assert ownership
|
||||
if not current:
|
||||
raise NotFoundError(f"Preset #{preset_id} not found for user #{user_id}")
|
||||
logger.debug(
|
||||
f"Updating preset #{preset_id} ({repr(current.name)}) for user #{user_id}",
|
||||
f"Updating preset #{preset_id} ({repr(preset.name)}) for user #{user_id}",
|
||||
)
|
||||
try:
|
||||
async with transaction() as tx:
|
||||
update_data: prisma.types.AgentPresetUpdateInput = {}
|
||||
if name:
|
||||
update_data["name"] = name
|
||||
if description:
|
||||
update_data["description"] = description
|
||||
if is_active is not None:
|
||||
update_data["isActive"] = is_active
|
||||
if inputs or credentials:
|
||||
if not (inputs and credentials):
|
||||
raise ValueError(
|
||||
"Preset inputs and credentials must be provided together"
|
||||
update_data: prisma.types.AgentPresetUpdateInput = {}
|
||||
if preset.name:
|
||||
update_data["name"] = preset.name
|
||||
if preset.description:
|
||||
update_data["description"] = preset.description
|
||||
if preset.inputs:
|
||||
update_data["InputPresets"] = {
|
||||
"create": [
|
||||
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput( # noqa
|
||||
name=name, data=prisma.fields.Json(data)
|
||||
)
|
||||
update_data["InputPresets"] = {
|
||||
"create": [
|
||||
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput( # noqa
|
||||
name=name, data=prisma.fields.Json(data)
|
||||
)
|
||||
for name, data in {
|
||||
**inputs,
|
||||
**{
|
||||
key: creds_meta.model_dump(exclude_none=True)
|
||||
for key, creds_meta in credentials.items()
|
||||
},
|
||||
}.items()
|
||||
],
|
||||
}
|
||||
# Existing InputPresets must be deleted, in a separate query
|
||||
await prisma.models.AgentNodeExecutionInputOutput.prisma(
|
||||
tx
|
||||
).delete_many(where={"agentPresetId": preset_id})
|
||||
for name, data in preset.inputs.items()
|
||||
]
|
||||
}
|
||||
if preset.is_active:
|
||||
update_data["isActive"] = preset.is_active
|
||||
|
||||
updated = await prisma.models.AgentPreset.prisma(tx).update(
|
||||
where={"id": preset_id},
|
||||
data=update_data,
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
updated = await prisma.models.AgentPreset.prisma().update(
|
||||
where={"id": preset_id},
|
||||
data=update_data,
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not updated:
|
||||
raise RuntimeError(f"AgentPreset #{preset_id} vanished while updating")
|
||||
raise ValueError(f"AgentPreset #{preset_id} not found")
|
||||
return library_model.LibraryAgentPreset.from_db(updated)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating preset: {e}")
|
||||
raise store_exceptions.DatabaseError("Failed to update preset") from e
|
||||
|
||||
|
||||
async def set_preset_webhook(
|
||||
user_id: str, preset_id: str, webhook_id: str | None
|
||||
) -> library_model.LibraryAgentPreset:
|
||||
current = await prisma.models.AgentPreset.prisma().find_unique(
|
||||
where={"id": preset_id},
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not current or current.userId != user_id:
|
||||
raise NotFoundError(f"Preset #{preset_id} not found")
|
||||
|
||||
updated = await prisma.models.AgentPreset.prisma().update(
|
||||
where={"id": preset_id},
|
||||
data=(
|
||||
{"Webhook": {"connect": {"id": webhook_id}}}
|
||||
if webhook_id
|
||||
else {"Webhook": {"disconnect": True}}
|
||||
),
|
||||
include={"InputPresets": True},
|
||||
)
|
||||
if not updated:
|
||||
raise RuntimeError(f"AgentPreset #{preset_id} vanished while updating")
|
||||
return library_model.LibraryAgentPreset.from_db(updated)
|
||||
|
||||
|
||||
async def delete_preset(user_id: str, preset_id: str) -> None:
|
||||
"""
|
||||
Soft-deletes a preset by marking it as isDeleted = True.
|
||||
@@ -864,7 +738,7 @@ async def delete_preset(user_id: str, preset_id: str) -> None:
|
||||
Raises:
|
||||
DatabaseError: If there's a database error during deletion.
|
||||
"""
|
||||
logger.debug(f"Setting preset #{preset_id} for user #{user_id} to deleted")
|
||||
logger.info(f"Deleting preset {preset_id} for user {user_id}")
|
||||
try:
|
||||
await prisma.models.AgentPreset.prisma().update_many(
|
||||
where={"id": preset_id, "userId": user_id},
|
||||
@@ -891,7 +765,7 @@ async def fork_library_agent(library_agent_id: str, user_id: str):
|
||||
"""
|
||||
logger.debug(f"Forking library agent {library_agent_id} for user {user_id}")
|
||||
try:
|
||||
async with locked_transaction(f"usr_trx_{user_id}-fork_agent"):
|
||||
async with db.locked_transaction(f"usr_trx_{user_id}-fork_agent"):
|
||||
# Fetch the original agent
|
||||
original_agent = await get_library_agent(library_agent_id, user_id)
|
||||
|
||||
|
||||
@@ -143,7 +143,7 @@ async def test_add_agent_to_library(mocker):
|
||||
)
|
||||
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.create = mocker.AsyncMock(
|
||||
return_value=mock_library_agent_data
|
||||
)
|
||||
@@ -159,24 +159,21 @@ async def test_add_agent_to_library(mocker):
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"AgentGraph": True}
|
||||
)
|
||||
mock_library_agent.return_value.find_unique.assert_called_once_with(
|
||||
mock_library_agent.return_value.find_first.assert_called_once_with(
|
||||
where={
|
||||
"userId_agentGraphId_agentGraphVersion": {
|
||||
"userId": "test-user",
|
||||
"agentGraphId": "agent1",
|
||||
"agentGraphVersion": 1,
|
||||
}
|
||||
"userId": "test-user",
|
||||
"agentGraphId": "agent1",
|
||||
"agentGraphVersion": 1,
|
||||
},
|
||||
include={"AgentGraph": True},
|
||||
include=library_agent_include("test-user"),
|
||||
)
|
||||
mock_library_agent.return_value.create.assert_called_once_with(
|
||||
data={
|
||||
"User": {"connect": {"id": "test-user"}},
|
||||
"AgentGraph": {
|
||||
"connect": {"graphVersionId": {"id": "agent1", "version": 1}}
|
||||
},
|
||||
"isCreatedByUser": False,
|
||||
},
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
userId="test-user",
|
||||
agentGraphId="agent1",
|
||||
agentGraphVersion=1,
|
||||
isCreatedByUser=False,
|
||||
),
|
||||
include=library_agent_include("test-user"),
|
||||
)
|
||||
|
||||
|
||||
@@ -9,8 +9,6 @@ import pydantic
|
||||
import backend.data.block as block_model
|
||||
import backend.data.graph as graph_model
|
||||
import backend.server.model as server_model
|
||||
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
|
||||
class LibraryAgentStatus(str, Enum):
|
||||
@@ -20,14 +18,6 @@ class LibraryAgentStatus(str, Enum):
|
||||
ERROR = "ERROR" # Agent is in an error state
|
||||
|
||||
|
||||
class LibraryAgentTriggerInfo(pydantic.BaseModel):
|
||||
provider: ProviderName
|
||||
config_schema: dict[str, Any] = pydantic.Field(
|
||||
description="Input schema for the trigger block"
|
||||
)
|
||||
credentials_input_name: Optional[str]
|
||||
|
||||
|
||||
class LibraryAgent(pydantic.BaseModel):
|
||||
"""
|
||||
Represents an agent in the library, including metadata for display and
|
||||
@@ -50,15 +40,8 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
# Made input_schema and output_schema match GraphMeta's type
|
||||
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
|
||||
credentials_input_schema: dict[str, Any] | None = pydantic.Field(
|
||||
description="Input schema for credentials required by the agent",
|
||||
)
|
||||
|
||||
has_external_trigger: bool = pydantic.Field(
|
||||
description="Whether the agent has an external trigger (e.g. webhook) node"
|
||||
)
|
||||
trigger_setup_info: Optional[LibraryAgentTriggerInfo] = None
|
||||
|
||||
# Indicates whether there's a new output (based on recent runs)
|
||||
new_output: bool
|
||||
@@ -70,10 +53,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
is_latest_version: bool
|
||||
|
||||
@staticmethod
|
||||
def from_db(
|
||||
agent: prisma.models.LibraryAgent,
|
||||
sub_graphs: Optional[list[prisma.models.AgentGraph]] = None,
|
||||
) -> "LibraryAgent":
|
||||
def from_db(agent: prisma.models.LibraryAgent) -> "LibraryAgent":
|
||||
"""
|
||||
Factory method that constructs a LibraryAgent from a Prisma LibraryAgent
|
||||
model instance.
|
||||
@@ -81,7 +61,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
if not agent.AgentGraph:
|
||||
raise ValueError("Associated Agent record is required.")
|
||||
|
||||
graph = graph_model.GraphModel.from_db(agent.AgentGraph, sub_graphs=sub_graphs)
|
||||
graph = graph_model.GraphModel.from_db(agent.AgentGraph)
|
||||
|
||||
agent_updated_at = agent.AgentGraph.updatedAt
|
||||
lib_agent_updated_at = agent.updatedAt
|
||||
@@ -126,34 +106,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
input_schema=graph.input_schema,
|
||||
credentials_input_schema=(
|
||||
graph.credentials_input_schema if sub_graphs else None
|
||||
),
|
||||
has_external_trigger=graph.has_webhook_trigger,
|
||||
trigger_setup_info=(
|
||||
LibraryAgentTriggerInfo(
|
||||
provider=trigger_block.webhook_config.provider,
|
||||
config_schema={
|
||||
**(json_schema := trigger_block.input_schema.jsonschema()),
|
||||
"properties": {
|
||||
pn: sub_schema
|
||||
for pn, sub_schema in json_schema["properties"].items()
|
||||
if not is_credentials_field_name(pn)
|
||||
},
|
||||
"required": [
|
||||
pn
|
||||
for pn in json_schema.get("required", [])
|
||||
if not is_credentials_field_name(pn)
|
||||
],
|
||||
},
|
||||
credentials_input_name=next(
|
||||
iter(trigger_block.input_schema.get_credentials_fields()), None
|
||||
),
|
||||
)
|
||||
if graph.webhook_input_node
|
||||
and (trigger_block := graph.webhook_input_node.block).webhook_config
|
||||
else None
|
||||
),
|
||||
new_output=new_output,
|
||||
can_access_graph=can_access_graph,
|
||||
is_latest_version=is_latest_version,
|
||||
@@ -225,15 +177,12 @@ class LibraryAgentPresetCreatable(pydantic.BaseModel):
|
||||
graph_version: int
|
||||
|
||||
inputs: block_model.BlockInput
|
||||
credentials: dict[str, CredentialsMetaInput]
|
||||
|
||||
name: str
|
||||
description: str
|
||||
|
||||
is_active: bool = True
|
||||
|
||||
webhook_id: Optional[str] = None
|
||||
|
||||
|
||||
class LibraryAgentPresetCreatableFromGraphExecution(pydantic.BaseModel):
|
||||
"""
|
||||
@@ -254,7 +203,6 @@ class LibraryAgentPresetUpdatable(pydantic.BaseModel):
|
||||
"""
|
||||
|
||||
inputs: Optional[block_model.BlockInput] = None
|
||||
credentials: Optional[dict[str, CredentialsMetaInput]] = None
|
||||
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
@@ -266,28 +214,20 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
||||
"""Represents a preset configuration for a library agent."""
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
updated_at: datetime.datetime
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, preset: prisma.models.AgentPreset) -> "LibraryAgentPreset":
|
||||
if preset.InputPresets is None:
|
||||
raise ValueError("InputPresets must be included in AgentPreset query")
|
||||
raise ValueError("Input values must be included in object")
|
||||
|
||||
input_data: block_model.BlockInput = {}
|
||||
input_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
|
||||
for preset_input in preset.InputPresets:
|
||||
if not is_credentials_field_name(preset_input.name):
|
||||
input_data[preset_input.name] = preset_input.data
|
||||
else:
|
||||
input_credentials[preset_input.name] = (
|
||||
CredentialsMetaInput.model_validate(preset_input.data)
|
||||
)
|
||||
input_data[preset_input.name] = preset_input.data
|
||||
|
||||
return cls(
|
||||
id=preset.id,
|
||||
user_id=preset.userId,
|
||||
updated_at=preset.updatedAt,
|
||||
graph_id=preset.agentGraphId,
|
||||
graph_version=preset.agentGraphVersion,
|
||||
@@ -295,8 +235,6 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
||||
description=preset.description,
|
||||
is_active=preset.isActive,
|
||||
inputs=input_data,
|
||||
credentials=input_credentials,
|
||||
webhook_id=preset.webhookId,
|
||||
)
|
||||
|
||||
|
||||
@@ -338,3 +276,6 @@ class LibraryAgentUpdateRequest(pydantic.BaseModel):
|
||||
is_archived: Optional[bool] = pydantic.Field(
|
||||
default=None, description="Archive the agent"
|
||||
)
|
||||
is_deleted: Optional[bool] = pydantic.Field(
|
||||
default=None, description="Delete the agent"
|
||||
)
|
||||
|
||||
@@ -1,19 +1,13 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, status
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel, Field
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
import backend.server.v2.library.db as library_db
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
from backend.data.graph import get_graph
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.executor.utils import make_node_credentials_input_map
|
||||
from backend.integrations.webhooks.utils import setup_webhook_for_block
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,7 +20,6 @@ router = APIRouter(
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
summary="List Library Agents",
|
||||
responses={
|
||||
500: {"description": "Server error", "content": {"application/json": {}}},
|
||||
},
|
||||
@@ -77,14 +70,14 @@ async def list_library_agents(
|
||||
page_size=page_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not list library agents for user #{user_id}: {e}")
|
||||
logger.exception("Listing library agents failed for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
detail={"message": str(e), "hint": "Inspect database connectivity."},
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/{library_agent_id}", summary="Get Library Agent")
|
||||
@router.get("/{library_agent_id}")
|
||||
async def get_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
@@ -92,26 +85,8 @@ async def get_library_agent(
|
||||
return await library_db.get_library_agent(id=library_agent_id, user_id=user_id)
|
||||
|
||||
|
||||
@router.get("/by-graph/{graph_id}")
|
||||
async def get_library_agent_by_graph_id(
|
||||
graph_id: str,
|
||||
version: Optional[int] = Query(default=None),
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
) -> library_model.LibraryAgent:
|
||||
library_agent = await library_db.get_library_agent_by_graph_id(
|
||||
user_id, graph_id, version
|
||||
)
|
||||
if not library_agent:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Library agent for graph #{graph_id} and user #{user_id} not found",
|
||||
)
|
||||
return library_agent
|
||||
|
||||
|
||||
@router.get(
|
||||
"/marketplace/{store_listing_version_id}",
|
||||
summary="Get Agent By Store ID",
|
||||
tags=["store, library"],
|
||||
response_model=library_model.LibraryAgent | None,
|
||||
)
|
||||
@@ -126,22 +101,23 @@ async def get_library_agent_by_store_listing_version_id(
|
||||
return await library_db.get_library_agent_by_store_version_id(
|
||||
store_listing_version_id, user_id
|
||||
)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not fetch library agent from store version ID: {e}")
|
||||
logger.exception(
|
||||
"Retrieving library agent by store version failed for user %s: %s",
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
detail={
|
||||
"message": str(e),
|
||||
"hint": "Check if the store listing ID is valid.",
|
||||
},
|
||||
) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
summary="Add Marketplace Agent",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
responses={
|
||||
201: {"description": "Agent added successfully"},
|
||||
@@ -173,20 +149,26 @@ async def add_marketplace_agent_to_library(
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
except store_exceptions.AgentNotFoundError as e:
|
||||
except store_exceptions.AgentNotFoundError:
|
||||
logger.warning(
|
||||
f"Could not find store listing version {store_listing_version_id} "
|
||||
"to add to library"
|
||||
"Store listing version %s not found when adding to library",
|
||||
store_listing_version_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"message": f"Store listing version {store_listing_version_id} not found",
|
||||
"hint": "Confirm the ID provided.",
|
||||
},
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
except store_exceptions.DatabaseError as e:
|
||||
logger.error(f"Database error while adding agent to library: {e}", e)
|
||||
logger.exception("Database error whilst adding agent to library: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Inspect DB logs for details."},
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while adding agent to library: {e}")
|
||||
logger.exception("Unexpected error while adding agent to library: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
@@ -196,11 +178,11 @@ async def add_marketplace_agent_to_library(
|
||||
) from e
|
||||
|
||||
|
||||
@router.patch(
|
||||
@router.put(
|
||||
"/{library_agent_id}",
|
||||
summary="Update Library Agent",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
responses={
|
||||
200: {"description": "Agent updated successfully"},
|
||||
204: {"description": "Agent updated successfully"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
@@ -208,7 +190,7 @@ async def update_library_agent(
|
||||
library_agent_id: str,
|
||||
payload: library_model.LibraryAgentUpdateRequest,
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
) -> library_model.LibraryAgent:
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
Update the library agent with the given fields.
|
||||
|
||||
@@ -217,76 +199,40 @@ async def update_library_agent(
|
||||
payload: Fields to update (auto_update_version, is_favorite, etc.).
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
204 (No Content) on success.
|
||||
|
||||
Raises:
|
||||
HTTPException(500): If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
return await library_db.update_library_agent(
|
||||
await library_db.update_library_agent(
|
||||
library_agent_id=library_agent_id,
|
||||
user_id=user_id,
|
||||
auto_update_version=payload.auto_update_version,
|
||||
is_favorite=payload.is_favorite,
|
||||
is_archived=payload.is_archived,
|
||||
is_deleted=payload.is_deleted,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
content={"message": "Agent updated successfully"},
|
||||
)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except store_exceptions.DatabaseError as e:
|
||||
logger.error(f"Database error while updating library agent: {e}")
|
||||
logger.exception("Database error while updating library agent: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Verify DB connection."},
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while updating library agent: {e}")
|
||||
logger.exception("Unexpected error while updating library agent: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Check server logs."},
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{library_agent_id}",
|
||||
summary="Delete Library Agent",
|
||||
responses={
|
||||
204: {"description": "Agent deleted successfully"},
|
||||
404: {"description": "Agent not found"},
|
||||
500: {"description": "Server error"},
|
||||
},
|
||||
)
|
||||
async def delete_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
) -> Response:
|
||||
"""
|
||||
Soft-delete the specified library agent.
|
||||
|
||||
Args:
|
||||
library_agent_id: ID of the library agent to delete.
|
||||
user_id: ID of the authenticated user.
|
||||
|
||||
Returns:
|
||||
204 No Content if successful.
|
||||
|
||||
Raises:
|
||||
HTTPException(404): If the agent does not exist.
|
||||
HTTPException(500): If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
await library_db.delete_library_agent(
|
||||
library_agent_id=library_agent_id, user_id=user_id
|
||||
)
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/{library_agent_id}/fork", summary="Fork Library Agent")
|
||||
@router.post("/{library_agent_id}/fork")
|
||||
async def fork_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
@@ -295,81 +241,3 @@ async def fork_library_agent(
|
||||
library_agent_id=library_agent_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
class TriggeredPresetSetupParams(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
|
||||
trigger_config: dict[str, Any]
|
||||
agent_credentials: dict[str, CredentialsMetaInput] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@router.post("/{library_agent_id}/setup-trigger")
|
||||
async def setup_trigger(
|
||||
library_agent_id: str = Path(..., description="ID of the library agent"),
|
||||
params: TriggeredPresetSetupParams = Body(),
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
) -> library_model.LibraryAgentPreset:
|
||||
"""
|
||||
Sets up a webhook-triggered `LibraryAgentPreset` for a `LibraryAgent`.
|
||||
Returns the correspondingly created `LibraryAgentPreset` with `webhook_id` set.
|
||||
"""
|
||||
library_agent = await library_db.get_library_agent(
|
||||
id=library_agent_id, user_id=user_id
|
||||
)
|
||||
if not library_agent:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Library agent #{library_agent_id} not found",
|
||||
)
|
||||
|
||||
graph = await get_graph(
|
||||
library_agent.graph_id, version=library_agent.graph_version, user_id=user_id
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
status.HTTP_410_GONE,
|
||||
f"Graph #{library_agent.graph_id} not accessible (anymore)",
|
||||
)
|
||||
if not (trigger_node := graph.webhook_input_node):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Graph #{library_agent.graph_id} does not have a webhook node",
|
||||
)
|
||||
|
||||
trigger_config_with_credentials = {
|
||||
**params.trigger_config,
|
||||
**(
|
||||
make_node_credentials_input_map(graph, params.agent_credentials).get(
|
||||
trigger_node.id
|
||||
)
|
||||
or {}
|
||||
),
|
||||
}
|
||||
|
||||
new_webhook, feedback = await setup_webhook_for_block(
|
||||
user_id=user_id,
|
||||
trigger_block=trigger_node.block,
|
||||
trigger_config=trigger_config_with_credentials,
|
||||
)
|
||||
if not new_webhook:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Could not set up webhook: {feedback}",
|
||||
)
|
||||
|
||||
new_preset = await library_db.create_preset(
|
||||
user_id=user_id,
|
||||
preset=library_model.LibraryAgentPresetCreatable(
|
||||
graph_id=library_agent.graph_id,
|
||||
graph_version=library_agent.graph_version,
|
||||
name=params.name,
|
||||
description=params.description,
|
||||
inputs=trigger_config_with_credentials,
|
||||
credentials=params.agent_credentials,
|
||||
webhook_id=new_webhook.id,
|
||||
is_active=True,
|
||||
),
|
||||
)
|
||||
return new_preset
|
||||
|
||||
@@ -1,23 +1,17 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Annotated, Any, Optional
|
||||
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
|
||||
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.library.model as models
|
||||
from backend.data.graph import get_graph
|
||||
from backend.data.integrations import get_webhook
|
||||
from backend.executor.utils import add_graph_execution, make_node_credentials_input_map
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.webhooks import get_webhook_manager
|
||||
from backend.integrations.webhooks.utils import setup_webhook_for_block
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
credentials_manager = IntegrationCredentialsManager()
|
||||
router = APIRouter(tags=["presets"])
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -55,7 +49,11 @@ async def list_presets(
|
||||
except Exception as e:
|
||||
logger.exception("Failed to list presets for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"message": str(e),
|
||||
"hint": "Ensure the presets DB table is accessible.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -83,21 +81,21 @@ async def get_preset(
|
||||
"""
|
||||
try:
|
||||
preset = await db.get_preset(user_id, preset_id)
|
||||
if not preset:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Preset {preset_id} not found",
|
||||
)
|
||||
return preset
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Error retrieving preset %s for user %s: %s", preset_id, user_id, e
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Validate preset ID and retry."},
|
||||
)
|
||||
|
||||
if not preset:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Preset #{preset_id} not found",
|
||||
)
|
||||
return preset
|
||||
|
||||
|
||||
@router.post(
|
||||
"/presets",
|
||||
@@ -134,7 +132,8 @@ async def create_preset(
|
||||
except Exception as e:
|
||||
logger.exception("Preset creation failed for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Check preset payload format."},
|
||||
)
|
||||
|
||||
|
||||
@@ -162,85 +161,17 @@ async def update_preset(
|
||||
Raises:
|
||||
HTTPException: If an error occurs while updating the preset.
|
||||
"""
|
||||
current = await get_preset(preset_id, user_id=user_id)
|
||||
if not current:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Preset #{preset_id} not found")
|
||||
|
||||
graph = await get_graph(
|
||||
current.graph_id,
|
||||
current.graph_version,
|
||||
user_id=user_id,
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
status.HTTP_410_GONE,
|
||||
f"Graph #{current.graph_id} not accessible (anymore)",
|
||||
)
|
||||
|
||||
trigger_inputs_updated, new_webhook, feedback = False, None, None
|
||||
if (trigger_node := graph.webhook_input_node) and (
|
||||
preset.inputs is not None and preset.credentials is not None
|
||||
):
|
||||
trigger_config_with_credentials = {
|
||||
**preset.inputs,
|
||||
**(
|
||||
make_node_credentials_input_map(graph, preset.credentials).get(
|
||||
trigger_node.id
|
||||
)
|
||||
or {}
|
||||
),
|
||||
}
|
||||
new_webhook, feedback = await setup_webhook_for_block(
|
||||
user_id=user_id,
|
||||
trigger_block=graph.webhook_input_node.block,
|
||||
trigger_config=trigger_config_with_credentials,
|
||||
for_preset_id=preset_id,
|
||||
)
|
||||
trigger_inputs_updated = True
|
||||
if not new_webhook:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Could not update trigger configuration: {feedback}",
|
||||
)
|
||||
|
||||
try:
|
||||
updated = await db.update_preset(
|
||||
user_id=user_id,
|
||||
preset_id=preset_id,
|
||||
inputs=preset.inputs,
|
||||
credentials=preset.credentials,
|
||||
name=preset.name,
|
||||
description=preset.description,
|
||||
is_active=preset.is_active,
|
||||
return await db.update_preset(
|
||||
user_id=user_id, preset_id=preset_id, preset=preset
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Preset update failed for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"message": str(e), "hint": "Check preset data and try again."},
|
||||
)
|
||||
|
||||
# Update the webhook as well, if necessary
|
||||
if trigger_inputs_updated:
|
||||
updated = await db.set_preset_webhook(
|
||||
user_id, preset_id, new_webhook.id if new_webhook else None
|
||||
)
|
||||
|
||||
# Clean up webhook if it is now unused
|
||||
if current.webhook_id and (
|
||||
current.webhook_id != (new_webhook.id if new_webhook else None)
|
||||
):
|
||||
current_webhook = await get_webhook(current.webhook_id)
|
||||
credentials = (
|
||||
await credentials_manager.get(user_id, current_webhook.credentials_id)
|
||||
if current_webhook.credentials_id
|
||||
else None
|
||||
)
|
||||
await get_webhook_manager(
|
||||
current_webhook.provider
|
||||
).prune_webhook_if_dangling(user_id, current_webhook.id, credentials)
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/presets/{preset_id}",
|
||||
@@ -262,28 +193,6 @@ async def delete_preset(
|
||||
Raises:
|
||||
HTTPException: If an error occurs while deleting the preset.
|
||||
"""
|
||||
preset = await db.get_preset(user_id, preset_id)
|
||||
if not preset:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Preset #{preset_id} not found for user #{user_id}",
|
||||
)
|
||||
|
||||
# Detach and clean up the attached webhook, if any
|
||||
if preset.webhook_id:
|
||||
webhook = await get_webhook(preset.webhook_id)
|
||||
await db.set_preset_webhook(user_id, preset_id, None)
|
||||
|
||||
# Clean up webhook if it is now unused
|
||||
credentials = (
|
||||
await credentials_manager.get(user_id, webhook.credentials_id)
|
||||
if webhook.credentials_id
|
||||
else None
|
||||
)
|
||||
await get_webhook_manager(webhook.provider).prune_webhook_if_dangling(
|
||||
user_id, webhook.id, credentials
|
||||
)
|
||||
|
||||
try:
|
||||
await db.delete_preset(user_id, preset_id)
|
||||
except Exception as e:
|
||||
@@ -292,7 +201,7 @@ async def delete_preset(
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
detail={"message": str(e), "hint": "Ensure preset exists before deleting."},
|
||||
)
|
||||
|
||||
|
||||
@@ -303,20 +212,24 @@ async def delete_preset(
|
||||
description="Execute a preset with the given graph and node input for the current user.",
|
||||
)
|
||||
async def execute_preset(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
preset_id: str,
|
||||
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
|
||||
inputs: dict[str, Any] = Body(..., embed=True, default_factory=dict),
|
||||
) -> dict[str, Any]: # FIXME: add proper return type
|
||||
"""
|
||||
Execute a preset given graph parameters, returning the execution ID on success.
|
||||
|
||||
Args:
|
||||
graph_id (str): ID of the graph to execute.
|
||||
graph_version (int): Version of the graph to execute.
|
||||
preset_id (str): ID of the preset to execute.
|
||||
node_input (Dict[Any, Any]): Input data for the node.
|
||||
user_id (str): ID of the authenticated user.
|
||||
inputs (dict[str, Any]): Optionally, additional input data for the graph execution.
|
||||
|
||||
Returns:
|
||||
{id: graph_exec_id}: A response containing the execution ID.
|
||||
Dict[str, Any]: A response containing the execution ID.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the preset is not found or an error occurs while executing the preset.
|
||||
@@ -326,18 +239,18 @@ async def execute_preset(
|
||||
if not preset:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Preset #{preset_id} not found",
|
||||
detail="Preset not found",
|
||||
)
|
||||
|
||||
# Merge input overrides with preset inputs
|
||||
merged_node_input = preset.inputs | inputs
|
||||
merged_node_input = preset.inputs | node_input
|
||||
|
||||
execution = await add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
graph_id=preset.graph_id,
|
||||
graph_version=preset.graph_version,
|
||||
preset_id=preset_id,
|
||||
inputs=merged_node_input,
|
||||
preset_id=preset_id,
|
||||
graph_version=graph_version,
|
||||
)
|
||||
|
||||
logger.debug(f"Execution added: {execution} with input: {merged_node_input}")
|
||||
@@ -348,6 +261,9 @@ async def execute_preset(
|
||||
except Exception as e:
|
||||
logger.exception("Preset execution failed for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"message": str(e),
|
||||
"hint": "Review preset configuration and graph ID.",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -50,8 +50,6 @@ async def test_get_library_agents_success(
|
||||
creator_name="Test Creator",
|
||||
creator_image_url="",
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
new_output=False,
|
||||
can_access_graph=True,
|
||||
@@ -68,8 +66,6 @@ async def test_get_library_agents_success(
|
||||
creator_name="Test Creator",
|
||||
creator_image_url="",
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
new_output=False,
|
||||
can_access_graph=False,
|
||||
@@ -121,57 +117,26 @@ def test_get_library_agents_error(mocker: pytest_mock.MockFixture):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Mocker Not implemented")
|
||||
def test_add_agent_to_library_success(mocker: pytest_mock.MockFixture):
|
||||
mock_library_agent = library_model.LibraryAgent(
|
||||
id="test-library-agent-id",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
image_url=None,
|
||||
creator_name="Test Creator",
|
||||
creator_image_url="",
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
new_output=False,
|
||||
can_access_graph=True,
|
||||
is_latest_version=True,
|
||||
updated_at=FIXED_NOW,
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.add_agent_to_library")
|
||||
mock_db_call.return_value = None
|
||||
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.add_store_agent_to_library"
|
||||
)
|
||||
mock_db_call.return_value = mock_library_agent
|
||||
|
||||
response = client.post(
|
||||
"/agents", json={"store_listing_version_id": "test-version-id"}
|
||||
)
|
||||
response = client.post("/agents/test-version-id")
|
||||
assert response.status_code == 201
|
||||
|
||||
# Verify the response contains the library agent data
|
||||
data = library_model.LibraryAgent.model_validate(response.json())
|
||||
assert data.id == "test-library-agent-id"
|
||||
assert data.graph_id == "test-agent-1"
|
||||
|
||||
mock_db_call.assert_called_once_with(
|
||||
store_listing_version_id="test-version-id", user_id="test-user-id"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Mocker Not implemented")
|
||||
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture):
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.add_store_agent_to_library"
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.add_agent_to_library")
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.post(
|
||||
"/agents", json={"store_listing_version_id": "test-version-id"}
|
||||
)
|
||||
response = client.post("/agents/test-version-id")
|
||||
assert response.status_code == 500
|
||||
assert "detail" in response.json() # Verify error response structure
|
||||
assert response.json()["detail"] == "Failed to add agent to library"
|
||||
mock_db_call.assert_called_once_with(
|
||||
store_listing_version_id="test-version-id", user_id="test-user-id"
|
||||
)
|
||||
|
||||
@@ -14,10 +14,7 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/ask",
|
||||
response_model=ApiResponse,
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
summary="Proxy Otto Chat Request",
|
||||
"/ask", response_model=ApiResponse, dependencies=[Depends(auth_middleware)]
|
||||
)
|
||||
async def proxy_otto_request(
|
||||
request: ChatRequest, user_id: str = Depends(get_user_id)
|
||||
|
||||
@@ -259,8 +259,8 @@ def test_ask_otto_unauthenticated(mocker: pytest_mock.MockFixture) -> None:
|
||||
}
|
||||
|
||||
response = client.post("/ask", json=request_data)
|
||||
# When auth is disabled and Otto API URL is not configured, we get 502 (wrapped from 503)
|
||||
assert response.status_code == 502
|
||||
# When auth is disabled and Otto API URL is not configured, we get 503
|
||||
assert response.status_code == 503
|
||||
|
||||
# Restore the override
|
||||
app.dependency_overrides[autogpt_libs.auth.middleware.auth_middleware] = (
|
||||
|
||||
@@ -93,14 +93,6 @@ async def test_get_store_agent_details(mocker):
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Mock Profile prisma call
|
||||
mock_profile = mocker.MagicMock()
|
||||
mock_profile.userId = "user-id-123"
|
||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_profile
|
||||
)
|
||||
|
||||
# Mock StoreListing prisma call - this is what was missing
|
||||
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
|
||||
|
||||
@@ -34,20 +34,6 @@ class StorageUploadError(MediaUploadError):
|
||||
pass
|
||||
|
||||
|
||||
class VirusDetectedError(MediaUploadError):
|
||||
"""Raised when a virus is detected in uploaded file"""
|
||||
|
||||
def __init__(self, threat_name: str, message: str | None = None):
|
||||
self.threat_name = threat_name
|
||||
super().__init__(message or f"Virus detected: {threat_name}")
|
||||
|
||||
|
||||
class VirusScanError(MediaUploadError):
|
||||
"""Raised when virus scanning fails"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StoreError(Exception):
|
||||
"""Base exception for store-related errors"""
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ from google.cloud import storage
|
||||
import backend.server.v2.store.exceptions
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -68,7 +67,7 @@ async def upload_media(
|
||||
# Validate file signature/magic bytes
|
||||
if file.content_type in ALLOWED_IMAGE_TYPES:
|
||||
# Check image file signatures
|
||||
if content.startswith(b"\xff\xd8\xff"): # JPEG
|
||||
if content.startswith(b"\xFF\xD8\xFF"): # JPEG
|
||||
if file.content_type != "image/jpeg":
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
"File signature does not match content type"
|
||||
@@ -176,7 +175,6 @@ async def upload_media(
|
||||
blob.content_type = content_type
|
||||
|
||||
file_bytes = await file.read()
|
||||
await scan_content_safe(file_bytes, filename=unique_filename)
|
||||
blob.upload_from_string(file_bytes, content_type=content_type)
|
||||
|
||||
public_url = blob.public_url
|
||||
|
||||
@@ -12,7 +12,6 @@ from autogpt_libs.auth.depends import auth_middleware, get_user_id
|
||||
import backend.data.block
|
||||
import backend.data.graph
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.image_gen
|
||||
import backend.server.v2.store.media
|
||||
import backend.server.v2.store.model
|
||||
@@ -30,7 +29,6 @@ router = fastapi.APIRouter()
|
||||
|
||||
@router.get(
|
||||
"/profile",
|
||||
summary="Get user profile",
|
||||
tags=["store", "private"],
|
||||
response_model=backend.server.v2.store.model.ProfileDetails,
|
||||
)
|
||||
@@ -63,7 +61,6 @@ async def get_profile(
|
||||
|
||||
@router.post(
|
||||
"/profile",
|
||||
summary="Update user profile",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
response_model=backend.server.v2.store.model.CreatorDetails,
|
||||
@@ -110,7 +107,6 @@ async def update_or_create_profile(
|
||||
|
||||
@router.get(
|
||||
"/agents",
|
||||
summary="List store agents",
|
||||
tags=["store", "public"],
|
||||
response_model=backend.server.v2.store.model.StoreAgentsResponse,
|
||||
)
|
||||
@@ -183,7 +179,6 @@ async def get_agents(
|
||||
|
||||
@router.get(
|
||||
"/agents/{username}/{agent_name}",
|
||||
summary="Get specific agent",
|
||||
tags=["store", "public"],
|
||||
response_model=backend.server.v2.store.model.StoreAgentDetails,
|
||||
)
|
||||
@@ -213,7 +208,6 @@ async def get_agent(username: str, agent_name: str):
|
||||
|
||||
@router.get(
|
||||
"/graph/{store_listing_version_id}",
|
||||
summary="Get agent graph",
|
||||
tags=["store"],
|
||||
)
|
||||
async def get_graph_meta_by_store_listing_version_id(
|
||||
@@ -238,7 +232,6 @@ async def get_graph_meta_by_store_listing_version_id(
|
||||
|
||||
@router.get(
|
||||
"/agents/{store_listing_version_id}",
|
||||
summary="Get agent by version",
|
||||
tags=["store"],
|
||||
response_model=backend.server.v2.store.model.StoreAgentDetails,
|
||||
)
|
||||
@@ -264,7 +257,6 @@ async def get_store_agent(
|
||||
|
||||
@router.post(
|
||||
"/agents/{username}/{agent_name}/review",
|
||||
summary="Create agent review",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
response_model=backend.server.v2.store.model.StoreReview,
|
||||
@@ -316,7 +308,6 @@ async def create_review(
|
||||
|
||||
@router.get(
|
||||
"/creators",
|
||||
summary="List store creators",
|
||||
tags=["store", "public"],
|
||||
response_model=backend.server.v2.store.model.CreatorsResponse,
|
||||
)
|
||||
@@ -368,7 +359,6 @@ async def get_creators(
|
||||
|
||||
@router.get(
|
||||
"/creator/{username}",
|
||||
summary="Get creator details",
|
||||
tags=["store", "public"],
|
||||
response_model=backend.server.v2.store.model.CreatorDetails,
|
||||
)
|
||||
@@ -400,7 +390,6 @@ async def get_creator(
|
||||
############################################
|
||||
@router.get(
|
||||
"/myagents",
|
||||
summary="Get my agents",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
response_model=backend.server.v2.store.model.MyAgentsResponse,
|
||||
@@ -423,7 +412,6 @@ async def get_my_agents(
|
||||
|
||||
@router.delete(
|
||||
"/submissions/{submission_id}",
|
||||
summary="Delete store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
response_model=bool,
|
||||
@@ -460,7 +448,6 @@ async def delete_submission(
|
||||
|
||||
@router.get(
|
||||
"/submissions",
|
||||
summary="List my submissions",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
response_model=backend.server.v2.store.model.StoreSubmissionsResponse,
|
||||
@@ -514,7 +501,6 @@ async def get_submissions(
|
||||
|
||||
@router.post(
|
||||
"/submissions",
|
||||
summary="Create store submission",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
response_model=backend.server.v2.store.model.StoreSubmission,
|
||||
@@ -562,7 +548,6 @@ async def create_submission(
|
||||
|
||||
@router.post(
|
||||
"/submissions/media",
|
||||
summary="Upload submission media",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
@@ -590,25 +575,6 @@ async def upload_submission_media(
|
||||
user_id=user_id, file=file
|
||||
)
|
||||
return media_url
|
||||
except backend.server.v2.store.exceptions.VirusDetectedError as e:
|
||||
logger.warning(f"Virus detected in uploaded file: {e.threat_name}")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"detail": f"File rejected due to virus detection: {e.threat_name}",
|
||||
"error_type": "virus_detected",
|
||||
"threat_name": e.threat_name,
|
||||
},
|
||||
)
|
||||
except backend.server.v2.store.exceptions.VirusScanError as e:
|
||||
logger.error(f"Virus scanning failed: {str(e)}")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"detail": "Virus scanning service unavailable. Please try again later.",
|
||||
"error_type": "virus_scan_failed",
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst uploading submission media")
|
||||
return fastapi.responses.JSONResponse(
|
||||
@@ -619,7 +585,6 @@ async def upload_submission_media(
|
||||
|
||||
@router.post(
|
||||
"/submissions/generate_image",
|
||||
summary="Generate submission image",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
@@ -681,7 +646,6 @@ async def generate_image(
|
||||
|
||||
@router.get(
|
||||
"/download/agents/{store_listing_version_id}",
|
||||
summary="Download agent file",
|
||||
tags=["store", "public"],
|
||||
)
|
||||
async def download_agent_file(
|
||||
|
||||
@@ -13,9 +13,7 @@ router = APIRouter()
|
||||
settings = Settings()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/verify", response_model=TurnstileVerifyResponse, summary="Verify Turnstile Token"
|
||||
)
|
||||
@router.post("/verify", response_model=TurnstileVerifyResponse)
|
||||
async def verify_turnstile_token(
|
||||
request: TurnstileVerifyRequest,
|
||||
) -> TurnstileVerifyResponse:
|
||||
|
||||
@@ -2,17 +2,7 @@ import functools
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Literal,
|
||||
ParamSpec,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
overload,
|
||||
)
|
||||
from typing import Any, Awaitable, Callable, Coroutine, ParamSpec, Tuple, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -82,115 +72,37 @@ def async_time_measured(
|
||||
return async_wrapper
|
||||
|
||||
|
||||
@overload
|
||||
def error_logged(
|
||||
*, swallow: Literal[True]
|
||||
) -> Callable[[Callable[P, T]], Callable[P, T | None]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def error_logged(
|
||||
*, swallow: Literal[False]
|
||||
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def error_logged() -> Callable[[Callable[P, T]], Callable[P, T | None]]: ...
|
||||
|
||||
|
||||
def error_logged(
|
||||
*, swallow: bool = True
|
||||
) -> (
|
||||
Callable[[Callable[P, T]], Callable[P, T | None]]
|
||||
| Callable[[Callable[P, T]], Callable[P, T]]
|
||||
):
|
||||
def error_logged(func: Callable[P, T]) -> Callable[P, T | None]:
|
||||
"""
|
||||
Decorator to log any exceptions raised by a function, with optional suppression.
|
||||
|
||||
Args:
|
||||
swallow: Whether to suppress the exception (True) or re-raise it (False). Default is True.
|
||||
|
||||
Usage:
|
||||
@error_logged() # Default behavior (swallow errors)
|
||||
@error_logged(swallow=False) # Log and re-raise errors
|
||||
Decorator to suppress and log any exceptions raised by a function.
|
||||
"""
|
||||
|
||||
def decorator(f: Callable[P, T]) -> Callable[P, T | None]:
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
|
||||
try:
|
||||
return f(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error when calling function {f.__name__} with arguments {args} {kwargs}: {e}"
|
||||
)
|
||||
if not swallow:
|
||||
raise
|
||||
return None
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error when calling function {func.__name__} with arguments {args} {kwargs}: {e}"
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
return wrapper
|
||||
|
||||
|
||||
@overload
|
||||
def async_error_logged(
|
||||
*, swallow: Literal[True]
|
||||
) -> Callable[
|
||||
[Callable[P, Coroutine[Any, Any, T]]], Callable[P, Coroutine[Any, Any, T | None]]
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def async_error_logged(
|
||||
*, swallow: Literal[False]
|
||||
) -> Callable[
|
||||
[Callable[P, Coroutine[Any, Any, T]]], Callable[P, Coroutine[Any, Any, T]]
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def async_error_logged() -> Callable[
|
||||
[Callable[P, Coroutine[Any, Any, T]]],
|
||||
Callable[P, Coroutine[Any, Any, T | None]],
|
||||
]: ...
|
||||
|
||||
|
||||
def async_error_logged(*, swallow: bool = True) -> (
|
||||
Callable[
|
||||
[Callable[P, Coroutine[Any, Any, T]]],
|
||||
Callable[P, Coroutine[Any, Any, T | None]],
|
||||
]
|
||||
| Callable[
|
||||
[Callable[P, Coroutine[Any, Any, T]]], Callable[P, Coroutine[Any, Any, T]]
|
||||
]
|
||||
):
|
||||
func: Callable[P, Coroutine[Any, Any, T]],
|
||||
) -> Callable[P, Coroutine[Any, Any, T | None]]:
|
||||
"""
|
||||
Decorator to log any exceptions raised by an async function, with optional suppression.
|
||||
|
||||
Args:
|
||||
swallow: Whether to suppress the exception (True) or re-raise it (False). Default is True.
|
||||
|
||||
Usage:
|
||||
@async_error_logged() # Default behavior (swallow errors)
|
||||
@async_error_logged(swallow=False) # Log and re-raise errors
|
||||
Decorator to suppress and log any exceptions raised by an async function.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
f: Callable[P, Coroutine[Any, Any, T]]
|
||||
) -> Callable[P, Coroutine[Any, Any, T | None]]:
|
||||
@functools.wraps(f)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
|
||||
try:
|
||||
return await f(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error when calling async function {f.__name__} with arguments {args} {kwargs}: {e}"
|
||||
)
|
||||
if not swallow:
|
||||
raise
|
||||
return None
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error when calling async function {func.__name__} with arguments {args} {kwargs}: {e}"
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
return wrapper
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.decorator import async_error_logged, error_logged, time_measured
|
||||
|
||||
|
||||
@time_measured
|
||||
def example_function(a: int, b: int, c: int) -> int:
|
||||
time.sleep(0.5)
|
||||
return a + b + c
|
||||
|
||||
|
||||
@error_logged(swallow=True)
|
||||
def example_function_with_error_swallowed(a: int, b: int, c: int) -> int:
|
||||
raise ValueError("This error should be swallowed")
|
||||
|
||||
|
||||
@error_logged(swallow=False)
|
||||
def example_function_with_error_not_swallowed(a: int, b: int, c: int) -> int:
|
||||
raise ValueError("This error should NOT be swallowed")
|
||||
|
||||
|
||||
@async_error_logged(swallow=True)
|
||||
async def async_function_with_error_swallowed() -> int:
|
||||
raise ValueError("This async error should be swallowed")
|
||||
|
||||
|
||||
@async_error_logged(swallow=False)
|
||||
async def async_function_with_error_not_swallowed() -> int:
|
||||
raise ValueError("This async error should NOT be swallowed")
|
||||
|
||||
|
||||
def test_timer_decorator():
|
||||
"""Test that the time_measured decorator correctly measures execution time."""
|
||||
info, res = example_function(1, 2, 3)
|
||||
assert info.cpu_time >= 0
|
||||
assert info.wall_time >= 0.4
|
||||
assert res == 6
|
||||
|
||||
|
||||
def test_error_decorator_swallow_true():
|
||||
"""Test that error_logged(swallow=True) logs and swallows errors."""
|
||||
res = example_function_with_error_swallowed(1, 2, 3)
|
||||
assert res is None
|
||||
|
||||
|
||||
def test_error_decorator_swallow_false():
|
||||
"""Test that error_logged(swallow=False) logs errors but re-raises them."""
|
||||
with pytest.raises(ValueError, match="This error should NOT be swallowed"):
|
||||
example_function_with_error_not_swallowed(1, 2, 3)
|
||||
|
||||
|
||||
def test_async_error_decorator_swallow_true():
|
||||
"""Test that async_error_logged(swallow=True) logs and swallows errors."""
|
||||
import asyncio
|
||||
|
||||
async def run_test():
|
||||
res = await async_function_with_error_swallowed()
|
||||
return res
|
||||
|
||||
res = asyncio.run(run_test())
|
||||
assert res is None
|
||||
|
||||
|
||||
def test_async_error_decorator_swallow_false():
|
||||
"""Test that async_error_logged(swallow=False) logs errors but re-raises them."""
|
||||
import asyncio
|
||||
|
||||
async def run_test():
|
||||
await async_function_with_error_not_swallowed()
|
||||
|
||||
with pytest.raises(ValueError, match="This async error should NOT be swallowed"):
|
||||
asyncio.run(run_test())
|
||||
@@ -10,10 +10,6 @@ class NeedConfirmation(Exception):
|
||||
"""The user must explicitly confirm that they want to proceed"""
|
||||
|
||||
|
||||
class NotAuthorizedError(ValueError):
|
||||
"""The user is not authorized to perform the requested operation"""
|
||||
|
||||
|
||||
class InsufficientBalanceError(ValueError):
|
||||
user_id: str
|
||||
message: str
|
||||
|
||||
@@ -9,7 +9,6 @@ from urllib.parse import urlparse
|
||||
|
||||
from backend.util.request import Requests
|
||||
from backend.util.type import MediaFileType
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
TEMP_DIR = Path(tempfile.gettempdir()).resolve()
|
||||
|
||||
@@ -106,11 +105,7 @@ async def store_media_file(
|
||||
extension = _extension_from_mime(mime_type)
|
||||
filename = f"{uuid.uuid4()}{extension}"
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
content = base64.b64decode(b64_content)
|
||||
|
||||
# Virus scan the base64 content before writing
|
||||
await scan_content_safe(content, filename=filename)
|
||||
target_path.write_bytes(content)
|
||||
target_path.write_bytes(base64.b64decode(b64_content))
|
||||
|
||||
elif file.startswith(("http://", "https://")):
|
||||
# URL
|
||||
@@ -120,9 +115,6 @@ async def store_media_file(
|
||||
|
||||
# Download and save
|
||||
resp = await Requests().get(file)
|
||||
|
||||
# Virus scan the downloaded content before writing
|
||||
await scan_content_safe(resp.content, filename=filename)
|
||||
target_path.write_bytes(resp.content)
|
||||
|
||||
else:
|
||||
|
||||
@@ -14,37 +14,8 @@ def to_dict(data) -> dict:
|
||||
return jsonable_encoder(data)
|
||||
|
||||
|
||||
def dumps(data: Any, *args: Any, **kwargs: Any) -> str:
|
||||
"""
|
||||
Serialize data to JSON string with automatic conversion of Pydantic models and complex types.
|
||||
|
||||
This function converts the input data to a JSON-serializable format using FastAPI's
|
||||
jsonable_encoder before dumping to JSON. It handles Pydantic models, complex types,
|
||||
and ensures proper serialization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : Any
|
||||
The data to serialize. Can be any type including Pydantic models, dicts, lists, etc.
|
||||
*args : Any
|
||||
Additional positional arguments passed to json.dumps()
|
||||
**kwargs : Any
|
||||
Additional keyword arguments passed to json.dumps() (e.g., indent, separators)
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
JSON string representation of the data
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> dumps({"name": "Alice", "age": 30})
|
||||
'{"name": "Alice", "age": 30}'
|
||||
|
||||
>>> dumps(pydantic_model_instance, indent=2)
|
||||
'{\n "field1": "value1",\n "field2": "value2"\n}'
|
||||
"""
|
||||
return json.dumps(to_dict(data), *args, **kwargs)
|
||||
def dumps(data) -> str:
|
||||
return json.dumps(to_dict(data))
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import logging
|
||||
from logging import Logger
|
||||
|
||||
from backend.util.settings import AppEnvironment, BehaveAs, Settings
|
||||
|
||||
@@ -6,6 +6,8 @@ settings = Settings()
|
||||
|
||||
|
||||
def configure_logging():
|
||||
import logging
|
||||
|
||||
import autogpt_libs.logging.config
|
||||
|
||||
if (
|
||||
@@ -23,7 +25,7 @@ def configure_logging():
|
||||
class TruncatedLogger:
|
||||
def __init__(
|
||||
self,
|
||||
logger: logging.Logger,
|
||||
logger: Logger,
|
||||
prefix: str = "",
|
||||
metadata: dict | None = None,
|
||||
max_length: int = 1000,
|
||||
@@ -63,13 +65,3 @@ class TruncatedLogger:
|
||||
if len(text) > self.max_length:
|
||||
text = text[: self.max_length] + "..."
|
||||
return text
|
||||
|
||||
|
||||
class PrefixFilter(logging.Filter):
|
||||
def __init__(self, prefix: str):
|
||||
super().__init__()
|
||||
self.prefix = prefix
|
||||
|
||||
def filter(self, record):
|
||||
record.msg = f"{self.prefix} {record.msg}"
|
||||
return True
|
||||
|
||||
@@ -1,206 +0,0 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from backend.util import json
|
||||
|
||||
# ---------------------------------------------------------------------------#
|
||||
# INTERNAL UTILITIES #
|
||||
# ---------------------------------------------------------------------------#
|
||||
|
||||
|
||||
def _tok_len(text: str, enc) -> int:
|
||||
"""True token length of *text* in tokenizer *enc* (no wrapper cost)."""
|
||||
return len(enc.encode(str(text)))
|
||||
|
||||
|
||||
def _msg_tokens(msg: dict, enc) -> int:
|
||||
"""
|
||||
OpenAI counts ≈3 wrapper tokens per chat message, plus 1 if "name"
|
||||
is present, plus the tokenised content length.
|
||||
"""
|
||||
WRAPPER = 3 + (1 if "name" in msg else 0)
|
||||
return WRAPPER + _tok_len(msg.get("content") or "", enc)
|
||||
|
||||
|
||||
def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
||||
"""
|
||||
Return *text* shortened to ≈max_tok tokens by keeping the head & tail
|
||||
and inserting an ellipsis token in the middle.
|
||||
"""
|
||||
ids = enc.encode(str(text))
|
||||
if len(ids) <= max_tok:
|
||||
return text # nothing to do
|
||||
|
||||
# Split the allowance between the two ends:
|
||||
head = max_tok // 2 - 1 # -1 for the ellipsis
|
||||
tail = max_tok - head - 1
|
||||
mid = enc.encode(" … ")
|
||||
return enc.decode(ids[:head] + mid + ids[-tail:])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------#
|
||||
# PUBLIC API #
|
||||
# ---------------------------------------------------------------------------#
|
||||
|
||||
|
||||
def compress_prompt(
|
||||
messages: list[dict],
|
||||
target_tokens: int,
|
||||
*,
|
||||
model: str = "gpt-4o",
|
||||
reserve: int = 2_048,
|
||||
start_cap: int = 8_192,
|
||||
floor_cap: int = 128,
|
||||
lossy_ok: bool = True,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Shrink *messages* so that::
|
||||
|
||||
token_count(prompt) + reserve ≤ target_tokens
|
||||
|
||||
Strategy
|
||||
--------
|
||||
1. **Token-aware truncation** – progressively halve a per-message cap
|
||||
(`start_cap`, `start_cap/2`, … `floor_cap`) and apply it to the
|
||||
*content* of every message except the first and last. Tool shells
|
||||
are included: we keep the envelope but shorten huge payloads.
|
||||
2. **Middle-out deletion** – if still over the limit, delete whole
|
||||
messages working outward from the centre, **skipping** any message
|
||||
that contains ``tool_calls`` or has ``role == "tool"``.
|
||||
3. **Last-chance trim** – if still too big, truncate the *first* and
|
||||
*last* message bodies down to `floor_cap` tokens.
|
||||
4. If the prompt is *still* too large:
|
||||
• raise ``ValueError`` when ``lossy_ok == False`` (default)
|
||||
• return the partially-trimmed prompt when ``lossy_ok == True``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
messages Complete chat history (will be deep-copied).
|
||||
model Model name; passed to tiktoken to pick the right
|
||||
tokenizer (gpt-4o → 'o200k_base', others fallback).
|
||||
target_tokens Hard ceiling for prompt size **excluding** the model's
|
||||
forthcoming answer.
|
||||
reserve How many tokens you want to leave available for that
|
||||
answer (`max_tokens` in your subsequent completion call).
|
||||
start_cap Initial per-message truncation ceiling (tokens).
|
||||
floor_cap Lowest cap we'll accept before moving to deletions.
|
||||
lossy_ok If *True* return best-effort prompt instead of raising
|
||||
after all trim passes have been exhausted.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[dict] – A *new* messages list that abides by the rules above.
|
||||
"""
|
||||
enc = encoding_for_model(model) # best-match tokenizer
|
||||
msgs = deepcopy(messages) # never mutate caller
|
||||
|
||||
def total_tokens() -> int:
|
||||
"""Current size of *msgs* in tokens."""
|
||||
return sum(_msg_tokens(m, enc) for m in msgs)
|
||||
|
||||
original_token_count = total_tokens()
|
||||
if original_token_count + reserve <= target_tokens:
|
||||
return msgs
|
||||
|
||||
# ---- STEP 0 : normalise content --------------------------------------
|
||||
# Convert non-string payloads to strings so token counting is coherent.
|
||||
for m in msgs[1:-1]: # keep the first & last intact
|
||||
if not isinstance(m.get("content"), str) and m.get("content") is not None:
|
||||
# Reasonable 20k-char ceiling prevents pathological blobs
|
||||
content_str = json.dumps(m["content"], separators=(",", ":"))
|
||||
if len(content_str) > 20_000:
|
||||
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
|
||||
m["content"] = content_str
|
||||
|
||||
# ---- STEP 1 : token-aware truncation ---------------------------------
|
||||
cap = start_cap
|
||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||
for m in msgs[1:-1]: # keep first & last intact
|
||||
if _tok_len(m.get("content") or "", enc) > cap:
|
||||
m["content"] = _truncate_middle_tokens(m["content"], enc, cap)
|
||||
cap //= 2 # tighten the screw
|
||||
|
||||
# ---- STEP 2 : middle-out deletion -----------------------------------
|
||||
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
||||
centre = len(msgs) // 2
|
||||
# Build a symmetrical centre-out index walk: centre, centre+1, centre-1, ...
|
||||
order = [centre] + [
|
||||
i
|
||||
for pair in zip(range(centre + 1, len(msgs) - 1), range(centre - 1, 0, -1))
|
||||
for i in pair
|
||||
]
|
||||
removed = False
|
||||
for i in order:
|
||||
msg = msgs[i]
|
||||
if "tool_calls" in msg or msg.get("role") == "tool":
|
||||
continue # protect tool shells
|
||||
del msgs[i]
|
||||
removed = True
|
||||
break
|
||||
if not removed: # nothing more we can drop
|
||||
break
|
||||
|
||||
# ---- STEP 3 : final safety-net trim on first & last ------------------
|
||||
cap = start_cap
|
||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||
for idx in (0, -1): # first and last
|
||||
text = msgs[idx].get("content") or ""
|
||||
if _tok_len(text, enc) > cap:
|
||||
msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap)
|
||||
cap //= 2 # tighten the screw
|
||||
|
||||
# ---- STEP 4 : success or fail-gracefully -----------------------------
|
||||
if total_tokens() + reserve > target_tokens and not lossy_ok:
|
||||
raise ValueError(
|
||||
"compress_prompt: prompt still exceeds budget "
|
||||
f"({total_tokens() + reserve} > {target_tokens})."
|
||||
)
|
||||
|
||||
return msgs
|
||||
|
||||
|
||||
def estimate_token_count(
|
||||
messages: list[dict],
|
||||
*,
|
||||
model: str = "gpt-4o",
|
||||
) -> int:
|
||||
"""
|
||||
Return the true token count of *messages* when encoded for *model*.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
messages Complete chat history.
|
||||
model Model name; passed to tiktoken to pick the right
|
||||
tokenizer (gpt-4o → 'o200k_base', others fallback).
|
||||
|
||||
Returns
|
||||
-------
|
||||
int – Token count.
|
||||
"""
|
||||
enc = encoding_for_model(model) # best-match tokenizer
|
||||
return sum(_msg_tokens(m, enc) for m in messages)
|
||||
|
||||
|
||||
def estimate_token_count_str(
|
||||
text: Any,
|
||||
*,
|
||||
model: str = "gpt-4o",
|
||||
) -> int:
|
||||
"""
|
||||
Return the true token count of *text* when encoded for *model*.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text Input text.
|
||||
model Model name; passed to tiktoken to pick the right
|
||||
tokenizer (gpt-4o → 'o200k_base', others fallback).
|
||||
|
||||
Returns
|
||||
-------
|
||||
int – Token count.
|
||||
"""
|
||||
enc = encoding_for_model(model) # best-match tokenizer
|
||||
text = json.dumps(text) if not isinstance(text, str) else text
|
||||
return _tok_len(text, enc)
|
||||
@@ -353,10 +353,6 @@ class Requests:
|
||||
max_redirects: int = 10,
|
||||
**kwargs,
|
||||
) -> Response:
|
||||
# Convert auth tuple to aiohttp.BasicAuth if necessary
|
||||
if "auth" in kwargs and isinstance(kwargs["auth"], tuple):
|
||||
kwargs["auth"] = aiohttp.BasicAuth(*kwargs["auth"])
|
||||
|
||||
if files is not None:
|
||||
if json is not None:
|
||||
raise ValueError(
|
||||
@@ -434,13 +430,7 @@ class Requests:
|
||||
) as response:
|
||||
|
||||
if self.raise_for_status:
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except ClientResponseError as e:
|
||||
body = await response.read()
|
||||
raise Exception(
|
||||
f"HTTP {response.status} Error: {response.reason}, Body: {body.decode(errors='replace')}"
|
||||
) from e
|
||||
response.raise_for_status()
|
||||
|
||||
# If allowed and a redirect is received, follow the redirect manually
|
||||
if allow_redirects and response.status in (301, 302, 303, 307, 308):
|
||||
|
||||
@@ -31,7 +31,7 @@ from tenacity import (
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
import backend.util.exceptions as exceptions
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.json import to_dict
|
||||
from backend.util.metrics import sentry_init
|
||||
from backend.util.process import AppProcess, get_service_name
|
||||
@@ -106,13 +106,7 @@ EXCEPTION_MAPPING = {
|
||||
ValueError,
|
||||
TimeoutError,
|
||||
ConnectionError,
|
||||
*[
|
||||
ErrorType
|
||||
for _, ErrorType in inspect.getmembers(exceptions)
|
||||
if inspect.isclass(ErrorType)
|
||||
and issubclass(ErrorType, Exception)
|
||||
and ErrorType.__module__ == exceptions.__name__
|
||||
],
|
||||
InsufficientBalanceError,
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -238,31 +238,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The Discord channel for the platform",
|
||||
)
|
||||
|
||||
clamav_service_host: str = Field(
|
||||
default="localhost",
|
||||
description="The host for the ClamAV daemon",
|
||||
)
|
||||
clamav_service_port: int = Field(
|
||||
default=3310,
|
||||
description="The port for the ClamAV daemon",
|
||||
)
|
||||
clamav_service_timeout: int = Field(
|
||||
default=60,
|
||||
description="The timeout in seconds for the ClamAV daemon",
|
||||
)
|
||||
clamav_service_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether virus scanning is enabled or not",
|
||||
)
|
||||
clamav_max_concurrency: int = Field(
|
||||
default=10,
|
||||
description="The maximum number of concurrent scans to perform",
|
||||
)
|
||||
clamav_mark_failed_scans_as_clean: bool = Field(
|
||||
default=False,
|
||||
description="Whether to mark failed scans as clean or not",
|
||||
)
|
||||
|
||||
@field_validator("platform_base_url", "frontend_base_url")
|
||||
@classmethod
|
||||
def validate_platform_base_url(cls, v: str, info: ValidationInfo) -> str:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import types
|
||||
from typing import Any, Type, TypeVar, Union, cast, get_args, get_origin, overload
|
||||
from typing import Any, Type, TypeVar, cast, get_args, get_origin
|
||||
|
||||
from prisma import Json as PrismaJson
|
||||
|
||||
@@ -105,37 +104,9 @@ def __convert_bool(value: Any) -> bool:
|
||||
return bool(value)
|
||||
|
||||
|
||||
def _try_convert(value: Any, target_type: Any, raise_on_mismatch: bool) -> Any:
|
||||
def _try_convert(value: Any, target_type: Type, raise_on_mismatch: bool) -> Any:
|
||||
origin = get_origin(target_type)
|
||||
args = get_args(target_type)
|
||||
|
||||
# Handle Union types (including Optional which is Union[T, None])
|
||||
if origin is Union or origin is types.UnionType:
|
||||
# Handle None values for Optional types
|
||||
if value is None:
|
||||
if type(None) in args:
|
||||
return None
|
||||
elif raise_on_mismatch:
|
||||
raise TypeError(f"Value {value} is not of expected type {target_type}")
|
||||
else:
|
||||
return value
|
||||
|
||||
# Try to convert to each type in the union, excluding None
|
||||
non_none_types = [arg for arg in args if arg is not type(None)]
|
||||
|
||||
# Try each type in the union, using the original raise_on_mismatch behavior
|
||||
for arg_type in non_none_types:
|
||||
try:
|
||||
return _try_convert(value, arg_type, raise_on_mismatch)
|
||||
except (TypeError, ValueError, ConversionError):
|
||||
continue
|
||||
|
||||
# If no conversion succeeded
|
||||
if raise_on_mismatch:
|
||||
raise TypeError(f"Value {value} is not of expected type {target_type}")
|
||||
else:
|
||||
return value
|
||||
|
||||
if origin is None:
|
||||
origin = target_type
|
||||
if origin not in [list, dict, tuple, str, set, int, float, bool]:
|
||||
@@ -218,19 +189,11 @@ def type_match(value: Any, target_type: Type[T]) -> T:
|
||||
return cast(T, _try_convert(value, target_type, raise_on_mismatch=True))
|
||||
|
||||
|
||||
@overload
|
||||
def convert(value: Any, target_type: Type[T]) -> T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def convert(value: Any, target_type: Any) -> Any: ...
|
||||
|
||||
|
||||
def convert(value: Any, target_type: Any) -> Any:
|
||||
def convert(value: Any, target_type: Type[T]) -> T:
|
||||
try:
|
||||
if isinstance(value, PrismaJson):
|
||||
value = value.data
|
||||
return _try_convert(value, target_type, raise_on_mismatch=False)
|
||||
return cast(T, _try_convert(value, target_type, raise_on_mismatch=False))
|
||||
except Exception as e:
|
||||
raise ConversionError(f"Failed to convert {value} to {target_type}") from e
|
||||
|
||||
@@ -240,7 +203,6 @@ class FormattedStringType(str):
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type, handler):
|
||||
_ = source_type # unused parameter required by pydantic
|
||||
return handler(str)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,209 +0,0 @@
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import aioclamd
|
||||
from pydantic import BaseModel
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class VirusScanResult(BaseModel):
|
||||
is_clean: bool
|
||||
scan_time_ms: int
|
||||
file_size: int
|
||||
threat_name: Optional[str] = None
|
||||
|
||||
|
||||
class VirusScannerSettings(BaseSettings):
|
||||
# Tunables for the scanner layer (NOT the ClamAV daemon).
|
||||
clamav_service_host: str = "localhost"
|
||||
clamav_service_port: int = 3310
|
||||
clamav_service_timeout: int = 60
|
||||
clamav_service_enabled: bool = True
|
||||
# If the service is disabled, all files are considered clean.
|
||||
mark_failed_scans_as_clean: bool = False
|
||||
# Client-side protective limits
|
||||
max_scan_size: int = 2 * 1024 * 1024 * 1024 # 2 GB guard-rail in memory
|
||||
min_chunk_size: int = 128 * 1024 # 128 KB hard floor
|
||||
max_retries: int = 8 # halve ≤ max_retries times
|
||||
# Concurrency throttle toward the ClamAV daemon. Do *NOT* simply turn this
|
||||
# up to the number of CPU cores; keep it ≤ (MaxThreads / pods) – 1.
|
||||
max_concurrency: int = 5
|
||||
|
||||
|
||||
class VirusScannerService:
|
||||
"""Fully-async ClamAV wrapper using **aioclamd**.
|
||||
|
||||
• Reuses a single `ClamdAsyncClient` connection (aioclamd keeps the socket open).
|
||||
• Throttles concurrent `INSTREAM` calls with an `asyncio.Semaphore` so we don't exhaust daemon worker threads or file descriptors.
|
||||
• Falls back to progressively smaller chunk sizes when the daemon rejects a stream as too large.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: VirusScannerSettings) -> None:
|
||||
self.settings = settings
|
||||
self._client = aioclamd.ClamdAsyncClient(
|
||||
host=settings.clamav_service_host,
|
||||
port=settings.clamav_service_port,
|
||||
timeout=settings.clamav_service_timeout,
|
||||
)
|
||||
self._sem = asyncio.Semaphore(settings.max_concurrency)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
@staticmethod
|
||||
def _parse_raw(raw: Optional[dict]) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Convert aioclamd output to (infected?, threat_name).
|
||||
Returns (False, None) for clean.
|
||||
"""
|
||||
if not raw:
|
||||
return False, None
|
||||
status, threat = next(iter(raw.values()))
|
||||
return status == "FOUND", threat
|
||||
|
||||
async def _instream(self, chunk: bytes) -> Tuple[bool, Optional[str]]:
|
||||
"""Scan **one** chunk with concurrency control."""
|
||||
async with self._sem:
|
||||
try:
|
||||
raw = await self._client.instream(io.BytesIO(chunk))
|
||||
return self._parse_raw(raw)
|
||||
except (BrokenPipeError, ConnectionResetError) as exc:
|
||||
raise RuntimeError("size-limit") from exc
|
||||
except Exception as exc:
|
||||
if "INSTREAM size limit exceeded" in str(exc):
|
||||
raise RuntimeError("size-limit") from exc
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Public API
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
async def scan_file(
|
||||
self, content: bytes, *, filename: str = "unknown"
|
||||
) -> VirusScanResult:
|
||||
"""
|
||||
Scan `content`. Returns a result object or raises on infrastructure
|
||||
failure (unreachable daemon, etc.).
|
||||
The algorithm always tries whole-file first. If the daemon refuses
|
||||
on size grounds, it falls back to chunked parallel scanning.
|
||||
"""
|
||||
if not self.settings.clamav_service_enabled:
|
||||
logger.warning(f"Virus scanning disabled – accepting {filename}")
|
||||
return VirusScanResult(
|
||||
is_clean=True, scan_time_ms=0, file_size=len(content)
|
||||
)
|
||||
if len(content) > self.settings.max_scan_size:
|
||||
logger.warning(
|
||||
f"File {filename} ({len(content)} bytes) exceeds client max scan size ({self.settings.max_scan_size}); Stopping virus scan"
|
||||
)
|
||||
return VirusScanResult(
|
||||
is_clean=self.settings.mark_failed_scans_as_clean,
|
||||
file_size=len(content),
|
||||
scan_time_ms=0,
|
||||
)
|
||||
|
||||
# Ensure daemon is reachable (small RTT check)
|
||||
if not await self._client.ping():
|
||||
raise RuntimeError("ClamAV service is unreachable")
|
||||
|
||||
start = time.monotonic()
|
||||
chunk_size = len(content) # Start with full content length
|
||||
for retry in range(self.settings.max_retries):
|
||||
# For small files, don't check min_chunk_size limit
|
||||
if chunk_size < self.settings.min_chunk_size and chunk_size < len(content):
|
||||
break
|
||||
logger.debug(
|
||||
f"Scanning {filename} with chunk size: {chunk_size // 1_048_576} MB (retry {retry + 1}/{self.settings.max_retries})"
|
||||
)
|
||||
try:
|
||||
tasks = [
|
||||
asyncio.create_task(self._instream(content[o : o + chunk_size]))
|
||||
for o in range(0, len(content), chunk_size)
|
||||
]
|
||||
for coro in asyncio.as_completed(tasks):
|
||||
infected, threat = await coro
|
||||
if infected:
|
||||
for t in tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
return VirusScanResult(
|
||||
is_clean=False,
|
||||
threat_name=threat,
|
||||
file_size=len(content),
|
||||
scan_time_ms=int((time.monotonic() - start) * 1000),
|
||||
)
|
||||
# All chunks clean
|
||||
return VirusScanResult(
|
||||
is_clean=True,
|
||||
file_size=len(content),
|
||||
scan_time_ms=int((time.monotonic() - start) * 1000),
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
if str(exc) == "size-limit":
|
||||
chunk_size //= 2
|
||||
continue
|
||||
logger.error(f"Cannot scan {filename}: {exc}")
|
||||
raise
|
||||
# Phase 3 – give up but warn
|
||||
logger.warning(
|
||||
f"Unable to virus scan {filename} ({len(content)} bytes) even with minimum chunk size ({self.settings.min_chunk_size} bytes). Recommend manual review."
|
||||
)
|
||||
return VirusScanResult(
|
||||
is_clean=self.settings.mark_failed_scans_as_clean,
|
||||
file_size=len(content),
|
||||
scan_time_ms=int((time.monotonic() - start) * 1000),
|
||||
)
|
||||
|
||||
|
||||
_scanner: Optional[VirusScannerService] = None
|
||||
|
||||
|
||||
def get_virus_scanner() -> VirusScannerService:
|
||||
global _scanner
|
||||
if _scanner is None:
|
||||
_settings = VirusScannerSettings(
|
||||
clamav_service_host=settings.config.clamav_service_host,
|
||||
clamav_service_port=settings.config.clamav_service_port,
|
||||
clamav_service_enabled=settings.config.clamav_service_enabled,
|
||||
max_concurrency=settings.config.clamav_max_concurrency,
|
||||
mark_failed_scans_as_clean=settings.config.clamav_mark_failed_scans_as_clean,
|
||||
)
|
||||
_scanner = VirusScannerService(_settings)
|
||||
return _scanner
|
||||
|
||||
|
||||
async def scan_content_safe(content: bytes, *, filename: str = "unknown") -> None:
|
||||
"""
|
||||
Helper function to scan content and raise appropriate exceptions.
|
||||
|
||||
Raises:
|
||||
VirusDetectedError: If virus is found
|
||||
VirusScanError: If scanning fails
|
||||
"""
|
||||
from backend.server.v2.store.exceptions import VirusDetectedError, VirusScanError
|
||||
|
||||
try:
|
||||
result = await get_virus_scanner().scan_file(content, filename=filename)
|
||||
if not result.is_clean:
|
||||
threat_name = result.threat_name or "Unknown threat"
|
||||
logger.warning(f"Virus detected in file {filename}: {threat_name}")
|
||||
raise VirusDetectedError(
|
||||
threat_name, f"File rejected due to virus detection: {threat_name}"
|
||||
)
|
||||
|
||||
logger.info(f"File {filename} passed virus scan in {result.scan_time_ms}ms")
|
||||
|
||||
except VirusDetectedError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Virus scanning failed for {filename}: {str(e)}")
|
||||
raise VirusScanError(f"Virus scanning failed: {str(e)}") from e
|
||||
@@ -1,253 +0,0 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.store.exceptions import VirusDetectedError, VirusScanError
|
||||
from backend.util.virus_scanner import (
|
||||
VirusScannerService,
|
||||
VirusScannerSettings,
|
||||
VirusScanResult,
|
||||
get_virus_scanner,
|
||||
scan_content_safe,
|
||||
)
|
||||
|
||||
|
||||
class TestVirusScannerService:
|
||||
@pytest.fixture
|
||||
def scanner_settings(self):
|
||||
return VirusScannerSettings(
|
||||
clamav_service_host="localhost",
|
||||
clamav_service_port=3310,
|
||||
clamav_service_enabled=True,
|
||||
max_scan_size=10 * 1024 * 1024, # 10MB for testing
|
||||
mark_failed_scans_as_clean=False, # For testing, failed scans should be clean
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def scanner(self, scanner_settings):
|
||||
return VirusScannerService(scanner_settings)
|
||||
|
||||
@pytest.fixture
|
||||
def disabled_scanner(self):
|
||||
settings = VirusScannerSettings(clamav_service_enabled=False)
|
||||
return VirusScannerService(settings)
|
||||
|
||||
def test_scanner_initialization(self, scanner_settings):
|
||||
scanner = VirusScannerService(scanner_settings)
|
||||
assert scanner.settings.clamav_service_host == "localhost"
|
||||
assert scanner.settings.clamav_service_port == 3310
|
||||
assert scanner.settings.clamav_service_enabled is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_disabled_returns_clean(self, disabled_scanner):
|
||||
content = b"test file content"
|
||||
result = await disabled_scanner.scan_file(content, filename="test.txt")
|
||||
|
||||
assert result.is_clean is True
|
||||
assert result.threat_name is None
|
||||
assert result.file_size == len(content)
|
||||
assert result.scan_time_ms == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_file_too_large(self, scanner):
|
||||
# Create content larger than max_scan_size
|
||||
large_content = b"x" * (scanner.settings.max_scan_size + 1)
|
||||
|
||||
# Large files behavior depends on mark_failed_scans_as_clean setting
|
||||
result = await scanner.scan_file(large_content, filename="large_file.txt")
|
||||
assert result.is_clean == scanner.settings.mark_failed_scans_as_clean
|
||||
assert result.file_size == len(large_content)
|
||||
assert result.scan_time_ms == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_file_too_large_both_configurations(self):
|
||||
"""Test large file handling with both mark_failed_scans_as_clean configurations"""
|
||||
large_content = b"x" * (10 * 1024 * 1024 + 1) # Larger than 10MB
|
||||
|
||||
# Test with mark_failed_scans_as_clean=True
|
||||
settings_clean = VirusScannerSettings(
|
||||
max_scan_size=10 * 1024 * 1024, mark_failed_scans_as_clean=True
|
||||
)
|
||||
scanner_clean = VirusScannerService(settings_clean)
|
||||
result_clean = await scanner_clean.scan_file(
|
||||
large_content, filename="large_file.txt"
|
||||
)
|
||||
assert result_clean.is_clean is True
|
||||
|
||||
# Test with mark_failed_scans_as_clean=False
|
||||
settings_dirty = VirusScannerSettings(
|
||||
max_scan_size=10 * 1024 * 1024, mark_failed_scans_as_clean=False
|
||||
)
|
||||
scanner_dirty = VirusScannerService(settings_dirty)
|
||||
result_dirty = await scanner_dirty.scan_file(
|
||||
large_content, filename="large_file.txt"
|
||||
)
|
||||
assert result_dirty.is_clean is False
|
||||
|
||||
# Note: ping method was removed from current implementation
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_clean_file(self, scanner):
|
||||
async def mock_instream(_):
|
||||
await asyncio.sleep(0.001) # Small delay to ensure timing > 0
|
||||
return None # No virus detected
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.ping = AsyncMock(return_value=True)
|
||||
mock_client.instream = AsyncMock(side_effect=mock_instream)
|
||||
|
||||
# Replace the client instance that was created in the constructor
|
||||
scanner._client = mock_client
|
||||
|
||||
content = b"clean file content"
|
||||
result = await scanner.scan_file(content, filename="clean.txt")
|
||||
|
||||
assert result.is_clean is True
|
||||
assert result.threat_name is None
|
||||
assert result.file_size == len(content)
|
||||
assert result.scan_time_ms > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_infected_file(self, scanner):
|
||||
async def mock_instream(_):
|
||||
await asyncio.sleep(0.001) # Small delay to ensure timing > 0
|
||||
return {"stream": ("FOUND", "Win.Test.EICAR_HDB-1")}
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.ping = AsyncMock(return_value=True)
|
||||
mock_client.instream = AsyncMock(side_effect=mock_instream)
|
||||
|
||||
# Replace the client instance that was created in the constructor
|
||||
scanner._client = mock_client
|
||||
|
||||
content = b"infected file content"
|
||||
result = await scanner.scan_file(content, filename="infected.txt")
|
||||
|
||||
assert result.is_clean is False
|
||||
assert result.threat_name == "Win.Test.EICAR_HDB-1"
|
||||
assert result.file_size == len(content)
|
||||
assert result.scan_time_ms > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_clamav_unavailable_fail_safe(self, scanner):
|
||||
mock_client = Mock()
|
||||
mock_client.ping = AsyncMock(return_value=False)
|
||||
|
||||
# Replace the client instance that was created in the constructor
|
||||
scanner._client = mock_client
|
||||
|
||||
content = b"test content"
|
||||
|
||||
with pytest.raises(RuntimeError, match="ClamAV service is unreachable"):
|
||||
await scanner.scan_file(content, filename="test.txt")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_error_fail_safe(self, scanner):
|
||||
mock_client = Mock()
|
||||
mock_client.ping = AsyncMock(return_value=True)
|
||||
mock_client.instream = AsyncMock(side_effect=Exception("Scanning error"))
|
||||
|
||||
# Replace the client instance that was created in the constructor
|
||||
scanner._client = mock_client
|
||||
|
||||
content = b"test content"
|
||||
|
||||
with pytest.raises(Exception, match="Scanning error"):
|
||||
await scanner.scan_file(content, filename="test.txt")
|
||||
|
||||
# Note: scan_file_method and scan_upload_file tests removed as these APIs don't exist in current implementation
|
||||
|
||||
def test_get_virus_scanner_singleton(self):
|
||||
scanner1 = get_virus_scanner()
|
||||
scanner2 = get_virus_scanner()
|
||||
|
||||
# Should return the same instance
|
||||
assert scanner1 is scanner2
|
||||
|
||||
# Note: client_reuse test removed as _get_client method doesn't exist in current implementation
|
||||
|
||||
def test_scan_result_model(self):
|
||||
# Test VirusScanResult model
|
||||
result = VirusScanResult(
|
||||
is_clean=False, threat_name="Test.Virus", scan_time_ms=150, file_size=1024
|
||||
)
|
||||
|
||||
assert result.is_clean is False
|
||||
assert result.threat_name == "Test.Virus"
|
||||
assert result.scan_time_ms == 150
|
||||
assert result.file_size == 1024
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_scans(self, scanner):
|
||||
async def mock_instream(_):
|
||||
await asyncio.sleep(0.001) # Small delay to ensure timing > 0
|
||||
return None
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.ping = AsyncMock(return_value=True)
|
||||
mock_client.instream = AsyncMock(side_effect=mock_instream)
|
||||
|
||||
# Replace the client instance that was created in the constructor
|
||||
scanner._client = mock_client
|
||||
|
||||
content1 = b"file1 content"
|
||||
content2 = b"file2 content"
|
||||
|
||||
# Run concurrent scans
|
||||
results = await asyncio.gather(
|
||||
scanner.scan_file(content1, filename="file1.txt"),
|
||||
scanner.scan_file(content2, filename="file2.txt"),
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert all(result.is_clean for result in results)
|
||||
assert results[0].file_size == len(content1)
|
||||
assert results[1].file_size == len(content2)
|
||||
assert all(result.scan_time_ms > 0 for result in results)
|
||||
|
||||
|
||||
class TestHelperFunctions:
|
||||
"""Test the helper functions scan_content_safe"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_content_safe_clean(self):
|
||||
"""Test scan_content_safe with clean content"""
|
||||
with patch("backend.util.virus_scanner.get_virus_scanner") as mock_get_scanner:
|
||||
mock_scanner = Mock()
|
||||
mock_scanner.scan_file = AsyncMock()
|
||||
mock_scanner.scan_file.return_value = Mock(
|
||||
is_clean=True, threat_name=None, scan_time_ms=50, file_size=100
|
||||
)
|
||||
mock_get_scanner.return_value = mock_scanner
|
||||
|
||||
# Should not raise any exception
|
||||
await scan_content_safe(b"clean content", filename="test.txt")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_content_safe_infected(self):
|
||||
"""Test scan_content_safe with infected content"""
|
||||
with patch("backend.util.virus_scanner.get_virus_scanner") as mock_get_scanner:
|
||||
mock_scanner = Mock()
|
||||
mock_scanner.scan_file = AsyncMock()
|
||||
mock_scanner.scan_file.return_value = Mock(
|
||||
is_clean=False, threat_name="Test.Virus", scan_time_ms=50, file_size=100
|
||||
)
|
||||
mock_get_scanner.return_value = mock_scanner
|
||||
|
||||
with pytest.raises(VirusDetectedError) as exc_info:
|
||||
await scan_content_safe(b"infected content", filename="virus.txt")
|
||||
|
||||
assert exc_info.value.threat_name == "Test.Virus"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_content_safe_scan_error(self):
|
||||
"""Test scan_content_safe when scanning fails"""
|
||||
with patch("backend.util.virus_scanner.get_virus_scanner") as mock_get_scanner:
|
||||
mock_scanner = Mock()
|
||||
mock_scanner.scan_file = AsyncMock()
|
||||
mock_scanner.scan_file.side_effect = Exception("Scan failed")
|
||||
mock_get_scanner.return_value = mock_scanner
|
||||
|
||||
with pytest.raises(VirusScanError, match="Virus scanning failed"):
|
||||
await scan_content_safe(b"test content", filename="test.txt")
|
||||
@@ -1,2 +0,0 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentNodeExecutionInputOutput" ALTER COLUMN "data" DROP NOT NULL;
|
||||
@@ -1,5 +0,0 @@
|
||||
-- Add webhookId column
|
||||
ALTER TABLE "AgentPreset" ADD COLUMN "webhookId" TEXT;
|
||||
|
||||
-- Add AgentPreset<->IntegrationWebhook relation
|
||||
ALTER TABLE "AgentPreset" ADD CONSTRAINT "AgentPreset_webhookId_fkey" FOREIGN KEY ("webhookId") REFERENCES "IntegrationWebhook"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
@@ -1,11 +0,0 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "AgentNodeExecutionKeyValueData" (
|
||||
"userId" TEXT NOT NULL,
|
||||
"key" TEXT NOT NULL,
|
||||
"agentNodeExecutionId" TEXT NOT NULL,
|
||||
"data" JSONB,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3),
|
||||
|
||||
CONSTRAINT "AgentNodeExecutionKeyValueData_pkey" PRIMARY KEY ("userId","key")
|
||||
);
|
||||
458
autogpt_platform/backend/poetry.lock
generated
458
autogpt_platform/backend/poetry.lock
generated
@@ -17,32 +17,20 @@ aiormq = ">=6.8,<6.9"
|
||||
exceptiongroup = ">=1,<2"
|
||||
yarl = "*"
|
||||
|
||||
[[package]]
|
||||
name = "aioclamd"
|
||||
version = "1.0.0"
|
||||
description = "Asynchronous client for virus scanning with ClamAV"
|
||||
optional = false
|
||||
python-versions = ">=3.7,<4.0"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "aioclamd-1.0.0-py3-none-any.whl", hash = "sha256:4727da3953a4b38be4c2de1acb6b3bb3c94c1c171dcac780b80234ee6253f3d9"},
|
||||
{file = "aioclamd-1.0.0.tar.gz", hash = "sha256:7b14e94e3a2285cc89e2f4d434e2a01f322d3cb95476ce2dda015a7980876047"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiodns"
|
||||
version = "3.5.0"
|
||||
version = "3.4.0"
|
||||
description = "Simple DNS resolver for asyncio"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "aiodns-3.5.0-py3-none-any.whl", hash = "sha256:6d0404f7d5215849233f6ee44854f2bb2481adf71b336b2279016ea5990ca5c5"},
|
||||
{file = "aiodns-3.5.0.tar.gz", hash = "sha256:11264edbab51896ecf546c18eb0dd56dff0428c6aa6d2cd87e643e07300eb310"},
|
||||
{file = "aiodns-3.4.0-py3-none-any.whl", hash = "sha256:4da2b25f7475343f3afbb363a2bfe46afa544f2b318acb9a945065e622f4ed24"},
|
||||
{file = "aiodns-3.4.0.tar.gz", hash = "sha256:24b0ae58410530367f21234d0c848e4de52c1f16fbddc111726a4ab536ec1b2f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pycares = ">=4.9.0"
|
||||
pycares = ">=4.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "aiofiles"
|
||||
@@ -222,14 +210,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "anthropic"
|
||||
version = "0.57.1"
|
||||
version = "0.51.0"
|
||||
description = "The official Python library for the anthropic API"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "anthropic-0.57.1-py3-none-any.whl", hash = "sha256:33afc1f395af207d07ff1bffc0a3d1caac53c371793792569c5d2f09283ea306"},
|
||||
{file = "anthropic-0.57.1.tar.gz", hash = "sha256:7815dd92245a70d21f65f356f33fc80c5072eada87fb49437767ea2918b2c4b0"},
|
||||
{file = "anthropic-0.51.0-py3-none-any.whl", hash = "sha256:b8b47d482c9aa1f81b923555cebb687c2730309a20d01be554730c8302e0f62a"},
|
||||
{file = "anthropic-0.51.0.tar.gz", hash = "sha256:6f824451277992af079554430d5b2c8ff5bc059cc2c968cdc3f06824437da201"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -242,7 +230,6 @@ sniffio = "*"
|
||||
typing-extensions = ">=4.10,<5"
|
||||
|
||||
[package.extras]
|
||||
aiohttp = ["aiohttp", "httpx-aiohttp (>=0.1.6)"]
|
||||
bedrock = ["boto3 (>=1.28.57)", "botocore (>=1.31.57)"]
|
||||
vertex = ["google-auth[requests] (>=2,<3)"]
|
||||
|
||||
@@ -1006,14 +993,14 @@ pgp = ["gpg"]
|
||||
|
||||
[[package]]
|
||||
name = "e2b"
|
||||
version = "1.5.4"
|
||||
version = "1.5.0"
|
||||
description = "E2B SDK that give agents cloud environments"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "e2b-1.5.4-py3-none-any.whl", hash = "sha256:9c8d22f9203311dff890e037823596daaba3d793300238117f2efc5426888f2c"},
|
||||
{file = "e2b-1.5.4.tar.gz", hash = "sha256:49f1c115d0198244beef5854d19cc857fda9382e205f137b98d3dae0e7e0b2d2"},
|
||||
{file = "e2b-1.5.0-py3-none-any.whl", hash = "sha256:875a843d1d314a9945e24bfb78c9b1b5cac7e2ecb1e799664d827a26a0b2276a"},
|
||||
{file = "e2b-1.5.0.tar.gz", hash = "sha256:905730eea5c07f271d073d4b5d2a9ef44c8ac04b9b146a99fa0235db77bf6854"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1027,19 +1014,19 @@ typing-extensions = ">=4.1.0"
|
||||
|
||||
[[package]]
|
||||
name = "e2b-code-interpreter"
|
||||
version = "1.5.2"
|
||||
version = "1.5.0"
|
||||
description = "E2B Code Interpreter - Stateful code execution"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "e2b_code_interpreter-1.5.2-py3-none-any.whl", hash = "sha256:5c3188d8f25226b28fef4b255447cc6a4c36afb748bdd5180b45be486d5169f3"},
|
||||
{file = "e2b_code_interpreter-1.5.2.tar.gz", hash = "sha256:3bd6ea70596290e85aaf0a2f19f28bf37a5e73d13086f5e6a0080bb591c5a547"},
|
||||
{file = "e2b_code_interpreter-1.5.0-py3-none-any.whl", hash = "sha256:299f5641a3754264a07f8edc3cccb744d6b009f10dc9285789a9352e24989a9b"},
|
||||
{file = "e2b_code_interpreter-1.5.0.tar.gz", hash = "sha256:cd6028b6f20c4231e88a002de86484b9d4a99ea588b5be183b9ec7189a0f3cf6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
attrs = ">=21.3.0"
|
||||
e2b = ">=1.5.4,<2.0.0"
|
||||
e2b = ">=1.4.0,<2.0.0"
|
||||
httpx = ">=0.20.0,<1.0.0"
|
||||
|
||||
[[package]]
|
||||
@@ -1110,14 +1097,14 @@ typing-extensions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.115.14"
|
||||
version = "0.115.12"
|
||||
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "fastapi-0.115.14-py3-none-any.whl", hash = "sha256:6c0c8bf9420bd58f565e585036d971872472b4f7d3f6c73b698e10cffdefb3ca"},
|
||||
{file = "fastapi-0.115.14.tar.gz", hash = "sha256:b1de15cdc1c499a4da47914db35d0e4ef8f1ce62b624e94e0e5824421df99739"},
|
||||
{file = "fastapi-0.115.12-py3-none-any.whl", hash = "sha256:e94613d6c05e27be7ffebdd6ea5f388112e5e430c8f7d6494a9d1d88d43e814d"},
|
||||
{file = "fastapi-0.115.12.tar.gz", hash = "sha256:1e2c2a2646905f9e83d32f04a3f86aff4a286669c6c950ca95b5fd68c2602681"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1193,20 +1180,20 @@ packaging = ">=20"
|
||||
|
||||
[[package]]
|
||||
name = "flake8"
|
||||
version = "7.3.0"
|
||||
version = "7.2.0"
|
||||
description = "the modular source code checker: pep8 pyflakes and co"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "flake8-7.3.0-py2.py3-none-any.whl", hash = "sha256:b9696257b9ce8beb888cdbe31cf885c90d31928fe202be0889a7cdafad32f01e"},
|
||||
{file = "flake8-7.3.0.tar.gz", hash = "sha256:fe044858146b9fc69b551a4b490d69cf960fcb78ad1edcb84e7fbb1b4a8e3872"},
|
||||
{file = "flake8-7.2.0-py2.py3-none-any.whl", hash = "sha256:93b92ba5bdb60754a6da14fa3b93a9361fd00a59632ada61fd7b130436c40343"},
|
||||
{file = "flake8-7.2.0.tar.gz", hash = "sha256:fa558ae3f6f7dbf2b4f22663e5343b6b6023620461f8d4ff2019ef4b5ee70426"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
mccabe = ">=0.7.0,<0.8.0"
|
||||
pycodestyle = ">=2.14.0,<2.15.0"
|
||||
pyflakes = ">=3.4.0,<3.5.0"
|
||||
pycodestyle = ">=2.13.0,<2.14.0"
|
||||
pyflakes = ">=3.3.0,<3.4.0"
|
||||
|
||||
[[package]]
|
||||
name = "frozenlist"
|
||||
@@ -1357,14 +1344,14 @@ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "google-api-python-client"
|
||||
version = "2.176.0"
|
||||
version = "2.170.0"
|
||||
description = "Google API Client Library for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "google_api_python_client-2.176.0-py3-none-any.whl", hash = "sha256:e22239797f1d085341e12cd924591fc65c56d08e0af02549d7606092e6296510"},
|
||||
{file = "google_api_python_client-2.176.0.tar.gz", hash = "sha256:2b451cdd7fd10faeb5dd20f7d992f185e1e8f4124c35f2cdcc77c843139a4cf1"},
|
||||
{file = "google_api_python_client-2.170.0-py3-none-any.whl", hash = "sha256:7bf518a0527ad23322f070fa69f4f24053170d5c766821dc970ff0571ec22748"},
|
||||
{file = "google_api_python_client-2.170.0.tar.gz", hash = "sha256:75f3a1856f11418ea3723214e0abc59d9b217fd7ed43dcf743aab7f06ab9e2b1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1517,27 +1504,27 @@ protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4
|
||||
|
||||
[[package]]
|
||||
name = "google-cloud-storage"
|
||||
version = "3.2.0"
|
||||
version = "3.1.0"
|
||||
description = "Google Cloud Storage API client library"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "google_cloud_storage-3.2.0-py3-none-any.whl", hash = "sha256:ff7a9a49666954a7c3d1598291220c72d3b9e49d9dfcf9dfaecb301fc4fb0b24"},
|
||||
{file = "google_cloud_storage-3.2.0.tar.gz", hash = "sha256:decca843076036f45633198c125d1861ffbf47ebf5c0e3b98dcb9b2db155896c"},
|
||||
{file = "google_cloud_storage-3.1.0-py2.py3-none-any.whl", hash = "sha256:eaf36966b68660a9633f03b067e4a10ce09f1377cae3ff9f2c699f69a81c66c6"},
|
||||
{file = "google_cloud_storage-3.1.0.tar.gz", hash = "sha256:944273179897c7c8a07ee15f2e6466a02da0c7c4b9ecceac2a26017cb2972049"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
google-api-core = ">=2.15.0,<3.0.0"
|
||||
google-auth = ">=2.26.1,<3.0.0"
|
||||
google-cloud-core = ">=2.4.2,<3.0.0"
|
||||
google-crc32c = ">=1.1.3,<2.0.0"
|
||||
google-resumable-media = ">=2.7.2,<3.0.0"
|
||||
requests = ">=2.22.0,<3.0.0"
|
||||
google-api-core = ">=2.15.0,<3.0.0dev"
|
||||
google-auth = ">=2.26.1,<3.0dev"
|
||||
google-cloud-core = ">=2.4.2,<3.0dev"
|
||||
google-crc32c = ">=1.0,<2.0dev"
|
||||
google-resumable-media = ">=2.7.2"
|
||||
requests = ">=2.18.0,<3.0.0dev"
|
||||
|
||||
[package.extras]
|
||||
protobuf = ["protobuf (>=3.20.2,<7.0.0)"]
|
||||
tracing = ["opentelemetry-api (>=1.1.0,<2.0.0)"]
|
||||
protobuf = ["protobuf (<6.0.0dev)"]
|
||||
tracing = ["opentelemetry-api (>=1.1.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "google-crc32c"
|
||||
@@ -1745,14 +1732,14 @@ test = ["objgraph", "psutil"]
|
||||
|
||||
[[package]]
|
||||
name = "groq"
|
||||
version = "0.29.0"
|
||||
version = "0.24.0"
|
||||
description = "The official Python library for the groq API"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "groq-0.29.0-py3-none-any.whl", hash = "sha256:03515ec46be1ef1feef0cd9d876b6f30a39ee2742e76516153d84acd7c97f23a"},
|
||||
{file = "groq-0.29.0.tar.gz", hash = "sha256:109dc4d696c05d44e4c2cd157652c4c6600c3e96f093f6e158facb5691e37847"},
|
||||
{file = "groq-0.24.0-py3-none-any.whl", hash = "sha256:0020e6b0b2b267263c9eb7c318deef13c12f399c6525734200b11d777b00088e"},
|
||||
{file = "groq-0.24.0.tar.gz", hash = "sha256:e821559de8a77fb81d2585b3faec80ff923d6d64fd52339b33f6c94997d6f7f5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1763,9 +1750,6 @@ pydantic = ">=1.9.0,<3"
|
||||
sniffio = "*"
|
||||
typing-extensions = ">=4.10,<5"
|
||||
|
||||
[package.extras]
|
||||
aiohttp = ["aiohttp", "httpx-aiohttp (>=0.1.6)"]
|
||||
|
||||
[[package]]
|
||||
name = "grpc-google-iam-v1"
|
||||
version = "0.14.2"
|
||||
@@ -2552,14 +2536,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "mem0ai"
|
||||
version = "0.1.114"
|
||||
version = "0.1.102"
|
||||
description = "Long-term memory for AI Agents"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "mem0ai-0.1.114-py3-none-any.whl", hash = "sha256:dfb7f0079ee282f5d9782e220f6f09707bcf5e107925d1901dbca30d8dd83f9b"},
|
||||
{file = "mem0ai-0.1.114.tar.gz", hash = "sha256:b27886132eaec78544e8b8b54f0b14a36728f3c99da54cb7cb417150e2fad7e1"},
|
||||
{file = "mem0ai-0.1.102-py3-none-any.whl", hash = "sha256:1401ccfd2369e2182ce78abb61b817e739fe49508b5a8ad98abcd4f8ad4db0b4"},
|
||||
{file = "mem0ai-0.1.102.tar.gz", hash = "sha256:7358dba4fbe954b9c3f33204c14df7babaf9067e2eb48241d89a32e6bc774988"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2572,11 +2556,8 @@ sqlalchemy = ">=2.0.31"
|
||||
|
||||
[package.extras]
|
||||
dev = ["isort (>=5.13.2)", "pytest (>=8.2.2)", "ruff (>=0.6.5)"]
|
||||
extras = ["boto3 (>=1.34.0)", "elasticsearch (>=8.0.0)", "langchain-community (>=0.0.0)", "langchain-memgraph (>=0.1.0)", "opensearch-py (>=2.0.0)", "sentence-transformers (>=5.0.0)"]
|
||||
graph = ["langchain-aws (>=0.2.23)", "langchain-neo4j (>=0.4.0)", "neo4j (>=5.23.1)", "rank-bm25 (>=0.2.2)"]
|
||||
llms = ["google-genai (>=1.0.0)", "google-generativeai (>=0.3.0)", "groq (>=0.3.0)", "litellm (>=0.1.0)", "ollama (>=0.1.0)", "together (>=0.2.10)", "vertexai (>=0.1.0)"]
|
||||
graph = ["langchain-neo4j (>=0.4.0)", "neo4j (>=5.23.1)", "rank-bm25 (>=0.2.2)"]
|
||||
test = ["pytest (>=8.2.2)", "pytest-asyncio (>=0.23.7)", "pytest-mock (>=3.14.0)"]
|
||||
vector-stores = ["azure-search-documents (>=11.4.0b8)", "chromadb (>=0.4.24)", "faiss-cpu (>=1.7.4)", "pinecone (<=7.3.0)", "pinecone-text (>=0.10.0)", "pymochow (>=2.2.9)", "pymongo (>=4.13.2)", "upstash-vector (>=0.1.0)", "vecs (>=0.4.0)", "weaviate-client (>=4.4.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "more-itertools"
|
||||
@@ -2915,14 +2896,14 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
|
||||
|
||||
[[package]]
|
||||
name = "ollama"
|
||||
version = "0.5.1"
|
||||
version = "0.4.9"
|
||||
description = "The official Python client for Ollama."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "ollama-0.5.1-py3-none-any.whl", hash = "sha256:4c8839f35bc173c7057b1eb2cbe7f498c1a7e134eafc9192824c8aecb3617506"},
|
||||
{file = "ollama-0.5.1.tar.gz", hash = "sha256:5a799e4dc4e7af638b11e3ae588ab17623ee019e496caaf4323efbaa8feeff93"},
|
||||
{file = "ollama-0.4.9-py3-none-any.whl", hash = "sha256:18c8c85358c54d7f73d6a66cda495b0e3ba99fdb88f824ae470d740fbb211a50"},
|
||||
{file = "ollama-0.4.9.tar.gz", hash = "sha256:5266d4d29b5089a01489872b8e8f980f018bccbdd1082b3903448af1d5615ce7"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2931,14 +2912,14 @@ pydantic = ">=2.9"
|
||||
|
||||
[[package]]
|
||||
name = "openai"
|
||||
version = "1.93.2"
|
||||
version = "1.82.1"
|
||||
description = "The official Python library for the openai API"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openai-1.93.2-py3-none-any.whl", hash = "sha256:5adbbebd48eae160e6d68efc4c0a4f7cb1318a44c62d9fc626cec229f418eab4"},
|
||||
{file = "openai-1.93.2.tar.gz", hash = "sha256:4a7312b426b5e4c98b78dfa1148b5683371882de3ad3d5f7c8e0c74f3cc90778"},
|
||||
{file = "openai-1.82.1-py3-none-any.whl", hash = "sha256:334eb5006edf59aa464c9e932b9d137468d810b2659e5daea9b3a8c39d052395"},
|
||||
{file = "openai-1.82.1.tar.gz", hash = "sha256:ffc529680018e0417acac85f926f92aa0bbcbc26e82e2621087303c66bc7f95d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2952,7 +2933,6 @@ tqdm = ">4"
|
||||
typing-extensions = ">=4.11,<5"
|
||||
|
||||
[package.extras]
|
||||
aiohttp = ["aiohttp", "httpx-aiohttp (>=0.1.6)"]
|
||||
datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
|
||||
realtime = ["websockets (>=13,<16)"]
|
||||
voice-helpers = ["numpy (>=2.0.2)", "sounddevice (>=0.5.1)"]
|
||||
@@ -3267,14 +3247,14 @@ testing = ["coverage", "pytest", "pytest-benchmark"]
|
||||
|
||||
[[package]]
|
||||
name = "poethepoet"
|
||||
version = "0.36.0"
|
||||
description = "A task runner that works well with poetry and uv."
|
||||
version = "0.34.0"
|
||||
description = "A task runner that works well with poetry."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "poethepoet-0.36.0-py3-none-any.whl", hash = "sha256:693e3c1eae9f6731d3613c3c0c40f747d3c5c68a375beda42e590a63c5623308"},
|
||||
{file = "poethepoet-0.36.0.tar.gz", hash = "sha256:2217b49cb4e4c64af0b42ff8c4814b17f02e107d38bc461542517348ede25663"},
|
||||
{file = "poethepoet-0.34.0-py3-none-any.whl", hash = "sha256:c472d6f0fdb341b48d346f4ccd49779840c15b30dfd6bc6347a80d6274b5e34e"},
|
||||
{file = "poethepoet-0.34.0.tar.gz", hash = "sha256:86203acce555bbfe45cb6ccac61ba8b16a5784264484195874da457ddabf5850"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -3500,14 +3480,14 @@ tqdm = "*"
|
||||
|
||||
[[package]]
|
||||
name = "prometheus-client"
|
||||
version = "0.22.1"
|
||||
version = "0.21.1"
|
||||
description = "Python client for the Prometheus monitoring system."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "prometheus_client-0.22.1-py3-none-any.whl", hash = "sha256:cca895342e308174341b2cbf99a56bef291fbc0ef7b9e5412a0f26d653ba7094"},
|
||||
{file = "prometheus_client-0.22.1.tar.gz", hash = "sha256:190f1331e783cf21eb60bca559354e0a4d4378facecf78f5428c39b675d20d28"},
|
||||
{file = "prometheus_client-0.21.1-py3-none-any.whl", hash = "sha256:594b45c410d6f4f8888940fe80b5cc2521b305a1fafe1c58609ef715a001f301"},
|
||||
{file = "prometheus_client-0.21.1.tar.gz", hash = "sha256:252505a722ac04b0456be05c05f75f45d760c2911ffc45f2a06bcaed9f3ae3fb"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
@@ -3791,88 +3771,83 @@ pyasn1 = ">=0.6.1,<0.7.0"
|
||||
|
||||
[[package]]
|
||||
name = "pycares"
|
||||
version = "4.9.0"
|
||||
version = "4.8.0"
|
||||
description = "Python interface for c-ares"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "pycares-4.9.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0b8bd9a3ee6e9bc990e1933dc7e7e2f44d4184f49a90fa444297ac12ab6c0c84"},
|
||||
{file = "pycares-4.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:417a5c20861f35977240ad4961479a6778125bcac21eb2ad1c3aad47e2ff7fab"},
|
||||
{file = "pycares-4.9.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ab290faa4ea53ce53e3ceea1b3a42822daffce2d260005533293a52525076750"},
|
||||
{file = "pycares-4.9.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b1df81193084c9717734e4615e8c5074b9852478c9007d1a8bb242f7f580e67"},
|
||||
{file = "pycares-4.9.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:20c7a6af0c2ccd17cc5a70d76e299a90e7ebd6c4d8a3d7fff5ae533339f61431"},
|
||||
{file = "pycares-4.9.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:370f41442a5b034aebdb2719b04ee04d3e805454a20d3f64f688c1c49f9137c3"},
|
||||
{file = "pycares-4.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:340e4a3bbfd14d73c01ec0793a321b8a4a93f64c508225883291078b7ee17ac8"},
|
||||
{file = "pycares-4.9.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f0ec94785856ea4f5556aa18f4c027361ba4b26cb36c4ad97d2105ef4eec68ba"},
|
||||
{file = "pycares-4.9.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:dd6b7e23a4a9e2039b5d67dfa0499d2d5f114667dc13fb5d7d03eed230c7ac4f"},
|
||||
{file = "pycares-4.9.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:490c978b0be9d35a253a5e31dd598f6d66b453625f0eb7dc2d81b22b8c3bb3f4"},
|
||||
{file = "pycares-4.9.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:e433faaf07f44e44f1a1b839fee847480fe3db9431509dafc9f16d618d491d0f"},
|
||||
{file = "pycares-4.9.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:cf6d8851a06b79d10089962c9dadcb34dad00bf027af000f7102297a54aaff2e"},
|
||||
{file = "pycares-4.9.0-cp310-cp310-win32.whl", hash = "sha256:4f803e7d66ac7d8342998b8b07393788991353a46b05bbaad0b253d6f3484ea8"},
|
||||
{file = "pycares-4.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:8e17bd32267e3870855de3baed7d0efa6337344d68f44853fd9195c919f39400"},
|
||||
{file = "pycares-4.9.0-cp310-cp310-win_arm64.whl", hash = "sha256:6b74f75d8e430f9bb11a1cc99b2e328eed74b17d8d4b476de09126f38d419eb9"},
|
||||
{file = "pycares-4.9.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:16a97ee83ec60d35c7f716f117719932c27d428b1bb56b242ba1c4aa55521747"},
|
||||
{file = "pycares-4.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:78748521423a211ce699a50c27cc5c19e98b7db610ccea98daad652ace373990"},
|
||||
{file = "pycares-4.9.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8818b2c7a57d9d6d41e8b64d9ff87992b8ea2522fc0799686725228bc3cff6c5"},
|
||||
{file = "pycares-4.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:96df8990f16013ca5194d6ece19dddb4ef9cd7c3efaab9f196ec3ccd44b40f8d"},
|
||||
{file = "pycares-4.9.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:61af86fd58b8326e723b0d20fb96b56acaec2261c3a7c9a1c29d0a79659d613a"},
|
||||
{file = "pycares-4.9.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ec72edb276bda559813cc807bc47b423d409ffab2402417a5381077e9c2c6be1"},
|
||||
{file = "pycares-4.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:832fb122c7376c76cab62f8862fa5e398b9575fb7c9ff6bc9811086441ee64ca"},
|
||||
{file = "pycares-4.9.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:cdcfaef24f771a471671470ccfd676c0366ab6b0616fd8217b8f356c40a02b83"},
|
||||
{file = "pycares-4.9.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:52cb056d06ff55d78a8665b97ae948abaaba2ca200ca59b10346d4526bce1e7d"},
|
||||
{file = "pycares-4.9.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:54985ed3f2e8a87315269f24cb73441622857a7830adfc3a27c675a94c3261c1"},
|
||||
{file = "pycares-4.9.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:08048e223615d4aef3dac81fe0ea18fb18d6fc97881f1eb5be95bb1379969b8d"},
|
||||
{file = "pycares-4.9.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:cc60037421ce05a409484287b2cd428e1363cca73c999b5f119936bb8f255208"},
|
||||
{file = "pycares-4.9.0-cp311-cp311-win32.whl", hash = "sha256:62b86895b60cfb91befb3086caa0792b53f949231c6c0c3053c7dfee3f1386ab"},
|
||||
{file = "pycares-4.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:7046b3c80954beaabf2db52b09c3d6fe85f6c4646af973e61be79d1c51589932"},
|
||||
{file = "pycares-4.9.0-cp311-cp311-win_arm64.whl", hash = "sha256:fcbda3fdf44e94d3962ca74e6ba3dc18c0d7029106f030d61c04c0876f319403"},
|
||||
{file = "pycares-4.9.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d68ca2da1001aeccdc81c4a2fb1f1f6cfdafd3d00e44e7c1ed93e3e05437f666"},
|
||||
{file = "pycares-4.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4f0c8fa5a384d79551a27eafa39eed29529e66ba8fa795ee432ab88d050432a3"},
|
||||
{file = "pycares-4.9.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0eb8c428cf3b9c6ff9c641ba50ab6357b4480cd737498733e6169b0ac8a1a89b"},
|
||||
{file = "pycares-4.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6845bd4a43abf6dab7fedbf024ef458ac3750a25b25076ea9913e5ac5fec4548"},
|
||||
{file = "pycares-4.9.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5e28f4acc3b97e46610cf164665ebf914f709daea6ced0ca4358ce55bc1c3d6b"},
|
||||
{file = "pycares-4.9.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9464a39861840ce35a79352c34d653a9db44f9333af7c9feddb97998d3e00c07"},
|
||||
{file = "pycares-4.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0611c1bd46d1fc6bdd9305b8850eb84c77df485769f72c574ed7b8389dfbee2"},
|
||||
{file = "pycares-4.9.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d4fb5a38a51d03b75ac4320357e632c2e72e03fdeb13263ee333a40621415fdc"},
|
||||
{file = "pycares-4.9.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:df5edae05fb3e1370ab7639e67e8891fdaa9026cb10f05dbd57893713f7a9cfe"},
|
||||
{file = "pycares-4.9.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:397123ea53d261007bb0aa7e767ef238778f45026db40bed8196436da2cc73de"},
|
||||
{file = "pycares-4.9.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:bb0d874d0b131b29894fd8a0f842be91ac21d50f90ec04cff4bb3f598464b523"},
|
||||
{file = "pycares-4.9.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:497cc03a61ec1585eb17d2cb086a29a6a67d24babf1e9be519b47222916a3b06"},
|
||||
{file = "pycares-4.9.0-cp312-cp312-win32.whl", hash = "sha256:b46e46313fdb5e82da15478652aac0fd15e1c9f33e08153bad845aa4007d6f84"},
|
||||
{file = "pycares-4.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:12547a06445777091605a7581da15a0da158058beb8a05a3ebbf7301fd1f58d4"},
|
||||
{file = "pycares-4.9.0-cp312-cp312-win_arm64.whl", hash = "sha256:f1e10bf1e8eb80b08e5c828627dba1ebc4acd54803bd0a27d92b9063b6aa99d8"},
|
||||
{file = "pycares-4.9.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:574d815112a95ab09d75d0a9dc7dea737c06985e3125cf31c32ba6a3ed6ca006"},
|
||||
{file = "pycares-4.9.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50e5ab06361d59625a27a7ad93d27e067dc7c9f6aa529a07d691eb17f3b43605"},
|
||||
{file = "pycares-4.9.0-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:785f5fd11ff40237d9bc8afa441551bb449e2812c74334d1d10859569e07515c"},
|
||||
{file = "pycares-4.9.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e194a500e403eba89b91fb863c917495c5b3dfcd1ce0ee8dc3a6f99a1360e2fc"},
|
||||
{file = "pycares-4.9.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:112dd49cdec4e6150a8d95b197e8b6b7b4468a3170b30738ed9b248cb2240c04"},
|
||||
{file = "pycares-4.9.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:94aa3c2f3eb0aa69160137134775501f06c901188e722aac63d2a210d4084f99"},
|
||||
{file = "pycares-4.9.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b510d71255cf5a92ccc2643a553548fcb0623d6ed11c8c633b421d99d7fa4167"},
|
||||
{file = "pycares-4.9.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5c6aa30b1492b8130f7832bf95178642c710ce6b7ba610c2b17377f77177e3cd"},
|
||||
{file = "pycares-4.9.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:e5767988e044faffe2aff6a76aa08df99a8b6ef2641be8b00ea16334ce5dea93"},
|
||||
{file = "pycares-4.9.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:b9928a942820a82daa3207509eaba9e0fa9660756ac56667ec2e062815331fcb"},
|
||||
{file = "pycares-4.9.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:556c854174da76d544714cdfab10745ed5d4b99eec5899f7b13988cd26ff4763"},
|
||||
{file = "pycares-4.9.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d42e2202ca9aa9a0a9a6e43a4a4408bbe0311aaa44800fa27b8fd7f82b20152a"},
|
||||
{file = "pycares-4.9.0-cp313-cp313-win32.whl", hash = "sha256:cce8ef72c9ed4982c84114e6148a4e42e989d745de7862a0ad8b3f1cdc05def2"},
|
||||
{file = "pycares-4.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:318cdf24f826f1d2f0c5a988730bd597e1683296628c8f1be1a5b96643c284fe"},
|
||||
{file = "pycares-4.9.0-cp313-cp313-win_arm64.whl", hash = "sha256:faa9de8e647ed06757a2c117b70a7645a755561def814da6aca0d766cf71a402"},
|
||||
{file = "pycares-4.9.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8310d27d68fa25be9781ce04d330f4860634a2ac34dd9265774b5f404679b41f"},
|
||||
{file = "pycares-4.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99cf98452d3285307eec123049f2c9c50b109e06751b0727c6acefb6da30c6a0"},
|
||||
{file = "pycares-4.9.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ffd6e8c8250655504602b076f106653e085e6b1e15318013442558101aa4777"},
|
||||
{file = "pycares-4.9.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4065858d8c812159c9a55601fda73760d9e5e3300f7868d9e546eab1084f36c"},
|
||||
{file = "pycares-4.9.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91ee6818113faf9013945c2b54bcd6b123d0ac192ae3099cf4288cedaf2dbb25"},
|
||||
{file = "pycares-4.9.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:21f0602059ec11857ab7ad608c7ec8bc6f7a302c04559ec06d33e82f040585f8"},
|
||||
{file = "pycares-4.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e22e5b46ed9b12183091da56e4a5a20813b5436c4f13135d7a1c20a84027ca8a"},
|
||||
{file = "pycares-4.9.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:9eded8649867bfd7aea7589c5755eae4d37686272f6ed7a995da40890d02de71"},
|
||||
{file = "pycares-4.9.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:f71d31cbbe066657a2536c98aad850724a9ab7b1cd2624f491832ae9667ea8e7"},
|
||||
{file = "pycares-4.9.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:2b30945982ab4741f097efc5b0853051afc3c11df26996ed53a700c7575175af"},
|
||||
{file = "pycares-4.9.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:54a8f1f067d64810426491d33033f5353b54f35e5339126440ad4e6afbf3f149"},
|
||||
{file = "pycares-4.9.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:41556a269a192349e92eee953f62eddd867e9eddb27f444b261e2c1c4a4a9eff"},
|
||||
{file = "pycares-4.9.0-cp39-cp39-win32.whl", hash = "sha256:524d6c14eaa167ed098a4fe54856d1248fa20c296cdd6976f9c1b838ba32d014"},
|
||||
{file = "pycares-4.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:15f930c733d36aa487b4ad60413013bd811281b5ea4ca620070fa38505d84df4"},
|
||||
{file = "pycares-4.9.0-cp39-cp39-win_arm64.whl", hash = "sha256:79b7addb2a41267d46650ac0d9c4f3b3233b036f186b85606f7586881dfb4b69"},
|
||||
{file = "pycares-4.9.0.tar.gz", hash = "sha256:8ee484ddb23dbec4d88d14ed5b6d592c1960d2e93c385d5e52b6fad564d82395"},
|
||||
{file = "pycares-4.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f40d9f4a8de398b110fdf226cdfadd86e8c7eb71d5298120ec41cf8d94b0012f"},
|
||||
{file = "pycares-4.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:339de06fc849a51015968038d2bbed68fc24047522404af9533f32395ca80d25"},
|
||||
{file = "pycares-4.8.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:372a236c1502b9056b0bea195c64c329603b4efa70b593a33b7ae37fbb7fad00"},
|
||||
{file = "pycares-4.8.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03f66a5e143d102ccc204bd4e29edd70bed28420f707efd2116748241e30cb73"},
|
||||
{file = "pycares-4.8.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ef50504296cd5fc58cfd6318f82e20af24fbe2c83004f6ff16259adb13afdf14"},
|
||||
{file = "pycares-4.8.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d1bc541b627c7951dd36136b18bd185c5244a0fb2af5b1492ffb8acaceec1c5b"},
|
||||
{file = "pycares-4.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:938d188ed6bed696099be67ebdcdf121827b9432b17a9ea9e40dc35fd9d85363"},
|
||||
{file = "pycares-4.8.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:327837ffdc0c7adda09c98e1263c64b2aff814eea51a423f66733c75ccd9a642"},
|
||||
{file = "pycares-4.8.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:a6b9b8d08c4508c45bd39e0c74e9e7052736f18ca1d25a289365bb9ac36e5849"},
|
||||
{file = "pycares-4.8.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:feac07d5e6d2d8f031c71237c21c21b8c995b41a1eba64560e8cf1e42ac11bc6"},
|
||||
{file = "pycares-4.8.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:5bcdbf37012fd2323ca9f2a1074421a9ccf277d772632f8f0ce8c46ec7564250"},
|
||||
{file = "pycares-4.8.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e3ebb692cb43fcf34fe0d26f2cf9a0ea53fdfb136463845b81fad651277922db"},
|
||||
{file = "pycares-4.8.0-cp310-cp310-win32.whl", hash = "sha256:d98447ec0efff3fa868ccc54dcc56e71faff498f8848ecec2004c3108efb4da2"},
|
||||
{file = "pycares-4.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:1abb8f40917960ead3c2771277f0bdee1967393b0fdf68743c225b606787da68"},
|
||||
{file = "pycares-4.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5e25db89005ddd8d9c5720293afe6d6dd92e682fc6bc7a632535b84511e2060d"},
|
||||
{file = "pycares-4.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6f9665ef116e6ee216c396f5f927756c2164f9f3316aec7ff1a9a1e1e7ec9b2a"},
|
||||
{file = "pycares-4.8.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54a96893133471f6889b577147adcc21a480dbe316f56730871028379c8313f3"},
|
||||
{file = "pycares-4.8.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51024b3a69762bd3100d94986a29922be15e13f56f991aaefb41f5bcd3d7f0bb"},
|
||||
{file = "pycares-4.8.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:47ff9db50c599e4d965ae3bec99cc30941c1d2b0f078ec816680b70d052dd54a"},
|
||||
{file = "pycares-4.8.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:27ef8ff4e0f60ea6769a60d1c3d1d2aefed1d832e7bb83fc3934884e2dba5cdd"},
|
||||
{file = "pycares-4.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63511af7a3f9663f562fbb6bfa3591a259505d976e2aba1fa2da13dde43c6ca7"},
|
||||
{file = "pycares-4.8.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:73c3219b47616e6a5ad1810de96ed59721c7751f19b70ae7bf24997a8365408f"},
|
||||
{file = "pycares-4.8.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:da42a45207c18f37be5e491c14b6d1063cfe1e46620eb661735d0cedc2b59099"},
|
||||
{file = "pycares-4.8.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:8a068e898bb5dd09cd654e19cd2abf20f93d0cc59d5d955135ed48ea0f806aa1"},
|
||||
{file = "pycares-4.8.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:962aed95675bb66c0b785a2fbbd1bb58ce7f009e283e4ef5aaa4a1f2dc00d217"},
|
||||
{file = "pycares-4.8.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ce8b1a16c1e4517a82a0ebd7664783a327166a3764d844cf96b1fb7b9dd1e493"},
|
||||
{file = "pycares-4.8.0-cp311-cp311-win32.whl", hash = "sha256:b3749ddbcbd216376c3b53d42d8b640b457133f1a12b0e003f3838f953037ae7"},
|
||||
{file = "pycares-4.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:5ce8a4e1b485b2360ab666c4ea1db97f57ede345a3b566d80bfa52b17e616610"},
|
||||
{file = "pycares-4.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3273e01a75308ed06d2492d83c7ba476e579a60a24d9f20fe178ce5e9d8d028b"},
|
||||
{file = "pycares-4.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fcedaadea1f452911fd29935749f98d144dae758d6003b7e9b6c5d5bd47d1dff"},
|
||||
{file = "pycares-4.8.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aae6cb33e287e06a4aabcbc57626df682c9a4fa8026207f5b498697f1c2fb562"},
|
||||
{file = "pycares-4.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25038b930e5be82839503fb171385b2aefd6d541bc5b7da0938bdb67780467d2"},
|
||||
{file = "pycares-4.8.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cc8499b6e7dfbe4af65f6938db710ce9acd1debf34af2cbb93b898b1e5da6a5a"},
|
||||
{file = "pycares-4.8.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c4e1c6a68ef56a7622f6176d9946d4e51f3c853327a0123ef35a5380230c84cd"},
|
||||
{file = "pycares-4.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7cc8c3c9114b9c84e4062d25ca9b4bddc80a65d0b074c7cb059275273382f89"},
|
||||
{file = "pycares-4.8.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4404014069d3e362abf404c9932d4335bb9c07ba834cfe7d683c725b92e0f9da"},
|
||||
{file = "pycares-4.8.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:ee0a58c32ec2a352cef0e1d20335a7caf9871cd79b73be2ca2896fe70f09c9d7"},
|
||||
{file = "pycares-4.8.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:35f32f52b486b8fede3cbebf088f30b01242d0321b5216887c28e80490595302"},
|
||||
{file = "pycares-4.8.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:ecbb506e27a3b3a2abc001c77beeccf265475c84b98629a6b3e61bd9f2987eaa"},
|
||||
{file = "pycares-4.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:9392b2a34adbf60cb9e38f4a0d363413ecea8d835b5a475122f50f76676d59dd"},
|
||||
{file = "pycares-4.8.0-cp312-cp312-win32.whl", hash = "sha256:f0fbefe68403ffcff19c869b8d621c88a6d2cef18d53cf0dab0fa9458a6ca712"},
|
||||
{file = "pycares-4.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:fa8aab6085a2ddfb1b43a06ddf1b498347117bb47cd620d9b12c43383c9c2737"},
|
||||
{file = "pycares-4.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:358a9a2c6fed59f62788e63d88669224955443048a1602016d4358e92aedb365"},
|
||||
{file = "pycares-4.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0e3e1278967fa8d4a0056be3fcc8fc551b8bad1fc7d0e5172196dccb8ddb036a"},
|
||||
{file = "pycares-4.8.0-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:79befb773e370a8f97de9f16f5ea2c7e7fa0e3c6c74fbea6d332bf58164d7d06"},
|
||||
{file = "pycares-4.8.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2b00d3695db64ce98a34e632e1d53f5a1cdb25451489f227bec2a6c03ff87ee8"},
|
||||
{file = "pycares-4.8.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:37bdc4f2ff0612d60fc4f7547e12ff02cdcaa9a9e42e827bb64d4748994719f1"},
|
||||
{file = "pycares-4.8.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd92c44498ec7a6139888b464b28c49f7ba975933689bd67ea8d572b94188404"},
|
||||
{file = "pycares-4.8.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2665a0d810e2bbc41e97f3c3e5ea7950f666b3aa19c5f6c99d6b018ccd2e0052"},
|
||||
{file = "pycares-4.8.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:45a629a6470a33478514c566bce50c63f1b17d1c5f2f964c9a6790330dc105fb"},
|
||||
{file = "pycares-4.8.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:47bb378f1773f41cca8e31dcdf009ce4a9b8aff8a30c7267aaff9a099c407ba5"},
|
||||
{file = "pycares-4.8.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:fb3feae38458005cc101956e38f16eb3145fff8cd793e35cd4bdef6bf1aa2623"},
|
||||
{file = "pycares-4.8.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:14bc28aeaa66b0f4331ac94455e8043c8a06b3faafd78cc49d4b677bae0d0b08"},
|
||||
{file = "pycares-4.8.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:62c82b871470f2864a1febf7b96bb1d108ce9063e6d3d43727e8a46f0028a456"},
|
||||
{file = "pycares-4.8.0-cp313-cp313-win32.whl", hash = "sha256:01afa8964c698c8f548b46d726f766aa7817b2d4386735af1f7996903d724920"},
|
||||
{file = "pycares-4.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:22f86f81b12ab17b0a7bd0da1e27938caaed11715225c1168763af97f8bb51a7"},
|
||||
{file = "pycares-4.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:61325d13a95255e858f42a7a1a9e482ff47ef2233f95ad9a4f308a3bd8ecf903"},
|
||||
{file = "pycares-4.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:dfec3a7d42336fa46a1e7e07f67000fd4b97860598c59a894c08f81378629e4e"},
|
||||
{file = "pycares-4.8.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b65067e4b4f5345688817fff6be06b9b1f4ec3619b0b9ecc639bc681b73f646b"},
|
||||
{file = "pycares-4.8.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0322ad94bbaa7016139b5bbdcd0de6f6feb9d146d69e03a82aaca342e06830a6"},
|
||||
{file = "pycares-4.8.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:456c60f170c997f9a43c7afa1085fced8efb7e13ae49dd5656f998ae13c4bdb4"},
|
||||
{file = "pycares-4.8.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57a2c4c9ce423a85b0e0227409dbaf0d478f5e0c31d9e626768e77e1e887d32f"},
|
||||
{file = "pycares-4.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:478d9c479108b7527266864c0affe3d6e863492c9bc269217e36100c8fd89b91"},
|
||||
{file = "pycares-4.8.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:aed56bca096990ca0aa9bbf95761fc87e02880e04b0845922b5c12ea9abe523f"},
|
||||
{file = "pycares-4.8.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ef265a390928ee2f77f8901c2273c53293157860451ad453ce7f45dd268b72f9"},
|
||||
{file = "pycares-4.8.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:a5f17d7a76d8335f1c90a8530c8f1e8bb22e9a1d70a96f686efaed946de1c908"},
|
||||
{file = "pycares-4.8.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:891f981feb2ef34367378f813fc17b3d706ce95b6548eeea0c9fe7705d7e54b1"},
|
||||
{file = "pycares-4.8.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:4102f6d9117466cc0a1f527907a1454d109cc9e8551b8074888071ef16050fe3"},
|
||||
{file = "pycares-4.8.0-cp39-cp39-win32.whl", hash = "sha256:d6775308659652adc88c82c53eda59b5e86a154aaba5ad1e287bbb3e0be77076"},
|
||||
{file = "pycares-4.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:8bc05462aa44788d48544cca3d2532466fed2cdc5a2f24a43a92b620a61c9d19"},
|
||||
{file = "pycares-4.8.0.tar.gz", hash = "sha256:2fc2ebfab960f654b3e3cf08a732486950da99393a657f8b44618ad3ed2d39c1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -3883,14 +3858,14 @@ idna = ["idna (>=2.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pycodestyle"
|
||||
version = "2.14.0"
|
||||
version = "2.13.0"
|
||||
description = "Python style guide checker"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "pycodestyle-2.14.0-py2.py3-none-any.whl", hash = "sha256:dd6bf7cb4ee77f8e016f9c8e74a35ddd9f67e1d5fd4184d86c3b98e07099f42d"},
|
||||
{file = "pycodestyle-2.14.0.tar.gz", hash = "sha256:c4b5b517d278089ff9d0abdec919cd97262a3367449ea1c8b49b91529167b783"},
|
||||
{file = "pycodestyle-2.13.0-py2.py3-none-any.whl", hash = "sha256:35863c5974a271c7a726ed228a14a4f6daf49df369d8c50cd9a6f58a5e143ba9"},
|
||||
{file = "pycodestyle-2.13.0.tar.gz", hash = "sha256:c8415bf09abe81d9c7f872502a6eee881fbe85d8763dd5b9924bb0a01d67efae"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3907,14 +3882,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.11.7"
|
||||
version = "2.11.5"
|
||||
description = "Data validation using Python type hints"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "pydantic-2.11.7-py3-none-any.whl", hash = "sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b"},
|
||||
{file = "pydantic-2.11.7.tar.gz", hash = "sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db"},
|
||||
{file = "pydantic-2.11.5-py3-none-any.whl", hash = "sha256:f9c26ba06f9747749ca1e5c94d6a85cb84254577553c8785576fd38fa64dc0f7"},
|
||||
{file = "pydantic-2.11.5.tar.gz", hash = "sha256:7f853db3d0ce78ce8bbb148c401c2cdd6431b3473c0cdff2755c7690952a7b7a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -4042,14 +4017,14 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
|
||||
|
||||
[[package]]
|
||||
name = "pydantic-settings"
|
||||
version = "2.10.1"
|
||||
version = "2.9.1"
|
||||
description = "Settings management using Pydantic"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "pydantic_settings-2.10.1-py3-none-any.whl", hash = "sha256:a60952460b99cf661dc25c29c0ef171721f98bfcb52ef8d9ea4c943d7c8cc796"},
|
||||
{file = "pydantic_settings-2.10.1.tar.gz", hash = "sha256:06f0062169818d0f5524420a360d632d5857b83cffd4d42fe29597807a1614ee"},
|
||||
{file = "pydantic_settings-2.9.1-py3-none-any.whl", hash = "sha256:59b4f431b1defb26fe620c71a7d3968a710d719f5f4cdbbdb7926edeb770f6ef"},
|
||||
{file = "pydantic_settings-2.9.1.tar.gz", hash = "sha256:c509bf79d27563add44e8446233359004ed85066cd096d8b510f715e6ef5d268"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -4066,31 +4041,16 @@ yaml = ["pyyaml (>=6.0.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pyflakes"
|
||||
version = "3.4.0"
|
||||
version = "3.3.2"
|
||||
description = "passive checker of Python programs"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "pyflakes-3.4.0-py2.py3-none-any.whl", hash = "sha256:f742a7dbd0d9cb9ea41e9a24a918996e8170c799fa528688d40dd582c8265f4f"},
|
||||
{file = "pyflakes-3.4.0.tar.gz", hash = "sha256:b24f96fafb7d2ab0ec5075b7350b3d2d2218eab42003821c06344973d3ea2f58"},
|
||||
{file = "pyflakes-3.3.2-py2.py3-none-any.whl", hash = "sha256:5039c8339cbb1944045f4ee5466908906180f13cc99cc9949348d10f82a5c32a"},
|
||||
{file = "pyflakes-3.3.2.tar.gz", hash = "sha256:6dfd61d87b97fba5dcfaaf781171ac16be16453be6d816147989e7f6e6a9576b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pygments"
|
||||
version = "2.19.2"
|
||||
description = "Pygments is a syntax highlighting package written in Python."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main", "dev"]
|
||||
files = [
|
||||
{file = "pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b"},
|
||||
{file = "pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
windows-terminal = ["colorama (>=0.4.6)"]
|
||||
|
||||
[[package]]
|
||||
name = "pyjwt"
|
||||
version = "2.10.1"
|
||||
@@ -4150,14 +4110,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "pyright"
|
||||
version = "1.1.402"
|
||||
version = "1.1.401"
|
||||
description = "Command line wrapper for pyright"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pyright-1.1.402-py3-none-any.whl", hash = "sha256:2c721f11869baac1884e846232800fe021c33f1b4acb3929cff321f7ea4e2982"},
|
||||
{file = "pyright-1.1.402.tar.gz", hash = "sha256:85a33c2d40cd4439c66aa946fd4ce71ab2f3f5b8c22ce36a623f59ac22937683"},
|
||||
{file = "pyright-1.1.401-py3-none-any.whl", hash = "sha256:6fde30492ba5b0d7667c16ecaf6c699fab8d7a1263f6a18549e0b00bf7724c06"},
|
||||
{file = "pyright-1.1.401.tar.gz", hash = "sha256:788a82b6611fa5e34a326a921d86d898768cddf59edde8e93e56087d277cc6f1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -4171,27 +4131,26 @@ nodejs = ["nodejs-wheel-binaries"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.4.1"
|
||||
version = "8.3.5"
|
||||
description = "pytest: simple powerful testing with Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main", "dev"]
|
||||
files = [
|
||||
{file = "pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7"},
|
||||
{file = "pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c"},
|
||||
{file = "pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820"},
|
||||
{file = "pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
colorama = {version = ">=0.4", markers = "sys_platform == \"win32\""}
|
||||
exceptiongroup = {version = ">=1", markers = "python_version < \"3.11\""}
|
||||
iniconfig = ">=1"
|
||||
packaging = ">=20"
|
||||
colorama = {version = "*", markers = "sys_platform == \"win32\""}
|
||||
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
|
||||
iniconfig = "*"
|
||||
packaging = "*"
|
||||
pluggy = ">=1.5,<2"
|
||||
pygments = ">=2.7.2"
|
||||
tomli = {version = ">=1", markers = "python_version < \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests", "setuptools", "xmlschema"]
|
||||
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-asyncio"
|
||||
@@ -4278,14 +4237,14 @@ six = ">=1.5"
|
||||
|
||||
[[package]]
|
||||
name = "python-dotenv"
|
||||
version = "1.1.1"
|
||||
version = "1.1.0"
|
||||
description = "Read key-value pairs from a .env file and set them as environment variables"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "python_dotenv-1.1.1-py3-none-any.whl", hash = "sha256:31f23644fe2602f88ff55e1f5c79ba497e01224ee7737937930c448e4d0e24dc"},
|
||||
{file = "python_dotenv-1.1.1.tar.gz", hash = "sha256:a8a6399716257f45be6a007360200409fce5cda2661e3dec71d23dc15f6189ab"},
|
||||
{file = "python_dotenv-1.1.0-py3-none-any.whl", hash = "sha256:d7c01d9e2293916c18baf562d95698754b0dbbb5e74d457c45d4f6561fb9d55d"},
|
||||
{file = "python_dotenv-1.1.0.tar.gz", hash = "sha256:41f90bc6f5f177fb41f53e87666db362025010eb28f60a01c9143bfa33a2b2d5"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
@@ -4731,19 +4690,19 @@ typing_extensions = ">=4.5.0"
|
||||
|
||||
[[package]]
|
||||
name = "requests"
|
||||
version = "2.32.4"
|
||||
version = "2.32.3"
|
||||
description = "Python HTTP for Humans."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main", "dev"]
|
||||
files = [
|
||||
{file = "requests-2.32.4-py3-none-any.whl", hash = "sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c"},
|
||||
{file = "requests-2.32.4.tar.gz", hash = "sha256:27d0316682c8a29834d3264820024b62a36942083d52caf2f14c0591336d3422"},
|
||||
{file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
|
||||
{file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
certifi = ">=2017.4.17"
|
||||
charset_normalizer = ">=2,<4"
|
||||
charset-normalizer = ">=2,<4"
|
||||
idna = ">=2.5,<4"
|
||||
urllib3 = ">=1.21.1,<3"
|
||||
|
||||
@@ -4929,30 +4888,30 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.12.2"
|
||||
version = "0.11.12"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "ruff-0.12.2-py3-none-linux_armv6l.whl", hash = "sha256:093ea2b221df1d2b8e7ad92fc6ffdca40a2cb10d8564477a987b44fd4008a7be"},
|
||||
{file = "ruff-0.12.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:09e4cf27cc10f96b1708100fa851e0daf21767e9709e1649175355280e0d950e"},
|
||||
{file = "ruff-0.12.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:8ae64755b22f4ff85e9c52d1f82644abd0b6b6b6deedceb74bd71f35c24044cc"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3eb3a6b2db4d6e2c77e682f0b988d4d61aff06860158fdb413118ca133d57922"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:73448de992d05517170fc37169cbca857dfeaeaa8c2b9be494d7bcb0d36c8f4b"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b8b94317cbc2ae4a2771af641739f933934b03555e51515e6e021c64441532d"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:45fc42c3bf1d30d2008023a0a9a0cfb06bf9835b147f11fe0679f21ae86d34b1"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce48f675c394c37e958bf229fb5c1e843e20945a6d962cf3ea20b7a107dcd9f4"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:793d8859445ea47591272021a81391350205a4af65a9392401f418a95dfb75c9"},
|
||||
{file = "ruff-0.12.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6932323db80484dda89153da3d8e58164d01d6da86857c79f1961934354992da"},
|
||||
{file = "ruff-0.12.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6aa7e623a3a11538108f61e859ebf016c4f14a7e6e4eba1980190cacb57714ce"},
|
||||
{file = "ruff-0.12.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2a4a20aeed74671b2def096bdf2eac610c7d8ffcbf4fb0e627c06947a1d7078d"},
|
||||
{file = "ruff-0.12.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:71a4c550195612f486c9d1f2b045a600aeba851b298c667807ae933478fcef04"},
|
||||
{file = "ruff-0.12.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:4987b8f4ceadf597c927beee65a5eaf994c6e2b631df963f86d8ad1bdea99342"},
|
||||
{file = "ruff-0.12.2-py3-none-win32.whl", hash = "sha256:369ffb69b70cd55b6c3fc453b9492d98aed98062db9fec828cdfd069555f5f1a"},
|
||||
{file = "ruff-0.12.2-py3-none-win_amd64.whl", hash = "sha256:dca8a3b6d6dc9810ed8f328d406516bf4d660c00caeaef36eb831cf4871b0639"},
|
||||
{file = "ruff-0.12.2-py3-none-win_arm64.whl", hash = "sha256:48d6c6bfb4761df68bc05ae630e24f506755e702d4fb08f08460be778c7ccb12"},
|
||||
{file = "ruff-0.12.2.tar.gz", hash = "sha256:d7b4f55cd6f325cb7621244f19c873c565a08aff5a4ba9c69aa7355f3f7afd3e"},
|
||||
{file = "ruff-0.11.12-py3-none-linux_armv6l.whl", hash = "sha256:c7680aa2f0d4c4f43353d1e72123955c7a2159b8646cd43402de6d4a3a25d7cc"},
|
||||
{file = "ruff-0.11.12-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:2cad64843da9f134565c20bcc430642de897b8ea02e2e79e6e02a76b8dcad7c3"},
|
||||
{file = "ruff-0.11.12-py3-none-macosx_11_0_arm64.whl", hash = "sha256:9b6886b524a1c659cee1758140138455d3c029783d1b9e643f3624a5ee0cb0aa"},
|
||||
{file = "ruff-0.11.12-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3cc3a3690aad6e86c1958d3ec3c38c4594b6ecec75c1f531e84160bd827b2012"},
|
||||
{file = "ruff-0.11.12-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f97fdbc2549f456c65b3b0048560d44ddd540db1f27c778a938371424b49fe4a"},
|
||||
{file = "ruff-0.11.12-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:74adf84960236961090e2d1348c1a67d940fd12e811a33fb3d107df61eef8fc7"},
|
||||
{file = "ruff-0.11.12-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:b56697e5b8bcf1d61293ccfe63873aba08fdbcbbba839fc046ec5926bdb25a3a"},
|
||||
{file = "ruff-0.11.12-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4d47afa45e7b0eaf5e5969c6b39cbd108be83910b5c74626247e366fd7a36a13"},
|
||||
{file = "ruff-0.11.12-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:692bf9603fe1bf949de8b09a2da896f05c01ed7a187f4a386cdba6760e7f61be"},
|
||||
{file = "ruff-0.11.12-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08033320e979df3b20dba567c62f69c45e01df708b0f9c83912d7abd3e0801cd"},
|
||||
{file = "ruff-0.11.12-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:929b7706584f5bfd61d67d5070f399057d07c70585fa8c4491d78ada452d3bef"},
|
||||
{file = "ruff-0.11.12-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:7de4a73205dc5756b8e09ee3ed67c38312dce1aa28972b93150f5751199981b5"},
|
||||
{file = "ruff-0.11.12-py3-none-musllinux_1_2_i686.whl", hash = "sha256:2635c2a90ac1b8ca9e93b70af59dfd1dd2026a40e2d6eebaa3efb0465dd9cf02"},
|
||||
{file = "ruff-0.11.12-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:d05d6a78a89166f03f03a198ecc9d18779076ad0eec476819467acb401028c0c"},
|
||||
{file = "ruff-0.11.12-py3-none-win32.whl", hash = "sha256:f5a07f49767c4be4772d161bfc049c1f242db0cfe1bd976e0f0886732a4765d6"},
|
||||
{file = "ruff-0.11.12-py3-none-win_amd64.whl", hash = "sha256:5a4d9f8030d8c3a45df201d7fb3ed38d0219bccd7955268e863ee4a115fa0832"},
|
||||
{file = "ruff-0.11.12-py3-none-win_arm64.whl", hash = "sha256:65194e37853158d368e333ba282217941029a28ea90913c67e558c611d04daa5"},
|
||||
{file = "ruff-0.11.12.tar.gz", hash = "sha256:43cf7f69c7d7c7d7513b9d59c5d8cafd704e05944f978614aa9faff6ac202603"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4986,14 +4945,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "sentry-sdk"
|
||||
version = "2.32.0"
|
||||
version = "2.29.1"
|
||||
description = "Python client for Sentry (https://sentry.io)"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "sentry_sdk-2.32.0-py2.py3-none-any.whl", hash = "sha256:6cf51521b099562d7ce3606da928c473643abe99b00ce4cb5626ea735f4ec345"},
|
||||
{file = "sentry_sdk-2.32.0.tar.gz", hash = "sha256:9016c75d9316b0f6921ac14c8cd4fb938f26002430ac5be9945ab280f78bec6b"},
|
||||
{file = "sentry_sdk-2.29.1-py2.py3-none-any.whl", hash = "sha256:90862fe0616ded4572da6c9dadb363121a1ae49a49e21c418f0634e9d10b4c19"},
|
||||
{file = "sentry_sdk-2.29.1.tar.gz", hash = "sha256:8d4a0206b95fa5fe85e5e7517ed662e3888374bdc342c00e435e10e6d831aa6d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -5047,27 +5006,6 @@ statsig = ["statsig (>=0.55.3)"]
|
||||
tornado = ["tornado (>=6)"]
|
||||
unleash = ["UnleashClient (>=6.0.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "setuptools"
|
||||
version = "80.9.0"
|
||||
description = "Easily download, build, install, upgrade, and uninstall Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"},
|
||||
{file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""]
|
||||
core = ["importlib_metadata (>=6) ; python_version < \"3.10\"", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1) ; python_version < \"3.11\"", "wheel (>=0.43.0)"]
|
||||
cover = ["pytest-cov"]
|
||||
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
|
||||
enabler = ["pytest-enabler (>=2.2)"]
|
||||
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
|
||||
type = ["importlib_metadata (>=7.0.2) ; python_version < \"3.10\"", "jaraco.develop (>=7.21) ; sys_platform != \"cygwin\"", "mypy (==1.14.*)", "pytest-mypy"]
|
||||
|
||||
[[package]]
|
||||
name = "sgmllib3k"
|
||||
version = "1.0.0"
|
||||
@@ -5280,23 +5218,23 @@ typing-extensions = {version = ">=4.5.0", markers = "python_version >= \"3.7\""}
|
||||
|
||||
[[package]]
|
||||
name = "supabase"
|
||||
version = "2.16.0"
|
||||
version = "2.15.1"
|
||||
description = "Supabase client for Python."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "supabase-2.16.0-py3-none-any.whl", hash = "sha256:99065caab3d90a56650bf39fbd0e49740995da3738ab28706c61bd7f2401db55"},
|
||||
{file = "supabase-2.16.0.tar.gz", hash = "sha256:98f3810158012d4ec0e3083f2e5515f5e10b32bd71e7d458662140e963c1d164"},
|
||||
{file = "supabase-2.15.1-py3-none-any.whl", hash = "sha256:749299cdd74ecf528f52045c1e60d9dba81cc2054656f754c0ca7fba0dd34827"},
|
||||
{file = "supabase-2.15.1.tar.gz", hash = "sha256:66e847dab9346062aa6a25b4e81ac786b972c5d4299827c57d1d5bd6a0346070"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
gotrue = ">=2.11.0,<3.0.0"
|
||||
httpx = ">=0.26,<0.29"
|
||||
postgrest = ">0.19,<1.2"
|
||||
realtime = ">=2.4.0,<2.6.0"
|
||||
storage3 = ">=0.10,<0.13"
|
||||
supafunc = ">=0.9,<0.11"
|
||||
postgrest = ">0.19,<1.1"
|
||||
realtime = ">=2.4.0,<2.5.0"
|
||||
storage3 = ">=0.10,<0.12"
|
||||
supafunc = ">=0.9,<0.10"
|
||||
|
||||
[[package]]
|
||||
name = "supafunc"
|
||||
@@ -5503,14 +5441,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "tweepy"
|
||||
version = "4.16.0"
|
||||
description = "Library for accessing the X API (Twitter)"
|
||||
version = "4.15.0"
|
||||
description = "Twitter library for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "tweepy-4.16.0-py3-none-any.whl", hash = "sha256:48d1a1eb311d2c4b8990abcfa6f9fa2b2ad61be05c723b1a9b4f242656badae2"},
|
||||
{file = "tweepy-4.16.0.tar.gz", hash = "sha256:1d95cbdc50bf6353a387f881f2584eaf60d14e00dbbdd8872a73de79c66878e3"},
|
||||
{file = "tweepy-4.15.0-py3-none-any.whl", hash = "sha256:64adcea317158937059e4e2897b3ceb750b0c2dd5df58938c2da8f7eb3b88e6a"},
|
||||
{file = "tweepy-4.15.0.tar.gz", hash = "sha256:1345cbcdf0a75e2d89f424c559fd49fda4d8cd7be25cd5131e3b57bad8a21d76"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -5521,6 +5459,8 @@ requests-oauthlib = ">=1.2.0,<3"
|
||||
[package.extras]
|
||||
async = ["aiohttp (>=3.7.3,<4)", "async-lru (>=1.0.3,<3)"]
|
||||
dev = ["coverage (>=4.4.2)", "coveralls (>=2.1.0)", "tox (>=3.21.0)"]
|
||||
docs = ["myst-parser (==0.15.2)", "readthedocs-sphinx-search (==0.1.1)", "sphinx (==4.2.0)", "sphinx-hoverxref (==0.7b1)", "sphinx-tabs (==3.2.0)", "sphinx_rtd_theme (==1.0.0)"]
|
||||
socks = ["requests[socks] (>=2.27.0,<3)"]
|
||||
test = ["urllib3 (<2)", "vcrpy (>=1.10.3)"]
|
||||
|
||||
[[package]]
|
||||
@@ -6280,14 +6220,14 @@ requests = "*"
|
||||
|
||||
[[package]]
|
||||
name = "zerobouncesdk"
|
||||
version = "1.1.2"
|
||||
version = "1.1.1"
|
||||
description = "ZeroBounce Python API - https://www.zerobounce.net."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "zerobouncesdk-1.1.2-py3-none-any.whl", hash = "sha256:a89febfb3adade01c314e6bad2113ad093f1e1cca6ddf9fcf445a8b2a9a458b4"},
|
||||
{file = "zerobouncesdk-1.1.2.tar.gz", hash = "sha256:24810a2e39c963bc75b4732356b0fc8b10091f2c892f0c8b08fbb32640fdccaf"},
|
||||
{file = "zerobouncesdk-1.1.1-py3-none-any.whl", hash = "sha256:9fb9dfa44fe4ce35d6f2e43d5144c31ca03544a3317d75643cb9f86b0c028675"},
|
||||
{file = "zerobouncesdk-1.1.1.tar.gz", hash = "sha256:00aa537263d5bc21534c0007dd9f94ce8e0986caa530c5a0bbe0bd917451f236"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6429,4 +6369,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "476228d2bf59b90edc5425c462c1263cbc1f2d346f79a826ac5e7efe7823aaa6"
|
||||
content-hash = "6c93e51cf22c06548aa6d0e23ca8ceb4450f5e02d4142715e941aabc1a2cbd6a"
|
||||
|
||||
@@ -10,67 +10,64 @@ packages = [{ include = "backend", format = "sdist" }]
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<3.13"
|
||||
aio-pika = "^9.5.5"
|
||||
aiodns = "^3.5.0"
|
||||
anthropic = "^0.57.1"
|
||||
aiodns = "^3.1.1"
|
||||
anthropic = "^0.51.0"
|
||||
apscheduler = "^3.11.0"
|
||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||
bleach = { extras = ["css"], version = "^6.2.0" }
|
||||
click = "^8.2.0"
|
||||
cryptography = "^43.0"
|
||||
discord-py = "^2.5.2"
|
||||
e2b-code-interpreter = "^1.5.2"
|
||||
fastapi = "^0.115.14"
|
||||
e2b-code-interpreter = "^1.5.0"
|
||||
fastapi = "^0.115.12"
|
||||
feedparser = "^6.0.11"
|
||||
flake8 = "^7.3.0"
|
||||
google-api-python-client = "^2.176.0"
|
||||
flake8 = "^7.2.0"
|
||||
google-api-python-client = "^2.169.0"
|
||||
google-auth-oauthlib = "^1.2.2"
|
||||
google-cloud-storage = "^3.2.0"
|
||||
google-cloud-storage = "^3.1.0"
|
||||
googlemaps = "^4.10.0"
|
||||
gravitasml = "^0.1.3"
|
||||
groq = "^0.29.0"
|
||||
groq = "^0.24.0"
|
||||
jinja2 = "^3.1.6"
|
||||
jsonref = "^1.1.0"
|
||||
jsonschema = "^4.22.0"
|
||||
launchdarkly-server-sdk = "^9.11.0"
|
||||
mem0ai = "^0.1.114"
|
||||
mem0ai = "^0.1.98"
|
||||
moviepy = "^2.1.2"
|
||||
ollama = "^0.5.1"
|
||||
openai = "^1.93.2"
|
||||
ollama = "^0.4.8"
|
||||
openai = "^1.78.1"
|
||||
pika = "^1.3.2"
|
||||
pinecone = "^5.3.1"
|
||||
poetry = "2.1.1" # CHECK DEPENDABOT SUPPORT BEFORE UPGRADING
|
||||
postmarker = "^1.0"
|
||||
praw = "~7.8.1"
|
||||
prisma = "^0.15.0"
|
||||
prometheus-client = "^0.22.1"
|
||||
prometheus-client = "^0.21.1"
|
||||
psutil = "^7.0.0"
|
||||
psycopg2-binary = "^2.9.10"
|
||||
pydantic = { extras = ["email"], version = "^2.11.7" }
|
||||
pydantic-settings = "^2.10.1"
|
||||
pytest = "^8.4.1"
|
||||
pydantic = { extras = ["email"], version = "^2.11.4" }
|
||||
pydantic-settings = "^2.9.1"
|
||||
pytest = "^8.3.5"
|
||||
pytest-asyncio = "^0.26.0"
|
||||
python-dotenv = "^1.1.1"
|
||||
python-dotenv = "^1.1.0"
|
||||
python-multipart = "^0.0.20"
|
||||
redis = "^5.2.0"
|
||||
replicate = "^1.0.6"
|
||||
sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlalchemy"], version = "^2.32.0"}
|
||||
sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlalchemy"], version = "^2.28.0"}
|
||||
sqlalchemy = "^2.0.40"
|
||||
strenum = "^0.4.9"
|
||||
stripe = "^11.5.0"
|
||||
supabase = "2.16.0"
|
||||
supabase = "2.15.1"
|
||||
tenacity = "^9.1.2"
|
||||
todoist-api-python = "^2.1.7"
|
||||
tweepy = "^4.16.0"
|
||||
tweepy = "^4.14.0"
|
||||
uvicorn = { extras = ["standard"], version = "^0.34.2" }
|
||||
websockets = "^14.2"
|
||||
youtube-transcript-api = "^0.6.2"
|
||||
zerobouncesdk = "^1.1.2"
|
||||
zerobouncesdk = "^1.1.1"
|
||||
# NOTE: please insert new dependencies in their alphabetical location
|
||||
pytest-snapshot = "^0.9.0"
|
||||
aiofiles = "^24.1.0"
|
||||
tiktoken = "^0.9.0"
|
||||
aioclamd = "^1.0.0"
|
||||
setuptools = "^80.9.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
aiohappyeyeballs = "^2.6.1"
|
||||
@@ -78,12 +75,12 @@ black = "^24.10.0"
|
||||
faker = "^33.3.1"
|
||||
httpx = "^0.28.1"
|
||||
isort = "^5.13.2"
|
||||
poethepoet = "^0.36.0"
|
||||
pyright = "^1.1.402"
|
||||
poethepoet = "^0.34.0"
|
||||
pyright = "^1.1.400"
|
||||
pytest-mock = "^3.14.0"
|
||||
pytest-watcher = "^0.4.2"
|
||||
requests = "^2.32.4"
|
||||
ruff = "^0.12.2"
|
||||
requests = "^2.32.3"
|
||||
ruff = "^0.11.10"
|
||||
# NOTE: please insert new dependencies in their alphabetical location
|
||||
|
||||
[build-system]
|
||||
@@ -115,11 +112,6 @@ ignore_patterns = []
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "session"
|
||||
filterwarnings = [
|
||||
"ignore:'audioop' is deprecated:DeprecationWarning:discord.player",
|
||||
"ignore:invalid escape sequence:DeprecationWarning:tweepy.api",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user