mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-13 00:58:16 -05:00
Compare commits
347 Commits
detached
...
Search-res
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05afc1e1bb | ||
|
|
a597507680 | ||
|
|
14c223435b | ||
|
|
f6ad194306 | ||
|
|
e5c21ceda9 | ||
|
|
1a4a3e72ec | ||
|
|
94a312a279 | ||
|
|
de3c096e23 | ||
|
|
f090f4ca4a | ||
|
|
29c771ba1b | ||
|
|
582e12c766 | ||
|
|
584b121a16 | ||
|
|
d67c7a3704 | ||
|
|
dec9dfad9d | ||
|
|
4af2297ca4 | ||
|
|
72e9dd6b37 | ||
|
|
987fe6d3b8 | ||
|
|
e3cf605e9b | ||
|
|
3c024be10c | ||
|
|
e508d90fa7 | ||
|
|
abf73e8d66 | ||
|
|
b16bf42fa3 | ||
|
|
33b9eef376 | ||
|
|
b8a3ffc04a | ||
|
|
3fd2b7ce4a | ||
|
|
6490b4e188 | ||
|
|
91fb4f5c56 | ||
|
|
32c15434d2 | ||
|
|
6fcceeabdb | ||
|
|
b1468f779c | ||
|
|
ac03211c59 | ||
|
|
7a9115db18 | ||
|
|
c35453f41e | ||
|
|
07d732ece7 | ||
|
|
54baa84c28 | ||
|
|
6307ca1841 | ||
|
|
d827d4f9e4 | ||
|
|
0fb33f6cb7 | ||
|
|
afd809f68a | ||
|
|
a8ae006967 | ||
|
|
b7603f6053 | ||
|
|
c91099a5ef | ||
|
|
c7ae4cfbda | ||
|
|
b5e1c8075e | ||
|
|
d14cecb954 | ||
|
|
9bc72305c5 | ||
|
|
980c90c07c | ||
|
|
006e843461 | ||
|
|
452fe314f6 | ||
|
|
984d42234c | ||
|
|
b86b0c758e | ||
|
|
f1053ff8b4 | ||
|
|
7a1176073d | ||
|
|
79c0c314e2 | ||
|
|
e6d728b081 | ||
|
|
a7a526e820 | ||
|
|
df431d71ff | ||
|
|
3145f3d59d | ||
|
|
89e8fe3854 | ||
|
|
6c34f8bd96 | ||
|
|
6ccf62c3a6 | ||
|
|
0217b75d65 | ||
|
|
6a712fd6bb | ||
|
|
f4244e5038 | ||
|
|
9c0e8750b0 | ||
|
|
14ecaf861e | ||
|
|
6d6ed348fa | ||
|
|
281cd2910b | ||
|
|
6997e2a170 | ||
|
|
1a85eb1dcf | ||
|
|
1d81c61b77 | ||
|
|
3da5fc2cac | ||
|
|
8d8664a3ce | ||
|
|
b62f411518 | ||
|
|
eb79c04855 | ||
|
|
d7c9742d7e | ||
|
|
ea6c9a1152 | ||
|
|
dcfad263cb | ||
|
|
9ad9dd9fe1 | ||
|
|
e2904136bd | ||
|
|
6dba31e021 | ||
|
|
ffc3eff7e2 | ||
|
|
73eafa37c6 | ||
|
|
c621226554 | ||
|
|
227806aef9 | ||
|
|
0272d87af3 | ||
|
|
d201b653c8 | ||
|
|
f77172ec82 | ||
|
|
ec072ad52a | ||
|
|
9b456612aa | ||
|
|
64f5e60d12 | ||
|
|
3d0ec9c52a | ||
|
|
50d08654f8 | ||
|
|
fa871073ca | ||
|
|
ab4f4549d6 | ||
|
|
6b742d1a8c | ||
|
|
4492995d6b | ||
|
|
6a736cc60f | ||
|
|
60417545f0 | ||
|
|
4f20f868c1 | ||
|
|
2ed8347430 | ||
|
|
12eb3b2937 | ||
|
|
d5580f8a94 | ||
|
|
5b0ecfd26b | ||
|
|
3b24b0ac0d | ||
|
|
715d425579 | ||
|
|
0eda63fa15 | ||
|
|
4e886bd6e9 | ||
|
|
f3168ea187 | ||
|
|
7fdfffdfcc | ||
|
|
8f4d552909 | ||
|
|
d4edb9371d | ||
|
|
89011aabe0 | ||
|
|
303a55145d | ||
|
|
f54cfee4a7 | ||
|
|
43bd5c89d7 | ||
|
|
0a604a5746 | ||
|
|
5ccfb8e4c6 | ||
|
|
96bba3c1bd | ||
|
|
b3bd0f5d54 | ||
|
|
de1cd6c295 | ||
|
|
91aa371220 | ||
|
|
3bca279b35 | ||
|
|
1ab19dcc56 | ||
|
|
7c2e371f23 | ||
|
|
30a047eac3 | ||
|
|
f2387147c7 | ||
|
|
0f77f931ab | ||
|
|
fdb82eda38 | ||
|
|
dce9bdd488 | ||
|
|
30bb9a3d72 | ||
|
|
758edaec9e | ||
|
|
be7f9123bb | ||
|
|
53eee63161 | ||
|
|
5c49fc87fd | ||
|
|
48d27c91d4 | ||
|
|
39c9e4a76c | ||
|
|
b3fd8bbfb9 | ||
|
|
a4b186cf81 | ||
|
|
a971d59974 | ||
|
|
9db8832d6b | ||
|
|
acb35c3926 | ||
|
|
cd9c4218b0 | ||
|
|
3ccf0138b1 | ||
|
|
19ff8f324d | ||
|
|
f445918abf | ||
|
|
5a4083d542 | ||
|
|
1e66137c7e | ||
|
|
bb13157d18 | ||
|
|
37e6b6f385 | ||
|
|
fabf742601 | ||
|
|
100b667afc | ||
|
|
da10f1a2df | ||
|
|
65ada3fb72 | ||
|
|
bc277acf57 | ||
|
|
52d19d084c | ||
|
|
86a10858bd | ||
|
|
f297e42be4 | ||
|
|
2bbac5714f | ||
|
|
e063f4bcda | ||
|
|
6c64c5b98f | ||
|
|
5530db63de | ||
|
|
4278ae61b0 | ||
|
|
0c78edb592 | ||
|
|
93d3bd3773 | ||
|
|
c1c3fd4982 | ||
|
|
d4b69d864f | ||
|
|
db3284830a | ||
|
|
d16cf6cfeb | ||
|
|
3f3919a843 | ||
|
|
244171d748 | ||
|
|
faa683b6e4 | ||
|
|
9255759e1e | ||
|
|
abbed4051d | ||
|
|
8c5380d4f9 | ||
|
|
cacc6e1f86 | ||
|
|
9616baf695 | ||
|
|
518f196e6b | ||
|
|
a9693b582f | ||
|
|
64fcba3f3a | ||
|
|
a53f3f0e0a | ||
|
|
84cdf189f4 | ||
|
|
610a5b9943 | ||
|
|
18b5f2047c | ||
|
|
17db193faa | ||
|
|
78476630cd | ||
|
|
7a41f36b13 | ||
|
|
c315b8e700 | ||
|
|
f6c1bdccac | ||
|
|
afe5c12afb | ||
|
|
65344b9783 | ||
|
|
9f71fb940d | ||
|
|
28a327c57a | ||
|
|
6aba1bce62 | ||
|
|
5db220c568 | ||
|
|
8e63a4a8d7 | ||
|
|
175f17b131 | ||
|
|
9aec1f51ed | ||
|
|
760e2ff592 | ||
|
|
d17ea2d62a | ||
|
|
6351ba7f5d | ||
|
|
a1ba3b1ac3 | ||
|
|
d78b4d9ab4 | ||
|
|
1292c85d2a | ||
|
|
c44fd7332c | ||
|
|
45a2826df8 | ||
|
|
abb8134761 | ||
|
|
c879599871 | ||
|
|
a71b2a1de6 | ||
|
|
38b20e6158 | ||
|
|
a8a0da1e3c | ||
|
|
d3e7aab796 | ||
|
|
f539c24571 | ||
|
|
2068073e8c | ||
|
|
04d36194c9 | ||
|
|
e8eda51b27 | ||
|
|
aa5d304a2e | ||
|
|
9dab7c9132 | ||
|
|
4562606c54 | ||
|
|
fa933ada85 | ||
|
|
bc1d11bf42 | ||
|
|
5ac2e7044e | ||
|
|
7f741468dd | ||
|
|
715e4c7d73 | ||
|
|
ef473bbc8d | ||
|
|
43ac5e0343 | ||
|
|
da15408a35 | ||
|
|
bd00338690 | ||
|
|
c63ccb5bd9 | ||
|
|
0e9906ea65 | ||
|
|
ff0e786202 | ||
|
|
deee943c3a | ||
|
|
40f38fcb46 | ||
|
|
35ec676f35 | ||
|
|
c345a79962 | ||
|
|
b9c7d1a115 | ||
|
|
bf459a17ba | ||
|
|
15befae65f | ||
|
|
2340e9b3f5 | ||
|
|
7044782689 | ||
|
|
ef4776f697 | ||
|
|
946ba02969 | ||
|
|
9a6ff408b4 | ||
|
|
c07ea4f63d | ||
|
|
2af974b381 | ||
|
|
7e6bdf8b04 | ||
|
|
707c485212 | ||
|
|
5145aa7609 | ||
|
|
496990c096 | ||
|
|
52de22469f | ||
|
|
c27f163623 | ||
|
|
29a61abfe3 | ||
|
|
1580fb5fa7 | ||
|
|
b9763aa28a | ||
|
|
f9eefae1ad | ||
|
|
d16c1f259b | ||
|
|
f0650d1f76 | ||
|
|
d9710ce1af | ||
|
|
e5a4b9a5ac | ||
|
|
ebad48481e | ||
|
|
4503dab267 | ||
|
|
d73acd13cb | ||
|
|
a955794acd | ||
|
|
03d754cb50 | ||
|
|
8dc22e2a63 | ||
|
|
bacdc190e7 | ||
|
|
5943c75873 | ||
|
|
8db695932e | ||
|
|
a70d6a5193 | ||
|
|
4869a8ce22 | ||
|
|
d5d9ecd71c | ||
|
|
52119eadc2 | ||
|
|
334b4d5ef9 | ||
|
|
d7fc2dfb46 | ||
|
|
120469c8bf | ||
|
|
259d6f2b69 | ||
|
|
dce436fb30 | ||
|
|
44cb4e8e77 | ||
|
|
88767a84d1 | ||
|
|
8b03477d2d | ||
|
|
10a2b36dc9 | ||
|
|
b1ccdacd98 | ||
|
|
8811011286 | ||
|
|
b8c764ad70 | ||
|
|
4d69b2eb75 | ||
|
|
5adc6c0a46 | ||
|
|
915b08d8a7 | ||
|
|
2997b12367 | ||
|
|
82f0ee2240 | ||
|
|
26bef8b918 | ||
|
|
82f553ec0d | ||
|
|
4ea85b5eaf | ||
|
|
45efe5a947 | ||
|
|
c2b320dd6a | ||
|
|
6b2d264414 | ||
|
|
9a742cbe93 | ||
|
|
f96f2f101b | ||
|
|
6318a976b5 | ||
|
|
c2c39a0cd6 | ||
|
|
657e64d903 | ||
|
|
a2e681a09f | ||
|
|
1fc3b2aa0a | ||
|
|
5465296ba6 | ||
|
|
fbfb8838fd | ||
|
|
130261c75f | ||
|
|
c8c954d862 | ||
|
|
abfa707f69 | ||
|
|
a01f326f7e | ||
|
|
e377711c7e | ||
|
|
4c40b5f187 | ||
|
|
a8c5264e17 | ||
|
|
9e8f76b749 | ||
|
|
43f8ed55ea | ||
|
|
5d5f14c799 | ||
|
|
9db01a7836 | ||
|
|
1e4a96883f | ||
|
|
a34dc25b34 | ||
|
|
90d3954e8c | ||
|
|
aeb43b7d37 | ||
|
|
7dedcaddb6 | ||
|
|
02463a5cb2 | ||
|
|
4c18763e55 | ||
|
|
8a9a1b59a4 | ||
|
|
8d9b282376 | ||
|
|
ffdc457dea | ||
|
|
015ac85a83 | ||
|
|
ee01a602ff | ||
|
|
f7f4207902 | ||
|
|
37bae48bd9 | ||
|
|
2ef21f929f | ||
|
|
6b77d71b88 | ||
|
|
e2946e0cd1 | ||
|
|
887e7a4f0a | ||
|
|
c4a8ab8a19 | ||
|
|
a3dda5d5a2 | ||
|
|
948279a67d | ||
|
|
cf97e25b4a | ||
|
|
72e4eb2418 | ||
|
|
e241a4cc2a | ||
|
|
3accc65d44 | ||
|
|
fd6f23f4a6 | ||
|
|
5ad56c553f | ||
|
|
81471d8ffe | ||
|
|
e46bf6300e | ||
|
|
a0e87867b7 | ||
|
|
00d5d843a2 | ||
|
|
3ce7cf2713 |
24
.github/dependabot.yml
vendored
24
.github/dependabot.yml
vendored
@@ -7,6 +7,9 @@ updates:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 10
|
||||
target-branch: "dev"
|
||||
commit-message:
|
||||
prefix: "chore(libs/deps)"
|
||||
prefix-development: "chore(libs/deps-dev)"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
@@ -26,6 +29,9 @@ updates:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 10
|
||||
target-branch: "dev"
|
||||
commit-message:
|
||||
prefix: "chore(backend/deps)"
|
||||
prefix-development: "chore(backend/deps-dev)"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
@@ -38,7 +44,6 @@ updates:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# frontend (Next.js project)
|
||||
- package-ecosystem: "npm"
|
||||
directory: "autogpt_platform/frontend"
|
||||
@@ -46,6 +51,9 @@ updates:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 10
|
||||
target-branch: "dev"
|
||||
commit-message:
|
||||
prefix: "chore(frontend/deps)"
|
||||
prefix-development: "chore(frontend/deps-dev)"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
@@ -58,7 +66,6 @@ updates:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# infra (Terraform)
|
||||
- package-ecosystem: "terraform"
|
||||
directory: "autogpt_platform/infra"
|
||||
@@ -66,6 +73,10 @@ updates:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 5
|
||||
target-branch: "dev"
|
||||
commit-message:
|
||||
prefix: "chore(infra/deps)"
|
||||
prefix-development: "chore(infra/deps-dev)"
|
||||
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
@@ -78,7 +89,6 @@ updates:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# market (Poetry project)
|
||||
- package-ecosystem: "pip"
|
||||
directory: "autogpt_platform/market"
|
||||
@@ -86,6 +96,9 @@ updates:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 10
|
||||
target-branch: "dev"
|
||||
commit-message:
|
||||
prefix: "chore(market/deps)"
|
||||
prefix-development: "chore(market/deps-dev)"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
@@ -146,6 +159,9 @@ updates:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 1
|
||||
target-branch: "dev"
|
||||
commit-message:
|
||||
prefix: "chore(platform/deps)"
|
||||
prefix-development: "chore(platform/deps-dev)"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
@@ -166,6 +182,8 @@ updates:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 1
|
||||
target-branch: "dev"
|
||||
commit-message:
|
||||
prefix: "chore(docs/deps)"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
|
||||
13
.github/workflows/platform-backend-ci.yml
vendored
13
.github/workflows/platform-backend-ci.yml
vendored
@@ -6,11 +6,13 @@ on:
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- "autogpt_platform/backend/**"
|
||||
- "autogpt_platform/autogpt_libs/**"
|
||||
pull_request:
|
||||
branches: [master, dev, release-*]
|
||||
paths:
|
||||
- ".github/workflows/platform-backend-ci.yml"
|
||||
- "autogpt_platform/backend/**"
|
||||
- "autogpt_platform/autogpt_libs/**"
|
||||
merge_group:
|
||||
|
||||
concurrency:
|
||||
@@ -77,6 +79,17 @@ jobs:
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Check poetry.lock
|
||||
run: |
|
||||
poetry lock --no-update
|
||||
|
||||
if ! git diff --quiet poetry.lock; then
|
||||
echo "Error: poetry.lock not up to date."
|
||||
echo
|
||||
git diff poetry.lock
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
|
||||
35
.github/workflows/platform-frontend-ci.yml
vendored
35
.github/workflows/platform-frontend-ci.yml
vendored
@@ -23,6 +23,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
@@ -38,24 +39,12 @@ jobs:
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
browser: [chromium, firefox, webkit]
|
||||
|
||||
steps:
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
# this might remove tools that are actually needed,
|
||||
# if set to "true" but frees about 6 GB
|
||||
tool-cache: false
|
||||
|
||||
# all of these default to true, but feel free to set to
|
||||
# "false" if necessary for your workflow
|
||||
android: false
|
||||
dotnet: false
|
||||
haskell: false
|
||||
large-packages: true
|
||||
docker-images: true
|
||||
swap-storage: true
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -66,6 +55,12 @@ jobs:
|
||||
with:
|
||||
node-version: "21"
|
||||
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
large-packages: false # slow
|
||||
docker-images: false # limited benefit
|
||||
|
||||
- name: Copy default supabase .env
|
||||
run: |
|
||||
cp ../supabase/docker/.env.example ../.env
|
||||
@@ -86,16 +81,16 @@ jobs:
|
||||
run: |
|
||||
cp .env.example .env
|
||||
|
||||
- name: Install Playwright Browsers
|
||||
run: yarn playwright install --with-deps
|
||||
- name: Install Browser '${{ matrix.browser }}'
|
||||
run: yarn playwright install --with-deps ${{ matrix.browser }}
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
yarn test
|
||||
yarn test --project=${{ matrix.browser }}
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: ${{ !cancelled() }}
|
||||
with:
|
||||
name: playwright-report
|
||||
name: playwright-report-${{ matrix.browser }}
|
||||
path: playwright-report/
|
||||
retention-days: 30
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -173,3 +173,6 @@ LICENSE.rtf
|
||||
autogpt_platform/backend/settings.py
|
||||
/.auth
|
||||
/autogpt_platform/frontend/.auth
|
||||
|
||||
*.ign.*
|
||||
.test-contents
|
||||
|
||||
@@ -98,6 +98,11 @@ repos:
|
||||
files: ^autogpt_platform/autogpt_libs/
|
||||
args: [--fix]
|
||||
|
||||
- id: ruff-format
|
||||
name: Format (Ruff) - AutoGPT Platform - Libs
|
||||
alias: ruff-lint-platform-libs
|
||||
files: ^autogpt_platform/autogpt_libs/
|
||||
|
||||
- repo: local
|
||||
# isort needs the context of which packages are installed to function, so we
|
||||
# can't use a vendored isort pre-commit hook (which runs in its own isolated venv).
|
||||
@@ -140,7 +145,7 @@ repos:
|
||||
# everything in .gitignore, so it works fine without any config or arguments.
|
||||
hooks:
|
||||
- id: black
|
||||
name: Lint (Black)
|
||||
name: Format (Black)
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.0.0
|
||||
|
||||
@@ -8,7 +8,7 @@ We take the security of our project seriously. If you believe you have found a s
|
||||
|
||||
Instead, please report them via:
|
||||
- [GitHub Security Advisory](https://github.com/Significant-Gravitas/AutoGPT/security/advisories/new)
|
||||
- [Huntr.dev](https://huntr.com/repos/significant-gravitas/autogpt) - where you may be eligible for a bounty
|
||||
<!--- [Huntr.dev](https://huntr.com/repos/significant-gravitas/autogpt) - where you may be eligible for a bounty-->
|
||||
|
||||
### Reporting Process
|
||||
1. **Submit Report**: Use one of the above channels to submit your report
|
||||
|
||||
@@ -35,3 +35,12 @@ def verify_user(payload: dict | None, admin_only: bool) -> User:
|
||||
raise fastapi.HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
return User.from_payload(payload)
|
||||
|
||||
|
||||
def get_user_id(payload: dict = fastapi.Depends(auth_middleware)) -> str:
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401, detail="User ID not found in token"
|
||||
)
|
||||
return user_id
|
||||
|
||||
@@ -72,7 +72,7 @@ def feature_flag(
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[P, Union[T, Awaitable[T]]]
|
||||
func: Callable[P, Union[T, Awaitable[T]]],
|
||||
) -> Callable[P, Union[T, Awaitable[T]]]:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import pytest
|
||||
from autogpt_libs.feature_flag.client import feature_flag, mock_flag_variation
|
||||
from ldclient import LDClient
|
||||
|
||||
from autogpt_libs.feature_flag.client import feature_flag, mock_flag_variation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ld_client(mocker):
|
||||
|
||||
@@ -6,7 +6,7 @@ class Settings(BaseSettings):
|
||||
launch_darkly_sdk_key: str = Field(
|
||||
default="",
|
||||
description="The Launch Darkly SDK key",
|
||||
validation_alias="LAUNCH_DARKLY_SDK_KEY"
|
||||
validation_alias="LAUNCH_DARKLY_SDK_KEY",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(case_sensitive=True, extra="ignore")
|
||||
|
||||
@@ -23,7 +23,6 @@ DEBUG_LOG_FORMAT = (
|
||||
|
||||
|
||||
class LoggingConfig(BaseSettings):
|
||||
|
||||
level: str = Field(
|
||||
default="INFO",
|
||||
description="Logging level",
|
||||
|
||||
@@ -24,10 +24,10 @@ from .utils import remove_color_codes
|
||||
),
|
||||
("", ""),
|
||||
("hello", "hello"),
|
||||
("hello\x1B[31m world", "hello world"),
|
||||
("\x1B[36mHello,\x1B[32m World!", "Hello, World!"),
|
||||
("hello\x1b[31m world", "hello world"),
|
||||
("\x1b[36mHello,\x1b[32m World!", "Hello, World!"),
|
||||
(
|
||||
"\x1B[1m\x1B[31mError:\x1B[0m\x1B[31m file not found",
|
||||
"\x1b[1m\x1b[31mError:\x1b[0m\x1b[31m file not found",
|
||||
"Error: file not found",
|
||||
),
|
||||
],
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class RateLimitSettings(BaseSettings):
|
||||
redis_host: str = Field(
|
||||
default="redis://localhost:6379",
|
||||
description="Redis host",
|
||||
validation_alias="REDIS_HOST",
|
||||
)
|
||||
|
||||
redis_port: str = Field(
|
||||
default="6379", description="Redis port", validation_alias="REDIS_PORT"
|
||||
)
|
||||
|
||||
redis_password: str = Field(
|
||||
default="password",
|
||||
description="Redis password",
|
||||
validation_alias="REDIS_PASSWORD",
|
||||
)
|
||||
|
||||
requests_per_minute: int = Field(
|
||||
default=60,
|
||||
description="Maximum number of requests allowed per minute per API key",
|
||||
validation_alias="RATE_LIMIT_REQUESTS_PER_MINUTE",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(case_sensitive=True, extra="ignore")
|
||||
|
||||
|
||||
RATE_LIMIT_SETTINGS = RateLimitSettings()
|
||||
@@ -0,0 +1,51 @@
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
from redis import Redis
|
||||
|
||||
from .config import RATE_LIMIT_SETTINGS
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
def __init__(
|
||||
self,
|
||||
redis_host: str = RATE_LIMIT_SETTINGS.redis_host,
|
||||
redis_port: str = RATE_LIMIT_SETTINGS.redis_port,
|
||||
redis_password: str = RATE_LIMIT_SETTINGS.redis_password,
|
||||
requests_per_minute: int = RATE_LIMIT_SETTINGS.requests_per_minute,
|
||||
):
|
||||
self.redis = Redis(
|
||||
host=redis_host,
|
||||
port=int(redis_port),
|
||||
password=redis_password,
|
||||
decode_responses=True,
|
||||
)
|
||||
self.window = 60
|
||||
self.max_requests = requests_per_minute
|
||||
|
||||
async def check_rate_limit(self, api_key_id: str) -> Tuple[bool, int, int]:
|
||||
"""
|
||||
Check if request is within rate limits.
|
||||
|
||||
Args:
|
||||
api_key_id: The API key identifier to check
|
||||
|
||||
Returns:
|
||||
Tuple of (is_allowed, remaining_requests, reset_time)
|
||||
"""
|
||||
now = time.time()
|
||||
window_start = now - self.window
|
||||
key = f"ratelimit:{api_key_id}:1min"
|
||||
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.zremrangebyscore(key, 0, window_start)
|
||||
pipe.zadd(key, {str(now): now})
|
||||
pipe.zcount(key, window_start, now)
|
||||
pipe.expire(key, self.window)
|
||||
|
||||
_, _, request_count, _ = pipe.execute()
|
||||
|
||||
remaining = max(0, self.max_requests - request_count)
|
||||
reset_time = int(now + self.window)
|
||||
|
||||
return request_count <= self.max_requests, remaining, reset_time
|
||||
@@ -0,0 +1,32 @@
|
||||
from fastapi import HTTPException, Request
|
||||
from starlette.middleware.base import RequestResponseEndpoint
|
||||
|
||||
from .limiter import RateLimiter
|
||||
|
||||
|
||||
async def rate_limit_middleware(request: Request, call_next: RequestResponseEndpoint):
|
||||
"""FastAPI middleware for rate limiting API requests."""
|
||||
limiter = RateLimiter()
|
||||
|
||||
if not request.url.path.startswith("/api"):
|
||||
return await call_next(request)
|
||||
|
||||
api_key = request.headers.get("Authorization")
|
||||
if not api_key:
|
||||
return await call_next(request)
|
||||
|
||||
api_key = api_key.replace("Bearer ", "")
|
||||
|
||||
is_allowed, remaining, reset_time = await limiter.check_rate_limit(api_key)
|
||||
|
||||
if not is_allowed:
|
||||
raise HTTPException(
|
||||
status_code=429, detail="Rate limit exceeded. Please try again later."
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(limiter.max_requests)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
response.headers["X-RateLimit-Reset"] = str(reset_time)
|
||||
|
||||
return response
|
||||
250
autogpt_platform/autogpt_libs/poetry.lock
generated
250
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1091,19 +1091,22 @@ pyasn1 = ">=0.4.6,<0.7.0"
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.10.2"
|
||||
version = "2.9.2"
|
||||
description = "Data validation using Python type hints"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pydantic-2.10.2-py3-none-any.whl", hash = "sha256:cfb96e45951117c3024e6b67b25cdc33a3cb7b2fa62e239f7af1378358a1d99e"},
|
||||
{file = "pydantic-2.10.2.tar.gz", hash = "sha256:2bc2d7f17232e0841cbba4641e65ba1eb6fafb3a08de3a091ff3ce14a197c4fa"},
|
||||
{file = "pydantic-2.9.2-py3-none-any.whl", hash = "sha256:f048cec7b26778210e28a0459867920654d48e5e62db0958433636cde4254f12"},
|
||||
{file = "pydantic-2.9.2.tar.gz", hash = "sha256:d155cef71265d1e9807ed1c32b4c8deec042a44a50a4188b25ac67ecd81a9c0f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
annotated-types = ">=0.6.0"
|
||||
pydantic-core = "2.27.1"
|
||||
typing-extensions = ">=4.12.2"
|
||||
pydantic-core = "2.23.4"
|
||||
typing-extensions = [
|
||||
{version = ">=4.12.2", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=4.6.1", markers = "python_version < \"3.13\""},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
email = ["email-validator (>=2.0.0)"]
|
||||
@@ -1111,111 +1114,100 @@ timezone = ["tzdata"]
|
||||
|
||||
[[package]]
|
||||
name = "pydantic-core"
|
||||
version = "2.27.1"
|
||||
version = "2.23.4"
|
||||
description = "Core functionality for Pydantic validation and serialization"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:71a5e35c75c021aaf400ac048dacc855f000bdfed91614b4a726f7432f1f3d6a"},
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f82d068a2d6ecfc6e054726080af69a6764a10015467d7d7b9f66d6ed5afa23b"},
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:121ceb0e822f79163dd4699e4c54f5ad38b157084d97b34de8b232bcaad70278"},
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4603137322c18eaf2e06a4495f426aa8d8388940f3c457e7548145011bb68e05"},
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a33cd6ad9017bbeaa9ed78a2e0752c5e250eafb9534f308e7a5f7849b0b1bfb4"},
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:15cc53a3179ba0fcefe1e3ae50beb2784dede4003ad2dfd24f81bba4b23a454f"},
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:45d9c5eb9273aa50999ad6adc6be5e0ecea7e09dbd0d31bd0c65a55a2592ca08"},
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8bf7b66ce12a2ac52d16f776b31d16d91033150266eb796967a7e4621707e4f6"},
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:655d7dd86f26cb15ce8a431036f66ce0318648f8853d709b4167786ec2fa4807"},
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:5556470f1a2157031e676f776c2bc20acd34c1990ca5f7e56f1ebf938b9ab57c"},
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f69ed81ab24d5a3bd93861c8c4436f54afdf8e8cc421562b0c7504cf3be58206"},
|
||||
{file = "pydantic_core-2.27.1-cp310-none-win32.whl", hash = "sha256:f5a823165e6d04ccea61a9f0576f345f8ce40ed533013580e087bd4d7442b52c"},
|
||||
{file = "pydantic_core-2.27.1-cp310-none-win_amd64.whl", hash = "sha256:57866a76e0b3823e0b56692d1a0bf722bffb324839bb5b7226a7dbd6c9a40b17"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:ac3b20653bdbe160febbea8aa6c079d3df19310d50ac314911ed8cc4eb7f8cb8"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a5a8e19d7c707c4cadb8c18f5f60c843052ae83c20fa7d44f41594c644a1d330"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f7059ca8d64fea7f238994c97d91f75965216bcbe5f695bb44f354893f11d52"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bed0f8a0eeea9fb72937ba118f9db0cb7e90773462af7962d382445f3005e5a4"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a3cb37038123447cf0f3ea4c74751f6a9d7afef0eb71aa07bf5f652b5e6a132c"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:84286494f6c5d05243456e04223d5a9417d7f443c3b76065e75001beb26f88de"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:acc07b2cfc5b835444b44a9956846b578d27beeacd4b52e45489e93276241025"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4fefee876e07a6e9aad7a8c8c9f85b0cdbe7df52b8a9552307b09050f7512c7e"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:258c57abf1188926c774a4c94dd29237e77eda19462e5bb901d88adcab6af919"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:35c14ac45fcfdf7167ca76cc80b2001205a8d5d16d80524e13508371fb8cdd9c"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d1b26e1dff225c31897696cab7d4f0a315d4c0d9e8666dbffdb28216f3b17fdc"},
|
||||
{file = "pydantic_core-2.27.1-cp311-none-win32.whl", hash = "sha256:2cdf7d86886bc6982354862204ae3b2f7f96f21a3eb0ba5ca0ac42c7b38598b9"},
|
||||
{file = "pydantic_core-2.27.1-cp311-none-win_amd64.whl", hash = "sha256:3af385b0cee8df3746c3f406f38bcbfdc9041b5c2d5ce3e5fc6637256e60bbc5"},
|
||||
{file = "pydantic_core-2.27.1-cp311-none-win_arm64.whl", hash = "sha256:81f2ec23ddc1b476ff96563f2e8d723830b06dceae348ce02914a37cb4e74b89"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:9cbd94fc661d2bab2bc702cddd2d3370bbdcc4cd0f8f57488a81bcce90c7a54f"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5f8c4718cd44ec1580e180cb739713ecda2bdee1341084c1467802a417fe0f02"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15aae984e46de8d376df515f00450d1522077254ef6b7ce189b38ecee7c9677c"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ba5e3963344ff25fc8c40da90f44b0afca8cfd89d12964feb79ac1411a260ac"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:992cea5f4f3b29d6b4f7f1726ed8ee46c8331c6b4eed6db5b40134c6fe1768bb"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0325336f348dbee6550d129b1627cb8f5351a9dc91aad141ffb96d4937bd9529"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7597c07fbd11515f654d6ece3d0e4e5093edc30a436c63142d9a4b8e22f19c35"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3bbd5d8cc692616d5ef6fbbbd50dbec142c7e6ad9beb66b78a96e9c16729b089"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:dc61505e73298a84a2f317255fcc72b710b72980f3a1f670447a21efc88f8381"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:e1f735dc43da318cad19b4173dd1ffce1d84aafd6c9b782b3abc04a0d5a6f5bb"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:f4e5658dbffe8843a0f12366a4c2d1c316dbe09bb4dfbdc9d2d9cd6031de8aae"},
|
||||
{file = "pydantic_core-2.27.1-cp312-none-win32.whl", hash = "sha256:672ebbe820bb37988c4d136eca2652ee114992d5d41c7e4858cdd90ea94ffe5c"},
|
||||
{file = "pydantic_core-2.27.1-cp312-none-win_amd64.whl", hash = "sha256:66ff044fd0bb1768688aecbe28b6190f6e799349221fb0de0e6f4048eca14c16"},
|
||||
{file = "pydantic_core-2.27.1-cp312-none-win_arm64.whl", hash = "sha256:9a3b0793b1bbfd4146304e23d90045f2a9b5fd5823aa682665fbdaf2a6c28f3e"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:f216dbce0e60e4d03e0c4353c7023b202d95cbaeff12e5fd2e82ea0a66905073"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a2e02889071850bbfd36b56fd6bc98945e23670773bc7a76657e90e6b6603c08"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42b0e23f119b2b456d07ca91b307ae167cc3f6c846a7b169fca5326e32fdc6cf"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:764be71193f87d460a03f1f7385a82e226639732214b402f9aa61f0d025f0737"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1c00666a3bd2f84920a4e94434f5974d7bbc57e461318d6bb34ce9cdbbc1f6b2"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3ccaa88b24eebc0f849ce0a4d09e8a408ec5a94afff395eb69baf868f5183107"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c65af9088ac534313e1963443d0ec360bb2b9cba6c2909478d22c2e363d98a51"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:206b5cf6f0c513baffaeae7bd817717140770c74528f3e4c3e1cec7871ddd61a"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:062f60e512fc7fff8b8a9d680ff0ddaaef0193dba9fa83e679c0c5f5fbd018bc"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:a0697803ed7d4af5e4c1adf1670af078f8fcab7a86350e969f454daf598c4960"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:58ca98a950171f3151c603aeea9303ef6c235f692fe555e883591103da709b23"},
|
||||
{file = "pydantic_core-2.27.1-cp313-none-win32.whl", hash = "sha256:8065914ff79f7eab1599bd80406681f0ad08f8e47c880f17b416c9f8f7a26d05"},
|
||||
{file = "pydantic_core-2.27.1-cp313-none-win_amd64.whl", hash = "sha256:ba630d5e3db74c79300d9a5bdaaf6200172b107f263c98a0539eeecb857b2337"},
|
||||
{file = "pydantic_core-2.27.1-cp313-none-win_arm64.whl", hash = "sha256:45cf8588c066860b623cd11c4ba687f8d7175d5f7ef65f7129df8a394c502de5"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:5897bec80a09b4084aee23f9b73a9477a46c3304ad1d2d07acca19723fb1de62"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d0165ab2914379bd56908c02294ed8405c252250668ebcb438a55494c69f44ab"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b9af86e1d8e4cfc82c2022bfaa6f459381a50b94a29e95dcdda8442d6d83864"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f6c8a66741c5f5447e047ab0ba7a1c61d1e95580d64bce852e3df1f895c4067"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9a42d6a8156ff78981f8aa56eb6394114e0dedb217cf8b729f438f643608cbcd"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64c65f40b4cd8b0e049a8edde07e38b476da7e3aaebe63287c899d2cff253fa5"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fdcf339322a3fae5cbd504edcefddd5a50d9ee00d968696846f089b4432cf78"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bf99c8404f008750c846cb4ac4667b798a9f7de673ff719d705d9b2d6de49c5f"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8f1edcea27918d748c7e5e4d917297b2a0ab80cad10f86631e488b7cddf76a36"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-musllinux_1_1_armv7l.whl", hash = "sha256:159cac0a3d096f79ab6a44d77a961917219707e2a130739c64d4dd46281f5c2a"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:029d9757eb621cc6e1848fa0b0310310de7301057f623985698ed7ebb014391b"},
|
||||
{file = "pydantic_core-2.27.1-cp38-none-win32.whl", hash = "sha256:a28af0695a45f7060e6f9b7092558a928a28553366519f64083c63a44f70e618"},
|
||||
{file = "pydantic_core-2.27.1-cp38-none-win_amd64.whl", hash = "sha256:2d4567c850905d5eaaed2f7a404e61012a51caf288292e016360aa2b96ff38d4"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:e9386266798d64eeb19dd3677051f5705bf873e98e15897ddb7d76f477131967"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4228b5b646caa73f119b1ae756216b59cc6e2267201c27d3912b592c5e323b60"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b3dfe500de26c52abe0477dde16192ac39c98f05bf2d80e76102d394bd13854"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:aee66be87825cdf72ac64cb03ad4c15ffef4143dbf5c113f64a5ff4f81477bf9"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3b748c44bb9f53031c8cbc99a8a061bc181c1000c60a30f55393b6e9c45cc5bd"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ca038c7f6a0afd0b2448941b6ef9d5e1949e999f9e5517692eb6da58e9d44be"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e0bd57539da59a3e4671b90a502da9a28c72322a4f17866ba3ac63a82c4498e"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ac6c2c45c847bbf8f91930d88716a0fb924b51e0c6dad329b793d670ec5db792"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b94d4ba43739bbe8b0ce4262bcc3b7b9f31459ad120fb595627eaeb7f9b9ca01"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:00e6424f4b26fe82d44577b4c842d7df97c20be6439e8e685d0d715feceb9fb9"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:38de0a70160dd97540335b7ad3a74571b24f1dc3ed33f815f0880682e6880131"},
|
||||
{file = "pydantic_core-2.27.1-cp39-none-win32.whl", hash = "sha256:7ccebf51efc61634f6c2344da73e366c75e735960b5654b63d7e6f69a5885fa3"},
|
||||
{file = "pydantic_core-2.27.1-cp39-none-win_amd64.whl", hash = "sha256:a57847b090d7892f123726202b7daa20df6694cbd583b67a592e856bff603d6c"},
|
||||
{file = "pydantic_core-2.27.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3fa80ac2bd5856580e242dbc202db873c60a01b20309c8319b5c5986fbe53ce6"},
|
||||
{file = "pydantic_core-2.27.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d950caa237bb1954f1b8c9227b5065ba6875ac9771bb8ec790d956a699b78676"},
|
||||
{file = "pydantic_core-2.27.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0e4216e64d203e39c62df627aa882f02a2438d18a5f21d7f721621f7a5d3611d"},
|
||||
{file = "pydantic_core-2.27.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02a3d637bd387c41d46b002f0e49c52642281edacd2740e5a42f7017feea3f2c"},
|
||||
{file = "pydantic_core-2.27.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:161c27ccce13b6b0c8689418da3885d3220ed2eae2ea5e9b2f7f3d48f1d52c27"},
|
||||
{file = "pydantic_core-2.27.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:19910754e4cc9c63bc1c7f6d73aa1cfee82f42007e407c0f413695c2f7ed777f"},
|
||||
{file = "pydantic_core-2.27.1-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:e173486019cc283dc9778315fa29a363579372fe67045e971e89b6365cc035ed"},
|
||||
{file = "pydantic_core-2.27.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:af52d26579b308921b73b956153066481f064875140ccd1dfd4e77db89dbb12f"},
|
||||
{file = "pydantic_core-2.27.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:981fb88516bd1ae8b0cbbd2034678a39dedc98752f264ac9bc5839d3923fa04c"},
|
||||
{file = "pydantic_core-2.27.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5fde892e6c697ce3e30c61b239330fc5d569a71fefd4eb6512fc6caec9dd9e2f"},
|
||||
{file = "pydantic_core-2.27.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:816f5aa087094099fff7edabb5e01cc370eb21aa1a1d44fe2d2aefdfb5599b31"},
|
||||
{file = "pydantic_core-2.27.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c10c309e18e443ddb108f0ef64e8729363adbfd92d6d57beec680f6261556f3"},
|
||||
{file = "pydantic_core-2.27.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98476c98b02c8e9b2eec76ac4156fd006628b1b2d0ef27e548ffa978393fd154"},
|
||||
{file = "pydantic_core-2.27.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c3027001c28434e7ca5a6e1e527487051136aa81803ac812be51802150d880dd"},
|
||||
{file = "pydantic_core-2.27.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:7699b1df36a48169cdebda7ab5a2bac265204003f153b4bd17276153d997670a"},
|
||||
{file = "pydantic_core-2.27.1-pp39-pypy39_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:1c39b07d90be6b48968ddc8c19e7585052088fd7ec8d568bb31ff64c70ae3c97"},
|
||||
{file = "pydantic_core-2.27.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:46ccfe3032b3915586e469d4972973f893c0a2bb65669194a5bdea9bacc088c2"},
|
||||
{file = "pydantic_core-2.27.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:62ba45e21cf6571d7f716d903b5b7b6d2617e2d5d67c0923dc47b9d41369f840"},
|
||||
{file = "pydantic_core-2.27.1.tar.gz", hash = "sha256:62a763352879b84aa31058fc931884055fd75089cccbd9d58bb6afd01141b235"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b10bd51f823d891193d4717448fab065733958bdb6a6b351967bd349d48d5c9b"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4fc714bdbfb534f94034efaa6eadd74e5b93c8fa6315565a222f7b6f42ca1166"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63e46b3169866bd62849936de036f901a9356e36376079b05efa83caeaa02ceb"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed1a53de42fbe34853ba90513cea21673481cd81ed1be739f7f2efb931b24916"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cfdd16ab5e59fc31b5e906d1a3f666571abc367598e3e02c83403acabc092e07"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:255a8ef062cbf6674450e668482456abac99a5583bbafb73f9ad469540a3a232"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a7cd62e831afe623fbb7aabbb4fe583212115b3ef38a9f6b71869ba644624a2"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f09e2ff1f17c2b51f2bc76d1cc33da96298f0a036a137f5440ab3ec5360b624f"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e38e63e6f3d1cec5a27e0afe90a085af8b6806ee208b33030e65b6516353f1a3"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0dbd8dbed2085ed23b5c04afa29d8fd2771674223135dc9bc937f3c09284d071"},
|
||||
{file = "pydantic_core-2.23.4-cp310-none-win32.whl", hash = "sha256:6531b7ca5f951d663c339002e91aaebda765ec7d61b7d1e3991051906ddde119"},
|
||||
{file = "pydantic_core-2.23.4-cp310-none-win_amd64.whl", hash = "sha256:7c9129eb40958b3d4500fa2467e6a83356b3b61bfff1b414c7361d9220f9ae8f"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:77733e3892bb0a7fa797826361ce8a9184d25c8dffaec60b7ffe928153680ba8"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1b84d168f6c48fabd1f2027a3d1bdfe62f92cade1fb273a5d68e621da0e44e6d"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df49e7a0861a8c36d089c1ed57d308623d60416dab2647a4a17fe050ba85de0e"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ff02b6d461a6de369f07ec15e465a88895f3223eb75073ffea56b84d9331f607"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:996a38a83508c54c78a5f41456b0103c30508fed9abcad0a59b876d7398f25fd"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d97683ddee4723ae8c95d1eddac7c192e8c552da0c73a925a89fa8649bf13eea"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:216f9b2d7713eb98cb83c80b9c794de1f6b7e3145eef40400c62e86cee5f4e1e"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6f783e0ec4803c787bcea93e13e9932edab72068f68ecffdf86a99fd5918878b"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d0776dea117cf5272382634bd2a5c1b6eb16767c223c6a5317cd3e2a757c61a0"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d5f7a395a8cf1621939692dba2a6b6a830efa6b3cee787d82c7de1ad2930de64"},
|
||||
{file = "pydantic_core-2.23.4-cp311-none-win32.whl", hash = "sha256:74b9127ffea03643e998e0c5ad9bd3811d3dac8c676e47db17b0ee7c3c3bf35f"},
|
||||
{file = "pydantic_core-2.23.4-cp311-none-win_amd64.whl", hash = "sha256:98d134c954828488b153d88ba1f34e14259284f256180ce659e8d83e9c05eaa3"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f3e0da4ebaef65158d4dfd7d3678aad692f7666877df0002b8a522cdf088f231"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f69a8e0b033b747bb3e36a44e7732f0c99f7edd5cea723d45bc0d6e95377ffee"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:723314c1d51722ab28bfcd5240d858512ffd3116449c557a1336cbe3919beb87"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bb2802e667b7051a1bebbfe93684841cc9351004e2badbd6411bf357ab8d5ac8"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d18ca8148bebe1b0a382a27a8ee60350091a6ddaf475fa05ef50dc35b5df6327"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:33e3d65a85a2a4a0dc3b092b938a4062b1a05f3a9abde65ea93b233bca0e03f2"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:128585782e5bfa515c590ccee4b727fb76925dd04a98864182b22e89a4e6ed36"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:68665f4c17edcceecc112dfed5dbe6f92261fb9d6054b47d01bf6371a6196126"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:20152074317d9bed6b7a95ade3b7d6054845d70584216160860425f4fbd5ee9e"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9261d3ce84fa1d38ed649c3638feefeae23d32ba9182963e465d58d62203bd24"},
|
||||
{file = "pydantic_core-2.23.4-cp312-none-win32.whl", hash = "sha256:4ba762ed58e8d68657fc1281e9bb72e1c3e79cc5d464be146e260c541ec12d84"},
|
||||
{file = "pydantic_core-2.23.4-cp312-none-win_amd64.whl", hash = "sha256:97df63000f4fea395b2824da80e169731088656d1818a11b95f3b173747b6cd9"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:7530e201d10d7d14abce4fb54cfe5b94a0aefc87da539d0346a484ead376c3cc"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:df933278128ea1cd77772673c73954e53a1c95a4fdf41eef97c2b779271bd0bd"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cb3da3fd1b6a5d0279a01877713dbda118a2a4fc6f0d821a57da2e464793f05"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:42c6dcb030aefb668a2b7009c85b27f90e51e6a3b4d5c9bc4c57631292015b0d"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:696dd8d674d6ce621ab9d45b205df149399e4bb9aa34102c970b721554828510"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2971bb5ffe72cc0f555c13e19b23c85b654dd2a8f7ab493c262071377bfce9f6"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8394d940e5d400d04cad4f75c0598665cbb81aecefaca82ca85bd28264af7f9b"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0dff76e0602ca7d4cdaacc1ac4c005e0ce0dcfe095d5b5259163a80d3a10d327"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7d32706badfe136888bdea71c0def994644e09fff0bfe47441deaed8e96fdbc6"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ed541d70698978a20eb63d8c5d72f2cc6d7079d9d90f6b50bad07826f1320f5f"},
|
||||
{file = "pydantic_core-2.23.4-cp313-none-win32.whl", hash = "sha256:3d5639516376dce1940ea36edf408c554475369f5da2abd45d44621cb616f769"},
|
||||
{file = "pydantic_core-2.23.4-cp313-none-win_amd64.whl", hash = "sha256:5a1504ad17ba4210df3a045132a7baeeba5a200e930f57512ee02909fc5c4cb5"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:d4488a93b071c04dc20f5cecc3631fc78b9789dd72483ba15d423b5b3689b555"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:81965a16b675b35e1d09dd14df53f190f9129c0202356ed44ab2728b1c905658"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ffa2ebd4c8530079140dd2d7f794a9d9a73cbb8e9d59ffe24c63436efa8f271"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:61817945f2fe7d166e75fbfb28004034b48e44878177fc54d81688e7b85a3665"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:29d2c342c4bc01b88402d60189f3df065fb0dda3654744d5a165a5288a657368"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5e11661ce0fd30a6790e8bcdf263b9ec5988e95e63cf901972107efc49218b13"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d18368b137c6295db49ce7218b1a9ba15c5bc254c96d7c9f9e924a9bc7825ad"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ec4e55f79b1c4ffb2eecd8a0cfba9955a2588497d96851f4c8f99aa4a1d39b12"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:374a5e5049eda9e0a44c696c7ade3ff355f06b1fe0bb945ea3cac2bc336478a2"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5c364564d17da23db1106787675fc7af45f2f7b58b4173bfdd105564e132e6fb"},
|
||||
{file = "pydantic_core-2.23.4-cp38-none-win32.whl", hash = "sha256:d7a80d21d613eec45e3d41eb22f8f94ddc758a6c4720842dc74c0581f54993d6"},
|
||||
{file = "pydantic_core-2.23.4-cp38-none-win_amd64.whl", hash = "sha256:5f5ff8d839f4566a474a969508fe1c5e59c31c80d9e140566f9a37bba7b8d556"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a4fa4fc04dff799089689f4fd502ce7d59de529fc2f40a2c8836886c03e0175a"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0a7df63886be5e270da67e0966cf4afbae86069501d35c8c1b3b6c168f42cb36"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dcedcd19a557e182628afa1d553c3895a9f825b936415d0dbd3cd0bbcfd29b4b"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f54b118ce5de9ac21c363d9b3caa6c800341e8c47a508787e5868c6b79c9323"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86d2f57d3e1379a9525c5ab067b27dbb8a0642fb5d454e17a9ac434f9ce523e3"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:de6d1d1b9e5101508cb37ab0d972357cac5235f5c6533d1071964c47139257df"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1278e0d324f6908e872730c9102b0112477a7f7cf88b308e4fc36ce1bdb6d58c"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9a6b5099eeec78827553827f4c6b8615978bb4b6a88e5d9b93eddf8bb6790f55"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:e55541f756f9b3ee346b840103f32779c695a19826a4c442b7954550a0972040"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a5c7ba8ffb6d6f8f2ab08743be203654bb1aaa8c9dcb09f82ddd34eadb695605"},
|
||||
{file = "pydantic_core-2.23.4-cp39-none-win32.whl", hash = "sha256:37b0fe330e4a58d3c58b24d91d1eb102aeec675a3db4c292ec3928ecd892a9a6"},
|
||||
{file = "pydantic_core-2.23.4-cp39-none-win_amd64.whl", hash = "sha256:1498bec4c05c9c787bde9125cfdcc63a41004ff167f495063191b863399b1a29"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f455ee30a9d61d3e1a15abd5068827773d6e4dc513e795f380cdd59932c782d5"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1e90d2e3bd2c3863d48525d297cd143fe541be8bbf6f579504b9712cb6b643ec"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e203fdf807ac7e12ab59ca2bfcabb38c7cf0b33c41efeb00f8e5da1d86af480"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e08277a400de01bc72436a0ccd02bdf596631411f592ad985dcee21445bd0068"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f220b0eea5965dec25480b6333c788fb72ce5f9129e8759ef876a1d805d00801"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d06b0c8da4f16d1d1e352134427cb194a0a6e19ad5db9161bf32b2113409e728"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:ba1a0996f6c2773bd83e63f18914c1de3c9dd26d55f4ac302a7efe93fb8e7433"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:9a5bce9d23aac8f0cf0836ecfc033896aa8443b501c58d0602dbfd5bd5b37753"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:78ddaaa81421a29574a682b3179d4cf9e6d405a09b99d93ddcf7e5239c742e21"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:883a91b5dd7d26492ff2f04f40fbb652de40fcc0afe07e8129e8ae779c2110eb"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88ad334a15b32a791ea935af224b9de1bf99bcd62fabf745d5f3442199d86d59"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:233710f069d251feb12a56da21e14cca67994eab08362207785cf8c598e74577"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:19442362866a753485ba5e4be408964644dd6a09123d9416c54cd49171f50744"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:624e278a7d29b6445e4e813af92af37820fafb6dcc55c012c834f9e26f9aaaef"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f5ef8f42bec47f21d07668a043f077d507e5bf4e668d5c6dfe6aaba89de1a5b8"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:aea443fffa9fbe3af1a9ba721a87f926fe548d32cab71d188a6ede77d0ff244e"},
|
||||
{file = "pydantic_core-2.23.4.tar.gz", hash = "sha256:2584f7cf844ac4d970fba483a717dbe10c1c1c96a969bf65d61ffe94df1b2863"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1362,13 +1354,13 @@ websockets = ">=11,<13"
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "5.2.0"
|
||||
version = "5.2.1"
|
||||
description = "Python client for Redis database and key-value store"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "redis-5.2.0-py3-none-any.whl", hash = "sha256:ae174f2bb3b1bf2b09d54bf3e51fbc1469cf6c10aa03e21141f51969801a7897"},
|
||||
{file = "redis-5.2.0.tar.gz", hash = "sha256:0b1087665a771b1ff2e003aa5bdd354f15a70c9e25d5a7dbf9c722c16528a7b0"},
|
||||
{file = "redis-5.2.1-py3-none-any.whl", hash = "sha256:ee7e1056b9aea0f04c6c2ed59452947f34c4940ee025f5dd83e6a6418b6989e4"},
|
||||
{file = "redis-5.2.1.tar.gz", hash = "sha256:16f2e22dff21d5125e8481515e386711a34cbec50f0e44413dd7d9c060a54e0f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1415,29 +1407,29 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.8.0"
|
||||
version = "0.8.2"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "ruff-0.8.0-py3-none-linux_armv6l.whl", hash = "sha256:fcb1bf2cc6706adae9d79c8d86478677e3bbd4ced796ccad106fd4776d395fea"},
|
||||
{file = "ruff-0.8.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:295bb4c02d58ff2ef4378a1870c20af30723013f441c9d1637a008baaf928c8b"},
|
||||
{file = "ruff-0.8.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7b1f1c76b47c18fa92ee78b60d2d20d7e866c55ee603e7d19c1e991fad933a9a"},
|
||||
{file = "ruff-0.8.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb0d4f250a7711b67ad513fde67e8870109e5ce590a801c3722580fe98c33a99"},
|
||||
{file = "ruff-0.8.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0e55cce9aa93c5d0d4e3937e47b169035c7e91c8655b0974e61bb79cf398d49c"},
|
||||
{file = "ruff-0.8.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f4cd64916d8e732ce6b87f3f5296a8942d285bbbc161acee7fe561134af64f9"},
|
||||
{file = "ruff-0.8.0-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:c5c1466be2a2ebdf7c5450dd5d980cc87c8ba6976fb82582fea18823da6fa362"},
|
||||
{file = "ruff-0.8.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2dabfd05b96b7b8f2da00d53c514eea842bff83e41e1cceb08ae1966254a51df"},
|
||||
{file = "ruff-0.8.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:facebdfe5a5af6b1588a1d26d170635ead6892d0e314477e80256ef4a8470cf3"},
|
||||
{file = "ruff-0.8.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87a8e86bae0dbd749c815211ca11e3a7bd559b9710746c559ed63106d382bd9c"},
|
||||
{file = "ruff-0.8.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:85e654f0ded7befe2d61eeaf3d3b1e4ef3894469cd664ffa85006c7720f1e4a2"},
|
||||
{file = "ruff-0.8.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:83a55679c4cb449fa527b8497cadf54f076603cc36779b2170b24f704171ce70"},
|
||||
{file = "ruff-0.8.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:812e2052121634cf13cd6fddf0c1871d0ead1aad40a1a258753c04c18bb71bbd"},
|
||||
{file = "ruff-0.8.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:780d5d8523c04202184405e60c98d7595bdb498c3c6abba3b6d4cdf2ca2af426"},
|
||||
{file = "ruff-0.8.0-py3-none-win32.whl", hash = "sha256:5fdb6efecc3eb60bba5819679466471fd7d13c53487df7248d6e27146e985468"},
|
||||
{file = "ruff-0.8.0-py3-none-win_amd64.whl", hash = "sha256:582891c57b96228d146725975fbb942e1f30a0c4ba19722e692ca3eb25cc9b4f"},
|
||||
{file = "ruff-0.8.0-py3-none-win_arm64.whl", hash = "sha256:ba93e6294e9a737cd726b74b09a6972e36bb511f9a102f1d9a7e1ce94dd206a6"},
|
||||
{file = "ruff-0.8.0.tar.gz", hash = "sha256:a7ccfe6331bf8c8dad715753e157457faf7351c2b69f62f32c165c2dbcbacd44"},
|
||||
{file = "ruff-0.8.2-py3-none-linux_armv6l.whl", hash = "sha256:c49ab4da37e7c457105aadfd2725e24305ff9bc908487a9bf8d548c6dad8bb3d"},
|
||||
{file = "ruff-0.8.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ec016beb69ac16be416c435828be702ee694c0d722505f9c1f35e1b9c0cc1bf5"},
|
||||
{file = "ruff-0.8.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:f05cdf8d050b30e2ba55c9b09330b51f9f97d36d4673213679b965d25a785f3c"},
|
||||
{file = "ruff-0.8.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:60f578c11feb1d3d257b2fb043ddb47501ab4816e7e221fbb0077f0d5d4e7b6f"},
|
||||
{file = "ruff-0.8.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cbd5cf9b0ae8f30eebc7b360171bd50f59ab29d39f06a670b3e4501a36ba5897"},
|
||||
{file = "ruff-0.8.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b402ddee3d777683de60ff76da801fa7e5e8a71038f57ee53e903afbcefdaa58"},
|
||||
{file = "ruff-0.8.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:705832cd7d85605cb7858d8a13d75993c8f3ef1397b0831289109e953d833d29"},
|
||||
{file = "ruff-0.8.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:32096b41aaf7a5cc095fa45b4167b890e4c8d3fd217603f3634c92a541de7248"},
|
||||
{file = "ruff-0.8.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e769083da9439508833cfc7c23e351e1809e67f47c50248250ce1ac52c21fb93"},
|
||||
{file = "ruff-0.8.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fe716592ae8a376c2673fdfc1f5c0c193a6d0411f90a496863c99cd9e2ae25d"},
|
||||
{file = "ruff-0.8.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:81c148825277e737493242b44c5388a300584d73d5774defa9245aaef55448b0"},
|
||||
{file = "ruff-0.8.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d261d7850c8367704874847d95febc698a950bf061c9475d4a8b7689adc4f7fa"},
|
||||
{file = "ruff-0.8.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1ca4e3a87496dc07d2427b7dd7ffa88a1e597c28dad65ae6433ecb9f2e4f022f"},
|
||||
{file = "ruff-0.8.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:729850feed82ef2440aa27946ab39c18cb4a8889c1128a6d589ffa028ddcfc22"},
|
||||
{file = "ruff-0.8.2-py3-none-win32.whl", hash = "sha256:ac42caaa0411d6a7d9594363294416e0e48fc1279e1b0e948391695db2b3d5b1"},
|
||||
{file = "ruff-0.8.2-py3-none-win_amd64.whl", hash = "sha256:2aae99ec70abf43372612a838d97bfe77d45146254568d94926e8ed5bbb409ea"},
|
||||
{file = "ruff-0.8.2-py3-none-win_arm64.whl", hash = "sha256:fb88e2a506b70cfbc2de6fae6681c4f944f7dd5f2fe87233a7233d888bad73e8"},
|
||||
{file = "ruff-0.8.2.tar.gz", hash = "sha256:b84f4f414dda8ac7f75075c1fa0b905ac0ff25361f42e6d5da681a465e0f78e5"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1852,4 +1844,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "54bf6e076ec4d09be2307f07240018459dd6594efdc55a2dc2dc1d673184587e"
|
||||
content-hash = "2d92dded4ebeff76a3d1dd0bb6b348e8177f053918141b4c4828442da318a907"
|
||||
|
||||
@@ -10,7 +10,7 @@ packages = [{ include = "autogpt_libs" }]
|
||||
colorama = "^0.4.6"
|
||||
expiringdict = "^1.2.2"
|
||||
google-cloud-logging = "^3.11.3"
|
||||
pydantic = "^2.10.2"
|
||||
pydantic = "^2.9.2"
|
||||
pydantic-settings = "^2.6.1"
|
||||
pyjwt = "^2.10.0"
|
||||
pytest-asyncio = "^0.24.0"
|
||||
@@ -20,8 +20,8 @@ python-dotenv = "^1.0.1"
|
||||
supabase = "^2.10.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
redis = "^5.2.0"
|
||||
ruff = "^0.8.0"
|
||||
redis = "^5.2.1"
|
||||
ruff = "^0.8.2"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -6,18 +6,23 @@ ENV PYTHONUNBUFFERED 1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN echo 'Acquire::http::Pipeline-Depth 0;\nAcquire::http::No-Cache true;\nAcquire::BrokenProxy true;\n' > /etc/apt/apt.conf.d/99fixbadproxy
|
||||
|
||||
RUN apt-get update --allow-releaseinfo-change --fix-missing
|
||||
|
||||
# Install build dependencies
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y build-essential curl ffmpeg wget libcurl4-gnutls-dev libexpat1-dev libpq5 gettext libz-dev libssl-dev postgresql-client git \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENV POETRY_VERSION=1.8.3 \
|
||||
POETRY_HOME="/opt/poetry" \
|
||||
POETRY_NO_INTERACTION=1 \
|
||||
POETRY_VIRTUALENVS_CREATE=false \
|
||||
PATH="$POETRY_HOME/bin:$PATH"
|
||||
RUN apt-get install -y build-essential
|
||||
RUN apt-get install -y libpq5
|
||||
RUN apt-get install -y libz-dev
|
||||
RUN apt-get install -y libssl-dev
|
||||
RUN apt-get install -y postgresql-client
|
||||
|
||||
ENV POETRY_VERSION=1.8.3
|
||||
ENV POETRY_HOME=/opt/poetry
|
||||
ENV POETRY_NO_INTERACTION=1
|
||||
ENV POETRY_VIRTUALENVS_CREATE=false
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||
RUN pip3 install --upgrade pip setuptools
|
||||
|
||||
@@ -39,11 +44,11 @@ FROM python:3.11.10-slim-bookworm AS server_dependencies
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENV POETRY_VERSION=1.8.3 \
|
||||
POETRY_HOME="/opt/poetry" \
|
||||
POETRY_NO_INTERACTION=1 \
|
||||
POETRY_VIRTUALENVS_CREATE=false \
|
||||
PATH="$POETRY_HOME/bin:$PATH"
|
||||
ENV POETRY_VERSION=1.8.3
|
||||
ENV POETRY_HOME=/opt/poetry
|
||||
ENV POETRY_NO_INTERACTION=1
|
||||
ENV POETRY_VIRTUALENVS_CREATE=false
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
|
||||
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||
|
||||
@@ -200,4 +200,4 @@ To add a new agent block, you need to create a new class that inherits from `Blo
|
||||
* `run` method: the main logic of the block.
|
||||
* `test_input` & `test_output`: the sample input and output data for the block, which will be used to auto-test the block.
|
||||
* You can mock the functions declared in the block using the `test_mock` field for your unit tests.
|
||||
* Once you finish creating the block, you can test it by running `pytest -s test/block/test_block.py`.
|
||||
* Once you finish creating the block, you can test it by running `poetry run pytest -s test/block/test_block.py`.
|
||||
|
||||
@@ -15,10 +15,10 @@ modules = [
|
||||
if f.is_file() and f.name != "__init__.py"
|
||||
]
|
||||
for module in modules:
|
||||
if not re.match("^[a-z_.]+$", module):
|
||||
if not re.match("^[a-z0-9_.]+$", module):
|
||||
raise ValueError(
|
||||
f"Block module {module} error: module name must be lowercase, "
|
||||
"separated by underscores, and contain only alphabet characters"
|
||||
"and contain only alphanumeric characters and underscores."
|
||||
)
|
||||
|
||||
importlib.import_module(f".{module}", package=__name__)
|
||||
|
||||
@@ -12,6 +12,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
|
||||
class ImageSize(str, Enum):
|
||||
@@ -101,12 +102,10 @@ class ImageGenModel(str, Enum):
|
||||
|
||||
class AIImageGeneratorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[Literal["replicate"], Literal["api_key"]] = (
|
||||
CredentialsField(
|
||||
provider="replicate",
|
||||
supported_credential_types={"api_key"},
|
||||
description="Enter your Replicate API key to access the image generation API. You can obtain an API key from https://replicate.com/account/api-tokens.",
|
||||
)
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Enter your Replicate API key to access the image generation API. You can obtain an API key from https://replicate.com/account/api-tokens.",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="Text prompt for image generation",
|
||||
|
||||
@@ -13,6 +13,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -54,13 +55,11 @@ class NormalizationStrategy(str, Enum):
|
||||
|
||||
class AIMusicGeneratorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[Literal["replicate"], Literal["api_key"]] = (
|
||||
CredentialsField(
|
||||
provider="replicate",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Replicate integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="The Replicate integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="A description of the music you want to generate",
|
||||
|
||||
@@ -12,6 +12,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import requests
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
@@ -140,13 +141,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class AIShortformVideoCreatorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[Literal["revid"], Literal["api_key"]] = (
|
||||
CredentialsField(
|
||||
provider="revid",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The revid.ai integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REVID], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="The revid.ai integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
script: str = SchemaField(
|
||||
description="""1. Use short and punctuated sentences\n\n2. Use linebreaks to create a new clip\n\n3. Text outside of brackets is spoken by the AI, and [text between brackets] will be used to guide the visual generation. For example, [close-up of a cat] will show a close-up of a cat.""",
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import re
|
||||
from typing import Any, List
|
||||
|
||||
from jinja2 import BaseLoader, Environment
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.mock import MockObject
|
||||
from backend.util.text import TextFormatter
|
||||
|
||||
jinja = Environment(loader=BaseLoader())
|
||||
formatter = TextFormatter()
|
||||
|
||||
|
||||
class StoreValueBlock(Block):
|
||||
@@ -304,9 +302,9 @@ class AgentOutputBlock(Block):
|
||||
"""
|
||||
if input_data.format:
|
||||
try:
|
||||
fmt = re.sub(r"(?<!{){[ a-zA-Z0-9_]+}", r"{\g<0>}", input_data.format)
|
||||
template = jinja.from_string(fmt)
|
||||
yield "output", template.render({input_data.name: input_data.value})
|
||||
yield "output", formatter.format_string(
|
||||
input_data.format, {input_data.name: input_data.value}
|
||||
)
|
||||
except Exception as e:
|
||||
yield "output", f"Error: {e}, {input_data.value}"
|
||||
else:
|
||||
@@ -494,3 +492,101 @@ class NoteBlock(Block):
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "output", input_data.text
|
||||
|
||||
|
||||
class CreateDictionaryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: dict[str, Any] = SchemaField(
|
||||
description="Key-value pairs to create the dictionary with",
|
||||
placeholder="e.g., {'name': 'Alice', 'age': 25}",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
dictionary: dict[str, Any] = SchemaField(
|
||||
description="The created dictionary containing the specified key-value pairs"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if dictionary creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b924ddf4-de4f-4b56-9a85-358930dcbc91",
|
||||
description="Creates a dictionary with the specified key-value pairs. Use this when you know all the values you want to add upfront.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=CreateDictionaryBlock.Input,
|
||||
output_schema=CreateDictionaryBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"values": {"name": "Alice", "age": 25, "city": "New York"},
|
||||
},
|
||||
{
|
||||
"values": {"numbers": [1, 2, 3], "active": True, "score": 95.5},
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"dictionary",
|
||||
{"name": "Alice", "age": 25, "city": "New York"},
|
||||
),
|
||||
(
|
||||
"dictionary",
|
||||
{"numbers": [1, 2, 3], "active": True, "score": 95.5},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# The values are already validated by Pydantic schema
|
||||
yield "dictionary", input_data.values
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create dictionary: {str(e)}"
|
||||
|
||||
|
||||
class CreateListBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: List[Any] = SchemaField(
|
||||
description="A list of values to be combined into a new list.",
|
||||
placeholder="e.g., ['Alice', 25, True]",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
list: List[Any] = SchemaField(
|
||||
description="The created list containing the specified values."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if list creation failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a912d5c7-6e00-4542-b2a9-8034136930e4",
|
||||
description="Creates a list with the specified values. Use this when you know all the values you want to add upfront.",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=CreateListBlock.Input,
|
||||
output_schema=CreateListBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"values": ["Alice", 25, True],
|
||||
},
|
||||
{
|
||||
"values": [1, 2, 3, "four", {"key": "value"}],
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"list",
|
||||
["Alice", 25, True],
|
||||
),
|
||||
(
|
||||
"list",
|
||||
[1, 2, 3, "four", {"key": "value"}],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# The values are already validated by Pydantic schema
|
||||
yield "list", input_data.values
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create list: {str(e)}"
|
||||
|
||||
190
autogpt_platform/backend/backend/blocks/code_executor.py
Normal file
190
autogpt_platform/backend/backend/blocks/code_executor.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from e2b_code_interpreter import Sandbox
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="e2b",
|
||||
api_key=SecretStr("mock-e2b-api-key"),
|
||||
title="Mock E2B API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.type,
|
||||
}
|
||||
|
||||
|
||||
class ProgrammingLanguage(Enum):
|
||||
PYTHON = "python"
|
||||
JAVASCRIPT = "js"
|
||||
BASH = "bash"
|
||||
R = "r"
|
||||
JAVA = "java"
|
||||
|
||||
|
||||
class CodeExecutionBlock(Block):
|
||||
# TODO : Add support to upload and download files
|
||||
# Currently, You can customized the CPU and Memory, only by creating a pre customized sandbox template
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.E2B], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
|
||||
)
|
||||
|
||||
# Todo : Option to run commond in background
|
||||
setup_commands: list[str] = SchemaField(
|
||||
description=(
|
||||
"Shell commands to set up the sandbox before running the code. "
|
||||
"You can use `curl` or `git` to install your desired Debian based "
|
||||
"package manager. `pip` and `npm` are pre-installed.\n\n"
|
||||
"These commands are executed with `sh`, in the foreground."
|
||||
),
|
||||
placeholder="pip install cowsay",
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
code: str = SchemaField(
|
||||
description="Code to execute in the sandbox",
|
||||
placeholder="print('Hello, World!')",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
language: ProgrammingLanguage = SchemaField(
|
||||
description="Programming language to execute",
|
||||
default=ProgrammingLanguage.PYTHON,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
timeout: int = SchemaField(
|
||||
description="Execution timeout in seconds", default=300
|
||||
)
|
||||
|
||||
template_id: str = SchemaField(
|
||||
description=(
|
||||
"You can use an E2B sandbox template by entering its ID here. "
|
||||
"Check out the E2B docs for more details: "
|
||||
"[E2B - Sandbox template](https://e2b.dev/docs/sandbox-template)"
|
||||
),
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: str = SchemaField(description="Response from code execution")
|
||||
stdout_logs: str = SchemaField(
|
||||
description="Standard output logs from execution"
|
||||
)
|
||||
stderr_logs: str = SchemaField(description="Standard error logs from execution")
|
||||
error: str = SchemaField(description="Error message if execution failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0b02b072-abe7-11ef-8372-fb5d162dd712",
|
||||
description="Executes code in an isolated sandbox environment with internet access.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=CodeExecutionBlock.Input,
|
||||
output_schema=CodeExecutionBlock.Output,
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"code": "print('Hello World')",
|
||||
"language": ProgrammingLanguage.PYTHON.value,
|
||||
"setup_commands": [],
|
||||
"timeout": 300,
|
||||
"template_id": "",
|
||||
},
|
||||
test_output=[
|
||||
("response", "Hello World"),
|
||||
("stdout_logs", "Hello World\n"),
|
||||
],
|
||||
test_mock={
|
||||
"execute_code": lambda code, language, setup_commands, timeout, api_key, template_id: (
|
||||
"Hello World",
|
||||
"Hello World\n",
|
||||
"",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def execute_code(
|
||||
self,
|
||||
code: str,
|
||||
language: ProgrammingLanguage,
|
||||
setup_commands: list[str],
|
||||
timeout: int,
|
||||
api_key: str,
|
||||
template_id: str,
|
||||
):
|
||||
try:
|
||||
sandbox = None
|
||||
if template_id:
|
||||
sandbox = Sandbox(
|
||||
template=template_id, api_key=api_key, timeout=timeout
|
||||
)
|
||||
else:
|
||||
sandbox = Sandbox(api_key=api_key, timeout=timeout)
|
||||
|
||||
if not sandbox:
|
||||
raise Exception("Sandbox not created")
|
||||
|
||||
# Running setup commands
|
||||
for cmd in setup_commands:
|
||||
sandbox.commands.run(cmd)
|
||||
|
||||
# Executing the code
|
||||
execution = sandbox.run_code(
|
||||
code,
|
||||
language=language.value,
|
||||
on_error=lambda e: sandbox.kill(), # Kill the sandbox if there is an error
|
||||
)
|
||||
|
||||
if execution.error:
|
||||
raise Exception(execution.error)
|
||||
|
||||
response = execution.text
|
||||
stdout_logs = "".join(execution.logs.stdout)
|
||||
stderr_logs = "".join(execution.logs.stderr)
|
||||
|
||||
return response, stdout_logs, stderr_logs
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
response, stdout_logs, stderr_logs = self.execute_code(
|
||||
input_data.code,
|
||||
input_data.language,
|
||||
input_data.setup_commands,
|
||||
input_data.timeout,
|
||||
credentials.api_key.get_secret_value(),
|
||||
input_data.template_id,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield "response", response
|
||||
if stdout_logs:
|
||||
yield "stdout_logs", stdout_logs
|
||||
if stderr_logs:
|
||||
yield "stderr_logs", stderr_logs
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
110
autogpt_platform/backend/backend/blocks/code_extraction_block.py
Normal file
110
autogpt_platform/backend/backend/blocks/code_extraction_block.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import re
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class CodeExtractionBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(
|
||||
description="Text containing code blocks to extract (e.g., AI response)",
|
||||
placeholder="Enter text containing code blocks",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
html: str = SchemaField(description="Extracted HTML code")
|
||||
css: str = SchemaField(description="Extracted CSS code")
|
||||
javascript: str = SchemaField(description="Extracted JavaScript code")
|
||||
python: str = SchemaField(description="Extracted Python code")
|
||||
sql: str = SchemaField(description="Extracted SQL code")
|
||||
java: str = SchemaField(description="Extracted Java code")
|
||||
cpp: str = SchemaField(description="Extracted C++ code")
|
||||
csharp: str = SchemaField(description="Extracted C# code")
|
||||
json_code: str = SchemaField(description="Extracted JSON code")
|
||||
bash: str = SchemaField(description="Extracted Bash code")
|
||||
php: str = SchemaField(description="Extracted PHP code")
|
||||
ruby: str = SchemaField(description="Extracted Ruby code")
|
||||
yaml: str = SchemaField(description="Extracted YAML code")
|
||||
markdown: str = SchemaField(description="Extracted Markdown code")
|
||||
typescript: str = SchemaField(description="Extracted TypeScript code")
|
||||
xml: str = SchemaField(description="Extracted XML code")
|
||||
remaining_text: str = SchemaField(
|
||||
description="Remaining text after code extraction"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d3a7d896-3b78-4f44-8b4b-48fbf4f0bcd8",
|
||||
description="Extracts code blocks from text and identifies their programming languages",
|
||||
categories={BlockCategory.TEXT},
|
||||
input_schema=CodeExtractionBlock.Input,
|
||||
output_schema=CodeExtractionBlock.Output,
|
||||
test_input={
|
||||
"text": "Here's a Python example:\n```python\nprint('Hello World')\n```\nAnd some HTML:\n```html\n<h1>Title</h1>\n```"
|
||||
},
|
||||
test_output=[
|
||||
("html", "<h1>Title</h1>"),
|
||||
("python", "print('Hello World')"),
|
||||
("remaining_text", "Here's a Python example:\nAnd some HTML:"),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# List of supported programming languages with mapped aliases
|
||||
language_aliases = {
|
||||
"html": ["html", "htm"],
|
||||
"css": ["css"],
|
||||
"javascript": ["javascript", "js"],
|
||||
"python": ["python", "py"],
|
||||
"sql": ["sql"],
|
||||
"java": ["java"],
|
||||
"cpp": ["cpp", "c++"],
|
||||
"csharp": ["csharp", "c#", "cs"],
|
||||
"json_code": ["json"],
|
||||
"bash": ["bash", "shell", "sh"],
|
||||
"php": ["php"],
|
||||
"ruby": ["ruby", "rb"],
|
||||
"yaml": ["yaml", "yml"],
|
||||
"markdown": ["markdown", "md"],
|
||||
"typescript": ["typescript", "ts"],
|
||||
"xml": ["xml"],
|
||||
}
|
||||
|
||||
# Extract code for each language
|
||||
for canonical_name, aliases in language_aliases.items():
|
||||
code = ""
|
||||
# Try each alias for the language
|
||||
for alias in aliases:
|
||||
code_for_alias = self.extract_code(input_data.text, alias)
|
||||
if code_for_alias:
|
||||
code = code + "\n\n" + code_for_alias if code else code_for_alias
|
||||
|
||||
if code: # Only yield if there's actual code content
|
||||
yield canonical_name, code
|
||||
|
||||
# Remove all code blocks from the text to get remaining text
|
||||
pattern = (
|
||||
r"```(?:"
|
||||
+ "|".join(
|
||||
re.escape(alias)
|
||||
for aliases in language_aliases.values()
|
||||
for alias in aliases
|
||||
)
|
||||
+ r")\s+[\s\S]*?```"
|
||||
)
|
||||
|
||||
remaining_text = re.sub(pattern, "", input_data.text).strip()
|
||||
remaining_text = re.sub(r"\n\s*\n", "\n", remaining_text)
|
||||
|
||||
if remaining_text: # Only yield if there's remaining text
|
||||
yield "remaining_text", remaining_text
|
||||
|
||||
def extract_code(self, text: str, language: str) -> str:
|
||||
# Escape special regex characters in the language string
|
||||
language = re.escape(language)
|
||||
# Extract all code blocks enclosed in ```language``` blocks
|
||||
pattern = re.compile(rf"```{language}\s+(.*?)```", re.DOTALL | re.IGNORECASE)
|
||||
matches = pattern.finditer(text)
|
||||
# Combine all code blocks for this language with newlines between them
|
||||
code_blocks = [match.group(1).strip() for match in matches]
|
||||
return "\n\n".join(code_blocks) if code_blocks else ""
|
||||
@@ -12,16 +12,15 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
DiscordCredentials = CredentialsMetaInput[Literal["discord"], Literal["api_key"]]
|
||||
DiscordCredentials = CredentialsMetaInput[
|
||||
Literal[ProviderName.DISCORD], Literal["api_key"]
|
||||
]
|
||||
|
||||
|
||||
def DiscordCredentialsField() -> DiscordCredentials:
|
||||
return CredentialsField(
|
||||
description="Discord bot token",
|
||||
provider="discord",
|
||||
supported_credential_types={"api_key"},
|
||||
)
|
||||
return CredentialsField(description="Discord bot token")
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
|
||||
32
autogpt_platform/backend/backend/blocks/exa/_auth.py
Normal file
32
autogpt_platform/backend/backend/blocks/exa/_auth.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
ExaCredentials = APIKeyCredentials
|
||||
ExaCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.EXA],
|
||||
Literal["api_key"],
|
||||
]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="exa",
|
||||
api_key=SecretStr("mock-exa-api-key"),
|
||||
title="Mock Exa API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
def ExaCredentialsField() -> ExaCredentialsInput:
|
||||
"""Creates an Exa credentials input on a block."""
|
||||
return CredentialsField(description="The Exa integration requires an API Key.")
|
||||
157
autogpt_platform/backend/backend/blocks/exa/search.py
Normal file
157
autogpt_platform/backend/backend/blocks/exa/search.py
Normal file
@@ -0,0 +1,157 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.blocks.exa._auth import (
|
||||
ExaCredentials,
|
||||
ExaCredentialsField,
|
||||
ExaCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import requests
|
||||
|
||||
|
||||
class ContentSettings(BaseModel):
|
||||
text: dict = SchemaField(
|
||||
description="Text content settings",
|
||||
default={"maxCharacters": 1000, "includeHtmlTags": False},
|
||||
)
|
||||
highlights: dict = SchemaField(
|
||||
description="Highlight settings",
|
||||
default={"numSentences": 3, "highlightsPerUrl": 3},
|
||||
)
|
||||
summary: dict = SchemaField(
|
||||
description="Summary settings",
|
||||
default={"query": ""},
|
||||
)
|
||||
|
||||
|
||||
class ExaSearchBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: ExaCredentialsInput = ExaCredentialsField()
|
||||
query: str = SchemaField(description="The search query")
|
||||
useAutoprompt: bool = SchemaField(
|
||||
description="Whether to use autoprompt",
|
||||
default=True,
|
||||
)
|
||||
type: str = SchemaField(
|
||||
description="Type of search",
|
||||
default="",
|
||||
)
|
||||
category: str = SchemaField(
|
||||
description="Category to search within",
|
||||
default="",
|
||||
)
|
||||
numResults: int = SchemaField(
|
||||
description="Number of results to return",
|
||||
default=10,
|
||||
)
|
||||
includeDomains: List[str] = SchemaField(
|
||||
description="Domains to include in search",
|
||||
default=[],
|
||||
)
|
||||
excludeDomains: List[str] = SchemaField(
|
||||
description="Domains to exclude from search",
|
||||
default=[],
|
||||
)
|
||||
startCrawlDate: datetime = SchemaField(
|
||||
description="Start date for crawled content",
|
||||
)
|
||||
endCrawlDate: datetime = SchemaField(
|
||||
description="End date for crawled content",
|
||||
)
|
||||
startPublishedDate: datetime = SchemaField(
|
||||
description="Start date for published content",
|
||||
)
|
||||
endPublishedDate: datetime = SchemaField(
|
||||
description="End date for published content",
|
||||
)
|
||||
includeText: List[str] = SchemaField(
|
||||
description="Text patterns to include",
|
||||
default=[],
|
||||
)
|
||||
excludeText: List[str] = SchemaField(
|
||||
description="Text patterns to exclude",
|
||||
default=[],
|
||||
)
|
||||
contents: ContentSettings = SchemaField(
|
||||
description="Content retrieval settings",
|
||||
default=ContentSettings(),
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: list = SchemaField(
|
||||
description="List of search results",
|
||||
default=[],
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="996cec64-ac40-4dde-982f-b0dc60a5824d",
|
||||
description="Searches the web using Exa's advanced search API",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaSearchBlock.Input,
|
||||
output_schema=ExaSearchBlock.Output,
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: ExaCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/search"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
payload = {
|
||||
"query": input_data.query,
|
||||
"useAutoprompt": input_data.useAutoprompt,
|
||||
"numResults": input_data.numResults,
|
||||
"contents": {
|
||||
"text": {"maxCharacters": 1000, "includeHtmlTags": False},
|
||||
"highlights": {
|
||||
"numSentences": 3,
|
||||
"highlightsPerUrl": 3,
|
||||
},
|
||||
"summary": {"query": ""},
|
||||
},
|
||||
}
|
||||
|
||||
# Add dates if they exist
|
||||
date_fields = [
|
||||
"startCrawlDate",
|
||||
"endCrawlDate",
|
||||
"startPublishedDate",
|
||||
"endPublishedDate",
|
||||
]
|
||||
for field in date_fields:
|
||||
value = getattr(input_data, field, None)
|
||||
if value:
|
||||
payload[field] = value.strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
|
||||
# Add other fields
|
||||
optional_fields = [
|
||||
"type",
|
||||
"category",
|
||||
"includeDomains",
|
||||
"excludeDomains",
|
||||
"includeText",
|
||||
"excludeText",
|
||||
]
|
||||
|
||||
for field in optional_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value: # Only add non-empty values
|
||||
payload[field] = value
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
# Extract just the results array from the response
|
||||
yield "results", data.get("results", [])
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "results", []
|
||||
@@ -3,10 +3,11 @@ from typing import Literal
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
FalCredentials = APIKeyCredentials
|
||||
FalCredentialsInput = CredentialsMetaInput[
|
||||
Literal["fal"],
|
||||
Literal[ProviderName.FAL],
|
||||
Literal["api_key"],
|
||||
]
|
||||
|
||||
@@ -30,7 +31,5 @@ def FalCredentialsField() -> FalCredentialsInput:
|
||||
Creates a FAL credentials input on a block.
|
||||
"""
|
||||
return CredentialsField(
|
||||
provider="fal",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The FAL integration can be used with an API Key.",
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -64,7 +64,7 @@ class AIVideoGeneratorBlock(Block):
|
||||
},
|
||||
)
|
||||
|
||||
def _get_headers(self, api_key: str) -> Dict[str, str]:
|
||||
def _get_headers(self, api_key: str) -> dict[str, str]:
|
||||
"""Get headers for FAL API requests."""
|
||||
return {
|
||||
"Authorization": f"Key {api_key}",
|
||||
@@ -72,8 +72,8 @@ class AIVideoGeneratorBlock(Block):
|
||||
}
|
||||
|
||||
def _submit_request(
|
||||
self, url: str, headers: Dict[str, str], data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
self, url: str, headers: dict[str, str], data: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Submit a request to the FAL API."""
|
||||
try:
|
||||
response = httpx.post(url, headers=headers, json=data)
|
||||
@@ -83,7 +83,7 @@ class AIVideoGeneratorBlock(Block):
|
||||
logger.error(f"FAL API request failed: {str(e)}")
|
||||
raise RuntimeError(f"Failed to submit request: {str(e)}")
|
||||
|
||||
def _poll_status(self, status_url: str, headers: Dict[str, str]) -> Dict[str, Any]:
|
||||
def _poll_status(self, status_url: str, headers: dict[str, str]) -> dict[str, Any]:
|
||||
"""Poll the status endpoint until completion or failure."""
|
||||
try:
|
||||
response = httpx.get(status_url, headers=headers)
|
||||
|
||||
@@ -8,6 +8,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
secrets = Secrets()
|
||||
@@ -17,7 +18,7 @@ GITHUB_OAUTH_IS_CONFIGURED = bool(
|
||||
|
||||
GithubCredentials = APIKeyCredentials | OAuth2Credentials
|
||||
GithubCredentialsInput = CredentialsMetaInput[
|
||||
Literal["github"],
|
||||
Literal[ProviderName.GITHUB],
|
||||
Literal["api_key", "oauth2"] if GITHUB_OAUTH_IS_CONFIGURED else Literal["api_key"],
|
||||
]
|
||||
|
||||
@@ -30,10 +31,6 @@ def GithubCredentialsField(scope: str) -> GithubCredentialsInput:
|
||||
scope: The authorization scope needed for the block to work. ([list of available scopes](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/scopes-for-oauth-apps#available-scopes))
|
||||
""" # noqa
|
||||
return CredentialsField(
|
||||
provider="github",
|
||||
supported_credential_types=(
|
||||
{"api_key", "oauth2"} if GITHUB_OAUTH_IS_CONFIGURED else {"api_key"}
|
||||
),
|
||||
required_scopes={scope},
|
||||
description="The GitHub integration can be used with OAuth, "
|
||||
"or any API key with sufficient permissions for the blocks it is used on.",
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import re
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
@@ -253,7 +255,7 @@ class GithubReadPullRequestBlock(Block):
|
||||
@staticmethod
|
||||
def read_pr_changes(credentials: GithubCredentials, pr_url: str) -> str:
|
||||
api = get_api(credentials)
|
||||
files_url = pr_url + "/files"
|
||||
files_url = prepare_pr_api_url(pr_url=pr_url, path="files")
|
||||
response = api.get(files_url)
|
||||
files = response.json()
|
||||
changes = []
|
||||
@@ -331,7 +333,7 @@ class GithubAssignPRReviewerBlock(Block):
|
||||
credentials: GithubCredentials, pr_url: str, reviewer: str
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
reviewers_url = pr_url + "/requested_reviewers"
|
||||
reviewers_url = prepare_pr_api_url(pr_url=pr_url, path="requested_reviewers")
|
||||
data = {"reviewers": [reviewer]}
|
||||
api.post(reviewers_url, json=data)
|
||||
return "Reviewer assigned successfully"
|
||||
@@ -398,7 +400,7 @@ class GithubUnassignPRReviewerBlock(Block):
|
||||
credentials: GithubCredentials, pr_url: str, reviewer: str
|
||||
) -> str:
|
||||
api = get_api(credentials)
|
||||
reviewers_url = pr_url + "/requested_reviewers"
|
||||
reviewers_url = prepare_pr_api_url(pr_url=pr_url, path="requested_reviewers")
|
||||
data = {"reviewers": [reviewer]}
|
||||
api.delete(reviewers_url, json=data)
|
||||
return "Reviewer unassigned successfully"
|
||||
@@ -478,7 +480,7 @@ class GithubListPRReviewersBlock(Block):
|
||||
credentials: GithubCredentials, pr_url: str
|
||||
) -> list[Output.ReviewerItem]:
|
||||
api = get_api(credentials)
|
||||
reviewers_url = pr_url + "/requested_reviewers"
|
||||
reviewers_url = prepare_pr_api_url(pr_url=pr_url, path="requested_reviewers")
|
||||
response = api.get(reviewers_url)
|
||||
data = response.json()
|
||||
reviewers: list[GithubListPRReviewersBlock.Output.ReviewerItem] = [
|
||||
@@ -499,3 +501,14 @@ class GithubListPRReviewersBlock(Block):
|
||||
input_data.pr_url,
|
||||
)
|
||||
yield from (("reviewer", reviewer) for reviewer in reviewers)
|
||||
|
||||
|
||||
def prepare_pr_api_url(pr_url: str, path: str) -> str:
|
||||
# Pattern to capture the base repository URL and the pull request number
|
||||
pattern = r"^(?:https?://)?([^/]+/[^/]+/[^/]+)/pull/(\d+)"
|
||||
match = re.match(pattern, pr_url)
|
||||
if not match:
|
||||
return pr_url
|
||||
|
||||
base_url, pr_number = match.groups()
|
||||
return f"{base_url}/pulls/{pr_number}/{path}"
|
||||
|
||||
@@ -111,7 +111,9 @@ class GithubPullRequestTriggerBlock(GitHubTriggerBase, Block):
|
||||
def __init__(self):
|
||||
from backend.integrations.webhooks.github import GithubWebhookType
|
||||
|
||||
example_payload = json.loads(self.EXAMPLE_PAYLOAD_FILE.read_text())
|
||||
example_payload = json.loads(
|
||||
self.EXAMPLE_PAYLOAD_FILE.read_text(encoding="utf-8")
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
id="6c60ec01-8128-419e-988f-96a063ee2fea",
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Literal
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import CredentialsField, CredentialsMetaInput, OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
# --8<-- [start:GoogleOAuthIsConfigured]
|
||||
@@ -12,7 +13,9 @@ GOOGLE_OAUTH_IS_CONFIGURED = bool(
|
||||
)
|
||||
# --8<-- [end:GoogleOAuthIsConfigured]
|
||||
GoogleCredentials = OAuth2Credentials
|
||||
GoogleCredentialsInput = CredentialsMetaInput[Literal["google"], Literal["oauth2"]]
|
||||
GoogleCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.GOOGLE], Literal["oauth2"]
|
||||
]
|
||||
|
||||
|
||||
def GoogleCredentialsField(scopes: list[str]) -> GoogleCredentialsInput:
|
||||
@@ -23,8 +26,6 @@ def GoogleCredentialsField(scopes: list[str]) -> GoogleCredentialsInput:
|
||||
scopes: The authorization scopes needed for the block to work.
|
||||
"""
|
||||
return CredentialsField(
|
||||
provider="google",
|
||||
supported_credential_types={"oauth2"},
|
||||
required_scopes=set(scopes),
|
||||
description="The Google integration requires OAuth2 authentication.",
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -38,12 +39,8 @@ class Place(BaseModel):
|
||||
class GoogleMapsSearchBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal["google_maps"], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
provider="google_maps",
|
||||
supported_credential_types={"api_key"},
|
||||
description="Google Maps API Key",
|
||||
)
|
||||
Literal[ProviderName.GOOGLE_MAPS], Literal["api_key"]
|
||||
] = CredentialsField(description="Google Maps API Key")
|
||||
query: str = SchemaField(
|
||||
description="Search query for local businesses",
|
||||
placeholder="e.g., 'restaurants in New York'",
|
||||
|
||||
@@ -3,10 +3,11 @@ from typing import Literal
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
HubSpotCredentials = APIKeyCredentials
|
||||
HubSpotCredentialsInput = CredentialsMetaInput[
|
||||
Literal["hubspot"],
|
||||
Literal[ProviderName.HUBSPOT],
|
||||
Literal["api_key"],
|
||||
]
|
||||
|
||||
@@ -14,8 +15,6 @@ HubSpotCredentialsInput = CredentialsMetaInput[
|
||||
def HubSpotCredentialsField() -> HubSpotCredentialsInput:
|
||||
"""Creates a HubSpot credentials input on a block."""
|
||||
return CredentialsField(
|
||||
provider="hubspot",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The HubSpot integration requires an API Key.",
|
||||
)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import requests
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
@@ -83,13 +84,10 @@ class UpscaleOption(str, Enum):
|
||||
|
||||
class IdeogramModelBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
|
||||
credentials: CredentialsMetaInput[Literal["ideogram"], Literal["api_key"]] = (
|
||||
CredentialsField(
|
||||
provider="ideogram",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Ideogram integration can be used with any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.IDEOGRAM], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="The Ideogram integration can be used with any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="Text prompt for image generation",
|
||||
|
||||
@@ -3,27 +3,14 @@ from typing import Literal
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
JinaCredentials = APIKeyCredentials
|
||||
JinaCredentialsInput = CredentialsMetaInput[
|
||||
Literal["jina"],
|
||||
Literal[ProviderName.JINA],
|
||||
Literal["api_key"],
|
||||
]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="jina",
|
||||
api_key=SecretStr("mock-jina-api-key"),
|
||||
title="Mock Jina API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.type,
|
||||
}
|
||||
|
||||
|
||||
def JinaCredentialsField() -> JinaCredentialsInput:
|
||||
"""
|
||||
@@ -31,8 +18,6 @@ def JinaCredentialsField() -> JinaCredentialsInput:
|
||||
|
||||
"""
|
||||
return CredentialsField(
|
||||
provider="jina",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Jina integration can be used with an API Key.",
|
||||
)
|
||||
|
||||
|
||||
59
autogpt_platform/backend/backend/blocks/jina/fact_checker.py
Normal file
59
autogpt_platform/backend/backend/blocks/jina/fact_checker.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from urllib.parse import quote
|
||||
|
||||
import requests
|
||||
|
||||
from backend.blocks.jina._auth import (
|
||||
JinaCredentials,
|
||||
JinaCredentialsField,
|
||||
JinaCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class FactCheckerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
statement: str = SchemaField(
|
||||
description="The statement to check for factuality"
|
||||
)
|
||||
credentials: JinaCredentialsInput = JinaCredentialsField()
|
||||
|
||||
class Output(BlockSchema):
|
||||
factuality: float = SchemaField(
|
||||
description="The factuality score of the statement"
|
||||
)
|
||||
result: bool = SchemaField(description="The result of the factuality check")
|
||||
reason: str = SchemaField(description="The reason for the factuality result")
|
||||
error: str = SchemaField(description="Error message if the check fails")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d38b6c5e-9968-4271-8423-6cfe60d6e7e6",
|
||||
description="This block checks the factuality of a given statement using Jina AI's Grounding API.",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=FactCheckerBlock.Input,
|
||||
output_schema=FactCheckerBlock.Output,
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: JinaCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
encoded_statement = quote(input_data.statement)
|
||||
url = f"https://g.jina.ai/{encoded_statement}"
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {credentials.api_key.get_secret_value()}",
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if "data" in data:
|
||||
data = data["data"]
|
||||
yield "factuality", data["factuality"]
|
||||
yield "result", data["result"]
|
||||
yield "reason", data["reason"]
|
||||
else:
|
||||
raise RuntimeError(f"Expected 'data' key not found in response: {data}")
|
||||
@@ -7,6 +7,8 @@ from typing import TYPE_CHECKING, Any, List, Literal, NamedTuple
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enum import _EnumMemberT
|
||||
|
||||
@@ -27,7 +29,13 @@ from backend.util.settings import BehaveAs, Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LLMProviderName = Literal["anthropic", "groq", "openai", "ollama", "open_router"]
|
||||
LLMProviderName = Literal[
|
||||
ProviderName.ANTHROPIC,
|
||||
ProviderName.GROQ,
|
||||
ProviderName.OLLAMA,
|
||||
ProviderName.OPENAI,
|
||||
ProviderName.OPEN_ROUTER,
|
||||
]
|
||||
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
@@ -48,8 +56,6 @@ TEST_CREDENTIALS_INPUT = {
|
||||
def AICredentialsField() -> AICredentials:
|
||||
return CredentialsField(
|
||||
description="API key for the LLM provider.",
|
||||
provider=["anthropic", "groq", "openai", "ollama", "open_router"],
|
||||
supported_credential_types={"api_key"},
|
||||
discriminator="model",
|
||||
discriminator_mapping={
|
||||
model.value: model.metadata.provider for model in LlmModel
|
||||
@@ -105,9 +111,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
# Ollama models
|
||||
OLLAMA_LLAMA3_8B = "llama3"
|
||||
OLLAMA_LLAMA3_405B = "llama3.1:405b"
|
||||
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
|
||||
# OpenRouter models
|
||||
GEMINI_FLASH_1_5_8B = "google/gemini-flash-1.5"
|
||||
GEMINI_FLASH_1_5_EXP = "google/gemini-flash-1.5-exp"
|
||||
GROK_BETA = "x-ai/grok-beta"
|
||||
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
||||
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
||||
@@ -117,6 +123,14 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE = (
|
||||
"perplexity/llama-3.1-sonar-large-128k-online"
|
||||
)
|
||||
QWEN_QWQ_32B_PREVIEW = "qwen/qwq-32b-preview"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
|
||||
AMAZON_NOVA_LITE_V1 = "amazon/nova-lite-v1"
|
||||
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
|
||||
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
||||
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
||||
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
@@ -151,8 +165,8 @@ MODEL_METADATA = {
|
||||
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072),
|
||||
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192),
|
||||
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192),
|
||||
LlmModel.OLLAMA_DOLPHIN: ModelMetadata("ollama", 32768),
|
||||
LlmModel.GEMINI_FLASH_1_5_8B: ModelMetadata("open_router", 8192),
|
||||
LlmModel.GEMINI_FLASH_1_5_EXP: ModelMetadata("open_router", 8192),
|
||||
LlmModel.GROK_BETA: ModelMetadata("open_router", 8192),
|
||||
LlmModel.MISTRAL_NEMO: ModelMetadata("open_router", 4000),
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata("open_router", 4000),
|
||||
@@ -162,6 +176,14 @@ MODEL_METADATA = {
|
||||
LlmModel.PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE: ModelMetadata(
|
||||
"open_router", 8192
|
||||
),
|
||||
LlmModel.QWEN_QWQ_32B_PREVIEW: ModelMetadata("open_router", 4000),
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata("open_router", 4000),
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata("open_router", 4000),
|
||||
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata("open_router", 4000),
|
||||
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata("open_router", 4000),
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata("open_router", 4000),
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata("open_router", 4000),
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata("open_router", 4000),
|
||||
}
|
||||
|
||||
for model in LlmModel:
|
||||
@@ -220,6 +242,12 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
description="The maximum number of tokens to generate in the chat completion.",
|
||||
)
|
||||
|
||||
ollama_host: str = SchemaField(
|
||||
advanced=True,
|
||||
default="localhost:11434",
|
||||
description="Ollama host for local models",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: dict[str, Any] = SchemaField(
|
||||
description="The response object generated by the language model."
|
||||
@@ -265,6 +293,7 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
prompt: list[dict],
|
||||
json_format: bool,
|
||||
max_tokens: int | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
) -> tuple[str, int, int]:
|
||||
"""
|
||||
Args:
|
||||
@@ -273,6 +302,7 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
prompt: The prompt to send to the LLM.
|
||||
json_format: Whether the response should be in JSON format.
|
||||
max_tokens: The maximum number of tokens to generate in the chat completion.
|
||||
ollama_host: The host for ollama to use
|
||||
|
||||
Returns:
|
||||
The response from the LLM.
|
||||
@@ -362,9 +392,10 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
response.usage.completion_tokens if response.usage else 0,
|
||||
)
|
||||
elif provider == "ollama":
|
||||
client = ollama.Client(host=ollama_host)
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
response = ollama.generate(
|
||||
response = client.generate(
|
||||
model=llm_model.value,
|
||||
prompt=f"{sys_messages}\n\n{usr_messages}",
|
||||
stream=False,
|
||||
@@ -464,6 +495,7 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
json_format=bool(input_data.expected_format),
|
||||
ollama_host=input_data.ollama_host,
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
self.merge_stats(
|
||||
@@ -546,6 +578,11 @@ class AITextGeneratorBlock(Block):
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False, default={}, description="Values used to fill in the prompt."
|
||||
)
|
||||
ollama_host: str = SchemaField(
|
||||
advanced=True,
|
||||
default="localhost:11434",
|
||||
description="Ollama host for local models",
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
advanced=True,
|
||||
default=None,
|
||||
@@ -636,6 +673,11 @@ class AITextSummarizerBlock(Block):
|
||||
description="The number of overlapping tokens between chunks to maintain context.",
|
||||
ge=0,
|
||||
)
|
||||
ollama_host: str = SchemaField(
|
||||
advanced=True,
|
||||
default="localhost:11434",
|
||||
description="Ollama host for local models",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
summary: str = SchemaField(description="The final summary of the text.")
|
||||
@@ -774,6 +816,11 @@ class AIConversationBlock(Block):
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate in the chat completion.",
|
||||
)
|
||||
ollama_host: str = SchemaField(
|
||||
advanced=True,
|
||||
default="localhost:11434",
|
||||
description="Ollama host for local models",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: str = SchemaField(
|
||||
@@ -871,6 +918,11 @@ class AIListGeneratorBlock(Block):
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate in the chat completion.",
|
||||
)
|
||||
ollama_host: str = SchemaField(
|
||||
advanced=True,
|
||||
default="localhost:11434",
|
||||
description="Ollama host for local models",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
generated_list: List[str] = SchemaField(description="The generated list.")
|
||||
@@ -1022,6 +1074,7 @@ class AIListGeneratorBlock(Block):
|
||||
credentials=input_data.credentials,
|
||||
model=input_data.model,
|
||||
expected_format={}, # Do not use structured response
|
||||
ollama_host=input_data.ollama_host,
|
||||
),
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
SecretField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import requests
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
@@ -77,12 +78,10 @@ class PublishToMediumBlock(Block):
|
||||
description="Whether to notify followers that the user has published",
|
||||
placeholder="False",
|
||||
)
|
||||
credentials: CredentialsMetaInput[Literal["medium"], Literal["api_key"]] = (
|
||||
CredentialsField(
|
||||
provider="medium",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Medium integration can be used with any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.MEDIUM], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="The Medium integration can be used with any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
@@ -10,22 +10,18 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
PineconeCredentials = APIKeyCredentials
|
||||
PineconeCredentialsInput = CredentialsMetaInput[
|
||||
Literal["pinecone"],
|
||||
Literal[ProviderName.PINECONE],
|
||||
Literal["api_key"],
|
||||
]
|
||||
|
||||
|
||||
def PineconeCredentialsField() -> PineconeCredentialsInput:
|
||||
"""
|
||||
Creates a Pinecone credentials input on a block.
|
||||
|
||||
"""
|
||||
"""Creates a Pinecone credentials input on a block."""
|
||||
return CredentialsField(
|
||||
provider="pinecone",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Pinecone integration can be used with an API Key.",
|
||||
)
|
||||
|
||||
@@ -147,7 +143,7 @@ class PineconeQueryBlock(Block):
|
||||
top_k=input_data.top_k,
|
||||
include_values=input_data.include_values,
|
||||
include_metadata=input_data.include_metadata,
|
||||
).to_dict()
|
||||
).to_dict() # type: ignore
|
||||
combined_text = ""
|
||||
if results["matches"]:
|
||||
texts = [
|
||||
|
||||
@@ -13,6 +13,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -54,13 +55,11 @@ class ImageType(str, Enum):
|
||||
|
||||
class ReplicateFluxAdvancedModelBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[Literal["replicate"], Literal["api_key"]] = (
|
||||
CredentialsField(
|
||||
provider="replicate",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Replicate integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="The Replicate integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
prompt: str = SchemaField(
|
||||
description="Text prompt for image generation",
|
||||
|
||||
@@ -11,6 +11,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
|
||||
class GetWikipediaSummaryBlock(Block, GetRequest):
|
||||
@@ -65,10 +66,8 @@ class GetWeatherInformationBlock(Block, GetRequest):
|
||||
description="Location to get weather information for"
|
||||
)
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal["openweathermap"], Literal["api_key"]
|
||||
Literal[ProviderName.OPENWEATHERMAP], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
provider="openweathermap",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The OpenWeatherMap integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
|
||||
70
autogpt_platform/backend/backend/blocks/slant3d/_api.py
Normal file
70
autogpt_platform/backend/backend/blocks/slant3d/_api.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
Slant3DCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.SLANT3D], Literal["api_key"]
|
||||
]
|
||||
|
||||
|
||||
def Slant3DCredentialsField() -> Slant3DCredentialsInput:
|
||||
return CredentialsField(description="Slant3D API key for authentication")
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="slant3d",
|
||||
api_key=SecretStr("mock-slant3d-api-key"),
|
||||
title="Mock Slant3D API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class CustomerDetails(BaseModel):
|
||||
name: str
|
||||
email: str
|
||||
phone: str
|
||||
address: str
|
||||
city: str
|
||||
state: str
|
||||
zip: str
|
||||
country_iso: str = "US"
|
||||
is_residential: bool = True
|
||||
|
||||
|
||||
class Color(Enum):
|
||||
WHITE = "white"
|
||||
BLACK = "black"
|
||||
|
||||
|
||||
class Profile(Enum):
|
||||
PLA = "PLA"
|
||||
PETG = "PETG"
|
||||
|
||||
|
||||
class OrderItem(BaseModel):
|
||||
# filename: str
|
||||
file_url: str
|
||||
quantity: str # String as per API spec
|
||||
color: Color = Color.WHITE
|
||||
profile: Profile = Profile.PLA
|
||||
# image_url: str = ""
|
||||
# sku: str = ""
|
||||
|
||||
|
||||
class Filament(BaseModel):
|
||||
filament: str
|
||||
hexColor: str
|
||||
colorTag: str
|
||||
profile: str
|
||||
94
autogpt_platform/backend/backend/blocks/slant3d/base.py
Normal file
94
autogpt_platform/backend/backend/blocks/slant3d/base.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from backend.data.block import Block
|
||||
from backend.util.request import requests
|
||||
|
||||
from ._api import Color, CustomerDetails, OrderItem, Profile
|
||||
|
||||
|
||||
class Slant3DBlockBase(Block):
|
||||
"""Base block class for Slant3D API interactions"""
|
||||
|
||||
BASE_URL = "https://www.slant3dapi.com/api"
|
||||
|
||||
def _get_headers(self, api_key: str) -> Dict[str, str]:
|
||||
return {"api-key": api_key, "Content-Type": "application/json"}
|
||||
|
||||
def _make_request(self, method: str, endpoint: str, api_key: str, **kwargs) -> Dict:
|
||||
url = f"{self.BASE_URL}/{endpoint}"
|
||||
response = requests.request(
|
||||
method=method, url=url, headers=self._get_headers(api_key), **kwargs
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
error_msg = response.json().get("error", "Unknown error")
|
||||
raise RuntimeError(f"API request failed: {error_msg}")
|
||||
|
||||
return response.json()
|
||||
|
||||
def _check_valid_color(self, profile: Profile, color: Color, api_key: str) -> str:
|
||||
response = self._make_request(
|
||||
"GET",
|
||||
"filament",
|
||||
api_key,
|
||||
params={"profile": profile.value, "color": color.value},
|
||||
)
|
||||
if profile == Profile.PLA:
|
||||
color_tag = color.value
|
||||
else:
|
||||
color_tag = f"{profile.value.lower()}{color.value.capitalize()}"
|
||||
valid_tags = [filament["colorTag"] for filament in response["filaments"]]
|
||||
|
||||
if color_tag not in valid_tags:
|
||||
raise ValueError(
|
||||
f"""Invalid color profile combination {color_tag}.
|
||||
Valid colors for {profile.value} are:
|
||||
{','.join([filament['colorTag'].replace(profile.value.lower(), '') for filament in response['filaments'] if filament['profile'] == profile.value])}
|
||||
"""
|
||||
)
|
||||
return color_tag
|
||||
|
||||
def _convert_to_color(self, profile: Profile, color: Color, api_key: str) -> str:
|
||||
return self._check_valid_color(profile, color, api_key)
|
||||
|
||||
def _format_order_data(
|
||||
self,
|
||||
customer: CustomerDetails,
|
||||
order_number: str,
|
||||
items: list[OrderItem],
|
||||
api_key: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Helper function to format order data for API requests"""
|
||||
orders = []
|
||||
for item in items:
|
||||
order_data = {
|
||||
"email": customer.email,
|
||||
"phone": customer.phone,
|
||||
"name": customer.name,
|
||||
"orderNumber": order_number,
|
||||
"filename": item.file_url,
|
||||
"fileURL": item.file_url,
|
||||
"bill_to_street_1": customer.address,
|
||||
"bill_to_city": customer.city,
|
||||
"bill_to_state": customer.state,
|
||||
"bill_to_zip": customer.zip,
|
||||
"bill_to_country_as_iso": customer.country_iso,
|
||||
"bill_to_is_US_residential": str(customer.is_residential).lower(),
|
||||
"ship_to_name": customer.name,
|
||||
"ship_to_street_1": customer.address,
|
||||
"ship_to_city": customer.city,
|
||||
"ship_to_state": customer.state,
|
||||
"ship_to_zip": customer.zip,
|
||||
"ship_to_country_as_iso": customer.country_iso,
|
||||
"ship_to_is_US_residential": str(customer.is_residential).lower(),
|
||||
"order_item_name": item.file_url,
|
||||
"order_quantity": item.quantity,
|
||||
"order_image_url": "",
|
||||
"order_sku": "NOT_USED",
|
||||
"order_item_color": self._convert_to_color(
|
||||
item.profile, item.color, api_key
|
||||
),
|
||||
"profile": item.profile.value,
|
||||
}
|
||||
orders.append(order_data)
|
||||
return orders
|
||||
85
autogpt_platform/backend/backend/blocks/slant3d/filament.py
Normal file
85
autogpt_platform/backend/backend/blocks/slant3d/filament.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from typing import List
|
||||
|
||||
from backend.data.block import BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, SchemaField
|
||||
|
||||
from ._api import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
Filament,
|
||||
Slant3DCredentialsField,
|
||||
Slant3DCredentialsInput,
|
||||
)
|
||||
from .base import Slant3DBlockBase
|
||||
|
||||
|
||||
class Slant3DFilamentBlock(Slant3DBlockBase):
|
||||
"""Block for retrieving available filaments"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: Slant3DCredentialsInput = Slant3DCredentialsField()
|
||||
|
||||
class Output(BlockSchema):
|
||||
filaments: List[Filament] = SchemaField(
|
||||
description="List of available filaments"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7cc416f4-f305-4606-9b3b-452b8a81031c",
|
||||
description="Get list of available filaments",
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"filaments",
|
||||
[
|
||||
{
|
||||
"filament": "PLA BLACK",
|
||||
"hexColor": "000000",
|
||||
"colorTag": "black",
|
||||
"profile": "PLA",
|
||||
},
|
||||
{
|
||||
"filament": "PLA WHITE",
|
||||
"hexColor": "ffffff",
|
||||
"colorTag": "white",
|
||||
"profile": "PLA",
|
||||
},
|
||||
],
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"_make_request": lambda *args, **kwargs: {
|
||||
"filaments": [
|
||||
{
|
||||
"filament": "PLA BLACK",
|
||||
"hexColor": "000000",
|
||||
"colorTag": "black",
|
||||
"profile": "PLA",
|
||||
},
|
||||
{
|
||||
"filament": "PLA WHITE",
|
||||
"hexColor": "ffffff",
|
||||
"colorTag": "white",
|
||||
"profile": "PLA",
|
||||
},
|
||||
]
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = self._make_request(
|
||||
"GET", "filament", credentials.api_key.get_secret_value()
|
||||
)
|
||||
yield "filaments", result["filaments"]
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
raise
|
||||
418
autogpt_platform/backend/backend/blocks/slant3d/order.py
Normal file
418
autogpt_platform/backend/backend/blocks/slant3d/order.py
Normal file
@@ -0,0 +1,418 @@
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
import requests as baserequests
|
||||
|
||||
from backend.data.block import BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, SchemaField
|
||||
from backend.util import settings
|
||||
from backend.util.settings import BehaveAs
|
||||
|
||||
from ._api import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
CustomerDetails,
|
||||
OrderItem,
|
||||
Slant3DCredentialsField,
|
||||
Slant3DCredentialsInput,
|
||||
)
|
||||
from .base import Slant3DBlockBase
|
||||
|
||||
|
||||
class Slant3DCreateOrderBlock(Slant3DBlockBase):
|
||||
"""Block for creating new orders"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: Slant3DCredentialsInput = Slant3DCredentialsField()
|
||||
order_number: str = SchemaField(
|
||||
description="Your custom order number (or leave blank for a random one)",
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
customer: CustomerDetails = SchemaField(
|
||||
description="Customer details for where to ship the item",
|
||||
advanced=False,
|
||||
)
|
||||
items: List[OrderItem] = SchemaField(
|
||||
description="List of items to print",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
order_id: str = SchemaField(description="Slant3D order ID")
|
||||
error: str = SchemaField(description="Error message if order failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f73007d6-f48f-4aaf-9e6b-6883998a09b4",
|
||||
description="Create a new print order",
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"order_number": "TEST-001",
|
||||
"customer": {
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"phone": "123-456-7890",
|
||||
"address": "123 Test St",
|
||||
"city": "Test City",
|
||||
"state": "TS",
|
||||
"zip": "12345",
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"file_url": "https://example.com/model.stl",
|
||||
"quantity": "1",
|
||||
"color": "black",
|
||||
"profile": "PLA",
|
||||
}
|
||||
],
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("order_id", "314144241")],
|
||||
test_mock={
|
||||
"_make_request": lambda *args, **kwargs: {"orderId": "314144241"},
|
||||
"_convert_to_color": lambda *args, **kwargs: "black",
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
order_data = self._format_order_data(
|
||||
input_data.customer,
|
||||
input_data.order_number,
|
||||
input_data.items,
|
||||
credentials.api_key.get_secret_value(),
|
||||
)
|
||||
result = self._make_request(
|
||||
"POST", "order", credentials.api_key.get_secret_value(), json=order_data
|
||||
)
|
||||
yield "order_id", result["orderId"]
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
raise
|
||||
|
||||
|
||||
class Slant3DEstimateOrderBlock(Slant3DBlockBase):
|
||||
"""Block for getting order cost estimates"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: Slant3DCredentialsInput = Slant3DCredentialsField()
|
||||
order_number: str = SchemaField(
|
||||
description="Your custom order number (or leave blank for a random one)",
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
customer: CustomerDetails = SchemaField(
|
||||
description="Customer details for where to ship the item",
|
||||
advanced=False,
|
||||
)
|
||||
items: List[OrderItem] = SchemaField(
|
||||
description="List of items to print",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
total_price: float = SchemaField(description="Total price in USD")
|
||||
shipping_cost: float = SchemaField(description="Shipping cost")
|
||||
printing_cost: float = SchemaField(description="Printing cost")
|
||||
error: str = SchemaField(description="Error message if estimation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bf8823d6-b42a-48c7-b558-d7c117f2ae85",
|
||||
description="Get order cost estimate",
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"order_number": "TEST-001",
|
||||
"customer": {
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"phone": "123-456-7890",
|
||||
"address": "123 Test St",
|
||||
"city": "Test City",
|
||||
"state": "TS",
|
||||
"zip": "12345",
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"file_url": "https://example.com/model.stl",
|
||||
"quantity": "1",
|
||||
"color": "black",
|
||||
"profile": "PLA",
|
||||
}
|
||||
],
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("total_price", 9.31),
|
||||
("shipping_cost", 5.56),
|
||||
("printing_cost", 3.75),
|
||||
],
|
||||
test_mock={
|
||||
"_make_request": lambda *args, **kwargs: {
|
||||
"totalPrice": 9.31,
|
||||
"shippingCost": 5.56,
|
||||
"printingCost": 3.75,
|
||||
},
|
||||
"_convert_to_color": lambda *args, **kwargs: "black",
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
order_data = self._format_order_data(
|
||||
input_data.customer,
|
||||
input_data.order_number,
|
||||
input_data.items,
|
||||
credentials.api_key.get_secret_value(),
|
||||
)
|
||||
try:
|
||||
result = self._make_request(
|
||||
"POST",
|
||||
"order/estimate",
|
||||
credentials.api_key.get_secret_value(),
|
||||
json=order_data,
|
||||
)
|
||||
yield "total_price", result["totalPrice"]
|
||||
yield "shipping_cost", result["shippingCost"]
|
||||
yield "printing_cost", result["printingCost"]
|
||||
except baserequests.HTTPError as e:
|
||||
yield "error", str(f"Error estimating order: {e} {e.response.text}")
|
||||
raise
|
||||
|
||||
|
||||
class Slant3DEstimateShippingBlock(Slant3DBlockBase):
|
||||
"""Block for getting shipping cost estimates"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: Slant3DCredentialsInput = Slant3DCredentialsField()
|
||||
order_number: str = SchemaField(
|
||||
description="Your custom order number (or leave blank for a random one)",
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
customer: CustomerDetails = SchemaField(
|
||||
description="Customer details for where to ship the item"
|
||||
)
|
||||
items: List[OrderItem] = SchemaField(
|
||||
description="List of items to print",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
shipping_cost: float = SchemaField(description="Estimated shipping cost")
|
||||
currency_code: str = SchemaField(description="Currency code (e.g., 'usd')")
|
||||
error: str = SchemaField(description="Error message if estimation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="00aae2a1-caf6-4a74-8175-39a0615d44e1",
|
||||
description="Get shipping cost estimate",
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"order_number": "TEST-001",
|
||||
"customer": {
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"phone": "123-456-7890",
|
||||
"address": "123 Test St",
|
||||
"city": "Test City",
|
||||
"state": "TS",
|
||||
"zip": "12345",
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"file_url": "https://example.com/model.stl",
|
||||
"quantity": "1",
|
||||
"color": "black",
|
||||
"profile": "PLA",
|
||||
}
|
||||
],
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("shipping_cost", 4.81), ("currency_code", "usd")],
|
||||
test_mock={
|
||||
"_make_request": lambda *args, **kwargs: {
|
||||
"shippingCost": 4.81,
|
||||
"currencyCode": "usd",
|
||||
},
|
||||
"_convert_to_color": lambda *args, **kwargs: "black",
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
order_data = self._format_order_data(
|
||||
input_data.customer,
|
||||
input_data.order_number,
|
||||
input_data.items,
|
||||
credentials.api_key.get_secret_value(),
|
||||
)
|
||||
result = self._make_request(
|
||||
"POST",
|
||||
"order/estimateShipping",
|
||||
credentials.api_key.get_secret_value(),
|
||||
json=order_data,
|
||||
)
|
||||
yield "shipping_cost", result["shippingCost"]
|
||||
yield "currency_code", result["currencyCode"]
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
raise
|
||||
|
||||
|
||||
class Slant3DGetOrdersBlock(Slant3DBlockBase):
|
||||
"""Block for retrieving all orders"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: Slant3DCredentialsInput = Slant3DCredentialsField()
|
||||
|
||||
class Output(BlockSchema):
|
||||
orders: List[str] = SchemaField(description="List of orders with their details")
|
||||
error: str = SchemaField(description="Error message if request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="42283bf5-8a32-4fb4-92a2-60a9ea48e105",
|
||||
description="Get all orders for the account",
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
# This block is disabled for cloud hosted because it allows access to all orders for the account
|
||||
disabled=settings.Settings().config.behave_as == BehaveAs.CLOUD,
|
||||
test_input={"credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"orders",
|
||||
[
|
||||
"1234567890",
|
||||
],
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"_make_request": lambda *args, **kwargs: {
|
||||
"ordersData": [
|
||||
{
|
||||
"orderId": 1234567890,
|
||||
"orderTimestamp": {
|
||||
"_seconds": 1719510986,
|
||||
"_nanoseconds": 710000000,
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = self._make_request(
|
||||
"GET", "order", credentials.api_key.get_secret_value()
|
||||
)
|
||||
yield "orders", [str(order["orderId"]) for order in result["ordersData"]]
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
raise
|
||||
|
||||
|
||||
class Slant3DTrackingBlock(Slant3DBlockBase):
|
||||
"""Block for tracking order status and shipping"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: Slant3DCredentialsInput = Slant3DCredentialsField()
|
||||
order_id: str = SchemaField(description="Slant3D order ID to track")
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Order status")
|
||||
tracking_numbers: List[str] = SchemaField(
|
||||
description="List of tracking numbers"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if tracking failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="dd7c0293-c5af-4551-ba3e-fc162fb1fb89",
|
||||
description="Track order status and shipping",
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"order_id": "314144241",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "awaiting_shipment"), ("tracking_numbers", [])],
|
||||
test_mock={
|
||||
"_make_request": lambda *args, **kwargs: {
|
||||
"status": "awaiting_shipment",
|
||||
"trackingNumbers": [],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = self._make_request(
|
||||
"GET",
|
||||
f"order/{input_data.order_id}/get-tracking",
|
||||
credentials.api_key.get_secret_value(),
|
||||
)
|
||||
yield "status", result["status"]
|
||||
yield "tracking_numbers", result["trackingNumbers"]
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
raise
|
||||
|
||||
|
||||
class Slant3DCancelOrderBlock(Slant3DBlockBase):
|
||||
"""Block for canceling orders"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: Slant3DCredentialsInput = Slant3DCredentialsField()
|
||||
order_id: str = SchemaField(description="Slant3D order ID to cancel")
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Cancellation status message")
|
||||
error: str = SchemaField(description="Error message if cancellation failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="54de35e1-407f-450b-b5fa-3b5e2eba8185",
|
||||
description="Cancel an existing order",
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"order_id": "314144241",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Order cancelled")],
|
||||
test_mock={
|
||||
"_make_request": lambda *args, **kwargs: {"status": "Order cancelled"}
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = self._make_request(
|
||||
"DELETE",
|
||||
f"order/{input_data.order_id}",
|
||||
credentials.api_key.get_secret_value(),
|
||||
)
|
||||
yield "status", result["status"]
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
raise
|
||||
61
autogpt_platform/backend/backend/blocks/slant3d/slicing.py
Normal file
61
autogpt_platform/backend/backend/blocks/slant3d/slicing.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from backend.data.block import BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, SchemaField
|
||||
|
||||
from ._api import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
Slant3DCredentialsField,
|
||||
Slant3DCredentialsInput,
|
||||
)
|
||||
from .base import Slant3DBlockBase
|
||||
|
||||
|
||||
class Slant3DSlicerBlock(Slant3DBlockBase):
|
||||
"""Block for slicing 3D model files"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: Slant3DCredentialsInput = Slant3DCredentialsField()
|
||||
file_url: str = SchemaField(
|
||||
description="URL of the 3D model file to slice (STL)"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
message: str = SchemaField(description="Response message")
|
||||
price: float = SchemaField(description="Calculated price for printing")
|
||||
error: str = SchemaField(description="Error message if slicing failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f8a12c8d-3e4b-4d5f-b6a7-8c9d0e1f2g3h",
|
||||
description="Slice a 3D model file and get pricing information",
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"file_url": "https://example.com/model.stl",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("message", "Slicing successful"), ("price", 8.23)],
|
||||
test_mock={
|
||||
"_make_request": lambda *args, **kwargs: {
|
||||
"message": "Slicing successful",
|
||||
"data": {"price": 8.23},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
result = self._make_request(
|
||||
"POST",
|
||||
"slicer",
|
||||
credentials.api_key.get_secret_value(),
|
||||
json={"fileURL": input_data.file_url},
|
||||
)
|
||||
yield "message", result["message"]
|
||||
yield "price", result["data"]["price"]
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
raise
|
||||
125
autogpt_platform/backend/backend/blocks/slant3d/webhook.py
Normal file
125
autogpt_platform/backend/backend/blocks/slant3d/webhook.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockWebhookConfig,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import settings
|
||||
from backend.util.settings import AppEnvironment, BehaveAs
|
||||
|
||||
from ._api import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
Slant3DCredentialsField,
|
||||
Slant3DCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class Slant3DTriggerBase:
|
||||
"""Base class for Slant3D webhook triggers"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: Slant3DCredentialsInput = Slant3DCredentialsField()
|
||||
# Webhook URL is handled by the webhook system
|
||||
payload: dict = SchemaField(hidden=True, default={})
|
||||
|
||||
class Output(BlockSchema):
|
||||
payload: dict = SchemaField(
|
||||
description="The complete webhook payload received from Slant3D"
|
||||
)
|
||||
order_id: str = SchemaField(description="The ID of the affected order")
|
||||
error: str = SchemaField(
|
||||
description="Error message if payload processing failed"
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "payload", input_data.payload
|
||||
yield "order_id", input_data.payload["orderId"]
|
||||
|
||||
|
||||
class Slant3DOrderWebhookBlock(Slant3DTriggerBase, Block):
|
||||
"""Block for handling Slant3D order webhooks"""
|
||||
|
||||
class Input(Slant3DTriggerBase.Input):
|
||||
class EventsFilter(BaseModel):
|
||||
"""
|
||||
Currently Slant3D only supports 'SHIPPED' status updates
|
||||
Could be expanded in the future with more status types
|
||||
"""
|
||||
|
||||
shipped: bool = True
|
||||
|
||||
events: EventsFilter = SchemaField(
|
||||
title="Events",
|
||||
description="Order status events to subscribe to",
|
||||
default=EventsFilter(shipped=True),
|
||||
)
|
||||
|
||||
class Output(Slant3DTriggerBase.Output):
|
||||
status: str = SchemaField(description="The new status of the order")
|
||||
tracking_number: str = SchemaField(
|
||||
description="The tracking number for the shipment"
|
||||
)
|
||||
carrier_code: str = SchemaField(description="The carrier code (e.g., 'usps')")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8a74c2ad-0104-4640-962f-26c6b69e58cd",
|
||||
description=(
|
||||
"This block triggers on Slant3D order status updates and outputs "
|
||||
"the event details, including tracking information when orders are shipped."
|
||||
),
|
||||
# All webhooks are currently subscribed to for all orders. This works for self hosted, but not for cloud hosted prod
|
||||
disabled=(
|
||||
settings.Settings().config.behave_as == BehaveAs.CLOUD
|
||||
and settings.Settings().config.app_env != AppEnvironment.LOCAL
|
||||
),
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
webhook_config=BlockWebhookConfig(
|
||||
provider="slant3d",
|
||||
webhook_type="orders", # Only one type for now
|
||||
resource_format="", # No resource format needed
|
||||
event_filter_input="events",
|
||||
event_format="order.{event}",
|
||||
),
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"events": {"shipped": True},
|
||||
"payload": {
|
||||
"orderId": "1234567890",
|
||||
"status": "SHIPPED",
|
||||
"trackingNumber": "ABCDEF123456",
|
||||
"carrierCode": "usps",
|
||||
},
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"payload",
|
||||
{
|
||||
"orderId": "1234567890",
|
||||
"status": "SHIPPED",
|
||||
"trackingNumber": "ABCDEF123456",
|
||||
"carrierCode": "usps",
|
||||
},
|
||||
),
|
||||
("order_id", "1234567890"),
|
||||
("status", "SHIPPED"),
|
||||
("tracking_number", "ABCDEF123456"),
|
||||
("carrier_code", "usps"),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput: # type: ignore
|
||||
yield from super().run(input_data, **kwargs)
|
||||
|
||||
# Extract and normalize values from the payload
|
||||
yield "status", input_data.payload["status"]
|
||||
yield "tracking_number", input_data.payload["trackingNumber"]
|
||||
yield "carrier_code", input_data.payload["carrierCode"]
|
||||
@@ -10,6 +10,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import requests
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
@@ -29,13 +30,11 @@ TEST_CREDENTIALS_INPUT = {
|
||||
|
||||
class CreateTalkingAvatarVideoBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[Literal["d_id"], Literal["api_key"]] = (
|
||||
CredentialsField(
|
||||
provider="d_id",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The D-ID integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.D_ID], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
description="The D-ID integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
script_input: str = SchemaField(
|
||||
description="The text input for the script",
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import BaseLoader, Environment
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
from backend.util import json, text
|
||||
|
||||
jinja = Environment(loader=BaseLoader())
|
||||
formatter = text.TextFormatter()
|
||||
|
||||
|
||||
class MatchTextPatternBlock(Block):
|
||||
@@ -73,6 +71,7 @@ class ExtractTextInformationBlock(Block):
|
||||
description="Case sensitive match", default=True
|
||||
)
|
||||
dot_all: bool = SchemaField(description="Dot matches all", default=True)
|
||||
find_all: bool = SchemaField(description="Find all matches", default=False)
|
||||
|
||||
class Output(BlockSchema):
|
||||
positive: str = SchemaField(description="Extracted text")
|
||||
@@ -90,12 +89,27 @@ class ExtractTextInformationBlock(Block):
|
||||
{"text": "Hello, World!", "pattern": "Hello, (.+)", "group": 0},
|
||||
{"text": "Hello, World!", "pattern": "Hello, (.+)", "group": 2},
|
||||
{"text": "Hello, World!", "pattern": "hello,", "case_sensitive": False},
|
||||
{
|
||||
"text": "Hello, World!! Hello, Earth!!",
|
||||
"pattern": "Hello, (\\S+)",
|
||||
"group": 1,
|
||||
"find_all": False,
|
||||
},
|
||||
{
|
||||
"text": "Hello, World!! Hello, Earth!!",
|
||||
"pattern": "Hello, (\\S+)",
|
||||
"group": 1,
|
||||
"find_all": True,
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("positive", "World!"),
|
||||
("positive", "Hello, World!"),
|
||||
("negative", "Hello, World!"),
|
||||
("positive", "Hello,"),
|
||||
("positive", "World!!"),
|
||||
("positive", "World!!"),
|
||||
("positive", "Earth!!"),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -107,15 +121,21 @@ class ExtractTextInformationBlock(Block):
|
||||
flags = flags | re.DOTALL
|
||||
|
||||
if isinstance(input_data.text, str):
|
||||
text = input_data.text
|
||||
txt = input_data.text
|
||||
else:
|
||||
text = json.dumps(input_data.text)
|
||||
txt = json.dumps(input_data.text)
|
||||
|
||||
match = re.search(input_data.pattern, text, flags)
|
||||
if match and input_data.group <= len(match.groups()):
|
||||
yield "positive", match.group(input_data.group)
|
||||
else:
|
||||
yield "negative", text
|
||||
matches = [
|
||||
match.group(input_data.group)
|
||||
for match in re.finditer(input_data.pattern, txt, flags)
|
||||
if input_data.group <= len(match.groups())
|
||||
]
|
||||
for match in matches:
|
||||
yield "positive", match
|
||||
if not input_data.find_all:
|
||||
return
|
||||
if not matches:
|
||||
yield "negative", input_data.text
|
||||
|
||||
|
||||
class FillTextTemplateBlock(Block):
|
||||
@@ -146,19 +166,20 @@ class FillTextTemplateBlock(Block):
|
||||
"values": {"list": ["Hello", " World!"]},
|
||||
"format": "{% for item in list %}{{ item }}{% endfor %}",
|
||||
},
|
||||
{
|
||||
"values": {},
|
||||
"format": "{% set name = 'Alice' %}Hello, World! {{ name }}",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("output", "Hello, World! Alice"),
|
||||
("output", "Hello World!"),
|
||||
("output", "Hello, World! Alice"),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# For python.format compatibility: replace all {...} with {{..}}.
|
||||
# But avoid replacing {{...}} to {{{...}}}.
|
||||
fmt = re.sub(r"(?<!{){[ a-zA-Z0-9_]+}", r"{\g<0>}", input_data.format)
|
||||
template = jinja.from_string(fmt)
|
||||
yield "output", template.render(**input_data.values)
|
||||
yield "output", formatter.format_string(input_data.format, input_data.values)
|
||||
|
||||
|
||||
class CombineTextsBlock(Block):
|
||||
|
||||
@@ -9,6 +9,7 @@ from backend.data.model import (
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import requests
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
@@ -38,10 +39,8 @@ class UnrealTextToSpeechBlock(Block):
|
||||
default="Scarlett",
|
||||
)
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal["unreal_speech"], Literal["api_key"]
|
||||
Literal[ProviderName.UNREAL_SPEECH], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
provider="unreal_speech",
|
||||
supported_credential_types={"api_key"},
|
||||
description="The Unreal Speech integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
|
||||
@@ -65,7 +65,7 @@ class BlockCategory(Enum):
|
||||
|
||||
|
||||
class BlockSchema(BaseModel):
|
||||
cached_jsonschema: ClassVar[dict[str, Any]] = {}
|
||||
cached_jsonschema: ClassVar[dict[str, Any]]
|
||||
|
||||
@classmethod
|
||||
def jsonschema(cls) -> dict[str, Any]:
|
||||
@@ -145,6 +145,10 @@ class BlockSchema(BaseModel):
|
||||
- A field that is called `credentials` MUST be a `CredentialsMetaInput`.
|
||||
"""
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
|
||||
# Reset cached JSON schema to prevent inheriting it from parent class
|
||||
cls.cached_jsonschema = {}
|
||||
|
||||
credentials_fields = [
|
||||
field_name
|
||||
for field_name, info in cls.model_fields.items()
|
||||
@@ -176,6 +180,11 @@ class BlockSchema(BaseModel):
|
||||
f"Field 'credentials' on {cls.__qualname__} "
|
||||
f"must be of type {CredentialsMetaInput.__name__}"
|
||||
)
|
||||
if credentials_field := cls.model_fields.get(CREDENTIALS_FIELD_NAME):
|
||||
credentials_input_type = cast(
|
||||
CredentialsMetaInput, credentials_field.annotation
|
||||
)
|
||||
credentials_input_type.validate_credentials_field_schema(cls)
|
||||
|
||||
|
||||
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchema)
|
||||
|
||||
@@ -53,8 +53,8 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.LLAMA3_1_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_405B: 1,
|
||||
LlmModel.OLLAMA_DOLPHIN: 1,
|
||||
LlmModel.GEMINI_FLASH_1_5_8B: 1,
|
||||
LlmModel.GEMINI_FLASH_1_5_EXP: 1,
|
||||
LlmModel.GROK_BETA: 5,
|
||||
LlmModel.MISTRAL_NEMO: 1,
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: 1,
|
||||
@@ -62,6 +62,14 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.EVA_QWEN_2_5_32B: 1,
|
||||
LlmModel.DEEPSEEK_CHAT: 2,
|
||||
LlmModel.PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE: 1,
|
||||
LlmModel.QWEN_QWQ_32B_PREVIEW: 2,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
|
||||
LlmModel.AMAZON_NOVA_LITE_V1: 1,
|
||||
LlmModel.AMAZON_NOVA_MICRO_V1: 1,
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: 1,
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
|
||||
}
|
||||
|
||||
for model in LlmModel:
|
||||
|
||||
@@ -10,6 +10,7 @@ class BlockCostType(str, Enum):
|
||||
RUN = "run" # cost X credits per run
|
||||
BYTE = "byte" # cost X credits per byte
|
||||
SECOND = "second" # cost X credits per second
|
||||
DOLLAR = "dollar" # cost X dollars per run
|
||||
|
||||
|
||||
class BlockCost(BaseModel):
|
||||
|
||||
@@ -2,9 +2,9 @@ from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from prisma import Json
|
||||
from prisma.enums import UserBlockCreditType
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import UserBlockCredit
|
||||
from prisma.models import CreditTransaction
|
||||
|
||||
from backend.data.block import Block, BlockInput, get_block
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
@@ -76,7 +76,7 @@ class UserCredit(UserCreditBase):
|
||||
else cur_month.replace(year=cur_month.year + 1, month=1)
|
||||
)
|
||||
|
||||
user_credit = await UserBlockCredit.prisma().group_by(
|
||||
user_credit = await CreditTransaction.prisma().group_by(
|
||||
by=["userId"],
|
||||
sum={"amount": True},
|
||||
where={
|
||||
@@ -93,10 +93,10 @@ class UserCredit(UserCreditBase):
|
||||
key = f"MONTHLY-CREDIT-TOP-UP-{cur_month}"
|
||||
|
||||
try:
|
||||
await UserBlockCredit.prisma().create(
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"amount": self.num_user_credits_refill,
|
||||
"type": UserBlockCreditType.TOP_UP,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"userId": user_id,
|
||||
"transactionKey": key,
|
||||
"createdAt": self.time_now(),
|
||||
@@ -184,11 +184,11 @@ class UserCredit(UserCreditBase):
|
||||
if validate_balance and user_credit < cost:
|
||||
raise ValueError(f"Insufficient credit: {user_credit} < {cost}")
|
||||
|
||||
await UserBlockCredit.prisma().create(
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"amount": -cost,
|
||||
"type": UserBlockCreditType.USAGE,
|
||||
"type": CreditTransactionType.USAGE,
|
||||
"blockId": block.id,
|
||||
"metadata": Json(
|
||||
{
|
||||
@@ -202,11 +202,11 @@ class UserCredit(UserCreditBase):
|
||||
return cost
|
||||
|
||||
async def top_up_credits(self, user_id: str, amount: int):
|
||||
await UserBlockCredit.prisma().create(
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"amount": amount,
|
||||
"type": UserBlockCreditType.TOP_UP,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"createdAt": self.time_now(),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -29,6 +29,13 @@ async def connect():
|
||||
if not prisma.is_connected():
|
||||
raise ConnectionError("Failed to connect to Prisma.")
|
||||
|
||||
# Connection acquired from a pool like Supabase somehow still possibly allows
|
||||
# the db client obtains a connection but still reject query connection afterward.
|
||||
try:
|
||||
await prisma.execute_raw("SELECT 1")
|
||||
except Exception as e:
|
||||
raise ConnectionError("Failed to connect to Prisma.") from e
|
||||
|
||||
|
||||
@conn_retry("Prisma", "Releasing connection")
|
||||
async def disconnect():
|
||||
|
||||
@@ -9,7 +9,6 @@ from prisma.models import (
|
||||
AgentNodeExecution,
|
||||
AgentNodeExecutionInputOutput,
|
||||
)
|
||||
from prisma.types import AgentGraphExecutionWhereInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
|
||||
@@ -19,14 +18,14 @@ from backend.util import json, mock
|
||||
from backend.util.settings import Config
|
||||
|
||||
|
||||
class GraphExecution(BaseModel):
|
||||
class GraphExecutionEntry(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
start_node_execs: list["NodeExecution"]
|
||||
start_node_execs: list["NodeExecutionEntry"]
|
||||
|
||||
|
||||
class NodeExecution(BaseModel):
|
||||
class NodeExecutionEntry(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
@@ -325,34 +324,6 @@ async def update_execution_status(
|
||||
return ExecutionResult.from_db(res)
|
||||
|
||||
|
||||
async def get_graph_execution(
|
||||
graph_exec_id: str, user_id: str
|
||||
) -> AgentGraphExecution | None:
|
||||
"""
|
||||
Retrieve a specific graph execution by its ID.
|
||||
|
||||
Args:
|
||||
graph_exec_id (str): The ID of the graph execution to retrieve.
|
||||
user_id (str): The ID of the user to whom the graph (execution) belongs.
|
||||
|
||||
Returns:
|
||||
AgentGraphExecution | None: The graph execution if found, None otherwise.
|
||||
"""
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={"id": graph_exec_id, "userId": user_id},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
return execution
|
||||
|
||||
|
||||
async def list_executions(graph_id: str, graph_version: int | None = None) -> list[str]:
|
||||
where: AgentGraphExecutionWhereInput = {"agentGraphId": graph_id}
|
||||
if graph_version is not None:
|
||||
where["agentGraphVersion"] = graph_version
|
||||
executions = await AgentGraphExecution.prisma().find_many(where=where)
|
||||
return [execution.id for execution in executions]
|
||||
|
||||
|
||||
async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
|
||||
executions = await AgentNodeExecution.prisma().find_many(
|
||||
where={"agentGraphExecutionId": graph_exec_id},
|
||||
|
||||
@@ -105,6 +105,8 @@ class GraphExecution(BaseDbModel):
|
||||
duration: float
|
||||
total_run_time: float
|
||||
status: ExecutionStatus
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
|
||||
@staticmethod
|
||||
def from_db(execution: AgentGraphExecution):
|
||||
@@ -130,6 +132,8 @@ class GraphExecution(BaseDbModel):
|
||||
duration=duration,
|
||||
total_run_time=total_run_time,
|
||||
status=ExecutionStatus(execution.executionStatus),
|
||||
graph_id=execution.agentGraphId,
|
||||
graph_version=execution.agentGraphVersion,
|
||||
)
|
||||
|
||||
|
||||
@@ -139,7 +143,6 @@ class Graph(BaseDbModel):
|
||||
is_template: bool = False
|
||||
name: str
|
||||
description: str
|
||||
executions: list[GraphExecution] = []
|
||||
nodes: list[Node] = []
|
||||
links: list[Link] = []
|
||||
|
||||
@@ -254,7 +257,7 @@ class GraphModel(Graph):
|
||||
for link in self.links:
|
||||
input_links[link.sink_id].append(link)
|
||||
|
||||
# Nodes: required fields are filled or connected and dependencies are satisfied
|
||||
# Nodes: required fields are filled or connected
|
||||
for node in self.nodes:
|
||||
block = get_block(node.block_id)
|
||||
if block is None:
|
||||
@@ -275,38 +278,6 @@ class GraphModel(Graph):
|
||||
f"Node {block.name} #{node.id} required input missing: `{name}`"
|
||||
)
|
||||
|
||||
# Get input schema properties and check dependencies
|
||||
input_schema = block.input_schema.model_fields
|
||||
required_fields = block.input_schema.get_required_fields()
|
||||
|
||||
def has_value(name):
|
||||
return (
|
||||
node is not None
|
||||
and name in node.input_default
|
||||
and node.input_default[name] is not None
|
||||
and str(node.input_default[name]).strip() != ""
|
||||
) or (name in input_schema and input_schema[name].default is not None)
|
||||
|
||||
# Validate dependencies between fields
|
||||
for field_name, field_info in input_schema.items():
|
||||
|
||||
# Apply input dependency validation only on run & field with depends_on
|
||||
json_schema_extra = field_info.json_schema_extra or {}
|
||||
dependencies = json_schema_extra.get("depends_on", [])
|
||||
if not for_run or not dependencies:
|
||||
continue
|
||||
|
||||
# Check if dependent field has value in input_default
|
||||
field_has_value = has_value(field_name)
|
||||
field_is_required = field_name in required_fields
|
||||
|
||||
# Check for missing dependencies when dependent field is present
|
||||
missing_deps = [dep for dep in dependencies if not has_value(dep)]
|
||||
if missing_deps and (field_has_value or field_is_required):
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id}: Field `{field_name}` requires [{', '.join(missing_deps)}] to be set"
|
||||
)
|
||||
|
||||
node_map = {v.id: v for v in self.nodes}
|
||||
|
||||
def is_static_output_block(nid: str) -> bool:
|
||||
@@ -357,11 +328,6 @@ class GraphModel(Graph):
|
||||
|
||||
@staticmethod
|
||||
def from_db(graph: AgentGraph, hide_credentials: bool = False):
|
||||
executions = [
|
||||
GraphExecution.from_db(execution)
|
||||
for execution in graph.AgentGraphExecution or []
|
||||
]
|
||||
|
||||
return GraphModel(
|
||||
id=graph.id,
|
||||
user_id=graph.userId,
|
||||
@@ -370,7 +336,6 @@ class GraphModel(Graph):
|
||||
is_template=graph.isTemplate,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
executions=executions,
|
||||
nodes=[
|
||||
GraphModel._process_node(node, hide_credentials)
|
||||
for node in graph.AgentNodes or []
|
||||
@@ -440,7 +405,6 @@ async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
|
||||
|
||||
async def get_graphs(
|
||||
user_id: str,
|
||||
include_executions: bool = False,
|
||||
filter_by: Literal["active", "template"] | None = "active",
|
||||
) -> list[GraphModel]:
|
||||
"""
|
||||
@@ -448,35 +412,44 @@ async def get_graphs(
|
||||
Default behaviour is to get all currently active graphs.
|
||||
|
||||
Args:
|
||||
include_executions: Whether to include executions in the graph metadata.
|
||||
filter_by: An optional filter to either select templates or active graphs.
|
||||
user_id: The ID of the user that owns the graph.
|
||||
|
||||
Returns:
|
||||
list[GraphModel]: A list of objects representing the retrieved graphs.
|
||||
"""
|
||||
where_clause: AgentGraphWhereInput = {}
|
||||
where_clause: AgentGraphWhereInput = {"userId": user_id}
|
||||
|
||||
if filter_by == "active":
|
||||
where_clause["isActive"] = True
|
||||
elif filter_by == "template":
|
||||
where_clause["isTemplate"] = True
|
||||
|
||||
where_clause["userId"] = user_id
|
||||
|
||||
graph_include = AGENT_GRAPH_INCLUDE
|
||||
graph_include["AgentGraphExecution"] = include_executions
|
||||
|
||||
graphs = await AgentGraph.prisma().find_many(
|
||||
where=where_clause,
|
||||
distinct=["id"],
|
||||
order={"version": "desc"},
|
||||
include=graph_include,
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
|
||||
return [GraphModel.from_db(graph) for graph in graphs]
|
||||
|
||||
|
||||
async def get_executions(user_id: str) -> list[GraphExecution]:
|
||||
executions = await AgentGraphExecution.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"createdAt": "desc"},
|
||||
)
|
||||
return [GraphExecution.from_db(execution) for execution in executions]
|
||||
|
||||
|
||||
async def get_execution(user_id: str, execution_id: str) -> GraphExecution | None:
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={"id": execution_id, "userId": user_id}
|
||||
)
|
||||
return GraphExecution.from_db(execution) if execution else None
|
||||
|
||||
|
||||
async def get_graph(
|
||||
graph_id: str,
|
||||
version: int | None = None,
|
||||
|
||||
@@ -7,6 +7,7 @@ from pydantic import Field
|
||||
|
||||
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
|
||||
from backend.data.queue import AsyncRedisEventBus
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from .db import BaseDbModel
|
||||
|
||||
@@ -18,7 +19,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class Webhook(BaseDbModel):
|
||||
user_id: str
|
||||
provider: str
|
||||
provider: ProviderName
|
||||
credentials_id: str
|
||||
webhook_type: str
|
||||
resource: str
|
||||
@@ -37,7 +38,7 @@ class Webhook(BaseDbModel):
|
||||
return Webhook(
|
||||
id=webhook.id,
|
||||
user_id=webhook.userId,
|
||||
provider=webhook.provider,
|
||||
provider=ProviderName(webhook.provider),
|
||||
credentials_id=webhook.credentialsId,
|
||||
webhook_type=webhook.webhookType,
|
||||
resource=webhook.resource,
|
||||
@@ -61,7 +62,7 @@ async def create_webhook(webhook: Webhook) -> Webhook:
|
||||
data={
|
||||
"id": webhook.id,
|
||||
"userId": webhook.user_id,
|
||||
"provider": webhook.provider,
|
||||
"provider": webhook.provider.value,
|
||||
"credentialsId": webhook.credentials_id,
|
||||
"webhookType": webhook.webhook_type,
|
||||
"resource": webhook.resource,
|
||||
@@ -144,25 +145,28 @@ class WebhookEventBus(AsyncRedisEventBus[WebhookEvent]):
|
||||
def event_bus_name(self) -> str:
|
||||
return "webhooks"
|
||||
|
||||
async def publish(self, event: WebhookEvent):
|
||||
await self.publish_event(event, f"{event.webhook_id}/{event.event_type}")
|
||||
|
||||
async def listen(
|
||||
self, webhook_id: str, event_type: Optional[str] = None
|
||||
) -> AsyncGenerator[WebhookEvent, None]:
|
||||
async for event in self.listen_events(f"{webhook_id}/{event_type or '*'}"):
|
||||
yield event
|
||||
|
||||
|
||||
event_bus = WebhookEventBus()
|
||||
_webhook_event_bus = WebhookEventBus()
|
||||
|
||||
|
||||
async def publish_webhook_event(event: WebhookEvent):
|
||||
await event_bus.publish(event)
|
||||
await _webhook_event_bus.publish_event(
|
||||
event, f"{event.webhook_id}/{event.event_type}"
|
||||
)
|
||||
|
||||
|
||||
async def listen_for_webhook_event(
|
||||
async def listen_for_webhook_events(
|
||||
webhook_id: str, event_type: Optional[str] = None
|
||||
) -> AsyncGenerator[WebhookEvent, None]:
|
||||
async for event in _webhook_event_bus.listen_events(
|
||||
f"{webhook_id}/{event_type or '*'}"
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
async def wait_for_webhook_event(
|
||||
webhook_id: str, event_type: Optional[str] = None, timeout: Optional[float] = None
|
||||
) -> WebhookEvent | None:
|
||||
async for event in event_bus.listen(webhook_id, event_type):
|
||||
return event # Only one event is expected
|
||||
return await _webhook_event_bus.wait_for_event(
|
||||
f"{webhook_id}/{event_type or '*'}", timeout
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
@@ -11,19 +12,32 @@ from typing import (
|
||||
Optional,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
get_args,
|
||||
)
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, GetCoreSchemaHandler, SecretStr, field_serializer
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
GetCoreSchemaHandler,
|
||||
SecretStr,
|
||||
field_serializer,
|
||||
)
|
||||
from pydantic_core import (
|
||||
CoreSchema,
|
||||
PydanticUndefined,
|
||||
PydanticUndefinedType,
|
||||
ValidationError,
|
||||
core_schema,
|
||||
)
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.block import BlockSchema
|
||||
|
||||
T = TypeVar("T")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -120,11 +134,10 @@ def SchemaField(
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
placeholder: Optional[str] = None,
|
||||
advanced: Optional[bool] = None,
|
||||
advanced: Optional[bool] = False,
|
||||
secret: bool = False,
|
||||
exclude: bool = False,
|
||||
hidden: Optional[bool] = None,
|
||||
depends_on: list[str] | None = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
json_extra = {
|
||||
@@ -134,7 +147,6 @@ def SchemaField(
|
||||
"secret": secret,
|
||||
"advanced": advanced,
|
||||
"hidden": hidden,
|
||||
"depends_on": depends_on,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
@@ -148,7 +160,7 @@ def SchemaField(
|
||||
exclude=exclude,
|
||||
json_schema_extra=json_extra,
|
||||
**kwargs,
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class _BaseCredentials(BaseModel):
|
||||
@@ -222,7 +234,7 @@ class UserIntegrations(BaseModel):
|
||||
oauth_states: list[OAuthState] = Field(default_factory=list)
|
||||
|
||||
|
||||
CP = TypeVar("CP", bound=str)
|
||||
CP = TypeVar("CP", bound=ProviderName)
|
||||
CT = TypeVar("CT", bound=CredentialsType)
|
||||
|
||||
|
||||
@@ -235,19 +247,51 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
provider: CP
|
||||
type: CT
|
||||
|
||||
@staticmethod
|
||||
def _add_json_schema_extra(schema, cls: CredentialsMetaInput):
|
||||
schema["credentials_provider"] = get_args(
|
||||
cls.model_fields["provider"].annotation
|
||||
)
|
||||
schema["credentials_types"] = get_args(cls.model_fields["type"].annotation)
|
||||
|
||||
class CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=_add_json_schema_extra, # type: ignore
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
|
||||
"""Validates the schema of a `credentials` field"""
|
||||
field_schema = model.jsonschema()["properties"][CREDENTIALS_FIELD_NAME]
|
||||
try:
|
||||
schema_extra = _CredentialsFieldSchemaExtra[CP, CT].model_validate(
|
||||
field_schema
|
||||
)
|
||||
except ValidationError as e:
|
||||
if "Field required [type=missing" not in str(e):
|
||||
raise
|
||||
|
||||
raise TypeError(
|
||||
"Field 'credentials' JSON schema lacks required extra items: "
|
||||
f"{field_schema}"
|
||||
) from e
|
||||
|
||||
if (
|
||||
len(schema_extra.credentials_provider) > 1
|
||||
and not schema_extra.discriminator
|
||||
):
|
||||
raise TypeError("Multi-provider CredentialsField requires discriminator!")
|
||||
|
||||
|
||||
class _CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
|
||||
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
|
||||
credentials_provider: list[CP]
|
||||
credentials_scopes: Optional[list[str]]
|
||||
credentials_scopes: Optional[list[str]] = None
|
||||
credentials_types: list[CT]
|
||||
discriminator: Optional[str] = None
|
||||
discriminator_mapping: Optional[dict[str, CP]] = None
|
||||
|
||||
|
||||
def CredentialsField(
|
||||
provider: CP | list[CP],
|
||||
supported_credential_types: set[CT],
|
||||
required_scopes: set[str] = set(),
|
||||
*,
|
||||
discriminator: Optional[str] = None,
|
||||
@@ -255,26 +299,26 @@ def CredentialsField(
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> CredentialsMetaInput[CP, CT]:
|
||||
) -> CredentialsMetaInput:
|
||||
"""
|
||||
`CredentialsField` must and can only be used on fields named `credentials`.
|
||||
This is enforced by the `BlockSchema` base class.
|
||||
"""
|
||||
if not isinstance(provider, str) and len(provider) > 1 and not discriminator:
|
||||
raise TypeError("Multi-provider CredentialsField requires discriminator!")
|
||||
|
||||
field_schema_extra = CredentialsFieldSchemaExtra[CP, CT](
|
||||
credentials_provider=[provider] if isinstance(provider, str) else provider,
|
||||
credentials_scopes=list(required_scopes) or None, # omit if empty
|
||||
credentials_types=list(supported_credential_types),
|
||||
discriminator=discriminator,
|
||||
discriminator_mapping=discriminator_mapping,
|
||||
)
|
||||
field_schema_extra = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"credentials_scopes": list(required_scopes) or None,
|
||||
"discriminator": discriminator,
|
||||
"discriminator_mapping": discriminator_mapping,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
return Field(
|
||||
title=title,
|
||||
description=description,
|
||||
json_schema_extra=field_schema_extra.model_dump(exclude_none=True),
|
||||
json_schema_extra=field_schema_extra, # validated on BlockSchema init
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Generator, Generic, TypeVar
|
||||
from typing import Any, AsyncGenerator, Generator, Generic, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio.client import PubSub as AsyncPubSub
|
||||
@@ -48,12 +49,12 @@ class BaseRedisEventBus(Generic[M], ABC):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse event result from Redis {msg} {e}")
|
||||
|
||||
def _subscribe(
|
||||
def _get_pubsub_channel(
|
||||
self, connection: redis.Redis | redis.AsyncRedis, channel_key: str
|
||||
) -> tuple[PubSub | AsyncPubSub, str]:
|
||||
channel_name = f"{self.event_bus_name}/{channel_key}"
|
||||
full_channel_name = f"{self.event_bus_name}/{channel_key}"
|
||||
pubsub = connection.pubsub()
|
||||
return pubsub, channel_name
|
||||
return pubsub, full_channel_name
|
||||
|
||||
|
||||
class RedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
@@ -64,17 +65,19 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
return redis.get_redis()
|
||||
|
||||
def publish_event(self, event: M, channel_key: str):
|
||||
message, channel_name = self._serialize_message(event, channel_key)
|
||||
self.connection.publish(channel_name, message)
|
||||
message, full_channel_name = self._serialize_message(event, channel_key)
|
||||
self.connection.publish(full_channel_name, message)
|
||||
|
||||
def listen_events(self, channel_key: str) -> Generator[M, None, None]:
|
||||
pubsub, channel_name = self._subscribe(self.connection, channel_key)
|
||||
pubsub, full_channel_name = self._get_pubsub_channel(
|
||||
self.connection, channel_key
|
||||
)
|
||||
assert isinstance(pubsub, PubSub)
|
||||
|
||||
if "*" in channel_key:
|
||||
pubsub.psubscribe(channel_name)
|
||||
pubsub.psubscribe(full_channel_name)
|
||||
else:
|
||||
pubsub.subscribe(channel_name)
|
||||
pubsub.subscribe(full_channel_name)
|
||||
|
||||
for message in pubsub.listen():
|
||||
if event := self._deserialize_message(message, channel_key):
|
||||
@@ -89,19 +92,31 @@ class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
return await redis.get_redis_async()
|
||||
|
||||
async def publish_event(self, event: M, channel_key: str):
|
||||
message, channel_name = self._serialize_message(event, channel_key)
|
||||
message, full_channel_name = self._serialize_message(event, channel_key)
|
||||
connection = await self.connection
|
||||
await connection.publish(channel_name, message)
|
||||
await connection.publish(full_channel_name, message)
|
||||
|
||||
async def listen_events(self, channel_key: str) -> AsyncGenerator[M, None]:
|
||||
pubsub, channel_name = self._subscribe(await self.connection, channel_key)
|
||||
pubsub, full_channel_name = self._get_pubsub_channel(
|
||||
await self.connection, channel_key
|
||||
)
|
||||
assert isinstance(pubsub, AsyncPubSub)
|
||||
|
||||
if "*" in channel_key:
|
||||
await pubsub.psubscribe(channel_name)
|
||||
await pubsub.psubscribe(full_channel_name)
|
||||
else:
|
||||
await pubsub.subscribe(channel_name)
|
||||
await pubsub.subscribe(full_channel_name)
|
||||
|
||||
async for message in pubsub.listen():
|
||||
if event := self._deserialize_message(message, channel_key):
|
||||
yield event
|
||||
|
||||
async def wait_for_event(
|
||||
self, channel_key: str, timeout: Optional[float] = None
|
||||
) -> M | None:
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
anext(aiter(self.listen_events(channel_key))), timeout
|
||||
)
|
||||
except TimeoutError:
|
||||
return None
|
||||
|
||||
@@ -25,8 +25,8 @@ from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionResult,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
NodeExecution,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
merge_execution_input,
|
||||
parse_execution_output,
|
||||
)
|
||||
@@ -96,13 +96,13 @@ class LogMetadata:
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
ExecutionStream = Generator[NodeExecution, None, None]
|
||||
ExecutionStream = Generator[NodeExecutionEntry, None, None]
|
||||
|
||||
|
||||
def execute_node(
|
||||
db_client: "DatabaseManager",
|
||||
creds_manager: IntegrationCredentialsManager,
|
||||
data: NodeExecution,
|
||||
data: NodeExecutionEntry,
|
||||
execution_stats: dict[str, Any] | None = None,
|
||||
) -> ExecutionStream:
|
||||
"""
|
||||
@@ -252,15 +252,15 @@ def _enqueue_next_nodes(
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
log_metadata: LogMetadata,
|
||||
) -> list[NodeExecution]:
|
||||
) -> list[NodeExecutionEntry]:
|
||||
def add_enqueued_execution(
|
||||
node_exec_id: str, node_id: str, data: BlockInput
|
||||
) -> NodeExecution:
|
||||
) -> NodeExecutionEntry:
|
||||
exec_update = db_client.update_execution_status(
|
||||
node_exec_id, ExecutionStatus.QUEUED, data
|
||||
)
|
||||
db_client.send_execution_update(exec_update)
|
||||
return NodeExecution(
|
||||
return NodeExecutionEntry(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
@@ -269,7 +269,7 @@ def _enqueue_next_nodes(
|
||||
data=data,
|
||||
)
|
||||
|
||||
def register_next_executions(node_link: Link) -> list[NodeExecution]:
|
||||
def register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
|
||||
enqueued_executions = []
|
||||
next_output_name = node_link.source_name
|
||||
next_input_name = node_link.sink_name
|
||||
@@ -501,8 +501,8 @@ class Executor:
|
||||
@error_logged
|
||||
def on_node_execution(
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
node_exec: NodeExecution,
|
||||
q: ExecutionQueue[NodeExecutionEntry],
|
||||
node_exec: NodeExecutionEntry,
|
||||
) -> dict[str, Any]:
|
||||
log_metadata = LogMetadata(
|
||||
user_id=node_exec.user_id,
|
||||
@@ -529,8 +529,8 @@ class Executor:
|
||||
@time_measured
|
||||
def _on_node_execution(
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
node_exec: NodeExecution,
|
||||
q: ExecutionQueue[NodeExecutionEntry],
|
||||
node_exec: NodeExecutionEntry,
|
||||
log_metadata: LogMetadata,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
@@ -580,7 +580,9 @@ class Executor:
|
||||
|
||||
@classmethod
|
||||
@error_logged
|
||||
def on_graph_execution(cls, graph_exec: GraphExecution, cancel: threading.Event):
|
||||
def on_graph_execution(
|
||||
cls, graph_exec: GraphExecutionEntry, cancel: threading.Event
|
||||
):
|
||||
log_metadata = LogMetadata(
|
||||
user_id=graph_exec.user_id,
|
||||
graph_eid=graph_exec.graph_exec_id,
|
||||
@@ -605,7 +607,7 @@ class Executor:
|
||||
@time_measured
|
||||
def _on_graph_execution(
|
||||
cls,
|
||||
graph_exec: GraphExecution,
|
||||
graph_exec: GraphExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
log_metadata: LogMetadata,
|
||||
) -> tuple[dict[str, Any], Exception | None]:
|
||||
@@ -636,13 +638,13 @@ class Executor:
|
||||
cancel_thread.start()
|
||||
|
||||
try:
|
||||
queue = ExecutionQueue[NodeExecution]()
|
||||
queue = ExecutionQueue[NodeExecutionEntry]()
|
||||
for node_exec in graph_exec.start_node_execs:
|
||||
queue.add(node_exec)
|
||||
|
||||
running_executions: dict[str, AsyncResult] = {}
|
||||
|
||||
def make_exec_callback(exec_data: NodeExecution):
|
||||
def make_exec_callback(exec_data: NodeExecutionEntry):
|
||||
node_id = exec_data.node_id
|
||||
|
||||
def callback(result: object):
|
||||
@@ -717,7 +719,7 @@ class ExecutionManager(AppService):
|
||||
self.use_redis = True
|
||||
self.use_supabase = True
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.queue = ExecutionQueue[GraphExecution]()
|
||||
self.queue = ExecutionQueue[GraphExecutionEntry]()
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
|
||||
@classmethod
|
||||
@@ -768,7 +770,7 @@ class ExecutionManager(AppService):
|
||||
data: BlockInput,
|
||||
user_id: str,
|
||||
graph_version: int | None = None,
|
||||
) -> GraphExecution:
|
||||
) -> GraphExecutionEntry:
|
||||
graph: GraphModel | None = self.db_client.get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
)
|
||||
@@ -818,7 +820,7 @@ class ExecutionManager(AppService):
|
||||
starting_node_execs = []
|
||||
for node_exec in node_execs:
|
||||
starting_node_execs.append(
|
||||
NodeExecution(
|
||||
NodeExecutionEntry(
|
||||
user_id=user_id,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
@@ -832,7 +834,7 @@ class ExecutionManager(AppService):
|
||||
)
|
||||
self.db_client.send_execution_update(exec_update)
|
||||
|
||||
graph_exec = GraphExecution(
|
||||
graph_exec = GraphExecutionEntry(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from autogpt_libs.utils.synchronize import RedisKeyedMutex
|
||||
from redis.lock import Lock as RedisLock
|
||||
@@ -8,10 +9,13 @@ from redis.lock import Lock as RedisLock
|
||||
from backend.data import redis
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.integrations.oauth import BaseOAuthHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
@@ -148,7 +152,7 @@ class IntegrationCredentialsManager:
|
||||
self.store.locks.release_all_locks()
|
||||
|
||||
|
||||
def _get_provider_oauth_handler(provider_name: str) -> BaseOAuthHandler:
|
||||
def _get_provider_oauth_handler(provider_name: str) -> "BaseOAuthHandler":
|
||||
if provider_name not in HANDLERS_BY_NAME:
|
||||
raise KeyError(f"Unknown provider '{provider_name}'")
|
||||
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
from .base import BaseOAuthHandler
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .github import GitHubOAuthHandler
|
||||
from .google import GoogleOAuthHandler
|
||||
from .notion import NotionOAuthHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..providers import ProviderName
|
||||
from .base import BaseOAuthHandler
|
||||
|
||||
# --8<-- [start:HANDLERS_BY_NAMEExample]
|
||||
HANDLERS_BY_NAME: dict[str, type[BaseOAuthHandler]] = {
|
||||
HANDLERS_BY_NAME: dict["ProviderName", type["BaseOAuthHandler"]] = {
|
||||
handler.PROVIDER_NAME: handler
|
||||
for handler in [
|
||||
GitHubOAuthHandler,
|
||||
|
||||
@@ -4,13 +4,14 @@ from abc import ABC, abstractmethod
|
||||
from typing import ClassVar
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseOAuthHandler(ABC):
|
||||
# --8<-- [start:BaseOAuthHandler1]
|
||||
PROVIDER_NAME: ClassVar[str]
|
||||
PROVIDER_NAME: ClassVar[ProviderName]
|
||||
DEFAULT_SCOPES: ClassVar[list[str]] = []
|
||||
# --8<-- [end:BaseOAuthHandler1]
|
||||
|
||||
@@ -76,6 +77,8 @@ class BaseOAuthHandler(ABC):
|
||||
"""Handles the default scopes for the provider"""
|
||||
# If scopes are empty, use the default scopes for the provider
|
||||
if not scopes:
|
||||
logger.debug(f"Using default scopes for provider {self.PROVIDER_NAME}")
|
||||
logger.debug(
|
||||
f"Using default scopes for provider {self.PROVIDER_NAME.value}"
|
||||
)
|
||||
scopes = self.DEFAULT_SCOPES
|
||||
return scopes
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import requests
|
||||
|
||||
from .base import BaseOAuthHandler
|
||||
@@ -23,7 +24,7 @@ class GitHubOAuthHandler(BaseOAuthHandler):
|
||||
access token *with no refresh token*.
|
||||
""" # noqa
|
||||
|
||||
PROVIDER_NAME = "github"
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
|
||||
@@ -9,6 +9,7 @@ from google_auth_oauthlib.flow import Flow
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from .base import BaseOAuthHandler
|
||||
|
||||
@@ -21,7 +22,7 @@ class GoogleOAuthHandler(BaseOAuthHandler):
|
||||
Based on the documentation at https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
""" # noqa
|
||||
|
||||
PROVIDER_NAME = "google"
|
||||
PROVIDER_NAME = ProviderName.GOOGLE
|
||||
EMAIL_ENDPOINT = "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||
DEFAULT_SCOPES = [
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
|
||||
@@ -2,6 +2,7 @@ from base64 import b64encode
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import requests
|
||||
|
||||
from .base import BaseOAuthHandler
|
||||
@@ -16,7 +17,7 @@ class NotionOAuthHandler(BaseOAuthHandler):
|
||||
- Notion doesn't use scopes
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = "notion"
|
||||
PROVIDER_NAME = ProviderName.NOTION
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
|
||||
@@ -1,7 +1,30 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
# --8<-- [start:ProviderName]
|
||||
class ProviderName(str, Enum):
|
||||
ANTHROPIC = "anthropic"
|
||||
DISCORD = "discord"
|
||||
D_ID = "d_id"
|
||||
E2B = "e2b"
|
||||
EXA = "exa"
|
||||
FAL = "fal"
|
||||
GITHUB = "github"
|
||||
GOOGLE = "google"
|
||||
GOOGLE_MAPS = "google_maps"
|
||||
GROQ = "groq"
|
||||
HUBSPOT = "hubspot"
|
||||
IDEOGRAM = "ideogram"
|
||||
JINA = "jina"
|
||||
MEDIUM = "medium"
|
||||
NOTION = "notion"
|
||||
OLLAMA = "ollama"
|
||||
OPENAI = "openai"
|
||||
OPENWEATHERMAP = "openweathermap"
|
||||
OPEN_ROUTER = "open_router"
|
||||
PINECONE = "pinecone"
|
||||
REPLICATE = "replicate"
|
||||
REVID = "revid"
|
||||
SLANT3D = "slant3d"
|
||||
UNREAL_SPEECH = "unreal_speech"
|
||||
# --8<-- [end:ProviderName]
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .github import GithubWebhooksManager
|
||||
from .slant3d import Slant3DWebhooksManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..providers import ProviderName
|
||||
from .base import BaseWebhooksManager
|
||||
|
||||
# --8<-- [start:WEBHOOK_MANAGERS_BY_NAME]
|
||||
WEBHOOK_MANAGERS_BY_NAME: dict[str, type["BaseWebhooksManager"]] = {
|
||||
WEBHOOK_MANAGERS_BY_NAME: dict["ProviderName", type["BaseWebhooksManager"]] = {
|
||||
handler.PROVIDER_NAME: handler
|
||||
for handler in [
|
||||
GithubWebhooksManager,
|
||||
Slant3DWebhooksManager,
|
||||
]
|
||||
}
|
||||
# --8<-- [end:WEBHOOK_MANAGERS_BY_NAME]
|
||||
|
||||
@@ -9,6 +9,7 @@ from strenum import StrEnum
|
||||
|
||||
from backend.data import integrations
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
from backend.util.settings import Config
|
||||
|
||||
@@ -20,7 +21,7 @@ WT = TypeVar("WT", bound=StrEnum)
|
||||
|
||||
class BaseWebhooksManager(ABC, Generic[WT]):
|
||||
# --8<-- [start:BaseWebhooksManager1]
|
||||
PROVIDER_NAME: ClassVar[str]
|
||||
PROVIDER_NAME: ClassVar[ProviderName]
|
||||
# --8<-- [end:BaseWebhooksManager1]
|
||||
|
||||
WebhookType: WT
|
||||
@@ -81,7 +82,9 @@ class BaseWebhooksManager(ABC, Generic[WT]):
|
||||
# --8<-- [end:BaseWebhooksManager3]
|
||||
|
||||
# --8<-- [start:BaseWebhooksManager5]
|
||||
async def trigger_ping(self, webhook: integrations.Webhook) -> None:
|
||||
async def trigger_ping(
|
||||
self, webhook: integrations.Webhook, credentials: Credentials | None
|
||||
) -> None:
|
||||
"""
|
||||
Triggers a ping to the given webhook.
|
||||
|
||||
@@ -141,7 +144,7 @@ class BaseWebhooksManager(ABC, Generic[WT]):
|
||||
secret = secrets.token_hex(32)
|
||||
provider_name = self.PROVIDER_NAME
|
||||
ingress_url = (
|
||||
f"{app_config.platform_base_url}/api/integrations/{provider_name}"
|
||||
f"{app_config.platform_base_url}/api/integrations/{provider_name.value}"
|
||||
f"/webhooks/{id}/ingress"
|
||||
)
|
||||
provider_webhook_id, config = await self._register_webhook(
|
||||
|
||||
@@ -8,6 +8,7 @@ from strenum import StrEnum
|
||||
|
||||
from backend.data import integrations
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from .base import BaseWebhooksManager
|
||||
|
||||
@@ -20,7 +21,7 @@ class GithubWebhookType(StrEnum):
|
||||
|
||||
|
||||
class GithubWebhooksManager(BaseWebhooksManager):
|
||||
PROVIDER_NAME = "github"
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
WebhookType = GithubWebhookType
|
||||
|
||||
@@ -58,10 +59,15 @@ class GithubWebhooksManager(BaseWebhooksManager):
|
||||
|
||||
return payload, event_type
|
||||
|
||||
async def trigger_ping(self, webhook: integrations.Webhook) -> None:
|
||||
async def trigger_ping(
|
||||
self, webhook: integrations.Webhook, credentials: Credentials | None
|
||||
) -> None:
|
||||
if not credentials:
|
||||
raise ValueError("Credentials are required but were not passed")
|
||||
|
||||
headers = {
|
||||
**self.GITHUB_API_DEFAULT_HEADERS,
|
||||
"Authorization": f"Bearer {webhook.config.get('access_token')}",
|
||||
"Authorization": credentials.bearer(),
|
||||
}
|
||||
|
||||
repo, github_hook_id = webhook.resource, webhook.provider_webhook_id
|
||||
|
||||
@@ -95,11 +95,18 @@ async def on_node_activate(
|
||||
if not block.webhook_config:
|
||||
return node
|
||||
|
||||
provider = block.webhook_config.provider
|
||||
if provider not in WEBHOOK_MANAGERS_BY_NAME:
|
||||
raise ValueError(
|
||||
f"Block #{block.id} has webhook_config for provider {provider} "
|
||||
"which does not support webhooks"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Activating webhook node #{node.id} with config {block.webhook_config}"
|
||||
)
|
||||
|
||||
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[block.webhook_config.provider]()
|
||||
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
|
||||
|
||||
try:
|
||||
resource = block.webhook_config.resource_format.format(**node.input_default)
|
||||
@@ -167,7 +174,14 @@ async def on_node_deactivate(
|
||||
if not block.webhook_config:
|
||||
return node
|
||||
|
||||
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[block.webhook_config.provider]()
|
||||
provider = block.webhook_config.provider
|
||||
if provider not in WEBHOOK_MANAGERS_BY_NAME:
|
||||
raise ValueError(
|
||||
f"Block #{block.id} has webhook_config for provider {provider} "
|
||||
"which does not support webhooks"
|
||||
)
|
||||
|
||||
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
|
||||
|
||||
if node.webhook_id:
|
||||
logger.debug(f"Node #{node.id} has webhook_id {node.webhook_id}")
|
||||
@@ -189,7 +203,7 @@ async def on_node_deactivate(
|
||||
logger.warning(
|
||||
f"Cannot deregister webhook #{webhook.id}: credentials "
|
||||
f"#{webhook.credentials_id} not available "
|
||||
f"({webhook.provider} webhook ID: {webhook.provider_webhook_id})"
|
||||
f"({webhook.provider.value} webhook ID: {webhook.provider_webhook_id})"
|
||||
)
|
||||
return updated_node
|
||||
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from fastapi import Request
|
||||
|
||||
from backend.data import integrations
|
||||
from backend.data.model import APIKeyCredentials, Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.base import BaseWebhooksManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Slant3DWebhooksManager(BaseWebhooksManager):
|
||||
"""Manager for Slant3D webhooks"""
|
||||
|
||||
PROVIDER_NAME = ProviderName.SLANT3D
|
||||
BASE_URL = "https://www.slant3dapi.com/api"
|
||||
|
||||
async def _register_webhook(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
webhook_type: str,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
ingress_url: str,
|
||||
secret: str,
|
||||
) -> tuple[str, dict]:
|
||||
"""Register a new webhook with Slant3D"""
|
||||
|
||||
if not isinstance(credentials, APIKeyCredentials):
|
||||
raise ValueError("API key is required to register a webhook")
|
||||
|
||||
headers = {
|
||||
"api-key": credentials.api_key.get_secret_value(),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Slant3D's API doesn't use events list, just register for all order updates
|
||||
payload = {"endPoint": ingress_url}
|
||||
|
||||
response = requests.post(
|
||||
f"{self.BASE_URL}/customer/webhookSubscribe", headers=headers, json=payload
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
error = response.json().get("error", "Unknown error")
|
||||
raise RuntimeError(f"Failed to register webhook: {error}")
|
||||
|
||||
webhook_config = {
|
||||
"endpoint": ingress_url,
|
||||
"provider": self.PROVIDER_NAME,
|
||||
"events": ["order.shipped"], # Currently the only supported event
|
||||
"type": webhook_type,
|
||||
}
|
||||
|
||||
return "", webhook_config
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(
|
||||
cls, webhook: integrations.Webhook, request: Request
|
||||
) -> tuple[dict, str]:
|
||||
"""Validate incoming webhook payload from Slant3D"""
|
||||
|
||||
payload = await request.json()
|
||||
|
||||
# Validate required fields from Slant3D API spec
|
||||
required_fields = ["orderId", "status", "trackingNumber", "carrierCode"]
|
||||
missing_fields = [field for field in required_fields if field not in payload]
|
||||
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")
|
||||
|
||||
# Normalize payload structure
|
||||
normalized_payload = {
|
||||
"orderId": payload["orderId"],
|
||||
"status": payload["status"],
|
||||
"trackingNumber": payload["trackingNumber"],
|
||||
"carrierCode": payload["carrierCode"],
|
||||
}
|
||||
|
||||
# Currently Slant3D only sends shipping notifications
|
||||
# Convert status to lowercase for event format compatibility
|
||||
event_type = f"order.{payload['status'].lower()}"
|
||||
|
||||
return normalized_payload, event_type
|
||||
|
||||
async def _deregister_webhook(
|
||||
self, webhook: integrations.Webhook, credentials: Credentials
|
||||
) -> None:
|
||||
"""
|
||||
Note: Slant3D API currently doesn't provide a deregistration endpoint.
|
||||
This would need to be handled through support.
|
||||
"""
|
||||
# Log warning since we can't properly deregister
|
||||
logger.warning(
|
||||
f"Warning: Manual deregistration required for webhook {webhook.id}"
|
||||
)
|
||||
pass
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Annotated, Literal
|
||||
from typing import TYPE_CHECKING, Annotated, Literal
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
@@ -9,8 +9,8 @@ from backend.data.integrations import (
|
||||
WebhookEvent,
|
||||
get_all_webhooks,
|
||||
get_webhook,
|
||||
listen_for_webhook_event,
|
||||
publish_webhook_event,
|
||||
wait_for_webhook_event,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
@@ -20,12 +20,16 @@ from backend.data.model import (
|
||||
)
|
||||
from backend.executor.manager import ExecutionManager
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import WEBHOOK_MANAGERS_BY_NAME
|
||||
from backend.util.exceptions import NeedConfirmation
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.integrations.oauth import BaseOAuthHandler
|
||||
|
||||
from ..utils import get_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -42,7 +46,9 @@ class LoginResponse(BaseModel):
|
||||
|
||||
@router.get("/{provider}/login")
|
||||
def login(
|
||||
provider: Annotated[str, Path(title="The provider to initiate an OAuth flow for")],
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider to initiate an OAuth flow for")
|
||||
],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
request: Request,
|
||||
scopes: Annotated[
|
||||
@@ -74,7 +80,9 @@ class CredentialsMetaResponse(BaseModel):
|
||||
|
||||
@router.post("/{provider}/callback")
|
||||
def callback(
|
||||
provider: Annotated[str, Path(title="The target provider for this OAuth exchange")],
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The target provider for this OAuth exchange")
|
||||
],
|
||||
code: Annotated[str, Body(title="Authorization code acquired by user login")],
|
||||
state_token: Annotated[str, Body(title="Anti-CSRF nonce")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
@@ -103,11 +111,12 @@ def callback(
|
||||
if not set(scopes).issubset(set(credentials.scopes)):
|
||||
# For now, we'll just log the warning and continue
|
||||
logger.warning(
|
||||
f"Granted scopes {credentials.scopes} for {provider}do not include all requested scopes {scopes}"
|
||||
f"Granted scopes {credentials.scopes} for provider {provider.value} "
|
||||
f"do not include all requested scopes {scopes}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Code->Token exchange failed for provider {provider}: {e}")
|
||||
logger.error(f"Code->Token exchange failed for provider {provider.value}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Failed to exchange code for tokens: {str(e)}"
|
||||
)
|
||||
@@ -116,7 +125,8 @@ def callback(
|
||||
creds_manager.create(user_id, credentials)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully processed OAuth callback for user {user_id} and provider {provider}"
|
||||
f"Successfully processed OAuth callback for user {user_id} "
|
||||
f"and provider {provider.value}"
|
||||
)
|
||||
return CredentialsMetaResponse(
|
||||
id=credentials.id,
|
||||
@@ -148,7 +158,9 @@ def list_credentials(
|
||||
|
||||
@router.get("/{provider}/credentials")
|
||||
def list_credentials_by_provider(
|
||||
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider to list credentials for")
|
||||
],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
credentials = creds_manager.store.get_creds_by_provider(user_id, provider)
|
||||
@@ -167,7 +179,9 @@ def list_credentials_by_provider(
|
||||
|
||||
@router.get("/{provider}/credentials/{cred_id}")
|
||||
def get_credential(
|
||||
provider: Annotated[str, Path(title="The provider to retrieve credentials for")],
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider to retrieve credentials for")
|
||||
],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> Credentials:
|
||||
@@ -184,7 +198,9 @@ def get_credential(
|
||||
@router.post("/{provider}/credentials", status_code=201)
|
||||
def create_api_key_credentials(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
provider: Annotated[str, Path(title="The provider to create credentials for")],
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider to create credentials for")
|
||||
],
|
||||
api_key: Annotated[str, Body(title="The API key to store")],
|
||||
title: Annotated[str, Body(title="Optional title for the credentials")],
|
||||
expires_at: Annotated[
|
||||
@@ -225,7 +241,9 @@ class CredentialsDeletionNeedsConfirmationResponse(BaseModel):
|
||||
@router.delete("/{provider}/credentials/{cred_id}")
|
||||
async def delete_credentials(
|
||||
request: Request,
|
||||
provider: Annotated[str, Path(title="The provider to delete credentials for")],
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="The provider to delete credentials for")
|
||||
],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to delete")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
force: Annotated[
|
||||
@@ -264,15 +282,20 @@ async def delete_credentials(
|
||||
@router.post("/{provider}/webhooks/{webhook_id}/ingress")
|
||||
async def webhook_ingress_generic(
|
||||
request: Request,
|
||||
provider: Annotated[str, Path(title="Provider where the webhook was registered")],
|
||||
provider: Annotated[
|
||||
ProviderName, Path(title="Provider where the webhook was registered")
|
||||
],
|
||||
webhook_id: Annotated[str, Path(title="Our ID for the webhook")],
|
||||
):
|
||||
logger.debug(f"Received {provider} webhook ingress for ID {webhook_id}")
|
||||
logger.debug(f"Received {provider.value} webhook ingress for ID {webhook_id}")
|
||||
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
|
||||
webhook = await get_webhook(webhook_id)
|
||||
logger.debug(f"Webhook #{webhook_id}: {webhook}")
|
||||
payload, event_type = await webhook_manager.validate_payload(webhook, request)
|
||||
logger.debug(f"Validated {provider} {event_type} event with payload {payload}")
|
||||
logger.debug(
|
||||
f"Validated {provider.value} {webhook.webhook_type} {event_type} event "
|
||||
f"with payload {payload}"
|
||||
)
|
||||
|
||||
webhook_event = WebhookEvent(
|
||||
provider=provider,
|
||||
@@ -300,18 +323,28 @@ async def webhook_ingress_generic(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{provider}/webhooks/{webhook_id}/ping")
|
||||
@router.post("/webhooks/{webhook_id}/ping")
|
||||
async def webhook_ping(
|
||||
provider: Annotated[str, Path(title="Provider where the webhook was registered")],
|
||||
webhook_id: Annotated[str, Path(title="Our ID for the webhook")],
|
||||
user_id: Annotated[str, Depends(get_user_id)], # require auth
|
||||
):
|
||||
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
|
||||
webhook = await get_webhook(webhook_id)
|
||||
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[webhook.provider]()
|
||||
|
||||
await webhook_manager.trigger_ping(webhook)
|
||||
if not await listen_for_webhook_event(webhook_id, event_type="ping"):
|
||||
raise HTTPException(status_code=500, detail="Webhook ping event not received")
|
||||
credentials = (
|
||||
creds_manager.get(user_id, webhook.credentials_id)
|
||||
if webhook.credentials_id
|
||||
else None
|
||||
)
|
||||
try:
|
||||
await webhook_manager.trigger_ping(webhook, credentials)
|
||||
except NotImplementedError:
|
||||
return False
|
||||
|
||||
if not await wait_for_webhook_event(webhook_id, event_type="ping", timeout=10):
|
||||
raise HTTPException(status_code=504, detail="Webhook ping timed out")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# --------------------------- UTILITIES ---------------------------- #
|
||||
@@ -331,6 +364,14 @@ async def remove_all_webhooks_for_credentials(
|
||||
NeedConfirmation: If any of the webhooks are still in use and `force` is `False`
|
||||
"""
|
||||
webhooks = await get_all_webhooks(credentials.id)
|
||||
if credentials.provider not in WEBHOOK_MANAGERS_BY_NAME:
|
||||
if webhooks:
|
||||
logger.error(
|
||||
f"Credentials #{credentials.id} for provider {credentials.provider} "
|
||||
f"are attached to {len(webhooks)} webhooks, "
|
||||
f"but there is no available WebhooksHandler for {credentials.provider}"
|
||||
)
|
||||
return
|
||||
if any(w.attached_nodes for w in webhooks) and not force:
|
||||
raise NeedConfirmation(
|
||||
"Some webhooks linked to these credentials are still in use by an agent"
|
||||
@@ -349,18 +390,23 @@ async def remove_all_webhooks_for_credentials(
|
||||
logger.warning(f"Webhook #{webhook.id} failed to prune")
|
||||
|
||||
|
||||
def _get_provider_oauth_handler(req: Request, provider_name: str) -> BaseOAuthHandler:
|
||||
def _get_provider_oauth_handler(
|
||||
req: Request, provider_name: ProviderName
|
||||
) -> "BaseOAuthHandler":
|
||||
if provider_name not in HANDLERS_BY_NAME:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Unknown provider '{provider_name}'"
|
||||
status_code=404,
|
||||
detail=f"Provider '{provider_name.value}' does not support OAuth",
|
||||
)
|
||||
|
||||
client_id = getattr(settings.secrets, f"{provider_name}_client_id")
|
||||
client_secret = getattr(settings.secrets, f"{provider_name}_client_secret")
|
||||
client_id = getattr(settings.secrets, f"{provider_name.value}_client_id")
|
||||
client_secret = getattr(settings.secrets, f"{provider_name.value}_client_secret")
|
||||
if not (client_id and client_secret):
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail=f"Integration with provider '{provider_name}' is not configured",
|
||||
detail=(
|
||||
f"Integration with provider '{provider_name.value}' is not configured"
|
||||
),
|
||||
)
|
||||
|
||||
handler_class = HANDLERS_BY_NAME[provider_name]
|
||||
|
||||
@@ -16,6 +16,7 @@ import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.store.routes
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
|
||||
@@ -25,15 +26,26 @@ logger = logging.getLogger(__name__)
|
||||
logging.getLogger("autogpt_libs").setLevel(logging.INFO)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def launch_darkly_context():
|
||||
if settings.config.app_env != backend.util.settings.AppEnvironment.LOCAL:
|
||||
initialize_launchdarkly()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
shutdown_launchdarkly()
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.db.connect()
|
||||
await backend.data.block.initialize_blocks()
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
initialize_launchdarkly()
|
||||
yield
|
||||
shutdown_launchdarkly()
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
await backend.data.db.disconnect()
|
||||
|
||||
|
||||
@@ -73,7 +85,10 @@ def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
|
||||
|
||||
app.add_exception_handler(ValueError, handle_internal_http_error(400))
|
||||
app.add_exception_handler(Exception, handle_internal_http_error(500))
|
||||
app.include_router(backend.server.routers.v1.v1_router, tags=["v1"])
|
||||
app.include_router(backend.server.routers.v1.v1_router, tags=["v1"], prefix="/api")
|
||||
app.include_router(
|
||||
backend.server.v2.store.routes.router, tags=["v2"], prefix="/api/store"
|
||||
)
|
||||
|
||||
|
||||
@app.get(path="/health", tags=["health"], dependencies=[])
|
||||
@@ -106,17 +121,17 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
async def test_create_graph(
|
||||
create_graph: backend.server.routers.v1.CreateGraph,
|
||||
user_id: str,
|
||||
is_template=False,
|
||||
):
|
||||
return await backend.server.routers.v1.create_new_graph(create_graph, user_id)
|
||||
|
||||
@staticmethod
|
||||
async def test_get_graph_run_status(
|
||||
graph_id: str, graph_exec_id: str, user_id: str
|
||||
):
|
||||
return await backend.server.routers.v1.get_graph_run_status(
|
||||
graph_id, graph_exec_id, user_id
|
||||
async def test_get_graph_run_status(graph_exec_id: str, user_id: str):
|
||||
execution = await backend.data.graph.get_execution(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not execution:
|
||||
raise ValueError(f"Execution {graph_exec_id} not found")
|
||||
return execution.status
|
||||
|
||||
@staticmethod
|
||||
async def test_get_graph_run_node_execution_results(
|
||||
|
||||
@@ -69,8 +69,7 @@ integration_creds_manager = IntegrationCredentialsManager()
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
# Define the API routes
|
||||
v1_router = APIRouter(prefix="/api")
|
||||
|
||||
v1_router = APIRouter()
|
||||
|
||||
v1_router.include_router(
|
||||
backend.server.integrations.router.router,
|
||||
@@ -132,7 +131,7 @@ def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlockOutput
|
||||
|
||||
@v1_router.get(path="/credits", dependencies=[Depends(auth_middleware)])
|
||||
async def get_user_credits(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[str, int]:
|
||||
# Credits can go negative, so ensure it's at least 0 for user to see.
|
||||
return {"credits": max(await _user_credit_model.get_or_refill_credit(user_id), 0)}
|
||||
@@ -149,12 +148,9 @@ class DeleteGraphResponse(TypedDict):
|
||||
|
||||
@v1_router.get(path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)])
|
||||
async def get_graphs(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
with_runs: bool = False,
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> Sequence[graph_db.Graph]:
|
||||
return await graph_db.get_graphs(
|
||||
include_executions=with_runs, filter_by="active", user_id=user_id
|
||||
)
|
||||
return await graph_db.get_graphs(filter_by="active", user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -252,6 +248,13 @@ async def do_create_graph(
|
||||
async def delete_graph(
|
||||
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> DeleteGraphResponse:
|
||||
if active_version := await graph_db.get_graph(graph_id, user_id=user_id):
|
||||
|
||||
def get_credentials(credentials_id: str) -> "Credentials | None":
|
||||
return integration_creds_manager.get(user_id, credentials_id)
|
||||
|
||||
await on_graph_deactivate(active_version, get_credentials)
|
||||
|
||||
return {"version_counts": await graph_db.delete_graph(graph_id, user_id=user_id)}
|
||||
|
||||
|
||||
@@ -386,7 +389,7 @@ def execute_graph(
|
||||
async def stop_graph_run(
|
||||
graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> Sequence[execution_db.ExecutionResult]:
|
||||
if not await execution_db.get_graph_execution(graph_exec_id, user_id):
|
||||
if not await graph_db.get_execution(user_id=user_id, execution_id=graph_exec_id):
|
||||
raise HTTPException(404, detail=f"Agent execution #{graph_exec_id} not found")
|
||||
|
||||
await asyncio.to_thread(
|
||||
@@ -398,23 +401,14 @@ async def stop_graph_run(
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/executions",
|
||||
path="/executions",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def list_graph_runs(
|
||||
graph_id: str,
|
||||
async def get_executions(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
graph_version: int | None = None,
|
||||
) -> Sequence[str]:
|
||||
graph = await graph_db.get_graph(graph_id, graph_version, user_id=user_id)
|
||||
if not graph:
|
||||
rev = "" if graph_version is None else f" v{graph_version}"
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Agent #{graph_id}{rev} not found."
|
||||
)
|
||||
|
||||
return await execution_db.list_executions(graph_id, graph_version)
|
||||
) -> list[graph_db.GraphExecution]:
|
||||
return await graph_db.get_executions(user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
@@ -434,25 +428,6 @@ async def get_graph_run_node_execution_results(
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
||||
|
||||
# NOTE: This is used for testing
|
||||
async def get_graph_run_status(
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> execution_db.ExecutionStatus:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
execution = await execution_db.get_graph_execution(graph_exec_id, user_id)
|
||||
if not execution:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Execution #{graph_exec_id} not found."
|
||||
)
|
||||
|
||||
return execution.executionStatus
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Templates ########################
|
||||
########################################################
|
||||
|
||||
709
autogpt_platform/backend/backend/server/v2/store/db.py
Normal file
709
autogpt_platform/backend/backend/server/v2/store/db.py
Normal file
@@ -0,0 +1,709 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import random
|
||||
import prisma.enums
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_store_agents(
|
||||
featured: bool = False,
|
||||
creator: str | None = None,
|
||||
sorted_by: str | None = None,
|
||||
search_query: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.StoreAgentsResponse:
|
||||
logger.debug(
|
||||
f"Getting store agents. featured={featured}, creator={creator}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
||||
)
|
||||
|
||||
where_clause = {}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
if creator:
|
||||
where_clause["creator_username"] = creator
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
if search_query:
|
||||
where_clause["OR"] = [
|
||||
{"agent_name": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"description": {"contains": search_query, "mode": "insensitive"}},
|
||||
]
|
||||
|
||||
order_by = []
|
||||
if sorted_by == "rating":
|
||||
order_by.append({"rating": "desc"})
|
||||
elif sorted_by == "runs":
|
||||
order_by.append({"runs": "desc"})
|
||||
elif sorted_by == "name":
|
||||
order_by.append({"agent_name": "asc"})
|
||||
|
||||
try:
|
||||
agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=prisma.types.StoreAgentWhereInput(**where_clause),
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
total = await prisma.models.StoreAgent.prisma().count(
|
||||
where=prisma.types.StoreAgentWhereInput(**where_clause)
|
||||
)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents = [
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_image=agent.agent_image[0] if agent.agent_image else "",
|
||||
creator=agent.creator_username,
|
||||
creator_avatar=agent.creator_avatar,
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
)
|
||||
for agent in agents
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(store_agents)} agents")
|
||||
return backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=store_agents,
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=page,
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store agents: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch store agents"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_store_agent_details(
|
||||
username: str, agent_name: str
|
||||
) -> backend.server.v2.store.model.StoreAgentDetails:
|
||||
logger.debug(f"Getting store agent details for {username}/{agent_name}")
|
||||
|
||||
try:
|
||||
agent = await prisma.models.StoreAgent.prisma().find_first(
|
||||
where={"creator_username": username, "slug": agent_name}
|
||||
)
|
||||
|
||||
if not agent:
|
||||
logger.warning(f"Agent not found: {username}/{agent_name}")
|
||||
raise backend.server.v2.store.exceptions.AgentNotFoundError(
|
||||
f"Agent {username}/{agent_name} not found"
|
||||
)
|
||||
|
||||
logger.debug(f"Found agent details for {username}/{agent_name}")
|
||||
return backend.server.v2.store.model.StoreAgentDetails(
|
||||
store_listing_version_id=agent.storeListingVersionId,
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_video=agent.agent_video or "",
|
||||
agent_image=agent.agent_image,
|
||||
creator=agent.creator_username,
|
||||
creator_avatar=agent.creator_avatar,
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
categories=agent.categories,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
versions=agent.versions,
|
||||
last_updated=agent.updated_at,
|
||||
)
|
||||
except backend.server.v2.store.exceptions.AgentNotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store agent details: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch agent details"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_store_creators(
|
||||
featured: bool = False,
|
||||
search_query: str | None = None,
|
||||
sorted_by: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.CreatorsResponse:
|
||||
logger.debug(
|
||||
f"Getting store creators. featured={featured}, search={search_query}, sorted_by={sorted_by}, page={page}"
|
||||
)
|
||||
|
||||
# Build where clause
|
||||
where = {}
|
||||
|
||||
# Add search filter if provided
|
||||
if search_query:
|
||||
where["OR"] = [
|
||||
{"username": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"name": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"description": {"contains": search_query, "mode": "insensitive"}},
|
||||
]
|
||||
|
||||
try:
|
||||
# Get total count for pagination
|
||||
total = await prisma.models.Creator.prisma().count(
|
||||
where=prisma.types.CreatorWhereInput(**where)
|
||||
)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# Add pagination
|
||||
skip = (page - 1) * page_size
|
||||
take = page_size
|
||||
|
||||
# Add sorting
|
||||
order = []
|
||||
if sorted_by == "agent_rating":
|
||||
order.append({"agent_rating": "desc"})
|
||||
elif sorted_by == "agent_runs":
|
||||
order.append({"agent_runs": "desc"})
|
||||
elif sorted_by == "num_agents":
|
||||
order.append({"num_agents": "desc"})
|
||||
else:
|
||||
order.append({"username": "asc"})
|
||||
|
||||
# Execute query
|
||||
creators = await prisma.models.Creator.prisma().find_many(
|
||||
where=prisma.types.CreatorWhereInput(**where),
|
||||
skip=skip,
|
||||
take=take,
|
||||
order=order,
|
||||
)
|
||||
|
||||
# Convert to response model
|
||||
creator_models = [
|
||||
backend.server.v2.store.model.Creator(
|
||||
username=creator.username,
|
||||
name=creator.name,
|
||||
description=creator.description,
|
||||
avatar_url=creator.avatar_url,
|
||||
num_agents=creator.num_agents,
|
||||
agent_rating=creator.agent_rating,
|
||||
agent_runs=creator.agent_runs,
|
||||
)
|
||||
for creator in creators
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(creator_models)} creators")
|
||||
return backend.server.v2.store.model.CreatorsResponse(
|
||||
creators=creator_models,
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=page,
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store creators: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch store creators"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_store_creator_details(
|
||||
username: str,
|
||||
) -> backend.server.v2.store.model.CreatorDetails:
|
||||
logger.debug(f"Getting store creator details for {username}")
|
||||
|
||||
try:
|
||||
# Query creator details from database
|
||||
creator = await prisma.models.Creator.prisma().find_unique(
|
||||
where={"username": username}
|
||||
)
|
||||
|
||||
if not creator:
|
||||
logger.warning(f"Creator not found: {username}")
|
||||
raise backend.server.v2.store.exceptions.CreatorNotFoundError(
|
||||
f"Creator {username} not found"
|
||||
)
|
||||
|
||||
logger.debug(f"Found creator details for {username}")
|
||||
return backend.server.v2.store.model.CreatorDetails(
|
||||
name=creator.name,
|
||||
username=creator.username,
|
||||
description=creator.description,
|
||||
links=creator.links,
|
||||
avatar_url=creator.avatar_url,
|
||||
agent_rating=creator.agent_rating,
|
||||
agent_runs=creator.agent_runs,
|
||||
top_categories=creator.top_categories,
|
||||
)
|
||||
except backend.server.v2.store.exceptions.CreatorNotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting store creator details: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch creator details"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_store_submissions(
|
||||
user_id: str, page: int = 1, page_size: int = 20
|
||||
) -> backend.server.v2.store.model.StoreSubmissionsResponse:
|
||||
logger.debug(f"Getting store submissions for user {user_id}, page={page}")
|
||||
|
||||
try:
|
||||
# Calculate pagination values
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
where = prisma.types.StoreSubmissionWhereInput(user_id=user_id)
|
||||
# Query submissions from database
|
||||
submissions = await prisma.models.StoreSubmission.prisma().find_many(
|
||||
where=where, skip=skip, take=page_size, order=[{"date_submitted": "desc"}]
|
||||
)
|
||||
|
||||
# Get total count for pagination
|
||||
total = await prisma.models.StoreSubmission.prisma().count(where=where)
|
||||
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# Convert to response models
|
||||
submission_models = [
|
||||
backend.server.v2.store.model.StoreSubmission(
|
||||
agent_id=sub.agent_id,
|
||||
agent_version=sub.agent_version,
|
||||
name=sub.name,
|
||||
sub_heading=sub.sub_heading,
|
||||
slug=sub.slug,
|
||||
description=sub.description,
|
||||
image_urls=sub.image_urls or [],
|
||||
date_submitted=sub.date_submitted or datetime.now(),
|
||||
status=sub.status,
|
||||
runs=sub.runs or 0,
|
||||
rating=sub.rating or 0.0,
|
||||
)
|
||||
for sub in submissions
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(submission_models)} submissions")
|
||||
return backend.server.v2.store.model.StoreSubmissionsResponse(
|
||||
submissions=submission_models,
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=page,
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching store submissions: {str(e)}")
|
||||
# Return empty response rather than exposing internal errors
|
||||
return backend.server.v2.store.model.StoreSubmissionsResponse(
|
||||
submissions=[],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=page,
|
||||
total_items=0,
|
||||
total_pages=0,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def delete_store_submission(
|
||||
user_id: str,
|
||||
submission_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a store listing submission.
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user
|
||||
submission_id: ID of the submission to be deleted
|
||||
|
||||
Returns:
|
||||
bool: True if the submission was successfully deleted, False otherwise
|
||||
"""
|
||||
logger.debug(f"Deleting store submission {submission_id} for user {user_id}")
|
||||
|
||||
try:
|
||||
# Verify the submission belongs to this user
|
||||
submission = await prisma.models.StoreListing.prisma().find_first(
|
||||
where={"agentId": submission_id, "owningUserId": user_id}
|
||||
)
|
||||
|
||||
if not submission:
|
||||
logger.warning(f"Submission not found for user {user_id}: {submission_id}")
|
||||
raise backend.server.v2.store.exceptions.SubmissionNotFoundError(
|
||||
f"Submission not found for this user. User ID: {user_id}, Submission ID: {submission_id}"
|
||||
)
|
||||
|
||||
# Delete the submission
|
||||
await prisma.models.StoreListing.prisma().delete(
|
||||
where=prisma.types.StoreListingWhereUniqueInput(id=submission.id)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully deleted submission {submission_id} for user {user_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting store submission: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def create_store_submission(
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
agent_version: int,
|
||||
slug: str,
|
||||
name: str,
|
||||
video_url: str | None = None,
|
||||
image_urls: list[str] = [],
|
||||
description: str = "",
|
||||
sub_heading: str = "",
|
||||
categories: list[str] = [],
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
"""
|
||||
Create a new store listing submission.
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user submitting the listing
|
||||
agent_id: ID of the agent being submitted
|
||||
agent_version: Version of the agent being submitted
|
||||
slug: URL slug for the listing
|
||||
name: Name of the agent
|
||||
video_url: Optional URL to video demo
|
||||
image_urls: List of image URLs for the listing
|
||||
description: Description of the agent
|
||||
categories: List of categories for the agent
|
||||
|
||||
Returns:
|
||||
StoreSubmission: The created store submission
|
||||
"""
|
||||
logger.debug(
|
||||
f"Creating store submission for user {user_id}, agent {agent_id} v{agent_version}"
|
||||
)
|
||||
|
||||
try:
|
||||
# First verify the agent belongs to this user
|
||||
agent = await prisma.models.AgentGraph.prisma().find_first(
|
||||
where=prisma.types.AgentGraphWhereInput(
|
||||
id=agent_id, version=agent_version, userId=user_id
|
||||
)
|
||||
)
|
||||
|
||||
if not agent:
|
||||
logger.warning(
|
||||
f"Agent not found for user {user_id}: {agent_id} v{agent_version}"
|
||||
)
|
||||
raise backend.server.v2.store.exceptions.AgentNotFoundError(
|
||||
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
|
||||
)
|
||||
|
||||
listing = await prisma.models.StoreListing.prisma().find_first(
|
||||
where=prisma.types.StoreListingWhereInput(
|
||||
agentId=agent_id, owningUserId=user_id
|
||||
)
|
||||
)
|
||||
if listing is not None:
|
||||
logger.warning(f"Listing already exists for agent {agent_id}")
|
||||
raise backend.server.v2.store.exceptions.ListingExistsError(
|
||||
"Listing already exists for this agent"
|
||||
)
|
||||
|
||||
# Create the store listing
|
||||
listing = await prisma.models.StoreListing.prisma().create(
|
||||
data={
|
||||
"agentId": agent_id,
|
||||
"agentVersion": agent_version,
|
||||
"owningUserId": user_id,
|
||||
"createdAt": datetime.now(),
|
||||
"StoreListingVersions": {
|
||||
"create": {
|
||||
"agentId": agent_id,
|
||||
"agentVersion": agent_version,
|
||||
"slug": slug,
|
||||
"name": name,
|
||||
"videoUrl": video_url,
|
||||
"imageUrls": image_urls,
|
||||
"description": description,
|
||||
"categories": categories,
|
||||
"subHeading": sub_heading,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(f"Created store listing for agent {agent_id}")
|
||||
# Return submission details
|
||||
return backend.server.v2.store.model.StoreSubmission(
|
||||
agent_id=agent_id,
|
||||
agent_version=agent_version,
|
||||
name=name,
|
||||
slug=slug,
|
||||
sub_heading=sub_heading,
|
||||
description=description,
|
||||
image_urls=image_urls,
|
||||
date_submitted=listing.createdAt,
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
runs=0,
|
||||
rating=0.0,
|
||||
)
|
||||
|
||||
except (
|
||||
backend.server.v2.store.exceptions.AgentNotFoundError,
|
||||
backend.server.v2.store.exceptions.ListingExistsError,
|
||||
):
|
||||
raise
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error creating store submission: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to create store submission"
|
||||
) from e
|
||||
|
||||
|
||||
async def create_store_review(
|
||||
user_id: str,
|
||||
store_listing_version_id: str,
|
||||
score: int,
|
||||
comments: str | None = None,
|
||||
) -> backend.server.v2.store.model.StoreReview:
|
||||
try:
|
||||
review = await prisma.models.StoreListingReview.prisma().upsert(
|
||||
where={
|
||||
"storeListingVersionId_reviewByUserId": {
|
||||
"storeListingVersionId": store_listing_version_id,
|
||||
"reviewByUserId": user_id,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"create": {
|
||||
"reviewByUserId": user_id,
|
||||
"storeListingVersionId": store_listing_version_id,
|
||||
"score": score,
|
||||
"comments": comments,
|
||||
},
|
||||
"update": {
|
||||
"score": score,
|
||||
"comments": comments,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
return backend.server.v2.store.model.StoreReview(
|
||||
score=review.score,
|
||||
comments=review.comments,
|
||||
)
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error creating store review: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to create store review"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_user_profile(
|
||||
user_id: str,
|
||||
) -> backend.server.v2.store.model.ProfileDetails:
|
||||
logger.debug(f"Getting user profile for {user_id}")
|
||||
|
||||
try:
|
||||
profile = await prisma.models.Profile.prisma().find_first(
|
||||
where={"userId": user_id} # type: ignore
|
||||
)
|
||||
|
||||
if not profile:
|
||||
logger.warning(f"Profile not found for user {user_id}")
|
||||
await prisma.models.Profile.prisma().create(
|
||||
data=prisma.types.ProfileCreateInput(
|
||||
userId=user_id,
|
||||
name="No Profile Data",
|
||||
username=f"{random.choice(['happy', 'clever', 'swift', 'bright', 'wise'])}-{random.choice(['fox', 'wolf', 'bear', 'eagle', 'owl'])}_{random.randint(1000,9999)}",
|
||||
description="No Profile Data",
|
||||
links=[],
|
||||
avatarUrl="",
|
||||
)
|
||||
)
|
||||
return backend.server.v2.store.model.ProfileDetails(
|
||||
name="No Profile Data",
|
||||
username="No Profile Data",
|
||||
description="No Profile Data",
|
||||
links=[],
|
||||
avatar_url="",
|
||||
)
|
||||
|
||||
return backend.server.v2.store.model.ProfileDetails(
|
||||
name=profile.name,
|
||||
username=profile.username,
|
||||
description=profile.description,
|
||||
links=profile.links,
|
||||
avatar_url=profile.avatarUrl,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user profile: {str(e)}")
|
||||
return backend.server.v2.store.model.ProfileDetails(
|
||||
name="No Profile Data",
|
||||
username="No Profile Data",
|
||||
description="No Profile Data",
|
||||
links=[],
|
||||
avatar_url="",
|
||||
)
|
||||
|
||||
|
||||
async def update_or_create_profile(
|
||||
user_id: str, profile: backend.server.v2.store.model.Profile
|
||||
) -> backend.server.v2.store.model.CreatorDetails:
|
||||
"""
|
||||
Update the store profile for a user. Creates a new profile if one doesn't exist.
|
||||
Only allows updating if the user_id matches the owning user.
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user
|
||||
profile: Updated profile details
|
||||
|
||||
Returns:
|
||||
CreatorDetails: The updated profile
|
||||
|
||||
Raises:
|
||||
HTTPException: If user is not authorized to update this profile
|
||||
"""
|
||||
logger.debug(f"Updating profile for user {user_id}")
|
||||
|
||||
try:
|
||||
# Check if profile exists for user
|
||||
existing_profile = await prisma.models.Profile.prisma().find_first(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
|
||||
# If no profile exists, create a new one
|
||||
if not existing_profile:
|
||||
logger.debug(f"Creating new profile for user {user_id}")
|
||||
# Create new profile since one doesn't exist
|
||||
new_profile = await prisma.models.Profile.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"name": profile.name,
|
||||
"username": profile.username,
|
||||
"description": profile.description,
|
||||
"links": profile.links,
|
||||
"avatarUrl": profile.avatar_url,
|
||||
}
|
||||
)
|
||||
|
||||
return backend.server.v2.store.model.CreatorDetails(
|
||||
name=new_profile.name,
|
||||
username=new_profile.username,
|
||||
description=new_profile.description,
|
||||
links=new_profile.links,
|
||||
avatar_url=new_profile.avatarUrl or "",
|
||||
agent_rating=0.0,
|
||||
agent_runs=0,
|
||||
top_categories=[],
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Updating existing profile for user {user_id}")
|
||||
# Update the existing profile
|
||||
updated_profile = await prisma.models.Profile.prisma().update(
|
||||
where={"id": existing_profile.id},
|
||||
data=prisma.types.ProfileUpdateInput(
|
||||
name=profile.name,
|
||||
username=profile.username,
|
||||
description=profile.description,
|
||||
links=profile.links,
|
||||
avatarUrl=profile.avatar_url,
|
||||
),
|
||||
)
|
||||
if updated_profile is None:
|
||||
logger.error(f"Failed to update profile for user {user_id}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to update profile"
|
||||
)
|
||||
|
||||
return backend.server.v2.store.model.CreatorDetails(
|
||||
name=updated_profile.name,
|
||||
username=updated_profile.username,
|
||||
description=updated_profile.description,
|
||||
links=updated_profile.links,
|
||||
avatar_url=updated_profile.avatarUrl or "",
|
||||
agent_rating=0.0,
|
||||
agent_runs=0,
|
||||
top_categories=[],
|
||||
)
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating profile: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to update profile"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_my_agents(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.MyAgentsResponse:
|
||||
logger.debug(f"Getting my agents for user {user_id}, page={page}")
|
||||
|
||||
try:
|
||||
agents_with_max_version = await prisma.models.AgentGraph.prisma().find_many(
|
||||
where=prisma.types.AgentGraphWhereInput(
|
||||
userId=user_id, StoreListing={"none": {"isDeleted": False}}
|
||||
),
|
||||
order=[{"version": "desc"}],
|
||||
distinct=["id"],
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
# store_listings = await prisma.models.StoreListing.prisma().find_many(
|
||||
# where=prisma.types.StoreListingWhereInput(
|
||||
# isDeleted=False,
|
||||
# ),
|
||||
# )
|
||||
|
||||
total = len(
|
||||
await prisma.models.AgentGraph.prisma().find_many(
|
||||
where=prisma.types.AgentGraphWhereInput(
|
||||
userId=user_id, StoreListing={"none": {"isDeleted": False}}
|
||||
),
|
||||
order=[{"version": "desc"}],
|
||||
distinct=["id"],
|
||||
)
|
||||
)
|
||||
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
agents = agents_with_max_version
|
||||
|
||||
my_agents = [
|
||||
backend.server.v2.store.model.MyAgent(
|
||||
agent_id=agent.id,
|
||||
agent_version=agent.version,
|
||||
agent_name=agent.name or "",
|
||||
last_edited=agent.updatedAt or agent.createdAt,
|
||||
)
|
||||
for agent in agents
|
||||
]
|
||||
|
||||
return backend.server.v2.store.model.MyAgentsResponse(
|
||||
agents=my_agents,
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=page,
|
||||
total_items=total,
|
||||
total_pages=total_pages,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting my agents: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch my agents"
|
||||
) from e
|
||||
260
autogpt_platform/backend/backend/server/v2/store/db_test.py
Normal file
260
autogpt_platform/backend/backend/server/v2/store/db_test.py
Normal file
@@ -0,0 +1,260 @@
|
||||
from datetime import datetime
|
||||
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
import pytest
|
||||
from prisma import Prisma
|
||||
|
||||
import backend.server.v2.store.db as db
|
||||
from backend.server.v2.store.model import Profile
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_prisma():
|
||||
# Don't register client if already registered
|
||||
try:
|
||||
Prisma()
|
||||
except prisma.errors.ClientAlreadyRegisteredError:
|
||||
pass
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_store_agents(mocker):
|
||||
# Mock data
|
||||
mock_agents = [
|
||||
prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
storeListingVersionId="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video=None,
|
||||
agent_image=["image.jpg"],
|
||||
featured=False,
|
||||
creator_username="creator",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test heading",
|
||||
description="Test description",
|
||||
categories=[],
|
||||
runs=10,
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
]
|
||||
|
||||
# Mock prisma calls
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
mock_store_agent.return_value.find_many = mocker.AsyncMock(return_value=mock_agents)
|
||||
mock_store_agent.return_value.count = mocker.AsyncMock(return_value=1)
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_agents()
|
||||
|
||||
# Verify results
|
||||
assert len(result.agents) == 1
|
||||
assert result.agents[0].slug == "test-agent"
|
||||
assert result.pagination.total_items == 1
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_store_agent.return_value.find_many.assert_called_once()
|
||||
mock_store_agent.return_value.count.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_store_agent_details(mocker):
|
||||
# Mock data
|
||||
mock_agent = prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
storeListingVersionId="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video="video.mp4",
|
||||
agent_image=["image.jpg"],
|
||||
featured=False,
|
||||
creator_username="creator",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test heading",
|
||||
description="Test description",
|
||||
categories=["test"],
|
||||
runs=10,
|
||||
rating=4.5,
|
||||
versions=["1.0"],
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
|
||||
# Mock prisma call
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_agent_details("creator", "test-agent")
|
||||
|
||||
# Verify results
|
||||
assert result.slug == "test-agent"
|
||||
assert result.agent_name == "Test Agent"
|
||||
|
||||
# Verify mock called correctly
|
||||
mock_store_agent.return_value.find_first.assert_called_once_with(
|
||||
where={"creator_username": "creator", "slug": "test-agent"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_store_creator_details(mocker):
|
||||
# Mock data
|
||||
mock_creator_data = prisma.models.Creator(
|
||||
name="Test Creator",
|
||||
username="creator",
|
||||
description="Test description",
|
||||
links=["link1"],
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=1,
|
||||
agent_rating=4.5,
|
||||
agent_runs=10,
|
||||
top_categories=["test"],
|
||||
)
|
||||
|
||||
# Mock prisma call
|
||||
mock_creator = mocker.patch("prisma.models.Creator.prisma")
|
||||
mock_creator.return_value.find_unique = mocker.AsyncMock()
|
||||
# Configure the mock to return values that will pass validation
|
||||
mock_creator.return_value.find_unique.return_value = mock_creator_data
|
||||
|
||||
# Call function
|
||||
result = await db.get_store_creator_details("creator")
|
||||
|
||||
# Verify results
|
||||
assert result.username == "creator"
|
||||
assert result.name == "Test Creator"
|
||||
assert result.description == "Test description"
|
||||
assert result.avatar_url == "avatar.jpg"
|
||||
|
||||
# Verify mock called correctly
|
||||
mock_creator.return_value.find_unique.assert_called_once_with(
|
||||
where={"username": "creator"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_store_submission(mocker):
|
||||
# Mock data
|
||||
mock_agent = prisma.models.AgentGraph(
|
||||
id="agent-id",
|
||||
version=1,
|
||||
userId="user-id",
|
||||
createdAt=datetime.now(),
|
||||
isActive=True,
|
||||
isTemplate=False,
|
||||
)
|
||||
|
||||
mock_listing = prisma.models.StoreListing(
|
||||
id="listing-id",
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
isDeleted=False,
|
||||
isApproved=False,
|
||||
agentId="agent-id",
|
||||
agentVersion=1,
|
||||
owningUserId="user-id",
|
||||
)
|
||||
|
||||
# Mock prisma calls
|
||||
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
||||
mock_agent_graph.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
mock_store_listing = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_store_listing.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_store_listing.return_value.create = mocker.AsyncMock(return_value=mock_listing)
|
||||
|
||||
# Call function
|
||||
result = await db.create_store_submission(
|
||||
user_id="user-id",
|
||||
agent_id="agent-id",
|
||||
agent_version=1,
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result.name == "Test Agent"
|
||||
assert result.description == "Test description"
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_agent_graph.return_value.find_first.assert_called_once()
|
||||
mock_store_listing.return_value.find_first.assert_called_once()
|
||||
mock_store_listing.return_value.create.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_profile(mocker):
|
||||
# Mock data
|
||||
mock_profile = prisma.models.Profile(
|
||||
id="profile-id",
|
||||
name="Test Creator",
|
||||
username="creator",
|
||||
description="Test description",
|
||||
links=["link1"],
|
||||
avatarUrl="avatar.jpg",
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
)
|
||||
|
||||
# Mock prisma calls
|
||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||
mock_profile_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_profile
|
||||
)
|
||||
mock_profile_db.return_value.update = mocker.AsyncMock(return_value=mock_profile)
|
||||
|
||||
# Test data
|
||||
profile = Profile(
|
||||
name="Test Creator",
|
||||
username="creator",
|
||||
description="Test description",
|
||||
links=["link1"],
|
||||
avatar_url="avatar.jpg",
|
||||
)
|
||||
|
||||
# Call function
|
||||
result = await db.update_or_create_profile("user-id", profile)
|
||||
|
||||
# Verify results
|
||||
assert result.username == "creator"
|
||||
assert result.name == "Test Creator"
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_profile_db.return_value.find_first.assert_called_once()
|
||||
mock_profile_db.return_value.update.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_profile(mocker):
|
||||
# Mock data
|
||||
mock_profile = prisma.models.Profile(
|
||||
id="profile-id",
|
||||
name="No Profile Data",
|
||||
username="testuser",
|
||||
description="Test description",
|
||||
links=["link1", "link2"],
|
||||
avatarUrl="avatar.jpg",
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
)
|
||||
|
||||
# Mock prisma calls
|
||||
mock_profile_db = mocker.patch("prisma.models.Profile.prisma")
|
||||
mock_profile_db.return_value.find_unique = mocker.AsyncMock(
|
||||
return_value=mock_profile
|
||||
)
|
||||
|
||||
# Call function
|
||||
result = await db.get_user_profile("user-id")
|
||||
|
||||
# Verify results
|
||||
assert result.name == "No Profile Data"
|
||||
assert result.username == "No Profile Data"
|
||||
assert result.description == "No Profile Data"
|
||||
assert result.links == []
|
||||
assert result.avatar_url == ""
|
||||
@@ -0,0 +1,76 @@
|
||||
class MediaUploadError(Exception):
|
||||
"""Base exception for media upload errors"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidFileTypeError(MediaUploadError):
|
||||
"""Raised when file type is not supported"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FileSizeTooLargeError(MediaUploadError):
|
||||
"""Raised when file size exceeds maximum limit"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FileReadError(MediaUploadError):
|
||||
"""Raised when there's an error reading the file"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StorageConfigError(MediaUploadError):
|
||||
"""Raised when storage configuration is invalid"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StorageUploadError(MediaUploadError):
|
||||
"""Raised when upload to storage fails"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StoreError(Exception):
|
||||
"""Base exception for store-related errors"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentNotFoundError(StoreError):
|
||||
"""Raised when an agent is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CreatorNotFoundError(StoreError):
|
||||
"""Raised when a creator is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ListingExistsError(StoreError):
|
||||
"""Raised when trying to create a listing that already exists"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseError(StoreError):
|
||||
"""Raised when there is an error interacting with the database"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ProfileNotFoundError(StoreError):
|
||||
"""Raised when a profile is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SubmissionNotFoundError(StoreError):
|
||||
"""Raised when a submission is not found"""
|
||||
|
||||
pass
|
||||
101
autogpt_platform/backend/backend/server/v2/store/media.py
Normal file
101
autogpt_platform/backend/backend/server/v2/store/media.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import fastapi
|
||||
from google.cloud import storage
|
||||
|
||||
import backend.server.v2.store.exceptions
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALLOWED_IMAGE_TYPES = {"image/jpeg", "image/png", "image/gif", "image/webp"}
|
||||
ALLOWED_VIDEO_TYPES = {"video/mp4", "video/webm"}
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
|
||||
async def upload_media(user_id: str, file: fastapi.UploadFile) -> str:
|
||||
settings = Settings()
|
||||
|
||||
# Check required settings first before doing any file processing
|
||||
if (
|
||||
not settings.config.media_gcs_bucket_name
|
||||
or not settings.config.google_application_credentials
|
||||
):
|
||||
logger.error("Missing required GCS settings")
|
||||
raise backend.server.v2.store.exceptions.StorageConfigError(
|
||||
"Missing storage configuration"
|
||||
)
|
||||
|
||||
try:
|
||||
# Validate file type
|
||||
content_type = file.content_type
|
||||
if (
|
||||
content_type not in ALLOWED_IMAGE_TYPES
|
||||
and content_type not in ALLOWED_VIDEO_TYPES
|
||||
):
|
||||
logger.warning(f"Invalid file type attempted: {content_type}")
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
f"File type not supported. Must be jpeg, png, gif, webp, mp4 or webm. Content type: {content_type}"
|
||||
)
|
||||
|
||||
# Validate file size
|
||||
file_size = 0
|
||||
chunk_size = 8192 # 8KB chunks
|
||||
|
||||
try:
|
||||
while chunk := await file.read(chunk_size):
|
||||
file_size += len(chunk)
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
logger.warning(f"File size too large: {file_size} bytes")
|
||||
raise backend.server.v2.store.exceptions.FileSizeTooLargeError(
|
||||
"File too large. Maximum size is 50MB"
|
||||
)
|
||||
except backend.server.v2.store.exceptions.FileSizeTooLargeError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading file chunks: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.FileReadError(
|
||||
"Failed to read uploaded file"
|
||||
) from e
|
||||
|
||||
# Reset file pointer
|
||||
await file.seek(0)
|
||||
|
||||
# Generate unique filename
|
||||
filename = file.filename or ""
|
||||
file_ext = os.path.splitext(filename)[1].lower()
|
||||
unique_filename = f"{uuid.uuid4()}{file_ext}"
|
||||
|
||||
# Construct storage path
|
||||
media_type = "images" if content_type in ALLOWED_IMAGE_TYPES else "videos"
|
||||
storage_path = f"users/{user_id}/{media_type}/{unique_filename}"
|
||||
|
||||
try:
|
||||
storage_client = storage.Client()
|
||||
bucket = storage_client.bucket(settings.config.media_gcs_bucket_name)
|
||||
blob = bucket.blob(storage_path)
|
||||
blob.content_type = content_type
|
||||
|
||||
file_bytes = await file.read()
|
||||
blob.upload_from_string(file_bytes, content_type=content_type)
|
||||
|
||||
public_url = blob.public_url
|
||||
|
||||
logger.info(f"Successfully uploaded file to: {storage_path}")
|
||||
return public_url
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"GCS storage error: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.StorageUploadError(
|
||||
"Failed to upload file to storage"
|
||||
) from e
|
||||
|
||||
except backend.server.v2.store.exceptions.MediaUploadError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error in upload_media")
|
||||
raise backend.server.v2.store.exceptions.MediaUploadError(
|
||||
"Unexpected error during media upload"
|
||||
) from e
|
||||
107
autogpt_platform/backend/backend/server/v2/store/media_test.py
Normal file
107
autogpt_platform/backend/backend/server/v2/store/media_test.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import io
|
||||
import unittest.mock
|
||||
|
||||
import fastapi
|
||||
import pytest
|
||||
import starlette.datastructures
|
||||
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.media
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings(monkeypatch):
|
||||
settings = Settings()
|
||||
settings.config.media_gcs_bucket_name = "test-bucket"
|
||||
settings.config.google_application_credentials = "test-credentials"
|
||||
monkeypatch.setattr("backend.server.v2.store.media.Settings", lambda: settings)
|
||||
return settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage_client(mocker):
|
||||
mock_client = unittest.mock.MagicMock()
|
||||
mock_bucket = unittest.mock.MagicMock()
|
||||
mock_blob = unittest.mock.MagicMock()
|
||||
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.blob.return_value = mock_blob
|
||||
mock_blob.public_url = "http://test-url/media/test.jpg"
|
||||
|
||||
mocker.patch("google.cloud.storage.Client", return_value=mock_client)
|
||||
|
||||
return mock_client
|
||||
|
||||
|
||||
async def test_upload_media_success(mock_settings, mock_storage_client):
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="test.jpeg",
|
||||
file=io.BytesIO(b"test data"),
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
|
||||
assert result == "http://test-url/media/test.jpg"
|
||||
mock_bucket = mock_storage_client.bucket.return_value
|
||||
mock_blob = mock_bucket.blob.return_value
|
||||
mock_blob.upload_from_string.assert_called_once()
|
||||
|
||||
|
||||
async def test_upload_media_invalid_type(mock_settings, mock_storage_client):
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="test.txt",
|
||||
file=io.BytesIO(b"test data"),
|
||||
headers=starlette.datastructures.Headers({"content-type": "text/plain"}),
|
||||
)
|
||||
|
||||
with pytest.raises(backend.server.v2.store.exceptions.InvalidFileTypeError):
|
||||
await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
|
||||
mock_bucket = mock_storage_client.bucket.return_value
|
||||
mock_blob = mock_bucket.blob.return_value
|
||||
mock_blob.upload_from_string.assert_not_called()
|
||||
|
||||
|
||||
async def test_upload_media_missing_credentials(monkeypatch):
|
||||
settings = Settings()
|
||||
settings.config.media_gcs_bucket_name = ""
|
||||
settings.config.google_application_credentials = ""
|
||||
monkeypatch.setattr("backend.server.v2.store.media.Settings", lambda: settings)
|
||||
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="test.jpeg",
|
||||
file=io.BytesIO(b"test data"),
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
|
||||
with pytest.raises(backend.server.v2.store.exceptions.StorageConfigError):
|
||||
await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
|
||||
|
||||
async def test_upload_media_video_type(mock_settings, mock_storage_client):
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="test.mp4",
|
||||
file=io.BytesIO(b"test video data"),
|
||||
headers=starlette.datastructures.Headers({"content-type": "video/mp4"}),
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
|
||||
assert result == "http://test-url/media/test.jpg"
|
||||
mock_bucket = mock_storage_client.bucket.return_value
|
||||
mock_blob = mock_bucket.blob.return_value
|
||||
mock_blob.upload_from_string.assert_called_once()
|
||||
|
||||
|
||||
async def test_upload_media_file_too_large(mock_settings, mock_storage_client):
|
||||
large_data = b"x" * (50 * 1024 * 1024 + 1) # 50MB + 1 byte
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="test.jpeg",
|
||||
file=io.BytesIO(large_data),
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
|
||||
with pytest.raises(backend.server.v2.store.exceptions.FileSizeTooLargeError):
|
||||
await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
150
autogpt_platform/backend/backend/server/v2/store/model.py
Normal file
150
autogpt_platform/backend/backend/server/v2/store/model.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import datetime
|
||||
from typing import List
|
||||
|
||||
import prisma.enums
|
||||
import pydantic
|
||||
|
||||
|
||||
class Pagination(pydantic.BaseModel):
|
||||
total_items: int = pydantic.Field(
|
||||
description="Total number of items.", examples=[42]
|
||||
)
|
||||
total_pages: int = pydantic.Field(
|
||||
description="Total number of pages.", examples=[97]
|
||||
)
|
||||
current_page: int = pydantic.Field(
|
||||
description="Current_page page number.", examples=[1]
|
||||
)
|
||||
page_size: int = pydantic.Field(
|
||||
description="Number of items per page.", examples=[25]
|
||||
)
|
||||
|
||||
|
||||
class MyAgent(pydantic.BaseModel):
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
agent_name: str
|
||||
last_edited: datetime.datetime
|
||||
|
||||
|
||||
class MyAgentsResponse(pydantic.BaseModel):
|
||||
agents: list[MyAgent]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreAgent(pydantic.BaseModel):
|
||||
slug: str
|
||||
agent_name: str
|
||||
agent_image: str
|
||||
creator: str
|
||||
creator_avatar: str
|
||||
sub_heading: str
|
||||
description: str
|
||||
runs: int
|
||||
rating: float
|
||||
|
||||
|
||||
class StoreAgentsResponse(pydantic.BaseModel):
|
||||
agents: list[StoreAgent]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreAgentDetails(pydantic.BaseModel):
|
||||
store_listing_version_id: str
|
||||
slug: str
|
||||
agent_name: str
|
||||
agent_video: str
|
||||
agent_image: list[str]
|
||||
creator: str
|
||||
creator_avatar: str
|
||||
sub_heading: str
|
||||
description: str
|
||||
categories: list[str]
|
||||
runs: int
|
||||
rating: float
|
||||
versions: list[str]
|
||||
last_updated: datetime.datetime
|
||||
|
||||
|
||||
class Creator(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
avatar_url: str
|
||||
num_agents: int
|
||||
agent_rating: float
|
||||
agent_runs: int
|
||||
|
||||
|
||||
class CreatorsResponse(pydantic.BaseModel):
|
||||
creators: List[Creator]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class CreatorDetails(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
links: list[str]
|
||||
avatar_url: str
|
||||
agent_rating: float
|
||||
agent_runs: int
|
||||
top_categories: list[str]
|
||||
|
||||
|
||||
class Profile(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
links: list[str]
|
||||
avatar_url: str
|
||||
|
||||
|
||||
class StoreSubmission(pydantic.BaseModel):
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
name: str
|
||||
sub_heading: str
|
||||
slug: str
|
||||
description: str
|
||||
image_urls: list[str]
|
||||
date_submitted: datetime.datetime
|
||||
status: prisma.enums.SubmissionStatus
|
||||
runs: int
|
||||
rating: float
|
||||
|
||||
|
||||
class StoreSubmissionsResponse(pydantic.BaseModel):
|
||||
submissions: list[StoreSubmission]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
class StoreSubmissionRequest(pydantic.BaseModel):
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
slug: str
|
||||
name: str
|
||||
sub_heading: str
|
||||
video_url: str | None = None
|
||||
image_urls: list[str] = []
|
||||
description: str = ""
|
||||
categories: list[str] = []
|
||||
|
||||
|
||||
class ProfileDetails(pydantic.BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
description: str
|
||||
links: list[str]
|
||||
avatar_url: str | None = None
|
||||
|
||||
|
||||
class StoreReview(pydantic.BaseModel):
|
||||
score: int
|
||||
comments: str | None = None
|
||||
|
||||
|
||||
class StoreReviewCreate(pydantic.BaseModel):
|
||||
store_listing_version_id: str
|
||||
score: int
|
||||
comments: str | None = None
|
||||
193
autogpt_platform/backend/backend/server/v2/store/model_test.py
Normal file
193
autogpt_platform/backend/backend/server/v2/store/model_test.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import datetime
|
||||
|
||||
import prisma.enums
|
||||
|
||||
import backend.server.v2.store.model
|
||||
|
||||
|
||||
def test_pagination():
|
||||
pagination = backend.server.v2.store.model.Pagination(
|
||||
total_items=100, total_pages=5, current_page=2, page_size=20
|
||||
)
|
||||
assert pagination.total_items == 100
|
||||
assert pagination.total_pages == 5
|
||||
assert pagination.current_page == 2
|
||||
assert pagination.page_size == 20
|
||||
|
||||
|
||||
def test_store_agent():
|
||||
agent = backend.server.v2.store.model.StoreAgent(
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_image="test.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
)
|
||||
assert agent.slug == "test-agent"
|
||||
assert agent.agent_name == "Test Agent"
|
||||
assert agent.runs == 50
|
||||
assert agent.rating == 4.5
|
||||
|
||||
|
||||
def test_store_agents_response():
|
||||
response = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_image="test.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.agents) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_store_agent_details():
|
||||
details = backend.server.v2.store.model.StoreAgentDetails(
|
||||
store_listing_version_id="version123",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video="video.mp4",
|
||||
agent_image=["image1.jpg", "image2.jpg"],
|
||||
creator="creator1",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test subheading",
|
||||
description="Test description",
|
||||
categories=["cat1", "cat2"],
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
versions=["1.0", "2.0"],
|
||||
last_updated=datetime.datetime.now(),
|
||||
)
|
||||
assert details.slug == "test-agent"
|
||||
assert len(details.agent_image) == 2
|
||||
assert len(details.categories) == 2
|
||||
assert len(details.versions) == 2
|
||||
|
||||
|
||||
def test_creator():
|
||||
creator = backend.server.v2.store.model.Creator(
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=5,
|
||||
)
|
||||
assert creator.name == "Test Creator"
|
||||
assert creator.num_agents == 5
|
||||
|
||||
|
||||
def test_creators_response():
|
||||
response = backend.server.v2.store.model.CreatorsResponse(
|
||||
creators=[
|
||||
backend.server.v2.store.model.Creator(
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=5,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.creators) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_creator_details():
|
||||
details = backend.server.v2.store.model.CreatorDetails(
|
||||
name="Test Creator",
|
||||
username="creator1",
|
||||
description="Test description",
|
||||
links=["link1.com", "link2.com"],
|
||||
avatar_url="avatar.jpg",
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
top_categories=["cat1", "cat2"],
|
||||
)
|
||||
assert details.name == "Test Creator"
|
||||
assert len(details.links) == 2
|
||||
assert details.agent_rating == 4.8
|
||||
assert len(details.top_categories) == 2
|
||||
|
||||
|
||||
def test_store_submission():
|
||||
submission = backend.server.v2.store.model.StoreSubmission(
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
description="Test description",
|
||||
image_urls=["image1.jpg", "image2.jpg"],
|
||||
date_submitted=datetime.datetime(2023, 1, 1),
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
)
|
||||
assert submission.name == "Test Agent"
|
||||
assert len(submission.image_urls) == 2
|
||||
assert submission.status == prisma.enums.SubmissionStatus.PENDING
|
||||
|
||||
|
||||
def test_store_submissions_response():
|
||||
response = backend.server.v2.store.model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
backend.server.v2.store.model.StoreSubmission(
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
sub_heading="Test subheading",
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
description="Test description",
|
||||
image_urls=["image1.jpg"],
|
||||
date_submitted=datetime.datetime(2023, 1, 1),
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=20
|
||||
),
|
||||
)
|
||||
assert len(response.submissions) == 1
|
||||
assert response.pagination.total_items == 1
|
||||
|
||||
|
||||
def test_store_submission_request():
|
||||
request = backend.server.v2.store.model.StoreSubmissionRequest(
|
||||
agent_id="agent123",
|
||||
agent_version=1,
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
sub_heading="Test subheading",
|
||||
video_url="video.mp4",
|
||||
image_urls=["image1.jpg", "image2.jpg"],
|
||||
description="Test description",
|
||||
categories=["cat1", "cat2"],
|
||||
)
|
||||
assert request.agent_id == "agent123"
|
||||
assert request.agent_version == 1
|
||||
assert len(request.image_urls) == 2
|
||||
assert len(request.categories) == 2
|
||||
439
autogpt_platform/backend/backend/server/v2/store/routes.py
Normal file
439
autogpt_platform/backend/backend/server/v2/store/routes.py
Normal file
@@ -0,0 +1,439 @@
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import autogpt_libs.auth.depends
|
||||
import autogpt_libs.auth.middleware
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.media
|
||||
import backend.server.v2.store.model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
|
||||
|
||||
##############################################
|
||||
############### Profile Endpoints ############
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get("/profile", tags=["store", "private"])
|
||||
async def get_profile(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
]
|
||||
) -> backend.server.v2.store.model.ProfileDetails:
|
||||
"""
|
||||
Get the profile details for the authenticated user.
|
||||
"""
|
||||
try:
|
||||
profile = await backend.server.v2.store.db.get_user_profile(user_id)
|
||||
return profile
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting user profile")
|
||||
raise
|
||||
|
||||
|
||||
@router.post(
|
||||
"/profile",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def update_or_create_profile(
|
||||
profile: backend.server.v2.store.model.Profile,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
) -> backend.server.v2.store.model.CreatorDetails:
|
||||
"""
|
||||
Update the store profile for the authenticated user.
|
||||
|
||||
Args:
|
||||
profile (Profile): The updated profile details
|
||||
user_id (str): ID of the authenticated user
|
||||
|
||||
Returns:
|
||||
CreatorDetails: The updated profile
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error updating the profile
|
||||
"""
|
||||
try:
|
||||
updated_profile = await backend.server.v2.store.db.update_or_create_profile(
|
||||
user_id=user_id, profile=profile
|
||||
)
|
||||
return updated_profile
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst updating profile")
|
||||
raise
|
||||
|
||||
|
||||
##############################################
|
||||
############### Agent Endpoints ##############
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get("/agents", tags=["store", "public"])
|
||||
async def get_agents(
|
||||
featured: bool = False,
|
||||
creator: str | None = None,
|
||||
sorted_by: str | None = None,
|
||||
search_query: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.StoreAgentsResponse:
|
||||
"""
|
||||
Get a paginated list of agents from the store with optional filtering and sorting.
|
||||
|
||||
Args:
|
||||
featured (bool, optional): Filter to only show featured agents. Defaults to False.
|
||||
creator (str | None, optional): Filter agents by creator username. Defaults to None.
|
||||
sorted_by (str | None, optional): Sort agents by "runs" or "rating". Defaults to None.
|
||||
search_query (str | None, optional): Search agents by name, subheading and description. Defaults to None.
|
||||
category (str | None, optional): Filter agents by category. Defaults to None.
|
||||
page (int, optional): Page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): Number of agents per page. Defaults to 20.
|
||||
|
||||
Returns:
|
||||
StoreAgentsResponse: Paginated list of agents matching the filters
|
||||
|
||||
Raises:
|
||||
HTTPException: If page or page_size are less than 1
|
||||
|
||||
Used for:
|
||||
- Home Page Featured Agents
|
||||
- Home Page Top Agents
|
||||
- Search Results
|
||||
- Agent Details - Other Agents By Creator
|
||||
- Agent Details - Similar Agents
|
||||
- Creator Details - Agents By Creator
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
try:
|
||||
agents = await backend.server.v2.store.db.get_store_agents(
|
||||
featured=featured,
|
||||
creator=creator,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return agents
|
||||
except Exception:
|
||||
logger.exception("Exception occured whilst getting store agents")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/agents/{username}/{agent_name}", tags=["store", "public"])
|
||||
async def get_agent(
|
||||
username: str, agent_name: str
|
||||
) -> backend.server.v2.store.model.StoreAgentDetails:
|
||||
"""
|
||||
This is only used on the AgentDetails Page
|
||||
|
||||
It returns the store listing agents details.
|
||||
"""
|
||||
try:
|
||||
agent = await backend.server.v2.store.db.get_store_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
)
|
||||
return agent
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting store agent details")
|
||||
raise
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agents/{username}/{agent_name}/review",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def create_review(
|
||||
username: str,
|
||||
agent_name: str,
|
||||
review: backend.server.v2.store.model.StoreReviewCreate,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
) -> backend.server.v2.store.model.StoreReview:
|
||||
"""
|
||||
Create a review for a store agent.
|
||||
|
||||
Args:
|
||||
username: Creator's username
|
||||
agent_name: Name/slug of the agent
|
||||
review: Review details including score and optional comments
|
||||
user_id: ID of authenticated user creating the review
|
||||
|
||||
Returns:
|
||||
The created review
|
||||
"""
|
||||
try:
|
||||
# Create the review
|
||||
created_review = await backend.server.v2.store.db.create_store_review(
|
||||
user_id=user_id,
|
||||
store_listing_version_id=review.store_listing_version_id,
|
||||
score=review.score,
|
||||
comments=review.comments,
|
||||
)
|
||||
|
||||
return created_review
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst creating store review")
|
||||
raise
|
||||
|
||||
|
||||
##############################################
|
||||
############# Creator Endpoints #############
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get("/creators", tags=["store", "public"])
|
||||
async def get_creators(
|
||||
featured: bool = False,
|
||||
search_query: str | None = None,
|
||||
sorted_by: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.CreatorsResponse:
|
||||
"""
|
||||
This is needed for:
|
||||
- Home Page Featured Creators
|
||||
- Search Results Page
|
||||
|
||||
---
|
||||
|
||||
To support this functionality we need:
|
||||
- featured: bool - to limit the list to just featured agents
|
||||
- search_query: str - vector search based on the creators profile description.
|
||||
- sorted_by: [agent_rating, agent_runs] -
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
|
||||
try:
|
||||
creators = await backend.server.v2.store.db.get_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
sorted_by=sorted_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return creators
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting store creators")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/creator/{username}", tags=["store", "public"])
|
||||
async def get_creator(username: str) -> backend.server.v2.store.model.CreatorDetails:
|
||||
"""
|
||||
Get the details of a creator
|
||||
- Creator Details Page
|
||||
"""
|
||||
try:
|
||||
creator = await backend.server.v2.store.db.get_store_creator_details(
|
||||
username=username
|
||||
)
|
||||
return creator
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting creator details")
|
||||
raise
|
||||
|
||||
|
||||
############################################
|
||||
############# Store Submissions ###############
|
||||
############################################
|
||||
@router.get(
|
||||
"/myagents",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def get_my_agents(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
]
|
||||
) -> backend.server.v2.store.model.MyAgentsResponse:
|
||||
try:
|
||||
agents = await backend.server.v2.store.db.get_my_agents(user_id)
|
||||
return agents
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting my agents")
|
||||
raise
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/submissions/{submission_id}",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def delete_submission(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
submission_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a store listing submission.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the authenticated user
|
||||
submission_id (str): ID of the submission to be deleted
|
||||
|
||||
Returns:
|
||||
bool: True if the submission was successfully deleted, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await backend.server.v2.store.db.delete_store_submission(
|
||||
user_id=user_id,
|
||||
submission_id=submission_id,
|
||||
)
|
||||
return result
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst deleting store submission")
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/submissions",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def get_submissions(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.StoreSubmissionsResponse:
|
||||
"""
|
||||
Get a paginated list of store submissions for the authenticated user.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the authenticated user
|
||||
page (int, optional): Page number for pagination. Defaults to 1.
|
||||
page_size (int, optional): Number of submissions per page. Defaults to 20.
|
||||
|
||||
Returns:
|
||||
StoreListingsResponse: Paginated list of store submissions
|
||||
|
||||
Raises:
|
||||
HTTPException: If page or page_size are less than 1
|
||||
"""
|
||||
if page < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page must be greater than 0"
|
||||
)
|
||||
|
||||
if page_size < 1:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=422, detail="Page size must be greater than 0"
|
||||
)
|
||||
try:
|
||||
listings = await backend.server.v2.store.db.get_store_submissions(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return listings
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting store submissions")
|
||||
raise
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def create_submission(
|
||||
submission_request: backend.server.v2.store.model.StoreSubmissionRequest,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
"""
|
||||
Create a new store listing submission.
|
||||
|
||||
Args:
|
||||
submission_request (StoreSubmissionRequest): The submission details
|
||||
user_id (str): ID of the authenticated user submitting the listing
|
||||
|
||||
Returns:
|
||||
StoreSubmission: The created store submission
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error creating the submission
|
||||
"""
|
||||
try:
|
||||
submission = await backend.server.v2.store.db.create_store_submission(
|
||||
user_id=user_id,
|
||||
agent_id=submission_request.agent_id,
|
||||
agent_version=submission_request.agent_version,
|
||||
slug=submission_request.slug,
|
||||
name=submission_request.name,
|
||||
video_url=submission_request.video_url,
|
||||
image_urls=submission_request.image_urls,
|
||||
description=submission_request.description,
|
||||
sub_heading=submission_request.sub_heading,
|
||||
categories=submission_request.categories,
|
||||
)
|
||||
return submission
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst creating store submission")
|
||||
raise
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/media",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def upload_submission_media(
|
||||
file: fastapi.UploadFile,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
) -> str:
|
||||
"""
|
||||
Upload media (images/videos) for a store listing submission.
|
||||
|
||||
Args:
|
||||
file (UploadFile): The media file to upload
|
||||
user_id (str): ID of the authenticated user uploading the media
|
||||
|
||||
Returns:
|
||||
str: URL of the uploaded media file
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error uploading the media
|
||||
"""
|
||||
try:
|
||||
media_url = await backend.server.v2.store.media.upload_media(
|
||||
user_id=user_id, file=file
|
||||
)
|
||||
return media_url
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst uploading submission media")
|
||||
raise
|
||||
551
autogpt_platform/backend/backend/server/v2/store/routes_test.py
Normal file
551
autogpt_platform/backend/backend/server/v2/store/routes_test.py
Normal file
@@ -0,0 +1,551 @@
|
||||
import datetime
|
||||
|
||||
import autogpt_libs.auth.depends
|
||||
import autogpt_libs.auth.middleware
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import prisma.enums
|
||||
import pytest_mock
|
||||
|
||||
import backend.server.v2.store.model
|
||||
import backend.server.v2.store.routes
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(backend.server.v2.store.routes.router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
def override_auth_middleware():
|
||||
"""Override auth middleware for testing"""
|
||||
return {"sub": "test-user-id"}
|
||||
|
||||
|
||||
def override_get_user_id():
|
||||
"""Override get_user_id for testing"""
|
||||
return "test-user-id"
|
||||
|
||||
|
||||
app.dependency_overrides[autogpt_libs.auth.middleware.auth_middleware] = (
|
||||
override_auth_middleware
|
||||
)
|
||||
app.dependency_overrides[autogpt_libs.auth.depends.get_user_id] = override_get_user_id
|
||||
|
||||
|
||||
def test_get_agents_defaults(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=[],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=0,
|
||||
total_items=0,
|
||||
total_pages=0,
|
||||
page_size=10,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert data.pagination.total_pages == 0
|
||||
assert data.agents == []
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
search_query=None,
|
||||
category=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
|
||||
def test_get_agents_featured(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
slug="featured-agent",
|
||||
agent_name="Featured Agent",
|
||||
agent_image="featured.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar1.jpg",
|
||||
sub_heading="Featured agent subheading",
|
||||
description="Featured agent description",
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=1,
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents?featured=true")
|
||||
assert response.status_code == 200
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert len(data.agents) == 1
|
||||
assert data.agents[0].slug == "featured-agent"
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=True,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
search_query=None,
|
||||
category=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
|
||||
def test_get_agents_by_creator(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
slug="creator-agent",
|
||||
agent_name="Creator Agent",
|
||||
agent_image="agent.jpg",
|
||||
creator="specific-creator",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Creator agent subheading",
|
||||
description="Creator agent description",
|
||||
runs=50,
|
||||
rating=4.0,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=1,
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents?creator=specific-creator")
|
||||
assert response.status_code == 200
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert len(data.agents) == 1
|
||||
assert data.agents[0].creator == "specific-creator"
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False,
|
||||
creator="specific-creator",
|
||||
sorted_by=None,
|
||||
search_query=None,
|
||||
category=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
|
||||
def test_get_agents_sorted(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
slug="top-agent",
|
||||
agent_name="Top Agent",
|
||||
agent_image="top.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar1.jpg",
|
||||
sub_heading="Top agent subheading",
|
||||
description="Top agent description",
|
||||
runs=1000,
|
||||
rating=5.0,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=1,
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents?sorted_by=runs")
|
||||
assert response.status_code == 200
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert len(data.agents) == 1
|
||||
assert data.agents[0].runs == 1000
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by="runs",
|
||||
search_query=None,
|
||||
category=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
|
||||
def test_get_agents_search(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
slug="search-agent",
|
||||
agent_name="Search Agent",
|
||||
agent_image="search.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar1.jpg",
|
||||
sub_heading="Search agent subheading",
|
||||
description="Specific search term description",
|
||||
runs=75,
|
||||
rating=4.2,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=1,
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents?search_query=specific")
|
||||
assert response.status_code == 200
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert len(data.agents) == 1
|
||||
assert "specific" in data.agents[0].description.lower()
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
search_query="specific",
|
||||
category=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
|
||||
def test_get_agents_category(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
slug="category-agent",
|
||||
agent_name="Category Agent",
|
||||
agent_image="category.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar1.jpg",
|
||||
sub_heading="Category agent subheading",
|
||||
description="Category agent description",
|
||||
runs=60,
|
||||
rating=4.1,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=1,
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents?category=test-category")
|
||||
assert response.status_code == 200
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert len(data.agents) == 1
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
search_query=None,
|
||||
category="test-category",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
|
||||
def test_get_agents_pagination(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=[
|
||||
backend.server.v2.store.model.StoreAgent(
|
||||
slug=f"agent-{i}",
|
||||
agent_name=f"Agent {i}",
|
||||
agent_image=f"agent{i}.jpg",
|
||||
creator="creator1",
|
||||
creator_avatar="avatar1.jpg",
|
||||
sub_heading=f"Agent {i} subheading",
|
||||
description=f"Agent {i} description",
|
||||
runs=i * 10,
|
||||
rating=4.0,
|
||||
)
|
||||
for i in range(5)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=2,
|
||||
total_items=15,
|
||||
total_pages=3,
|
||||
page_size=5,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents?page=2&page_size=5")
|
||||
assert response.status_code == 200
|
||||
data = backend.server.v2.store.model.StoreAgentsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert len(data.agents) == 5
|
||||
assert data.pagination.current_page == 2
|
||||
assert data.pagination.page_size == 5
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False,
|
||||
creator=None,
|
||||
sorted_by=None,
|
||||
search_query=None,
|
||||
category=None,
|
||||
page=2,
|
||||
page_size=5,
|
||||
)
|
||||
|
||||
|
||||
def test_get_agents_malformed_request(mocker: pytest_mock.MockFixture):
|
||||
# Test with invalid page number
|
||||
response = client.get("/agents?page=-1")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Test with invalid page size
|
||||
response = client.get("/agents?page_size=0")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Test with non-numeric values
|
||||
response = client.get("/agents?page=abc&page_size=def")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Verify no DB calls were made
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call.assert_not_called()
|
||||
|
||||
|
||||
def test_get_agent_details(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.StoreAgentDetails(
|
||||
store_listing_version_id="test-version-id",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent",
|
||||
agent_video="video.mp4",
|
||||
agent_image=["image1.jpg", "image2.jpg"],
|
||||
creator="creator1",
|
||||
creator_avatar="avatar1.jpg",
|
||||
sub_heading="Test agent subheading",
|
||||
description="Test agent description",
|
||||
categories=["category1", "category2"],
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
versions=["1.0.0", "1.1.0"],
|
||||
last_updated=datetime.datetime.now(),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agent_details")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/agents/creator1/test-agent")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.StoreAgentDetails.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert data.agent_name == "Test Agent"
|
||||
assert data.creator == "creator1"
|
||||
mock_db_call.assert_called_once_with(username="creator1", agent_name="test-agent")
|
||||
|
||||
|
||||
def test_get_creators_defaults(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.CreatorsResponse(
|
||||
creators=[],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=0,
|
||||
total_items=0,
|
||||
total_pages=0,
|
||||
page_size=10,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_creators")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/creators")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.CreatorsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert data.pagination.total_pages == 0
|
||||
assert data.creators == []
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False, search_query=None, sorted_by=None, page=1, page_size=20
|
||||
)
|
||||
|
||||
|
||||
def test_get_creators_pagination(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.CreatorsResponse(
|
||||
creators=[
|
||||
backend.server.v2.store.model.Creator(
|
||||
name=f"Creator {i}",
|
||||
username=f"creator{i}",
|
||||
description=f"Creator {i} description",
|
||||
avatar_url=f"avatar{i}.jpg",
|
||||
num_agents=1,
|
||||
agent_rating=4.5,
|
||||
agent_runs=100,
|
||||
)
|
||||
for i in range(5)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=2,
|
||||
total_items=15,
|
||||
total_pages=3,
|
||||
page_size=5,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_creators")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/creators?page=2&page_size=5")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.CreatorsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert len(data.creators) == 5
|
||||
assert data.pagination.current_page == 2
|
||||
assert data.pagination.page_size == 5
|
||||
mock_db_call.assert_called_once_with(
|
||||
featured=False, search_query=None, sorted_by=None, page=2, page_size=5
|
||||
)
|
||||
|
||||
|
||||
def test_get_creators_malformed_request(mocker: pytest_mock.MockFixture):
|
||||
# Test with invalid page number
|
||||
response = client.get("/creators?page=-1")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Test with invalid page size
|
||||
response = client.get("/creators?page_size=0")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Test with non-numeric values
|
||||
response = client.get("/creators?page=abc&page_size=def")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Verify no DB calls were made
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_creators")
|
||||
mock_db_call.assert_not_called()
|
||||
|
||||
|
||||
def test_get_creator_details(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.CreatorDetails(
|
||||
name="Test User",
|
||||
username="creator1",
|
||||
description="Test creator description",
|
||||
links=["link1.com", "link2.com"],
|
||||
avatar_url="avatar.jpg",
|
||||
agent_rating=4.8,
|
||||
agent_runs=1000,
|
||||
top_categories=["category1", "category2"],
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_creator_details")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/creator/creator1")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.CreatorDetails.model_validate(response.json())
|
||||
assert data.username == "creator1"
|
||||
assert data.name == "Test User"
|
||||
mock_db_call.assert_called_once_with(username="creator1")
|
||||
|
||||
|
||||
def test_get_submissions_success(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.StoreSubmissionsResponse(
|
||||
submissions=[
|
||||
backend.server.v2.store.model.StoreSubmission(
|
||||
name="Test Agent",
|
||||
description="Test agent description",
|
||||
image_urls=["test.jpg"],
|
||||
date_submitted=datetime.datetime.now(),
|
||||
status=prisma.enums.SubmissionStatus.APPROVED,
|
||||
runs=50,
|
||||
rating=4.2,
|
||||
agent_id="test-agent-id",
|
||||
agent_version=1,
|
||||
sub_heading="Test agent subheading",
|
||||
slug="test-agent",
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=1,
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_submissions")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/submissions")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.StoreSubmissionsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert len(data.submissions) == 1
|
||||
assert data.submissions[0].name == "Test Agent"
|
||||
assert data.pagination.current_page == 1
|
||||
mock_db_call.assert_called_once_with(user_id="test-user-id", page=1, page_size=20)
|
||||
|
||||
|
||||
def test_get_submissions_pagination(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = backend.server.v2.store.model.StoreSubmissionsResponse(
|
||||
submissions=[],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=2,
|
||||
total_items=10,
|
||||
total_pages=2,
|
||||
page_size=5,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_submissions")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/submissions?page=2&page_size=5")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = backend.server.v2.store.model.StoreSubmissionsResponse.model_validate(
|
||||
response.json()
|
||||
)
|
||||
assert data.pagination.current_page == 2
|
||||
assert data.pagination.page_size == 5
|
||||
mock_db_call.assert_called_once_with(user_id="test-user-id", page=2, page_size=5)
|
||||
|
||||
|
||||
def test_get_submissions_malformed_request(mocker: pytest_mock.MockFixture):
|
||||
# Test with invalid page number
|
||||
response = client.get("/submissions?page=-1")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Test with invalid page size
|
||||
response = client.get("/submissions?page_size=0")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Test with non-numeric values
|
||||
response = client.get("/submissions?page=abc&page_size=def")
|
||||
assert response.status_code == 422
|
||||
|
||||
# Verify no DB calls were made
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_submissions")
|
||||
mock_db_call.assert_not_called()
|
||||
@@ -1,8 +1,10 @@
|
||||
import ipaddress
|
||||
import re
|
||||
import socket
|
||||
from typing import Callable
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
import idna
|
||||
import requests as req
|
||||
|
||||
from backend.util.settings import Config
|
||||
@@ -21,8 +23,23 @@ BLOCKED_IP_NETWORKS = [
|
||||
# --8<-- [end:BLOCKED_IP_NETWORKS]
|
||||
]
|
||||
|
||||
ALLOWED_SCHEMES = ["http", "https"]
|
||||
HOSTNAME_REGEX = re.compile(r"^[A-Za-z0-9.-]+$") # Basic DNS-safe hostname pattern
|
||||
|
||||
def is_ip_blocked(ip: str) -> bool:
|
||||
|
||||
def _canonicalize_url(url: str) -> str:
|
||||
# Strip spaces and trailing slashes
|
||||
url = url.strip().strip("/")
|
||||
# Ensure the URL starts with http:// or https://
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = "http://" + url
|
||||
|
||||
# Replace backslashes with forward slashes to avoid parsing ambiguities
|
||||
url = url.replace("\\", "/")
|
||||
return url
|
||||
|
||||
|
||||
def _is_ip_blocked(ip: str) -> bool:
|
||||
"""
|
||||
Checks if the IP address is in a blocked network.
|
||||
"""
|
||||
@@ -35,29 +52,51 @@ def validate_url(url: str, trusted_origins: list[str]) -> str:
|
||||
Validates the URL to prevent SSRF attacks by ensuring it does not point to a private
|
||||
or untrusted IP address, unless whitelisted.
|
||||
"""
|
||||
url = url.strip().strip("/")
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = "http://" + url
|
||||
url = _canonicalize_url(url)
|
||||
parsed = urlparse(url)
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
hostname = parsed_url.hostname
|
||||
# Check scheme
|
||||
if parsed.scheme not in ALLOWED_SCHEMES:
|
||||
raise ValueError(
|
||||
f"Scheme '{parsed.scheme}' is not allowed. Only HTTP/HTTPS are supported."
|
||||
)
|
||||
|
||||
if not hostname:
|
||||
raise ValueError(f"Invalid URL: Unable to determine hostname from {url}")
|
||||
# Validate and IDNA encode the hostname
|
||||
if not parsed.hostname:
|
||||
raise ValueError("Invalid URL: No hostname found.")
|
||||
|
||||
if any(hostname == origin for origin in trusted_origins):
|
||||
# IDNA encode to prevent Unicode domain attacks
|
||||
try:
|
||||
ascii_hostname = idna.encode(parsed.hostname).decode("ascii")
|
||||
except idna.IDNAError:
|
||||
raise ValueError("Invalid hostname with unsupported characters.")
|
||||
|
||||
# Check hostname characters
|
||||
if not HOSTNAME_REGEX.match(ascii_hostname):
|
||||
raise ValueError("Hostname contains invalid characters.")
|
||||
|
||||
# Rebuild the URL with the normalized, IDNA-encoded hostname
|
||||
parsed = parsed._replace(netloc=ascii_hostname)
|
||||
url = str(urlunparse(parsed))
|
||||
|
||||
# Check if hostname is a trusted origin (exact match)
|
||||
if ascii_hostname in trusted_origins:
|
||||
return url
|
||||
|
||||
# Resolve all IP addresses for the hostname
|
||||
ip_addresses = {result[4][0] for result in socket.getaddrinfo(hostname, None)}
|
||||
if not ip_addresses:
|
||||
raise ValueError(f"Unable to resolve IP address for {hostname}")
|
||||
try:
|
||||
ip_addresses = {res[4][0] for res in socket.getaddrinfo(ascii_hostname, None)}
|
||||
except socket.gaierror:
|
||||
raise ValueError(f"Unable to resolve IP address for hostname {ascii_hostname}")
|
||||
|
||||
# Check if all IP addresses are global
|
||||
if not ip_addresses:
|
||||
raise ValueError(f"No IP addresses found for {ascii_hostname}")
|
||||
|
||||
# Check if any resolved IP address falls into blocked ranges
|
||||
for ip in ip_addresses:
|
||||
if is_ip_blocked(ip):
|
||||
if _is_ip_blocked(ip):
|
||||
raise ValueError(
|
||||
f"Access to private IP address at {hostname}: {ip} is not allowed."
|
||||
f"Access to private IP address {ip} for hostname {ascii_hostname} is not allowed."
|
||||
)
|
||||
|
||||
return url
|
||||
|
||||
@@ -148,6 +148,16 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
"This value is then used to generate redirect URLs for OAuth flows.",
|
||||
)
|
||||
|
||||
media_gcs_bucket_name: str = Field(
|
||||
default="",
|
||||
description="The name of the Google Cloud Storage bucket for media files",
|
||||
)
|
||||
|
||||
google_application_credentials: str = Field(
|
||||
default="",
|
||||
description="The path to the Google Cloud credentials JSON file",
|
||||
)
|
||||
|
||||
@field_validator("platform_base_url", "frontend_base_url")
|
||||
@classmethod
|
||||
def validate_platform_base_url(cls, v: str, info: ValidationInfo) -> str:
|
||||
|
||||
@@ -60,9 +60,7 @@ async def wait_execution(
|
||||
timeout: int = 20,
|
||||
) -> Sequence[ExecutionResult]:
|
||||
async def is_execution_completed():
|
||||
status = await AgentServer().test_get_graph_run_status(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
status = await AgentServer().test_get_graph_run_status(graph_exec_id, user_id)
|
||||
log.info(f"Execution status: {status}")
|
||||
if status == ExecutionStatus.FAILED:
|
||||
log.info("Execution failed")
|
||||
|
||||
22
autogpt_platform/backend/backend/util/text.py
Normal file
22
autogpt_platform/backend/backend/util/text.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import re
|
||||
|
||||
from jinja2 import BaseLoader
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
|
||||
class TextFormatter:
|
||||
def __init__(self):
|
||||
# Create a sandboxed environment
|
||||
self.env = SandboxedEnvironment(loader=BaseLoader(), autoescape=True)
|
||||
|
||||
# Clear any registered filters, tests, and globals to minimize attack surface
|
||||
self.env.filters.clear()
|
||||
self.env.tests.clear()
|
||||
self.env.globals.clear()
|
||||
|
||||
def format_string(self, template_str: str, values=None, **kwargs) -> str:
|
||||
# For python.format compatibility: replace all {...} with {{..}}.
|
||||
# But avoid replacing {{...}} to {{{...}}}.
|
||||
template_str = re.sub(r"(?<!{){[ a-zA-Z0-9_]+}", r"{\g<0>}", template_str)
|
||||
template = self.env.from_string(template_str)
|
||||
return template.render(values or {}, **kwargs)
|
||||
@@ -1,20 +1,31 @@
|
||||
version: "3"
|
||||
services:
|
||||
postgres-test:
|
||||
image: ankane/pgvector:latest
|
||||
environment:
|
||||
- POSTGRES_USER=agpt_user
|
||||
- POSTGRES_PASSWORD=pass123
|
||||
- POSTGRES_DB=agpt_local
|
||||
- POSTGRES_USER=${DB_USER}
|
||||
- POSTGRES_PASSWORD=${DB_PASS}
|
||||
- POSTGRES_DB=${DB_NAME}
|
||||
healthcheck:
|
||||
test: pg_isready -U $$POSTGRES_USER -d $$POSTGRES_DB
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
ports:
|
||||
- "5433:5432"
|
||||
- "${DB_PORT}:5432"
|
||||
networks:
|
||||
- app-network-test
|
||||
redis-test:
|
||||
image: redis:latest
|
||||
command: redis-server --requirepass password
|
||||
ports:
|
||||
- "6379:6379"
|
||||
networks:
|
||||
- app-network-test
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
networks:
|
||||
app-network-test:
|
||||
|
||||
@@ -3,6 +3,10 @@ import subprocess
|
||||
|
||||
directory = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
BACKEND_DIR = "."
|
||||
LIBS_DIR = "../autogpt_libs"
|
||||
TARGET_DIRS = [BACKEND_DIR, LIBS_DIR]
|
||||
|
||||
|
||||
def run(*command: str) -> None:
|
||||
print(f">>>>> Running poetry run {' '.join(command)}")
|
||||
@@ -11,17 +15,19 @@ def run(*command: str) -> None:
|
||||
|
||||
def lint():
|
||||
try:
|
||||
run("ruff", "check", ".", "--exit-zero")
|
||||
run("isort", "--diff", "--check", "--profile", "black", ".")
|
||||
run("black", "--diff", "--check", ".")
|
||||
run("pyright")
|
||||
run("ruff", "check", *TARGET_DIRS, "--exit-zero")
|
||||
run("ruff", "format", "--diff", "--check", LIBS_DIR)
|
||||
run("isort", "--diff", "--check", "--profile", "black", BACKEND_DIR)
|
||||
run("black", "--diff", "--check", BACKEND_DIR)
|
||||
run("pyright", *TARGET_DIRS)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("Lint failed, try running `poetry run format` to fix the issues: ", e)
|
||||
raise e
|
||||
|
||||
|
||||
def format():
|
||||
run("ruff", "check", "--fix", ".")
|
||||
run("isort", "--profile", "black", ".")
|
||||
run("black", ".")
|
||||
run("pyright", ".")
|
||||
run("ruff", "check", "--fix", *TARGET_DIRS)
|
||||
run("ruff", "format", LIBS_DIR)
|
||||
run("isort", "--profile", "black", BACKEND_DIR)
|
||||
run("black", BACKEND_DIR)
|
||||
run("pyright", *TARGET_DIRS)
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentGraph_userId_isActive_idx" ON "AgentGraph"("userId", "isActive");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentGraphExecution_agentGraphId_agentGraphVersion_idx" ON "AgentGraphExecution"("agentGraphId", "agentGraphVersion");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentGraphExecution_userId_idx" ON "AgentGraphExecution"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentNode_agentGraphId_agentGraphVersion_idx" ON "AgentNode"("agentGraphId", "agentGraphVersion");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentNode_agentBlockId_idx" ON "AgentNode"("agentBlockId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentNode_webhookId_idx" ON "AgentNode"("webhookId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentNodeExecution_agentGraphExecutionId_idx" ON "AgentNodeExecution"("agentGraphExecutionId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentNodeExecution_agentNodeId_idx" ON "AgentNodeExecution"("agentNodeId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentNodeExecutionInputOutput_referencedByOutputExecId_idx" ON "AgentNodeExecutionInputOutput"("referencedByOutputExecId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentNodeLink_agentNodeSourceId_idx" ON "AgentNodeLink"("agentNodeSourceId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentNodeLink_agentNodeSinkId_idx" ON "AgentNodeLink"("agentNodeSinkId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AnalyticsMetrics_userId_idx" ON "AnalyticsMetrics"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "IntegrationWebhook_userId_idx" ON "IntegrationWebhook"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "UserBlockCredit_userId_createdAt_idx" ON "UserBlockCredit"("userId", "createdAt");
|
||||
@@ -0,0 +1,8 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "User" ADD COLUMN "stripeCustomerId" TEXT;
|
||||
|
||||
-- AlterEnum
|
||||
ALTER TYPE "UserBlockCreditType" RENAME TO "CreditTransactionType";
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "UserBlockCredit" RENAME TO "CreditTransaction";
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user