mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 00:28:31 -05:00
Compare commits
1 Commits
clarify-li
...
ntindle-pa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4b4207760f |
@@ -1,18 +0,0 @@
|
||||
version = 1
|
||||
|
||||
test_patterns = ["**/*.spec.ts","**/*_test.py","**/*_tests.py","**/test_*.py"]
|
||||
|
||||
exclude_patterns = ["classic/**"]
|
||||
|
||||
[[analyzers]]
|
||||
name = "javascript"
|
||||
|
||||
[analyzers.meta]
|
||||
plugins = ["react"]
|
||||
environment = ["nodejs"]
|
||||
|
||||
[[analyzers]]
|
||||
name = "python"
|
||||
|
||||
[analyzers.meta]
|
||||
runtime_version = "3.x.x"
|
||||
24
.github/dependabot.yml
vendored
24
.github/dependabot.yml
vendored
@@ -129,6 +129,30 @@ updates:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# Submodules
|
||||
- package-ecosystem: "gitsubmodule"
|
||||
directory: "autogpt_platform/supabase"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 1
|
||||
target-branch: "dev"
|
||||
commit-message:
|
||||
prefix: "chore(platform/deps)"
|
||||
prefix-development: "chore(platform/deps-dev)"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# Docs
|
||||
- package-ecosystem: 'pip'
|
||||
directory: "docs/"
|
||||
|
||||
9
.github/workflows/classic-autogpt-ci.yml
vendored
9
.github/workflows/classic-autogpt-ci.yml
vendored
@@ -115,7 +115,6 @@ jobs:
|
||||
poetry run pytest -vv \
|
||||
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--numprocesses=logical --durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
tests/unit tests/integration
|
||||
env:
|
||||
CI: true
|
||||
@@ -125,14 +124,8 @@ jobs:
|
||||
AWS_ACCESS_KEY_ID: minioadmin
|
||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||
|
||||
- name: Upload test results to Codecov
|
||||
if: ${{ !cancelled() }} # Run even if tests fail
|
||||
uses: codecov/test-results-action@v1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: autogpt-agent,${{ runner.os }}
|
||||
|
||||
9
.github/workflows/classic-benchmark-ci.yml
vendored
9
.github/workflows/classic-benchmark-ci.yml
vendored
@@ -87,20 +87,13 @@ jobs:
|
||||
poetry run pytest -vv \
|
||||
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
tests
|
||||
env:
|
||||
CI: true
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Upload test results to Codecov
|
||||
if: ${{ !cancelled() }} # Run even if tests fail
|
||||
uses: codecov/test-results-action@v1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: agbenchmark,${{ runner.os }}
|
||||
|
||||
9
.github/workflows/classic-forge-ci.yml
vendored
9
.github/workflows/classic-forge-ci.yml
vendored
@@ -139,7 +139,6 @@ jobs:
|
||||
poetry run pytest -vv \
|
||||
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
forge
|
||||
env:
|
||||
CI: true
|
||||
@@ -149,14 +148,8 @@ jobs:
|
||||
AWS_ACCESS_KEY_ID: minioadmin
|
||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||
|
||||
- name: Upload test results to Codecov
|
||||
if: ${{ !cancelled() }} # Run even if tests fail
|
||||
uses: codecov/test-results-action@v1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: forge,${{ runner.os }}
|
||||
|
||||
17
.github/workflows/platform-backend-ci.yml
vendored
17
.github/workflows/platform-backend-ci.yml
vendored
@@ -42,14 +42,6 @@ jobs:
|
||||
REDIS_PASSWORD: testpassword
|
||||
ports:
|
||||
- 6379:6379
|
||||
rabbitmq:
|
||||
image: rabbitmq:3.12-management
|
||||
ports:
|
||||
- 5672:5672
|
||||
- 15672:15672
|
||||
env:
|
||||
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
|
||||
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -66,7 +58,7 @@ jobs:
|
||||
- name: Setup Supabase
|
||||
uses: supabase/setup-cli@v1
|
||||
with:
|
||||
version: 1.178.1
|
||||
version: latest
|
||||
|
||||
- id: get_date
|
||||
name: Get date
|
||||
@@ -147,13 +139,6 @@ jobs:
|
||||
RUN_ENV: local
|
||||
PORT: 8080
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
# We know these are here, don't report this as a security vulnerability
|
||||
# This is used as the default credential for the entire system's RabbitMQ instance
|
||||
# If you want to replace this, you can do so by making our entire system generate
|
||||
# new credentials for each local user and update the environment variables in
|
||||
# the backend service, docker composes, and examples
|
||||
RABBITMQ_DEFAULT_USER: 'rabbitmq_user_default'
|
||||
RABBITMQ_DEFAULT_PASS: 'k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7'
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
|
||||
30
.github/workflows/platform-frontend-ci.yml
vendored
30
.github/workflows/platform-frontend-ci.yml
vendored
@@ -37,25 +37,6 @@ jobs:
|
||||
run: |
|
||||
yarn lint
|
||||
|
||||
type-check:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
yarn install --frozen-lockfile
|
||||
|
||||
- name: Run tsc check
|
||||
run: |
|
||||
yarn type-check
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
@@ -77,12 +58,12 @@ jobs:
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
large-packages: false # slow
|
||||
docker-images: false # limited benefit
|
||||
large-packages: false # slow
|
||||
docker-images: false # limited benefit
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.example ../.env
|
||||
cp ../supabase/docker/.env.example ../.env
|
||||
|
||||
- name: Copy backend .env
|
||||
run: |
|
||||
@@ -104,12 +85,11 @@ jobs:
|
||||
run: yarn playwright install --with-deps ${{ matrix.browser }}
|
||||
|
||||
- name: Run tests
|
||||
timeout-minutes: 20
|
||||
run: |
|
||||
yarn test --project=${{ matrix.browser }}
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
- name: Print Docker Compose logs in debug mode
|
||||
if: runner.debug
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml logs
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ jobs:
|
||||
close-issue-message: >
|
||||
This issue was closed automatically because it has been stale for 10 days
|
||||
with no activity.
|
||||
days-before-stale: 100
|
||||
days-before-stale: 50
|
||||
days-before-close: 10
|
||||
# Do not touch meta issues:
|
||||
exempt-issue-labels: meta,fridge,project management
|
||||
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +1,6 @@
|
||||
[submodule "classic/forge/tests/vcr_cassettes"]
|
||||
path = classic/forge/tests/vcr_cassettes
|
||||
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes
|
||||
[submodule "autogpt_platform/supabase"]
|
||||
path = autogpt_platform/supabase
|
||||
url = https://github.com/supabase/supabase.git
|
||||
|
||||
@@ -140,7 +140,7 @@ repos:
|
||||
language: system
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 24.10.0
|
||||
rev: 23.12.1
|
||||
# Black has sensible defaults, doesn't need package context, and ignores
|
||||
# everything in .gitignore, so it works fine without any config or arguments.
|
||||
hooks:
|
||||
@@ -170,16 +170,6 @@ repos:
|
||||
files: ^classic/benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
|
||||
args: [--config=classic/benchmark/.flake8]
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: prettier
|
||||
name: Format (Prettier) - AutoGPT Platform - Frontend
|
||||
alias: format-platform-frontend
|
||||
entry: bash -c 'cd autogpt_platform/frontend && npx prettier --write $(echo "$@" | sed "s|autogpt_platform/frontend/||g")' --
|
||||
files: ^autogpt_platform/frontend/
|
||||
types: [file]
|
||||
language: system
|
||||
|
||||
- repo: local
|
||||
# To have watertight type checking, we check *all* the files in an affected
|
||||
# project. To trigger on poetry.lock we also reset the file `types` filter.
|
||||
@@ -231,16 +221,6 @@ repos:
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: tsc
|
||||
name: Typecheck - AutoGPT Platform - Frontend
|
||||
entry: bash -c 'cd autogpt_platform/frontend && npm run type-check'
|
||||
files: ^autogpt_platform/frontend/
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: pytest
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
If you are reading this, you are probably looking for the full **[contribution guide]**,
|
||||
which is part of our [wiki].
|
||||
|
||||
Also check out our [🚀 Roadmap][roadmap] for information about our priorities and associated tasks.
|
||||
<!-- You can find our immediate priorities and their progress on our public [kanban board]. -->
|
||||
|
||||
[contribution guide]: https://github.com/Significant-Gravitas/AutoGPT/wiki/Contributing
|
||||
[wiki]: https://github.com/Significant-Gravitas/AutoGPT/wiki
|
||||
[roadmap]: https://github.com/Significant-Gravitas/AutoGPT/discussions/6971
|
||||
|
||||
173
LICENSE
173
LICENSE
@@ -1,8 +1,5 @@
|
||||
All portions of this repository are under one of two licenses.
|
||||
|
||||
The all files outside of the autogpt_platform folder are under the MIT License below.
|
||||
|
||||
The autogpt_platform folder is under the Polyform Shield License below.
|
||||
All portions of this repository are under one of two licenses. The majority of the AutoGPT repository is under the MIT License below. The autogpt_platform folder is under the
|
||||
Polyform Shield License.
|
||||
|
||||
|
||||
MIT License
|
||||
@@ -30,169 +27,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
||||
# PolyForm Shield License 1.0.0
|
||||
|
||||
<https://polyformproject.org/licenses/shield/1.0.0>
|
||||
|
||||
## Acceptance
|
||||
|
||||
In order to get any license under these terms, you must agree
|
||||
to them as both strict obligations and conditions to all
|
||||
your licenses.
|
||||
|
||||
## Copyright License
|
||||
|
||||
The licensor grants you a copyright license for the
|
||||
software to do everything you might do with the software
|
||||
that would otherwise infringe the licensor's copyright
|
||||
in it for any permitted purpose. However, you may
|
||||
only distribute the software according to [Distribution
|
||||
License](#distribution-license) and make changes or new works
|
||||
based on the software according to [Changes and New Works
|
||||
License](#changes-and-new-works-license).
|
||||
|
||||
## Distribution License
|
||||
|
||||
The licensor grants you an additional copyright license
|
||||
to distribute copies of the software. Your license
|
||||
to distribute covers distributing the software with
|
||||
changes and new works permitted by [Changes and New Works
|
||||
License](#changes-and-new-works-license).
|
||||
|
||||
## Notices
|
||||
|
||||
You must ensure that anyone who gets a copy of any part of
|
||||
the software from you also gets a copy of these terms or the
|
||||
URL for them above, as well as copies of any plain-text lines
|
||||
beginning with `Required Notice:` that the licensor provided
|
||||
with the software. For example:
|
||||
|
||||
> Required Notice: Copyright Yoyodyne, Inc. (http://example.com)
|
||||
|
||||
## Changes and New Works License
|
||||
|
||||
The licensor grants you an additional copyright license to
|
||||
make changes and new works based on the software for any
|
||||
permitted purpose.
|
||||
|
||||
## Patent License
|
||||
|
||||
The licensor grants you a patent license for the software that
|
||||
covers patent claims the licensor can license, or becomes able
|
||||
to license, that you would infringe by using the software.
|
||||
|
||||
## Noncompete
|
||||
|
||||
Any purpose is a permitted purpose, except for providing any
|
||||
product that competes with the software or any product the
|
||||
licensor or any of its affiliates provides using the software.
|
||||
|
||||
## Competition
|
||||
|
||||
Goods and services compete even when they provide functionality
|
||||
through different kinds of interfaces or for different technical
|
||||
platforms. Applications can compete with services, libraries
|
||||
with plugins, frameworks with development tools, and so on,
|
||||
even if they're written in different programming languages
|
||||
or for different computer architectures. Goods and services
|
||||
compete even when provided free of charge. If you market a
|
||||
product as a practical substitute for the software or another
|
||||
product, it definitely competes.
|
||||
|
||||
## New Products
|
||||
|
||||
If you are using the software to provide a product that does
|
||||
not compete, but the licensor or any of its affiliates brings
|
||||
your product into competition by providing a new version of
|
||||
the software or another product using the software, you may
|
||||
continue using versions of the software available under these
|
||||
terms beforehand to provide your competing product, but not
|
||||
any later versions.
|
||||
|
||||
## Discontinued Products
|
||||
|
||||
You may begin using the software to compete with a product
|
||||
or service that the licensor or any of its affiliates has
|
||||
stopped providing, unless the licensor includes a plain-text
|
||||
line beginning with `Licensor Line of Business:` with the
|
||||
software that mentions that line of business. For example:
|
||||
|
||||
> Licensor Line of Business: YoyodyneCMS Content Management
|
||||
System (http://example.com/cms)
|
||||
|
||||
## Sales of Business
|
||||
|
||||
If the licensor or any of its affiliates sells a line of
|
||||
business developing the software or using the software
|
||||
to provide a product, the buyer can also enforce
|
||||
[Noncompete](#noncompete) for that product.
|
||||
|
||||
## Fair Use
|
||||
|
||||
You may have "fair use" rights for the software under the
|
||||
law. These terms do not limit them.
|
||||
|
||||
## No Other Rights
|
||||
|
||||
These terms do not allow you to sublicense or transfer any of
|
||||
your licenses to anyone else, or prevent the licensor from
|
||||
granting licenses to anyone else. These terms do not imply
|
||||
any other licenses.
|
||||
|
||||
## Patent Defense
|
||||
|
||||
If you make any written claim that the software infringes or
|
||||
contributes to infringement of any patent, your patent license
|
||||
for the software granted under these terms ends immediately. If
|
||||
your company makes such a claim, your patent license ends
|
||||
immediately for work on behalf of your company.
|
||||
|
||||
## Violations
|
||||
|
||||
The first time you are notified in writing that you have
|
||||
violated any of these terms, or done anything with the software
|
||||
not covered by your licenses, your licenses can nonetheless
|
||||
continue if you come into full compliance with these terms,
|
||||
and take practical steps to correct past violations, within
|
||||
32 days of receiving notice. Otherwise, all your licenses
|
||||
end immediately.
|
||||
|
||||
## No Liability
|
||||
|
||||
***As far as the law allows, the software comes as is, without
|
||||
any warranty or condition, and the licensor will not be liable
|
||||
to you for any damages arising out of these terms or the use
|
||||
or nature of the software, under any kind of legal claim.***
|
||||
|
||||
## Definitions
|
||||
|
||||
The **licensor** is the individual or entity offering these
|
||||
terms, and the **software** is the software the licensor makes
|
||||
available under these terms.
|
||||
|
||||
A **product** can be a good or service, or a combination
|
||||
of them.
|
||||
|
||||
**You** refers to the individual or entity agreeing to these
|
||||
terms.
|
||||
|
||||
**Your company** is any legal entity, sole proprietorship,
|
||||
or other kind of organization that you work for, plus all
|
||||
its affiliates.
|
||||
|
||||
**Affiliates** means the other organizations than an
|
||||
organization has control over, is under the control of, or is
|
||||
under common control with.
|
||||
|
||||
**Control** means ownership of substantially all the assets of
|
||||
an entity, or the power to direct its management and policies
|
||||
by vote, contract, or otherwise. Control can be direct or
|
||||
indirect.
|
||||
|
||||
**Your licenses** are all the licenses granted to you for the
|
||||
software under these terms.
|
||||
|
||||
**Use** means anything you do with the software requiring one
|
||||
of your licenses.
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
[](https://discord.gg/autogpt)  
|
||||
[](https://twitter.com/Auto_GPT)  
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
|
||||
**AutoGPT** is a powerful platform that allows you to create, deploy, and manage continuous AI agents that automate complex workflows.
|
||||
|
||||
@@ -79,7 +80,7 @@ Be part of the revolution! **AutoGPT** is here to stay, at the forefront of AI i
|
||||
|
||||
**Licensing:**
|
||||
|
||||
MIT License: All files outside of autogpt_platform folder are under the MIT License.
|
||||
MIT License: The majority of the AutoGPT repository is under the MIT License.
|
||||
|
||||
Polyform Shield License: This license applies to the autogpt_platform folder.
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ Instead, please report them via:
|
||||
- Please provide detailed reports with reproducible steps
|
||||
- Include the version/commit hash where you discovered the vulnerability
|
||||
- Allow us a 90-day security fix window before any public disclosure
|
||||
- After patch is released, allow 30 days for users to update before public disclosure (for a total of 120 days max between update time and fix time)
|
||||
- Share any potential mitigations or workarounds if known
|
||||
|
||||
## Supported Versions
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
############
|
||||
# Secrets
|
||||
# YOU MUST CHANGE THESE BEFORE GOING INTO PRODUCTION
|
||||
############
|
||||
|
||||
POSTGRES_PASSWORD=your-super-secret-and-long-postgres-password
|
||||
JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
|
||||
SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
DASHBOARD_USERNAME=supabase
|
||||
DASHBOARD_PASSWORD=this_password_is_insecure_and_should_be_updated
|
||||
SECRET_KEY_BASE=UpNVntn3cDxHJpq99YMc1T1AQgQpc8kfYTuRgBiYa15BLrx8etQoXz3gZv1/u2oq
|
||||
VAULT_ENC_KEY=your-encryption-key-32-chars-min
|
||||
|
||||
|
||||
############
|
||||
# Database - You can change these to any PostgreSQL database that has logical replication enabled.
|
||||
############
|
||||
|
||||
POSTGRES_HOST=db
|
||||
POSTGRES_DB=postgres
|
||||
POSTGRES_PORT=5432
|
||||
# default user is postgres
|
||||
|
||||
|
||||
############
|
||||
# Supavisor -- Database pooler
|
||||
############
|
||||
POOLER_PROXY_PORT_TRANSACTION=6543
|
||||
POOLER_DEFAULT_POOL_SIZE=20
|
||||
POOLER_MAX_CLIENT_CONN=100
|
||||
POOLER_TENANT_ID=your-tenant-id
|
||||
|
||||
|
||||
############
|
||||
# API Proxy - Configuration for the Kong Reverse proxy.
|
||||
############
|
||||
|
||||
KONG_HTTP_PORT=8000
|
||||
KONG_HTTPS_PORT=8443
|
||||
|
||||
|
||||
############
|
||||
# API - Configuration for PostgREST.
|
||||
############
|
||||
|
||||
PGRST_DB_SCHEMAS=public,storage,graphql_public
|
||||
|
||||
|
||||
############
|
||||
# Auth - Configuration for the GoTrue authentication server.
|
||||
############
|
||||
|
||||
## General
|
||||
SITE_URL=http://localhost:3000
|
||||
ADDITIONAL_REDIRECT_URLS=
|
||||
JWT_EXPIRY=3600
|
||||
DISABLE_SIGNUP=false
|
||||
API_EXTERNAL_URL=http://localhost:8000
|
||||
|
||||
## Mailer Config
|
||||
MAILER_URLPATHS_CONFIRMATION="/auth/v1/verify"
|
||||
MAILER_URLPATHS_INVITE="/auth/v1/verify"
|
||||
MAILER_URLPATHS_RECOVERY="/auth/v1/verify"
|
||||
MAILER_URLPATHS_EMAIL_CHANGE="/auth/v1/verify"
|
||||
|
||||
## Email auth
|
||||
ENABLE_EMAIL_SIGNUP=true
|
||||
ENABLE_EMAIL_AUTOCONFIRM=false
|
||||
SMTP_ADMIN_EMAIL=admin@example.com
|
||||
SMTP_HOST=supabase-mail
|
||||
SMTP_PORT=2500
|
||||
SMTP_USER=fake_mail_user
|
||||
SMTP_PASS=fake_mail_password
|
||||
SMTP_SENDER_NAME=fake_sender
|
||||
ENABLE_ANONYMOUS_USERS=false
|
||||
|
||||
## Phone auth
|
||||
ENABLE_PHONE_SIGNUP=true
|
||||
ENABLE_PHONE_AUTOCONFIRM=true
|
||||
|
||||
|
||||
############
|
||||
# Studio - Configuration for the Dashboard
|
||||
############
|
||||
|
||||
STUDIO_DEFAULT_ORGANIZATION=Default Organization
|
||||
STUDIO_DEFAULT_PROJECT=Default Project
|
||||
|
||||
STUDIO_PORT=3000
|
||||
# replace if you intend to use Studio outside of localhost
|
||||
SUPABASE_PUBLIC_URL=http://localhost:8000
|
||||
|
||||
# Enable webp support
|
||||
IMGPROXY_ENABLE_WEBP_DETECTION=true
|
||||
|
||||
# Add your OpenAI API key to enable SQL Editor Assistant
|
||||
OPENAI_API_KEY=
|
||||
|
||||
|
||||
############
|
||||
# Functions - Configuration for Functions
|
||||
############
|
||||
# NOTE: VERIFY_JWT applies to all functions. Per-function VERIFY_JWT is not supported yet.
|
||||
FUNCTIONS_VERIFY_JWT=false
|
||||
|
||||
|
||||
############
|
||||
# Logs - Configuration for Logflare
|
||||
# Please refer to https://supabase.com/docs/reference/self-hosting-analytics/introduction
|
||||
############
|
||||
|
||||
LOGFLARE_LOGGER_BACKEND_API_KEY=your-super-secret-and-long-logflare-key
|
||||
|
||||
# Change vector.toml sinks to reflect this change
|
||||
LOGFLARE_API_KEY=your-super-secret-and-long-logflare-key
|
||||
|
||||
# Docker socket location - this value will differ depending on your OS
|
||||
DOCKER_SOCKET_LOCATION=/var/run/docker.sock
|
||||
|
||||
# Google Cloud Project details
|
||||
GOOGLE_PROJECT_ID=GOOGLE_PROJECT_ID
|
||||
GOOGLE_PROJECT_NUMBER=GOOGLE_PROJECT_NUMBER
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
**Contributor License Agreement (“Agreement”)**
|
||||
|
||||
Thank you for your interest in the AutoGPT project at [https://github.com/Significant-Gravitas/AutoGPT](https://github.com/Significant-Gravitas/AutoGPT) stewarded by Determinist Ltd (“**Determinist**”), with offices at 3rd Floor 1 Ashley Road, Altrincham, Cheshire, WA14 2DT, United Kingdom. The form of license below is a document that clarifies the terms under which You, the person listed below, may contribute software code described below (the “**Contribution**”) to the project. We appreciate your participation in our project, and your help in improving our products, so we want you to understand what will be done with the Contributions. This license is for your protection as well as the protection of Determinist and its licensees; it does not change your rights to use your own Contributions for any other purpose.
|
||||
Thank you for your interest in the AutoGPT open source project at [https://github.com/Significant-Gravitas/AutoGPT](https://github.com/Significant-Gravitas/AutoGPT) stewarded by Determinist Ltd (“**Determinist**”), with offices at 3rd Floor 1 Ashley Road, Altrincham, Cheshire, WA14 2DT, United Kingdom. The form of license below is a document that clarifies the terms under which You, the person listed below, may contribute software code described below (the “**Contribution**”) to the project. We appreciate your participation in our project, and your help in improving our products, so we want you to understand what will be done with the Contributions. This license is for your protection as well as the protection of Determinist and its licensees; it does not change your rights to use your own Contributions for any other purpose.
|
||||
|
||||
By submitting a Pull Request which modifies the content of the “autogpt\_platform” folder at [https://github.com/Significant-Gravitas/AutoGPT/tree/master/autogpt\_platform](https://github.com/Significant-Gravitas/AutoGPT/tree/master/autogpt_platform), You hereby agree:
|
||||
|
||||
|
||||
@@ -22,29 +22,35 @@ To run the AutoGPT Platform, follow these steps:
|
||||
|
||||
2. Run the following command:
|
||||
```
|
||||
cp .env.example .env
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
This command will copy the `.env.example` file to `.env`. You can modify the `.env` file to add your own environment variables.
|
||||
This command will initialize and update the submodules in the repository. The `supabase` folder will be cloned to the root directory.
|
||||
|
||||
3. Run the following command:
|
||||
```
|
||||
cp supabase/docker/.env.example .env
|
||||
```
|
||||
This command will copy the `.env.example` file to `.env` in the `supabase/docker` directory. You can modify the `.env` file to add your own environment variables.
|
||||
|
||||
4. Run the following command:
|
||||
```
|
||||
docker compose up -d
|
||||
```
|
||||
This command will start all the necessary backend services defined in the `docker-compose.yml` file in detached mode.
|
||||
|
||||
4. Navigate to `frontend` within the `autogpt_platform` directory:
|
||||
5. Navigate to `frontend` within the `autogpt_platform` directory:
|
||||
```
|
||||
cd frontend
|
||||
```
|
||||
You will need to run your frontend application separately on your local machine.
|
||||
|
||||
5. Run the following command:
|
||||
6. Run the following command:
|
||||
```
|
||||
cp .env.example .env.local
|
||||
```
|
||||
This command will copy the `.env.example` file to `.env.local` in the `frontend` directory. You can modify the `.env.local` within this folder to add your own environment variables for the frontend application.
|
||||
|
||||
6. Run the following command:
|
||||
7. Run the following command:
|
||||
```
|
||||
npm install
|
||||
npm run dev
|
||||
@@ -55,7 +61,7 @@ To run the AutoGPT Platform, follow these steps:
|
||||
yarn install && yarn dev
|
||||
```
|
||||
|
||||
7. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
8. Open your browser and navigate to `http://localhost:3000` to access the AutoGPT Platform frontend.
|
||||
|
||||
### Docker Compose Commands
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from .config import Settings
|
||||
from .depends import requires_admin_user, requires_user
|
||||
from .jwt_utils import parse_jwt_token
|
||||
from .middleware import APIKeyValidator, auth_middleware
|
||||
from .middleware import auth_middleware
|
||||
from .models import User
|
||||
|
||||
__all__ = [
|
||||
"Settings",
|
||||
"parse_jwt_token",
|
||||
"requires_user",
|
||||
"requires_admin_user",
|
||||
"APIKeyValidator",
|
||||
"auth_middleware",
|
||||
"User",
|
||||
]
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Settings:
|
||||
def __init__(self):
|
||||
self.JWT_SECRET_KEY: str = os.getenv("SUPABASE_JWT_SECRET", "")
|
||||
self.ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
|
||||
self.JWT_ALGORITHM: str = "HS256"
|
||||
JWT_SECRET_KEY: str = os.getenv("SUPABASE_JWT_SECRET", "")
|
||||
ENABLE_AUTH: bool = os.getenv("ENABLE_AUTH", "false").lower() == "true"
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import fastapi
|
||||
|
||||
from .config import settings
|
||||
from .config import Settings
|
||||
from .middleware import auth_middleware
|
||||
from .models import DEFAULT_USER_ID, User
|
||||
|
||||
@@ -17,7 +17,7 @@ def requires_admin_user(
|
||||
|
||||
def verify_user(payload: dict | None, admin_only: bool) -> User:
|
||||
if not payload:
|
||||
if settings.ENABLE_AUTH:
|
||||
if Settings.ENABLE_AUTH:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="Authorization header is missing"
|
||||
)
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from fastapi import HTTPException, Request, Security
|
||||
from fastapi.security import APIKeyHeader, HTTPBearer
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.security import HTTPBearer
|
||||
|
||||
from .config import settings
|
||||
from .jwt_utils import parse_jwt_token
|
||||
@@ -32,104 +29,3 @@ async def auth_middleware(request: Request):
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
return payload
|
||||
|
||||
|
||||
class APIKeyValidator:
|
||||
"""
|
||||
Configurable API key validator that supports custom validation functions
|
||||
for FastAPI applications.
|
||||
|
||||
This class provides a flexible way to implement API key authentication with optional
|
||||
custom validation logic. It can be used for simple token matching
|
||||
or more complex validation scenarios like database lookups.
|
||||
|
||||
Examples:
|
||||
Simple token validation:
|
||||
```python
|
||||
validator = APIKeyValidator(
|
||||
header_name="X-API-Key",
|
||||
expected_token="your-secret-token"
|
||||
)
|
||||
|
||||
@app.get("/protected", dependencies=[Depends(validator.get_dependency())])
|
||||
def protected_endpoint():
|
||||
return {"message": "Access granted"}
|
||||
```
|
||||
|
||||
Custom validation with database lookup:
|
||||
```python
|
||||
async def validate_with_db(api_key: str):
|
||||
api_key_obj = await db.get_api_key(api_key)
|
||||
return api_key_obj if api_key_obj and api_key_obj.is_active else None
|
||||
|
||||
validator = APIKeyValidator(
|
||||
header_name="X-API-Key",
|
||||
validate_fn=validate_with_db
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
header_name (str): The name of the header containing the API key
|
||||
expected_token (Optional[str]): The expected API key value for simple token matching
|
||||
validate_fn (Optional[Callable]): Custom validation function that takes an API key
|
||||
string and returns a boolean or object. Can be async.
|
||||
error_status (int): HTTP status code to use for validation errors
|
||||
error_message (str): Error message to return when validation fails
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
header_name: str,
|
||||
expected_token: Optional[str] = None,
|
||||
validate_fn: Optional[Callable[[str], bool]] = None,
|
||||
error_status: int = HTTP_401_UNAUTHORIZED,
|
||||
error_message: str = "Invalid API key",
|
||||
):
|
||||
# Create the APIKeyHeader as a class property
|
||||
self.security_scheme = APIKeyHeader(name=header_name)
|
||||
self.expected_token = expected_token
|
||||
self.custom_validate_fn = validate_fn
|
||||
self.error_status = error_status
|
||||
self.error_message = error_message
|
||||
|
||||
async def default_validator(self, api_key: str) -> bool:
|
||||
return api_key == self.expected_token
|
||||
|
||||
async def __call__(
|
||||
self, request: Request, api_key: str = Security(APIKeyHeader)
|
||||
) -> Any:
|
||||
if api_key is None:
|
||||
raise HTTPException(status_code=self.error_status, detail="Missing API key")
|
||||
|
||||
# Use custom validation if provided, otherwise use default equality check
|
||||
validator = self.custom_validate_fn or self.default_validator
|
||||
result = (
|
||||
await validator(api_key)
|
||||
if inspect.iscoroutinefunction(validator)
|
||||
else validator(api_key)
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=self.error_status, detail=self.error_message
|
||||
)
|
||||
|
||||
# Store validation result in request state if it's not just a boolean
|
||||
if result is not True:
|
||||
request.state.api_key = result
|
||||
|
||||
return result
|
||||
|
||||
def get_dependency(self):
|
||||
"""
|
||||
Returns a callable dependency that FastAPI will recognize as a security scheme
|
||||
"""
|
||||
|
||||
async def validate_api_key(
|
||||
request: Request, api_key: str = Security(self.security_scheme)
|
||||
) -> Any:
|
||||
return await self(request, api_key)
|
||||
|
||||
# This helps FastAPI recognize it as a security dependency
|
||||
validate_api_key.__name__ = f"validate_{self.security_scheme.model.name}"
|
||||
return validate_api_key
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing_extensions import ParamSpec
|
||||
from .config import SETTINGS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -18,7 +18,7 @@ ERROR_LOG_FILE = "error.log"
|
||||
SIMPLE_LOG_FORMAT = "%(asctime)s %(levelname)s %(title)s%(message)s"
|
||||
|
||||
DEBUG_LOG_FORMAT = (
|
||||
"%(asctime)s %(levelname)s %(filename)s:%(lineno)d %(title)s%(message)s"
|
||||
"%(asctime)s %(levelname)s %(filename)s:%(lineno)d" " %(title)s%(message)s"
|
||||
)
|
||||
|
||||
|
||||
@@ -99,6 +99,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
cloud_handler.setLevel(config.level)
|
||||
cloud_handler.setFormatter(StructuredLoggingFormatter())
|
||||
log_handlers.append(cloud_handler)
|
||||
print("Cloud logging enabled")
|
||||
else:
|
||||
# Console output handlers
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
@@ -117,6 +118,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
|
||||
|
||||
log_handlers += [stdout, stderr]
|
||||
print("Console logging enabled")
|
||||
|
||||
# File logging setup
|
||||
if config.enable_file_logging:
|
||||
@@ -154,6 +156,7 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
|
||||
error_log_handler.setLevel(logging.ERROR)
|
||||
error_log_handler.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT, no_color=True))
|
||||
log_handlers.append(error_log_handler)
|
||||
print("File logging enabled")
|
||||
|
||||
# Configure the root logger
|
||||
logging.basicConfig(
|
||||
|
||||
938
autogpt_platform/autogpt_libs/poetry.lock
generated
938
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -9,18 +9,19 @@ packages = [{ include = "autogpt_libs" }]
|
||||
[tool.poetry.dependencies]
|
||||
colorama = "^0.4.6"
|
||||
expiringdict = "^1.2.2"
|
||||
google-cloud-logging = "^3.11.4"
|
||||
pydantic = "^2.11.1"
|
||||
pydantic-settings = "^2.8.1"
|
||||
google-cloud-logging = "^3.11.3"
|
||||
pydantic = "^2.10.3"
|
||||
pydantic-settings = "^2.7.0"
|
||||
pyjwt = "^2.10.1"
|
||||
pytest-asyncio = "^0.26.0"
|
||||
pytest-asyncio = "^0.25.0"
|
||||
pytest-mock = "^3.14.0"
|
||||
python = ">=3.10,<4.0"
|
||||
supabase = "^2.15.0"
|
||||
python-dotenv = "^1.0.1"
|
||||
supabase = "^2.10.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
redis = "^5.2.1"
|
||||
ruff = "^0.11.0"
|
||||
ruff = "^0.8.6"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -2,32 +2,19 @@ DB_USER=postgres
|
||||
DB_PASS=your-super-secret-and-long-postgres-password
|
||||
DB_NAME=postgres
|
||||
DB_PORT=5432
|
||||
DB_HOST=localhost
|
||||
DB_CONNECTION_LIMIT=12
|
||||
DB_CONNECT_TIMEOUT=60
|
||||
DB_POOL_TIMEOUT=300
|
||||
DB_SCHEMA=platform
|
||||
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@localhost:${DB_PORT}/${DB_NAME}?connect_timeout=60&schema=platform"
|
||||
PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
|
||||
# EXECUTOR
|
||||
NUM_GRAPH_WORKERS=10
|
||||
NUM_NODE_WORKERS=3
|
||||
|
||||
BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]
|
||||
|
||||
# generate using `from cryptography.fernet import Fernet;Fernet.generate_key().decode()`
|
||||
ENCRYPTION_KEY='dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw='
|
||||
UNSUBSCRIBE_SECRET_KEY = 'HlP8ivStJjmbf6NKi78m_3FnOogut0t5ckzjsIqeaio='
|
||||
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=password
|
||||
|
||||
ENABLE_CREDIT=false
|
||||
STRIPE_API_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
|
||||
# What environment things should be logged under: local dev or prod
|
||||
APP_ENV=local
|
||||
# What environment to behave as: "local" or "cloud"
|
||||
@@ -35,26 +22,12 @@ BEHAVE_AS=local
|
||||
PYRO_HOST=localhost
|
||||
SENTRY_DSN=
|
||||
|
||||
# Email For Postmark so we can send emails
|
||||
POSTMARK_SERVER_API_TOKEN=
|
||||
POSTMARK_SENDER_EMAIL=invalid@invalid.com
|
||||
POSTMARK_WEBHOOK_TOKEN=
|
||||
|
||||
## User auth with Supabase is required for any of the 3rd party integrations with auth to work.
|
||||
ENABLE_AUTH=true
|
||||
SUPABASE_URL=http://localhost:8000
|
||||
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
|
||||
# RabbitMQ credentials -- Used for communication between services
|
||||
RABBITMQ_HOST=localhost
|
||||
RABBITMQ_PORT=5672
|
||||
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||
|
||||
## GCS bucket is required for marketplace and library functionality
|
||||
MEDIA_GCS_BUCKET_NAME=
|
||||
|
||||
## For local development, you may need to set FRONTEND_BASE_URL for the OAuth flow
|
||||
## for integrations to work. Defaults to the value of PLATFORM_BASE_URL if not set.
|
||||
# FRONTEND_BASE_URL=http://localhost:3000
|
||||
@@ -63,7 +36,7 @@ MEDIA_GCS_BUCKET_NAME=
|
||||
## to use the platform's webhook-related functionality.
|
||||
## If you are developing locally, you can use something like ngrok to get a publc URL
|
||||
## and tunnel it to your locally running backend.
|
||||
PLATFORM_BASE_URL=http://localhost:3000
|
||||
PLATFORM_BASE_URL=https://your-public-url-here
|
||||
|
||||
## == INTEGRATION CREDENTIALS == ##
|
||||
# Each set of server side credentials is required for the corresponding 3rd party
|
||||
@@ -99,20 +72,6 @@ GOOGLE_CLIENT_SECRET=
|
||||
TWITTER_CLIENT_ID=
|
||||
TWITTER_CLIENT_SECRET=
|
||||
|
||||
# Linear App
|
||||
# Make a new workspace for your OAuth APP -- trust me
|
||||
# https://linear.app/settings/api/applications/new
|
||||
# Callback URL: http://localhost:3000/auth/integrations/oauth_callback
|
||||
LINEAR_CLIENT_ID=
|
||||
LINEAR_CLIENT_SECRET=
|
||||
|
||||
# To obtain Todoist API credentials:
|
||||
# 1. Create a Todoist account at todoist.com
|
||||
# 2. Visit the Developer Console: https://developer.todoist.com/appconsole.html
|
||||
# 3. Click "Create new app"
|
||||
# 4. Once created, copy your Client ID and Client Secret below
|
||||
TODOIST_CLIENT_ID=
|
||||
TODOIST_CLIENT_SECRET=
|
||||
|
||||
## ===== OPTIONAL API KEYS ===== ##
|
||||
|
||||
@@ -123,12 +82,10 @@ GROQ_API_KEY=
|
||||
OPEN_ROUTER_API_KEY=
|
||||
|
||||
# Reddit
|
||||
# Go to https://www.reddit.com/prefs/apps and create a new app
|
||||
# Choose "script" for the type
|
||||
# Fill in the redirect uri as <your_frontend_url>/auth/integrations/oauth_callback, e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||
REDDIT_CLIENT_ID=
|
||||
REDDIT_CLIENT_SECRET=
|
||||
REDDIT_USER_AGENT="AutoGPT:1.0 (by /u/autogpt)"
|
||||
REDDIT_USERNAME=
|
||||
REDDIT_PASSWORD=
|
||||
|
||||
# Discord
|
||||
DISCORD_BOT_TOKEN=
|
||||
@@ -173,23 +130,9 @@ EXA_API_KEY=
|
||||
# E2B
|
||||
E2B_API_KEY=
|
||||
|
||||
# Mem0
|
||||
MEM0_API_KEY=
|
||||
|
||||
# Nvidia
|
||||
NVIDIA_API_KEY=
|
||||
|
||||
# Apollo
|
||||
APOLLO_API_KEY=
|
||||
|
||||
# SmartLead
|
||||
SMARTLEAD_API_KEY=
|
||||
|
||||
# ZeroBounce
|
||||
ZEROBOUNCE_API_KEY=
|
||||
|
||||
## ===== OPTIONAL API KEYS END ===== ##
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=INFO
|
||||
ENABLE_CLOUD_LOGGING=false
|
||||
|
||||
@@ -1 +1,75 @@
|
||||
[Advanced Setup (Dev Branch)](https://dev-docs.agpt.co/platform/advanced_setup/#autogpt_agent_server_advanced_set_up)
|
||||
# AutoGPT Agent Server Advanced set up
|
||||
|
||||
This guide walks you through a dockerized set up, with an external DB (postgres)
|
||||
|
||||
## Setup
|
||||
|
||||
We use the Poetry to manage the dependencies. To set up the project, follow these steps inside this directory:
|
||||
|
||||
0. Install Poetry
|
||||
```sh
|
||||
pip install poetry
|
||||
```
|
||||
|
||||
1. Configure Poetry to use .venv in your project directory
|
||||
```sh
|
||||
poetry config virtualenvs.in-project true
|
||||
```
|
||||
|
||||
2. Enter the poetry shell
|
||||
|
||||
```sh
|
||||
poetry shell
|
||||
```
|
||||
|
||||
3. Install dependencies
|
||||
|
||||
```sh
|
||||
poetry install
|
||||
```
|
||||
|
||||
4. Copy .env.example to .env
|
||||
|
||||
```sh
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
5. Generate the Prisma client
|
||||
|
||||
```sh
|
||||
poetry run prisma generate
|
||||
```
|
||||
|
||||
|
||||
> In case Prisma generates the client for the global Python installation instead of the virtual environment, the current mitigation is to just uninstall the global Prisma package:
|
||||
>
|
||||
> ```sh
|
||||
> pip uninstall prisma
|
||||
> ```
|
||||
>
|
||||
> Then run the generation again. The path *should* look something like this:
|
||||
> `<some path>/pypoetry/virtualenvs/backend-TQIRSwR6-py3.12/bin/prisma`
|
||||
|
||||
6. Run the postgres database from the /rnd folder
|
||||
|
||||
```sh
|
||||
cd autogpt_platform/
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
7. Run the migrations (from the backend folder)
|
||||
|
||||
```sh
|
||||
cd ../backend
|
||||
prisma migrate deploy
|
||||
```
|
||||
|
||||
## Running The Server
|
||||
|
||||
### Starting the server directly
|
||||
|
||||
Run the following command:
|
||||
|
||||
```sh
|
||||
poetry run app
|
||||
```
|
||||
|
||||
@@ -1 +1,203 @@
|
||||
[Getting Started (Released)](https://docs.agpt.co/platform/getting-started/#autogpt_agent_server)
|
||||
# AutoGPT Agent Server
|
||||
|
||||
This is an initial project for creating the next generation of agent execution, which is an AutoGPT agent server.
|
||||
The agent server will enable the creation of composite multi-agent systems that utilize AutoGPT agents and other non-agent components as its primitives.
|
||||
|
||||
## Docs
|
||||
|
||||
You can access the docs for the [AutoGPT Agent Server here](https://docs.agpt.co/server/setup).
|
||||
|
||||
## Setup
|
||||
|
||||
We use the Poetry to manage the dependencies. To set up the project, follow these steps inside this directory:
|
||||
|
||||
0. Install Poetry
|
||||
```sh
|
||||
pip install poetry
|
||||
```
|
||||
|
||||
1. Configure Poetry to use .venv in your project directory
|
||||
```sh
|
||||
poetry config virtualenvs.in-project true
|
||||
```
|
||||
|
||||
2. Enter the poetry shell
|
||||
|
||||
```sh
|
||||
poetry shell
|
||||
```
|
||||
|
||||
3. Install dependencies
|
||||
|
||||
```sh
|
||||
poetry install
|
||||
```
|
||||
|
||||
4. Copy .env.example to .env
|
||||
|
||||
```sh
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
5. Generate the Prisma client
|
||||
|
||||
```sh
|
||||
poetry run prisma generate
|
||||
```
|
||||
|
||||
|
||||
> In case Prisma generates the client for the global Python installation instead of the virtual environment, the current mitigation is to just uninstall the global Prisma package:
|
||||
>
|
||||
> ```sh
|
||||
> pip uninstall prisma
|
||||
> ```
|
||||
>
|
||||
> Then run the generation again. The path *should* look something like this:
|
||||
> `<some path>/pypoetry/virtualenvs/backend-TQIRSwR6-py3.12/bin/prisma`
|
||||
|
||||
6. Migrate the database. Be careful because this deletes current data in the database.
|
||||
|
||||
```sh
|
||||
docker compose up db -d
|
||||
poetry run prisma migrate deploy
|
||||
```
|
||||
|
||||
## Running The Server
|
||||
|
||||
### Starting the server without Docker
|
||||
|
||||
Run the following command to run database in docker but the application locally:
|
||||
|
||||
```sh
|
||||
docker compose --profile local up deps --build --detach
|
||||
poetry run app
|
||||
```
|
||||
|
||||
### Starting the server with Docker
|
||||
|
||||
Run the following command to build the dockerfiles:
|
||||
|
||||
```sh
|
||||
docker compose build
|
||||
```
|
||||
|
||||
Run the following command to run the app:
|
||||
|
||||
```sh
|
||||
docker compose up
|
||||
```
|
||||
|
||||
Run the following to automatically rebuild when code changes, in another terminal:
|
||||
|
||||
```sh
|
||||
docker compose watch
|
||||
```
|
||||
|
||||
Run the following command to shut down:
|
||||
|
||||
```sh
|
||||
docker compose down
|
||||
```
|
||||
|
||||
If you run into issues with dangling orphans, try:
|
||||
|
||||
```sh
|
||||
docker compose down --volumes --remove-orphans && docker-compose up --force-recreate --renew-anon-volumes --remove-orphans
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
To run the tests:
|
||||
|
||||
```sh
|
||||
poetry run test
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Formatting & Linting
|
||||
Auto formatter and linter are set up in the project. To run them:
|
||||
|
||||
Install:
|
||||
```sh
|
||||
poetry install --with dev
|
||||
```
|
||||
|
||||
Format the code:
|
||||
```sh
|
||||
poetry run format
|
||||
```
|
||||
|
||||
Lint the code:
|
||||
```sh
|
||||
poetry run lint
|
||||
```
|
||||
|
||||
## Project Outline
|
||||
|
||||
The current project has the following main modules:
|
||||
|
||||
### **blocks**
|
||||
|
||||
This module stores all the Agent Blocks, which are reusable components to build a graph that represents the agent's behavior.
|
||||
|
||||
### **data**
|
||||
|
||||
This module stores the logical model that is persisted in the database.
|
||||
It abstracts the database operations into functions that can be called by the service layer.
|
||||
Any code that interacts with Prisma objects or the database should reside in this module.
|
||||
The main models are:
|
||||
* `block`: anything related to the block used in the graph
|
||||
* `execution`: anything related to the execution graph execution
|
||||
* `graph`: anything related to the graph, node, and its relations
|
||||
|
||||
### **execution**
|
||||
|
||||
This module stores the business logic of executing the graph.
|
||||
It currently has the following main modules:
|
||||
* `manager`: A service that consumes the queue of the graph execution and executes the graph. It contains both pieces of logic.
|
||||
* `scheduler`: A service that triggers scheduled graph execution based on a cron expression. It pushes an execution request to the manager.
|
||||
|
||||
### **server**
|
||||
|
||||
This module stores the logic for the server API.
|
||||
It contains all the logic used for the API that allows the client to create, execute, and monitor the graph and its execution.
|
||||
This API service interacts with other services like those defined in `manager` and `scheduler`.
|
||||
|
||||
### **utils**
|
||||
|
||||
This module stores utility functions that are used across the project.
|
||||
Currently, it has two main modules:
|
||||
* `process`: A module that contains the logic to spawn a new process.
|
||||
* `service`: A module that serves as a parent class for all the services in the project.
|
||||
|
||||
## Service Communication
|
||||
|
||||
Currently, there are only 3 active services:
|
||||
|
||||
- AgentServer (the API, defined in `server.py`)
|
||||
- ExecutionManager (the executor, defined in `manager.py`)
|
||||
- ExecutionScheduler (the scheduler, defined in `scheduler.py`)
|
||||
|
||||
The services run in independent Python processes and communicate through an IPC.
|
||||
A communication layer (`service.py`) is created to decouple the communication library from the implementation.
|
||||
|
||||
Currently, the IPC is done using Pyro5 and abstracted in a way that allows a function decorated with `@expose` to be called from a different process.
|
||||
|
||||
|
||||
By default the daemons run on the following ports:
|
||||
|
||||
Execution Manager Daemon: 8002
|
||||
Execution Scheduler Daemon: 8003
|
||||
Rest Server Daemon: 8004
|
||||
|
||||
## Adding a New Agent Block
|
||||
|
||||
To add a new agent block, you need to create a new class that inherits from `Block` and provides the following information:
|
||||
* All the block code should live in the `blocks` (`backend.blocks`) module.
|
||||
* `input_schema`: the schema of the input data, represented by a Pydantic object.
|
||||
* `output_schema`: the schema of the output data, represented by a Pydantic object.
|
||||
* `run` method: the main logic of the block.
|
||||
* `test_input` & `test_output`: the sample input and output data for the block, which will be used to auto-test the block.
|
||||
* You can mock the functions declared in the block using the `test_mock` field for your unit tests.
|
||||
* Once you finish creating the block, you can test it by running `poetry run pytest -s test/block/test_block.py`.
|
||||
|
||||
@@ -1,30 +1,22 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.util.process import AppProcess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_processes(*processes: "AppProcess", **kwargs):
|
||||
"""
|
||||
Execute all processes in the app. The last process is run in the foreground.
|
||||
Includes enhanced error handling and process lifecycle management.
|
||||
"""
|
||||
try:
|
||||
# Run all processes except the last one in the background.
|
||||
for process in processes[:-1]:
|
||||
process.start(background=True, **kwargs)
|
||||
|
||||
# Run the last process in the foreground.
|
||||
# Run the last process in the foreground
|
||||
processes[-1].start(background=False, **kwargs)
|
||||
finally:
|
||||
for process in processes:
|
||||
try:
|
||||
process.stop()
|
||||
except Exception as e:
|
||||
logger.exception(f"[{process.service_name}] unable to stop: {e}")
|
||||
process.stop()
|
||||
|
||||
|
||||
def main(**kwargs):
|
||||
@@ -32,16 +24,14 @@ def main(**kwargs):
|
||||
Run all the processes required for the AutoGPT-server (REST and WebSocket APIs).
|
||||
"""
|
||||
|
||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||
from backend.notifications import NotificationManager
|
||||
from backend.executor import DatabaseManager, ExecutionManager, ExecutionScheduler
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.server.ws_api import WebsocketServer
|
||||
|
||||
run_processes(
|
||||
DatabaseManager(),
|
||||
ExecutionManager(),
|
||||
Scheduler(),
|
||||
NotificationManager(),
|
||||
ExecutionScheduler(),
|
||||
WebsocketServer(),
|
||||
AgentServer(),
|
||||
**kwargs,
|
||||
|
||||
@@ -2,103 +2,88 @@ import importlib
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
from typing import Type, TypeVar
|
||||
|
||||
from backend.data.block import Block
|
||||
|
||||
# Dynamically load all modules under backend.blocks
|
||||
AVAILABLE_MODULES = []
|
||||
current_dir = Path(__file__).parent
|
||||
modules = [
|
||||
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
|
||||
for f in current_dir.rglob("*.py")
|
||||
if f.is_file() and f.name != "__init__.py"
|
||||
]
|
||||
for module in modules:
|
||||
if not re.match("^[a-z0-9_.]+$", module):
|
||||
raise ValueError(
|
||||
f"Block module {module} error: module name must be lowercase, "
|
||||
"and contain only alphanumeric characters and underscores."
|
||||
)
|
||||
|
||||
importlib.import_module(f".{module}", package=__name__)
|
||||
AVAILABLE_MODULES.append(module)
|
||||
|
||||
# Load all Block instances from the available modules
|
||||
AVAILABLE_BLOCKS: dict[str, Type[Block]] = {}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.block import Block
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
_AVAILABLE_BLOCKS: dict[str, type["Block"]] = {}
|
||||
|
||||
|
||||
def load_all_blocks() -> dict[str, type["Block"]]:
|
||||
from backend.data.block import Block
|
||||
|
||||
if _AVAILABLE_BLOCKS:
|
||||
return _AVAILABLE_BLOCKS
|
||||
|
||||
# Dynamically load all modules under backend.blocks
|
||||
AVAILABLE_MODULES = []
|
||||
current_dir = Path(__file__).parent
|
||||
modules = [
|
||||
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
|
||||
for f in current_dir.rglob("*.py")
|
||||
if f.is_file() and f.name != "__init__.py"
|
||||
]
|
||||
for module in modules:
|
||||
if not re.match("^[a-z0-9_.]+$", module):
|
||||
raise ValueError(
|
||||
f"Block module {module} error: module name must be lowercase, "
|
||||
"and contain only alphanumeric characters and underscores."
|
||||
)
|
||||
|
||||
importlib.import_module(f".{module}", package=__name__)
|
||||
AVAILABLE_MODULES.append(module)
|
||||
|
||||
# Load all Block instances from the available modules
|
||||
for block_cls in all_subclasses(Block):
|
||||
class_name = block_cls.__name__
|
||||
|
||||
if class_name.endswith("Base"):
|
||||
continue
|
||||
|
||||
if not class_name.endswith("Block"):
|
||||
raise ValueError(
|
||||
f"Block class {class_name} does not end with 'Block'. "
|
||||
"If you are creating an abstract class, "
|
||||
"please name the class with 'Base' at the end"
|
||||
)
|
||||
|
||||
block = block_cls.create()
|
||||
|
||||
if not isinstance(block.id, str) or len(block.id) != 36:
|
||||
raise ValueError(
|
||||
f"Block ID {block.name} error: {block.id} is not a valid UUID"
|
||||
)
|
||||
|
||||
if block.id in _AVAILABLE_BLOCKS:
|
||||
raise ValueError(
|
||||
f"Block ID {block.name} error: {block.id} is already in use"
|
||||
)
|
||||
|
||||
input_schema = block.input_schema.model_fields
|
||||
output_schema = block.output_schema.model_fields
|
||||
|
||||
# Make sure `error` field is a string in the output schema
|
||||
if "error" in output_schema and output_schema["error"].annotation is not str:
|
||||
raise ValueError(
|
||||
f"{block.name} `error` field in output_schema must be a string"
|
||||
)
|
||||
|
||||
# Ensure all fields in input_schema and output_schema are annotated SchemaFields
|
||||
for field_name, field in [*input_schema.items(), *output_schema.items()]:
|
||||
if field.annotation is None:
|
||||
raise ValueError(
|
||||
f"{block.name} has a field {field_name} that is not annotated"
|
||||
)
|
||||
if field.json_schema_extra is None:
|
||||
raise ValueError(
|
||||
f"{block.name} has a field {field_name} not defined as SchemaField"
|
||||
)
|
||||
|
||||
for field in block.input_schema.model_fields.values():
|
||||
if field.annotation is bool and field.default not in (True, False):
|
||||
raise ValueError(
|
||||
f"{block.name} has a boolean field with no default value"
|
||||
)
|
||||
|
||||
_AVAILABLE_BLOCKS[block.id] = block_cls
|
||||
|
||||
return _AVAILABLE_BLOCKS
|
||||
|
||||
|
||||
__all__ = ["load_all_blocks"]
|
||||
|
||||
|
||||
def all_subclasses(cls: type[T]) -> list[type[T]]:
|
||||
def all_subclasses(cls: Type[T]) -> list[Type[T]]:
|
||||
subclasses = cls.__subclasses__()
|
||||
for subclass in subclasses:
|
||||
subclasses += all_subclasses(subclass)
|
||||
return subclasses
|
||||
|
||||
|
||||
for block_cls in all_subclasses(Block):
|
||||
name = block_cls.__name__
|
||||
|
||||
if block_cls.__name__.endswith("Base"):
|
||||
continue
|
||||
|
||||
if not block_cls.__name__.endswith("Block"):
|
||||
raise ValueError(
|
||||
f"Block class {block_cls.__name__} does not end with 'Block', If you are creating an abstract class, please name the class with 'Base' at the end"
|
||||
)
|
||||
|
||||
block = block_cls.create()
|
||||
|
||||
if not isinstance(block.id, str) or len(block.id) != 36:
|
||||
raise ValueError(f"Block ID {block.name} error: {block.id} is not a valid UUID")
|
||||
|
||||
if block.id in AVAILABLE_BLOCKS:
|
||||
raise ValueError(f"Block ID {block.name} error: {block.id} is already in use")
|
||||
|
||||
input_schema = block.input_schema.model_fields
|
||||
output_schema = block.output_schema.model_fields
|
||||
|
||||
# Make sure `error` field is a string in the output schema
|
||||
if "error" in output_schema and output_schema["error"].annotation is not str:
|
||||
raise ValueError(
|
||||
f"{block.name} `error` field in output_schema must be a string"
|
||||
)
|
||||
|
||||
# Make sure all fields in input_schema and output_schema are annotated and has a value
|
||||
for field_name, field in [*input_schema.items(), *output_schema.items()]:
|
||||
if field.annotation is None:
|
||||
raise ValueError(
|
||||
f"{block.name} has a field {field_name} that is not annotated"
|
||||
)
|
||||
if field.json_schema_extra is None:
|
||||
raise ValueError(
|
||||
f"{block.name} has a field {field_name} not defined as SchemaField"
|
||||
)
|
||||
|
||||
for field in block.input_schema.model_fields.values():
|
||||
if field.annotation is bool and field.default not in (True, False):
|
||||
raise ValueError(f"{block.name} has a boolean field with no default value")
|
||||
|
||||
if block.disabled:
|
||||
continue
|
||||
|
||||
AVAILABLE_BLOCKS[block.id] = block_cls
|
||||
|
||||
__all__ = ["AVAILABLE_MODULES", "AVAILABLE_BLOCKS"]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
@@ -14,7 +13,6 @@ from backend.data.block import (
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -44,23 +42,6 @@ class AgentExecutorBlock(Block):
|
||||
input_schema: dict = SchemaField(description="Input schema for the graph")
|
||||
output_schema: dict = SchemaField(description="Output schema for the graph")
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
|
||||
return data.get("input_schema", {})
|
||||
|
||||
@classmethod
|
||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||
return data.get("data", {})
|
||||
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
required_fields = cls.get_input_schema(data).get("required", [])
|
||||
return set(required_fields) - set(data)
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
return json.validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
|
||||
class Output(BlockSchema):
|
||||
pass
|
||||
|
||||
@@ -75,8 +56,6 @@ class AgentExecutorBlock(Block):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
from backend.data.execution import ExecutionEventType
|
||||
|
||||
executor_manager = get_executor_manager_client()
|
||||
event_bus = get_event_bus()
|
||||
|
||||
@@ -90,11 +69,13 @@ class AgentExecutorBlock(Block):
|
||||
logger.info(f"Starting execution of {log_id}")
|
||||
|
||||
for event in event_bus.listen(
|
||||
user_id=graph_exec.user_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
graph_id=graph_exec.graph_id, graph_exec_id=graph_exec.graph_exec_id
|
||||
):
|
||||
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
|
||||
logger.info(
|
||||
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
|
||||
)
|
||||
|
||||
if not event.node_id:
|
||||
if event.status in [
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.TERMINATED,
|
||||
@@ -105,10 +86,6 @@ class AgentExecutorBlock(Block):
|
||||
else:
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
|
||||
)
|
||||
|
||||
if not event.block_id:
|
||||
logger.warning(f"{log_id} received event without block_id {event}")
|
||||
continue
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from backend.blocks.apollo._auth import ApolloCredentials
|
||||
from backend.blocks.apollo.models import (
|
||||
Contact,
|
||||
Organization,
|
||||
SearchOrganizationsRequest,
|
||||
SearchOrganizationsResponse,
|
||||
SearchPeopleRequest,
|
||||
SearchPeopleResponse,
|
||||
)
|
||||
from backend.util.request import Requests
|
||||
|
||||
logger = logging.getLogger(name=__name__)
|
||||
|
||||
|
||||
class ApolloClient:
|
||||
"""Client for the Apollo API"""
|
||||
|
||||
API_URL = "https://api.apollo.io/api/v1"
|
||||
|
||||
def __init__(self, credentials: ApolloCredentials):
|
||||
self.credentials = credentials
|
||||
self.requests = Requests()
|
||||
|
||||
def _get_headers(self) -> dict[str, str]:
|
||||
return {"x-api-key": self.credentials.api_key.get_secret_value()}
|
||||
|
||||
def search_people(self, query: SearchPeopleRequest) -> List[Contact]:
|
||||
"""Search for people in Apollo"""
|
||||
response = self.requests.get(
|
||||
f"{self.API_URL}/mixed_people/search",
|
||||
headers=self._get_headers(),
|
||||
params=query.model_dump(exclude={"credentials", "max_results"}),
|
||||
)
|
||||
parsed_response = SearchPeopleResponse(**response.json())
|
||||
if parsed_response.pagination.total_entries == 0:
|
||||
return []
|
||||
|
||||
people = parsed_response.people
|
||||
|
||||
# handle pagination
|
||||
if (
|
||||
query.max_results is not None
|
||||
and query.max_results < parsed_response.pagination.total_entries
|
||||
and len(people) < query.max_results
|
||||
):
|
||||
while (
|
||||
len(people) < query.max_results
|
||||
and query.page < parsed_response.pagination.total_pages
|
||||
and len(parsed_response.people) > 0
|
||||
):
|
||||
query.page += 1
|
||||
response = self.requests.get(
|
||||
f"{self.API_URL}/mixed_people/search",
|
||||
headers=self._get_headers(),
|
||||
params=query.model_dump(exclude={"credentials", "max_results"}),
|
||||
)
|
||||
parsed_response = SearchPeopleResponse(**response.json())
|
||||
people.extend(parsed_response.people[: query.max_results - len(people)])
|
||||
|
||||
logger.info(f"Found {len(people)} people")
|
||||
return people[: query.max_results] if query.max_results else people
|
||||
|
||||
def search_organizations(
|
||||
self, query: SearchOrganizationsRequest
|
||||
) -> List[Organization]:
|
||||
"""Search for organizations in Apollo"""
|
||||
response = self.requests.get(
|
||||
f"{self.API_URL}/mixed_companies/search",
|
||||
headers=self._get_headers(),
|
||||
params=query.model_dump(exclude={"credentials", "max_results"}),
|
||||
)
|
||||
parsed_response = SearchOrganizationsResponse(**response.json())
|
||||
if parsed_response.pagination.total_entries == 0:
|
||||
return []
|
||||
|
||||
organizations = parsed_response.organizations
|
||||
|
||||
# handle pagination
|
||||
if (
|
||||
query.max_results is not None
|
||||
and query.max_results < parsed_response.pagination.total_entries
|
||||
and len(organizations) < query.max_results
|
||||
):
|
||||
while (
|
||||
len(organizations) < query.max_results
|
||||
and query.page < parsed_response.pagination.total_pages
|
||||
and len(parsed_response.organizations) > 0
|
||||
):
|
||||
query.page += 1
|
||||
response = self.requests.get(
|
||||
f"{self.API_URL}/mixed_companies/search",
|
||||
headers=self._get_headers(),
|
||||
params=query.model_dump(exclude={"credentials", "max_results"}),
|
||||
)
|
||||
parsed_response = SearchOrganizationsResponse(**response.json())
|
||||
organizations.extend(
|
||||
parsed_response.organizations[
|
||||
: query.max_results - len(organizations)
|
||||
]
|
||||
)
|
||||
|
||||
logger.info(f"Found {len(organizations)} organizations")
|
||||
return (
|
||||
organizations[: query.max_results] if query.max_results else organizations
|
||||
)
|
||||
@@ -1,35 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
ApolloCredentials = APIKeyCredentials
|
||||
ApolloCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.APOLLO],
|
||||
Literal["api_key"],
|
||||
]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="apollo",
|
||||
api_key=SecretStr("mock-apollo-api-key"),
|
||||
title="Mock Apollo API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
def ApolloCredentialsField() -> ApolloCredentialsInput:
|
||||
"""
|
||||
Creates a Apollo credentials input on a block.
|
||||
"""
|
||||
return CredentialsField(
|
||||
description="The Apollo integration can be used with an API Key.",
|
||||
)
|
||||
@@ -1,543 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class PrimaryPhone(BaseModel):
|
||||
"""A primary phone in Apollo"""
|
||||
|
||||
number: str
|
||||
source: str
|
||||
sanitized_number: str
|
||||
|
||||
|
||||
class SenorityLevels(str, Enum):
|
||||
"""Seniority levels in Apollo"""
|
||||
|
||||
OWNER = "owner"
|
||||
FOUNDER = "founder"
|
||||
C_SUITE = "c_suite"
|
||||
PARTNER = "partner"
|
||||
VP = "vp"
|
||||
HEAD = "head"
|
||||
DIRECTOR = "director"
|
||||
MANAGER = "manager"
|
||||
SENIOR = "senior"
|
||||
ENTRY = "entry"
|
||||
INTERN = "intern"
|
||||
|
||||
|
||||
class ContactEmailStatuses(str, Enum):
|
||||
"""Contact email statuses in Apollo"""
|
||||
|
||||
VERIFIED = "verified"
|
||||
UNVERIFIED = "unverified"
|
||||
LIKELY_TO_ENGAGE = "likely_to_engage"
|
||||
UNAVAILABLE = "unavailable"
|
||||
|
||||
|
||||
class RuleConfigStatus(BaseModel):
|
||||
"""A rule config status in Apollo"""
|
||||
|
||||
_id: str
|
||||
created_at: str
|
||||
rule_action_config_id: str
|
||||
rule_config_id: str
|
||||
status_cd: str
|
||||
updated_at: str
|
||||
id: str
|
||||
key: str
|
||||
|
||||
|
||||
class ContactCampaignStatus(BaseModel):
|
||||
"""A contact campaign status in Apollo"""
|
||||
|
||||
id: str
|
||||
emailer_campaign_id: str
|
||||
send_email_from_user_id: str
|
||||
inactive_reason: str
|
||||
status: str
|
||||
added_at: str
|
||||
added_by_user_id: str
|
||||
finished_at: str
|
||||
paused_at: str
|
||||
auto_unpause_at: str
|
||||
send_email_from_email_address: str
|
||||
send_email_from_email_account_id: str
|
||||
manually_set_unpause: str
|
||||
failure_reason: str
|
||||
current_step_id: str
|
||||
in_response_to_emailer_message_id: str
|
||||
cc_emails: str
|
||||
bcc_emails: str
|
||||
to_emails: str
|
||||
|
||||
|
||||
class Account(BaseModel):
|
||||
"""An account in Apollo"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
website_url: str
|
||||
blog_url: str
|
||||
angellist_url: str
|
||||
linkedin_url: str
|
||||
twitter_url: str
|
||||
facebook_url: str
|
||||
primary_phone: PrimaryPhone
|
||||
languages: list[str]
|
||||
alexa_ranking: int
|
||||
phone: str
|
||||
linkedin_uid: str
|
||||
founded_year: int
|
||||
publicly_traded_symbol: str
|
||||
publicly_traded_exchange: str
|
||||
logo_url: str
|
||||
chrunchbase_url: str
|
||||
primary_domain: str
|
||||
domain: str
|
||||
team_id: str
|
||||
organization_id: str
|
||||
account_stage_id: str
|
||||
source: str
|
||||
original_source: str
|
||||
creator_id: str
|
||||
owner_id: str
|
||||
created_at: str
|
||||
phone_status: str
|
||||
hubspot_id: str
|
||||
salesforce_id: str
|
||||
crm_owner_id: str
|
||||
parent_account_id: str
|
||||
sanitized_phone: str
|
||||
# no listed type on the API docs
|
||||
account_playbook_statues: list[Any]
|
||||
account_rule_config_statuses: list[RuleConfigStatus]
|
||||
existence_level: str
|
||||
label_ids: list[str]
|
||||
typed_custom_fields: Any
|
||||
custom_field_errors: Any
|
||||
modality: str
|
||||
source_display_name: str
|
||||
salesforce_record_id: str
|
||||
crm_record_url: str
|
||||
|
||||
|
||||
class ContactEmail(BaseModel):
|
||||
"""A contact email in Apollo"""
|
||||
|
||||
email: str = ""
|
||||
email_md5: str = ""
|
||||
email_sha256: str = ""
|
||||
email_status: str = ""
|
||||
email_source: str = ""
|
||||
extrapolated_email_confidence: str = ""
|
||||
position: int = 0
|
||||
email_from_customer: str = ""
|
||||
free_domain: bool = True
|
||||
|
||||
|
||||
class EmploymentHistory(BaseModel):
|
||||
"""An employment history in Apollo"""
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
from_attributes = True
|
||||
populate_by_name = True
|
||||
|
||||
_id: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
current: Optional[bool] = None
|
||||
degree: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
emails: Optional[str] = None
|
||||
end_date: Optional[str] = None
|
||||
grade_level: Optional[str] = None
|
||||
kind: Optional[str] = None
|
||||
major: Optional[str] = None
|
||||
organization_id: Optional[str] = None
|
||||
organization_name: Optional[str] = None
|
||||
raw_address: Optional[str] = None
|
||||
start_date: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
id: Optional[str] = None
|
||||
key: Optional[str] = None
|
||||
|
||||
|
||||
class Breadcrumb(BaseModel):
|
||||
"""A breadcrumb in Apollo"""
|
||||
|
||||
label: Optional[str] = "N/A"
|
||||
signal_field_name: Optional[str] = "N/A"
|
||||
value: str | list | None = "N/A"
|
||||
display_name: Optional[str] = "N/A"
|
||||
|
||||
|
||||
class TypedCustomField(BaseModel):
|
||||
"""A typed custom field in Apollo"""
|
||||
|
||||
id: Optional[str] = "N/A"
|
||||
value: Optional[str] = "N/A"
|
||||
|
||||
|
||||
class Pagination(BaseModel):
|
||||
"""Pagination in Apollo"""
|
||||
|
||||
class Config:
|
||||
extra = "allow" # Allow extra fields
|
||||
arbitrary_types_allowed = True # Allow any type
|
||||
from_attributes = True # Allow from_orm
|
||||
populate_by_name = True # Allow field aliases to work both ways
|
||||
|
||||
page: int = 0
|
||||
per_page: int = 0
|
||||
total_entries: int = 0
|
||||
total_pages: int = 0
|
||||
|
||||
|
||||
class DialerFlags(BaseModel):
|
||||
"""A dialer flags in Apollo"""
|
||||
|
||||
country_name: str
|
||||
country_enabled: bool
|
||||
high_risk_calling_enabled: bool
|
||||
potential_high_risk_number: bool
|
||||
|
||||
|
||||
class PhoneNumber(BaseModel):
|
||||
"""A phone number in Apollo"""
|
||||
|
||||
raw_number: str = ""
|
||||
sanitized_number: str = ""
|
||||
type: str = ""
|
||||
position: int = 0
|
||||
status: str = ""
|
||||
dnc_status: str = ""
|
||||
dnc_other_info: str = ""
|
||||
dailer_flags: DialerFlags = DialerFlags(
|
||||
country_name="",
|
||||
country_enabled=True,
|
||||
high_risk_calling_enabled=True,
|
||||
potential_high_risk_number=True,
|
||||
)
|
||||
|
||||
|
||||
class Organization(BaseModel):
|
||||
"""An organization in Apollo"""
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
from_attributes = True
|
||||
populate_by_name = True
|
||||
|
||||
id: Optional[str] = "N/A"
|
||||
name: Optional[str] = "N/A"
|
||||
website_url: Optional[str] = "N/A"
|
||||
blog_url: Optional[str] = "N/A"
|
||||
angellist_url: Optional[str] = "N/A"
|
||||
linkedin_url: Optional[str] = "N/A"
|
||||
twitter_url: Optional[str] = "N/A"
|
||||
facebook_url: Optional[str] = "N/A"
|
||||
primary_phone: Optional[PrimaryPhone] = PrimaryPhone(
|
||||
number="N/A", source="N/A", sanitized_number="N/A"
|
||||
)
|
||||
languages: list[str] = []
|
||||
alexa_ranking: Optional[int] = 0
|
||||
phone: Optional[str] = "N/A"
|
||||
linkedin_uid: Optional[str] = "N/A"
|
||||
founded_year: Optional[int] = 0
|
||||
publicly_traded_symbol: Optional[str] = "N/A"
|
||||
publicly_traded_exchange: Optional[str] = "N/A"
|
||||
logo_url: Optional[str] = "N/A"
|
||||
chrunchbase_url: Optional[str] = "N/A"
|
||||
primary_domain: Optional[str] = "N/A"
|
||||
sanitized_phone: Optional[str] = "N/A"
|
||||
owned_by_organization_id: Optional[str] = "N/A"
|
||||
intent_strength: Optional[str] = "N/A"
|
||||
show_intent: bool = True
|
||||
has_intent_signal_account: Optional[bool] = True
|
||||
intent_signal_account: Optional[str] = "N/A"
|
||||
|
||||
|
||||
class Contact(BaseModel):
|
||||
"""A contact in Apollo"""
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
from_attributes = True
|
||||
populate_by_name = True
|
||||
|
||||
contact_roles: list[Any] = []
|
||||
id: Optional[str] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
linkedin_url: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
contact_stage_id: Optional[str] = None
|
||||
owner_id: Optional[str] = None
|
||||
creator_id: Optional[str] = None
|
||||
person_id: Optional[str] = None
|
||||
email_needs_tickling: bool = True
|
||||
organization_name: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
original_source: Optional[str] = None
|
||||
organization_id: Optional[str] = None
|
||||
headline: Optional[str] = None
|
||||
photo_url: Optional[str] = None
|
||||
present_raw_address: Optional[str] = None
|
||||
linkededin_uid: Optional[str] = None
|
||||
extrapolated_email_confidence: Optional[float] = None
|
||||
salesforce_id: Optional[str] = None
|
||||
salesforce_lead_id: Optional[str] = None
|
||||
salesforce_contact_id: Optional[str] = None
|
||||
saleforce_account_id: Optional[str] = None
|
||||
crm_owner_id: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
emailer_campaign_ids: list[str] = []
|
||||
direct_dial_status: Optional[str] = None
|
||||
direct_dial_enrichment_failed_at: Optional[str] = None
|
||||
email_status: Optional[str] = None
|
||||
email_source: Optional[str] = None
|
||||
account_id: Optional[str] = None
|
||||
last_activity_date: Optional[str] = None
|
||||
hubspot_vid: Optional[str] = None
|
||||
hubspot_company_id: Optional[str] = None
|
||||
crm_id: Optional[str] = None
|
||||
sanitized_phone: Optional[str] = None
|
||||
merged_crm_ids: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
queued_for_crm_push: bool = True
|
||||
suggested_from_rule_engine_config_id: Optional[str] = None
|
||||
email_unsubscribed: Optional[str] = None
|
||||
label_ids: list[Any] = []
|
||||
has_pending_email_arcgate_request: bool = True
|
||||
has_email_arcgate_request: bool = True
|
||||
existence_level: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
email_from_customer: Optional[str] = None
|
||||
typed_custom_fields: list[TypedCustomField] = []
|
||||
custom_field_errors: Any = None
|
||||
salesforce_record_id: Optional[str] = None
|
||||
crm_record_url: Optional[str] = None
|
||||
email_status_unavailable_reason: Optional[str] = None
|
||||
email_true_status: Optional[str] = None
|
||||
updated_email_true_status: bool = True
|
||||
contact_rule_config_statuses: list[RuleConfigStatus] = []
|
||||
source_display_name: Optional[str] = None
|
||||
twitter_url: Optional[str] = None
|
||||
contact_campaign_statuses: list[ContactCampaignStatus] = []
|
||||
state: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
account: Optional[Account] = None
|
||||
contact_emails: list[ContactEmail] = []
|
||||
organization: Optional[Organization] = None
|
||||
employment_history: list[EmploymentHistory] = []
|
||||
time_zone: Optional[str] = None
|
||||
intent_strength: Optional[str] = None
|
||||
show_intent: bool = True
|
||||
phone_numbers: list[PhoneNumber] = []
|
||||
account_phone_note: Optional[str] = None
|
||||
free_domain: bool = True
|
||||
is_likely_to_engage: bool = True
|
||||
email_domain_catchall: bool = True
|
||||
contact_job_change_event: Optional[str] = None
|
||||
|
||||
|
||||
class SearchOrganizationsRequest(BaseModel):
|
||||
"""Request for Apollo's search organizations API"""
|
||||
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default=[0, 1000000],
|
||||
)
|
||||
|
||||
organization_locations: list[str] = SchemaField(
|
||||
description="""The location of the company headquarters. You can search across cities, US states, and countries.
|
||||
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, any Boston-based companies will not appearch in your search results, even if they match other parameters.
|
||||
|
||||
To exclude companies based on location, use the organization_not_locations parameter.
|
||||
""",
|
||||
default=[],
|
||||
)
|
||||
organizations_not_locations: list[str] = SchemaField(
|
||||
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
|
||||
|
||||
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
|
||||
""",
|
||||
default=[],
|
||||
)
|
||||
q_organization_keyword_tags: list[str] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry."""
|
||||
)
|
||||
q_organization_name: str = SchemaField(
|
||||
description="""Filter search results to include a specific company name.
|
||||
|
||||
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible."""
|
||||
)
|
||||
organization_ids: list[str] = SchemaField(
|
||||
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, identify the values for organization_id when you call this endpoint.""",
|
||||
default=[],
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
ge=1,
|
||||
le=50000,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
page: int = SchemaField(
|
||||
description="""The page number of the Apollo data that you want to retrieve.
|
||||
|
||||
Use this parameter in combination with the per_page parameter to make search results for navigable and improve the performance of the endpoint.""",
|
||||
default=1,
|
||||
)
|
||||
per_page: int = SchemaField(
|
||||
description="""The number of search results that should be returned for each page. Limited the number of results per page improves the endpoint's performance.
|
||||
|
||||
Use the page parameter to search the different pages of data.""",
|
||||
default=100,
|
||||
)
|
||||
|
||||
|
||||
class SearchOrganizationsResponse(BaseModel):
|
||||
"""Response from Apollo's search organizations API"""
|
||||
|
||||
breadcrumbs: list[Breadcrumb] = []
|
||||
partial_results_only: bool = True
|
||||
has_join: bool = True
|
||||
disable_eu_prospecting: bool = True
|
||||
partial_results_limit: int = 0
|
||||
pagination: Pagination = Pagination(
|
||||
page=0, per_page=0, total_entries=0, total_pages=0
|
||||
)
|
||||
# no listed type on the API docs
|
||||
accounts: list[Any] = []
|
||||
organizations: list[Organization] = []
|
||||
models_ids: list[str] = []
|
||||
num_fetch_result: Optional[str] = "N/A"
|
||||
derived_params: Optional[str] = "N/A"
|
||||
|
||||
|
||||
class SearchPeopleRequest(BaseModel):
|
||||
"""Request for Apollo's search people API"""
|
||||
|
||||
person_titles: list[str] = SchemaField(
|
||||
description="""Job titles held by the people you want to find. For a person to be included in search results, they only need to match 1 of the job titles you add. Adding more job titles expands your search results.
|
||||
|
||||
Results also include job titles with the same terms, even if they are not exact matches. For example, searching for marketing manager might return people with the job title content marketing manager.
|
||||
|
||||
Use this parameter in combination with the person_seniorities[] parameter to find people based on specific job functions and seniority levels.
|
||||
""",
|
||||
default=[],
|
||||
placeholder="marketing manager",
|
||||
)
|
||||
person_locations: list[str] = SchemaField(
|
||||
description="""The location where people live. You can search across cities, US states, and countries.
|
||||
|
||||
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
|
||||
default=[],
|
||||
)
|
||||
person_seniorities: list[SenorityLevels] = SchemaField(
|
||||
description="""The job seniority that people hold within their current employer. This enables you to find people that currently hold positions at certain reporting levels, such as Director level or senior IC level.
|
||||
|
||||
For a person to be included in search results, they only need to match 1 of the seniorities you add. Adding more seniorities expands your search results.
|
||||
|
||||
Searches only return results based on their current job title, so searching for Director-level employees only returns people that currently hold a Director-level title. If someone was previously a Director, but is currently a VP, they would not be included in your search results.
|
||||
|
||||
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
|
||||
default=[],
|
||||
)
|
||||
organization_locations: list[str] = SchemaField(
|
||||
description="""The location of the company headquarters for a person's current employer. You can search across cities, US states, and countries.
|
||||
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
|
||||
|
||||
To find people based on their personal location, use the person_locations parameter.""",
|
||||
default=[],
|
||||
)
|
||||
q_organization_domains: list[str] = SchemaField(
|
||||
description="""The domain name for the person's employer. This can be the current employer or a previous employer. Do not include www., the @ symbol, or similar.
|
||||
|
||||
You can add multiple domains to search across companies.
|
||||
|
||||
Examples: apollo.io and microsoft.com""",
|
||||
default=[],
|
||||
)
|
||||
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
|
||||
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
|
||||
default=[],
|
||||
)
|
||||
organization_ids: list[str] = SchemaField(
|
||||
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
|
||||
default=[],
|
||||
)
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default=[],
|
||||
)
|
||||
q_keywords: str = SchemaField(
|
||||
description="""A string of words over which we want to filter the results""",
|
||||
default="",
|
||||
)
|
||||
page: int = SchemaField(
|
||||
description="""The page number of the Apollo data that you want to retrieve.
|
||||
|
||||
Use this parameter in combination with the per_page parameter to make search results for navigable and improve the performance of the endpoint.""",
|
||||
default=1,
|
||||
)
|
||||
per_page: int = SchemaField(
|
||||
description="""The number of search results that should be returned for each page. Limited the number of results per page improves the endpoint's performance.
|
||||
|
||||
Use the page parameter to search the different pages of data.""",
|
||||
default=100,
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
ge=1,
|
||||
le=50000,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class SearchPeopleResponse(BaseModel):
|
||||
"""Response from Apollo's search people API"""
|
||||
|
||||
class Config:
|
||||
extra = "allow" # Allow extra fields
|
||||
arbitrary_types_allowed = True # Allow any type
|
||||
from_attributes = True # Allow from_orm
|
||||
populate_by_name = True # Allow field aliases to work both ways
|
||||
|
||||
breadcrumbs: list[Breadcrumb] = []
|
||||
partial_results_only: bool = True
|
||||
has_join: bool = True
|
||||
disable_eu_prospecting: bool = True
|
||||
partial_results_limit: int = 0
|
||||
pagination: Pagination = Pagination(
|
||||
page=0, per_page=0, total_entries=0, total_pages=0
|
||||
)
|
||||
contacts: list[Contact] = []
|
||||
people: list[Contact] = []
|
||||
model_ids: list[str] = []
|
||||
num_fetch_result: Optional[str] = "N/A"
|
||||
derived_params: Optional[str] = "N/A"
|
||||
@@ -1,219 +0,0 @@
|
||||
from backend.blocks.apollo._api import ApolloClient
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
ApolloCredentials,
|
||||
ApolloCredentialsInput,
|
||||
)
|
||||
from backend.blocks.apollo.models import (
|
||||
Organization,
|
||||
PrimaryPhone,
|
||||
SearchOrganizationsRequest,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class SearchOrganizationsBlock(Block):
|
||||
"""Search for organizations in Apollo"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default=[0, 1000000],
|
||||
)
|
||||
|
||||
organization_locations: list[str] = SchemaField(
|
||||
description="""The location of the company headquarters. You can search across cities, US states, and countries.
|
||||
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, any Boston-based companies will not appearch in your search results, even if they match other parameters.
|
||||
|
||||
To exclude companies based on location, use the organization_not_locations parameter.
|
||||
""",
|
||||
default=[],
|
||||
)
|
||||
organizations_not_locations: list[str] = SchemaField(
|
||||
description="""Exclude companies from search results based on the location of the company headquarters. You can use cities, US states, and countries as locations to exclude.
|
||||
|
||||
This parameter is useful for ensuring you do not prospect in an undesirable territory. For example, if you use ireland as a value, no Ireland-based companies will appear in your search results.
|
||||
""",
|
||||
default=[],
|
||||
)
|
||||
q_organization_keyword_tags: list[str] = SchemaField(
|
||||
description="""Filter search results based on keywords associated with companies. For example, you can enter mining as a value to return only companies that have an association with the mining industry.""",
|
||||
default=[],
|
||||
)
|
||||
q_organization_name: str = SchemaField(
|
||||
description="""Filter search results to include a specific company name.
|
||||
|
||||
If the value you enter for this parameter does not match with a company's name, the company will not appear in search results, even if it matches other parameters. Partial matches are accepted. For example, if you filter by the value marketing, a company called NY Marketing Unlimited would still be eligible as a search result, but NY Market Analysis would not be eligible.""",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
organization_ids: list[str] = SchemaField(
|
||||
description="""The Apollo IDs for the companies you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, identify the values for organization_id when you call this endpoint.""",
|
||||
default=[],
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
ge=1,
|
||||
le=50000,
|
||||
advanced=True,
|
||||
)
|
||||
credentials: ApolloCredentialsInput = SchemaField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
organizations: list[Organization] = SchemaField(
|
||||
description="List of organizations found",
|
||||
default=[],
|
||||
)
|
||||
organization: Organization = SchemaField(
|
||||
description="Each found organization, one at a time",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the search failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3d71270d-599e-4148-9b95-71b35d2f44f0",
|
||||
description="Search for organizations in Apollo",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=SearchOrganizationsBlock.Input,
|
||||
output_schema=SearchOrganizationsBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"query": "Google", "credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
(
|
||||
"organization",
|
||||
Organization(
|
||||
id="1",
|
||||
name="Google",
|
||||
website_url="https://google.com",
|
||||
blog_url="https://google.com/blog",
|
||||
angellist_url="https://angel.co/google",
|
||||
linkedin_url="https://linkedin.com/company/google",
|
||||
twitter_url="https://twitter.com/google",
|
||||
facebook_url="https://facebook.com/google",
|
||||
primary_phone=PrimaryPhone(
|
||||
source="google",
|
||||
number="1234567890",
|
||||
sanitized_number="1234567890",
|
||||
),
|
||||
languages=["en"],
|
||||
alexa_ranking=1000,
|
||||
phone="1234567890",
|
||||
linkedin_uid="1234567890",
|
||||
founded_year=2000,
|
||||
publicly_traded_symbol="GOOGL",
|
||||
publicly_traded_exchange="NASDAQ",
|
||||
logo_url="https://google.com/logo.png",
|
||||
chrunchbase_url="https://chrunchbase.com/google",
|
||||
primary_domain="google.com",
|
||||
sanitized_phone="1234567890",
|
||||
owned_by_organization_id="1",
|
||||
intent_strength="strong",
|
||||
show_intent=True,
|
||||
has_intent_signal_account=True,
|
||||
intent_signal_account="1",
|
||||
),
|
||||
),
|
||||
(
|
||||
"organizations",
|
||||
[
|
||||
Organization(
|
||||
id="1",
|
||||
name="Google",
|
||||
website_url="https://google.com",
|
||||
blog_url="https://google.com/blog",
|
||||
angellist_url="https://angel.co/google",
|
||||
linkedin_url="https://linkedin.com/company/google",
|
||||
twitter_url="https://twitter.com/google",
|
||||
facebook_url="https://facebook.com/google",
|
||||
primary_phone=PrimaryPhone(
|
||||
source="google",
|
||||
number="1234567890",
|
||||
sanitized_number="1234567890",
|
||||
),
|
||||
languages=["en"],
|
||||
alexa_ranking=1000,
|
||||
phone="1234567890",
|
||||
linkedin_uid="1234567890",
|
||||
founded_year=2000,
|
||||
publicly_traded_symbol="GOOGL",
|
||||
publicly_traded_exchange="NASDAQ",
|
||||
logo_url="https://google.com/logo.png",
|
||||
chrunchbase_url="https://chrunchbase.com/google",
|
||||
primary_domain="google.com",
|
||||
sanitized_phone="1234567890",
|
||||
owned_by_organization_id="1",
|
||||
intent_strength="strong",
|
||||
show_intent=True,
|
||||
has_intent_signal_account=True,
|
||||
intent_signal_account="1",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"search_organizations": lambda *args, **kwargs: [
|
||||
Organization(
|
||||
id="1",
|
||||
name="Google",
|
||||
website_url="https://google.com",
|
||||
blog_url="https://google.com/blog",
|
||||
angellist_url="https://angel.co/google",
|
||||
linkedin_url="https://linkedin.com/company/google",
|
||||
twitter_url="https://twitter.com/google",
|
||||
facebook_url="https://facebook.com/google",
|
||||
primary_phone=PrimaryPhone(
|
||||
source="google",
|
||||
number="1234567890",
|
||||
sanitized_number="1234567890",
|
||||
),
|
||||
languages=["en"],
|
||||
alexa_ranking=1000,
|
||||
phone="1234567890",
|
||||
linkedin_uid="1234567890",
|
||||
founded_year=2000,
|
||||
publicly_traded_symbol="GOOGL",
|
||||
publicly_traded_exchange="NASDAQ",
|
||||
logo_url="https://google.com/logo.png",
|
||||
chrunchbase_url="https://chrunchbase.com/google",
|
||||
primary_domain="google.com",
|
||||
sanitized_phone="1234567890",
|
||||
owned_by_organization_id="1",
|
||||
intent_strength="strong",
|
||||
show_intent=True,
|
||||
has_intent_signal_account=True,
|
||||
intent_signal_account="1",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def search_organizations(
|
||||
query: SearchOrganizationsRequest, credentials: ApolloCredentials
|
||||
) -> list[Organization]:
|
||||
client = ApolloClient(credentials)
|
||||
return client.search_organizations(query)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: ApolloCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
query = SearchOrganizationsRequest(
|
||||
**input_data.model_dump(exclude={"credentials"})
|
||||
)
|
||||
organizations = self.search_organizations(query, credentials)
|
||||
for organization in organizations:
|
||||
yield "organization", organization
|
||||
yield "organizations", organizations
|
||||
@@ -1,394 +0,0 @@
|
||||
from backend.blocks.apollo._api import ApolloClient
|
||||
from backend.blocks.apollo._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
ApolloCredentials,
|
||||
ApolloCredentialsInput,
|
||||
)
|
||||
from backend.blocks.apollo.models import (
|
||||
Contact,
|
||||
ContactEmailStatuses,
|
||||
SearchPeopleRequest,
|
||||
SenorityLevels,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class SearchPeopleBlock(Block):
|
||||
"""Search for people in Apollo"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
person_titles: list[str] = SchemaField(
|
||||
description="""Job titles held by the people you want to find. For a person to be included in search results, they only need to match 1 of the job titles you add. Adding more job titles expands your search results.
|
||||
|
||||
Results also include job titles with the same terms, even if they are not exact matches. For example, searching for marketing manager might return people with the job title content marketing manager.
|
||||
|
||||
Use this parameter in combination with the person_seniorities[] parameter to find people based on specific job functions and seniority levels.
|
||||
""",
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
person_locations: list[str] = SchemaField(
|
||||
description="""The location where people live. You can search across cities, US states, and countries.
|
||||
|
||||
To find people based on the headquarters locations of their current employer, use the organization_locations parameter.""",
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
person_seniorities: list[SenorityLevels] = SchemaField(
|
||||
description="""The job seniority that people hold within their current employer. This enables you to find people that currently hold positions at certain reporting levels, such as Director level or senior IC level.
|
||||
|
||||
For a person to be included in search results, they only need to match 1 of the seniorities you add. Adding more seniorities expands your search results.
|
||||
|
||||
Searches only return results based on their current job title, so searching for Director-level employees only returns people that currently hold a Director-level title. If someone was previously a Director, but is currently a VP, they would not be included in your search results.
|
||||
|
||||
Use this parameter in combination with the person_titles[] parameter to find people based on specific job functions and seniority levels.""",
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
organization_locations: list[str] = SchemaField(
|
||||
description="""The location of the company headquarters for a person's current employer. You can search across cities, US states, and countries.
|
||||
|
||||
If a company has several office locations, results are still based on the headquarters location. For example, if you search chicago but a company's HQ location is in boston, people that work for the Boston-based company will not appear in your results, even if they match other parameters.
|
||||
|
||||
To find people based on their personal location, use the person_locations parameter.""",
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
q_organization_domains: list[str] = SchemaField(
|
||||
description="""The domain name for the person's employer. This can be the current employer or a previous employer. Do not include www., the @ symbol, or similar.
|
||||
|
||||
You can add multiple domains to search across companies.
|
||||
|
||||
Examples: apollo.io and microsoft.com""",
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
contact_email_statuses: list[ContactEmailStatuses] = SchemaField(
|
||||
description="""The email statuses for the people you want to find. You can add multiple statuses to expand your search.""",
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
organization_ids: list[str] = SchemaField(
|
||||
description="""The Apollo IDs for the companies (employers) you want to include in your search results. Each company in the Apollo database is assigned a unique ID.
|
||||
|
||||
To find IDs, call the Organization Search endpoint and identify the values for organization_id.""",
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
organization_num_empoloyees_range: list[int] = SchemaField(
|
||||
description="""The number range of employees working for the company. This enables you to find companies based on headcount. You can add multiple ranges to expand your search results.
|
||||
|
||||
Each range you add needs to be a string, with the upper and lower numbers of the range separated only by a comma.""",
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
q_keywords: str = SchemaField(
|
||||
description="""A string of words over which we want to filter the results""",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
max_results: int = SchemaField(
|
||||
description="""The maximum number of results to return. If you don't specify this parameter, the default is 100.""",
|
||||
default=100,
|
||||
ge=1,
|
||||
le=50000,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
credentials: ApolloCredentialsInput = SchemaField(
|
||||
description="Apollo credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
people: list[Contact] = SchemaField(
|
||||
description="List of people found",
|
||||
default=[],
|
||||
)
|
||||
person: Contact = SchemaField(
|
||||
description="Each found person, one at a time",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the search failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c2adb3aa-5aae-488d-8a6e-4eb8c23e2ed6",
|
||||
description="Search for people in Apollo",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=SearchPeopleBlock.Input,
|
||||
output_schema=SearchPeopleBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
(
|
||||
"person",
|
||||
Contact(
|
||||
contact_roles=[],
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
organization_id="123456",
|
||||
contact_stage_id="1",
|
||||
owner_id="1",
|
||||
creator_id="1",
|
||||
person_id="1",
|
||||
email_needs_tickling=True,
|
||||
source="apollo",
|
||||
original_source="apollo",
|
||||
headline="Software Engineer",
|
||||
photo_url="https://www.linkedin.com/in/johndoe",
|
||||
present_raw_address="123 Main St, Anytown, USA",
|
||||
linkededin_uid="123456",
|
||||
extrapolated_email_confidence=0.8,
|
||||
salesforce_id="123456",
|
||||
salesforce_lead_id="123456",
|
||||
salesforce_contact_id="123456",
|
||||
saleforce_account_id="123456",
|
||||
crm_owner_id="123456",
|
||||
created_at="2021-01-01",
|
||||
emailer_campaign_ids=[],
|
||||
direct_dial_status="active",
|
||||
direct_dial_enrichment_failed_at="2021-01-01",
|
||||
email_status="active",
|
||||
email_source="apollo",
|
||||
account_id="123456",
|
||||
last_activity_date="2021-01-01",
|
||||
hubspot_vid="123456",
|
||||
hubspot_company_id="123456",
|
||||
crm_id="123456",
|
||||
sanitized_phone="123456",
|
||||
merged_crm_ids="123456",
|
||||
updated_at="2021-01-01",
|
||||
queued_for_crm_push=True,
|
||||
suggested_from_rule_engine_config_id="123456",
|
||||
email_unsubscribed=None,
|
||||
label_ids=[],
|
||||
has_pending_email_arcgate_request=True,
|
||||
has_email_arcgate_request=True,
|
||||
existence_level=None,
|
||||
email=None,
|
||||
email_from_customer=None,
|
||||
typed_custom_fields=[],
|
||||
custom_field_errors=None,
|
||||
salesforce_record_id=None,
|
||||
crm_record_url=None,
|
||||
email_status_unavailable_reason=None,
|
||||
email_true_status=None,
|
||||
updated_email_true_status=True,
|
||||
contact_rule_config_statuses=[],
|
||||
source_display_name=None,
|
||||
twitter_url=None,
|
||||
contact_campaign_statuses=[],
|
||||
state=None,
|
||||
city=None,
|
||||
country=None,
|
||||
account=None,
|
||||
contact_emails=[],
|
||||
organization=None,
|
||||
employment_history=[],
|
||||
time_zone=None,
|
||||
intent_strength=None,
|
||||
show_intent=True,
|
||||
phone_numbers=[],
|
||||
account_phone_note=None,
|
||||
free_domain=True,
|
||||
is_likely_to_engage=True,
|
||||
email_domain_catchall=True,
|
||||
contact_job_change_event=None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"people",
|
||||
[
|
||||
Contact(
|
||||
contact_roles=[],
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
organization_id="123456",
|
||||
contact_stage_id="1",
|
||||
owner_id="1",
|
||||
creator_id="1",
|
||||
person_id="1",
|
||||
email_needs_tickling=True,
|
||||
source="apollo",
|
||||
original_source="apollo",
|
||||
headline="Software Engineer",
|
||||
photo_url="https://www.linkedin.com/in/johndoe",
|
||||
present_raw_address="123 Main St, Anytown, USA",
|
||||
linkededin_uid="123456",
|
||||
extrapolated_email_confidence=0.8,
|
||||
salesforce_id="123456",
|
||||
salesforce_lead_id="123456",
|
||||
salesforce_contact_id="123456",
|
||||
saleforce_account_id="123456",
|
||||
crm_owner_id="123456",
|
||||
created_at="2021-01-01",
|
||||
emailer_campaign_ids=[],
|
||||
direct_dial_status="active",
|
||||
direct_dial_enrichment_failed_at="2021-01-01",
|
||||
email_status="active",
|
||||
email_source="apollo",
|
||||
account_id="123456",
|
||||
last_activity_date="2021-01-01",
|
||||
hubspot_vid="123456",
|
||||
hubspot_company_id="123456",
|
||||
crm_id="123456",
|
||||
sanitized_phone="123456",
|
||||
merged_crm_ids="123456",
|
||||
updated_at="2021-01-01",
|
||||
queued_for_crm_push=True,
|
||||
suggested_from_rule_engine_config_id="123456",
|
||||
email_unsubscribed=None,
|
||||
label_ids=[],
|
||||
has_pending_email_arcgate_request=True,
|
||||
has_email_arcgate_request=True,
|
||||
existence_level=None,
|
||||
email=None,
|
||||
email_from_customer=None,
|
||||
typed_custom_fields=[],
|
||||
custom_field_errors=None,
|
||||
salesforce_record_id=None,
|
||||
crm_record_url=None,
|
||||
email_status_unavailable_reason=None,
|
||||
email_true_status=None,
|
||||
updated_email_true_status=True,
|
||||
contact_rule_config_statuses=[],
|
||||
source_display_name=None,
|
||||
twitter_url=None,
|
||||
contact_campaign_statuses=[],
|
||||
state=None,
|
||||
city=None,
|
||||
country=None,
|
||||
account=None,
|
||||
contact_emails=[],
|
||||
organization=None,
|
||||
employment_history=[],
|
||||
time_zone=None,
|
||||
intent_strength=None,
|
||||
show_intent=True,
|
||||
phone_numbers=[],
|
||||
account_phone_note=None,
|
||||
free_domain=True,
|
||||
is_likely_to_engage=True,
|
||||
email_domain_catchall=True,
|
||||
contact_job_change_event=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"search_people": lambda query, credentials: [
|
||||
Contact(
|
||||
id="1",
|
||||
name="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
linkedin_url="https://www.linkedin.com/in/johndoe",
|
||||
title="Software Engineer",
|
||||
organization_name="Google",
|
||||
organization_id="123456",
|
||||
contact_stage_id="1",
|
||||
owner_id="1",
|
||||
creator_id="1",
|
||||
person_id="1",
|
||||
email_needs_tickling=True,
|
||||
source="apollo",
|
||||
original_source="apollo",
|
||||
headline="Software Engineer",
|
||||
photo_url="https://www.linkedin.com/in/johndoe",
|
||||
present_raw_address="123 Main St, Anytown, USA",
|
||||
linkededin_uid="123456",
|
||||
extrapolated_email_confidence=0.8,
|
||||
salesforce_id="123456",
|
||||
salesforce_lead_id="123456",
|
||||
salesforce_contact_id="123456",
|
||||
saleforce_account_id="123456",
|
||||
crm_owner_id="123456",
|
||||
created_at="2021-01-01",
|
||||
emailer_campaign_ids=[],
|
||||
direct_dial_status="active",
|
||||
direct_dial_enrichment_failed_at="2021-01-01",
|
||||
email_status="active",
|
||||
email_source="apollo",
|
||||
account_id="123456",
|
||||
last_activity_date="2021-01-01",
|
||||
hubspot_vid="123456",
|
||||
hubspot_company_id="123456",
|
||||
crm_id="123456",
|
||||
sanitized_phone="123456",
|
||||
merged_crm_ids="123456",
|
||||
updated_at="2021-01-01",
|
||||
queued_for_crm_push=True,
|
||||
suggested_from_rule_engine_config_id="123456",
|
||||
email_unsubscribed=None,
|
||||
label_ids=[],
|
||||
has_pending_email_arcgate_request=True,
|
||||
has_email_arcgate_request=True,
|
||||
existence_level=None,
|
||||
email=None,
|
||||
email_from_customer=None,
|
||||
typed_custom_fields=[],
|
||||
custom_field_errors=None,
|
||||
salesforce_record_id=None,
|
||||
crm_record_url=None,
|
||||
email_status_unavailable_reason=None,
|
||||
email_true_status=None,
|
||||
updated_email_true_status=True,
|
||||
contact_rule_config_statuses=[],
|
||||
source_display_name=None,
|
||||
twitter_url=None,
|
||||
contact_campaign_statuses=[],
|
||||
state=None,
|
||||
city=None,
|
||||
country=None,
|
||||
account=None,
|
||||
contact_emails=[],
|
||||
organization=None,
|
||||
employment_history=[],
|
||||
time_zone=None,
|
||||
intent_strength=None,
|
||||
show_intent=True,
|
||||
phone_numbers=[],
|
||||
account_phone_note=None,
|
||||
free_domain=True,
|
||||
is_likely_to_engage=True,
|
||||
email_domain_catchall=True,
|
||||
contact_job_change_event=None,
|
||||
),
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def search_people(
|
||||
query: SearchPeopleRequest, credentials: ApolloCredentials
|
||||
) -> list[Contact]:
|
||||
client = ApolloClient(credentials)
|
||||
return client.search_people(query)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: ApolloCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
query = SearchPeopleRequest(**input_data.model_dump(exclude={"credentials"}))
|
||||
people = self.search_people(query, credentials)
|
||||
for person in people:
|
||||
yield "person", person
|
||||
yield "people", people
|
||||
@@ -1,48 +1,11 @@
|
||||
import enum
|
||||
from typing import Any, List
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.type import MediaFileType, convert
|
||||
from backend.util.text import TextFormatter
|
||||
|
||||
|
||||
class FileStoreBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
file_in: MediaFileType = SchemaField(
|
||||
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
file_out: MediaFileType = SchemaField(
|
||||
description="The relative path to the stored file in the temporary directory."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="cbb50872-625b-42f0-8203-a2ae78242d8a",
|
||||
description="Stores the input file in the temporary directory.",
|
||||
categories={BlockCategory.BASIC, BlockCategory.MULTIMEDIA},
|
||||
input_schema=FileStoreBlock.Input,
|
||||
output_schema=FileStoreBlock.Output,
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
file_path = store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.file_in,
|
||||
return_content=False,
|
||||
)
|
||||
yield "file_out", file_path
|
||||
formatter = TextFormatter()
|
||||
|
||||
|
||||
class StoreValueBlock(Block):
|
||||
@@ -88,6 +51,29 @@ class StoreValueBlock(Block):
|
||||
yield "output", input_data.data or input_data.input
|
||||
|
||||
|
||||
class PrintToConsoleBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(description="The text to print to the console.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="The status of the print operation.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
|
||||
description="Print the given text to the console, this is used for a debugging purpose.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=PrintToConsoleBlock.Input,
|
||||
output_schema=PrintToConsoleBlock.Output,
|
||||
test_input={"text": "Hello, World!"},
|
||||
test_output=("status", "printed"),
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
print(">>>>> Print: ", input_data.text)
|
||||
yield "status", "printed"
|
||||
|
||||
|
||||
class FindInDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: Any = SchemaField(description="Dictionary to lookup from")
|
||||
@@ -128,9 +114,6 @@ class FindInDictionaryBlock(Block):
|
||||
obj = input_data.input
|
||||
key = input_data.key
|
||||
|
||||
if isinstance(obj, str):
|
||||
obj = json.loads(obj)
|
||||
|
||||
if isinstance(obj, dict) and key in obj:
|
||||
yield "output", obj[key]
|
||||
elif isinstance(obj, list) and isinstance(key, int) and 0 <= key < len(obj):
|
||||
@@ -148,6 +131,186 @@ class FindInDictionaryBlock(Block):
|
||||
yield "missing", input_data.input
|
||||
|
||||
|
||||
class AgentInputBlock(Block):
|
||||
"""
|
||||
This block is used to provide input to the graph.
|
||||
|
||||
It takes in a value, name, description, default values list and bool to limit selection to default values.
|
||||
|
||||
It Outputs the value passed as input.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
name: str = SchemaField(description="The name of the input.")
|
||||
value: Any = SchemaField(
|
||||
description="The value to be passed as input.",
|
||||
default=None,
|
||||
)
|
||||
title: str | None = SchemaField(
|
||||
description="The title of the input.", default=None, advanced=True
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the input.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
placeholder_values: List[Any] = SchemaField(
|
||||
description="The placeholder values to be passed as input.",
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
limit_to_placeholder_values: bool = SchemaField(
|
||||
description="Whether to limit the selection to placeholder values.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to show the input in the advanced section, if the field is not required.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
secret: bool = SchemaField(
|
||||
description="Whether the input should be treated as a secret.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: Any = SchemaField(description="The value passed as input.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
|
||||
description="This block is used to provide input to the graph.",
|
||||
input_schema=AgentInputBlock.Input,
|
||||
output_schema=AgentInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "input_1",
|
||||
"description": "This is a test input.",
|
||||
"placeholder_values": [],
|
||||
"limit_to_placeholder_values": False,
|
||||
},
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "input_2",
|
||||
"description": "This is a test input.",
|
||||
"placeholder_values": ["Hello, World!"],
|
||||
"limit_to_placeholder_values": True,
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", "Hello, World!"),
|
||||
("result", "Hello, World!"),
|
||||
],
|
||||
categories={BlockCategory.INPUT, BlockCategory.BASIC},
|
||||
block_type=BlockType.INPUT,
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "result", input_data.value
|
||||
|
||||
|
||||
class AgentOutputBlock(Block):
|
||||
"""
|
||||
Records the output of the graph for users to see.
|
||||
|
||||
Behavior:
|
||||
If `format` is provided and the `value` is of a type that can be formatted,
|
||||
the block attempts to format the recorded_value using the `format`.
|
||||
If formatting fails or no `format` is provided, the raw `value` is output.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
value: Any = SchemaField(
|
||||
description="The value to be recorded as output.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
name: str = SchemaField(description="The name of the output.")
|
||||
title: str | None = SchemaField(
|
||||
description="The title of the output.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the output.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="The format string to be used to format the recorded_value. Use Jinja2 syntax.",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to treat the output as advanced.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
secret: bool = SchemaField(
|
||||
description="Whether the output should be treated as a secret.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="The value recorded as output.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="363ae599-353e-4804-937e-b2ee3cef3da4",
|
||||
description="Stores the output of the graph for users to see.",
|
||||
input_schema=AgentOutputBlock.Input,
|
||||
output_schema=AgentOutputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "output_1",
|
||||
"description": "This is a test output.",
|
||||
"format": "{{ output_1 }}!!",
|
||||
},
|
||||
{
|
||||
"value": "42",
|
||||
"name": "output_2",
|
||||
"description": "This is another test output.",
|
||||
"format": "{{ output_2 }}",
|
||||
},
|
||||
{
|
||||
"value": MockObject(value="!!", key="key"),
|
||||
"name": "output_3",
|
||||
"description": "This is a test output with a mock object.",
|
||||
"format": "{{ output_3 }}",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("output", "Hello, World!!!"),
|
||||
("output", "42"),
|
||||
("output", MockObject(value="!!", key="key")),
|
||||
],
|
||||
categories={BlockCategory.OUTPUT, BlockCategory.BASIC},
|
||||
block_type=BlockType.OUTPUT,
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Attempts to format the recorded_value using the fmt_string if provided.
|
||||
If formatting fails or no fmt_string is given, returns the original recorded_value.
|
||||
"""
|
||||
if input_data.format:
|
||||
try:
|
||||
yield "output", formatter.format_string(
|
||||
input_data.format, {input_data.name: input_data.value}
|
||||
)
|
||||
except Exception as e:
|
||||
yield "output", f"Error: {e}, {input_data.value}"
|
||||
else:
|
||||
yield "output", input_data.value
|
||||
|
||||
|
||||
class AddToDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
dictionary: dict[Any, Any] = SchemaField(
|
||||
@@ -306,48 +469,6 @@ class AddToListBlock(Block):
|
||||
yield "updated_list", updated_list
|
||||
|
||||
|
||||
class FindInListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
list: List[Any] = SchemaField(description="The list to search in.")
|
||||
value: Any = SchemaField(description="The value to search for.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
index: int = SchemaField(description="The index of the value in the list.")
|
||||
found: bool = SchemaField(
|
||||
description="Whether the value was found in the list."
|
||||
)
|
||||
not_found_value: Any = SchemaField(
|
||||
description="The value that was not found in the list."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5e2c6d0a-1e37-489f-b1d0-8e1812b23333",
|
||||
description="Finds the index of the value in the list.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=FindInListBlock.Input,
|
||||
output_schema=FindInListBlock.Output,
|
||||
test_input=[
|
||||
{"list": [1, 2, 3, 4, 5], "value": 3},
|
||||
{"list": [1, 2, 3, 4, 5], "value": 6},
|
||||
],
|
||||
test_output=[
|
||||
("index", 2),
|
||||
("found", True),
|
||||
("found", False),
|
||||
("not_found_value", 6),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
yield "index", input_data.list.index(input_data.value)
|
||||
yield "found", True
|
||||
except ValueError:
|
||||
yield "found", False
|
||||
yield "not_found_value", input_data.value
|
||||
|
||||
|
||||
class NoteBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(description="The text to display in the sticky note.")
|
||||
@@ -469,47 +590,3 @@ class CreateListBlock(Block):
|
||||
yield "list", input_data.values
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create list: {str(e)}"
|
||||
|
||||
|
||||
class TypeOptions(enum.Enum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
LIST = "list"
|
||||
DICTIONARY = "dictionary"
|
||||
|
||||
|
||||
class UniversalTypeConverterBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
value: Any = SchemaField(
|
||||
description="The value to convert to a universal type."
|
||||
)
|
||||
type: TypeOptions = SchemaField(description="The type to convert the value to.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
value: Any = SchemaField(description="The converted value.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="95d1b990-ce13-4d88-9737-ba5c2070c97b",
|
||||
description="This block is used to convert a value to a universal type.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=UniversalTypeConverterBlock.Input,
|
||||
output_schema=UniversalTypeConverterBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
converted_value = convert(
|
||||
input_data.value,
|
||||
{
|
||||
TypeOptions.STRING: str,
|
||||
TypeOptions.NUMBER: float,
|
||||
TypeOptions.BOOLEAN: bool,
|
||||
TypeOptions.LIST: list,
|
||||
TypeOptions.DICTIONARY: dict,
|
||||
}[input_data.type],
|
||||
)
|
||||
yield "value", converted_value
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to convert value: {str(e)}"
|
||||
|
||||
@@ -107,83 +107,3 @@ class ConditionBlock(Block):
|
||||
yield "yes_output", yes_value
|
||||
else:
|
||||
yield "no_output", no_value
|
||||
|
||||
|
||||
class IfInputMatchesBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: Any = SchemaField(
|
||||
description="The input to match against",
|
||||
placeholder="For example: 10 or 'hello' or True",
|
||||
)
|
||||
value: Any = SchemaField(
|
||||
description="The value to output if the input matches",
|
||||
placeholder="For example: 'Greater' or 20 or False",
|
||||
)
|
||||
yes_value: Any = SchemaField(
|
||||
description="The value to output if the input matches",
|
||||
placeholder="For example: 'Greater' or 20 or False",
|
||||
default=None,
|
||||
)
|
||||
no_value: Any = SchemaField(
|
||||
description="The value to output if the input does not match",
|
||||
placeholder="For example: 'Greater' or 20 or False",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: bool = SchemaField(
|
||||
description="The result of the condition evaluation (True or False)"
|
||||
)
|
||||
yes_output: Any = SchemaField(
|
||||
description="The output value if the condition is true"
|
||||
)
|
||||
no_output: Any = SchemaField(
|
||||
description="The output value if the condition is false"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="6dbbc4b3-ca6c-42b6-b508-da52d23e13f2",
|
||||
input_schema=IfInputMatchesBlock.Input,
|
||||
output_schema=IfInputMatchesBlock.Output,
|
||||
description="Handles conditional logic based on comparison operators",
|
||||
categories={BlockCategory.LOGIC},
|
||||
test_input=[
|
||||
{
|
||||
"input": 10,
|
||||
"value": 10,
|
||||
"yes_value": "Greater",
|
||||
"no_value": "Not greater",
|
||||
},
|
||||
{
|
||||
"input": 10,
|
||||
"value": 20,
|
||||
"yes_value": "Greater",
|
||||
"no_value": "Not greater",
|
||||
},
|
||||
{
|
||||
"input": 10,
|
||||
"value": None,
|
||||
"yes_value": "Yes",
|
||||
"no_value": "No",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", True),
|
||||
("yes_output", "Greater"),
|
||||
("result", False),
|
||||
("no_output", "Not greater"),
|
||||
("result", False),
|
||||
("no_output", "No"),
|
||||
# ("result", True),
|
||||
# ("yes_output", "Yes"),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
if input_data.input == input_data.value or input_data.input is input_data.value:
|
||||
yield "result", True
|
||||
yield "yes_output", input_data.yes_value
|
||||
else:
|
||||
yield "result", False
|
||||
yield "no_output", input_data.no_value
|
||||
|
||||
@@ -188,270 +188,3 @@ class CodeExecutionBlock(Block):
|
||||
yield "stderr_logs", stderr_logs
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class InstantiationBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
)
|
||||
|
||||
# Todo : Option to run commond in background
|
||||
setup_commands: list[str] = SchemaField(
|
||||
description=(
|
||||
"Shell commands to set up the sandbox before running the code. "
|
||||
"You can use `curl` or `git` to install your desired Debian based "
|
||||
"package manager. `pip` and `npm` are pre-installed.\n\n"
|
||||
"These commands are executed with `sh`, in the foreground."
|
||||
),
|
||||
placeholder="pip install cowsay",
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
setup_code: str = SchemaField(
|
||||
description="Code to execute in the sandbox",
|
||||
placeholder="print('Hello, World!')",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
language: ProgrammingLanguage = SchemaField(
|
||||
description="Programming language to execute",
|
||||
default=ProgrammingLanguage.PYTHON,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
timeout: int = SchemaField(
|
||||
description="Execution timeout in seconds", default=300
|
||||
)
|
||||
|
||||
template_id: str = SchemaField(
|
||||
description=(
|
||||
"You can use an E2B sandbox template by entering its ID here. "
|
||||
"Check out the E2B docs for more details: "
|
||||
"[E2B - Sandbox template](https://e2b.dev/docs/sandbox-template)"
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
sandbox_id: str = SchemaField(description="ID of the sandbox instance")
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
stderr_logs: str = SchemaField(description="Standard error logs from execution")
|
||||
error: str = SchemaField(description="Error message if execution failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ff0861c9-1726-4aec-9e5b-bf53f3622112",
|
||||
description="Instantiate an isolated sandbox environment with internet access where to execute code in.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=InstantiationBlock.Input,
|
||||
output_schema=InstantiationBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"setup_code": "print('Hello World')",
|
||||
"language": ProgrammingLanguage.PYTHON.value,
|
||||
"setup_commands": [],
|
||||
"timeout": 300,
|
||||
"template_id": "",
|
||||
},
|
||||
test_output=[
|
||||
("sandbox_id", str),
|
||||
("response", "Hello World"),
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda setup_code, language, setup_commands, timeout, api_key, template_id: (
|
||||
"sandbox_id",
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
sandbox_id, response, stdout_logs, stderr_logs = self.execute_code(
|
||||
input_data.setup_code,
|
||||
input_data.language,
|
||||
input_data.setup_commands,
|
||||
input_data.timeout,
|
||||
credentials.api_key.get_secret_value(),
|
||||
input_data.template_id,
|
||||
)
|
||||
if sandbox_id:
|
||||
yield "sandbox_id", sandbox_id
|
||||
else:
|
||||
yield "error", "Sandbox ID not found"
|
||||
if response:
|
||||
yield "response", response
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
def execute_code(
|
||||
self,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
setup_commands: list[str],
|
||||
timeout: int,
|
||||
api_key: str,
|
||||
template_id: str,
|
||||
):
|
||||
try:
|
||||
sandbox = None
|
||||
if template_id:
|
||||
sandbox = Sandbox(
|
||||
template=template_id, api_key=api_key, timeout=timeout
|
||||
)
|
||||
else:
|
||||
sandbox = Sandbox(api_key=api_key, timeout=timeout)
|
||||
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not created")
|
||||
|
||||
# Running setup commands
|
||||
for cmd in setup_commands:
|
||||
sandbox.commands.run(cmd)
|
||||
|
||||
# Executing the code
|
||||
execution = sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox if there is an error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return sandbox.sandbox_id, response, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
class StepExecutionBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
)
|
||||
|
||||
sandbox_id: str = SchemaField(
|
||||
description="ID of the sandbox instance to execute the code in",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
step_code: str = SchemaField(
|
||||
description="Code to execute in the sandbox",
|
||||
placeholder="print('Hello, World!')",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
language: ProgrammingLanguage = SchemaField(
|
||||
description="Programming language to execute",
|
||||
default=ProgrammingLanguage.PYTHON,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
stderr_logs: str = SchemaField(description="Standard error logs from execution")
|
||||
error: str = SchemaField(description="Error message if execution failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="82b59b8e-ea10-4d57-9161-8b169b0adba6",
|
||||
description="Execute code in a previously instantiated sandbox environment.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=StepExecutionBlock.Input,
|
||||
output_schema=StepExecutionBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"sandbox_id": "sandbox_id",
|
||||
"step_code": "print('Hello World')",
|
||||
"language": ProgrammingLanguage.PYTHON.value,
|
||||
},
|
||||
test_output=[
|
||||
("response", "Hello World"),
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_step_code": lambda sandbox_id, step_code, language, api_key: (
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def execute_step_code(
|
||||
self,
|
||||
sandbox_id: str,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
api_key: str,
|
||||
):
|
||||
try:
|
||||
sandbox = Sandbox.connect(sandbox_id=sandbox_id, api_key=api_key)
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not found")
|
||||
|
||||
# Executing the code
|
||||
execution = sandbox.run_code(code, language=language.value)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return response, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
response, stdout_logs, stderr_logs = self.execute_step_code(
|
||||
input_data.sandbox_id,
|
||||
input_data.step_code,
|
||||
input_data.language,
|
||||
credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
if response:
|
||||
yield "response", response
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -8,7 +8,6 @@ from backend.data.block import (
|
||||
BlockSchema,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.compass import CompassWebhookType
|
||||
|
||||
|
||||
@@ -43,7 +42,7 @@ class CompassAITriggerBlock(Block):
|
||||
input_schema=CompassAITriggerBlock.Input,
|
||||
output_schema=CompassAITriggerBlock.Output,
|
||||
webhook_config=BlockManualWebhookConfig(
|
||||
provider=ProviderName.COMPASS,
|
||||
provider="compass",
|
||||
webhook_type=CompassWebhookType.TRANSCRIPTION,
|
||||
),
|
||||
test_input=[
|
||||
|
||||
@@ -1,53 +1,22 @@
|
||||
import smtplib
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, SecretStr
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
TEST_CREDENTIALS = UserPasswordCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="smtp",
|
||||
username=SecretStr("mock-smtp-username"),
|
||||
password=SecretStr("mock-smtp-password"),
|
||||
title="Mock SMTP credentials",
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
SMTPCredentials = UserPasswordCredentials
|
||||
SMTPCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.SMTP],
|
||||
Literal["user_password"],
|
||||
]
|
||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||
|
||||
|
||||
def SMTPCredentialsField() -> SMTPCredentialsInput:
|
||||
return CredentialsField(
|
||||
description="The SMTP integration requires a username and password.",
|
||||
)
|
||||
|
||||
|
||||
class SMTPConfig(BaseModel):
|
||||
class EmailCredentials(BaseModel):
|
||||
smtp_server: str = SchemaField(
|
||||
default="smtp.example.com", description="SMTP server address"
|
||||
default="smtp.gmail.com", description="SMTP server address"
|
||||
)
|
||||
smtp_port: int = SchemaField(default=25, description="SMTP port number")
|
||||
smtp_username: BlockSecret = SecretField(key="smtp_username")
|
||||
smtp_password: BlockSecret = SecretField(key="smtp_password")
|
||||
|
||||
model_config = ConfigDict(title="SMTP Config")
|
||||
model_config = ConfigDict(title="Email Credentials")
|
||||
|
||||
|
||||
class SendEmailBlock(Block):
|
||||
@@ -61,11 +30,10 @@ class SendEmailBlock(Block):
|
||||
body: str = SchemaField(
|
||||
description="Body of the email", placeholder="Enter the email body"
|
||||
)
|
||||
config: SMTPConfig = SchemaField(
|
||||
description="SMTP Config",
|
||||
default=SMTPConfig(),
|
||||
creds: EmailCredentials = SchemaField(
|
||||
description="SMTP credentials",
|
||||
default=EmailCredentials(),
|
||||
)
|
||||
credentials: SMTPCredentialsInput = SMTPCredentialsField()
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Status of the email sending operation")
|
||||
@@ -75,6 +43,7 @@ class SendEmailBlock(Block):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="4335878a-394e-4e67-adf2-919877ff49ae",
|
||||
description="This block sends an email using the provided SMTP credentials.",
|
||||
categories={BlockCategory.OUTPUT},
|
||||
@@ -84,29 +53,25 @@ class SendEmailBlock(Block):
|
||||
"to_email": "recipient@example.com",
|
||||
"subject": "Test Email",
|
||||
"body": "This is a test email.",
|
||||
"config": {
|
||||
"creds": {
|
||||
"smtp_server": "smtp.gmail.com",
|
||||
"smtp_port": 25,
|
||||
"smtp_username": "your-email@gmail.com",
|
||||
"smtp_password": "your-gmail-password",
|
||||
},
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Email sent successfully")],
|
||||
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def send_email(
|
||||
config: SMTPConfig,
|
||||
to_email: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
credentials: SMTPCredentials,
|
||||
creds: EmailCredentials, to_email: str, subject: str, body: str
|
||||
) -> str:
|
||||
smtp_server = config.smtp_server
|
||||
smtp_port = config.smtp_port
|
||||
smtp_username = credentials.username.get_secret_value()
|
||||
smtp_password = credentials.password.get_secret_value()
|
||||
smtp_server = creds.smtp_server
|
||||
smtp_port = creds.smtp_port
|
||||
smtp_username = creds.smtp_username.get_secret_value()
|
||||
smtp_password = creds.smtp_password.get_secret_value()
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg["From"] = smtp_username
|
||||
@@ -121,13 +86,10 @@ class SendEmailBlock(Block):
|
||||
|
||||
return "Email sent successfully"
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: SMTPCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "status", self.send_email(
|
||||
config=input_data.config,
|
||||
to_email=input_data.to_email,
|
||||
subject=input_data.subject,
|
||||
body=input_data.body,
|
||||
credentials=credentials,
|
||||
input_data.creds,
|
||||
input_data.to_email,
|
||||
input_data.subject,
|
||||
input_data.body,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -51,7 +51,6 @@ class ExaContentsBlock(Block):
|
||||
description="List of document contents",
|
||||
default=[],
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from backend.blocks.github._auth import (
|
||||
GithubCredentials,
|
||||
GithubFineGrainedAPICredentials,
|
||||
)
|
||||
from backend.blocks.github._auth import GithubCredentials
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
@@ -33,68 +30,12 @@ def _convert_to_api_url(url: str) -> str:
|
||||
|
||||
def _get_headers(credentials: GithubCredentials) -> dict[str, str]:
|
||||
return {
|
||||
"Authorization": credentials.auth_header(),
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
|
||||
|
||||
def convert_comment_url_to_api_endpoint(comment_url: str) -> str:
|
||||
"""
|
||||
Converts a GitHub comment URL (web interface) to the appropriate API endpoint URL.
|
||||
|
||||
Handles:
|
||||
1. Issue/PR comments: #issuecomment-{id}
|
||||
2. PR review comments: #discussion_r{id}
|
||||
|
||||
Returns the appropriate API endpoint path for the comment.
|
||||
"""
|
||||
# First, check if this is already an API URL
|
||||
parsed_url = urlparse(comment_url)
|
||||
if parsed_url.hostname == "api.github.com":
|
||||
return comment_url
|
||||
|
||||
# Replace pull with issues for comment endpoints
|
||||
if "/pull/" in comment_url:
|
||||
comment_url = comment_url.replace("/pull/", "/issues/")
|
||||
|
||||
# Handle issue/PR comments (#issuecomment-xxx)
|
||||
if "#issuecomment-" in comment_url:
|
||||
base_url, comment_part = comment_url.split("#issuecomment-")
|
||||
comment_id = comment_part
|
||||
|
||||
# Extract repo information from base URL
|
||||
parsed_url = urlparse(base_url)
|
||||
path_parts = parsed_url.path.strip("/").split("/")
|
||||
owner, repo = path_parts[0], path_parts[1]
|
||||
|
||||
# Construct API URL for issue comments
|
||||
return (
|
||||
f"https://api.github.com/repos/{owner}/{repo}/issues/comments/{comment_id}"
|
||||
)
|
||||
|
||||
# Handle PR review comments (#discussion_r)
|
||||
elif "#discussion_r" in comment_url:
|
||||
base_url, comment_part = comment_url.split("#discussion_r")
|
||||
comment_id = comment_part
|
||||
|
||||
# Extract repo information from base URL
|
||||
parsed_url = urlparse(base_url)
|
||||
path_parts = parsed_url.path.strip("/").split("/")
|
||||
owner, repo = path_parts[0], path_parts[1]
|
||||
|
||||
# Construct API URL for PR review comments
|
||||
return (
|
||||
f"https://api.github.com/repos/{owner}/{repo}/pulls/comments/{comment_id}"
|
||||
)
|
||||
|
||||
# If no specific comment identifiers are found, use the general URL conversion
|
||||
return _convert_to_api_url(comment_url)
|
||||
|
||||
|
||||
def get_api(
|
||||
credentials: GithubCredentials | GithubFineGrainedAPICredentials,
|
||||
convert_urls: bool = True,
|
||||
) -> Requests:
|
||||
def get_api(credentials: GithubCredentials, convert_urls: bool = True) -> Requests:
|
||||
return Requests(
|
||||
trusted_origins=["https://api.github.com", "https://github.com"],
|
||||
extra_url_validator=_convert_to_api_url if convert_urls else None,
|
||||
|
||||
@@ -22,11 +22,6 @@ GithubCredentialsInput = CredentialsMetaInput[
|
||||
Literal["api_key", "oauth2"] if GITHUB_OAUTH_IS_CONFIGURED else Literal["api_key"],
|
||||
]
|
||||
|
||||
GithubFineGrainedAPICredentials = APIKeyCredentials
|
||||
GithubFineGrainedAPICredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.GITHUB], Literal["api_key"]
|
||||
]
|
||||
|
||||
|
||||
def GithubCredentialsField(scope: str) -> GithubCredentialsInput:
|
||||
"""
|
||||
@@ -42,16 +37,6 @@ def GithubCredentialsField(scope: str) -> GithubCredentialsInput:
|
||||
)
|
||||
|
||||
|
||||
def GithubFineGrainedAPICredentialsField(
|
||||
scope: str,
|
||||
) -> GithubFineGrainedAPICredentialsInput:
|
||||
return CredentialsField(
|
||||
required_scopes={scope},
|
||||
description="The GitHub integration can be used with OAuth, "
|
||||
"or any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="github",
|
||||
@@ -65,18 +50,3 @@ TEST_CREDENTIALS_INPUT = {
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.type,
|
||||
}
|
||||
|
||||
TEST_FINE_GRAINED_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="github",
|
||||
api_key=SecretStr("mock-github-api-key"),
|
||||
title="Mock GitHub API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_FINE_GRAINED_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_FINE_GRAINED_CREDENTIALS.provider,
|
||||
"id": TEST_FINE_GRAINED_CREDENTIALS.id,
|
||||
"type": TEST_FINE_GRAINED_CREDENTIALS.type,
|
||||
"title": TEST_FINE_GRAINED_CREDENTIALS.type,
|
||||
}
|
||||
|
||||
@@ -1,360 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
# queued, in_progress, completed, waiting, requested, pending
|
||||
class ChecksStatus(Enum):
|
||||
QUEUED = "queued"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
WAITING = "waiting"
|
||||
REQUESTED = "requested"
|
||||
PENDING = "pending"
|
||||
|
||||
|
||||
class ChecksConclusion(Enum):
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
NEUTRAL = "neutral"
|
||||
CANCELLED = "cancelled"
|
||||
TIMED_OUT = "timed_out"
|
||||
ACTION_REQUIRED = "action_required"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
class GithubCreateCheckRunBlock(Block):
|
||||
"""Block for creating a new check run on a GitHub repository."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo:status")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
name: str = SchemaField(
|
||||
description="The name of the check run (e.g., 'code-coverage')",
|
||||
)
|
||||
head_sha: str = SchemaField(
|
||||
description="The SHA of the commit to check",
|
||||
)
|
||||
status: ChecksStatus = SchemaField(
|
||||
description="Current status of the check run",
|
||||
default=ChecksStatus.QUEUED,
|
||||
)
|
||||
conclusion: Optional[ChecksConclusion] = SchemaField(
|
||||
description="The final conclusion of the check (required if status is completed)",
|
||||
default=None,
|
||||
)
|
||||
details_url: str = SchemaField(
|
||||
description="The URL for the full details of the check",
|
||||
default="",
|
||||
)
|
||||
output_title: str = SchemaField(
|
||||
description="Title of the check run output",
|
||||
default="",
|
||||
)
|
||||
output_summary: str = SchemaField(
|
||||
description="Summary of the check run output",
|
||||
default="",
|
||||
)
|
||||
output_text: str = SchemaField(
|
||||
description="Detailed text of the check run output",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class CheckRunResult(BaseModel):
|
||||
id: int
|
||||
html_url: str
|
||||
status: str
|
||||
|
||||
check_run: CheckRunResult = SchemaField(
|
||||
description="Details of the created check run"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if check run creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2f45e89a-3b7d-4f22-b89e-6c4f5c7e1234",
|
||||
description="Creates a new check run for a specific commit in a GitHub repository",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreateCheckRunBlock.Input,
|
||||
output_schema=GithubCreateCheckRunBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"name": "test-check",
|
||||
"head_sha": "ce587453ced02b1526dfb4cb910479d431683101",
|
||||
"status": ChecksStatus.COMPLETED.value,
|
||||
"conclusion": ChecksConclusion.SUCCESS.value,
|
||||
"output_title": "Test Results",
|
||||
"output_summary": "All tests passed",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
# requires a github app not available to oauth in our current system
|
||||
disabled=True,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"check_run",
|
||||
{
|
||||
"id": 4,
|
||||
"html_url": "https://github.com/owner/repo/runs/4",
|
||||
"status": "completed",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"create_check_run": lambda *args, **kwargs: {
|
||||
"id": 4,
|
||||
"html_url": "https://github.com/owner/repo/runs/4",
|
||||
"status": "completed",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_check_run(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
name: str,
|
||||
head_sha: str,
|
||||
status: ChecksStatus,
|
||||
conclusion: Optional[ChecksConclusion] = None,
|
||||
details_url: Optional[str] = None,
|
||||
output_title: Optional[str] = None,
|
||||
output_summary: Optional[str] = None,
|
||||
output_text: Optional[str] = None,
|
||||
) -> dict:
|
||||
api = get_api(credentials)
|
||||
|
||||
class CheckRunData(BaseModel):
|
||||
name: str
|
||||
head_sha: str
|
||||
status: str
|
||||
conclusion: Optional[str] = None
|
||||
details_url: Optional[str] = None
|
||||
output: Optional[dict[str, str]] = None
|
||||
|
||||
data = CheckRunData(
|
||||
name=name,
|
||||
head_sha=head_sha,
|
||||
status=status.value,
|
||||
)
|
||||
|
||||
if conclusion:
|
||||
data.conclusion = conclusion.value
|
||||
|
||||
if details_url:
|
||||
data.details_url = details_url
|
||||
|
||||
if output_title or output_summary or output_text:
|
||||
output_data = {
|
||||
"title": output_title or "",
|
||||
"summary": output_summary or "",
|
||||
"text": output_text or "",
|
||||
}
|
||||
data.output = output_data
|
||||
|
||||
check_runs_url = f"{repo_url}/check-runs"
|
||||
response = api.post(
|
||||
check_runs_url, data=data.model_dump_json(exclude_none=True)
|
||||
)
|
||||
result = response.json()
|
||||
|
||||
return {
|
||||
"id": result["id"],
|
||||
"html_url": result["html_url"],
|
||||
"status": result["status"],
|
||||
}
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = self.create_check_run(
|
||||
credentials=credentials,
|
||||
repo_url=input_data.repo_url,
|
||||
name=input_data.name,
|
||||
head_sha=input_data.head_sha,
|
||||
status=input_data.status,
|
||||
conclusion=input_data.conclusion,
|
||||
details_url=input_data.details_url,
|
||||
output_title=input_data.output_title,
|
||||
output_summary=input_data.output_summary,
|
||||
output_text=input_data.output_text,
|
||||
)
|
||||
yield "check_run", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubUpdateCheckRunBlock(Block):
|
||||
"""Block for updating an existing check run on a GitHub repository."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo:status")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
check_run_id: int = SchemaField(
|
||||
description="The ID of the check run to update",
|
||||
)
|
||||
status: ChecksStatus = SchemaField(
|
||||
description="New status of the check run",
|
||||
)
|
||||
conclusion: ChecksConclusion = SchemaField(
|
||||
description="The final conclusion of the check (required if status is completed)",
|
||||
)
|
||||
output_title: Optional[str] = SchemaField(
|
||||
description="New title of the check run output",
|
||||
default=None,
|
||||
)
|
||||
output_summary: Optional[str] = SchemaField(
|
||||
description="New summary of the check run output",
|
||||
default=None,
|
||||
)
|
||||
output_text: Optional[str] = SchemaField(
|
||||
description="New detailed text of the check run output",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class CheckRunResult(BaseModel):
|
||||
id: int
|
||||
html_url: str
|
||||
status: str
|
||||
conclusion: Optional[str]
|
||||
|
||||
check_run: CheckRunResult = SchemaField(
|
||||
description="Details of the updated check run"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if check run update failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8a23c567-9d01-4e56-b789-0c12d3e45678", # Generated UUID
|
||||
description="Updates an existing check run in a GitHub repository",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubUpdateCheckRunBlock.Input,
|
||||
output_schema=GithubUpdateCheckRunBlock.Output,
|
||||
# requires a github app not available to oauth in our current system
|
||||
disabled=True,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"check_run_id": 4,
|
||||
"status": ChecksStatus.COMPLETED.value,
|
||||
"conclusion": ChecksConclusion.SUCCESS.value,
|
||||
"output_title": "Updated Results",
|
||||
"output_summary": "All tests passed after retry",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"check_run",
|
||||
{
|
||||
"id": 4,
|
||||
"html_url": "https://github.com/owner/repo/runs/4",
|
||||
"status": "completed",
|
||||
"conclusion": "success",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"update_check_run": lambda *args, **kwargs: {
|
||||
"id": 4,
|
||||
"html_url": "https://github.com/owner/repo/runs/4",
|
||||
"status": "completed",
|
||||
"conclusion": "success",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_check_run(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
check_run_id: int,
|
||||
status: ChecksStatus,
|
||||
conclusion: Optional[ChecksConclusion] = None,
|
||||
output_title: Optional[str] = None,
|
||||
output_summary: Optional[str] = None,
|
||||
output_text: Optional[str] = None,
|
||||
) -> dict:
|
||||
api = get_api(credentials)
|
||||
|
||||
class UpdateCheckRunData(BaseModel):
|
||||
status: str
|
||||
conclusion: Optional[str] = None
|
||||
output: Optional[dict[str, str]] = None
|
||||
|
||||
data = UpdateCheckRunData(
|
||||
status=status.value,
|
||||
)
|
||||
|
||||
if conclusion:
|
||||
data.conclusion = conclusion.value
|
||||
|
||||
if output_title or output_summary or output_text:
|
||||
output_data = {
|
||||
"title": output_title or "",
|
||||
"summary": output_summary or "",
|
||||
"text": output_text or "",
|
||||
}
|
||||
data.output = output_data
|
||||
|
||||
check_run_url = f"{repo_url}/check-runs/{check_run_id}"
|
||||
response = api.patch(
|
||||
check_run_url, data=data.model_dump_json(exclude_none=True)
|
||||
)
|
||||
result = response.json()
|
||||
|
||||
return {
|
||||
"id": result["id"],
|
||||
"html_url": result["html_url"],
|
||||
"status": result["status"],
|
||||
"conclusion": result.get("conclusion"),
|
||||
}
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = self.update_check_run(
|
||||
credentials=credentials,
|
||||
repo_url=input_data.repo_url,
|
||||
check_run_id=input_data.check_run_id,
|
||||
status=input_data.status,
|
||||
conclusion=input_data.conclusion,
|
||||
output_title=input_data.output_title,
|
||||
output_summary=input_data.output_summary,
|
||||
output_text=input_data.output_text,
|
||||
)
|
||||
yield "check_run", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
@@ -6,7 +5,7 @@ from typing_extensions import TypedDict
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import convert_comment_url_to_api_endpoint, get_api
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
@@ -15,8 +14,6 @@ from ._auth import (
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_github_url(url: str) -> bool:
|
||||
return urlparse(url).netloc == "github.com"
|
||||
@@ -111,228 +108,6 @@ class GithubCommentBlock(Block):
|
||||
# --8<-- [end:GithubCommentBlockExample]
|
||||
|
||||
|
||||
class GithubUpdateCommentBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
comment_url: str = SchemaField(
|
||||
description="URL of the GitHub comment",
|
||||
placeholder="https://github.com/owner/repo/issues/1#issuecomment-123456789",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue or pull request",
|
||||
placeholder="https://github.com/owner/repo/issues/1",
|
||||
default="",
|
||||
)
|
||||
comment_id: str = SchemaField(
|
||||
description="ID of the GitHub comment",
|
||||
placeholder="123456789",
|
||||
default="",
|
||||
)
|
||||
comment: str = SchemaField(
|
||||
description="Comment to update",
|
||||
placeholder="Enter your comment",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
id: int = SchemaField(description="ID of the updated comment")
|
||||
url: str = SchemaField(description="URL to the comment on GitHub")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the comment update failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b3f4d747-10e3-4e69-8c51-f2be1d99c9a7",
|
||||
description="This block updates a comment on a specified GitHub issue or pull request.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubUpdateCommentBlock.Input,
|
||||
output_schema=GithubUpdateCommentBlock.Output,
|
||||
test_input={
|
||||
"comment_url": "https://github.com/owner/repo/issues/1#issuecomment-123456789",
|
||||
"comment": "This is an updated comment.",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("id", 123456789),
|
||||
(
|
||||
"url",
|
||||
"https://github.com/owner/repo/issues/1#issuecomment-123456789",
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"update_comment": lambda *args, **kwargs: (
|
||||
123456789,
|
||||
"https://github.com/owner/repo/issues/1#issuecomment-123456789",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_comment(
|
||||
credentials: GithubCredentials, comment_url: str, body_text: str
|
||||
) -> tuple[int, str]:
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
data = {"body": body_text}
|
||||
url = convert_comment_url_to_api_endpoint(comment_url)
|
||||
|
||||
logger.info(url)
|
||||
response = api.patch(url, json=data)
|
||||
comment = response.json()
|
||||
return comment["id"], comment["html_url"]
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
if (
|
||||
not input_data.comment_url
|
||||
and input_data.comment_id
|
||||
and input_data.issue_url
|
||||
):
|
||||
parsed_url = urlparse(input_data.issue_url)
|
||||
path_parts = parsed_url.path.strip("/").split("/")
|
||||
owner, repo = path_parts[0], path_parts[1]
|
||||
|
||||
input_data.comment_url = f"https://api.github.com/repos/{owner}/{repo}/issues/comments/{input_data.comment_id}"
|
||||
|
||||
elif (
|
||||
not input_data.comment_url
|
||||
and not input_data.comment_id
|
||||
and input_data.issue_url
|
||||
):
|
||||
raise ValueError(
|
||||
"Must provide either comment_url or comment_id and issue_url"
|
||||
)
|
||||
id, url = self.update_comment(
|
||||
credentials,
|
||||
input_data.comment_url,
|
||||
input_data.comment,
|
||||
)
|
||||
yield "id", id
|
||||
yield "url", url
|
||||
|
||||
|
||||
class GithubListCommentsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue or pull request",
|
||||
placeholder="https://github.com/owner/repo/issues/1",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class CommentItem(TypedDict):
|
||||
id: int
|
||||
body: str
|
||||
user: str
|
||||
url: str
|
||||
|
||||
comment: CommentItem = SchemaField(
|
||||
title="Comment", description="Comments with their ID, body, user, and URL"
|
||||
)
|
||||
comments: list[CommentItem] = SchemaField(
|
||||
description="List of comments with their ID, body, user, and URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing comments failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c4b5fb63-0005-4a11-b35a-0c2467bd6b59",
|
||||
description="This block lists all comments for a specified GitHub issue or pull request.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListCommentsBlock.Input,
|
||||
output_schema=GithubListCommentsBlock.Output,
|
||||
test_input={
|
||||
"issue_url": "https://github.com/owner/repo/issues/1",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"comment",
|
||||
{
|
||||
"id": 123456789,
|
||||
"body": "This is a test comment.",
|
||||
"user": "test_user",
|
||||
"url": "https://github.com/owner/repo/issues/1#issuecomment-123456789",
|
||||
},
|
||||
),
|
||||
(
|
||||
"comments",
|
||||
[
|
||||
{
|
||||
"id": 123456789,
|
||||
"body": "This is a test comment.",
|
||||
"user": "test_user",
|
||||
"url": "https://github.com/owner/repo/issues/1#issuecomment-123456789",
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"list_comments": lambda *args, **kwargs: [
|
||||
{
|
||||
"id": 123456789,
|
||||
"body": "This is a test comment.",
|
||||
"user": "test_user",
|
||||
"url": "https://github.com/owner/repo/issues/1#issuecomment-123456789",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def list_comments(
|
||||
credentials: GithubCredentials, issue_url: str
|
||||
) -> list[Output.CommentItem]:
|
||||
parsed_url = urlparse(issue_url)
|
||||
path_parts = parsed_url.path.strip("/").split("/")
|
||||
|
||||
owner = path_parts[0]
|
||||
repo = path_parts[1]
|
||||
|
||||
# GitHub API uses 'issues' for both issues and pull requests when it comes to comments
|
||||
issue_number = path_parts[3] # Whether 'issues/123' or 'pull/123'
|
||||
|
||||
# Construct the proper API URL directly
|
||||
api_url = f"https://api.github.com/repos/{owner}/{repo}/issues/{issue_number}/comments"
|
||||
|
||||
# Set convert_urls=False since we're already providing an API URL
|
||||
api = get_api(credentials, convert_urls=False)
|
||||
response = api.get(api_url)
|
||||
comments = response.json()
|
||||
parsed_comments: list[GithubListCommentsBlock.Output.CommentItem] = [
|
||||
{
|
||||
"id": comment["id"],
|
||||
"body": comment["body"],
|
||||
"user": comment["user"]["login"],
|
||||
"url": comment["html_url"],
|
||||
}
|
||||
for comment in comments
|
||||
]
|
||||
return parsed_comments
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
comments = self.list_comments(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
)
|
||||
yield from (("comment", comment) for comment in comments)
|
||||
yield "comments", comments
|
||||
|
||||
|
||||
class GithubMakeIssueBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
|
||||
@@ -200,7 +200,6 @@ class GithubReadPullRequestBlock(Block):
|
||||
include_pr_changes: bool = SchemaField(
|
||||
description="Whether to include the changes made in the pull request",
|
||||
default=False,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubFineGrainedAPICredentials,
|
||||
GithubFineGrainedAPICredentialsField,
|
||||
GithubFineGrainedAPICredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class StatusState(Enum):
|
||||
ERROR = "error"
|
||||
FAILURE = "failure"
|
||||
PENDING = "pending"
|
||||
SUCCESS = "success"
|
||||
|
||||
|
||||
class GithubCreateStatusBlock(Block):
|
||||
"""Block for creating a commit status on a GitHub repository."""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubFineGrainedAPICredentialsInput = (
|
||||
GithubFineGrainedAPICredentialsField("repo:status")
|
||||
)
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
sha: str = SchemaField(
|
||||
description="The SHA of the commit to set status for",
|
||||
)
|
||||
state: StatusState = SchemaField(
|
||||
description="The state of the status (error, failure, pending, success)",
|
||||
)
|
||||
target_url: Optional[str] = SchemaField(
|
||||
description="URL with additional details about this status",
|
||||
default=None,
|
||||
)
|
||||
description: Optional[str] = SchemaField(
|
||||
description="Short description of the status",
|
||||
default=None,
|
||||
)
|
||||
check_name: Optional[str] = SchemaField(
|
||||
description="Label to differentiate this status from others",
|
||||
default="AutoGPT Platform Checks",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class StatusResult(BaseModel):
|
||||
id: int
|
||||
url: str
|
||||
state: str
|
||||
context: str
|
||||
description: Optional[str]
|
||||
target_url: Optional[str]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
status: StatusResult = SchemaField(description="Details of the created status")
|
||||
error: str = SchemaField(description="Error message if status creation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3d67f123-a4b5-4c89-9d01-2e34f5c67890", # Generated UUID
|
||||
description="Creates a new commit status in a GitHub repository",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreateStatusBlock.Input,
|
||||
output_schema=GithubCreateStatusBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"sha": "ce587453ced02b1526dfb4cb910479d431683101",
|
||||
"state": StatusState.SUCCESS.value,
|
||||
"target_url": "https://example.com/build/status",
|
||||
"description": "The build succeeded!",
|
||||
"check_name": "continuous-integration/jenkins",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"status",
|
||||
{
|
||||
"id": 1234567890,
|
||||
"url": "https://api.github.com/repos/owner/repo/statuses/ce587453ced02b1526dfb4cb910479d431683101",
|
||||
"state": "success",
|
||||
"context": "continuous-integration/jenkins",
|
||||
"description": "The build succeeded!",
|
||||
"target_url": "https://example.com/build/status",
|
||||
"created_at": "2024-01-21T10:00:00Z",
|
||||
"updated_at": "2024-01-21T10:00:00Z",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"create_status": lambda *args, **kwargs: {
|
||||
"id": 1234567890,
|
||||
"url": "https://api.github.com/repos/owner/repo/statuses/ce587453ced02b1526dfb4cb910479d431683101",
|
||||
"state": "success",
|
||||
"context": "continuous-integration/jenkins",
|
||||
"description": "The build succeeded!",
|
||||
"target_url": "https://example.com/build/status",
|
||||
"created_at": "2024-01-21T10:00:00Z",
|
||||
"updated_at": "2024-01-21T10:00:00Z",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_status(
|
||||
credentials: GithubFineGrainedAPICredentials,
|
||||
repo_url: str,
|
||||
sha: str,
|
||||
state: StatusState,
|
||||
target_url: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
context: str = "default",
|
||||
) -> dict:
|
||||
api = get_api(credentials)
|
||||
|
||||
class StatusData(BaseModel):
|
||||
state: str
|
||||
target_url: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
context: str
|
||||
|
||||
data = StatusData(
|
||||
state=state.value,
|
||||
context=context,
|
||||
)
|
||||
|
||||
if target_url:
|
||||
data.target_url = target_url
|
||||
|
||||
if description:
|
||||
data.description = description
|
||||
|
||||
status_url = f"{repo_url}/statuses/{sha}"
|
||||
response = api.post(status_url, data=data.model_dump_json(exclude_none=True))
|
||||
result = response.json()
|
||||
|
||||
return {
|
||||
"id": result["id"],
|
||||
"url": result["url"],
|
||||
"state": result["state"],
|
||||
"context": result["context"],
|
||||
"description": result.get("description"),
|
||||
"target_url": result.get("target_url"),
|
||||
"created_at": result["created_at"],
|
||||
"updated_at": result["updated_at"],
|
||||
}
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubFineGrainedAPICredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = self.create_status(
|
||||
credentials=credentials,
|
||||
repo_url=input_data.repo_url,
|
||||
sha=input_data.sha,
|
||||
state=input_data.state,
|
||||
target_url=input_data.target_url,
|
||||
description=input_data.description,
|
||||
context=input_data.check_name or "AutoGPT Platform Checks",
|
||||
)
|
||||
yield "status", result
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -12,7 +12,6 @@ from backend.data.block import (
|
||||
BlockWebhookConfig,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
@@ -124,7 +123,7 @@ class GithubPullRequestTriggerBlock(GitHubTriggerBase, Block):
|
||||
output_schema=GithubPullRequestTriggerBlock.Output,
|
||||
# --8<-- [start:example-webhook_config]
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName.GITHUB,
|
||||
provider="github",
|
||||
webhook_type=GithubWebhookType.REPO,
|
||||
resource_format="{repo}",
|
||||
event_filter_input="events",
|
||||
|
||||
@@ -8,7 +8,6 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from ._auth import (
|
||||
GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
@@ -151,8 +150,8 @@ class GmailReadBlock(Block):
|
||||
else None
|
||||
),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=Settings().secrets.google_client_id,
|
||||
client_secret=Settings().secrets.google_client_secret,
|
||||
client_id=kwargs.get("client_id"),
|
||||
client_secret=kwargs.get("client_secret"),
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
return build("gmail", "v1", credentials=creds)
|
||||
|
||||
@@ -3,7 +3,6 @@ from googleapiclient.discovery import build
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from ._auth import (
|
||||
GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
@@ -87,8 +86,8 @@ class GoogleSheetsReadBlock(Block):
|
||||
else None
|
||||
),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=Settings().secrets.google_client_id,
|
||||
client_secret=Settings().secrets.google_client_secret,
|
||||
client_id=kwargs.get("client_id"),
|
||||
client_secret=kwargs.get("client_secret"),
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
return build("sheets", "v4", credentials=creds)
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from requests.exceptions import HTTPError, RequestException
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import requests
|
||||
|
||||
logger = logging.getLogger(name=__name__)
|
||||
|
||||
|
||||
class HttpMethod(Enum):
|
||||
GET = "GET"
|
||||
@@ -48,9 +43,8 @@ class SendWebRequestBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: object = SchemaField(description="The response from the server")
|
||||
client_error: object = SchemaField(description="Errors on 4xx status codes")
|
||||
server_error: object = SchemaField(description="Errors on 5xx status codes")
|
||||
error: str = SchemaField(description="Errors for all other exceptions")
|
||||
client_error: object = SchemaField(description="The error on 4xx status codes")
|
||||
server_error: object = SchemaField(description="The error on 5xx status codes")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -74,40 +68,20 @@ class SendWebRequestBlock(Block):
|
||||
# we should send it as plain text instead
|
||||
input_data.json_format = False
|
||||
|
||||
try:
|
||||
response = requests.request(
|
||||
input_data.method.value,
|
||||
input_data.url,
|
||||
headers=input_data.headers,
|
||||
json=body if input_data.json_format else None,
|
||||
data=body if not input_data.json_format else None,
|
||||
)
|
||||
result = response.json() if input_data.json_format else response.text
|
||||
response = requests.request(
|
||||
input_data.method.value,
|
||||
input_data.url,
|
||||
headers=input_data.headers,
|
||||
json=body if input_data.json_format else None,
|
||||
data=body if not input_data.json_format else None,
|
||||
)
|
||||
result = response.json() if input_data.json_format else response.text
|
||||
|
||||
if response.status_code // 100 == 2:
|
||||
yield "response", result
|
||||
|
||||
except HTTPError as e:
|
||||
# Handle error responses
|
||||
try:
|
||||
result = e.response.json() if input_data.json_format else str(e)
|
||||
except json.JSONDecodeError:
|
||||
result = str(e)
|
||||
|
||||
if 400 <= e.response.status_code < 500:
|
||||
yield "client_error", result
|
||||
elif 500 <= e.response.status_code < 600:
|
||||
yield "server_error", result
|
||||
else:
|
||||
error_msg = (
|
||||
"Unexpected status code "
|
||||
f"{e.response.status_code} '{e.response.reason}'"
|
||||
)
|
||||
logger.warning(error_msg)
|
||||
yield "error", error_msg
|
||||
|
||||
except RequestException as e:
|
||||
# Handle other request-related exceptions
|
||||
yield "error", str(e)
|
||||
|
||||
except Exception as e:
|
||||
# Catch any other unexpected exceptions
|
||||
yield "error", str(e)
|
||||
elif response.status_code // 100 == 4:
|
||||
yield "client_error", result
|
||||
elif response.status_code // 100 == 5:
|
||||
yield "server_error", result
|
||||
else:
|
||||
raise ValueError(f"Unexpected status code: {response.status_code}")
|
||||
|
||||
@@ -142,16 +142,6 @@ class IdeogramModelBlock(Block):
|
||||
title="Color Palette Preset",
|
||||
advanced=True,
|
||||
)
|
||||
custom_color_palette: Optional[list[str]] = SchemaField(
|
||||
description=(
|
||||
"Only available for model version V_2 or V_2_TURBO. Provide one or more color hex codes "
|
||||
"(e.g., ['#000030', '#1C0C47', '#9900FF', '#4285F4', '#FFFFFF']) to define a custom color "
|
||||
"palette. Only used if 'color_palette_name' is 'NONE'."
|
||||
),
|
||||
default=None,
|
||||
title="Custom Color Palette",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: str = SchemaField(description="Generated image URL")
|
||||
@@ -161,7 +151,7 @@ class IdeogramModelBlock(Block):
|
||||
super().__init__(
|
||||
id="6ab085e2-20b3-4055-bc3e-08036e01eca6",
|
||||
description="This block runs Ideogram models with both simple and advanced settings.",
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=IdeogramModelBlock.Input,
|
||||
output_schema=IdeogramModelBlock.Output,
|
||||
test_input={
|
||||
@@ -174,13 +164,6 @@ class IdeogramModelBlock(Block):
|
||||
"style_type": StyleType.AUTO,
|
||||
"negative_prompt": None,
|
||||
"color_palette_name": ColorPalettePreset.NONE,
|
||||
"custom_color_palette": [
|
||||
"#000030",
|
||||
"#1C0C47",
|
||||
"#9900FF",
|
||||
"#4285F4",
|
||||
"#FFFFFF",
|
||||
],
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
@@ -190,7 +173,7 @@ class IdeogramModelBlock(Block):
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"run_model": lambda api_key, model_name, prompt, seed, aspect_ratio, magic_prompt_option, style_type, negative_prompt, color_palette_name, custom_colors: "https://ideogram.ai/api/images/test-generated-image-url.png",
|
||||
"run_model": lambda api_key, model_name, prompt, seed, aspect_ratio, magic_prompt_option, style_type, negative_prompt, color_palette_name: "https://ideogram.ai/api/images/test-generated-image-url.png",
|
||||
"upscale_image": lambda api_key, image_url: "https://ideogram.ai/api/images/test-upscaled-image-url.png",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -212,7 +195,6 @@ class IdeogramModelBlock(Block):
|
||||
style_type=input_data.style_type.value,
|
||||
negative_prompt=input_data.negative_prompt,
|
||||
color_palette_name=input_data.color_palette_name.value,
|
||||
custom_colors=input_data.custom_color_palette,
|
||||
)
|
||||
|
||||
# Step 2: Upscale the image if requested
|
||||
@@ -235,7 +217,6 @@ class IdeogramModelBlock(Block):
|
||||
style_type: str,
|
||||
negative_prompt: Optional[str],
|
||||
color_palette_name: str,
|
||||
custom_colors: Optional[list[str]],
|
||||
):
|
||||
url = "https://api.ideogram.ai/generate"
|
||||
headers = {
|
||||
@@ -260,11 +241,7 @@ class IdeogramModelBlock(Block):
|
||||
data["image_request"]["negative_prompt"] = negative_prompt
|
||||
|
||||
if color_palette_name != "NONE":
|
||||
data["color_palette"] = {"name": color_palette_name}
|
||||
elif custom_colors:
|
||||
data["color_palette"] = {
|
||||
"members": [{"color_hex": color} for color in custom_colors]
|
||||
}
|
||||
data["image_request"]["color_palette"] = {"name": color_palette_name}
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=data, headers=headers)
|
||||
@@ -290,7 +267,9 @@ class IdeogramModelBlock(Block):
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=headers,
|
||||
data={"image_request": "{}"},
|
||||
data={
|
||||
"image_request": "{}", # Empty JSON object
|
||||
},
|
||||
files=files,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,555 +0,0 @@
|
||||
from datetime import date, time
|
||||
from typing import Any, Optional
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.settings import Config
|
||||
from backend.util.text import TextFormatter
|
||||
from backend.util.type import LongTextType, MediaFileType, ShortTextType
|
||||
|
||||
formatter = TextFormatter()
|
||||
config = Config()
|
||||
|
||||
|
||||
class AgentInputBlock(Block):
|
||||
"""
|
||||
This block is used to provide input to the graph.
|
||||
|
||||
It takes in a value, name, description, default values list and bool to limit selection to default values.
|
||||
|
||||
It Outputs the value passed as input.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
name: str = SchemaField(description="The name of the input.")
|
||||
value: Any = SchemaField(
|
||||
description="The value to be passed as input.",
|
||||
default=None,
|
||||
)
|
||||
title: str | None = SchemaField(
|
||||
description="The title of the input.", default=None, advanced=True
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the input.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
placeholder_values: list = SchemaField(
|
||||
description="The placeholder values to be passed as input.",
|
||||
default=[],
|
||||
advanced=True,
|
||||
hidden=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to show the input in the advanced section, if the field is not required.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
secret: bool = SchemaField(
|
||||
description="Whether the input should be treated as a secret.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
schema = self.get_field_schema("value")
|
||||
if possible_values := self.placeholder_values:
|
||||
schema["enum"] = possible_values
|
||||
return schema
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: Any = SchemaField(description="The value passed as input.")
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
**{
|
||||
"id": "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
|
||||
"description": "Base block for user inputs.",
|
||||
"input_schema": AgentInputBlock.Input,
|
||||
"output_schema": AgentInputBlock.Output,
|
||||
"test_input": [
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "input_1",
|
||||
"description": "Example test input.",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "input_2",
|
||||
"description": "Example test input with placeholders.",
|
||||
"placeholder_values": ["Hello, World!"],
|
||||
},
|
||||
],
|
||||
"test_output": [
|
||||
("result", "Hello, World!"),
|
||||
("result", "Hello, World!"),
|
||||
],
|
||||
"categories": {BlockCategory.INPUT, BlockCategory.BASIC},
|
||||
"block_type": BlockType.INPUT,
|
||||
"static_output": True,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, *args, **kwargs) -> BlockOutput:
|
||||
if input_data.value is not None:
|
||||
yield "result", input_data.value
|
||||
|
||||
|
||||
class AgentOutputBlock(Block):
|
||||
"""
|
||||
Records the output of the graph for users to see.
|
||||
|
||||
Behavior:
|
||||
If `format` is provided and the `value` is of a type that can be formatted,
|
||||
the block attempts to format the recorded_value using the `format`.
|
||||
If formatting fails or no `format` is provided, the raw `value` is output.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
value: Any = SchemaField(
|
||||
description="The value to be recorded as output.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
name: str = SchemaField(description="The name of the output.")
|
||||
title: str | None = SchemaField(
|
||||
description="The title of the output.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the output.",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="The format string to be used to format the recorded_value. Use Jinja2 syntax.",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to treat the output as advanced.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
secret: bool = SchemaField(
|
||||
description="Whether the output should be treated as a secret.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
def generate_schema(self):
|
||||
return self.get_field_schema("value")
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="The value recorded as output.")
|
||||
name: Any = SchemaField(description="The name of the value recorded as output.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="363ae599-353e-4804-937e-b2ee3cef3da4",
|
||||
description="Stores the output of the graph for users to see.",
|
||||
input_schema=AgentOutputBlock.Input,
|
||||
output_schema=AgentOutputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "output_1",
|
||||
"description": "This is a test output.",
|
||||
"format": "{{ output_1 }}!!",
|
||||
},
|
||||
{
|
||||
"value": "42",
|
||||
"name": "output_2",
|
||||
"description": "This is another test output.",
|
||||
"format": "{{ output_2 }}",
|
||||
},
|
||||
{
|
||||
"value": MockObject(value="!!", key="key"),
|
||||
"name": "output_3",
|
||||
"description": "This is a test output with a mock object.",
|
||||
"format": "{{ output_3 }}",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("output", "Hello, World!!!"),
|
||||
("output", "42"),
|
||||
("output", MockObject(value="!!", key="key")),
|
||||
],
|
||||
categories={BlockCategory.OUTPUT, BlockCategory.BASIC},
|
||||
block_type=BlockType.OUTPUT,
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, *args, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Attempts to format the recorded_value using the fmt_string if provided.
|
||||
If formatting fails or no fmt_string is given, returns the original recorded_value.
|
||||
"""
|
||||
if input_data.format:
|
||||
try:
|
||||
yield "output", formatter.format_string(
|
||||
input_data.format, {input_data.name: input_data.value}
|
||||
)
|
||||
except Exception as e:
|
||||
yield "output", f"Error: {e}, {input_data.value}"
|
||||
else:
|
||||
yield "output", input_data.value
|
||||
yield "name", input_data.name
|
||||
|
||||
|
||||
class AgentShortTextInputBlock(AgentInputBlock):
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[ShortTextType] = SchemaField(
|
||||
description="Short text input.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: str = SchemaField(description="Short text result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7fcd3bcb-8e1b-4e69-903d-32d3d4a92158",
|
||||
description="Block for short text input (single-line).",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentShortTextInputBlock.Input,
|
||||
output_schema=AgentShortTextInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Hello",
|
||||
"name": "short_text_1",
|
||||
"description": "Short text example 1",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": "Quick test",
|
||||
"name": "short_text_2",
|
||||
"description": "Short text example 2",
|
||||
"placeholder_values": ["Quick test", "Another option"],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", "Hello"),
|
||||
("result", "Quick test"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class AgentLongTextInputBlock(AgentInputBlock):
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[LongTextType] = SchemaField(
|
||||
description="Long text input (potentially multi-line).",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: str = SchemaField(description="Long text result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="90a56ffb-7024-4b2b-ab50-e26c5e5ab8ba",
|
||||
description="Block for long text input (multi-line).",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentLongTextInputBlock.Input,
|
||||
output_schema=AgentLongTextInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Lorem ipsum dolor sit amet...",
|
||||
"name": "long_text_1",
|
||||
"description": "Long text example 1",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": "Another multiline text input.",
|
||||
"name": "long_text_2",
|
||||
"description": "Long text example 2",
|
||||
"placeholder_values": ["Another multiline text input."],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", "Lorem ipsum dolor sit amet..."),
|
||||
("result", "Another multiline text input."),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class AgentNumberInputBlock(AgentInputBlock):
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[int] = SchemaField(
|
||||
description="Number input.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: int = SchemaField(description="Number result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="96dae2bb-97a2-41c2-bd2f-13a3b5a8ea98",
|
||||
description="Block for number input.",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentNumberInputBlock.Input,
|
||||
output_schema=AgentNumberInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": 42,
|
||||
"name": "number_input_1",
|
||||
"description": "Number example 1",
|
||||
"placeholder_values": [],
|
||||
},
|
||||
{
|
||||
"value": 314,
|
||||
"name": "number_input_2",
|
||||
"description": "Number example 2",
|
||||
"placeholder_values": [314, 2718],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", 42),
|
||||
("result", 314),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class AgentDateInputBlock(AgentInputBlock):
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[date] = SchemaField(
|
||||
description="Date input (YYYY-MM-DD).",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: date = SchemaField(description="Date result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7e198b09-4994-47db-8b4d-952d98241817",
|
||||
description="Block for date input.",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentDateInputBlock.Input,
|
||||
output_schema=AgentDateInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
# If your system can parse JSON date strings to date objects
|
||||
"value": str(date(2025, 3, 19)),
|
||||
"name": "date_input_1",
|
||||
"description": "Example date input 1",
|
||||
},
|
||||
{
|
||||
"value": str(date(2023, 12, 31)),
|
||||
"name": "date_input_2",
|
||||
"description": "Example date input 2",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", date(2025, 3, 19)),
|
||||
("result", date(2023, 12, 31)),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class AgentTimeInputBlock(AgentInputBlock):
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[time] = SchemaField(
|
||||
description="Time input (HH:MM:SS).",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: time = SchemaField(description="Time result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2a1c757e-86cf-4c7e-aacf-060dc382e434",
|
||||
description="Block for time input.",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentTimeInputBlock.Input,
|
||||
output_schema=AgentTimeInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": str(time(9, 30, 0)),
|
||||
"name": "time_input_1",
|
||||
"description": "Time example 1",
|
||||
},
|
||||
{
|
||||
"value": str(time(23, 59, 59)),
|
||||
"name": "time_input_2",
|
||||
"description": "Time example 2",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", time(9, 30, 0)),
|
||||
("result", time(23, 59, 59)),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class AgentFileInputBlock(AgentInputBlock):
|
||||
"""
|
||||
A simplified file-upload block. In real usage, you might have a custom
|
||||
file type or handle binary data. Here, we'll store a string path as the example.
|
||||
"""
|
||||
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[MediaFileType] = SchemaField(
|
||||
description="Path or reference to an uploaded file.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: str = SchemaField(description="File reference/path result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="95ead23f-8283-4654-aef3-10c053b74a31",
|
||||
description="Block for file upload input (string path for example).",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentFileInputBlock.Input,
|
||||
output_schema=AgentFileInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "data:image/png;base64,MQ==",
|
||||
"name": "file_upload_1",
|
||||
"description": "Example file upload 1",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", str),
|
||||
],
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
if not input_data.value:
|
||||
return
|
||||
|
||||
file_path = store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.value,
|
||||
return_content=False,
|
||||
)
|
||||
yield "result", file_path
|
||||
|
||||
|
||||
class AgentDropdownInputBlock(AgentInputBlock):
|
||||
"""
|
||||
A specialized text input block that relies on placeholder_values to present a dropdown.
|
||||
"""
|
||||
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: Optional[str] = SchemaField(
|
||||
description="Text selected from a dropdown.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
placeholder_values: list = SchemaField(
|
||||
description="Possible values for the dropdown.",
|
||||
default=[],
|
||||
advanced=False,
|
||||
title="Dropdown Options",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: str = SchemaField(description="Selected dropdown value.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="655d6fdf-a334-421c-b733-520549c07cd1",
|
||||
description="Block for dropdown text selection.",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentDropdownInputBlock.Input,
|
||||
output_schema=AgentDropdownInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Option A",
|
||||
"name": "dropdown_1",
|
||||
"placeholder_values": ["Option A", "Option B", "Option C"],
|
||||
"description": "Dropdown example 1",
|
||||
},
|
||||
{
|
||||
"value": "Option C",
|
||||
"name": "dropdown_2",
|
||||
"placeholder_values": ["Option A", "Option B", "Option C"],
|
||||
"description": "Dropdown example 2",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", "Option A"),
|
||||
("result", "Option C"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class AgentToggleInputBlock(AgentInputBlock):
|
||||
class Input(AgentInputBlock.Input):
|
||||
value: bool = SchemaField(
|
||||
description="Boolean toggle input.",
|
||||
default=False,
|
||||
advanced=False,
|
||||
title="Default Value",
|
||||
)
|
||||
|
||||
class Output(AgentInputBlock.Output):
|
||||
result: bool = SchemaField(description="Boolean toggle result.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="cbf36ab5-df4a-43b6-8a7f-f7ed8652116e",
|
||||
description="Block for boolean toggle input.",
|
||||
disabled=not config.enable_agent_input_subtype_blocks,
|
||||
input_schema=AgentToggleInputBlock.Input,
|
||||
output_schema=AgentToggleInputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": True,
|
||||
"name": "toggle_1",
|
||||
"description": "Toggle example 1",
|
||||
},
|
||||
{
|
||||
"value": False,
|
||||
"name": "toggle_2",
|
||||
"description": "Toggle example 2",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", True),
|
||||
("result", False),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
IO_BLOCK_IDs = [
|
||||
AgentInputBlock().id,
|
||||
AgentOutputBlock().id,
|
||||
AgentShortTextInputBlock().id,
|
||||
AgentLongTextInputBlock().id,
|
||||
AgentNumberInputBlock().id,
|
||||
AgentDateInputBlock().id,
|
||||
AgentTimeInputBlock().id,
|
||||
AgentFileInputBlock().id,
|
||||
AgentDropdownInputBlock().id,
|
||||
AgentToggleInputBlock().id,
|
||||
]
|
||||
@@ -1,272 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from backend.blocks.linear._auth import LinearCredentials
|
||||
from backend.blocks.linear.models import (
|
||||
CreateCommentResponse,
|
||||
CreateIssueResponse,
|
||||
Issue,
|
||||
Project,
|
||||
)
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
class LinearAPIException(Exception):
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class LinearClient:
|
||||
"""Client for the Linear API
|
||||
|
||||
If you're looking for the schema: https://studio.apollographql.com/public/Linear-API/variant/current/schema
|
||||
"""
|
||||
|
||||
API_URL = "https://api.linear.app/graphql"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
credentials: LinearCredentials | None = None,
|
||||
custom_requests: Optional[Requests] = None,
|
||||
):
|
||||
if custom_requests:
|
||||
self._requests = custom_requests
|
||||
else:
|
||||
|
||||
headers: Dict[str, str] = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if credentials:
|
||||
headers["Authorization"] = credentials.auth_header()
|
||||
|
||||
self._requests = Requests(
|
||||
extra_headers=headers,
|
||||
trusted_origins=["https://api.linear.app"],
|
||||
raise_for_status=False,
|
||||
)
|
||||
|
||||
def _execute_graphql_request(
|
||||
self, query: str, variables: dict | None = None
|
||||
) -> Any:
|
||||
"""
|
||||
Executes a GraphQL request against the Linear API and returns the response data.
|
||||
|
||||
Args:
|
||||
query: The GraphQL query string.
|
||||
variables (optional): Any GraphQL query variables
|
||||
|
||||
Returns:
|
||||
The parsed JSON response data, or raises a LinearAPIException on error.
|
||||
"""
|
||||
payload: Dict[str, Any] = {"query": query}
|
||||
if variables:
|
||||
payload["variables"] = variables
|
||||
|
||||
response = self._requests.post(self.API_URL, json=payload)
|
||||
|
||||
if not response.ok:
|
||||
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_message = error_data.get("errors", [{}])[0].get("message", "")
|
||||
except json.JSONDecodeError:
|
||||
error_message = response.text
|
||||
|
||||
raise LinearAPIException(
|
||||
f"Linear API request failed ({response.status_code}): {error_message}",
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
if "errors" in response_data:
|
||||
|
||||
error_messages = [
|
||||
error.get("message", "") for error in response_data["errors"]
|
||||
]
|
||||
raise LinearAPIException(
|
||||
f"Linear API returned errors: {', '.join(error_messages)}",
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
return response_data["data"]
|
||||
|
||||
def query(self, query: str, variables: Optional[dict] = None) -> dict:
|
||||
"""Executes a GraphQL query.
|
||||
|
||||
Args:
|
||||
query: The GraphQL query string.
|
||||
variables: Query variables, if any.
|
||||
|
||||
Returns:
|
||||
The response data.
|
||||
"""
|
||||
return self._execute_graphql_request(query, variables)
|
||||
|
||||
def mutate(self, mutation: str, variables: Optional[dict] = None) -> dict:
|
||||
"""Executes a GraphQL mutation.
|
||||
|
||||
Args:
|
||||
mutation: The GraphQL mutation string.
|
||||
variables: Query variables, if any.
|
||||
|
||||
Returns:
|
||||
The response data.
|
||||
"""
|
||||
return self._execute_graphql_request(mutation, variables)
|
||||
|
||||
def try_create_comment(self, issue_id: str, comment: str) -> CreateCommentResponse:
|
||||
try:
|
||||
mutation = """
|
||||
mutation CommentCreate($input: CommentCreateInput!) {
|
||||
commentCreate(input: $input) {
|
||||
success
|
||||
comment {
|
||||
id
|
||||
body
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
variables = {
|
||||
"input": {
|
||||
"body": comment,
|
||||
"issueId": issue_id,
|
||||
}
|
||||
}
|
||||
|
||||
added_comment = self.mutate(mutation, variables)
|
||||
# Select the commentCreate field from the mutation response
|
||||
return CreateCommentResponse(**added_comment["commentCreate"])
|
||||
except LinearAPIException as e:
|
||||
raise e
|
||||
|
||||
def try_get_team_by_name(self, team_name: str) -> str:
|
||||
try:
|
||||
query = """
|
||||
query GetTeamId($searchTerm: String!) {
|
||||
teams(filter: {
|
||||
or: [
|
||||
{ name: { eqIgnoreCase: $searchTerm } },
|
||||
{ key: { eqIgnoreCase: $searchTerm } }
|
||||
]
|
||||
}) {
|
||||
nodes {
|
||||
id
|
||||
name
|
||||
key
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
variables: dict[str, Any] = {
|
||||
"searchTerm": team_name,
|
||||
}
|
||||
|
||||
team_id = self.query(query, variables)
|
||||
return team_id["teams"]["nodes"][0]["id"]
|
||||
except LinearAPIException as e:
|
||||
raise e
|
||||
|
||||
def try_create_issue(
|
||||
self,
|
||||
team_id: str,
|
||||
title: str,
|
||||
description: str | None = None,
|
||||
priority: int | None = None,
|
||||
project_id: str | None = None,
|
||||
) -> CreateIssueResponse:
|
||||
try:
|
||||
mutation = """
|
||||
mutation IssueCreate($input: IssueCreateInput!) {
|
||||
issueCreate(input: $input) {
|
||||
issue {
|
||||
title
|
||||
description
|
||||
id
|
||||
identifier
|
||||
priority
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
variables: dict[str, Any] = {
|
||||
"input": {
|
||||
"teamId": team_id,
|
||||
"title": title,
|
||||
}
|
||||
}
|
||||
|
||||
if project_id:
|
||||
variables["input"]["projectId"] = project_id
|
||||
|
||||
if description:
|
||||
variables["input"]["description"] = description
|
||||
|
||||
if priority:
|
||||
variables["input"]["priority"] = priority
|
||||
|
||||
added_issue = self.mutate(mutation, variables)
|
||||
return CreateIssueResponse(**added_issue["issueCreate"])
|
||||
except LinearAPIException as e:
|
||||
raise e
|
||||
|
||||
def try_search_projects(self, term: str) -> list[Project]:
|
||||
try:
|
||||
query = """
|
||||
query SearchProjects($term: String!, $includeComments: Boolean!) {
|
||||
searchProjects(term: $term, includeComments: $includeComments) {
|
||||
nodes {
|
||||
id
|
||||
name
|
||||
description
|
||||
priority
|
||||
progress
|
||||
content
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
variables: dict[str, Any] = {
|
||||
"term": term,
|
||||
"includeComments": True,
|
||||
}
|
||||
|
||||
projects = self.query(query, variables)
|
||||
return [
|
||||
Project(**project) for project in projects["searchProjects"]["nodes"]
|
||||
]
|
||||
except LinearAPIException as e:
|
||||
raise e
|
||||
|
||||
def try_search_issues(self, term: str) -> list[Issue]:
|
||||
try:
|
||||
query = """
|
||||
query SearchIssues($term: String!, $includeComments: Boolean!) {
|
||||
searchIssues(term: $term, includeComments: $includeComments) {
|
||||
nodes {
|
||||
id
|
||||
identifier
|
||||
title
|
||||
description
|
||||
priority
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
variables: dict[str, Any] = {
|
||||
"term": term,
|
||||
"includeComments": True,
|
||||
}
|
||||
|
||||
issues = self.query(query, variables)
|
||||
return [Issue(**issue) for issue in issues["searchIssues"]["nodes"]]
|
||||
except LinearAPIException as e:
|
||||
raise e
|
||||
@@ -1,101 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
secrets = Secrets()
|
||||
LINEAR_OAUTH_IS_CONFIGURED = bool(
|
||||
secrets.linear_client_id and secrets.linear_client_secret
|
||||
)
|
||||
|
||||
LinearCredentials = OAuth2Credentials | APIKeyCredentials
|
||||
# LinearCredentialsInput = CredentialsMetaInput[
|
||||
# Literal[ProviderName.LINEAR],
|
||||
# Literal["oauth2", "api_key"] if LINEAR_OAUTH_IS_CONFIGURED else Literal["oauth2"],
|
||||
# ]
|
||||
LinearCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.LINEAR], Literal["oauth2"]
|
||||
]
|
||||
|
||||
|
||||
# (required) Comma separated list of scopes:
|
||||
|
||||
# read - (Default) Read access for the user's account. This scope will always be present.
|
||||
|
||||
# write - Write access for the user's account. If your application only needs to create comments, use a more targeted scope
|
||||
|
||||
# issues:create - Allows creating new issues and their attachments
|
||||
|
||||
# comments:create - Allows creating new issue comments
|
||||
|
||||
# timeSchedule:write - Allows creating and modifying time schedules
|
||||
|
||||
|
||||
# admin - Full access to admin level endpoints. You should never ask for this permission unless it's absolutely needed
|
||||
class LinearScope(str, Enum):
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
ISSUES_CREATE = "issues:create"
|
||||
COMMENTS_CREATE = "comments:create"
|
||||
TIME_SCHEDULE_WRITE = "timeSchedule:write"
|
||||
ADMIN = "admin"
|
||||
|
||||
|
||||
def LinearCredentialsField(scopes: list[LinearScope]) -> LinearCredentialsInput:
|
||||
"""
|
||||
Creates a Linear credentials input on a block.
|
||||
|
||||
Params:
|
||||
scope: The authorization scope needed for the block to work. ([list of available scopes](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/scopes-for-oauth-apps#available-scopes))
|
||||
""" # noqa
|
||||
return CredentialsField(
|
||||
required_scopes=set([LinearScope.READ.value]).union(
|
||||
set([scope.value for scope in scopes])
|
||||
),
|
||||
description="The Linear integration can be used with OAuth, "
|
||||
"or any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
|
||||
|
||||
TEST_CREDENTIALS_OAUTH = OAuth2Credentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="linear",
|
||||
title="Mock Linear API key",
|
||||
username="mock-linear-username",
|
||||
access_token=SecretStr("mock-linear-access-token"),
|
||||
access_token_expires_at=None,
|
||||
refresh_token=SecretStr("mock-linear-refresh-token"),
|
||||
refresh_token_expires_at=None,
|
||||
scopes=["mock-linear-scopes"],
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_API_KEY = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="linear",
|
||||
title="Mock Linear API key",
|
||||
api_key=SecretStr("mock-linear-api-key"),
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT_OAUTH = {
|
||||
"provider": TEST_CREDENTIALS_OAUTH.provider,
|
||||
"id": TEST_CREDENTIALS_OAUTH.id,
|
||||
"type": TEST_CREDENTIALS_OAUTH.type,
|
||||
"title": TEST_CREDENTIALS_OAUTH.type,
|
||||
}
|
||||
|
||||
TEST_CREDENTIALS_INPUT_API_KEY = {
|
||||
"provider": TEST_CREDENTIALS_API_KEY.provider,
|
||||
"id": TEST_CREDENTIALS_API_KEY.id,
|
||||
"type": TEST_CREDENTIALS_API_KEY.type,
|
||||
"title": TEST_CREDENTIALS_API_KEY.type,
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
from backend.blocks.linear._api import LinearAPIException, LinearClient
|
||||
from backend.blocks.linear._auth import (
|
||||
LINEAR_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS_INPUT_OAUTH,
|
||||
TEST_CREDENTIALS_OAUTH,
|
||||
LinearCredentials,
|
||||
LinearCredentialsField,
|
||||
LinearCredentialsInput,
|
||||
LinearScope,
|
||||
)
|
||||
from backend.blocks.linear.models import CreateCommentResponse
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class LinearCreateCommentBlock(Block):
|
||||
"""Block for creating comments on Linear issues"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: LinearCredentialsInput = LinearCredentialsField(
|
||||
scopes=[LinearScope.COMMENTS_CREATE],
|
||||
)
|
||||
issue_id: str = SchemaField(description="ID of the issue to comment on")
|
||||
comment: str = SchemaField(description="Comment text to add to the issue")
|
||||
|
||||
class Output(BlockSchema):
|
||||
comment_id: str = SchemaField(description="ID of the created comment")
|
||||
comment_body: str = SchemaField(
|
||||
description="Text content of the created comment"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if comment creation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8f7d3a2e-9b5c-4c6a-8f1d-7c8b3e4a5d6c",
|
||||
description="Creates a new comment on a Linear issue",
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
|
||||
test_input={
|
||||
"issue_id": "TEST-123",
|
||||
"comment": "Test comment",
|
||||
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
||||
},
|
||||
disabled=not LINEAR_OAUTH_IS_CONFIGURED,
|
||||
test_credentials=TEST_CREDENTIALS_OAUTH,
|
||||
test_output=[("comment_id", "abc123"), ("comment_body", "Test comment")],
|
||||
test_mock={
|
||||
"create_comment": lambda *args, **kwargs: (
|
||||
"abc123",
|
||||
"Test comment",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_comment(
|
||||
credentials: LinearCredentials, issue_id: str, comment: str
|
||||
) -> tuple[str, str]:
|
||||
client = LinearClient(credentials=credentials)
|
||||
response: CreateCommentResponse = client.try_create_comment(
|
||||
issue_id=issue_id, comment=comment
|
||||
)
|
||||
return response.comment.id, response.comment.body
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""Execute the comment creation"""
|
||||
try:
|
||||
comment_id, comment_body = self.create_comment(
|
||||
credentials=credentials,
|
||||
issue_id=input_data.issue_id,
|
||||
comment=input_data.comment,
|
||||
)
|
||||
|
||||
yield "comment_id", comment_id
|
||||
yield "comment_body", comment_body
|
||||
|
||||
except LinearAPIException as e:
|
||||
yield "error", str(e)
|
||||
except Exception as e:
|
||||
yield "error", f"Unexpected error: {str(e)}"
|
||||
@@ -1,189 +0,0 @@
|
||||
from backend.blocks.linear._api import LinearAPIException, LinearClient
|
||||
from backend.blocks.linear._auth import (
|
||||
LINEAR_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS_INPUT_OAUTH,
|
||||
TEST_CREDENTIALS_OAUTH,
|
||||
LinearCredentials,
|
||||
LinearCredentialsField,
|
||||
LinearCredentialsInput,
|
||||
LinearScope,
|
||||
)
|
||||
from backend.blocks.linear.models import CreateIssueResponse, Issue
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class LinearCreateIssueBlock(Block):
|
||||
"""Block for creating issues on Linear"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: LinearCredentialsInput = LinearCredentialsField(
|
||||
scopes=[LinearScope.ISSUES_CREATE],
|
||||
)
|
||||
title: str = SchemaField(description="Title of the issue")
|
||||
description: str | None = SchemaField(description="Description of the issue")
|
||||
team_name: str = SchemaField(
|
||||
description="Name of the team to create the issue on"
|
||||
)
|
||||
priority: int | None = SchemaField(
|
||||
description="Priority of the issue",
|
||||
default=None,
|
||||
minimum=0,
|
||||
maximum=4,
|
||||
)
|
||||
project_name: str | None = SchemaField(
|
||||
description="Name of the project to create the issue on",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
issue_id: str = SchemaField(description="ID of the created issue")
|
||||
issue_title: str = SchemaField(description="Title of the created issue")
|
||||
error: str = SchemaField(description="Error message if issue creation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f9c68f55-dcca-40a8-8771-abf9601680aa",
|
||||
description="Creates a new issue on Linear",
|
||||
disabled=not LINEAR_OAUTH_IS_CONFIGURED,
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
|
||||
test_input={
|
||||
"title": "Test issue",
|
||||
"description": "Test description",
|
||||
"team_name": "Test team",
|
||||
"project_name": "Test project",
|
||||
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS_OAUTH,
|
||||
test_output=[("issue_id", "abc123"), ("issue_title", "Test issue")],
|
||||
test_mock={
|
||||
"create_issue": lambda *args, **kwargs: (
|
||||
"abc123",
|
||||
"Test issue",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_issue(
|
||||
credentials: LinearCredentials,
|
||||
team_name: str,
|
||||
title: str,
|
||||
description: str | None = None,
|
||||
priority: int | None = None,
|
||||
project_name: str | None = None,
|
||||
) -> tuple[str, str]:
|
||||
client = LinearClient(credentials=credentials)
|
||||
team_id = client.try_get_team_by_name(team_name=team_name)
|
||||
project_id: str | None = None
|
||||
if project_name:
|
||||
projects = client.try_search_projects(term=project_name)
|
||||
if projects:
|
||||
project_id = projects[0].id
|
||||
else:
|
||||
raise LinearAPIException("Project not found", status_code=404)
|
||||
response: CreateIssueResponse = client.try_create_issue(
|
||||
team_id=team_id,
|
||||
title=title,
|
||||
description=description,
|
||||
priority=priority,
|
||||
project_id=project_id,
|
||||
)
|
||||
return response.issue.identifier, response.issue.title
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""Execute the issue creation"""
|
||||
try:
|
||||
issue_id, issue_title = self.create_issue(
|
||||
credentials=credentials,
|
||||
team_name=input_data.team_name,
|
||||
title=input_data.title,
|
||||
description=input_data.description,
|
||||
priority=input_data.priority,
|
||||
project_name=input_data.project_name,
|
||||
)
|
||||
|
||||
yield "issue_id", issue_id
|
||||
yield "issue_title", issue_title
|
||||
|
||||
except LinearAPIException as e:
|
||||
yield "error", str(e)
|
||||
except Exception as e:
|
||||
yield "error", f"Unexpected error: {str(e)}"
|
||||
|
||||
|
||||
class LinearSearchIssuesBlock(Block):
|
||||
"""Block for searching issues on Linear"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
term: str = SchemaField(description="Term to search for issues")
|
||||
credentials: LinearCredentialsInput = LinearCredentialsField(
|
||||
scopes=[LinearScope.READ],
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
issues: list[Issue] = SchemaField(description="List of issues")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b5a2a0e6-26b4-4c5b-8a42-bc79e9cb65c2",
|
||||
description="Searches for issues on Linear",
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
disabled=not LINEAR_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"term": "Test issue",
|
||||
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS_OAUTH,
|
||||
test_output=[
|
||||
(
|
||||
"issues",
|
||||
[
|
||||
Issue(
|
||||
id="abc123",
|
||||
identifier="abc123",
|
||||
title="Test issue",
|
||||
description="Test description",
|
||||
priority=1,
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"search_issues": lambda *args, **kwargs: [
|
||||
Issue(
|
||||
id="abc123",
|
||||
identifier="abc123",
|
||||
title="Test issue",
|
||||
description="Test description",
|
||||
priority=1,
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def search_issues(
|
||||
credentials: LinearCredentials,
|
||||
term: str,
|
||||
) -> list[Issue]:
|
||||
client = LinearClient(credentials=credentials)
|
||||
response: list[Issue] = client.try_search_issues(term=term)
|
||||
return response
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""Execute the issue search"""
|
||||
try:
|
||||
issues = self.search_issues(credentials=credentials, term=input_data.term)
|
||||
yield "issues", issues
|
||||
except LinearAPIException as e:
|
||||
yield "error", str(e)
|
||||
except Exception as e:
|
||||
yield "error", f"Unexpected error: {str(e)}"
|
||||
@@ -1,41 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Comment(BaseModel):
|
||||
id: str
|
||||
body: str
|
||||
|
||||
|
||||
class CreateCommentInput(BaseModel):
|
||||
body: str
|
||||
issueId: str
|
||||
|
||||
|
||||
class CreateCommentResponse(BaseModel):
|
||||
success: bool
|
||||
comment: Comment
|
||||
|
||||
|
||||
class CreateCommentResponseWrapper(BaseModel):
|
||||
commentCreate: CreateCommentResponse
|
||||
|
||||
|
||||
class Issue(BaseModel):
|
||||
id: str
|
||||
identifier: str
|
||||
title: str
|
||||
description: str | None
|
||||
priority: int
|
||||
|
||||
|
||||
class CreateIssueResponse(BaseModel):
|
||||
issue: Issue
|
||||
|
||||
|
||||
class Project(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
priority: int
|
||||
progress: int
|
||||
content: str
|
||||
@@ -1,95 +0,0 @@
|
||||
from backend.blocks.linear._api import LinearAPIException, LinearClient
|
||||
from backend.blocks.linear._auth import (
|
||||
LINEAR_OAUTH_IS_CONFIGURED,
|
||||
TEST_CREDENTIALS_INPUT_OAUTH,
|
||||
TEST_CREDENTIALS_OAUTH,
|
||||
LinearCredentials,
|
||||
LinearCredentialsField,
|
||||
LinearCredentialsInput,
|
||||
LinearScope,
|
||||
)
|
||||
from backend.blocks.linear.models import Project
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class LinearSearchProjectsBlock(Block):
|
||||
"""Block for searching projects on Linear"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: LinearCredentialsInput = LinearCredentialsField(
|
||||
scopes=[LinearScope.READ],
|
||||
)
|
||||
term: str = SchemaField(description="Term to search for projects")
|
||||
|
||||
class Output(BlockSchema):
|
||||
projects: list[Project] = SchemaField(description="List of projects")
|
||||
error: str = SchemaField(description="Error message if issue creation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="446a1d35-9d8f-4ac5-83ea-7684ec50e6af",
|
||||
description="Searches for projects on Linear",
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
|
||||
test_input={
|
||||
"term": "Test project",
|
||||
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
||||
},
|
||||
disabled=not LINEAR_OAUTH_IS_CONFIGURED,
|
||||
test_credentials=TEST_CREDENTIALS_OAUTH,
|
||||
test_output=[
|
||||
(
|
||||
"projects",
|
||||
[
|
||||
Project(
|
||||
id="abc123",
|
||||
name="Test project",
|
||||
description="Test description",
|
||||
priority=1,
|
||||
progress=1,
|
||||
content="Test content",
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"search_projects": lambda *args, **kwargs: [
|
||||
Project(
|
||||
id="abc123",
|
||||
name="Test project",
|
||||
description="Test description",
|
||||
priority=1,
|
||||
progress=1,
|
||||
content="Test content",
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def search_projects(
|
||||
credentials: LinearCredentials,
|
||||
term: str,
|
||||
) -> list[Project]:
|
||||
client = LinearClient(credentials=credentials)
|
||||
response: list[Project] = client.try_search_projects(term=term)
|
||||
return response
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""Execute the project search"""
|
||||
try:
|
||||
projects = self.search_projects(
|
||||
credentials=credentials,
|
||||
term=input_data.term,
|
||||
)
|
||||
|
||||
yield "projects", projects
|
||||
|
||||
except LinearAPIException as e:
|
||||
yield "error", str(e)
|
||||
except Exception as e:
|
||||
yield "error", f"Unexpected error: {str(e)}"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,245 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Literal, Optional
|
||||
|
||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||
from moviepy.video.fx.Loop import Loop
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class MediaDurationBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
media_in: MediaFileType = SchemaField(
|
||||
description="Media input (URL, data URI, or local path)."
|
||||
)
|
||||
is_video: bool = SchemaField(
|
||||
description="Whether the media is a video (True) or audio (False).",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
duration: float = SchemaField(
|
||||
description="Duration of the media file (in seconds)."
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if something fails.", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
||||
description="Block to get the duration of a media file.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=MediaDurationBlock.Input,
|
||||
output_schema=MediaDurationBlock.Output,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# 1) Store the input media locally
|
||||
local_media_path = store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.media_in,
|
||||
return_content=False,
|
||||
)
|
||||
media_abspath = get_exec_file_path(graph_exec_id, local_media_path)
|
||||
|
||||
# 2) Load the clip
|
||||
if input_data.is_video:
|
||||
clip = VideoFileClip(media_abspath)
|
||||
else:
|
||||
clip = AudioFileClip(media_abspath)
|
||||
|
||||
yield "duration", clip.duration
|
||||
|
||||
|
||||
class LoopVideoBlock(Block):
|
||||
"""
|
||||
Block for looping (repeating) a video clip until a given duration or number of loops.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="The input video (can be a URL, data URI, or local path)."
|
||||
)
|
||||
# Provide EITHER a `duration` or `n_loops` or both. We'll demonstrate `duration`.
|
||||
duration: Optional[float] = SchemaField(
|
||||
description="Target duration (in seconds) to loop the video to. If omitted, defaults to no looping.",
|
||||
default=None,
|
||||
ge=0.0,
|
||||
)
|
||||
n_loops: Optional[int] = SchemaField(
|
||||
description="Number of times to repeat the video. If omitted, defaults to 1 (no repeat).",
|
||||
default=None,
|
||||
ge=1,
|
||||
)
|
||||
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
|
||||
description="How to return the output video. Either a relative path or base64 data URI.",
|
||||
default="file_path",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
video_out: str = SchemaField(
|
||||
description="Looped video returned either as a relative path or a data URI."
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if something fails.", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
||||
description="Block to loop a video to a given duration or number of repeats.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=LoopVideoBlock.Input,
|
||||
output_schema=LoopVideoBlock.Output,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# 1) Store the input video locally
|
||||
local_video_path = store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.video_in,
|
||||
return_content=False,
|
||||
)
|
||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||
|
||||
# 2) Load the clip
|
||||
clip = VideoFileClip(input_abspath)
|
||||
|
||||
# 3) Apply the loop effect
|
||||
looped_clip = clip
|
||||
if input_data.duration:
|
||||
# Loop until we reach the specified duration
|
||||
looped_clip = looped_clip.with_effects([Loop(duration=input_data.duration)])
|
||||
elif input_data.n_loops:
|
||||
looped_clip = looped_clip.with_effects([Loop(n=input_data.n_loops)])
|
||||
else:
|
||||
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
||||
|
||||
assert isinstance(looped_clip, VideoFileClip)
|
||||
|
||||
# 4) Save the looped output
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_looped_{os.path.basename(local_video_path)}"
|
||||
)
|
||||
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
||||
|
||||
looped_clip = looped_clip.with_audio(clip.audio)
|
||||
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||
|
||||
# Return as data URI
|
||||
video_out = store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=output_filename,
|
||||
return_content=input_data.output_return_type == "data_uri",
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
|
||||
|
||||
class AddAudioToVideoBlock(Block):
|
||||
"""
|
||||
Block that adds (attaches) an audio track to an existing video.
|
||||
Optionally scale the volume of the new track.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="Video input (URL, data URI, or local path)."
|
||||
)
|
||||
audio_in: MediaFileType = SchemaField(
|
||||
description="Audio input (URL, data URI, or local path)."
|
||||
)
|
||||
volume: float = SchemaField(
|
||||
description="Volume scale for the newly attached audio track (1.0 = original).",
|
||||
default=1.0,
|
||||
)
|
||||
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
|
||||
description="Return the final output as a relative path or base64 data URI.",
|
||||
default="file_path",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
description="Final video (with attached audio), as a path or data URI."
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if something fails.", default=""
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3503748d-62b6-4425-91d6-725b064af509",
|
||||
description="Block to attach an audio file to a video file using moviepy.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=AddAudioToVideoBlock.Input,
|
||||
output_schema=AddAudioToVideoBlock.Output,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# 1) Store the inputs locally
|
||||
local_video_path = store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.video_in,
|
||||
return_content=False,
|
||||
)
|
||||
local_audio_path = store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.audio_in,
|
||||
return_content=False,
|
||||
)
|
||||
|
||||
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
||||
video_abspath = os.path.join(abs_temp_dir, local_video_path)
|
||||
audio_abspath = os.path.join(abs_temp_dir, local_audio_path)
|
||||
|
||||
# 2) Load video + audio with moviepy
|
||||
video_clip = VideoFileClip(video_abspath)
|
||||
audio_clip = AudioFileClip(audio_abspath)
|
||||
# Optionally scale volume
|
||||
if input_data.volume != 1.0:
|
||||
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
||||
|
||||
# 3) Attach the new audio track
|
||||
final_clip = video_clip.with_audio(audio_clip)
|
||||
|
||||
# 4) Write to output file
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}"
|
||||
)
|
||||
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
||||
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||
|
||||
# 5) Return either path or data URI
|
||||
video_out = store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=output_filename,
|
||||
return_content=input_data.output_return_type == "data_uri",
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
@@ -1,338 +0,0 @@
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from mem0 import MemoryClient
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
provider="mem0",
|
||||
api_key=SecretStr("mock-mem0-api-key"),
|
||||
title="Mock Mem0 API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class Mem0Base:
|
||||
"""Base class with shared utilities for Mem0 blocks"""
|
||||
|
||||
@staticmethod
|
||||
def _get_client(credentials: APIKeyCredentials) -> MemoryClient:
|
||||
"""Get initialized Mem0 client"""
|
||||
return MemoryClient(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
|
||||
Filter = dict[str, list[dict[str, str | dict[str, list[str]]]]]
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
discriminator: Literal["conversation"]
|
||||
messages: list[dict[str, str]]
|
||||
|
||||
|
||||
class Content(BaseModel):
|
||||
discriminator: Literal["content"]
|
||||
content: str
|
||||
|
||||
|
||||
class AddMemoryBlock(Block, Mem0Base):
|
||||
"""Block for adding memories to Mem0
|
||||
|
||||
Always limited by user_id and optional graph_id and graph_exec_id"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.MEM0], Literal["api_key"]
|
||||
] = CredentialsField(description="Mem0 API key credentials")
|
||||
content: Union[Content, Conversation] = SchemaField(
|
||||
discriminator="discriminator",
|
||||
description="Content to add - either a string or list of message objects as output from an AI block",
|
||||
default=Content(discriminator="content", content="I'm a vegetarian"),
|
||||
)
|
||||
metadata: dict[str, Any] = SchemaField(
|
||||
description="Optional metadata for the memory", default={}
|
||||
)
|
||||
|
||||
limit_memory_to_run: bool = SchemaField(
|
||||
description="Limit the memory to the run", default=False
|
||||
)
|
||||
limit_memory_to_agent: bool = SchemaField(
|
||||
description="Limit the memory to the agent", default=False
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
action: str = SchemaField(description="Action of the operation")
|
||||
memory: str = SchemaField(description="Memory created")
|
||||
error: str = SchemaField(description="Error message if operation fails")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="dce97578-86be-45a4-ae50-f6de33fc935a",
|
||||
description="Add new memories to Mem0 with user segmentation",
|
||||
input_schema=AddMemoryBlock.Input,
|
||||
output_schema=AddMemoryBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"content": {
|
||||
"discriminator": "conversation",
|
||||
"messages": [{"role": "user", "content": "I'm a vegetarian"}],
|
||||
},
|
||||
"metadata": {"food": "vegetarian"},
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
{
|
||||
"content": {
|
||||
"discriminator": "content",
|
||||
"content": "I am a vegetarian",
|
||||
},
|
||||
"metadata": {"food": "vegetarian"},
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
],
|
||||
test_output=[("action", "NO_CHANGE"), ("action", "NO_CHANGE")],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={"_get_client": lambda credentials: MockMemoryClient()},
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
**kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
client = self._get_client(credentials)
|
||||
|
||||
if isinstance(input_data.content, Conversation):
|
||||
messages = input_data.content.messages
|
||||
else:
|
||||
messages = [{"role": "user", "content": input_data.content}]
|
||||
|
||||
params = {
|
||||
"user_id": user_id,
|
||||
"output_format": "v1.1",
|
||||
"metadata": input_data.metadata,
|
||||
}
|
||||
|
||||
if input_data.limit_memory_to_run:
|
||||
params["run_id"] = graph_exec_id
|
||||
if input_data.limit_memory_to_agent:
|
||||
params["agent_id"] = graph_id
|
||||
|
||||
# Use the client to add memory
|
||||
result = client.add(
|
||||
messages,
|
||||
**params,
|
||||
)
|
||||
|
||||
if len(result.get("results", [])) > 0:
|
||||
for result in result.get("results", []):
|
||||
yield "action", result["event"]
|
||||
yield "memory", result["memory"]
|
||||
else:
|
||||
yield "action", "NO_CHANGE"
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(object=e)
|
||||
|
||||
|
||||
class SearchMemoryBlock(Block, Mem0Base):
|
||||
"""Block for searching memories in Mem0"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.MEM0], Literal["api_key"]
|
||||
] = CredentialsField(description="Mem0 API key credentials")
|
||||
query: str = SchemaField(
|
||||
description="Search query",
|
||||
advanced=False,
|
||||
)
|
||||
trigger: bool = SchemaField(
|
||||
description="An unused field that is used to (re-)trigger the block when you have no other inputs",
|
||||
default=False,
|
||||
advanced=False,
|
||||
)
|
||||
categories_filter: list[str] = SchemaField(
|
||||
description="Categories to filter by",
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
limit_memory_to_run: bool = SchemaField(
|
||||
description="Limit the memory to the run", default=False
|
||||
)
|
||||
limit_memory_to_agent: bool = SchemaField(
|
||||
description="Limit the memory to the agent", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
memories: Any = SchemaField(description="List of matching memories")
|
||||
error: str = SchemaField(description="Error message if operation fails")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bd7c84e3-e073-4b75-810c-600886ec8a5b",
|
||||
description="Search memories in Mem0 by user",
|
||||
input_schema=SearchMemoryBlock.Input,
|
||||
output_schema=SearchMemoryBlock.Output,
|
||||
test_input={
|
||||
"query": "vegetarian preferences",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"top_k": 10,
|
||||
"rerank": True,
|
||||
},
|
||||
test_output=[
|
||||
("memories", [{"id": "test-memory", "content": "test content"}])
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={"_get_client": lambda credentials: MockMemoryClient()},
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
**kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
client = self._get_client(credentials)
|
||||
|
||||
filters: Filter = {
|
||||
# This works with only one filter, so we can allow others to add on later
|
||||
"AND": [
|
||||
{"user_id": user_id},
|
||||
]
|
||||
}
|
||||
if input_data.categories_filter:
|
||||
filters["AND"].append(
|
||||
{"categories": {"contains": input_data.categories_filter}}
|
||||
)
|
||||
if input_data.limit_memory_to_run:
|
||||
filters["AND"].append({"run_id": graph_exec_id})
|
||||
if input_data.limit_memory_to_agent:
|
||||
filters["AND"].append({"agent_id": graph_id})
|
||||
|
||||
result: list[dict[str, Any]] = client.search(
|
||||
input_data.query, version="v2", filters=filters
|
||||
)
|
||||
yield "memories", result
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GetAllMemoriesBlock(Block, Mem0Base):
|
||||
"""Block for retrieving all memories from Mem0"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.MEM0], Literal["api_key"]
|
||||
] = CredentialsField(description="Mem0 API key credentials")
|
||||
trigger: bool = SchemaField(
|
||||
description="An unused field that is used to trigger the block when you have no other inputs",
|
||||
default=False,
|
||||
advanced=False,
|
||||
)
|
||||
categories: Optional[list[str]] = SchemaField(
|
||||
description="Filter by categories", default=None
|
||||
)
|
||||
limit_memory_to_run: bool = SchemaField(
|
||||
description="Limit the memory to the run", default=False
|
||||
)
|
||||
limit_memory_to_agent: bool = SchemaField(
|
||||
description="Limit the memory to the agent", default=False
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
memories: Any = SchemaField(description="List of memories")
|
||||
error: str = SchemaField(description="Error message if operation fails")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="45aee5bf-4767-45d1-a28b-e01c5aae9fc1",
|
||||
description="Retrieve all memories from Mem0 with pagination",
|
||||
input_schema=GetAllMemoriesBlock.Input,
|
||||
output_schema=GetAllMemoriesBlock.Output,
|
||||
test_input={
|
||||
"user_id": "test_user",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
("memories", [{"id": "test-memory", "content": "test content"}]),
|
||||
],
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_mock={"_get_client": lambda credentials: MockMemoryClient()},
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
**kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
client = self._get_client(credentials)
|
||||
|
||||
filters: Filter = {
|
||||
"AND": [
|
||||
{"user_id": user_id},
|
||||
]
|
||||
}
|
||||
if input_data.limit_memory_to_run:
|
||||
filters["AND"].append({"run_id": graph_exec_id})
|
||||
if input_data.limit_memory_to_agent:
|
||||
filters["AND"].append({"agent_id": graph_id})
|
||||
if input_data.categories:
|
||||
filters["AND"].append(
|
||||
{"categories": {"contains": input_data.categories}}
|
||||
)
|
||||
|
||||
memories: list[dict[str, Any]] = client.get_all(
|
||||
filters=filters,
|
||||
version="v2",
|
||||
)
|
||||
|
||||
yield "memories", memories
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
# Mock client for testing
|
||||
class MockMemoryClient:
|
||||
"""Mock Mem0 client for testing"""
|
||||
|
||||
def add(self, *args, **kwargs):
|
||||
return {"memory_id": "test-memory-id", "status": "success"}
|
||||
|
||||
def search(self, *args, **kwargs) -> list[dict[str, str]]:
|
||||
return [{"id": "test-memory", "content": "test content"}]
|
||||
|
||||
def get_all(self, *args, **kwargs) -> list[dict[str, str]]:
|
||||
return [{"id": "test-memory", "content": "test content"}]
|
||||
@@ -6,14 +6,13 @@ from backend.blocks.nvidia._auth import (
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import requests
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
class NvidiaDeepfakeDetectBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: NvidiaCredentialsInput = NvidiaCredentialsField()
|
||||
image_base64: MediaFileType = SchemaField(
|
||||
description="Image to analyze for deepfakes",
|
||||
image_base64: str = SchemaField(
|
||||
description="Image to analyze for deepfakes", image_upload=True
|
||||
)
|
||||
return_image: bool = SchemaField(
|
||||
description="Whether to return the processed image with markings",
|
||||
@@ -23,12 +22,16 @@ class NvidiaDeepfakeDetectBlock(Block):
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(
|
||||
description="Detection status (SUCCESS, ERROR, CONTENT_FILTERED)",
|
||||
default="",
|
||||
)
|
||||
image: MediaFileType = SchemaField(
|
||||
image: str = SchemaField(
|
||||
description="Processed image with detection markings (if return_image=True)",
|
||||
default="",
|
||||
image_output=True,
|
||||
)
|
||||
is_deepfake: float = SchemaField(
|
||||
description="Probability that the image is a deepfake (0-1)",
|
||||
default=0.0,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -1,48 +1,22 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Iterator, Literal
|
||||
from typing import Iterator
|
||||
|
||||
import praw
|
||||
from pydantic import BaseModel, SecretStr
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
UserPasswordCredentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.settings import Settings
|
||||
|
||||
RedditCredentials = UserPasswordCredentials
|
||||
RedditCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.REDDIT],
|
||||
Literal["user_password"],
|
||||
]
|
||||
|
||||
|
||||
def RedditCredentialsField() -> RedditCredentialsInput:
|
||||
"""Creates a Reddit credentials input on a block."""
|
||||
return CredentialsField(
|
||||
description="The Reddit integration requires a username and password.",
|
||||
)
|
||||
class RedditCredentials(BaseModel):
|
||||
client_id: BlockSecret = SecretField(key="reddit_client_id")
|
||||
client_secret: BlockSecret = SecretField(key="reddit_client_secret")
|
||||
username: BlockSecret = SecretField(key="reddit_username")
|
||||
password: BlockSecret = SecretField(key="reddit_password")
|
||||
user_agent: str = "AutoGPT:1.0 (by /u/autogpt)"
|
||||
|
||||
|
||||
TEST_CREDENTIALS = UserPasswordCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="reddit",
|
||||
username=SecretStr("mock-reddit-username"),
|
||||
password=SecretStr("mock-reddit-password"),
|
||||
title="Mock Reddit credentials",
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
model_config = ConfigDict(title="Reddit Credentials")
|
||||
|
||||
|
||||
class RedditPost(BaseModel):
|
||||
@@ -57,16 +31,13 @@ class RedditComment(BaseModel):
|
||||
comment: str
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def get_praw(creds: RedditCredentials) -> praw.Reddit:
|
||||
client = praw.Reddit(
|
||||
client_id=settings.secrets.reddit_client_id,
|
||||
client_secret=settings.secrets.reddit_client_secret,
|
||||
client_id=creds.client_id.get_secret_value(),
|
||||
client_secret=creds.client_secret.get_secret_value(),
|
||||
username=creds.username.get_secret_value(),
|
||||
password=creds.password.get_secret_value(),
|
||||
user_agent=settings.config.reddit_user_agent,
|
||||
user_agent=creds.user_agent,
|
||||
)
|
||||
me = client.user.me()
|
||||
if not me:
|
||||
@@ -77,11 +48,11 @@ def get_praw(creds: RedditCredentials) -> praw.Reddit:
|
||||
|
||||
class GetRedditPostsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
subreddit: str = SchemaField(
|
||||
description="Subreddit name, excluding the /r/ prefix",
|
||||
default="writingprompts",
|
||||
subreddit: str = SchemaField(description="Subreddit name")
|
||||
creds: RedditCredentials = SchemaField(
|
||||
description="Reddit credentials",
|
||||
default=RedditCredentials(),
|
||||
)
|
||||
credentials: RedditCredentialsInput = RedditCredentialsField()
|
||||
last_minutes: int | None = SchemaField(
|
||||
description="Post time to stop minutes ago while fetching posts",
|
||||
default=None,
|
||||
@@ -99,18 +70,20 @@ class GetRedditPostsBlock(Block):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
disabled=True,
|
||||
id="c6731acb-4285-4ee1-bc9b-03d0766c370f",
|
||||
description="This block fetches Reddit posts from a defined subreddit name.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
disabled=(
|
||||
not settings.secrets.reddit_client_id
|
||||
or not settings.secrets.reddit_client_secret
|
||||
),
|
||||
input_schema=GetRedditPostsBlock.Input,
|
||||
output_schema=GetRedditPostsBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"creds": {
|
||||
"client_id": "client_id",
|
||||
"client_secret": "client_secret",
|
||||
"username": "username",
|
||||
"password": "password",
|
||||
"user_agent": "user_agent",
|
||||
},
|
||||
"subreddit": "subreddit",
|
||||
"last_post": "id3",
|
||||
"post_limit": 2,
|
||||
@@ -130,7 +103,7 @@ class GetRedditPostsBlock(Block):
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_posts": lambda input_data, credentials: [
|
||||
"get_posts": lambda _: [
|
||||
MockObject(id="id1", title="title1", selftext="body1"),
|
||||
MockObject(id="id2", title="title2", selftext="body2"),
|
||||
MockObject(id="id3", title="title2", selftext="body2"),
|
||||
@@ -139,18 +112,14 @@ class GetRedditPostsBlock(Block):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_posts(
|
||||
input_data: Input, *, credentials: RedditCredentials
|
||||
) -> Iterator[praw.reddit.Submission]:
|
||||
client = get_praw(credentials)
|
||||
def get_posts(input_data: Input) -> Iterator[praw.reddit.Submission]:
|
||||
client = get_praw(input_data.creds)
|
||||
subreddit = client.subreddit(input_data.subreddit)
|
||||
return subreddit.new(limit=input_data.post_limit or 10)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: RedditCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
current_time = datetime.now(tz=timezone.utc)
|
||||
for post in self.get_posts(input_data=input_data, credentials=credentials):
|
||||
for post in self.get_posts(input_data):
|
||||
if input_data.last_minutes:
|
||||
post_datetime = datetime.fromtimestamp(
|
||||
post.created_utc, tz=timezone.utc
|
||||
@@ -172,7 +141,9 @@ class GetRedditPostsBlock(Block):
|
||||
|
||||
class PostRedditCommentBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: RedditCredentialsInput = RedditCredentialsField()
|
||||
creds: RedditCredentials = SchemaField(
|
||||
description="Reddit credentials", default=RedditCredentials()
|
||||
)
|
||||
data: RedditComment = SchemaField(description="Reddit comment")
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -185,15 +156,7 @@ class PostRedditCommentBlock(Block):
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=PostRedditCommentBlock.Input,
|
||||
output_schema=PostRedditCommentBlock.Output,
|
||||
disabled=(
|
||||
not settings.secrets.reddit_client_id
|
||||
or not settings.secrets.reddit_client_secret
|
||||
),
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"data": {"post_id": "id", "comment": "comment"},
|
||||
},
|
||||
test_input={"data": {"post_id": "id", "comment": "comment"}},
|
||||
test_output=[("comment_id", "dummy_comment_id")],
|
||||
test_mock={"reply_post": lambda creds, comment: "dummy_comment_id"},
|
||||
)
|
||||
@@ -207,7 +170,5 @@ class PostRedditCommentBlock(Block):
|
||||
raise ValueError("Failed to post comment.")
|
||||
return new_comment.id
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: RedditCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
yield "comment_id", self.reply_post(credentials, input_data.data)
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "comment_id", self.reply_post(input_data.creds, input_data.data)
|
||||
|
||||
@@ -131,7 +131,7 @@ class ReplicateFluxAdvancedModelBlock(Block):
|
||||
super().__init__(
|
||||
id="90f8c45e-e983-4644-aa0b-b4ebe2f531bc",
|
||||
description="This block runs Flux models on Replicate with advanced settings.",
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=ReplicateFluxAdvancedModelBlock.Input,
|
||||
output_schema=ReplicateFluxAdvancedModelBlock.Output,
|
||||
test_input={
|
||||
|
||||
@@ -1,176 +0,0 @@
|
||||
from base64 import b64encode
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.file import MediaFileType, store_media_file
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
class Format(str, Enum):
|
||||
PNG = "png"
|
||||
JPEG = "jpeg"
|
||||
WEBP = "webp"
|
||||
|
||||
|
||||
class ScreenshotWebPageBlock(Block):
|
||||
"""Block for taking screenshots using ScreenshotOne API"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.SCREENSHOTONE], Literal["api_key"]
|
||||
] = CredentialsField(description="The ScreenshotOne API key")
|
||||
url: str = SchemaField(
|
||||
description="URL of the website to screenshot",
|
||||
placeholder="https://example.com",
|
||||
)
|
||||
viewport_width: int = SchemaField(
|
||||
description="Width of the viewport in pixels", default=1920
|
||||
)
|
||||
viewport_height: int = SchemaField(
|
||||
description="Height of the viewport in pixels", default=1080
|
||||
)
|
||||
full_page: bool = SchemaField(
|
||||
description="Whether to capture the full page length", default=False
|
||||
)
|
||||
format: Format = SchemaField(
|
||||
description="Output format (png, jpeg, webp)", default=Format.PNG
|
||||
)
|
||||
block_ads: bool = SchemaField(description="Whether to block ads", default=True)
|
||||
block_cookie_banners: bool = SchemaField(
|
||||
description="Whether to block cookie banners", default=True
|
||||
)
|
||||
block_chats: bool = SchemaField(
|
||||
description="Whether to block chat widgets", default=True
|
||||
)
|
||||
cache: bool = SchemaField(
|
||||
description="Whether to enable caching", default=False
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
image: MediaFileType = SchemaField(description="The screenshot image data")
|
||||
error: str = SchemaField(description="Error message if the screenshot failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3a7c4b8d-6e2f-4a5d-b9c1-f8d23c5a9b0e", # Generated UUID
|
||||
description="Takes a screenshot of a specified website using ScreenshotOne API",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=ScreenshotWebPageBlock.Input,
|
||||
output_schema=ScreenshotWebPageBlock.Output,
|
||||
test_input={
|
||||
"url": "https://example.com",
|
||||
"viewport_width": 1920,
|
||||
"viewport_height": 1080,
|
||||
"full_page": False,
|
||||
"format": "png",
|
||||
"block_ads": True,
|
||||
"block_cookie_banners": True,
|
||||
"block_chats": True,
|
||||
"cache": False,
|
||||
"credentials": {
|
||||
"provider": "screenshotone",
|
||||
"type": "api_key",
|
||||
"id": "test-id",
|
||||
"title": "Test API Key",
|
||||
},
|
||||
},
|
||||
test_credentials=APIKeyCredentials(
|
||||
id="test-id",
|
||||
provider="screenshotone",
|
||||
api_key=SecretStr("test-key"),
|
||||
title="Test API Key",
|
||||
expires_at=None,
|
||||
),
|
||||
test_output=[
|
||||
(
|
||||
"image",
|
||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAB5JREFUOE9jZPjP8J+BAsA4agDDaBgwjIYBw7AIAwCV5B/xAsMbygAAAABJRU5ErkJggg==",
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"take_screenshot": lambda *args, **kwargs: {
|
||||
"image": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAB5JREFUOE9jZPjP8J+BAsA4agDDaBgwjIYBw7AIAwCV5B/xAsMbygAAAABJRU5ErkJggg==",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def take_screenshot(
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
url: str,
|
||||
viewport_width: int,
|
||||
viewport_height: int,
|
||||
full_page: bool,
|
||||
format: Format,
|
||||
block_ads: bool,
|
||||
block_cookie_banners: bool,
|
||||
block_chats: bool,
|
||||
cache: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Takes a screenshot using the ScreenshotOne API
|
||||
"""
|
||||
api = Requests(trusted_origins=["https://api.screenshotone.com"])
|
||||
|
||||
# Build API URL with parameters
|
||||
params = {
|
||||
"access_key": credentials.api_key.get_secret_value(),
|
||||
"url": url,
|
||||
"viewport_width": viewport_width,
|
||||
"viewport_height": viewport_height,
|
||||
"full_page": str(full_page).lower(),
|
||||
"format": format.value,
|
||||
"block_ads": str(block_ads).lower(),
|
||||
"block_cookie_banners": str(block_cookie_banners).lower(),
|
||||
"block_chats": str(block_chats).lower(),
|
||||
"cache": str(cache).lower(),
|
||||
}
|
||||
|
||||
response = api.get("https://api.screenshotone.com/take", params=params)
|
||||
|
||||
return {
|
||||
"image": store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=MediaFileType(
|
||||
f"data:image/{format.value};base64,{b64encode(response.content).decode('utf-8')}"
|
||||
),
|
||||
return_content=True,
|
||||
)
|
||||
}
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
screenshot_data = self.take_screenshot(
|
||||
credentials=credentials,
|
||||
graph_exec_id=graph_exec_id,
|
||||
url=input_data.url,
|
||||
viewport_width=input_data.viewport_width,
|
||||
viewport_height=input_data.viewport_height,
|
||||
full_page=input_data.full_page,
|
||||
format=input_data.format,
|
||||
block_ads=input_data.block_ads,
|
||||
block_cookie_banners=input_data.block_cookie_banners,
|
||||
block_chats=input_data.block_chats,
|
||||
cache=input_data.cache,
|
||||
)
|
||||
yield "image", screenshot_data["image"]
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -8,7 +8,6 @@ from backend.data.block import (
|
||||
BlockWebhookConfig,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import settings
|
||||
from backend.util.settings import AppEnvironment, BehaveAs
|
||||
|
||||
@@ -83,7 +82,7 @@ class Slant3DOrderWebhookBlock(Slant3DTriggerBase, Block):
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider=ProviderName.SLANT3D,
|
||||
provider="slant3d",
|
||||
webhook_type="orders", # Only one type for now
|
||||
resource_format="", # No resource format needed
|
||||
event_filter_input="events",
|
||||
|
||||
@@ -1,511 +0,0 @@
|
||||
import logging
|
||||
import re
|
||||
from collections import Counter
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client():
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
|
||||
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Return a list of tool_call_ids if the entry is a tool request.
|
||||
Supports both OpenAI and Anthropics formats.
|
||||
"""
|
||||
tool_call_ids = []
|
||||
if entry.get("role") != "assistant":
|
||||
return tool_call_ids
|
||||
|
||||
# OpenAI: check for tool_calls in the entry.
|
||||
calls = entry.get("tool_calls")
|
||||
if isinstance(calls, list):
|
||||
for call in calls:
|
||||
if tool_id := call.get("id"):
|
||||
tool_call_ids.append(tool_id)
|
||||
|
||||
# Anthropics: check content items for tool_use type.
|
||||
content = entry.get("content")
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("type") != "tool_use":
|
||||
continue
|
||||
if tool_id := item.get("id"):
|
||||
tool_call_ids.append(tool_id)
|
||||
|
||||
return tool_call_ids
|
||||
|
||||
|
||||
def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Return a list of tool_call_ids if the entry is a tool response.
|
||||
Supports both OpenAI and Anthropics formats.
|
||||
"""
|
||||
tool_call_ids: list[str] = []
|
||||
|
||||
# OpenAI: a tool response message with role "tool" and key "tool_call_id".
|
||||
if entry.get("role") == "tool":
|
||||
if tool_call_id := entry.get("tool_call_id"):
|
||||
tool_call_ids.append(str(tool_call_id))
|
||||
|
||||
# Anthropics: check content items for tool_result type.
|
||||
if entry.get("role") == "user":
|
||||
content = entry.get("content")
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("type") != "tool_result":
|
||||
continue
|
||||
if tool_call_id := item.get("tool_use_id"):
|
||||
tool_call_ids.append(tool_call_id)
|
||||
|
||||
return tool_call_ids
|
||||
|
||||
|
||||
def _create_tool_response(call_id: str, output: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Create a tool response message for either OpenAI or Anthropics,
|
||||
based on the tool_id format.
|
||||
"""
|
||||
content = output if isinstance(output, str) else json.dumps(output)
|
||||
|
||||
# Anthropics format: tool IDs typically start with "toolu_"
|
||||
if call_id.startswith("toolu_"):
|
||||
return {
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": [
|
||||
{"tool_use_id": call_id, "type": "tool_result", "content": content}
|
||||
],
|
||||
}
|
||||
|
||||
# OpenAI format: tool IDs typically start with "call_".
|
||||
# Or default fallback (if the tool_id doesn't match any known prefix)
|
||||
return {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||
|
||||
|
||||
def get_pending_tool_calls(conversation_history: list[Any]) -> dict[str, int]:
|
||||
"""
|
||||
All the tool calls entry in the conversation history requires a response.
|
||||
This function returns the pending tool calls that has not generated an output yet.
|
||||
|
||||
Return: dict[str, int] - A dictionary of pending tool call IDs with their count.
|
||||
"""
|
||||
pending_calls = Counter()
|
||||
for history in conversation_history:
|
||||
for call_id in _get_tool_requests(history):
|
||||
pending_calls[call_id] += 1
|
||||
|
||||
for call_id in _get_tool_responses(history):
|
||||
pending_calls[call_id] -= 1
|
||||
|
||||
return {call_id: count for call_id, count in pending_calls.items() if count > 0}
|
||||
|
||||
|
||||
class SmartDecisionMakerBlock(Block):
|
||||
"""
|
||||
A block that uses a language model to make smart decisions based on a given prompt.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
prompt: str = SchemaField(
|
||||
description="The prompt to send to the language model.",
|
||||
placeholder="Enter your prompt here...",
|
||||
)
|
||||
model: llm.LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=llm.LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
credentials: llm.AICredentials = llm.AICredentialsField()
|
||||
sys_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default="Thinking carefully step by step decide which function to call. "
|
||||
"Always choose a function call from the list of function signatures, "
|
||||
"and always provide the complete argument provided with the type "
|
||||
"matching the required jsonschema signature, no missing argument is allowed. "
|
||||
"If you have already completed the task objective, you can end the task "
|
||||
"by providing the end result of your work as a finish message. "
|
||||
"Only provide EXACTLY one function call, multiple tool calls is strictly prohibited.",
|
||||
description="The system prompt to provide additional context to the model.",
|
||||
)
|
||||
conversation_history: list[dict] = SchemaField(
|
||||
default=[],
|
||||
description="The conversation history to provide context for the prompt.",
|
||||
)
|
||||
last_tool_output: Any = SchemaField(
|
||||
default=None,
|
||||
description="The output of the last tool that was called.",
|
||||
)
|
||||
retry: int = SchemaField(
|
||||
title="Retry Count",
|
||||
default=3,
|
||||
description="Number of times to retry the LLM call if the response does not match the expected format.",
|
||||
)
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False,
|
||||
default={},
|
||||
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
advanced=True,
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate in the chat completion.",
|
||||
)
|
||||
ollama_host: str = SchemaField(
|
||||
advanced=True,
|
||||
default="localhost:11434",
|
||||
description="Ollama host for local models",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
|
||||
# conversation_history & last_tool_output validation is handled differently
|
||||
missing_links = super().get_missing_links(
|
||||
data,
|
||||
[
|
||||
link
|
||||
for link in links
|
||||
if link.sink_name
|
||||
not in ["conversation_history", "last_tool_output"]
|
||||
],
|
||||
)
|
||||
|
||||
# Avoid executing the block if the last_tool_output is connected to a static
|
||||
# link, like StoreValueBlock or AgentInputBlock.
|
||||
if any(link.sink_name == "conversation_history" for link in links) and any(
|
||||
link.sink_name == "last_tool_output" and link.is_static
|
||||
for link in links
|
||||
):
|
||||
raise ValueError(
|
||||
"Last Tool Output can't be connected to a static (dashed line) "
|
||||
"link like the output of `StoreValue` or `AgentInput` block"
|
||||
)
|
||||
|
||||
return missing_links
|
||||
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
if missing_input := super().get_missing_input(data):
|
||||
return missing_input
|
||||
|
||||
conversation_history = data.get("conversation_history", [])
|
||||
pending_tool_calls = get_pending_tool_calls(conversation_history)
|
||||
last_tool_output = data.get("last_tool_output")
|
||||
if not last_tool_output and pending_tool_calls:
|
||||
return {"last_tool_output"}
|
||||
return set()
|
||||
|
||||
class Output(BlockSchema):
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
tools: Any = SchemaField(description="The tools that are available to use.")
|
||||
finished: str = SchemaField(
|
||||
description="The finished message to display to the user."
|
||||
)
|
||||
conversations: list[Any] = SchemaField(
|
||||
description="The conversation history to provide context for the prompt."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||
description="Uses AI to intelligently decide what tool to use.",
|
||||
categories={BlockCategory.AI},
|
||||
block_type=BlockType.AI,
|
||||
input_schema=SmartDecisionMakerBlock.Input,
|
||||
output_schema=SmartDecisionMakerBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[],
|
||||
test_credentials=llm.TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_block_function_signature(
|
||||
sink_node: "Node", links: list["Link"]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Creates a function signature for a block node.
|
||||
|
||||
Args:
|
||||
sink_node: The node for which to create a function signature.
|
||||
links: The list of links connected to the sink node.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the function signature in the format expected by LLM tools.
|
||||
|
||||
Raises:
|
||||
ValueError: If the block specified by sink_node.block_id is not found.
|
||||
"""
|
||||
block = get_block(sink_node.block_id)
|
||||
if not block:
|
||||
raise ValueError(f"Block not found: {sink_node.block_id}")
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", block.name).lower(),
|
||||
"description": block.description,
|
||||
}
|
||||
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for link in links:
|
||||
sink_block_input_schema = block.input_schema
|
||||
description = (
|
||||
sink_block_input_schema.model_fields[link.sink_name].description
|
||||
if link.sink_name in sink_block_input_schema.model_fields
|
||||
and sink_block_input_schema.model_fields[link.sink_name].description
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[link.sink_name.lower()] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
}
|
||||
|
||||
tool_function["parameters"] = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
"additionalProperties": False,
|
||||
"strict": True,
|
||||
}
|
||||
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
def _create_agent_function_signature(
|
||||
sink_node: "Node", links: list["Link"]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Creates a function signature for an agent node.
|
||||
|
||||
Args:
|
||||
sink_node: The agent node for which to create a function signature.
|
||||
links: The list of links connected to the sink node.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the function signature in the format expected by LLM tools.
|
||||
|
||||
Raises:
|
||||
ValueError: If the graph metadata for the specified graph_id and graph_version is not found.
|
||||
"""
|
||||
graph_id = sink_node.input_default.get("graph_id")
|
||||
graph_version = sink_node.input_default.get("graph_version")
|
||||
if not graph_id or not graph_version:
|
||||
raise ValueError("Graph ID or Graph Version not found in sink node.")
|
||||
|
||||
db_client = get_database_manager_client()
|
||||
sink_graph_meta = db_client.get_graph_metadata(graph_id, graph_version)
|
||||
if not sink_graph_meta:
|
||||
raise ValueError(
|
||||
f"Sink graph metadata not found: {graph_id} {graph_version}"
|
||||
)
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", sink_graph_meta.name).lower(),
|
||||
"description": sink_graph_meta.description,
|
||||
}
|
||||
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for link in links:
|
||||
sink_block_input_schema = sink_node.input_default["input_schema"]
|
||||
description = (
|
||||
sink_block_input_schema["properties"][link.sink_name]["description"]
|
||||
if "description"
|
||||
in sink_block_input_schema["properties"][link.sink_name]
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[link.sink_name.lower()] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
}
|
||||
|
||||
tool_function["parameters"] = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
"additionalProperties": False,
|
||||
"strict": True,
|
||||
}
|
||||
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Creates function signatures for tools linked to a specified node within a graph.
|
||||
|
||||
This method filters the graph links to identify those that are tools and are
|
||||
connected to the given node_id. It then constructs function signatures for each
|
||||
tool based on the metadata and input schema of the linked nodes.
|
||||
|
||||
Args:
|
||||
node_id: The node_id for which to create function signatures.
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]: A list of dictionaries, each representing a function signature
|
||||
for a tool, including its name, description, and parameters.
|
||||
|
||||
Raises:
|
||||
ValueError: If no tool links are found for the specified node_id, or if a sink node
|
||||
or its metadata cannot be found.
|
||||
"""
|
||||
db_client = get_database_manager_client()
|
||||
tools = [
|
||||
(link, node)
|
||||
for link, node in db_client.get_connected_output_nodes(node_id)
|
||||
if link.source_name.startswith("tools_^_") and link.source_id == node_id
|
||||
]
|
||||
if not tools:
|
||||
raise ValueError("There is no next node to execute.")
|
||||
|
||||
return_tool_functions = []
|
||||
|
||||
grouped_tool_links: dict[str, tuple["Node", list["Link"]]] = {}
|
||||
for link, node in tools:
|
||||
if link.sink_id not in grouped_tool_links:
|
||||
grouped_tool_links[link.sink_id] = (node, [link])
|
||||
else:
|
||||
grouped_tool_links[link.sink_id][1].append(link)
|
||||
|
||||
for sink_node, links in grouped_tool_links.values():
|
||||
if not sink_node:
|
||||
raise ValueError(f"Sink node not found: {links[0].sink_id}")
|
||||
|
||||
if sink_node.block_id == AgentExecutorBlock().id:
|
||||
return_tool_functions.append(
|
||||
SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
else:
|
||||
return_tool_functions.append(
|
||||
SmartDecisionMakerBlock._create_block_function_signature(
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
|
||||
return return_tool_functions
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: llm.APIKeyCredentials,
|
||||
graph_id: str,
|
||||
node_id: str,
|
||||
graph_exec_id: str,
|
||||
node_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
tool_functions = self._create_function_signature(node_id)
|
||||
|
||||
input_data.conversation_history = input_data.conversation_history or []
|
||||
prompt = [json.to_dict(p) for p in input_data.conversation_history if p]
|
||||
|
||||
pending_tool_calls = get_pending_tool_calls(input_data.conversation_history)
|
||||
if pending_tool_calls and not input_data.last_tool_output:
|
||||
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
|
||||
|
||||
# Prefill all missing tool calls with the last tool output/
|
||||
# TODO: we need a better way to handle this.
|
||||
tool_output = [
|
||||
_create_tool_response(pending_call_id, input_data.last_tool_output)
|
||||
for pending_call_id, count in pending_tool_calls.items()
|
||||
for _ in range(count)
|
||||
]
|
||||
|
||||
# If the SDM block only calls 1 tool at a time, this should not happen.
|
||||
if len(tool_output) > 1:
|
||||
logger.warning(
|
||||
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
|
||||
f"Multiple pending tool calls are prefilled using a single output. "
|
||||
f"Execution may not be accurate."
|
||||
)
|
||||
|
||||
# Fallback on adding tool output in the conversation history as user prompt.
|
||||
if len(tool_output) == 0 and input_data.last_tool_output:
|
||||
logger.warning(
|
||||
f"[SmartDecisionMakerBlock-node_exec_id={node_exec_id}] "
|
||||
f"No pending tool calls found. This may indicate an issue with the "
|
||||
f"conversation history, or an LLM calling two tools at the same time."
|
||||
)
|
||||
tool_output.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Last tool output: {json.dumps(input_data.last_tool_output)}",
|
||||
}
|
||||
)
|
||||
|
||||
prompt.extend(tool_output)
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values)
|
||||
|
||||
prefix = "[Main Objective Prompt]: "
|
||||
|
||||
if input_data.sys_prompt and not any(
|
||||
p["role"] == "system" and p["content"].startswith(prefix) for p in prompt
|
||||
):
|
||||
prompt.append({"role": "system", "content": prefix + input_data.sys_prompt})
|
||||
|
||||
if input_data.prompt and not any(
|
||||
p["role"] == "user" and p["content"].startswith(prefix) for p in prompt
|
||||
):
|
||||
prompt.append({"role": "user", "content": prefix + input_data.prompt})
|
||||
|
||||
response = llm.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
json_format=False,
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
)
|
||||
|
||||
if not response.tool_calls:
|
||||
yield "finished", response.response
|
||||
return
|
||||
|
||||
for tool_call in response.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
for arg_name, arg_value in tool_args.items():
|
||||
yield f"tools_^_{tool_name}_{arg_name}".lower(), arg_value
|
||||
|
||||
response.prompt.append(response.raw_response)
|
||||
yield "conversations", response.prompt
|
||||
@@ -1,97 +0,0 @@
|
||||
from backend.blocks.smartlead.models import (
|
||||
AddLeadsRequest,
|
||||
AddLeadsToCampaignResponse,
|
||||
CreateCampaignRequest,
|
||||
CreateCampaignResponse,
|
||||
SaveSequencesRequest,
|
||||
SaveSequencesResponse,
|
||||
)
|
||||
from backend.util.request import Requests
|
||||
|
||||
|
||||
class SmartLeadClient:
|
||||
"""Client for the SmartLead API"""
|
||||
|
||||
# This api is stupid and requires your api key in the url. DO NOT RAISE ERRORS FOR BAD REQUESTS.
|
||||
# FILTER OUT THE API KEY FROM THE ERROR MESSAGE.
|
||||
|
||||
API_URL = "https://server.smartlead.ai/api/v1"
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
self.api_key = api_key
|
||||
self.requests = Requests()
|
||||
|
||||
def _add_auth_to_url(self, url: str) -> str:
|
||||
return f"{url}?api_key={self.api_key}"
|
||||
|
||||
def _handle_error(self, e: Exception) -> str:
|
||||
return e.__str__().replace(self.api_key, "API KEY")
|
||||
|
||||
def create_campaign(self, request: CreateCampaignRequest) -> CreateCampaignResponse:
|
||||
try:
|
||||
response = self.requests.post(
|
||||
self._add_auth_to_url(f"{self.API_URL}/campaigns/create"),
|
||||
json=request.model_dump(),
|
||||
)
|
||||
response_data = response.json()
|
||||
return CreateCampaignResponse(**response_data)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid response format: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to create campaign: {self._handle_error(e)}")
|
||||
|
||||
def add_leads_to_campaign(
|
||||
self, request: AddLeadsRequest
|
||||
) -> AddLeadsToCampaignResponse:
|
||||
try:
|
||||
response = self.requests.post(
|
||||
self._add_auth_to_url(
|
||||
f"{self.API_URL}/campaigns/{request.campaign_id}/leads"
|
||||
),
|
||||
json=request.model_dump(exclude={"campaign_id"}),
|
||||
)
|
||||
response_data = response.json()
|
||||
response_parsed = AddLeadsToCampaignResponse(**response_data)
|
||||
if not response_parsed.ok:
|
||||
raise ValueError(
|
||||
f"Failed to add leads to campaign: {response_parsed.error}"
|
||||
)
|
||||
return response_parsed
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid response format: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to add leads to campaign: {self._handle_error(e)}"
|
||||
)
|
||||
|
||||
def save_campaign_sequences(
|
||||
self, campaign_id: int, request: SaveSequencesRequest
|
||||
) -> SaveSequencesResponse:
|
||||
"""
|
||||
Save sequences within a campaign.
|
||||
|
||||
Args:
|
||||
campaign_id: ID of the campaign to save sequences for
|
||||
request: SaveSequencesRequest containing the sequences configuration
|
||||
|
||||
Returns:
|
||||
SaveSequencesResponse with the result of the operation
|
||||
|
||||
Note:
|
||||
For variant_distribution_type:
|
||||
- MANUAL_EQUAL: Equally distributes variants across leads
|
||||
- AI_EQUAL: Requires winning_metric_property and lead_distribution_percentage
|
||||
- MANUAL_PERCENTAGE: Requires variant_distribution_percentage in seq_variants
|
||||
"""
|
||||
try:
|
||||
response = self.requests.post(
|
||||
self._add_auth_to_url(
|
||||
f"{self.API_URL}/campaigns/{campaign_id}/sequences"
|
||||
),
|
||||
json=request.model_dump(exclude_none=True),
|
||||
)
|
||||
return SaveSequencesResponse(**response.json())
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to save campaign sequences: {e.__str__().replace(self.api_key, 'API KEY')}"
|
||||
)
|
||||
@@ -1,35 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
SmartLeadCredentials = APIKeyCredentials
|
||||
SmartLeadCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.SMARTLEAD],
|
||||
Literal["api_key"],
|
||||
]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="smartlead",
|
||||
api_key=SecretStr("mock-smartlead-api-key"),
|
||||
title="Mock SmartLead API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
def SmartLeadCredentialsField() -> SmartLeadCredentialsInput:
|
||||
"""
|
||||
Creates a SmartLead credentials input on a block.
|
||||
"""
|
||||
return CredentialsField(
|
||||
description="The SmartLead integration can be used with an API Key.",
|
||||
)
|
||||
@@ -1,326 +0,0 @@
|
||||
from backend.blocks.smartlead._api import SmartLeadClient
|
||||
from backend.blocks.smartlead._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
SmartLeadCredentials,
|
||||
SmartLeadCredentialsInput,
|
||||
)
|
||||
from backend.blocks.smartlead.models import (
|
||||
AddLeadsRequest,
|
||||
AddLeadsToCampaignResponse,
|
||||
CreateCampaignRequest,
|
||||
CreateCampaignResponse,
|
||||
LeadInput,
|
||||
LeadUploadSettings,
|
||||
SaveSequencesRequest,
|
||||
SaveSequencesResponse,
|
||||
Sequence,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class CreateCampaignBlock(Block):
|
||||
"""Create a campaign in SmartLead"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
name: str = SchemaField(
|
||||
description="The name of the campaign",
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
description="SmartLead credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
id: int = SchemaField(
|
||||
description="The ID of the created campaign",
|
||||
)
|
||||
name: str = SchemaField(
|
||||
description="The name of the created campaign",
|
||||
)
|
||||
created_at: str = SchemaField(
|
||||
description="The date and time the campaign was created",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the search failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8865699f-9188-43c4-89b0-79c84cfaa03e",
|
||||
description="Create a campaign in SmartLead",
|
||||
categories={BlockCategory.CRM},
|
||||
input_schema=CreateCampaignBlock.Input,
|
||||
output_schema=CreateCampaignBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={"name": "Test Campaign", "credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_output=[
|
||||
(
|
||||
"id",
|
||||
1,
|
||||
),
|
||||
(
|
||||
"name",
|
||||
"Test Campaign",
|
||||
),
|
||||
(
|
||||
"created_at",
|
||||
"2024-01-01T00:00:00Z",
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"create_campaign": lambda name, credentials: CreateCampaignResponse(
|
||||
ok=True,
|
||||
id=1,
|
||||
name=name,
|
||||
created_at="2024-01-01T00:00:00Z",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_campaign(
|
||||
name: str, credentials: SmartLeadCredentials
|
||||
) -> CreateCampaignResponse:
|
||||
client = SmartLeadClient(credentials.api_key.get_secret_value())
|
||||
return client.create_campaign(CreateCampaignRequest(name=name))
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: SmartLeadCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
response = self.create_campaign(input_data.name, credentials)
|
||||
|
||||
yield "id", response.id
|
||||
yield "name", response.name
|
||||
yield "created_at", response.created_at
|
||||
if not response.ok:
|
||||
yield "error", "Failed to create campaign"
|
||||
|
||||
|
||||
class AddLeadToCampaignBlock(Block):
|
||||
"""Add a lead to a campaign in SmartLead"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
campaign_id: int = SchemaField(
|
||||
description="The ID of the campaign to add the lead to",
|
||||
)
|
||||
lead_list: list[LeadInput] = SchemaField(
|
||||
description="An array of JSON objects, each representing a lead's details. Can hold max 100 leads.",
|
||||
max_length=100,
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
settings: LeadUploadSettings = SchemaField(
|
||||
description="Settings for lead upload",
|
||||
default=LeadUploadSettings(),
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
description="SmartLead credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
campaign_id: int = SchemaField(
|
||||
description="The ID of the campaign the lead was added to (passed through)",
|
||||
)
|
||||
upload_count: int = SchemaField(
|
||||
description="The number of leads added to the campaign",
|
||||
)
|
||||
already_added_to_campaign: int = SchemaField(
|
||||
description="The number of leads that were already added to the campaign",
|
||||
)
|
||||
duplicate_count: int = SchemaField(
|
||||
description="The number of emails that were duplicates",
|
||||
)
|
||||
invalid_email_count: int = SchemaField(
|
||||
description="The number of emails that were invalidly formatted",
|
||||
)
|
||||
is_lead_limit_exhausted: bool = SchemaField(
|
||||
description="Whether the lead limit was exhausted",
|
||||
)
|
||||
lead_import_stopped_count: int = SchemaField(
|
||||
description="The number of leads that were not added to the campaign because the lead import was stopped",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the lead was not added to the campaign",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fb8106a4-1a8f-42f9-a502-f6d07e6fe0ec",
|
||||
description="Add a lead to a campaign in SmartLead",
|
||||
categories={BlockCategory.CRM},
|
||||
input_schema=AddLeadToCampaignBlock.Input,
|
||||
output_schema=AddLeadToCampaignBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"campaign_id": 1,
|
||||
"lead_list": [],
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"campaign_id",
|
||||
1,
|
||||
),
|
||||
(
|
||||
"upload_count",
|
||||
1,
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"add_leads_to_campaign": lambda campaign_id, lead_list, credentials: AddLeadsToCampaignResponse(
|
||||
ok=True,
|
||||
upload_count=1,
|
||||
already_added_to_campaign=0,
|
||||
duplicate_count=0,
|
||||
invalid_email_count=0,
|
||||
is_lead_limit_exhausted=False,
|
||||
lead_import_stopped_count=0,
|
||||
error="",
|
||||
total_leads=1,
|
||||
block_count=0,
|
||||
invalid_emails=[],
|
||||
unsubscribed_leads=[],
|
||||
bounce_count=0,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_leads_to_campaign(
|
||||
campaign_id: int, lead_list: list[LeadInput], credentials: SmartLeadCredentials
|
||||
) -> AddLeadsToCampaignResponse:
|
||||
client = SmartLeadClient(credentials.api_key.get_secret_value())
|
||||
return client.add_leads_to_campaign(
|
||||
AddLeadsRequest(
|
||||
campaign_id=campaign_id,
|
||||
lead_list=lead_list,
|
||||
settings=LeadUploadSettings(
|
||||
ignore_global_block_list=False,
|
||||
ignore_unsubscribe_list=False,
|
||||
ignore_community_bounce_list=False,
|
||||
ignore_duplicate_leads_in_other_campaign=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: SmartLeadCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
response = self.add_leads_to_campaign(
|
||||
input_data.campaign_id, input_data.lead_list, credentials
|
||||
)
|
||||
|
||||
yield "campaign_id", input_data.campaign_id
|
||||
yield "upload_count", response.upload_count
|
||||
if response.already_added_to_campaign:
|
||||
yield "already_added_to_campaign", response.already_added_to_campaign
|
||||
if response.duplicate_count:
|
||||
yield "duplicate_count", response.duplicate_count
|
||||
if response.invalid_email_count:
|
||||
yield "invalid_email_count", response.invalid_email_count
|
||||
if response.is_lead_limit_exhausted:
|
||||
yield "is_lead_limit_exhausted", response.is_lead_limit_exhausted
|
||||
if response.lead_import_stopped_count:
|
||||
yield "lead_import_stopped_count", response.lead_import_stopped_count
|
||||
if response.error:
|
||||
yield "error", response.error
|
||||
if not response.ok:
|
||||
yield "error", "Failed to add leads to campaign"
|
||||
|
||||
|
||||
class SaveCampaignSequencesBlock(Block):
|
||||
"""Save sequences within a campaign"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
campaign_id: int = SchemaField(
|
||||
description="The ID of the campaign to save sequences for",
|
||||
)
|
||||
sequences: list[Sequence] = SchemaField(
|
||||
description="The sequences to save",
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
credentials: SmartLeadCredentialsInput = SchemaField(
|
||||
description="SmartLead credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
data: dict | str | None = SchemaField(
|
||||
description="Data from the API",
|
||||
default=None,
|
||||
)
|
||||
message: str = SchemaField(
|
||||
description="Message from the API",
|
||||
default="",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the sequences were not saved",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e7d9f41c-dc10-4f39-98ba-a432abd128c0",
|
||||
description="Save sequences within a campaign",
|
||||
categories={BlockCategory.CRM},
|
||||
input_schema=SaveCampaignSequencesBlock.Input,
|
||||
output_schema=SaveCampaignSequencesBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"campaign_id": 1,
|
||||
"sequences": [],
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"message",
|
||||
"Sequences saved successfully",
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"save_campaign_sequences": lambda campaign_id, sequences, credentials: SaveSequencesResponse(
|
||||
ok=True,
|
||||
message="Sequences saved successfully",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def save_campaign_sequences(
|
||||
campaign_id: int, sequences: list[Sequence], credentials: SmartLeadCredentials
|
||||
) -> SaveSequencesResponse:
|
||||
client = SmartLeadClient(credentials.api_key.get_secret_value())
|
||||
return client.save_campaign_sequences(
|
||||
campaign_id=campaign_id, request=SaveSequencesRequest(sequences=sequences)
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: SmartLeadCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
response = self.save_campaign_sequences(
|
||||
input_data.campaign_id, input_data.sequences, credentials
|
||||
)
|
||||
|
||||
if response.data:
|
||||
yield "data", response.data
|
||||
if response.message:
|
||||
yield "message", response.message
|
||||
if response.error:
|
||||
yield "error", response.error
|
||||
if not response.ok:
|
||||
yield "error", "Failed to save sequences"
|
||||
@@ -1,147 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class CreateCampaignResponse(BaseModel):
|
||||
ok: bool
|
||||
id: int
|
||||
name: str
|
||||
created_at: str
|
||||
|
||||
|
||||
class CreateCampaignRequest(BaseModel):
|
||||
name: str
|
||||
client_id: str | None = None
|
||||
|
||||
|
||||
class AddLeadsToCampaignResponse(BaseModel):
|
||||
ok: bool
|
||||
upload_count: int
|
||||
total_leads: int
|
||||
block_count: int
|
||||
duplicate_count: int
|
||||
invalid_email_count: int
|
||||
invalid_emails: list[str]
|
||||
already_added_to_campaign: int
|
||||
unsubscribed_leads: list[str]
|
||||
is_lead_limit_exhausted: bool
|
||||
lead_import_stopped_count: int
|
||||
bounce_count: int
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class LeadCustomFields(BaseModel):
|
||||
"""Custom fields for a lead (max 20 fields)"""
|
||||
|
||||
fields: dict[str, str] = SchemaField(
|
||||
description="Custom fields for a lead (max 20 fields)",
|
||||
max_length=20,
|
||||
default={},
|
||||
)
|
||||
|
||||
|
||||
class LeadInput(BaseModel):
|
||||
"""Single lead input data"""
|
||||
|
||||
first_name: str
|
||||
last_name: str
|
||||
email: str
|
||||
phone_number: str | None = None # Changed from int to str for phone numbers
|
||||
company_name: str | None = None
|
||||
website: str | None = None
|
||||
location: str | None = None
|
||||
custom_fields: LeadCustomFields | None = None
|
||||
linkedin_profile: str | None = None
|
||||
company_url: str | None = None
|
||||
|
||||
|
||||
class LeadUploadSettings(BaseModel):
|
||||
"""Settings for lead upload"""
|
||||
|
||||
ignore_global_block_list: bool = SchemaField(
|
||||
description="Ignore the global block list",
|
||||
default=False,
|
||||
)
|
||||
ignore_unsubscribe_list: bool = SchemaField(
|
||||
description="Ignore the unsubscribe list",
|
||||
default=False,
|
||||
)
|
||||
ignore_community_bounce_list: bool = SchemaField(
|
||||
description="Ignore the community bounce list",
|
||||
default=False,
|
||||
)
|
||||
ignore_duplicate_leads_in_other_campaign: bool = SchemaField(
|
||||
description="Ignore duplicate leads in other campaigns",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class AddLeadsRequest(BaseModel):
|
||||
"""Request body for adding leads to a campaign"""
|
||||
|
||||
lead_list: list[LeadInput] = SchemaField(
|
||||
description="List of leads to add to the campaign",
|
||||
max_length=100,
|
||||
default=[],
|
||||
)
|
||||
settings: LeadUploadSettings
|
||||
campaign_id: int
|
||||
|
||||
|
||||
class VariantDistributionType(str, Enum):
|
||||
MANUAL_EQUAL = "MANUAL_EQUAL"
|
||||
MANUAL_PERCENTAGE = "MANUAL_PERCENTAGE"
|
||||
AI_EQUAL = "AI_EQUAL"
|
||||
|
||||
|
||||
class WinningMetricProperty(str, Enum):
|
||||
OPEN_RATE = "OPEN_RATE"
|
||||
CLICK_RATE = "CLICK_RATE"
|
||||
REPLY_RATE = "REPLY_RATE"
|
||||
POSITIVE_REPLY_RATE = "POSITIVE_REPLY_RATE"
|
||||
|
||||
|
||||
class SequenceDelayDetails(BaseModel):
|
||||
delay_in_days: int
|
||||
|
||||
|
||||
class SequenceVariant(BaseModel):
|
||||
subject: str
|
||||
email_body: str
|
||||
variant_label: str
|
||||
id: int | None = None # Optional for creation, required for updates
|
||||
variant_distribution_percentage: int | None = None
|
||||
|
||||
|
||||
class Sequence(BaseModel):
|
||||
seq_number: int = SchemaField(
|
||||
description="The sequence number",
|
||||
default=1,
|
||||
)
|
||||
seq_delay_details: SequenceDelayDetails
|
||||
id: int | None = None
|
||||
variant_distribution_type: VariantDistributionType | None = None
|
||||
lead_distribution_percentage: int | None = SchemaField(
|
||||
None, ge=20, le=100
|
||||
) # >= 20% for fair calculation
|
||||
winning_metric_property: WinningMetricProperty | None = None
|
||||
seq_variants: list[SequenceVariant] | None = None
|
||||
subject: str = "" # blank makes the follow up in the same thread
|
||||
email_body: str | None = None
|
||||
|
||||
|
||||
class SaveSequencesRequest(BaseModel):
|
||||
sequences: list[Sequence]
|
||||
|
||||
|
||||
class SaveSequencesResponse(BaseModel):
|
||||
ok: bool
|
||||
message: str = SchemaField(
|
||||
description="Message from the API",
|
||||
default="",
|
||||
)
|
||||
data: dict | str | None = None
|
||||
error: str | None = None
|
||||
@@ -78,7 +78,7 @@ class CreateTalkingAvatarVideoBlock(Block):
|
||||
super().__init__(
|
||||
id="98c6f503-8c47-4b1c-a96d-351fc7c87dab",
|
||||
description="This block integrates with D-ID to create video clips and retrieve their URLs.",
|
||||
categories={BlockCategory.AI, BlockCategory.MULTIMEDIA},
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=CreateTalkingAvatarVideoBlock.Input,
|
||||
output_schema=CreateTalkingAvatarVideoBlock.Output,
|
||||
test_input={
|
||||
|
||||
@@ -76,8 +76,6 @@ class ExtractTextInformationBlock(Block):
|
||||
class Output(BlockSchema):
|
||||
positive: str = SchemaField(description="Extracted text")
|
||||
negative: str = SchemaField(description="Original text")
|
||||
matched_results: list[str] = SchemaField(description="List of matched results")
|
||||
matched_count: int = SchemaField(description="Number of matched results")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -105,31 +103,13 @@ class ExtractTextInformationBlock(Block):
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
# Test case 1
|
||||
("positive", "World!"),
|
||||
("matched_results", ["World!"]),
|
||||
("matched_count", 1),
|
||||
# Test case 2
|
||||
("positive", "Hello, World!"),
|
||||
("matched_results", ["Hello, World!"]),
|
||||
("matched_count", 1),
|
||||
# Test case 3
|
||||
("negative", "Hello, World!"),
|
||||
("matched_results", []),
|
||||
("matched_count", 0),
|
||||
# Test case 4
|
||||
("positive", "Hello,"),
|
||||
("matched_results", ["Hello,"]),
|
||||
("matched_count", 1),
|
||||
# Test case 5
|
||||
("positive", "World!!"),
|
||||
("matched_results", ["World!!"]),
|
||||
("matched_count", 1),
|
||||
# Test case 6
|
||||
("positive", "World!!"),
|
||||
("positive", "Earth!!"),
|
||||
("matched_results", ["World!!", "Earth!!"]),
|
||||
("matched_count", 2),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -150,16 +130,13 @@ class ExtractTextInformationBlock(Block):
|
||||
for match in re.finditer(input_data.pattern, txt, flags)
|
||||
if input_data.group <= len(match.groups())
|
||||
]
|
||||
if not input_data.find_all:
|
||||
matches = matches[:1]
|
||||
for match in matches:
|
||||
yield "positive", match
|
||||
if not input_data.find_all:
|
||||
return
|
||||
if not matches:
|
||||
yield "negative", input_data.text
|
||||
|
||||
yield "matched_results", matches
|
||||
yield "matched_count", len(matches)
|
||||
|
||||
|
||||
class FillTextTemplateBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
@@ -235,71 +212,3 @@ class CombineTextsBlock(Block):
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
combined_text = input_data.delimiter.join(input_data.input)
|
||||
yield "output", combined_text
|
||||
|
||||
|
||||
class TextSplitBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(description="The text to split.")
|
||||
delimiter: str = SchemaField(description="The delimiter to split the text by.")
|
||||
strip: bool = SchemaField(
|
||||
description="Whether to strip the text.", default=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
texts: list[str] = SchemaField(
|
||||
description="The text split into a list of strings."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d5ea33c8-a575-477a-b42f-2fe3be5055ec",
|
||||
description="This block is used to split a text into a list of strings.",
|
||||
categories={BlockCategory.TEXT},
|
||||
input_schema=TextSplitBlock.Input,
|
||||
output_schema=TextSplitBlock.Output,
|
||||
test_input=[
|
||||
{"text": "Hello, World!", "delimiter": ","},
|
||||
{"text": "Hello, World!", "delimiter": ",", "strip": False},
|
||||
],
|
||||
test_output=[
|
||||
("texts", ["Hello", "World!"]),
|
||||
("texts", ["Hello", " World!"]),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
if len(input_data.text) == 0:
|
||||
yield "texts", []
|
||||
else:
|
||||
texts = input_data.text.split(input_data.delimiter)
|
||||
if input_data.strip:
|
||||
texts = [text.strip() for text in texts]
|
||||
yield "texts", texts
|
||||
|
||||
|
||||
class TextReplaceBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(description="The text to replace.")
|
||||
old: str = SchemaField(description="The old text to replace.")
|
||||
new: str = SchemaField(description="The new text to replace with.")
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: str = SchemaField(description="The text with the replaced text.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7e7c87ab-3469-4bcc-9abe-67705091b713",
|
||||
description="This block is used to replace a text with a new text.",
|
||||
categories={BlockCategory.TEXT},
|
||||
input_schema=TextReplaceBlock.Input,
|
||||
output_schema=TextReplaceBlock.Output,
|
||||
test_input=[
|
||||
{"text": "Hello, World!", "old": "Hello", "new": "Hi"},
|
||||
],
|
||||
test_output=[
|
||||
("output", "Hi, World!"),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "output", input_data.text.replace(input_data.old, input_data.new)
|
||||
|
||||
@@ -53,7 +53,7 @@ class UnrealTextToSpeechBlock(Block):
|
||||
super().__init__(
|
||||
id="4ff1ff6d-cc40-4caa-ae69-011daa20c378",
|
||||
description="Converts text to speech using the Unreal Speech API",
|
||||
categories={BlockCategory.AI, BlockCategory.TEXT, BlockCategory.MULTIMEDIA},
|
||||
categories={BlockCategory.AI, BlockCategory.TEXT},
|
||||
input_schema=UnrealTextToSpeechBlock.Input,
|
||||
output_schema=UnrealTextToSpeechBlock.Output,
|
||||
test_input={
|
||||
|
||||
@@ -156,10 +156,6 @@ class CountdownTimerBlock(Block):
|
||||
days: Union[int, str] = SchemaField(
|
||||
advanced=False, description="Duration in days", default=0
|
||||
)
|
||||
repeat: int = SchemaField(
|
||||
description="Number of times to repeat the timer",
|
||||
default=1,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output_message: Any = SchemaField(
|
||||
@@ -191,6 +187,5 @@ class CountdownTimerBlock(Block):
|
||||
|
||||
total_seconds = seconds + minutes * 60 + hours * 3600 + days * 86400
|
||||
|
||||
for _ in range(input_data.repeat):
|
||||
time.sleep(total_seconds)
|
||||
yield "output_message", input_data.input_message
|
||||
time.sleep(total_seconds)
|
||||
yield "output_message", input_data.input_message
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
ProviderName,
|
||||
)
|
||||
from backend.integrations.oauth.todoist import TodoistOAuthHandler
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
secrets = Secrets()
|
||||
TODOIST_OAUTH_IS_CONFIGURED = bool(
|
||||
secrets.todoist_client_id and secrets.todoist_client_secret
|
||||
)
|
||||
|
||||
TodoistCredentials = OAuth2Credentials
|
||||
TodoistCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.TODOIST], Literal["oauth2"]
|
||||
]
|
||||
|
||||
|
||||
def TodoistCredentialsField(scopes: list[str]) -> TodoistCredentialsInput:
|
||||
"""
|
||||
Creates a Todoist credentials input on a block.
|
||||
|
||||
Params:
|
||||
scopes: The authorization scopes needed for the block to work.
|
||||
"""
|
||||
return CredentialsField(
|
||||
required_scopes=set(TodoistOAuthHandler.DEFAULT_SCOPES + scopes),
|
||||
description="The Todoist integration requires OAuth2 authentication.",
|
||||
)
|
||||
|
||||
|
||||
TEST_CREDENTIALS = OAuth2Credentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="todoist",
|
||||
access_token=SecretStr("mock-todoist-access-token"),
|
||||
refresh_token=None,
|
||||
access_token_expires_at=None,
|
||||
scopes=[
|
||||
"task:add",
|
||||
"data:read",
|
||||
"data:read_write",
|
||||
"data:delete",
|
||||
"project:delete",
|
||||
],
|
||||
title="Mock Todoist OAuth2 Credentials",
|
||||
username="mock-todoist-username",
|
||||
refresh_token_expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Colors(Enum):
|
||||
berry_red = "berry_red"
|
||||
red = "red"
|
||||
orange = "orange"
|
||||
yellow = "yellow"
|
||||
olive_green = "olive_green"
|
||||
lime_green = "lime_green"
|
||||
green = "green"
|
||||
mint_green = "mint_green"
|
||||
teal = "teal"
|
||||
sky_blue = "sky_blue"
|
||||
light_blue = "light_blue"
|
||||
blue = "blue"
|
||||
grape = "grape"
|
||||
violet = "violet"
|
||||
lavender = "lavender"
|
||||
magenta = "magenta"
|
||||
salmon = "salmon"
|
||||
charcoal = "charcoal"
|
||||
grey = "grey"
|
||||
taupe = "taupe"
|
||||
@@ -1,439 +0,0 @@
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from todoist_api_python.api import TodoistAPI
|
||||
from typing_extensions import Optional
|
||||
|
||||
from backend.blocks.todoist._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TodoistCredentials,
|
||||
TodoistCredentialsField,
|
||||
TodoistCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TaskId(BaseModel):
|
||||
discriminator: Literal["task"]
|
||||
task_id: str
|
||||
|
||||
|
||||
class ProjectId(BaseModel):
|
||||
discriminator: Literal["project"]
|
||||
project_id: str
|
||||
|
||||
|
||||
class TodoistCreateCommentBlock(Block):
|
||||
"""Creates a new comment on a Todoist task or project"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
content: str = SchemaField(description="Comment content")
|
||||
id_type: Union[TaskId, ProjectId] = SchemaField(
|
||||
discriminator="discriminator",
|
||||
description="Specify either task_id or project_id to comment on",
|
||||
default=TaskId(discriminator="task", task_id=""),
|
||||
advanced=False,
|
||||
)
|
||||
attachment: Optional[dict] = SchemaField(
|
||||
description="Optional file attachment", default=None
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
id: str = SchemaField(description="ID of created comment")
|
||||
content: str = SchemaField(description="Comment content")
|
||||
posted_at: str = SchemaField(description="Comment timestamp")
|
||||
task_id: Optional[str] = SchemaField(
|
||||
description="Associated task ID", default=None
|
||||
)
|
||||
project_id: Optional[str] = SchemaField(
|
||||
description="Associated project ID", default=None
|
||||
)
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1bba7e54-2310-4a31-8e6f-54d5f9ab7459",
|
||||
description="Creates a new comment on a Todoist task or project",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistCreateCommentBlock.Input,
|
||||
output_schema=TodoistCreateCommentBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"content": "Test comment",
|
||||
"id_type": {"discriminator": "task", "task_id": "2995104339"},
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("id", "2992679862"),
|
||||
("content", "Test comment"),
|
||||
("posted_at", "2016-09-22T07:00:00.000000Z"),
|
||||
("task_id", "2995104339"),
|
||||
("project_id", None),
|
||||
],
|
||||
test_mock={
|
||||
"create_comment": lambda content, credentials, task_id=None, project_id=None, attachment=None: {
|
||||
"id": "2992679862",
|
||||
"content": "Test comment",
|
||||
"posted_at": "2016-09-22T07:00:00.000000Z",
|
||||
"task_id": "2995104339",
|
||||
"project_id": None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_comment(
|
||||
credentials: TodoistCredentials,
|
||||
content: str,
|
||||
task_id: Optional[str] = None,
|
||||
project_id: Optional[str] = None,
|
||||
attachment: Optional[dict] = None,
|
||||
):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
comment = api.add_comment(
|
||||
content=content,
|
||||
task_id=task_id,
|
||||
project_id=project_id,
|
||||
attachment=attachment,
|
||||
)
|
||||
return comment.__dict__
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
task_id = None
|
||||
project_id = None
|
||||
|
||||
if isinstance(input_data.id_type, TaskId):
|
||||
task_id = input_data.id_type.task_id
|
||||
else:
|
||||
project_id = input_data.id_type.project_id
|
||||
|
||||
comment_data = self.create_comment(
|
||||
credentials,
|
||||
input_data.content,
|
||||
task_id=task_id,
|
||||
project_id=project_id,
|
||||
attachment=input_data.attachment,
|
||||
)
|
||||
|
||||
if comment_data:
|
||||
yield "id", comment_data["id"]
|
||||
yield "content", comment_data["content"]
|
||||
yield "posted_at", comment_data["posted_at"]
|
||||
yield "task_id", comment_data["task_id"]
|
||||
yield "project_id", comment_data["project_id"]
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistGetCommentsBlock(Block):
|
||||
"""Get all comments for a Todoist task or project"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
id_type: Union[TaskId, ProjectId] = SchemaField(
|
||||
discriminator="discriminator",
|
||||
description="Specify either task_id or project_id to get comments for",
|
||||
default=TaskId(discriminator="task", task_id=""),
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
comments: list = SchemaField(description="List of comments")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9972d8ae-ddf2-11ef-a9b8-32d3674e8b7e",
|
||||
description="Get all comments for a Todoist task or project",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistGetCommentsBlock.Input,
|
||||
output_schema=TodoistGetCommentsBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"id_type": {"discriminator": "task", "task_id": "2995104339"},
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"comments",
|
||||
[
|
||||
{
|
||||
"id": "2992679862",
|
||||
"content": "Test comment",
|
||||
"posted_at": "2016-09-22T07:00:00.000000Z",
|
||||
"task_id": "2995104339",
|
||||
"project_id": None,
|
||||
"attachment": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"get_comments": lambda credentials, task_id=None, project_id=None: [
|
||||
{
|
||||
"id": "2992679862",
|
||||
"content": "Test comment",
|
||||
"posted_at": "2016-09-22T07:00:00.000000Z",
|
||||
"task_id": "2995104339",
|
||||
"project_id": None,
|
||||
"attachment": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_comments(
|
||||
credentials: TodoistCredentials,
|
||||
task_id: Optional[str] = None,
|
||||
project_id: Optional[str] = None,
|
||||
):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
comments = api.get_comments(task_id=task_id, project_id=project_id)
|
||||
return [comment.__dict__ for comment in comments]
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
task_id = None
|
||||
project_id = None
|
||||
|
||||
if isinstance(input_data.id_type, TaskId):
|
||||
task_id = input_data.id_type.task_id
|
||||
else:
|
||||
project_id = input_data.id_type.project_id
|
||||
|
||||
comments = self.get_comments(
|
||||
credentials, task_id=task_id, project_id=project_id
|
||||
)
|
||||
|
||||
yield "comments", comments
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistGetCommentBlock(Block):
|
||||
"""Get a single comment from Todoist using comment ID"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
comment_id: str = SchemaField(description="Comment ID to retrieve")
|
||||
|
||||
class Output(BlockSchema):
|
||||
content: str = SchemaField(description="Comment content")
|
||||
id: str = SchemaField(description="Comment ID")
|
||||
posted_at: str = SchemaField(description="Comment timestamp")
|
||||
project_id: Optional[str] = SchemaField(
|
||||
description="Associated project ID", default=None
|
||||
)
|
||||
task_id: Optional[str] = SchemaField(
|
||||
description="Associated task ID", default=None
|
||||
)
|
||||
attachment: Optional[dict] = SchemaField(
|
||||
description="Optional file attachment", default=None
|
||||
)
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a809d264-ddf2-11ef-9764-32d3674e8b7e",
|
||||
description="Get a single comment from Todoist",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistGetCommentBlock.Input,
|
||||
output_schema=TodoistGetCommentBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"comment_id": "2992679862",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("content", "Test comment"),
|
||||
("id", "2992679862"),
|
||||
("posted_at", "2016-09-22T07:00:00.000000Z"),
|
||||
("project_id", None),
|
||||
("task_id", "2995104339"),
|
||||
("attachment", None),
|
||||
],
|
||||
test_mock={
|
||||
"get_comment": lambda credentials, comment_id: {
|
||||
"content": "Test comment",
|
||||
"id": "2992679862",
|
||||
"posted_at": "2016-09-22T07:00:00.000000Z",
|
||||
"project_id": None,
|
||||
"task_id": "2995104339",
|
||||
"attachment": None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_comment(credentials: TodoistCredentials, comment_id: str):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
comment = api.get_comment(comment_id=comment_id)
|
||||
return comment.__dict__
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
comment_data = self.get_comment(
|
||||
credentials, comment_id=input_data.comment_id
|
||||
)
|
||||
|
||||
if comment_data:
|
||||
yield "content", comment_data["content"]
|
||||
yield "id", comment_data["id"]
|
||||
yield "posted_at", comment_data["posted_at"]
|
||||
yield "project_id", comment_data["project_id"]
|
||||
yield "task_id", comment_data["task_id"]
|
||||
yield "attachment", comment_data["attachment"]
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistUpdateCommentBlock(Block):
|
||||
"""Updates a Todoist comment"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
comment_id: str = SchemaField(description="Comment ID to update")
|
||||
content: str = SchemaField(description="New content for the comment")
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the update was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b773c520-ddf2-11ef-9f34-32d3674e8b7e",
|
||||
description="Updates a Todoist comment",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistUpdateCommentBlock.Input,
|
||||
output_schema=TodoistUpdateCommentBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"comment_id": "2992679862",
|
||||
"content": "Need one bottle of milk",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"update_comment": lambda credentials, comment_id, content: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_comment(credentials: TodoistCredentials, comment_id: str, content: str):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
api.update_comment(comment_id=comment_id, content=content)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.update_comment(
|
||||
credentials,
|
||||
comment_id=input_data.comment_id,
|
||||
content=input_data.content,
|
||||
)
|
||||
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistDeleteCommentBlock(Block):
|
||||
"""Deletes a Todoist comment"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
comment_id: str = SchemaField(description="Comment ID to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the deletion was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bda4c020-ddf2-11ef-b114-32d3674e8b7e",
|
||||
description="Deletes a Todoist comment",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistDeleteCommentBlock.Input,
|
||||
output_schema=TodoistDeleteCommentBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"comment_id": "2992679862",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"delete_comment": lambda credentials, comment_id: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def delete_comment(credentials: TodoistCredentials, comment_id: str):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
success = api.delete_comment(comment_id=comment_id)
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.delete_comment(credentials, comment_id=input_data.comment_id)
|
||||
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,557 +0,0 @@
|
||||
from todoist_api_python.api import TodoistAPI
|
||||
from typing_extensions import Optional
|
||||
|
||||
from backend.blocks.todoist._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TodoistCredentials,
|
||||
TodoistCredentialsField,
|
||||
TodoistCredentialsInput,
|
||||
)
|
||||
from backend.blocks.todoist._types import Colors
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TodoistCreateLabelBlock(Block):
|
||||
"""Creates a new label in Todoist"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
name: str = SchemaField(description="Name of the label")
|
||||
order: Optional[int] = SchemaField(description="Label order", default=None)
|
||||
color: Optional[Colors] = SchemaField(
|
||||
description="The color of the label icon", default=Colors.charcoal
|
||||
)
|
||||
is_favorite: bool = SchemaField(
|
||||
description="Whether the label is a favorite", default=False
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
id: str = SchemaField(description="ID of the created label")
|
||||
name: str = SchemaField(description="Name of the label")
|
||||
color: str = SchemaField(description="Color of the label")
|
||||
order: int = SchemaField(description="Label order")
|
||||
is_favorite: bool = SchemaField(description="Favorite status")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7288a968-de14-11ef-8997-32d3674e8b7e",
|
||||
description="Creates a new label in Todoist, It will not work if same name already exists",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistCreateLabelBlock.Input,
|
||||
output_schema=TodoistCreateLabelBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"name": "Test Label",
|
||||
"color": Colors.charcoal.value,
|
||||
"order": 1,
|
||||
"is_favorite": False,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("id", "2156154810"),
|
||||
("name", "Test Label"),
|
||||
("color", "charcoal"),
|
||||
("order", 1),
|
||||
("is_favorite", False),
|
||||
],
|
||||
test_mock={
|
||||
"create_label": lambda *args, **kwargs: {
|
||||
"id": "2156154810",
|
||||
"name": "Test Label",
|
||||
"color": "charcoal",
|
||||
"order": 1,
|
||||
"is_favorite": False,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_label(credentials: TodoistCredentials, name: str, **kwargs):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
label = api.add_label(name=name, **kwargs)
|
||||
return label.__dict__
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
label_args = {
|
||||
"order": input_data.order,
|
||||
"color": (
|
||||
input_data.color.value if input_data.color is not None else None
|
||||
),
|
||||
"is_favorite": input_data.is_favorite,
|
||||
}
|
||||
|
||||
label_data = self.create_label(
|
||||
credentials,
|
||||
input_data.name,
|
||||
**{k: v for k, v in label_args.items() if v is not None},
|
||||
)
|
||||
|
||||
if label_data:
|
||||
yield "id", label_data["id"]
|
||||
yield "name", label_data["name"]
|
||||
yield "color", label_data["color"]
|
||||
yield "order", label_data["order"]
|
||||
yield "is_favorite", label_data["is_favorite"]
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistListLabelsBlock(Block):
|
||||
"""Gets all personal labels from Todoist"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
|
||||
class Output(BlockSchema):
|
||||
labels: list = SchemaField(description="List of complete label data")
|
||||
label_ids: list = SchemaField(description="List of label IDs")
|
||||
label_names: list = SchemaField(description="List of label names")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="776dd750-de14-11ef-b927-32d3674e8b7e",
|
||||
description="Gets all personal labels from Todoist",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistListLabelsBlock.Input,
|
||||
output_schema=TodoistListLabelsBlock.Output,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"labels",
|
||||
[
|
||||
{
|
||||
"id": "2156154810",
|
||||
"name": "Test Label",
|
||||
"color": "charcoal",
|
||||
"order": 1,
|
||||
"is_favorite": False,
|
||||
}
|
||||
],
|
||||
),
|
||||
("label_ids", ["2156154810"]),
|
||||
("label_names", ["Test Label"]),
|
||||
],
|
||||
test_mock={
|
||||
"get_labels": lambda *args, **kwargs: [
|
||||
{
|
||||
"id": "2156154810",
|
||||
"name": "Test Label",
|
||||
"color": "charcoal",
|
||||
"order": 1,
|
||||
"is_favorite": False,
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_labels(credentials: TodoistCredentials):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
labels = api.get_labels()
|
||||
return [label.__dict__ for label in labels]
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
labels = self.get_labels(credentials)
|
||||
yield "labels", labels
|
||||
yield "label_ids", [label["id"] for label in labels]
|
||||
yield "label_names", [label["name"] for label in labels]
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistGetLabelBlock(Block):
|
||||
"""Gets a personal label from Todoist by ID"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
label_id: str = SchemaField(description="ID of the label to retrieve")
|
||||
|
||||
class Output(BlockSchema):
|
||||
id: str = SchemaField(description="ID of the label")
|
||||
name: str = SchemaField(description="Name of the label")
|
||||
color: str = SchemaField(description="Color of the label")
|
||||
order: int = SchemaField(description="Label order")
|
||||
is_favorite: bool = SchemaField(description="Favorite status")
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7f236514-de14-11ef-bd7a-32d3674e8b7e",
|
||||
description="Gets a personal label from Todoist by ID",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistGetLabelBlock.Input,
|
||||
output_schema=TodoistGetLabelBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"label_id": "2156154810",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("id", "2156154810"),
|
||||
("name", "Test Label"),
|
||||
("color", "charcoal"),
|
||||
("order", 1),
|
||||
("is_favorite", False),
|
||||
],
|
||||
test_mock={
|
||||
"get_label": lambda *args, **kwargs: {
|
||||
"id": "2156154810",
|
||||
"name": "Test Label",
|
||||
"color": "charcoal",
|
||||
"order": 1,
|
||||
"is_favorite": False,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_label(credentials: TodoistCredentials, label_id: str):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
label = api.get_label(label_id=label_id)
|
||||
return label.__dict__
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
label_data = self.get_label(credentials, input_data.label_id)
|
||||
|
||||
if label_data:
|
||||
yield "id", label_data["id"]
|
||||
yield "name", label_data["name"]
|
||||
yield "color", label_data["color"]
|
||||
yield "order", label_data["order"]
|
||||
yield "is_favorite", label_data["is_favorite"]
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistUpdateLabelBlock(Block):
|
||||
"""Updates a personal label in Todoist using ID"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
label_id: str = SchemaField(description="ID of the label to update")
|
||||
name: Optional[str] = SchemaField(
|
||||
description="New name of the label", default=None
|
||||
)
|
||||
order: Optional[int] = SchemaField(description="Label order", default=None)
|
||||
color: Optional[Colors] = SchemaField(
|
||||
description="The color of the label icon", default=None
|
||||
)
|
||||
is_favorite: bool = SchemaField(
|
||||
description="Whether the label is a favorite (true/false)", default=False
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the update was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8755614c-de14-11ef-9b56-32d3674e8b7e",
|
||||
description="Updates a personal label in Todoist",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistUpdateLabelBlock.Input,
|
||||
output_schema=TodoistUpdateLabelBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"label_id": "2156154810",
|
||||
"name": "Updated Label",
|
||||
"color": Colors.charcoal.value,
|
||||
"order": 2,
|
||||
"is_favorite": True,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"update_label": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_label(credentials: TodoistCredentials, label_id: str, **kwargs):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
api.update_label(label_id=label_id, **kwargs)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
label_args = {}
|
||||
if input_data.name is not None:
|
||||
label_args["name"] = input_data.name
|
||||
if input_data.order is not None:
|
||||
label_args["order"] = input_data.order
|
||||
if input_data.color is not None:
|
||||
label_args["color"] = input_data.color.value
|
||||
if input_data.is_favorite is not None:
|
||||
label_args["is_favorite"] = input_data.is_favorite
|
||||
|
||||
success = self.update_label(
|
||||
credentials,
|
||||
input_data.label_id,
|
||||
**{k: v for k, v in label_args.items() if v is not None},
|
||||
)
|
||||
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistDeleteLabelBlock(Block):
|
||||
"""Deletes a personal label in Todoist"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
label_id: str = SchemaField(description="ID of the label to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the deletion was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="901b8f86-de14-11ef-98b8-32d3674e8b7e",
|
||||
description="Deletes a personal label in Todoist",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistDeleteLabelBlock.Input,
|
||||
output_schema=TodoistDeleteLabelBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"label_id": "2156154810",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"delete_label": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def delete_label(credentials: TodoistCredentials, label_id: str):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
success = api.delete_label(label_id=label_id)
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.delete_label(credentials, input_data.label_id)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistGetSharedLabelsBlock(Block):
|
||||
"""Gets all shared labels from Todoist"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
|
||||
class Output(BlockSchema):
|
||||
labels: list = SchemaField(description="List of shared label names")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="55fba510-de15-11ef-aed2-32d3674e8b7e",
|
||||
description="Gets all shared labels from Todoist",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistGetSharedLabelsBlock.Input,
|
||||
output_schema=TodoistGetSharedLabelsBlock.Output,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("labels", ["Label1", "Label2", "Label3"])],
|
||||
test_mock={
|
||||
"get_shared_labels": lambda *args, **kwargs: [
|
||||
"Label1",
|
||||
"Label2",
|
||||
"Label3",
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_shared_labels(credentials: TodoistCredentials):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
labels = api.get_shared_labels()
|
||||
return labels
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
labels = self.get_shared_labels(credentials)
|
||||
yield "labels", labels
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistRenameSharedLabelsBlock(Block):
|
||||
"""Renames all instances of a shared label"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
name: str = SchemaField(description="The name of the existing label to rename")
|
||||
new_name: str = SchemaField(description="The new name for the label")
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the rename was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9d63ad9a-de14-11ef-ab3f-32d3674e8b7e",
|
||||
description="Renames all instances of a shared label",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistRenameSharedLabelsBlock.Input,
|
||||
output_schema=TodoistRenameSharedLabelsBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"name": "OldLabel",
|
||||
"new_name": "NewLabel",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"rename_shared_labels": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def rename_shared_labels(credentials: TodoistCredentials, name: str, new_name: str):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
success = api.rename_shared_label(name=name, new_name=new_name)
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.rename_shared_labels(
|
||||
credentials, input_data.name, input_data.new_name
|
||||
)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistRemoveSharedLabelsBlock(Block):
|
||||
"""Removes all instances of a shared label"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
name: str = SchemaField(description="The name of the label to remove")
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the removal was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a6c5cbde-de14-11ef-8863-32d3674e8b7e",
|
||||
description="Removes all instances of a shared label",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistRemoveSharedLabelsBlock.Input,
|
||||
output_schema=TodoistRemoveSharedLabelsBlock.Output,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "name": "LabelToRemove"},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"remove_shared_label": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def remove_shared_label(credentials: TodoistCredentials, name: str):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
success = api.remove_shared_label(name=name)
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.remove_shared_label(credentials, input_data.name)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,566 +0,0 @@
|
||||
from todoist_api_python.api import TodoistAPI
|
||||
from typing_extensions import Optional
|
||||
|
||||
from backend.blocks.todoist._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TodoistCredentials,
|
||||
TodoistCredentialsField,
|
||||
TodoistCredentialsInput,
|
||||
)
|
||||
from backend.blocks.todoist._types import Colors
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TodoistListProjectsBlock(Block):
|
||||
"""Gets all projects for a Todoist user"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
|
||||
class Output(BlockSchema):
|
||||
names_list: list[str] = SchemaField(description="List of project names")
|
||||
ids_list: list[str] = SchemaField(description="List of project IDs")
|
||||
url_list: list[str] = SchemaField(description="List of project URLs")
|
||||
complete_data: list[dict] = SchemaField(
|
||||
description="Complete project data including all fields"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5f3e1d5b-6bc5-40e3-97ee-1318b3f38813",
|
||||
description="Gets all projects and their details from Todoist",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistListProjectsBlock.Input,
|
||||
output_schema=TodoistListProjectsBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("names_list", ["Inbox"]),
|
||||
("ids_list", ["220474322"]),
|
||||
("url_list", ["https://todoist.com/showProject?id=220474322"]),
|
||||
(
|
||||
"complete_data",
|
||||
[
|
||||
{
|
||||
"id": "220474322",
|
||||
"name": "Inbox",
|
||||
"url": "https://todoist.com/showProject?id=220474322",
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_project_lists": lambda *args, **kwargs: (
|
||||
["Inbox"],
|
||||
["220474322"],
|
||||
["https://todoist.com/showProject?id=220474322"],
|
||||
[
|
||||
{
|
||||
"id": "220474322",
|
||||
"name": "Inbox",
|
||||
"url": "https://todoist.com/showProject?id=220474322",
|
||||
}
|
||||
],
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_project_lists(credentials: TodoistCredentials):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
projects = api.get_projects()
|
||||
|
||||
names = []
|
||||
ids = []
|
||||
urls = []
|
||||
complete_data = []
|
||||
|
||||
for project in projects:
|
||||
names.append(project.name)
|
||||
ids.append(project.id)
|
||||
urls.append(project.url)
|
||||
complete_data.append(project.__dict__)
|
||||
|
||||
return names, ids, urls, complete_data, None
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
names, ids, urls, data, error = self.get_project_lists(credentials)
|
||||
|
||||
if names:
|
||||
yield "names_list", names
|
||||
if ids:
|
||||
yield "ids_list", ids
|
||||
if urls:
|
||||
yield "url_list", urls
|
||||
if data:
|
||||
yield "complete_data", data
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistCreateProjectBlock(Block):
|
||||
"""Creates a new project in Todoist"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
name: str = SchemaField(description="Name of the project", advanced=False)
|
||||
parent_id: Optional[str] = SchemaField(
|
||||
description="Parent project ID", default=None, advanced=True
|
||||
)
|
||||
color: Optional[Colors] = SchemaField(
|
||||
description="Color of the project icon",
|
||||
default=Colors.charcoal,
|
||||
advanced=True,
|
||||
)
|
||||
is_favorite: bool = SchemaField(
|
||||
description="Whether the project is a favorite",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
view_style: Optional[str] = SchemaField(
|
||||
description="Display style (list or board)", default=None, advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the creation was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ade60136-de14-11ef-b5e5-32d3674e8b7e",
|
||||
description="Creates a new project in Todoist",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistCreateProjectBlock.Input,
|
||||
output_schema=TodoistCreateProjectBlock.Output,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "name": "Test Project"},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"create_project": lambda *args, **kwargs: (True)},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_project(
|
||||
credentials: TodoistCredentials,
|
||||
name: str,
|
||||
parent_id: Optional[str],
|
||||
color: Optional[Colors],
|
||||
is_favorite: bool,
|
||||
view_style: Optional[str],
|
||||
):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
params = {"name": name, "is_favorite": is_favorite}
|
||||
|
||||
if parent_id is not None:
|
||||
params["parent_id"] = parent_id
|
||||
if color is not None:
|
||||
params["color"] = color.value
|
||||
if view_style is not None:
|
||||
params["view_style"] = view_style
|
||||
|
||||
api.add_project(**params)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.create_project(
|
||||
credentials=credentials,
|
||||
name=input_data.name,
|
||||
parent_id=input_data.parent_id,
|
||||
color=input_data.color,
|
||||
is_favorite=input_data.is_favorite,
|
||||
view_style=input_data.view_style,
|
||||
)
|
||||
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistGetProjectBlock(Block):
|
||||
"""Gets details for a specific Todoist project"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
project_id: str = SchemaField(
|
||||
description="ID of the project to get details for", advanced=False
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
project_id: str = SchemaField(description="ID of project")
|
||||
project_name: str = SchemaField(description="Name of project")
|
||||
project_url: str = SchemaField(description="URL of project")
|
||||
complete_data: dict = SchemaField(
|
||||
description="Complete project data including all fields"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b435b5ea-de14-11ef-8b51-32d3674e8b7e",
|
||||
description="Gets details for a specific Todoist project",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistGetProjectBlock.Input,
|
||||
output_schema=TodoistGetProjectBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"project_id": "2203306141",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("project_id", "2203306141"),
|
||||
("project_name", "Shopping List"),
|
||||
("project_url", "https://todoist.com/showProject?id=2203306141"),
|
||||
(
|
||||
"complete_data",
|
||||
{
|
||||
"id": "2203306141",
|
||||
"name": "Shopping List",
|
||||
"url": "https://todoist.com/showProject?id=2203306141",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_project": lambda *args, **kwargs: (
|
||||
"2203306141",
|
||||
"Shopping List",
|
||||
"https://todoist.com/showProject?id=2203306141",
|
||||
{
|
||||
"id": "2203306141",
|
||||
"name": "Shopping List",
|
||||
"url": "https://todoist.com/showProject?id=2203306141",
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_project(credentials: TodoistCredentials, project_id: str):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
project = api.get_project(project_id=project_id)
|
||||
|
||||
return project.id, project.name, project.url, project.__dict__
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
project_id, project_name, project_url, data = self.get_project(
|
||||
credentials=credentials, project_id=input_data.project_id
|
||||
)
|
||||
|
||||
if project_id:
|
||||
yield "project_id", project_id
|
||||
if project_name:
|
||||
yield "project_name", project_name
|
||||
if project_url:
|
||||
yield "project_url", project_url
|
||||
if data:
|
||||
yield "complete_data", data
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistUpdateProjectBlock(Block):
|
||||
"""Updates an existing project in Todoist"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
project_id: str = SchemaField(
|
||||
description="ID of project to update", advanced=False
|
||||
)
|
||||
name: Optional[str] = SchemaField(
|
||||
description="New name for the project", default=None, advanced=False
|
||||
)
|
||||
color: Optional[Colors] = SchemaField(
|
||||
description="New color for the project icon", default=None, advanced=True
|
||||
)
|
||||
is_favorite: Optional[bool] = SchemaField(
|
||||
description="Whether the project should be a favorite",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
view_style: Optional[str] = SchemaField(
|
||||
description="Display style (list or board)", default=None, advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the update was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ba41a20a-de14-11ef-91d7-32d3674e8b7e",
|
||||
description="Updates an existing project in Todoist",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistUpdateProjectBlock.Input,
|
||||
output_schema=TodoistUpdateProjectBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"project_id": "2203306141",
|
||||
"name": "Things To Buy",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"update_project": lambda *args, **kwargs: (True)},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_project(
|
||||
credentials: TodoistCredentials,
|
||||
project_id: str,
|
||||
name: Optional[str],
|
||||
color: Optional[Colors],
|
||||
is_favorite: Optional[bool],
|
||||
view_style: Optional[str],
|
||||
):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
params = {}
|
||||
|
||||
if name is not None:
|
||||
params["name"] = name
|
||||
if color is not None:
|
||||
params["color"] = color.value
|
||||
if is_favorite is not None:
|
||||
params["is_favorite"] = is_favorite
|
||||
if view_style is not None:
|
||||
params["view_style"] = view_style
|
||||
|
||||
api.update_project(project_id=project_id, **params)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.update_project(
|
||||
credentials=credentials,
|
||||
project_id=input_data.project_id,
|
||||
name=input_data.name,
|
||||
color=input_data.color,
|
||||
is_favorite=input_data.is_favorite,
|
||||
view_style=input_data.view_style,
|
||||
)
|
||||
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistDeleteProjectBlock(Block):
|
||||
"""Deletes a project and all of its sections and tasks"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
project_id: str = SchemaField(
|
||||
description="ID of project to delete", advanced=False
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the deletion was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c2893acc-de14-11ef-a113-32d3674e8b7e",
|
||||
description="Deletes a Todoist project and all its contents",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistDeleteProjectBlock.Input,
|
||||
output_schema=TodoistDeleteProjectBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"project_id": "2203306141",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"delete_project": lambda *args, **kwargs: (True)},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def delete_project(credentials: TodoistCredentials, project_id: str):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
success = api.delete_project(project_id=project_id)
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.delete_project(
|
||||
credentials=credentials, project_id=input_data.project_id
|
||||
)
|
||||
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistListCollaboratorsBlock(Block):
|
||||
"""Gets all collaborators for a Todoist project"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
project_id: str = SchemaField(
|
||||
description="ID of the project to get collaborators for", advanced=False
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
collaborator_ids: list[str] = SchemaField(
|
||||
description="List of collaborator IDs"
|
||||
)
|
||||
collaborator_names: list[str] = SchemaField(
|
||||
description="List of collaborator names"
|
||||
)
|
||||
collaborator_emails: list[str] = SchemaField(
|
||||
description="List of collaborator email addresses"
|
||||
)
|
||||
complete_data: list[dict] = SchemaField(
|
||||
description="Complete collaborator data including all fields"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c99c804e-de14-11ef-9f47-32d3674e8b7e",
|
||||
description="Gets all collaborators for a specific Todoist project",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistListCollaboratorsBlock.Input,
|
||||
output_schema=TodoistListCollaboratorsBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"project_id": "2203306141",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("collaborator_ids", ["2671362", "2671366"]),
|
||||
("collaborator_names", ["Alice", "Bob"]),
|
||||
("collaborator_emails", ["alice@example.com", "bob@example.com"]),
|
||||
(
|
||||
"complete_data",
|
||||
[
|
||||
{
|
||||
"id": "2671362",
|
||||
"name": "Alice",
|
||||
"email": "alice@example.com",
|
||||
},
|
||||
{"id": "2671366", "name": "Bob", "email": "bob@example.com"},
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_collaborators": lambda *args, **kwargs: (
|
||||
["2671362", "2671366"],
|
||||
["Alice", "Bob"],
|
||||
["alice@example.com", "bob@example.com"],
|
||||
[
|
||||
{
|
||||
"id": "2671362",
|
||||
"name": "Alice",
|
||||
"email": "alice@example.com",
|
||||
},
|
||||
{"id": "2671366", "name": "Bob", "email": "bob@example.com"},
|
||||
],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_collaborators(credentials: TodoistCredentials, project_id: str):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
collaborators = api.get_collaborators(project_id=project_id)
|
||||
|
||||
ids = []
|
||||
names = []
|
||||
emails = []
|
||||
complete_data = []
|
||||
|
||||
for collaborator in collaborators:
|
||||
ids.append(collaborator.id)
|
||||
names.append(collaborator.name)
|
||||
emails.append(collaborator.email)
|
||||
complete_data.append(collaborator.__dict__)
|
||||
|
||||
return ids, names, emails, complete_data
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, names, emails, data = self.get_collaborators(
|
||||
credentials=credentials, project_id=input_data.project_id
|
||||
)
|
||||
|
||||
if ids:
|
||||
yield "collaborator_ids", ids
|
||||
if names:
|
||||
yield "collaborator_names", names
|
||||
if emails:
|
||||
yield "collaborator_emails", emails
|
||||
if data:
|
||||
yield "complete_data", data
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,306 +0,0 @@
|
||||
from todoist_api_python.api import TodoistAPI
|
||||
from typing_extensions import Optional
|
||||
|
||||
from backend.blocks.todoist._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TodoistCredentials,
|
||||
TodoistCredentialsField,
|
||||
TodoistCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TodoistListSectionsBlock(Block):
|
||||
"""Gets all sections for a Todoist project"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
project_id: Optional[str] = SchemaField(
|
||||
description="Optional project ID to filter sections"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
names_list: list[str] = SchemaField(description="List of section names")
|
||||
ids_list: list[str] = SchemaField(description="List of section IDs")
|
||||
complete_data: list[dict] = SchemaField(
|
||||
description="Complete section data including all fields"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d6a116d8-de14-11ef-a94c-32d3674e8b7e",
|
||||
description="Gets all sections and their details from Todoist",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistListSectionsBlock.Input,
|
||||
output_schema=TodoistListSectionsBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"project_id": "2203306141",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("names_list", ["Groceries"]),
|
||||
("ids_list", ["7025"]),
|
||||
(
|
||||
"complete_data",
|
||||
[
|
||||
{
|
||||
"id": "7025",
|
||||
"project_id": "2203306141",
|
||||
"order": 1,
|
||||
"name": "Groceries",
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_section_lists": lambda *args, **kwargs: (
|
||||
["Groceries"],
|
||||
["7025"],
|
||||
[
|
||||
{
|
||||
"id": "7025",
|
||||
"project_id": "2203306141",
|
||||
"order": 1,
|
||||
"name": "Groceries",
|
||||
}
|
||||
],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_section_lists(
|
||||
credentials: TodoistCredentials, project_id: Optional[str] = None
|
||||
):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
sections = api.get_sections(project_id=project_id)
|
||||
|
||||
names = []
|
||||
ids = []
|
||||
complete_data = []
|
||||
|
||||
for section in sections:
|
||||
names.append(section.name)
|
||||
ids.append(section.id)
|
||||
complete_data.append(section.__dict__)
|
||||
|
||||
return names, ids, complete_data
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
names, ids, data = self.get_section_lists(
|
||||
credentials, input_data.project_id
|
||||
)
|
||||
|
||||
if names:
|
||||
yield "names_list", names
|
||||
if ids:
|
||||
yield "ids_list", ids
|
||||
if data:
|
||||
yield "complete_data", data
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
# Error in official todoist SDK. Will add this block using sync_api
|
||||
# class TodoistCreateSectionBlock(Block):
|
||||
# """Creates a new section in a Todoist project"""
|
||||
|
||||
# class Input(BlockSchema):
|
||||
# credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
# name: str = SchemaField(description="Section name")
|
||||
# project_id: str = SchemaField(description="Project ID this section should belong to")
|
||||
# order: Optional[int] = SchemaField(description="Optional order among other sections", default=None)
|
||||
|
||||
# class Output(BlockSchema):
|
||||
# success: bool = SchemaField(description="Whether section was successfully created")
|
||||
# error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
# def __init__(self):
|
||||
# super().__init__(
|
||||
# id="e3025cfc-de14-11ef-b9f2-32d3674e8b7e",
|
||||
# description="Creates a new section in a Todoist project",
|
||||
# categories={BlockCategory.PRODUCTIVITY},
|
||||
# input_schema=TodoistCreateSectionBlock.Input,
|
||||
# output_schema=TodoistCreateSectionBlock.Output,
|
||||
# test_input={
|
||||
# "credentials": TEST_CREDENTIALS_INPUT,
|
||||
# "name": "Groceries",
|
||||
# "project_id": "2203306141"
|
||||
# },
|
||||
# test_credentials=TEST_CREDENTIALS,
|
||||
# test_output=[
|
||||
# ("success", True)
|
||||
# ],
|
||||
# test_mock={
|
||||
# "create_section": lambda *args, **kwargs: (
|
||||
# {"id": "7025", "project_id": "2203306141", "order": 1, "name": "Groceries"},
|
||||
# )
|
||||
# },
|
||||
# )
|
||||
|
||||
# @staticmethod
|
||||
# def create_section(credentials: TodoistCredentials, name: str, project_id: str, order: Optional[int] = None):
|
||||
# try:
|
||||
# api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
# section = api.add_section(name=name, project_id=project_id, order=order)
|
||||
# return section.__dict__
|
||||
|
||||
# except Exception as e:
|
||||
# raise e
|
||||
|
||||
# def run(
|
||||
# self,
|
||||
# input_data: Input,
|
||||
# *,
|
||||
# credentials: TodoistCredentials,
|
||||
# **kwargs,
|
||||
# ) -> BlockOutput:
|
||||
# try:
|
||||
# section_data = self.create_section(
|
||||
# credentials,
|
||||
# input_data.name,
|
||||
# input_data.project_id,
|
||||
# input_data.order
|
||||
# )
|
||||
|
||||
# if section_data:
|
||||
# yield "success", True
|
||||
|
||||
# except Exception as e:
|
||||
# yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistGetSectionBlock(Block):
|
||||
"""Gets a single section from Todoist by ID"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
section_id: str = SchemaField(description="ID of section to fetch")
|
||||
|
||||
class Output(BlockSchema):
|
||||
id: str = SchemaField(description="ID of section")
|
||||
project_id: str = SchemaField(description="Project ID the section belongs to")
|
||||
order: int = SchemaField(description="Order of the section")
|
||||
name: str = SchemaField(description="Name of the section")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ea5580e2-de14-11ef-a5d3-32d3674e8b7e",
|
||||
description="Gets a single section by ID from Todoist",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistGetSectionBlock.Input,
|
||||
output_schema=TodoistGetSectionBlock.Output,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "section_id": "7025"},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("id", "7025"),
|
||||
("project_id", "2203306141"),
|
||||
("order", 1),
|
||||
("name", "Groceries"),
|
||||
],
|
||||
test_mock={
|
||||
"get_section": lambda *args, **kwargs: {
|
||||
"id": "7025",
|
||||
"project_id": "2203306141",
|
||||
"order": 1,
|
||||
"name": "Groceries",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_section(credentials: TodoistCredentials, section_id: str):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
section = api.get_section(section_id=section_id)
|
||||
return section.__dict__
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
section_data = self.get_section(credentials, input_data.section_id)
|
||||
|
||||
if section_data:
|
||||
yield "id", section_data["id"]
|
||||
yield "project_id", section_data["project_id"]
|
||||
yield "order", section_data["order"]
|
||||
yield "name", section_data["name"]
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistDeleteSectionBlock(Block):
|
||||
"""Deletes a section and all its tasks from Todoist"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
section_id: str = SchemaField(description="ID of section to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether section was successfully deleted"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f0e52eee-de14-11ef-9b12-32d3674e8b7e",
|
||||
description="Deletes a section and all its tasks from Todoist",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistDeleteSectionBlock.Input,
|
||||
output_schema=TodoistDeleteSectionBlock.Output,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "section_id": "7025"},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"delete_section": lambda *args, **kwargs: (True)},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def delete_section(credentials: TodoistCredentials, section_id: str):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
success = api.delete_section(section_id=section_id)
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.delete_section(credentials, input_data.section_id)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -1,660 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from todoist_api_python.api import TodoistAPI
|
||||
from todoist_api_python.models import Task
|
||||
from typing_extensions import Optional
|
||||
|
||||
from backend.blocks.todoist._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TodoistCredentials,
|
||||
TodoistCredentialsField,
|
||||
TodoistCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TodoistCreateTaskBlock(Block):
|
||||
"""Creates a new task in a Todoist project"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
content: str = SchemaField(description="Task content", advanced=False)
|
||||
description: Optional[str] = SchemaField(
|
||||
description="Task description", default=None, advanced=False
|
||||
)
|
||||
project_id: Optional[str] = SchemaField(
|
||||
description="Project ID this task should belong to",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
section_id: Optional[str] = SchemaField(
|
||||
description="Section ID this task should belong to",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
parent_id: Optional[str] = SchemaField(
|
||||
description="Parent task ID", default=None, advanced=True
|
||||
)
|
||||
order: Optional[int] = SchemaField(
|
||||
description="Optional order among other tasks,[Non-zero integer value used by clients to sort tasks under the same parent]",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
labels: Optional[list[str]] = SchemaField(
|
||||
description="Task labels", default=None, advanced=True
|
||||
)
|
||||
priority: Optional[int] = SchemaField(
|
||||
description="Task priority from 1 (normal) to 4 (urgent)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
due_date: Optional[datetime] = SchemaField(
|
||||
description="Due date in YYYY-MM-DD format", advanced=True, default=None
|
||||
)
|
||||
deadline_date: Optional[datetime] = SchemaField(
|
||||
description="Specific date in YYYY-MM-DD format relative to user's timezone",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
assignee_id: Optional[str] = SchemaField(
|
||||
description="Responsible user ID", default=None, advanced=True
|
||||
)
|
||||
duration_unit: Optional[str] = SchemaField(
|
||||
description="Task duration unit (minute/day)", default=None, advanced=True
|
||||
)
|
||||
duration: Optional[int] = SchemaField(
|
||||
description="Task duration amount, You need to selecct the duration unit first",
|
||||
depends_on=["duration_unit"],
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
id: str = SchemaField(description="Task ID")
|
||||
url: str = SchemaField(description="Task URL")
|
||||
complete_data: dict = SchemaField(
|
||||
description="Complete task data as dictionary"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fde4f458-de14-11ef-bf0c-32d3674e8b7e",
|
||||
description="Creates a new task in a Todoist project",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistCreateTaskBlock.Input,
|
||||
output_schema=TodoistCreateTaskBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"content": "Buy groceries",
|
||||
"project_id": "2203306141",
|
||||
"priority": 4,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("id", "2995104339"),
|
||||
("url", "https://todoist.com/showTask?id=2995104339"),
|
||||
(
|
||||
"complete_data",
|
||||
{
|
||||
"id": "2995104339",
|
||||
"project_id": "2203306141",
|
||||
"url": "https://todoist.com/showTask?id=2995104339",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"create_task": lambda *args, **kwargs: (
|
||||
"2995104339",
|
||||
"https://todoist.com/showTask?id=2995104339",
|
||||
{
|
||||
"id": "2995104339",
|
||||
"project_id": "2203306141",
|
||||
"url": "https://todoist.com/showTask?id=2995104339",
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_task(credentials: TodoistCredentials, content: str, **kwargs):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
task = api.add_task(content=content, **kwargs)
|
||||
task_dict = Task.to_dict(task)
|
||||
return task.id, task.url, task_dict
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
due_date = (
|
||||
input_data.due_date.strftime("%Y-%m-%d")
|
||||
if input_data.due_date
|
||||
else None
|
||||
)
|
||||
deadline_date = (
|
||||
input_data.deadline_date.strftime("%Y-%m-%d")
|
||||
if input_data.deadline_date
|
||||
else None
|
||||
)
|
||||
|
||||
task_args = {
|
||||
"description": input_data.description,
|
||||
"project_id": input_data.project_id,
|
||||
"section_id": input_data.section_id,
|
||||
"parent_id": input_data.parent_id,
|
||||
"order": input_data.order,
|
||||
"labels": input_data.labels,
|
||||
"priority": input_data.priority,
|
||||
"due_date": due_date,
|
||||
"deadline_date": deadline_date,
|
||||
"assignee_id": input_data.assignee_id,
|
||||
"duration": input_data.duration,
|
||||
"duration_unit": input_data.duration_unit,
|
||||
}
|
||||
|
||||
id, url, complete_data = self.create_task(
|
||||
credentials,
|
||||
input_data.content,
|
||||
**{k: v for k, v in task_args.items() if v is not None},
|
||||
)
|
||||
|
||||
yield "id", id
|
||||
yield "url", url
|
||||
yield "complete_data", complete_data
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistGetTasksBlock(Block):
|
||||
"""Get active tasks from Todoist"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
project_id: Optional[str] = SchemaField(
|
||||
description="Filter tasks by project ID", default=None, advanced=False
|
||||
)
|
||||
section_id: Optional[str] = SchemaField(
|
||||
description="Filter tasks by section ID", default=None, advanced=True
|
||||
)
|
||||
label: Optional[str] = SchemaField(
|
||||
description="Filter tasks by label name", default=None, advanced=True
|
||||
)
|
||||
filter: Optional[str] = SchemaField(
|
||||
description="Filter by any supported filter, You can see How to use filters or create one of your one here - https://todoist.com/help/articles/introduction-to-filters-V98wIH",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
lang: Optional[str] = SchemaField(
|
||||
description="IETF language tag for filter language", default=None
|
||||
)
|
||||
ids: Optional[list[str]] = SchemaField(
|
||||
description="List of task IDs to retrieve", default=None, advanced=False
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
ids: list[str] = SchemaField(description="Task IDs")
|
||||
urls: list[str] = SchemaField(description="Task URLs")
|
||||
complete_data: list[dict] = SchemaField(
|
||||
description="Complete task data as dictionary"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0b706e86-de15-11ef-a113-32d3674e8b7e",
|
||||
description="Get active tasks from Todoist",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistGetTasksBlock.Input,
|
||||
output_schema=TodoistGetTasksBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"project_id": "2203306141",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["2995104339"]),
|
||||
("urls", ["https://todoist.com/showTask?id=2995104339"]),
|
||||
(
|
||||
"complete_data",
|
||||
[
|
||||
{
|
||||
"id": "2995104339",
|
||||
"project_id": "2203306141",
|
||||
"url": "https://todoist.com/showTask?id=2995104339",
|
||||
"is_completed": False,
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_tasks": lambda *args, **kwargs: [
|
||||
{
|
||||
"id": "2995104339",
|
||||
"project_id": "2203306141",
|
||||
"url": "https://todoist.com/showTask?id=2995104339",
|
||||
"is_completed": False,
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_tasks(credentials: TodoistCredentials, **kwargs):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
tasks = api.get_tasks(**kwargs)
|
||||
return [Task.to_dict(task) for task in tasks]
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
task_filters = {
|
||||
"project_id": input_data.project_id,
|
||||
"section_id": input_data.section_id,
|
||||
"label": input_data.label,
|
||||
"filter": input_data.filter,
|
||||
"lang": input_data.lang,
|
||||
"ids": input_data.ids,
|
||||
}
|
||||
|
||||
tasks = self.get_tasks(
|
||||
credentials, **{k: v for k, v in task_filters.items() if v is not None}
|
||||
)
|
||||
|
||||
yield "ids", [task["id"] for task in tasks]
|
||||
yield "urls", [task["url"] for task in tasks]
|
||||
yield "complete_data", tasks
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistGetTaskBlock(Block):
|
||||
"""Get an active task from Todoist"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
task_id: str = SchemaField(description="Task ID to retrieve")
|
||||
|
||||
class Output(BlockSchema):
|
||||
project_id: str = SchemaField(description="Project ID containing the task")
|
||||
url: str = SchemaField(description="Task URL")
|
||||
complete_data: dict = SchemaField(
|
||||
description="Complete task data as dictionary"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="16d7dc8c-de15-11ef-8ace-32d3674e8b7e",
|
||||
description="Get an active task from Todoist",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistGetTaskBlock.Input,
|
||||
output_schema=TodoistGetTaskBlock.Output,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "task_id": "2995104339"},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("project_id", "2203306141"),
|
||||
("url", "https://todoist.com/showTask?id=2995104339"),
|
||||
(
|
||||
"complete_data",
|
||||
{
|
||||
"id": "2995104339",
|
||||
"project_id": "2203306141",
|
||||
"url": "https://todoist.com/showTask?id=2995104339",
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_task": lambda *args, **kwargs: {
|
||||
"project_id": "2203306141",
|
||||
"id": "2995104339",
|
||||
"url": "https://todoist.com/showTask?id=2995104339",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_task(credentials: TodoistCredentials, task_id: str):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
task = api.get_task(task_id=task_id)
|
||||
return Task.to_dict(task)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
task_data = self.get_task(credentials, input_data.task_id)
|
||||
|
||||
if task_data:
|
||||
yield "project_id", task_data["project_id"]
|
||||
yield "url", task_data["url"]
|
||||
yield "complete_data", task_data
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistUpdateTaskBlock(Block):
|
||||
"""Updates an existing task in Todoist"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
task_id: str = SchemaField(description="Task ID to update")
|
||||
content: str = SchemaField(description="Task content", advanced=False)
|
||||
description: Optional[str] = SchemaField(
|
||||
description="Task description", default=None, advanced=False
|
||||
)
|
||||
project_id: Optional[str] = SchemaField(
|
||||
description="Project ID this task should belong to",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
section_id: Optional[str] = SchemaField(
|
||||
description="Section ID this task should belong to",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
parent_id: Optional[str] = SchemaField(
|
||||
description="Parent task ID", default=None, advanced=True
|
||||
)
|
||||
order: Optional[int] = SchemaField(
|
||||
description="Optional order among other tasks,[Non-zero integer value used by clients to sort tasks under the same parent]",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
labels: Optional[list[str]] = SchemaField(
|
||||
description="Task labels", default=None, advanced=True
|
||||
)
|
||||
priority: Optional[int] = SchemaField(
|
||||
description="Task priority from 1 (normal) to 4 (urgent)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
due_date: Optional[datetime] = SchemaField(
|
||||
description="Due date in YYYY-MM-DD format", advanced=True, default=None
|
||||
)
|
||||
deadline_date: Optional[datetime] = SchemaField(
|
||||
description="Specific date in YYYY-MM-DD format relative to user's timezone",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
assignee_id: Optional[str] = SchemaField(
|
||||
description="Responsible user ID", default=None, advanced=True
|
||||
)
|
||||
duration_unit: Optional[str] = SchemaField(
|
||||
description="Task duration unit (minute/day)", default=None, advanced=True
|
||||
)
|
||||
duration: Optional[int] = SchemaField(
|
||||
description="Task duration amount, You need to selecct the duration unit first",
|
||||
depends_on=["duration_unit"],
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the update was successful")
|
||||
error: str = SchemaField(description="Error message if request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1eee6d32-de15-11ef-a2ff-32d3674e8b7e",
|
||||
description="Updates an existing task in Todoist",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistUpdateTaskBlock.Input,
|
||||
output_schema=TodoistUpdateTaskBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"task_id": "2995104339",
|
||||
"content": "Buy Coffee",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"update_task": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_task(credentials: TodoistCredentials, task_id: str, **kwargs):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
is_success = api.update_task(task_id=task_id, **kwargs)
|
||||
return is_success
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
due_date = (
|
||||
input_data.due_date.strftime("%Y-%m-%d")
|
||||
if input_data.due_date
|
||||
else None
|
||||
)
|
||||
deadline_date = (
|
||||
input_data.deadline_date.strftime("%Y-%m-%d")
|
||||
if input_data.deadline_date
|
||||
else None
|
||||
)
|
||||
|
||||
task_updates = {}
|
||||
if input_data.content is not None:
|
||||
task_updates["content"] = input_data.content
|
||||
if input_data.description is not None:
|
||||
task_updates["description"] = input_data.description
|
||||
if input_data.project_id is not None:
|
||||
task_updates["project_id"] = input_data.project_id
|
||||
if input_data.section_id is not None:
|
||||
task_updates["section_id"] = input_data.section_id
|
||||
if input_data.parent_id is not None:
|
||||
task_updates["parent_id"] = input_data.parent_id
|
||||
if input_data.order is not None:
|
||||
task_updates["order"] = input_data.order
|
||||
if input_data.labels is not None:
|
||||
task_updates["labels"] = input_data.labels
|
||||
if input_data.priority is not None:
|
||||
task_updates["priority"] = input_data.priority
|
||||
if due_date is not None:
|
||||
task_updates["due_date"] = due_date
|
||||
if deadline_date is not None:
|
||||
task_updates["deadline_date"] = deadline_date
|
||||
if input_data.assignee_id is not None:
|
||||
task_updates["assignee_id"] = input_data.assignee_id
|
||||
if input_data.duration is not None:
|
||||
task_updates["duration"] = input_data.duration
|
||||
if input_data.duration_unit is not None:
|
||||
task_updates["duration_unit"] = input_data.duration_unit
|
||||
|
||||
self.update_task(
|
||||
credentials,
|
||||
input_data.task_id,
|
||||
**{k: v for k, v in task_updates.items() if v is not None},
|
||||
)
|
||||
|
||||
yield "success", True
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistCloseTaskBlock(Block):
|
||||
"""Closes a task in Todoist"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
task_id: str = SchemaField(description="Task ID to close")
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the task was successfully closed"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="29fac798-de15-11ef-b839-32d3674e8b7e",
|
||||
description="Closes a task in Todoist",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistCloseTaskBlock.Input,
|
||||
output_schema=TodoistCloseTaskBlock.Output,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "task_id": "2995104339"},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"close_task": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def close_task(credentials: TodoistCredentials, task_id: str):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
is_success = api.close_task(task_id=task_id)
|
||||
return is_success
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
is_success = self.close_task(credentials, input_data.task_id)
|
||||
yield "success", is_success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistReopenTaskBlock(Block):
|
||||
"""Reopens a task in Todoist"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
task_id: str = SchemaField(description="Task ID to reopen")
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the task was successfully reopened"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2e6bf6f8-de15-11ef-ae7c-32d3674e8b7e",
|
||||
description="Reopens a task in Todoist",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistReopenTaskBlock.Input,
|
||||
output_schema=TodoistReopenTaskBlock.Output,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "task_id": "2995104339"},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"reopen_task": lambda *args, **kwargs: (True)},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def reopen_task(credentials: TodoistCredentials, task_id: str):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
is_success = api.reopen_task(task_id=task_id)
|
||||
return is_success
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
is_success = self.reopen_task(credentials, input_data.task_id)
|
||||
yield "success", is_success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TodoistDeleteTaskBlock(Block):
|
||||
"""Deletes a task in Todoist"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TodoistCredentialsInput = TodoistCredentialsField([])
|
||||
task_id: str = SchemaField(description="Task ID to delete")
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the task was successfully deleted"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="33c29ada-de15-11ef-bcbb-32d3674e8b7e",
|
||||
description="Deletes a task in Todoist",
|
||||
categories={BlockCategory.PRODUCTIVITY},
|
||||
input_schema=TodoistDeleteTaskBlock.Input,
|
||||
output_schema=TodoistDeleteTaskBlock.Output,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT, "task_id": "2995104339"},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"delete_task": lambda *args, **kwargs: (True)},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def delete_task(credentials: TodoistCredentials, task_id: str):
|
||||
try:
|
||||
api = TodoistAPI(credentials.access_token.get_secret_value())
|
||||
is_success = api.delete_task(task_id=task_id)
|
||||
return is_success
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TodoistCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
is_success = self.delete_task(credentials, input_data.task_id)
|
||||
yield "success", is_success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
@@ -92,8 +92,7 @@ class TwitterPostTweetBlock(Block):
|
||||
attachment: Union[Media, DeepLink, Poll, Place, Quote] | None = SchemaField(
|
||||
discriminator="discriminator",
|
||||
description="Additional tweet data (media, deep link, poll, place or quote)",
|
||||
advanced=False,
|
||||
default=Media(discriminator="media"),
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
exclude_reply_user_ids: Optional[List[str]] = SchemaField(
|
||||
|
||||
@@ -23,6 +23,71 @@ from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterUnblockUserBlock(Block):
|
||||
"""
|
||||
Unblock a specific user on Twitter. The request succeeds with no action when the user sends a request to a user they're not blocking or have already unblocked.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["block.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
target_user_id: str = SchemaField(
|
||||
description="The user ID of the user that you would like to unblock",
|
||||
placeholder="Enter target user ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the unblock was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0f1b6570-a631-11ef-a3ea-230cbe9650dd",
|
||||
description="This block unblocks a specific user on Twitter.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterUnblockUserBlock.Input,
|
||||
output_schema=TwitterUnblockUserBlock.Output,
|
||||
test_input={
|
||||
"target_user_id": "12345",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"unblock_user": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unblock_user(credentials: TwitterCredentials, target_user_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.unblock(target_user_id=target_user_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.unblock_user(credentials, input_data.target_user_id)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetBlockedUsersBlock(Block):
|
||||
"""
|
||||
Get a list of users who are blocked by the authenticating user
|
||||
@@ -173,3 +238,68 @@ class TwitterGetBlockedUsersBlock(Block):
|
||||
yield "next_token", next_token
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterBlockUserBlock(Block):
|
||||
"""
|
||||
Block a specific user on Twitter
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["block.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
target_user_id: str = SchemaField(
|
||||
description="The user ID of the user that you would like to block",
|
||||
placeholder="Enter target user ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the block was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fc258b94-a630-11ef-abc3-df050b75b816",
|
||||
description="This block blocks a specific user on Twitter.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterBlockUserBlock.Input,
|
||||
output_schema=TwitterBlockUserBlock.Output,
|
||||
test_input={
|
||||
"target_user_id": "12345",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"block_user": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def block_user(credentials: TwitterCredentials, target_user_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.block(target_user_id=target_user_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.block_user(credentials, input_data.target_user_id)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
from gravitasml.parser import Parser
|
||||
from gravitasml.token import tokenize
|
||||
|
||||
from backend.data.block import Block, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class XMLParserBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input_xml: str = SchemaField(description="input xml to be parsed")
|
||||
|
||||
class Output(BlockSchema):
|
||||
parsed_xml: dict = SchemaField(description="output parsed xml to dict")
|
||||
error: str = SchemaField(description="Error in parsing")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="286380af-9529-4b55-8be0-1d7c854abdb5",
|
||||
description="Parses XML using gravitasml to tokenize and coverts it to dict",
|
||||
input_schema=XMLParserBlock.Input,
|
||||
output_schema=XMLParserBlock.Output,
|
||||
test_input={"input_xml": "<tag1><tag2>content</tag2></tag1>"},
|
||||
test_output=[
|
||||
("parsed_xml", {"tag1": {"tag2": "content"}}),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
tokens = tokenize(input_data.input_xml)
|
||||
parser = Parser(tokens)
|
||||
parsed_result = parser.parse()
|
||||
yield "parsed_xml", parsed_result
|
||||
except ValueError as val_e:
|
||||
raise ValueError(f"Validation error for dict:{val_e}") from val_e
|
||||
except SyntaxError as syn_e:
|
||||
raise SyntaxError(f"Error in input xml syntax: {syn_e}") from syn_e
|
||||
@@ -1,10 +0,0 @@
|
||||
from zerobouncesdk import ZBValidateResponse, ZeroBounce
|
||||
|
||||
|
||||
class ZeroBounceClient:
|
||||
def __init__(self, api_key: str):
|
||||
self.api_key = api_key
|
||||
self.client = ZeroBounce(api_key)
|
||||
|
||||
def validate_email(self, email: str, ip_address: str) -> ZBValidateResponse:
|
||||
return self.client.validate(email, ip_address)
|
||||
@@ -1,35 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
ZeroBounceCredentials = APIKeyCredentials
|
||||
ZeroBounceCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.ZEROBOUNCE],
|
||||
Literal["api_key"],
|
||||
]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="zerobounce",
|
||||
api_key=SecretStr("mock-zerobounce-api-key"),
|
||||
title="Mock ZeroBounce API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
def ZeroBounceCredentialsField() -> ZeroBounceCredentialsInput:
|
||||
"""
|
||||
Creates a ZeroBounce credentials input on a block.
|
||||
"""
|
||||
return CredentialsField(
|
||||
description="The ZeroBounce integration can be used with an API Key.",
|
||||
)
|
||||
@@ -1,175 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from zerobouncesdk.zb_validate_response import (
|
||||
ZBValidateResponse,
|
||||
ZBValidateStatus,
|
||||
ZBValidateSubStatus,
|
||||
)
|
||||
|
||||
from backend.blocks.zerobounce._api import ZeroBounceClient
|
||||
from backend.blocks.zerobounce._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
ZeroBounceCredentials,
|
||||
ZeroBounceCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
address: str = SchemaField(
|
||||
description="The email address you are validating.", default="N/A"
|
||||
)
|
||||
status: ZBValidateStatus = SchemaField(
|
||||
description="The status of the email address.", default=ZBValidateStatus.unknown
|
||||
)
|
||||
sub_status: ZBValidateSubStatus = SchemaField(
|
||||
description="The sub-status of the email address.",
|
||||
default=ZBValidateSubStatus.none,
|
||||
)
|
||||
account: Optional[str] = SchemaField(
|
||||
description="The portion of the email address before the '@' symbol.",
|
||||
default="N/A",
|
||||
)
|
||||
domain: Optional[str] = SchemaField(
|
||||
description="The portion of the email address after the '@' symbol."
|
||||
)
|
||||
did_you_mean: Optional[str] = SchemaField(
|
||||
description="Suggestive Fix for an email typo",
|
||||
default=None,
|
||||
)
|
||||
domain_age_days: Optional[str] = SchemaField(
|
||||
description="Age of the email domain in days or [null].",
|
||||
default=None,
|
||||
)
|
||||
free_email: Optional[bool] = SchemaField(
|
||||
description="Whether the email address is a free email provider.", default=False
|
||||
)
|
||||
mx_found: Optional[bool] = SchemaField(
|
||||
description="Whether the MX record was found.", default=False
|
||||
)
|
||||
mx_record: Optional[str] = SchemaField(
|
||||
description="The MX record of the email address.", default=None
|
||||
)
|
||||
smtp_provider: Optional[str] = SchemaField(
|
||||
description="The SMTP provider of the email address.", default=None
|
||||
)
|
||||
firstname: Optional[str] = SchemaField(
|
||||
description="The first name of the email address.", default=None
|
||||
)
|
||||
lastname: Optional[str] = SchemaField(
|
||||
description="The last name of the email address.", default=None
|
||||
)
|
||||
gender: Optional[str] = SchemaField(
|
||||
description="The gender of the email address.", default=None
|
||||
)
|
||||
city: Optional[str] = SchemaField(
|
||||
description="The city of the email address.", default=None
|
||||
)
|
||||
region: Optional[str] = SchemaField(
|
||||
description="The region of the email address.", default=None
|
||||
)
|
||||
zipcode: Optional[str] = SchemaField(
|
||||
description="The zipcode of the email address.", default=None
|
||||
)
|
||||
country: Optional[str] = SchemaField(
|
||||
description="The country of the email address.", default=None
|
||||
)
|
||||
|
||||
|
||||
class ValidateEmailsBlock(Block):
|
||||
"""Search for people in Apollo"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
email: str = SchemaField(
|
||||
description="Email to validate",
|
||||
)
|
||||
ip_address: str = SchemaField(
|
||||
description="IP address to validate",
|
||||
default="",
|
||||
)
|
||||
credentials: ZeroBounceCredentialsInput = SchemaField(
|
||||
description="ZeroBounce credentials",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: Response = SchemaField(
|
||||
description="Response from ZeroBounce",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the search failed",
|
||||
default="",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e3950439-fa0b-40e8-b19f-e0dca0bf5853",
|
||||
description="Validate emails",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ValidateEmailsBlock.Input,
|
||||
output_schema=ValidateEmailsBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"email": "test@test.com",
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"response",
|
||||
Response(
|
||||
address="test@test.com",
|
||||
status=ZBValidateStatus.valid,
|
||||
sub_status=ZBValidateSubStatus.allowed,
|
||||
account="test",
|
||||
domain="test.com",
|
||||
did_you_mean=None,
|
||||
domain_age_days=None,
|
||||
free_email=False,
|
||||
mx_found=False,
|
||||
mx_record=None,
|
||||
smtp_provider=None,
|
||||
),
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"validate_email": lambda email, ip_address, credentials: ZBValidateResponse(
|
||||
data={
|
||||
"address": email,
|
||||
"status": ZBValidateStatus.valid,
|
||||
"sub_status": ZBValidateSubStatus.allowed,
|
||||
"account": "test",
|
||||
"domain": "test.com",
|
||||
"did_you_mean": None,
|
||||
"domain_age_days": None,
|
||||
"free_email": False,
|
||||
"mx_found": False,
|
||||
"mx_record": None,
|
||||
"smtp_provider": None,
|
||||
}
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_email(
|
||||
email: str, ip_address: str, credentials: ZeroBounceCredentials
|
||||
) -> ZBValidateResponse:
|
||||
client = ZeroBounceClient(credentials.api_key.get_secret_value())
|
||||
return client.validate_email(email, ip_address)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: ZeroBounceCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
response: ZBValidateResponse = self.validate_email(
|
||||
input_data.email, input_data.ip_address, credentials
|
||||
)
|
||||
|
||||
response_model = Response(**response.__dict__)
|
||||
|
||||
yield "response", response_model
|
||||
@@ -220,8 +220,8 @@ def event():
|
||||
|
||||
@test.command()
|
||||
@click.argument("server_address")
|
||||
@click.argument("graph_exec_id")
|
||||
def websocket(server_address: str, graph_exec_id: str):
|
||||
@click.argument("graph_id")
|
||||
def websocket(server_address: str, graph_id: str):
|
||||
"""
|
||||
Tests the websocket connection.
|
||||
"""
|
||||
@@ -229,21 +229,15 @@ def websocket(server_address: str, graph_exec_id: str):
|
||||
|
||||
import websockets.asyncio.client
|
||||
|
||||
from backend.server.ws_api import (
|
||||
WSMessage,
|
||||
WSMethod,
|
||||
WSSubscribeGraphExecutionRequest,
|
||||
)
|
||||
from backend.server.ws_api import ExecutionSubscription, Methods, WsMessage
|
||||
|
||||
async def send_message(server_address: str):
|
||||
uri = f"ws://{server_address}"
|
||||
async with websockets.asyncio.client.connect(uri) as websocket:
|
||||
try:
|
||||
msg = WSMessage(
|
||||
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
|
||||
data=WSSubscribeGraphExecutionRequest(
|
||||
graph_exec_id=graph_exec_id,
|
||||
).model_dump(),
|
||||
msg = WsMessage(
|
||||
method=Methods.SUBSCRIBE,
|
||||
data=ExecutionSubscription(graph_id=graph_id).model_dump(),
|
||||
).model_dump_json()
|
||||
await websocket.send(msg)
|
||||
print(f"Sending: {msg}")
|
||||
|
||||
@@ -2,7 +2,6 @@ import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
Generator,
|
||||
@@ -19,8 +18,6 @@ import jsonschema
|
||||
from prisma.models import AgentBlock
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.settings import Config
|
||||
|
||||
@@ -31,9 +28,6 @@ from .model import (
|
||||
is_credentials_field_name,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .graph import Link
|
||||
|
||||
app_config = Config()
|
||||
|
||||
BlockData = tuple[str, Any] # Input & Output data should be a tuple of (name, data).
|
||||
@@ -50,7 +44,6 @@ class BlockType(Enum):
|
||||
WEBHOOK = "Webhook"
|
||||
WEBHOOK_MANUAL = "Webhook (manual)"
|
||||
AGENT = "Agent"
|
||||
AI = "AI"
|
||||
|
||||
|
||||
class BlockCategory(Enum):
|
||||
@@ -71,9 +64,6 @@ class BlockCategory(Enum):
|
||||
SAFETY = (
|
||||
"Block that provides AI safety mechanisms such as detecting harmful content"
|
||||
)
|
||||
PRODUCTIVITY = "Block that helps with productivity"
|
||||
ISSUE_TRACKING = "Block that helps with issue tracking"
|
||||
MULTIMEDIA = "Block that interacts with multimedia content"
|
||||
|
||||
def dict(self) -> dict[str, str]:
|
||||
return {"category": self.name, "description": self.value}
|
||||
@@ -116,30 +106,21 @@ class BlockSchema(BaseModel):
|
||||
def validate_data(cls, data: BlockInput) -> str | None:
|
||||
return json.validate_with_jsonschema(schema=cls.jsonschema(), data=data)
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
return cls.validate_data(data)
|
||||
|
||||
@classmethod
|
||||
def get_field_schema(cls, field_name: str) -> dict[str, Any]:
|
||||
model_schema = cls.jsonschema().get("properties", {})
|
||||
if not model_schema:
|
||||
raise ValueError(f"Invalid model schema {cls}")
|
||||
|
||||
property_schema = model_schema.get(field_name)
|
||||
if not property_schema:
|
||||
raise ValueError(f"Invalid property name {field_name}")
|
||||
|
||||
return property_schema
|
||||
|
||||
@classmethod
|
||||
def validate_field(cls, field_name: str, data: BlockInput) -> str | None:
|
||||
"""
|
||||
Validate the data against a specific property (one of the input/output name).
|
||||
Returns the validation error message if the data does not match the schema.
|
||||
"""
|
||||
model_schema = cls.jsonschema().get("properties", {})
|
||||
if not model_schema:
|
||||
return f"Invalid model schema {cls}"
|
||||
|
||||
property_schema = model_schema.get(field_name)
|
||||
if not property_schema:
|
||||
return f"Invalid property name {field_name}"
|
||||
|
||||
try:
|
||||
property_schema = cls.get_field_schema(field_name)
|
||||
jsonschema.validate(json.to_dict(data), property_schema)
|
||||
return None
|
||||
except jsonschema.ValidationError as e:
|
||||
@@ -202,19 +183,6 @@ class BlockSchema(BaseModel):
|
||||
)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||
return data # Return as is, by default.
|
||||
|
||||
@classmethod
|
||||
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
|
||||
input_fields_from_nodes = {link.sink_name for link in links}
|
||||
return input_fields_from_nodes - set(data)
|
||||
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
return cls.get_required_fields() - set(data)
|
||||
|
||||
|
||||
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchema)
|
||||
BlockSchemaOutputType = TypeVar("BlockSchemaOutputType", bound=BlockSchema)
|
||||
@@ -231,7 +199,7 @@ class BlockManualWebhookConfig(BaseModel):
|
||||
the user has to manually set up the webhook at the provider.
|
||||
"""
|
||||
|
||||
provider: ProviderName
|
||||
provider: str
|
||||
"""The service provider that the webhook connects to"""
|
||||
|
||||
webhook_type: str
|
||||
@@ -323,7 +291,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.static_output = static_output
|
||||
self.block_type = block_type
|
||||
self.webhook_config = webhook_config
|
||||
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
||||
self.execution_stats = {}
|
||||
|
||||
if self.webhook_config:
|
||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||
@@ -380,14 +348,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
Run the block with the given input data.
|
||||
Args:
|
||||
input_data: The input data with the structure of input_schema.
|
||||
|
||||
Kwargs: Currently 14/02/2025 these include
|
||||
graph_id: The ID of the graph.
|
||||
node_id: The ID of the node.
|
||||
graph_exec_id: The ID of the graph execution.
|
||||
node_exec_id: The ID of the node execution.
|
||||
user_id: The ID of the user.
|
||||
|
||||
Returns:
|
||||
A Generator that yields (output_name, output_data).
|
||||
output_name: One of the output name defined in Block's output_schema.
|
||||
@@ -401,29 +361,18 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
return data
|
||||
raise ValueError(f"{self.name} did not produce any output for {output}")
|
||||
|
||||
def merge_stats(self, stats: NodeExecutionStats) -> NodeExecutionStats:
|
||||
stats_dict = stats.model_dump()
|
||||
current_stats = self.execution_stats.model_dump()
|
||||
|
||||
for key, value in stats_dict.items():
|
||||
if key not in current_stats:
|
||||
# Field doesn't exist yet, just set it, but this will probably
|
||||
# not happen, just in case though so we throw for invalid when
|
||||
# converting back in
|
||||
current_stats[key] = value
|
||||
elif isinstance(value, dict) and isinstance(current_stats[key], dict):
|
||||
current_stats[key].update(value)
|
||||
elif isinstance(value, (int, float)) and isinstance(
|
||||
current_stats[key], (int, float)
|
||||
):
|
||||
current_stats[key] += value
|
||||
elif isinstance(value, list) and isinstance(current_stats[key], list):
|
||||
current_stats[key].extend(value)
|
||||
def merge_stats(self, stats: dict[str, Any]) -> dict[str, Any]:
|
||||
for key, value in stats.items():
|
||||
if isinstance(value, dict):
|
||||
self.execution_stats.setdefault(key, {}).update(value)
|
||||
elif isinstance(value, (int, float)):
|
||||
self.execution_stats.setdefault(key, 0)
|
||||
self.execution_stats[key] += value
|
||||
elif isinstance(value, list):
|
||||
self.execution_stats.setdefault(key, [])
|
||||
self.execution_stats[key].extend(value)
|
||||
else:
|
||||
current_stats[key] = value
|
||||
|
||||
self.execution_stats = NodeExecutionStats(**current_stats)
|
||||
|
||||
self.execution_stats[key] = value
|
||||
return self.execution_stats
|
||||
|
||||
@property
|
||||
@@ -467,9 +416,9 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
|
||||
|
||||
def get_blocks() -> dict[str, Type[Block]]:
|
||||
from backend.blocks import load_all_blocks
|
||||
from backend.blocks import AVAILABLE_BLOCKS # noqa: E402
|
||||
|
||||
return load_all_blocks()
|
||||
return AVAILABLE_BLOCKS
|
||||
|
||||
|
||||
async def initialize_blocks() -> None:
|
||||
|
||||
@@ -15,7 +15,6 @@ from backend.blocks.llm import (
|
||||
LlmModel,
|
||||
)
|
||||
from backend.blocks.replicate_flux_advanced import ReplicateFluxAdvancedModelBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
from backend.data.block import Block
|
||||
@@ -36,8 +35,6 @@ from backend.integrations.credentials_store import (
|
||||
# =============== Configure the cost for each LLM Model call =============== #
|
||||
|
||||
MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.O3_MINI: 2, # $1.10 / $4.40
|
||||
LlmModel.O1: 16, # $15 / $60
|
||||
LlmModel.O1_PREVIEW: 16,
|
||||
LlmModel.O1_MINI: 4,
|
||||
LlmModel.GPT4O_MINI: 1,
|
||||
@@ -45,21 +42,20 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.GPT4_TURBO: 10,
|
||||
LlmModel.GPT3_5_TURBO: 1,
|
||||
LlmModel.CLAUDE_3_5_SONNET: 4,
|
||||
LlmModel.CLAUDE_3_5_HAIKU: 1, # $0.80 / $4.00
|
||||
LlmModel.CLAUDE_3_HAIKU: 1,
|
||||
LlmModel.LLAMA3_8B: 1,
|
||||
LlmModel.LLAMA3_70B: 1,
|
||||
LlmModel.MIXTRAL_8X7B: 1,
|
||||
LlmModel.GEMMA_7B: 1,
|
||||
LlmModel.GEMMA2_9B: 1,
|
||||
LlmModel.LLAMA3_3_70B: 1, # $0.59 / $0.79
|
||||
LlmModel.LLAMA3_1_405B: 1,
|
||||
LlmModel.LLAMA3_1_70B: 1,
|
||||
LlmModel.LLAMA3_1_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_3: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_2: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_405B: 1,
|
||||
LlmModel.DEEPSEEK_LLAMA_70B: 1, # ? / ?
|
||||
LlmModel.OLLAMA_DOLPHIN: 1,
|
||||
LlmModel.GEMINI_FLASH_1_5: 1,
|
||||
LlmModel.GEMINI_FLASH_1_5_8B: 1,
|
||||
LlmModel.GROK_BETA: 5,
|
||||
LlmModel.MISTRAL_NEMO: 1,
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: 1,
|
||||
@@ -266,5 +262,4 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
},
|
||||
)
|
||||
],
|
||||
SmartDecisionMakerBlock: LLM_COST,
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ class BlockCostType(str, Enum):
|
||||
RUN = "run" # cost X credits per run
|
||||
BYTE = "byte" # cost X credits per byte
|
||||
SECOND = "second" # cost X credits per second
|
||||
DOLLAR = "dollar" # cost X dollars per run
|
||||
|
||||
|
||||
class BlockCost(BaseModel):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import zlib
|
||||
from contextlib import asynccontextmanager
|
||||
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
|
||||
from uuid import uuid4
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -16,36 +14,7 @@ load_dotenv()
|
||||
PRISMA_SCHEMA = os.getenv("PRISMA_SCHEMA", "schema.prisma")
|
||||
os.environ["PRISMA_SCHEMA_PATH"] = PRISMA_SCHEMA
|
||||
|
||||
|
||||
def add_param(url: str, key: str, value: str) -> str:
|
||||
p = urlparse(url)
|
||||
qs = dict(parse_qsl(p.query))
|
||||
qs[key] = value
|
||||
return urlunparse(p._replace(query=urlencode(qs)))
|
||||
|
||||
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
||||
|
||||
CONN_LIMIT = os.getenv("DB_CONNECTION_LIMIT")
|
||||
if CONN_LIMIT:
|
||||
DATABASE_URL = add_param(DATABASE_URL, "connection_limit", CONN_LIMIT)
|
||||
|
||||
CONN_TIMEOUT = os.getenv("DB_CONNECT_TIMEOUT")
|
||||
if CONN_TIMEOUT:
|
||||
DATABASE_URL = add_param(DATABASE_URL, "connect_timeout", CONN_TIMEOUT)
|
||||
|
||||
POOL_TIMEOUT = os.getenv("DB_POOL_TIMEOUT")
|
||||
if POOL_TIMEOUT:
|
||||
DATABASE_URL = add_param(DATABASE_URL, "pool_timeout", POOL_TIMEOUT)
|
||||
|
||||
HTTP_TIMEOUT = int(POOL_TIMEOUT) if POOL_TIMEOUT else None
|
||||
|
||||
prisma = Prisma(
|
||||
auto_register=True,
|
||||
http={"timeout": HTTP_TIMEOUT},
|
||||
datasource={"url": DATABASE_URL},
|
||||
)
|
||||
|
||||
prisma = Prisma(auto_register=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -85,14 +54,6 @@ async def transaction():
|
||||
yield tx
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def locked_transaction(key: str):
|
||||
lock_key = zlib.crc32(key.encode("utf-8"))
|
||||
async with transaction() as tx:
|
||||
await tx.execute_raw(f"SELECT pg_advisory_xact_lock({lock_key})")
|
||||
yield tx
|
||||
|
||||
|
||||
class BaseDbModel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
|
||||
|
||||
@@ -1,199 +1,65 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from multiprocessing import Manager
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Generator,
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
TypeVar,
|
||||
overload,
|
||||
)
|
||||
from typing import Any, AsyncGenerator, Generator, Generic, TypeVar
|
||||
|
||||
from prisma import Json
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
from prisma.models import (
|
||||
AgentGraphExecution,
|
||||
AgentNodeExecution,
|
||||
AgentNodeExecutionInputOutput,
|
||||
)
|
||||
from prisma.types import (
|
||||
AgentGraphExecutionWhereInput,
|
||||
AgentNodeExecutionUpdateInput,
|
||||
AgentNodeExecutionWhereInput,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util import mock
|
||||
from backend.util import type as type_utils
|
||||
from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
|
||||
from backend.data.includes import EXECUTION_RESULT_INCLUDE, GRAPH_EXECUTION_INCLUDE
|
||||
from backend.data.queue import AsyncRedisEventBus, RedisEventBus
|
||||
from backend.util import json, mock
|
||||
from backend.util.settings import Config
|
||||
|
||||
from .block import BlockData, BlockInput, BlockType, CompletedBlockOutput, get_block
|
||||
from .db import BaseDbModel
|
||||
from .includes import (
|
||||
EXECUTION_RESULT_INCLUDE,
|
||||
GRAPH_EXECUTION_INCLUDE,
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
from .model import GraphExecutionStats, NodeExecutionStats
|
||||
from .queue import AsyncRedisEventBus, RedisEventBus
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = Config()
|
||||
class GraphExecutionEntry(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
start_node_execs: list["NodeExecutionEntry"]
|
||||
|
||||
|
||||
# -------------------------- Models -------------------------- #
|
||||
class NodeExecutionEntry(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
node_exec_id: str
|
||||
node_id: str
|
||||
data: BlockInput
|
||||
|
||||
|
||||
ExecutionStatus = AgentExecutionStatus
|
||||
|
||||
|
||||
class GraphExecutionMeta(BaseDbModel):
|
||||
user_id: str
|
||||
started_at: datetime
|
||||
ended_at: datetime
|
||||
cost: Optional[int] = Field(..., description="Execution cost in credits")
|
||||
duration: float = Field(..., description="Seconds from start to end of run")
|
||||
total_run_time: float = Field(..., description="Seconds of node runtime")
|
||||
status: ExecutionStatus
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
preset_id: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(_graph_exec: AgentGraphExecution):
|
||||
now = datetime.now(timezone.utc)
|
||||
start_time = _graph_exec.startedAt or _graph_exec.createdAt
|
||||
end_time = _graph_exec.updatedAt or now
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
total_run_time = duration
|
||||
|
||||
try:
|
||||
stats = GraphExecutionStats.model_validate(_graph_exec.stats)
|
||||
except ValueError as e:
|
||||
if _graph_exec.stats is not None:
|
||||
logger.warning(
|
||||
"Failed to parse invalid graph execution stats "
|
||||
f"{_graph_exec.stats}: {e}"
|
||||
)
|
||||
stats = None
|
||||
|
||||
duration = stats.walltime if stats else duration
|
||||
total_run_time = stats.nodes_walltime if stats else total_run_time
|
||||
|
||||
return GraphExecutionMeta(
|
||||
id=_graph_exec.id,
|
||||
user_id=_graph_exec.userId,
|
||||
started_at=start_time,
|
||||
ended_at=end_time,
|
||||
cost=stats.cost if stats else None,
|
||||
duration=duration,
|
||||
total_run_time=total_run_time,
|
||||
status=ExecutionStatus(_graph_exec.executionStatus),
|
||||
graph_id=_graph_exec.agentGraphId,
|
||||
graph_version=_graph_exec.agentGraphVersion,
|
||||
preset_id=_graph_exec.agentPresetId,
|
||||
)
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class GraphExecution(GraphExecutionMeta):
|
||||
inputs: BlockInput
|
||||
outputs: CompletedBlockOutput
|
||||
class ExecutionQueue(Generic[T]):
|
||||
"""
|
||||
Queue for managing the execution of agents.
|
||||
This will be shared between different processes
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_db(_graph_exec: AgentGraphExecution):
|
||||
if _graph_exec.AgentNodeExecutions is None:
|
||||
raise ValueError("Node executions must be included in query")
|
||||
def __init__(self):
|
||||
self.queue = Manager().Queue()
|
||||
|
||||
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
|
||||
def add(self, execution: T) -> T:
|
||||
self.queue.put(execution)
|
||||
return execution
|
||||
|
||||
node_executions = sorted(
|
||||
[
|
||||
NodeExecutionResult.from_db(ne, _graph_exec.userId)
|
||||
for ne in _graph_exec.AgentNodeExecutions
|
||||
],
|
||||
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
|
||||
)
|
||||
def get(self) -> T:
|
||||
return self.queue.get()
|
||||
|
||||
inputs = {
|
||||
**{
|
||||
# inputs from Agent Input Blocks
|
||||
exec.input_data["name"]: exec.input_data.get("value")
|
||||
for exec in node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type == BlockType.INPUT
|
||||
)
|
||||
},
|
||||
**{
|
||||
# input from webhook-triggered block
|
||||
"payload": exec.input_data["payload"]
|
||||
for exec in node_executions
|
||||
if (
|
||||
(block := get_block(exec.block_id))
|
||||
and block.block_type
|
||||
in [BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL]
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
outputs: CompletedBlockOutput = defaultdict(list)
|
||||
for exec in node_executions:
|
||||
if (
|
||||
block := get_block(exec.block_id)
|
||||
) and block.block_type == BlockType.OUTPUT:
|
||||
outputs[exec.input_data["name"]].append(
|
||||
exec.input_data.get("value", None)
|
||||
)
|
||||
|
||||
return GraphExecution(
|
||||
**{
|
||||
field_name: getattr(graph_exec, field_name)
|
||||
for field_name in graph_exec.model_fields
|
||||
},
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
)
|
||||
def empty(self) -> bool:
|
||||
return self.queue.empty()
|
||||
|
||||
|
||||
class GraphExecutionWithNodes(GraphExecution):
|
||||
node_executions: list["NodeExecutionResult"]
|
||||
|
||||
@staticmethod
|
||||
def from_db(_graph_exec: AgentGraphExecution):
|
||||
if _graph_exec.AgentNodeExecutions is None:
|
||||
raise ValueError("Node executions must be included in query")
|
||||
|
||||
graph_exec_with_io = GraphExecution.from_db(_graph_exec)
|
||||
|
||||
node_executions = sorted(
|
||||
[
|
||||
NodeExecutionResult.from_db(ne, _graph_exec.userId)
|
||||
for ne in _graph_exec.AgentNodeExecutions
|
||||
],
|
||||
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
|
||||
)
|
||||
|
||||
return GraphExecutionWithNodes(
|
||||
**{
|
||||
field_name: getattr(graph_exec_with_io, field_name)
|
||||
for field_name in graph_exec_with_io.model_fields
|
||||
},
|
||||
node_executions=node_executions,
|
||||
)
|
||||
|
||||
|
||||
class NodeExecutionResult(BaseModel):
|
||||
user_id: str
|
||||
class ExecutionResult(BaseModel):
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
graph_exec_id: str
|
||||
@@ -209,30 +75,43 @@ class NodeExecutionResult(BaseModel):
|
||||
end_time: datetime | None
|
||||
|
||||
@staticmethod
|
||||
def from_db(execution: AgentNodeExecution, user_id: Optional[str] = None):
|
||||
def from_graph(graph: AgentGraphExecution):
|
||||
return ExecutionResult(
|
||||
graph_id=graph.agentGraphId,
|
||||
graph_version=graph.agentGraphVersion,
|
||||
graph_exec_id=graph.id,
|
||||
node_exec_id="",
|
||||
node_id="",
|
||||
block_id="",
|
||||
status=graph.executionStatus,
|
||||
# TODO: Populate input_data & output_data from AgentNodeExecutions
|
||||
# Input & Output comes AgentInputBlock & AgentOutputBlock.
|
||||
input_data={},
|
||||
output_data={},
|
||||
add_time=graph.createdAt,
|
||||
queue_time=graph.createdAt,
|
||||
start_time=graph.startedAt,
|
||||
end_time=graph.updatedAt,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_db(execution: AgentNodeExecution):
|
||||
if execution.executionData:
|
||||
# Execution that has been queued for execution will persist its data.
|
||||
input_data = type_utils.convert(execution.executionData, dict[str, Any])
|
||||
input_data = json.loads(execution.executionData, target_type=dict[str, Any])
|
||||
else:
|
||||
# For incomplete execution, executionData will not be yet available.
|
||||
input_data: BlockInput = defaultdict()
|
||||
for data in execution.Input or []:
|
||||
input_data[data.name] = type_utils.convert(data.data, type[Any])
|
||||
input_data[data.name] = json.loads(data.data)
|
||||
|
||||
output_data: CompletedBlockOutput = defaultdict(list)
|
||||
for data in execution.Output or []:
|
||||
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
|
||||
output_data[data.name].append(json.loads(data.data))
|
||||
|
||||
graph_execution: AgentGraphExecution | None = execution.AgentGraphExecution
|
||||
if graph_execution:
|
||||
user_id = graph_execution.userId
|
||||
elif not user_id:
|
||||
raise ValueError(
|
||||
"AgentGraphExecution must be included or user_id passed in"
|
||||
)
|
||||
|
||||
return NodeExecutionResult(
|
||||
user_id=user_id,
|
||||
return ExecutionResult(
|
||||
graph_id=graph_execution.agentGraphId if graph_execution else "",
|
||||
graph_version=graph_execution.agentGraphVersion if graph_execution else 0,
|
||||
graph_exec_id=execution.agentGraphExecutionId,
|
||||
@@ -252,88 +131,12 @@ class NodeExecutionResult(BaseModel):
|
||||
# --------------------- Model functions --------------------- #
|
||||
|
||||
|
||||
async def get_graph_executions(
|
||||
graph_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> list[GraphExecutionMeta]:
|
||||
where_filter: AgentGraphExecutionWhereInput = {
|
||||
"isDeleted": False,
|
||||
}
|
||||
if user_id:
|
||||
where_filter["userId"] = user_id
|
||||
if graph_id:
|
||||
where_filter["agentGraphId"] = graph_id
|
||||
|
||||
executions = await AgentGraphExecution.prisma().find_many(
|
||||
where=where_filter,
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [GraphExecutionMeta.from_db(execution) for execution in executions]
|
||||
|
||||
|
||||
async def get_graph_execution_meta(
|
||||
user_id: str, execution_id: str
|
||||
) -> GraphExecutionMeta | None:
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={"id": execution_id, "isDeleted": False, "userId": user_id}
|
||||
)
|
||||
return GraphExecutionMeta.from_db(execution) if execution else None
|
||||
|
||||
|
||||
@overload
|
||||
async def get_graph_execution(
|
||||
user_id: str,
|
||||
execution_id: str,
|
||||
include_node_executions: Literal[True],
|
||||
) -> GraphExecutionWithNodes | None: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def get_graph_execution(
|
||||
user_id: str,
|
||||
execution_id: str,
|
||||
include_node_executions: Literal[False] = False,
|
||||
) -> GraphExecution | None: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def get_graph_execution(
|
||||
user_id: str,
|
||||
execution_id: str,
|
||||
include_node_executions: bool = False,
|
||||
) -> GraphExecution | GraphExecutionWithNodes | None: ...
|
||||
|
||||
|
||||
async def get_graph_execution(
|
||||
user_id: str,
|
||||
execution_id: str,
|
||||
include_node_executions: bool = False,
|
||||
) -> GraphExecution | GraphExecutionWithNodes | None:
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={"id": execution_id, "isDeleted": False, "userId": user_id},
|
||||
include=(
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES
|
||||
if include_node_executions
|
||||
else GRAPH_EXECUTION_INCLUDE
|
||||
),
|
||||
)
|
||||
if not execution:
|
||||
return None
|
||||
|
||||
return (
|
||||
GraphExecutionWithNodes.from_db(execution)
|
||||
if include_node_executions
|
||||
else GraphExecution.from_db(execution)
|
||||
)
|
||||
|
||||
|
||||
async def create_graph_execution(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
nodes_input: list[tuple[str, BlockInput]],
|
||||
user_id: str,
|
||||
preset_id: str | None = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
) -> tuple[str, list[ExecutionResult]]:
|
||||
"""
|
||||
Create a new AgentGraphExecution record.
|
||||
Returns:
|
||||
@@ -348,11 +151,10 @@ async def create_graph_execution(
|
||||
"create": [ # type: ignore
|
||||
{
|
||||
"agentNodeId": node_id,
|
||||
"executionStatus": ExecutionStatus.QUEUED,
|
||||
"queuedTime": datetime.now(tz=timezone.utc),
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"Input": {
|
||||
"create": [
|
||||
{"name": name, "data": Json(data)}
|
||||
{"name": name, "data": json.dumps(data)}
|
||||
for name, data in node_input.items()
|
||||
]
|
||||
},
|
||||
@@ -361,12 +163,14 @@ async def create_graph_execution(
|
||||
]
|
||||
},
|
||||
"userId": user_id,
|
||||
"agentPresetId": preset_id,
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
|
||||
return GraphExecutionWithNodes.from_db(result)
|
||||
return result.id, [
|
||||
ExecutionResult.from_db(execution)
|
||||
for execution in result.AgentNodeExecutions or []
|
||||
]
|
||||
|
||||
|
||||
async def upsert_execution_input(
|
||||
@@ -388,24 +192,21 @@ async def upsert_execution_input(
|
||||
node_exec_id: [Optional] The id of the AgentNodeExecution that has no `input_name` as input. If not provided, it will find the eligible incomplete AgentNodeExecution or create a new one.
|
||||
|
||||
Returns:
|
||||
str: The id of the created or existing AgentNodeExecution.
|
||||
dict[str, Any]: Node input data; key is the input name, value is the input data.
|
||||
* The id of the created or existing AgentNodeExecution.
|
||||
* Dict of node input data, key is the input name, value is the input data.
|
||||
"""
|
||||
existing_exec_query_filter: AgentNodeExecutionWhereInput = {
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_exec_id,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"Input": {"every": {"name": {"not": input_name}}},
|
||||
}
|
||||
if node_exec_id:
|
||||
existing_exec_query_filter["id"] = node_exec_id
|
||||
|
||||
existing_execution = await AgentNodeExecution.prisma().find_first(
|
||||
where=existing_exec_query_filter,
|
||||
where={ # type: ignore
|
||||
**({"id": node_exec_id} if node_exec_id else {}),
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_exec_id,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"Input": {"every": {"name": {"not": input_name}}},
|
||||
},
|
||||
order={"addedTime": "asc"},
|
||||
include={"Input": True},
|
||||
)
|
||||
json_input_data = Json(input_data)
|
||||
json_input_data = json.dumps(input_data)
|
||||
|
||||
if existing_execution:
|
||||
await AgentNodeExecutionInputOutput.prisma().create(
|
||||
@@ -417,7 +218,7 @@ async def upsert_execution_input(
|
||||
)
|
||||
return existing_execution.id, {
|
||||
**{
|
||||
input_data.name: type_utils.convert(input_data.data, type[Any])
|
||||
input_data.name: json.loads(input_data.data)
|
||||
for input_data in existing_execution.Input or []
|
||||
},
|
||||
input_name: input_data,
|
||||
@@ -451,255 +252,89 @@ async def upsert_execution_output(
|
||||
await AgentNodeExecutionInputOutput.prisma().create(
|
||||
data={
|
||||
"name": output_name,
|
||||
"data": Json(output_data),
|
||||
"data": json.dumps(output_data),
|
||||
"referencedByOutputExecId": node_exec_id,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecution:
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
async def update_graph_execution_start_time(graph_exec_id: str):
|
||||
await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
data={
|
||||
"executionStatus": ExecutionStatus.RUNNING,
|
||||
"startedAt": datetime.now(tz=timezone.utc),
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
if not res:
|
||||
raise ValueError(f"Graph execution #{graph_exec_id} not found")
|
||||
|
||||
return GraphExecution.from_db(res)
|
||||
|
||||
|
||||
async def update_graph_execution_stats(
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecution | None:
|
||||
data = stats.model_dump() if stats else {}
|
||||
if isinstance(data.get("error"), Exception):
|
||||
data["error"] = str(data["error"])
|
||||
stats: dict[str, Any],
|
||||
) -> ExecutionResult:
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
"OR": [
|
||||
{"executionStatus": ExecutionStatus.RUNNING},
|
||||
{"executionStatus": ExecutionStatus.QUEUED},
|
||||
],
|
||||
},
|
||||
where={"id": graph_exec_id},
|
||||
data={
|
||||
"executionStatus": status,
|
||||
"stats": Json(data),
|
||||
"stats": json.dumps(stats),
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
if not res:
|
||||
raise ValueError(f"Execution {graph_exec_id} not found.")
|
||||
|
||||
return GraphExecution.from_db(res) if res else None
|
||||
return ExecutionResult.from_graph(res)
|
||||
|
||||
|
||||
async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionStats):
|
||||
data = stats.model_dump()
|
||||
if isinstance(data["error"], Exception):
|
||||
data["error"] = str(data["error"])
|
||||
async def update_node_execution_stats(node_exec_id: str, stats: dict[str, Any]):
|
||||
await AgentNodeExecution.prisma().update(
|
||||
where={"id": node_exec_id},
|
||||
data={"stats": Json(data)},
|
||||
data={"stats": json.dumps(stats)},
|
||||
)
|
||||
|
||||
|
||||
async def update_node_execution_status_batch(
|
||||
node_exec_ids: list[str],
|
||||
status: ExecutionStatus,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
await AgentNodeExecution.prisma().update_many(
|
||||
where={"id": {"in": node_exec_ids}},
|
||||
data=_get_update_status_data(status, None, stats),
|
||||
)
|
||||
|
||||
|
||||
async def update_node_execution_status(
|
||||
async def update_execution_status(
|
||||
node_exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: BlockInput | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
) -> NodeExecutionResult:
|
||||
) -> ExecutionResult:
|
||||
if status == ExecutionStatus.QUEUED and execution_data is None:
|
||||
raise ValueError("Execution data must be provided when queuing an execution.")
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
data = {
|
||||
**({"executionStatus": status}),
|
||||
**({"queuedTime": now} if status == ExecutionStatus.QUEUED else {}),
|
||||
**({"startedTime": now} if status == ExecutionStatus.RUNNING else {}),
|
||||
**({"endedTime": now} if status == ExecutionStatus.FAILED else {}),
|
||||
**({"endedTime": now} if status == ExecutionStatus.COMPLETED else {}),
|
||||
**({"executionData": json.dumps(execution_data)} if execution_data else {}),
|
||||
**({"stats": json.dumps(stats)} if stats else {}),
|
||||
}
|
||||
|
||||
res = await AgentNodeExecution.prisma().update(
|
||||
where={"id": node_exec_id},
|
||||
data=_get_update_status_data(status, execution_data, stats),
|
||||
data=data, # type: ignore
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
if not res:
|
||||
raise ValueError(f"Execution {node_exec_id} not found.")
|
||||
|
||||
return NodeExecutionResult.from_db(res)
|
||||
return ExecutionResult.from_db(res)
|
||||
|
||||
|
||||
def _get_update_status_data(
|
||||
status: ExecutionStatus,
|
||||
execution_data: BlockInput | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
) -> AgentNodeExecutionUpdateInput:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
update_data: AgentNodeExecutionUpdateInput = {"executionStatus": status}
|
||||
|
||||
if status == ExecutionStatus.QUEUED:
|
||||
update_data["queuedTime"] = now
|
||||
elif status == ExecutionStatus.RUNNING:
|
||||
update_data["startedTime"] = now
|
||||
elif status in (ExecutionStatus.FAILED, ExecutionStatus.COMPLETED):
|
||||
update_data["endedTime"] = now
|
||||
|
||||
if execution_data:
|
||||
update_data["executionData"] = Json(execution_data)
|
||||
if stats:
|
||||
update_data["stats"] = Json(stats)
|
||||
|
||||
return update_data
|
||||
|
||||
|
||||
async def delete_graph_execution(
|
||||
graph_exec_id: str, user_id: str, soft_delete: bool = True
|
||||
) -> None:
|
||||
if soft_delete:
|
||||
deleted_count = await AgentGraphExecution.prisma().update_many(
|
||||
where={"id": graph_exec_id, "userId": user_id}, data={"isDeleted": True}
|
||||
)
|
||||
else:
|
||||
deleted_count = await AgentGraphExecution.prisma().delete_many(
|
||||
where={"id": graph_exec_id, "userId": user_id}
|
||||
)
|
||||
if deleted_count < 1:
|
||||
raise DatabaseError(
|
||||
f"Could not delete graph execution #{graph_exec_id}: not found"
|
||||
)
|
||||
|
||||
|
||||
async def get_node_execution_results(
|
||||
graph_exec_id: str,
|
||||
block_ids: list[str] | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[NodeExecutionResult]:
|
||||
where_clause: AgentNodeExecutionWhereInput = {
|
||||
"agentGraphExecutionId": graph_exec_id,
|
||||
}
|
||||
if block_ids:
|
||||
where_clause["AgentNode"] = {"is": {"agentBlockId": {"in": block_ids}}}
|
||||
if statuses:
|
||||
where_clause["OR"] = [{"executionStatus": status} for status in statuses]
|
||||
|
||||
async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
|
||||
executions = await AgentNodeExecution.prisma().find_many(
|
||||
where=where_clause,
|
||||
where={"agentGraphExecutionId": graph_exec_id},
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
take=limit,
|
||||
)
|
||||
res = [NodeExecutionResult.from_db(execution) for execution in executions]
|
||||
return res
|
||||
|
||||
|
||||
async def get_graph_executions_in_timerange(
|
||||
user_id: str, start_time: str, end_time: str
|
||||
) -> list[GraphExecution]:
|
||||
try:
|
||||
executions = await AgentGraphExecution.prisma().find_many(
|
||||
where={
|
||||
"startedAt": {
|
||||
"gte": datetime.fromisoformat(start_time),
|
||||
"lte": datetime.fromisoformat(end_time),
|
||||
},
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
return [GraphExecution.from_db(execution) for execution in executions]
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to get executions in timerange {start_time} to {end_time} for user {user_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_latest_node_execution(
|
||||
node_id: str, graph_eid: str
|
||||
) -> NodeExecutionResult | None:
|
||||
execution = await AgentNodeExecution.prisma().find_first(
|
||||
where={
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
|
||||
},
|
||||
order=[
|
||||
{"queuedTime": "desc"},
|
||||
{"addedTime": "desc"},
|
||||
{"queuedTime": "asc"},
|
||||
{"addedTime": "asc"}, # Fallback: Incomplete execs has no queuedTime.
|
||||
],
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
if not execution:
|
||||
return None
|
||||
return NodeExecutionResult.from_db(execution)
|
||||
|
||||
|
||||
async def get_incomplete_node_executions(
|
||||
node_id: str, graph_eid: str
|
||||
) -> list[NodeExecutionResult]:
|
||||
executions = await AgentNodeExecution.prisma().find_many(
|
||||
where={
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
},
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
return [NodeExecutionResult.from_db(execution) for execution in executions]
|
||||
|
||||
|
||||
# ----------------- Execution Infrastructure ----------------- #
|
||||
|
||||
|
||||
class GraphExecutionEntry(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
start_node_execs: list["NodeExecutionEntry"]
|
||||
|
||||
|
||||
class NodeExecutionEntry(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
node_exec_id: str
|
||||
node_id: str
|
||||
block_id: str
|
||||
data: BlockInput
|
||||
|
||||
|
||||
class ExecutionQueue(Generic[T]):
|
||||
"""
|
||||
Queue for managing the execution of agents.
|
||||
This will be shared between different processes
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.queue = Manager().Queue()
|
||||
|
||||
def add(self, execution: T) -> T:
|
||||
self.queue.put(execution)
|
||||
return execution
|
||||
|
||||
def get(self) -> T:
|
||||
return self.queue.get()
|
||||
|
||||
def empty(self) -> bool:
|
||||
return self.queue.empty()
|
||||
|
||||
|
||||
# ------------------- Execution Utilities -------------------- #
|
||||
res = [ExecutionResult.from_db(execution) for execution in executions]
|
||||
return res
|
||||
|
||||
|
||||
LIST_SPLIT = "_$_"
|
||||
@@ -708,38 +343,7 @@ OBJC_SPLIT = "_@_"
|
||||
|
||||
|
||||
def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
"""
|
||||
Extracts partial output data by name from a given BlockData.
|
||||
|
||||
The function supports extracting data from lists, dictionaries, and objects
|
||||
using specific naming conventions:
|
||||
- For lists: <output_name>_$_<index>
|
||||
- For dictionaries: <output_name>_#_<key>
|
||||
- For objects: <output_name>_@_<attribute>
|
||||
|
||||
Args:
|
||||
output (BlockData): A tuple containing the output name and data.
|
||||
name (str): The name used to extract specific data from the output.
|
||||
|
||||
Returns:
|
||||
Any | None: The extracted data if found, otherwise None.
|
||||
|
||||
Examples:
|
||||
>>> output = ("result", [10, 20, 30])
|
||||
>>> parse_execution_output(output, "result_$_1")
|
||||
20
|
||||
|
||||
>>> output = ("config", {"key1": "value1", "key2": "value2"})
|
||||
>>> parse_execution_output(output, "config_#_key1")
|
||||
'value1'
|
||||
|
||||
>>> class Sample:
|
||||
... attr1 = "value1"
|
||||
... attr2 = "value2"
|
||||
>>> output = ("object", Sample())
|
||||
>>> parse_execution_output(output, "object_@_attr1")
|
||||
'value1'
|
||||
"""
|
||||
# Allow extracting partial output data by name.
|
||||
output_name, output_data = output
|
||||
|
||||
if name == output_name:
|
||||
@@ -768,37 +372,11 @@ def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
|
||||
def merge_execution_input(data: BlockInput) -> BlockInput:
|
||||
"""
|
||||
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
|
||||
|
||||
This function processes input keys that follow specific patterns to merge them into a unified structure:
|
||||
- `<input_name>_$_<index>` for list inputs.
|
||||
- `<input_name>_#_<index>` for dictionary inputs.
|
||||
- `<input_name>_@_<index>` for object inputs.
|
||||
|
||||
Args:
|
||||
data (BlockInput): A dictionary containing input keys and their corresponding values.
|
||||
|
||||
Returns:
|
||||
BlockInput: A dictionary with merged inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If a list index is not an integer.
|
||||
|
||||
Examples:
|
||||
>>> data = {
|
||||
... "list_$_0": "a",
|
||||
... "list_$_1": "b",
|
||||
... "dict_#_key1": "value1",
|
||||
... "dict_#_key2": "value2",
|
||||
... "object_@_attr1": "value1",
|
||||
... "object_@_attr2": "value2"
|
||||
... }
|
||||
>>> merge_execution_input(data)
|
||||
{
|
||||
"list": ["a", "b"],
|
||||
"dict": {"key1": "value1", "key2": "value2"},
|
||||
"object": <MockObject attr1="value1" attr2="value2">
|
||||
}
|
||||
Merge all dynamic input pins which described by the following pattern:
|
||||
- <input_name>_$_<index> for list input.
|
||||
- <input_name>_#_<index> for dict input.
|
||||
- <input_name>_@_<index> for object input.
|
||||
This function will construct pins with the same name into a single list/dict/object.
|
||||
"""
|
||||
|
||||
# Merge all input with <input_name>_$_<index> into a single list.
|
||||
@@ -837,84 +415,70 @@ def merge_execution_input(data: BlockInput) -> BlockInput:
|
||||
return data
|
||||
|
||||
|
||||
async def get_latest_execution(node_id: str, graph_eid: str) -> ExecutionResult | None:
|
||||
execution = await AgentNodeExecution.prisma().find_first(
|
||||
where={
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"executionStatus": {"not": ExecutionStatus.INCOMPLETE},
|
||||
"executionData": {"not": None}, # type: ignore
|
||||
},
|
||||
order={"queuedTime": "desc"},
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
if not execution:
|
||||
return None
|
||||
return ExecutionResult.from_db(execution)
|
||||
|
||||
|
||||
async def get_incomplete_executions(
|
||||
node_id: str, graph_eid: str
|
||||
) -> list[ExecutionResult]:
|
||||
executions = await AgentNodeExecution.prisma().find_many(
|
||||
where={
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
},
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
return [ExecutionResult.from_db(execution) for execution in executions]
|
||||
|
||||
|
||||
# --------------------- Event Bus --------------------- #
|
||||
|
||||
|
||||
class ExecutionEventType(str, Enum):
|
||||
GRAPH_EXEC_UPDATE = "graph_execution_update"
|
||||
NODE_EXEC_UPDATE = "node_execution_update"
|
||||
config = Config()
|
||||
|
||||
|
||||
class GraphExecutionEvent(GraphExecution):
|
||||
event_type: Literal[ExecutionEventType.GRAPH_EXEC_UPDATE] = (
|
||||
ExecutionEventType.GRAPH_EXEC_UPDATE
|
||||
)
|
||||
|
||||
|
||||
class NodeExecutionEvent(NodeExecutionResult):
|
||||
event_type: Literal[ExecutionEventType.NODE_EXEC_UPDATE] = (
|
||||
ExecutionEventType.NODE_EXEC_UPDATE
|
||||
)
|
||||
|
||||
|
||||
ExecutionEvent = Annotated[
|
||||
GraphExecutionEvent | NodeExecutionEvent, Field(discriminator="event_type")
|
||||
]
|
||||
|
||||
|
||||
class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
|
||||
Model = ExecutionEvent # type: ignore
|
||||
class RedisExecutionEventBus(RedisEventBus[ExecutionResult]):
|
||||
Model = ExecutionResult
|
||||
|
||||
@property
|
||||
def event_bus_name(self) -> str:
|
||||
return config.execution_event_bus_name
|
||||
|
||||
def publish(self, res: GraphExecution | NodeExecutionResult):
|
||||
if isinstance(res, GraphExecution):
|
||||
self.publish_graph_exec_update(res)
|
||||
else:
|
||||
self.publish_node_exec_update(res)
|
||||
|
||||
def publish_node_exec_update(self, res: NodeExecutionResult):
|
||||
event = NodeExecutionEvent.model_validate(res.model_dump())
|
||||
self.publish_event(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
|
||||
|
||||
def publish_graph_exec_update(self, res: GraphExecution):
|
||||
event = GraphExecutionEvent.model_validate(res.model_dump())
|
||||
self.publish_event(event, f"{res.user_id}/{res.graph_id}/{res.id}")
|
||||
def publish(self, res: ExecutionResult):
|
||||
self.publish_event(res, f"{res.graph_id}/{res.graph_exec_id}")
|
||||
|
||||
def listen(
|
||||
self, user_id: str, graph_id: str = "*", graph_exec_id: str = "*"
|
||||
) -> Generator[ExecutionEvent, None, None]:
|
||||
for event in self.listen_events(f"{user_id}/{graph_id}/{graph_exec_id}"):
|
||||
yield event
|
||||
self, graph_id: str = "*", graph_exec_id: str = "*"
|
||||
) -> Generator[ExecutionResult, None, None]:
|
||||
for execution_result in self.listen_events(f"{graph_id}/{graph_exec_id}"):
|
||||
yield execution_result
|
||||
|
||||
|
||||
class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
|
||||
Model = ExecutionEvent # type: ignore
|
||||
class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionResult]):
|
||||
Model = ExecutionResult
|
||||
|
||||
@property
|
||||
def event_bus_name(self) -> str:
|
||||
return config.execution_event_bus_name
|
||||
|
||||
async def publish(self, res: GraphExecutionMeta | NodeExecutionResult):
|
||||
if isinstance(res, GraphExecutionMeta):
|
||||
await self.publish_graph_exec_update(res)
|
||||
else:
|
||||
await self.publish_node_exec_update(res)
|
||||
|
||||
async def publish_node_exec_update(self, res: NodeExecutionResult):
|
||||
event = NodeExecutionEvent.model_validate(res.model_dump())
|
||||
await self.publish_event(
|
||||
event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}"
|
||||
)
|
||||
|
||||
async def publish_graph_exec_update(self, res: GraphExecutionMeta):
|
||||
event = GraphExecutionEvent.model_validate(res.model_dump())
|
||||
await self.publish_event(event, f"{res.user_id}/{res.graph_id}/{res.id}")
|
||||
async def publish(self, res: ExecutionResult):
|
||||
await self.publish_event(res, f"{res.graph_id}/{res.graph_exec_id}")
|
||||
|
||||
async def listen(
|
||||
self, user_id: str, graph_id: str = "*", graph_exec_id: str = "*"
|
||||
) -> AsyncGenerator[ExecutionEvent, None]:
|
||||
async for event in self.listen_events(f"{user_id}/{graph_id}/{graph_exec_id}"):
|
||||
yield event
|
||||
self, graph_id: str = "*", graph_exec_id: str = "*"
|
||||
) -> AsyncGenerator[ExecutionResult, None]:
|
||||
async for execution_result in self.listen_events(f"{graph_id}/{graph_exec_id}"):
|
||||
yield execution_result
|
||||
|
||||
@@ -1,23 +1,22 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal, Optional, Type
|
||||
|
||||
import prisma
|
||||
from prisma import Json
|
||||
from prisma.enums import SubmissionStatus
|
||||
from prisma.models import AgentGraph, AgentNode, AgentNodeLink, StoreListingVersion
|
||||
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
|
||||
from prisma.types import AgentGraphWhereInput
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import prisma as db
|
||||
from backend.util import type as type_utils
|
||||
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock
|
||||
from backend.util import json
|
||||
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
from .block import BlockInput, BlockType, get_block, get_blocks
|
||||
from .db import BaseDbModel, transaction
|
||||
from .execution import ExecutionStatus
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
||||
from .integrations import Webhook
|
||||
|
||||
@@ -62,20 +61,15 @@ class NodeModel(Node):
|
||||
|
||||
webhook: Optional[Webhook] = None
|
||||
|
||||
@property
|
||||
def block(self) -> Block[BlockSchema, BlockSchema]:
|
||||
block = get_block(self.block_id)
|
||||
if not block:
|
||||
raise ValueError(f"Block #{self.block_id} does not exist")
|
||||
return block
|
||||
|
||||
@staticmethod
|
||||
def from_db(node: AgentNode, for_export: bool = False) -> "NodeModel":
|
||||
def from_db(node: AgentNode):
|
||||
if not node.AgentBlock:
|
||||
raise ValueError(f"Invalid node {node.id}, invalid AgentBlock.")
|
||||
obj = NodeModel(
|
||||
id=node.id,
|
||||
block_id=node.agentBlockId,
|
||||
input_default=type_utils.convert(node.constantInput, dict[str, Any]),
|
||||
metadata=type_utils.convert(node.metadata, dict[str, Any]),
|
||||
block_id=node.AgentBlock.id,
|
||||
input_default=json.loads(node.constantInput, target_type=dict[str, Any]),
|
||||
metadata=json.loads(node.metadata, target_type=dict[str, Any]),
|
||||
graph_id=node.agentGraphId,
|
||||
graph_version=node.agentGraphVersion,
|
||||
webhook_id=node.webhookId,
|
||||
@@ -83,8 +77,6 @@ class NodeModel(Node):
|
||||
)
|
||||
obj.input_links = [Link.from_db(link) for link in node.Input or []]
|
||||
obj.output_links = [Link.from_db(link) for link in node.Output or []]
|
||||
if for_export:
|
||||
return obj.stripped_for_export()
|
||||
return obj
|
||||
|
||||
def is_triggered_by_event_type(self, event_type: str) -> bool:
|
||||
@@ -103,59 +95,54 @@ class NodeModel(Node):
|
||||
if event_filter[k] is True
|
||||
]
|
||||
|
||||
def stripped_for_export(self) -> "NodeModel":
|
||||
"""
|
||||
Returns a copy of the node model, stripped of any non-transferable properties
|
||||
"""
|
||||
stripped_node = self.model_copy(deep=True)
|
||||
# Remove credentials from node input
|
||||
if stripped_node.input_default:
|
||||
stripped_node.input_default = NodeModel._filter_secrets_from_node_input(
|
||||
stripped_node.input_default, self.block.input_schema.jsonschema()
|
||||
)
|
||||
|
||||
if (
|
||||
stripped_node.block.block_type == BlockType.INPUT
|
||||
and "value" in stripped_node.input_default
|
||||
):
|
||||
stripped_node.input_default["value"] = ""
|
||||
|
||||
# Remove webhook info
|
||||
stripped_node.webhook_id = None
|
||||
stripped_node.webhook = None
|
||||
|
||||
return stripped_node
|
||||
|
||||
@staticmethod
|
||||
def _filter_secrets_from_node_input(
|
||||
input_data: dict[str, Any], schema: dict[str, Any] | None
|
||||
) -> dict[str, Any]:
|
||||
sensitive_keys = ["credentials", "api_key", "password", "token", "secret"]
|
||||
field_schemas = schema.get("properties", {}) if schema else {}
|
||||
result = {}
|
||||
for key, value in input_data.items():
|
||||
field_schema: dict | None = field_schemas.get(key)
|
||||
if (field_schema and field_schema.get("secret", False)) or any(
|
||||
sensitive_key in key.lower() for sensitive_key in sensitive_keys
|
||||
):
|
||||
# This is a secret value -> filter this key-value pair out
|
||||
continue
|
||||
elif isinstance(value, dict):
|
||||
result[key] = NodeModel._filter_secrets_from_node_input(
|
||||
value, field_schema
|
||||
)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
# Fix 2-way reference Node <-> Webhook
|
||||
Webhook.model_rebuild()
|
||||
|
||||
|
||||
class BaseGraph(BaseDbModel):
|
||||
class GraphExecution(BaseDbModel):
|
||||
execution_id: str
|
||||
started_at: datetime
|
||||
ended_at: datetime
|
||||
duration: float
|
||||
total_run_time: float
|
||||
status: ExecutionStatus
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
|
||||
@staticmethod
|
||||
def from_db(execution: AgentGraphExecution):
|
||||
now = datetime.now(timezone.utc)
|
||||
start_time = execution.startedAt or execution.createdAt
|
||||
end_time = execution.updatedAt or now
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
total_run_time = duration
|
||||
|
||||
try:
|
||||
stats = json.loads(execution.stats or "{}", target_type=dict[str, Any])
|
||||
except ValueError:
|
||||
stats = {}
|
||||
|
||||
duration = stats.get("walltime", duration)
|
||||
total_run_time = stats.get("nodes_walltime", total_run_time)
|
||||
|
||||
return GraphExecution(
|
||||
id=execution.id,
|
||||
execution_id=execution.id,
|
||||
started_at=start_time,
|
||||
ended_at=end_time,
|
||||
duration=duration,
|
||||
total_run_time=total_run_time,
|
||||
status=ExecutionStatus(execution.executionStatus),
|
||||
graph_id=execution.agentGraphId,
|
||||
graph_version=execution.agentGraphVersion,
|
||||
)
|
||||
|
||||
|
||||
class Graph(BaseDbModel):
|
||||
version: int = 1
|
||||
is_active: bool = True
|
||||
is_template: bool = False
|
||||
name: str
|
||||
description: str
|
||||
nodes: list[Node] = []
|
||||
@@ -165,48 +152,46 @@ class BaseGraph(BaseDbModel):
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
return self._generate_schema(
|
||||
*(
|
||||
(b.input_schema, node.input_default)
|
||||
AgentInputBlock.Input,
|
||||
[
|
||||
node.input_default
|
||||
for node in self.nodes
|
||||
if (b := get_block(node.block_id))
|
||||
and b.block_type == BlockType.INPUT
|
||||
and issubclass(b.input_schema, AgentInputBlock.Input)
|
||||
)
|
||||
and "name" in node.input_default
|
||||
],
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
return self._generate_schema(
|
||||
*(
|
||||
(b.input_schema, node.input_default)
|
||||
AgentOutputBlock.Input,
|
||||
[
|
||||
node.input_default
|
||||
for node in self.nodes
|
||||
if (b := get_block(node.block_id))
|
||||
and b.block_type == BlockType.OUTPUT
|
||||
and issubclass(b.input_schema, AgentOutputBlock.Input)
|
||||
)
|
||||
and "name" in node.input_default
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_schema(
|
||||
*props: tuple[Type[AgentInputBlock.Input] | Type[AgentOutputBlock.Input], dict],
|
||||
type_class: Type[AgentInputBlock.Input] | Type[AgentOutputBlock.Input],
|
||||
data: list[dict],
|
||||
) -> dict[str, Any]:
|
||||
schema = []
|
||||
for type_class, input_default in props:
|
||||
props = []
|
||||
for p in data:
|
||||
try:
|
||||
schema.append(type_class(**input_default))
|
||||
props.append(type_class(**p))
|
||||
except Exception as e:
|
||||
logger.warning(f"Invalid {type_class}: {input_default}, {e}")
|
||||
logger.warning(f"Invalid {type_class}: {p}, {e}")
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
p.name: {
|
||||
**{
|
||||
k: v
|
||||
for k, v in p.generate_schema().items()
|
||||
if k not in ["description", "default"]
|
||||
},
|
||||
"secret": p.secret,
|
||||
# Default value has to be set for advanced fields.
|
||||
"advanced": p.advanced and p.value is not None,
|
||||
@@ -214,16 +199,12 @@ class BaseGraph(BaseDbModel):
|
||||
**({"description": p.description} if p.description else {}),
|
||||
**({"default": p.value} if p.value is not None else {}),
|
||||
}
|
||||
for p in schema
|
||||
for p in props
|
||||
},
|
||||
"required": [p.name for p in schema if p.value is None],
|
||||
"required": [p.name for p in props if p.value is None],
|
||||
}
|
||||
|
||||
|
||||
class Graph(BaseGraph):
|
||||
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs, only used in export
|
||||
|
||||
|
||||
class GraphModel(Graph):
|
||||
user_id: str
|
||||
nodes: list[NodeModel] = [] # type: ignore
|
||||
@@ -247,88 +228,42 @@ class GraphModel(Graph):
|
||||
Reassigns all IDs in the graph to new UUIDs.
|
||||
This method can be used before storing a new graph to the database.
|
||||
"""
|
||||
if reassign_graph_id:
|
||||
graph_id_map = {
|
||||
self.id: str(uuid.uuid4()),
|
||||
**{sub_graph.id: str(uuid.uuid4()) for sub_graph in self.sub_graphs},
|
||||
}
|
||||
else:
|
||||
graph_id_map = {}
|
||||
|
||||
self._reassign_ids(self, user_id, graph_id_map)
|
||||
for sub_graph in self.sub_graphs:
|
||||
self._reassign_ids(sub_graph, user_id, graph_id_map)
|
||||
|
||||
@staticmethod
|
||||
def _reassign_ids(
|
||||
graph: BaseGraph,
|
||||
user_id: str,
|
||||
graph_id_map: dict[str, str],
|
||||
):
|
||||
# Reassign Graph ID
|
||||
if graph.id in graph_id_map:
|
||||
graph.id = graph_id_map[graph.id]
|
||||
id_map = {node.id: str(uuid.uuid4()) for node in self.nodes}
|
||||
if reassign_graph_id:
|
||||
self.id = str(uuid.uuid4())
|
||||
|
||||
# Reassign Node IDs
|
||||
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
||||
for node in graph.nodes:
|
||||
for node in self.nodes:
|
||||
node.id = id_map[node.id]
|
||||
|
||||
# Reassign Link IDs
|
||||
for link in graph.links:
|
||||
for link in self.links:
|
||||
link.source_id = id_map[link.source_id]
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
|
||||
# Reassign User IDs for agent blocks
|
||||
for node in graph.nodes:
|
||||
for node in self.nodes:
|
||||
if node.block_id != AgentExecutorBlock().id:
|
||||
continue
|
||||
node.input_default["user_id"] = user_id
|
||||
node.input_default.setdefault("data", {})
|
||||
if (graph_id := node.input_default.get("graph_id")) in graph_id_map:
|
||||
node.input_default["graph_id"] = graph_id_map[graph_id]
|
||||
|
||||
self.validate_graph()
|
||||
|
||||
def validate_graph(self, for_run: bool = False):
|
||||
self._validate_graph(self, for_run)
|
||||
for sub_graph in self.sub_graphs:
|
||||
self._validate_graph(sub_graph, for_run)
|
||||
|
||||
@staticmethod
|
||||
def _validate_graph(graph: BaseGraph, for_run: bool = False):
|
||||
def sanitize(name):
|
||||
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
|
||||
if sanitized_name.startswith("tools_^_"):
|
||||
return sanitized_name.split("_^_")[0]
|
||||
return sanitized_name
|
||||
|
||||
# Validate smart decision maker nodes
|
||||
smart_decision_maker_nodes = set()
|
||||
agent_nodes = set()
|
||||
nodes_block = {
|
||||
node.id: block
|
||||
for node in graph.nodes
|
||||
if (block := get_block(node.block_id)) is not None
|
||||
}
|
||||
|
||||
for node in graph.nodes:
|
||||
if (block := nodes_block.get(node.id)) is None:
|
||||
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
||||
|
||||
# Smart decision maker nodes
|
||||
if block.block_type == BlockType.AI:
|
||||
smart_decision_maker_nodes.add(node.id)
|
||||
# Agent nodes
|
||||
elif block.block_type == BlockType.AGENT:
|
||||
agent_nodes.add(node.id)
|
||||
return name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
|
||||
|
||||
input_links = defaultdict(list)
|
||||
|
||||
for link in graph.links:
|
||||
for link in self.links:
|
||||
input_links[link.sink_id].append(link)
|
||||
|
||||
# Nodes: required fields are filled or connected and dependencies are satisfied
|
||||
for node in graph.nodes:
|
||||
if (block := nodes_block.get(node.id)) is None:
|
||||
for node in self.nodes:
|
||||
block = get_block(node.block_id)
|
||||
if block is None:
|
||||
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
||||
|
||||
provided_inputs = set(
|
||||
@@ -345,12 +280,9 @@ class GraphModel(Graph):
|
||||
)
|
||||
and (
|
||||
for_run # Skip input completion validation, unless when executing.
|
||||
or block.block_type
|
||||
in [
|
||||
BlockType.INPUT,
|
||||
BlockType.OUTPUT,
|
||||
BlockType.AGENT,
|
||||
]
|
||||
or block.block_type == BlockType.INPUT
|
||||
or block.block_type == BlockType.OUTPUT
|
||||
or block.block_type == BlockType.AGENT
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
@@ -388,7 +320,7 @@ class GraphModel(Graph):
|
||||
f"Node {block.name} #{node.id}: Field `{field_name}` requires [{', '.join(missing_deps)}] to be set"
|
||||
)
|
||||
|
||||
node_map = {v.id: v for v in graph.nodes}
|
||||
node_map = {v.id: v for v in self.nodes}
|
||||
|
||||
def is_static_output_block(nid: str) -> bool:
|
||||
bid = node_map[nid].block_id
|
||||
@@ -396,23 +328,23 @@ class GraphModel(Graph):
|
||||
return b.static_output if b else False
|
||||
|
||||
# Links: links are connected and the connected pin data type are compatible.
|
||||
for link in graph.links:
|
||||
for link in self.links:
|
||||
source = (link.source_id, link.source_name)
|
||||
sink = (link.sink_id, link.sink_name)
|
||||
prefix = f"Link {source} <-> {sink}"
|
||||
suffix = f"Link {source} <-> {sink}"
|
||||
|
||||
for i, (node_id, name) in enumerate([source, sink]):
|
||||
node = node_map.get(node_id)
|
||||
if not node:
|
||||
raise ValueError(
|
||||
f"{prefix}, {node_id} is invalid node id, available nodes: {node_map.keys()}"
|
||||
f"{suffix}, {node_id} is invalid node id, available nodes: {node_map.keys()}"
|
||||
)
|
||||
|
||||
block = get_block(node.block_id)
|
||||
if not block:
|
||||
blocks = {v().id: v().name for v in get_blocks().values()}
|
||||
raise ValueError(
|
||||
f"{prefix}, {node.block_id} is invalid block id, available blocks: {blocks}"
|
||||
f"{suffix}, {node.block_id} is invalid block id, available blocks: {blocks}"
|
||||
)
|
||||
|
||||
sanitized_name = sanitize(name)
|
||||
@@ -420,37 +352,35 @@ class GraphModel(Graph):
|
||||
if i == 0:
|
||||
fields = (
|
||||
block.output_schema.get_fields()
|
||||
if block.block_type not in [BlockType.AGENT]
|
||||
if block.block_type != BlockType.AGENT
|
||||
else vals.get("output_schema", {}).get("properties", {}).keys()
|
||||
)
|
||||
else:
|
||||
fields = (
|
||||
block.input_schema.get_fields()
|
||||
if block.block_type not in [BlockType.AGENT]
|
||||
if block.block_type != BlockType.AGENT
|
||||
else vals.get("input_schema", {}).get("properties", {}).keys()
|
||||
)
|
||||
if sanitized_name not in fields and not name.startswith("tools_^_"):
|
||||
if sanitized_name not in fields:
|
||||
fields_msg = f"Allowed fields: {fields}"
|
||||
raise ValueError(f"{prefix}, `{name}` invalid, {fields_msg}")
|
||||
raise ValueError(f"{suffix}, `{name}` invalid, {fields_msg}")
|
||||
|
||||
if is_static_output_block(link.source_id):
|
||||
link.is_static = True # Each value block output should be static.
|
||||
|
||||
@staticmethod
|
||||
def from_db(
|
||||
graph: AgentGraph,
|
||||
for_export: bool = False,
|
||||
sub_graphs: list[AgentGraph] | None = None,
|
||||
):
|
||||
def from_db(graph: AgentGraph, for_export: bool = False):
|
||||
return GraphModel(
|
||||
id=graph.id,
|
||||
user_id=graph.userId if not for_export else "",
|
||||
user_id=graph.userId,
|
||||
version=graph.version,
|
||||
is_active=graph.isActive,
|
||||
is_template=graph.isTemplate,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
nodes=[
|
||||
NodeModel.from_db(node, for_export) for node in graph.AgentNodes or []
|
||||
NodeModel.from_db(GraphModel._process_node(node, for_export))
|
||||
for node in graph.AgentNodes or []
|
||||
],
|
||||
links=list(
|
||||
{
|
||||
@@ -459,12 +389,61 @@ class GraphModel(Graph):
|
||||
for link in (node.Input or []) + (node.Output or [])
|
||||
}
|
||||
),
|
||||
sub_graphs=[
|
||||
GraphModel.from_db(sub_graph, for_export)
|
||||
for sub_graph in sub_graphs or []
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_node(node: AgentNode, for_export: bool) -> AgentNode:
|
||||
if for_export:
|
||||
# Remove credentials from node input
|
||||
if node.constantInput:
|
||||
constant_input = json.loads(
|
||||
node.constantInput, target_type=dict[str, Any]
|
||||
)
|
||||
constant_input = GraphModel._hide_node_input_credentials(constant_input)
|
||||
node.constantInput = json.dumps(constant_input)
|
||||
|
||||
# Remove webhook info
|
||||
node.webhookId = None
|
||||
node.Webhook = None
|
||||
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def _hide_node_input_credentials(input_data: dict[str, Any]) -> dict[str, Any]:
|
||||
sensitive_keys = ["credentials", "api_key", "password", "token", "secret"]
|
||||
result = {}
|
||||
for key, value in input_data.items():
|
||||
if isinstance(value, dict):
|
||||
result[key] = GraphModel._hide_node_input_credentials(value)
|
||||
elif isinstance(value, str) and any(
|
||||
sensitive_key in key.lower() for sensitive_key in sensitive_keys
|
||||
):
|
||||
# Skip this key-value pair in the result
|
||||
continue
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
def clean_graph(self):
|
||||
blocks = [block() for block in get_blocks().values()]
|
||||
|
||||
input_blocks = [
|
||||
node
|
||||
for node in self.nodes
|
||||
if next(
|
||||
(
|
||||
b
|
||||
for b in blocks
|
||||
if b.id == node.block_id and b.block_type == BlockType.INPUT
|
||||
),
|
||||
None,
|
||||
)
|
||||
]
|
||||
|
||||
for node in self.nodes:
|
||||
if any(input_block.id == node.id for input_block in input_blocks):
|
||||
node.input_default["value"] = ""
|
||||
|
||||
|
||||
# --------------------- CRUD functions --------------------- #
|
||||
|
||||
@@ -494,14 +473,14 @@ async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
|
||||
|
||||
async def get_graphs(
|
||||
user_id: str,
|
||||
filter_by: Literal["active"] | None = "active",
|
||||
filter_by: Literal["active", "template"] | None = "active",
|
||||
) -> list[GraphModel]:
|
||||
"""
|
||||
Retrieves graph metadata objects.
|
||||
Default behaviour is to get all currently active graphs.
|
||||
|
||||
Args:
|
||||
filter_by: An optional filter to either select graphs.
|
||||
filter_by: An optional filter to either select templates or active graphs.
|
||||
user_id: The ID of the user that owns the graph.
|
||||
|
||||
Returns:
|
||||
@@ -511,6 +490,8 @@ async def get_graphs(
|
||||
|
||||
if filter_by == "active":
|
||||
where_clause["isActive"] = True
|
||||
elif filter_by == "template":
|
||||
where_clause["isTemplate"] = True
|
||||
|
||||
graphs = await AgentGraph.prisma().find_many(
|
||||
where=where_clause,
|
||||
@@ -530,138 +511,53 @@ async def get_graphs(
|
||||
return graph_models
|
||||
|
||||
|
||||
async def get_graph_metadata(graph_id: str, version: int | None = None) -> Graph | None:
|
||||
where_clause: AgentGraphWhereInput = {
|
||||
"id": graph_id,
|
||||
}
|
||||
|
||||
if version is not None:
|
||||
where_clause["version"] = version
|
||||
|
||||
graph = await AgentGraph.prisma().find_first(
|
||||
where=where_clause,
|
||||
order={"version": "desc"},
|
||||
async def get_executions(user_id: str) -> list[GraphExecution]:
|
||||
executions = await AgentGraphExecution.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [GraphExecution.from_db(execution) for execution in executions]
|
||||
|
||||
if not graph:
|
||||
return None
|
||||
|
||||
return Graph(
|
||||
id=graph.id,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
version=graph.version,
|
||||
is_active=graph.isActive,
|
||||
async def get_execution(user_id: str, execution_id: str) -> GraphExecution | None:
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={"id": execution_id, "userId": user_id}
|
||||
)
|
||||
return GraphExecution.from_db(execution) if execution else None
|
||||
|
||||
|
||||
async def get_graph(
|
||||
graph_id: str,
|
||||
version: int | None = None,
|
||||
template: bool = False,
|
||||
user_id: str | None = None,
|
||||
for_export: bool = False,
|
||||
) -> GraphModel | None:
|
||||
"""
|
||||
Retrieves a graph from the DB.
|
||||
Defaults to the version with `is_active` if `version` is not passed.
|
||||
Defaults to the version with `is_active` if `version` is not passed,
|
||||
or the latest version with `is_template` if `template=True`.
|
||||
|
||||
Returns `None` if the record is not found.
|
||||
"""
|
||||
where_clause: AgentGraphWhereInput = {
|
||||
"id": graph_id,
|
||||
}
|
||||
|
||||
if version is not None:
|
||||
where_clause["version"] = version
|
||||
elif not template:
|
||||
where_clause["isActive"] = True
|
||||
|
||||
# TODO: Fix hack workaround to get adding store agents to work
|
||||
if user_id is not None and not template:
|
||||
where_clause["userId"] = user_id
|
||||
|
||||
graph = await AgentGraph.prisma().find_first(
|
||||
where=where_clause,
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
order={"version": "desc"},
|
||||
)
|
||||
|
||||
# For access, the graph must be owned by the user or listed in the store
|
||||
if graph is None or (
|
||||
graph.userId != user_id
|
||||
and not (
|
||||
await StoreListingVersion.prisma().find_first(
|
||||
where={
|
||||
"agentId": graph_id,
|
||||
"agentVersion": version or graph.version,
|
||||
"isDeleted": False,
|
||||
"submissionStatus": SubmissionStatus.APPROVED,
|
||||
}
|
||||
)
|
||||
)
|
||||
):
|
||||
return None
|
||||
|
||||
if for_export:
|
||||
sub_graphs = await get_sub_graphs(graph)
|
||||
return GraphModel.from_db(
|
||||
graph=graph,
|
||||
sub_graphs=sub_graphs,
|
||||
for_export=for_export,
|
||||
)
|
||||
|
||||
return GraphModel.from_db(graph, for_export)
|
||||
|
||||
|
||||
async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
|
||||
"""
|
||||
Iteratively fetches all sub-graphs of a given graph, and flattens them into a list.
|
||||
This call involves a DB fetch in batch, breadth-first, per-level of graph depth.
|
||||
On each DB fetch we will only fetch the sub-graphs that are not already in the list.
|
||||
"""
|
||||
sub_graphs = {graph.id: graph}
|
||||
search_graphs = [graph]
|
||||
agent_block_id = AgentExecutorBlock().id
|
||||
|
||||
while search_graphs:
|
||||
sub_graph_ids = [
|
||||
(graph_id, graph_version)
|
||||
for graph in search_graphs
|
||||
for node in graph.AgentNodes or []
|
||||
if (
|
||||
node.AgentBlock
|
||||
and node.AgentBlock.id == agent_block_id
|
||||
and (graph_id := dict(node.constantInput).get("graph_id"))
|
||||
and (graph_version := dict(node.constantInput).get("graph_version"))
|
||||
)
|
||||
]
|
||||
if not sub_graph_ids:
|
||||
break
|
||||
|
||||
graphs = await AgentGraph.prisma().find_many(
|
||||
where={
|
||||
"OR": [
|
||||
{
|
||||
"id": graph_id,
|
||||
"version": graph_version,
|
||||
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
|
||||
}
|
||||
for graph_id, graph_version in sub_graph_ids
|
||||
] # type: ignore
|
||||
},
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
|
||||
search_graphs = [graph for graph in graphs if graph.id not in sub_graphs]
|
||||
sub_graphs.update({graph.id: graph for graph in search_graphs})
|
||||
|
||||
return [g for g in sub_graphs.values() if g.id != graph.id]
|
||||
|
||||
|
||||
async def get_connected_output_nodes(node_id: str) -> list[tuple[Link, Node]]:
|
||||
links = await AgentNodeLink.prisma().find_many(
|
||||
where={"agentNodeSourceId": node_id},
|
||||
include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}}, # type: ignore
|
||||
)
|
||||
return [
|
||||
(Link.from_db(link), NodeModel.from_db(link.AgentNodeSink))
|
||||
for link in links
|
||||
if link.AgentNodeSink
|
||||
]
|
||||
return GraphModel.from_db(graph, for_export) if graph else None
|
||||
|
||||
|
||||
async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:
|
||||
@@ -715,56 +611,50 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
|
||||
async with transaction() as tx:
|
||||
await __create_graph(tx, graph, user_id)
|
||||
|
||||
if created_graph := await get_graph(graph.id, graph.version, user_id=user_id):
|
||||
if created_graph := await get_graph(
|
||||
graph.id, graph.version, graph.is_template, user_id=user_id
|
||||
):
|
||||
return created_graph
|
||||
|
||||
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
|
||||
|
||||
|
||||
async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
graphs = [graph] + graph.sub_graphs
|
||||
|
||||
await AgentGraph.prisma(tx).create_many(
|
||||
data=[
|
||||
{
|
||||
"id": graph.id,
|
||||
"version": graph.version,
|
||||
"name": graph.name,
|
||||
"description": graph.description,
|
||||
"isActive": graph.is_active,
|
||||
"userId": user_id,
|
||||
}
|
||||
for graph in graphs
|
||||
]
|
||||
await AgentGraph.prisma(tx).create(
|
||||
data={
|
||||
"id": graph.id,
|
||||
"version": graph.version,
|
||||
"name": graph.name,
|
||||
"description": graph.description,
|
||||
"isTemplate": graph.is_template,
|
||||
"isActive": graph.is_active,
|
||||
"userId": user_id,
|
||||
"AgentNodes": {
|
||||
"create": [
|
||||
{
|
||||
"id": node.id,
|
||||
"agentBlockId": node.block_id,
|
||||
"constantInput": json.dumps(node.input_default),
|
||||
"metadata": json.dumps(node.metadata),
|
||||
}
|
||||
for node in graph.nodes
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
await AgentNode.prisma(tx).create_many(
|
||||
data=[
|
||||
{
|
||||
"id": node.id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"agentBlockId": node.block_id,
|
||||
"constantInput": Json(node.input_default),
|
||||
"metadata": Json(node.metadata),
|
||||
"webhookId": node.webhook_id,
|
||||
}
|
||||
for graph in graphs
|
||||
for node in graph.nodes
|
||||
]
|
||||
)
|
||||
|
||||
await AgentNodeLink.prisma(tx).create_many(
|
||||
data=[
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"sourceName": link.source_name,
|
||||
"sinkName": link.sink_name,
|
||||
"agentNodeSourceId": link.source_id,
|
||||
"agentNodeSinkId": link.sink_id,
|
||||
"isStatic": link.is_static,
|
||||
}
|
||||
for graph in graphs
|
||||
await asyncio.gather(
|
||||
*[
|
||||
AgentNodeLink.prisma(tx).create(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"sourceName": link.source_name,
|
||||
"sinkName": link.sink_name,
|
||||
"agentNodeSourceId": link.source_id,
|
||||
"agentNodeSinkId": link.sink_id,
|
||||
"isStatic": link.is_static,
|
||||
}
|
||||
)
|
||||
for link in graph.links
|
||||
]
|
||||
)
|
||||
@@ -807,11 +697,9 @@ async def fix_llm_provider_credentials():
|
||||
|
||||
store = IntegrationCredentialsStore()
|
||||
|
||||
broken_nodes = []
|
||||
try:
|
||||
broken_nodes = await prisma.get_client().query_raw(
|
||||
"""
|
||||
SELECT graph."userId" user_id,
|
||||
broken_nodes = await prisma.get_client().query_raw(
|
||||
"""
|
||||
SELECT graph."userId" user_id,
|
||||
node.id node_id,
|
||||
node."constantInput" node_preset_input
|
||||
FROM platform."AgentNode" node
|
||||
@@ -820,10 +708,8 @@ async def fix_llm_provider_credentials():
|
||||
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
|
||||
ORDER BY graph."userId";
|
||||
"""
|
||||
)
|
||||
logger.info(f"Fixing LLM credential inputs on {len(broken_nodes)} nodes")
|
||||
except Exception as e:
|
||||
logger.error(f"Error fixing LLM credential inputs: {e}")
|
||||
)
|
||||
logger.info(f"Fixing LLM credential inputs on {len(broken_nodes)} nodes")
|
||||
|
||||
user_id: str = ""
|
||||
user_integrations = None
|
||||
@@ -836,7 +722,7 @@ async def fix_llm_provider_credentials():
|
||||
raise RuntimeError(f"Impossible state while processing node {node}")
|
||||
|
||||
node_id: str = node["node_id"]
|
||||
node_preset_input: dict = node["node_preset_input"]
|
||||
node_preset_input: dict = json.loads(node["node_preset_input"])
|
||||
credentials_meta: dict = node_preset_input["credentials"]
|
||||
|
||||
credentials = next(
|
||||
@@ -872,42 +758,5 @@ async def fix_llm_provider_credentials():
|
||||
store.update_creds(user_id, credentials)
|
||||
await AgentNode.prisma().update(
|
||||
where={"id": node_id},
|
||||
data={"constantInput": Json(node_preset_input)},
|
||||
data={"constantInput": json.dumps(node_preset_input)},
|
||||
)
|
||||
|
||||
|
||||
async def migrate_llm_models(migrate_to: LlmModel):
|
||||
"""
|
||||
Update all LLM models in all AI blocks that don't exist in the enum.
|
||||
Note: Only updates top level LlmModel SchemaFields of blocks (won't update nested fields).
|
||||
"""
|
||||
logger.info("Migrating LLM models")
|
||||
# Scan all blocks and search for LlmModel fields
|
||||
llm_model_fields: dict[str, str] = {} # {block_id: field_name}
|
||||
|
||||
# Search for all LlmModel fields
|
||||
for block_type in get_blocks().values():
|
||||
block = block_type()
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
fields: dict[str, FieldInfo] = block.input_schema.model_fields
|
||||
|
||||
# Collect top-level LlmModel fields
|
||||
for field_name, field in fields.items():
|
||||
if field.annotation == LlmModel:
|
||||
llm_model_fields[block.id] = field_name
|
||||
|
||||
# Update each block
|
||||
for id, path in llm_model_fields.items():
|
||||
# Convert enum values to a list of strings for the SQL query
|
||||
enum_values = [v.value for v in LlmModel.__members__.values()]
|
||||
|
||||
query = f"""
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = jsonb_set("constantInput", '{{{path}}}', '"{migrate_to.value}"', true)
|
||||
WHERE "agentBlockId" = '{id}'
|
||||
AND "constantInput" ? '{path}'
|
||||
AND "constantInput"->>'{path}' NOT IN ({','.join(f"'{value}'" for value in enum_values)})
|
||||
"""
|
||||
|
||||
await db.execute_raw(query)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import prisma
|
||||
|
||||
from backend.blocks.io import IO_BLOCK_IDs
|
||||
|
||||
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
@@ -20,49 +18,17 @@ EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
|
||||
"AgentGraphExecution": True,
|
||||
}
|
||||
|
||||
MAX_NODE_EXECUTIONS_FETCH = 1000
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
|
||||
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
"AgentNodeExecutions": {
|
||||
"include": {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
},
|
||||
"order_by": [
|
||||
{"queuedTime": "desc"},
|
||||
# Fallback: Incomplete execs has no queuedTime.
|
||||
{"addedTime": "desc"},
|
||||
],
|
||||
"take": MAX_NODE_EXECUTIONS_FETCH, # Avoid loading excessive node executions.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
"AgentNodeExecutions": {
|
||||
**GRAPH_EXECUTION_INCLUDE_WITH_NODES["AgentNodeExecutions"], # type: ignore
|
||||
"where": {
|
||||
"AgentNode": {
|
||||
"AgentBlock": {"id": {"in": IO_BLOCK_IDs}}, # type: ignore
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
|
||||
}
|
||||
|
||||
|
||||
def library_agent_include(user_id: str) -> prisma.types.LibraryAgentInclude:
|
||||
return {
|
||||
"Agent": {
|
||||
"include": {
|
||||
**AGENT_GRAPH_INCLUDE,
|
||||
"AgentGraphExecution": {"where": {"userId": user_id}},
|
||||
}
|
||||
},
|
||||
"Creator": True,
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
|
||||
from backend.data.queue import AsyncRedisEventBus
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.utils import webhook_ingress_url
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .db import BaseDbModel
|
||||
|
||||
@@ -83,18 +82,11 @@ async def create_webhook(webhook: Webhook) -> Webhook:
|
||||
|
||||
|
||||
async def get_webhook(webhook_id: str) -> Webhook:
|
||||
"""
|
||||
⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints.
|
||||
|
||||
Raises:
|
||||
NotFoundError: if no record with the given ID exists
|
||||
"""
|
||||
webhook = await IntegrationWebhook.prisma().find_unique(
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
webhook = await IntegrationWebhook.prisma().find_unique_or_raise(
|
||||
where={"id": webhook_id},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
)
|
||||
if not webhook:
|
||||
raise NotFoundError(f"Webhook #{webhook_id} not found")
|
||||
return Webhook.from_db(webhook)
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user