From 9a8c6ad609f8f29087066d318cd5e1673f838e89 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 13 Feb 2026 10:10:11 +0100 Subject: [PATCH 01/16] chore(libs/deps): bump the production-dependencies group across 1 directory with 4 updates (#12056) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps the production-dependencies group with 4 updates in the /autogpt_platform/autogpt_libs directory: [cryptography](https://github.com/pyca/cryptography), [fastapi](https://github.com/fastapi/fastapi), [launchdarkly-server-sdk](https://github.com/launchdarkly/python-server-sdk) and [supabase](https://github.com/supabase/supabase-py). Updates `cryptography` from 46.0.4 to 46.0.5
Changelog

Sourced from cryptography's changelog.

46.0.5 - 2026-02-10


* An attacker could create a malicious public key that reveals portions
of your
private key when using certain uncommon elliptic curves (binary curves).
This version now includes additional security checks to prevent this
attack.
This issue only affects binary elliptic curves, which are rarely used in
real-world applications. Credit to **XlabAI Team of Tencent Xuanwu Lab
and
Atuin Automated Vulnerability Discovery Engine** for reporting the
issue.
  **CVE-2026-26007**
* Support for ``SECT*`` binary elliptic curves is deprecated and will be
  removed in the next release.

.. v46-0-4:

Commits

Updates `fastapi` from 0.128.0 to 0.128.7
Release notes

Sourced from fastapi's releases.

0.128.7

Features

Refactors

Docs

Internal

0.128.6

Fixes

Translations

Internal

0.128.5

Refactors

Internal

0.128.4

Refactors

... (truncated)

Commits

Updates `launchdarkly-server-sdk` from 9.14.1 to 9.15.0
Release notes

Sourced from launchdarkly-server-sdk's releases.

v9.15.0

9.15.0 (2026-02-10)

Features

Bug Fixes


This PR was generated with Release Please. See documentation.

Changelog

Sourced from launchdarkly-server-sdk's changelog.

9.15.0 (2026-02-10)

⚠ BREAKING CHANGES

Note: The following breaking changes apply only to FDv2 (Flag Delivery v2) early access features, which are not subject to semantic versioning and may change without a major version bump.

Features

Bug Fixes

Commits

Updates `supabase` from 2.27.2 to 2.28.0
Release notes

Sourced from supabase's releases.

v2.28.0

2.28.0 (2026-02-10)

Features

Bug Fixes

v2.27.3

2.27.3 (2026-02-03)

Bug Fixes

Changelog

Sourced from supabase's changelog.

2.28.0 (2026-02-10)

Features

Bug Fixes

2.27.3 (2026-02-03)

Bug Fixes

Commits

Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore major version` will close this group update PR and stop Dependabot creating any more for the specific dependency's major version (unless you unignore this specific dependency's major version or upgrade to it yourself) - `@dependabot ignore minor version` will close this group update PR and stop Dependabot creating any more for the specific dependency's minor version (unless you unignore this specific dependency's minor version or upgrade to it yourself) - `@dependabot ignore ` will close this group update PR and stop Dependabot creating any more for the specific dependency (unless you unignore this specific dependency or upgrade to it yourself) - `@dependabot unignore ` will remove all of the ignore conditions of the specified dependency - `@dependabot unignore ` will remove the ignore condition of the specified dependency and ignore conditions

Greptile Overview

Greptile Summary

Dependency update bumps 4 packages in the production-dependencies group, including a **critical security patch for `cryptography`** (CVE-2026-26007) that prevents malicious public key attacks on binary elliptic curves. The update also includes bug fixes for `fastapi`, `launchdarkly-server-sdk`, and `supabase`. - **cryptography** 46.0.4 → 46.0.5: patches CVE-2026-26007, deprecates SECT* binary curves - **fastapi** 0.128.0 → 0.128.7: bug fixes, improved error handling, relaxed Starlette constraint - **launchdarkly-server-sdk** 9.14.1 → 9.15.0: drops Python 3.9 support (requires >=3.10), fixes race conditions - **supabase** 2.27.2/2.27.3 → 2.28.0: realtime fixes, new User model fields The lock files correctly resolve all dependencies. Python 3.10+ requirement is already enforced in both packages. However, backend's `pyproject.toml` still specifies `launchdarkly-server-sdk = "^9.14.1"` while the lock file uses 9.15.0 (pulled from autogpt_libs dependency), creating a minor version constraint inconsistency.

Confidence Score: 4/5

- This PR is safe to merge with one minor style suggestion - Automated dependency update with critical security patch for cryptography. All updates are backwards-compatible within semver constraints. Lock files correctly resolve all dependencies. Python 3.10+ is already enforced. Only minor issue is version constraint inconsistency in backend's pyproject.toml for launchdarkly-server-sdk, which doesn't affect functionality but should be aligned for clarity. - autogpt_platform/backend/pyproject.toml needs launchdarkly-server-sdk version constraint updated to ^9.15.0
--------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Otto --- autogpt_platform/autogpt_libs/poetry.lock | 169 ++++++++++--------- autogpt_platform/autogpt_libs/pyproject.toml | 6 +- autogpt_platform/backend/poetry.lock | 68 ++++---- autogpt_platform/backend/pyproject.toml | 2 +- 4 files changed, 123 insertions(+), 122 deletions(-) diff --git a/autogpt_platform/autogpt_libs/poetry.lock b/autogpt_platform/autogpt_libs/poetry.lock index 0a421dda31..e1d599360e 100644 --- a/autogpt_platform/autogpt_libs/poetry.lock +++ b/autogpt_platform/autogpt_libs/poetry.lock @@ -448,61 +448,61 @@ toml = ["tomli ; python_full_version <= \"3.11.0a6\""] [[package]] name = "cryptography" -version = "46.0.4" +version = "46.0.5" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." optional = false python-versions = "!=3.9.0,!=3.9.1,>=3.8" groups = ["main"] files = [ - {file = "cryptography-46.0.4-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:281526e865ed4166009e235afadf3a4c4cba6056f99336a99efba65336fd5485"}, - {file = "cryptography-46.0.4-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5f14fba5bf6f4390d7ff8f086c566454bff0411f6d8aa7af79c88b6f9267aecc"}, - {file = "cryptography-46.0.4-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:47bcd19517e6389132f76e2d5303ded6cf3f78903da2158a671be8de024f4cd0"}, - {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:01df4f50f314fbe7009f54046e908d1754f19d0c6d3070df1e6268c5a4af09fa"}, - {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:5aa3e463596b0087b3da0dbe2b2487e9fc261d25da85754e30e3b40637d61f81"}, - {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:0a9ad24359fee86f131836a9ac3bffc9329e956624a2d379b613f8f8abaf5255"}, - {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:dc1272e25ef673efe72f2096e92ae39dea1a1a450dd44918b15351f72c5a168e"}, - {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:de0f5f4ec8711ebc555f54735d4c673fc34b65c44283895f1a08c2b49d2fd99c"}, - {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:eeeb2e33d8dbcccc34d64651f00a98cb41b2dc69cef866771a5717e6734dfa32"}, - {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:3d425eacbc9aceafd2cb429e42f4e5d5633c6f873f5e567077043ef1b9bbf616"}, - {file = "cryptography-46.0.4-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:91627ebf691d1ea3976a031b61fb7bac1ccd745afa03602275dda443e11c8de0"}, - {file = "cryptography-46.0.4-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:2d08bc22efd73e8854b0b7caff402d735b354862f1145d7be3b9c0f740fef6a0"}, - {file = "cryptography-46.0.4-cp311-abi3-win32.whl", hash = "sha256:82a62483daf20b8134f6e92898da70d04d0ef9a75829d732ea1018678185f4f5"}, - {file = "cryptography-46.0.4-cp311-abi3-win_amd64.whl", hash = "sha256:6225d3ebe26a55dbc8ead5ad1265c0403552a63336499564675b29eb3184c09b"}, - {file = "cryptography-46.0.4-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:485e2b65d25ec0d901bca7bcae0f53b00133bf3173916d8e421f6fddde103908"}, - {file = "cryptography-46.0.4-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:078e5f06bd2fa5aea5a324f2a09f914b1484f1d0c2a4d6a8a28c74e72f65f2da"}, - {file = "cryptography-46.0.4-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dce1e4f068f03008da7fa51cc7abc6ddc5e5de3e3d1550334eaf8393982a5829"}, - {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:2067461c80271f422ee7bdbe79b9b4be54a5162e90345f86a23445a0cf3fd8a2"}, - {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:c92010b58a51196a5f41c3795190203ac52edfd5dc3ff99149b4659eba9d2085"}, - {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:829c2b12bbc5428ab02d6b7f7e9bbfd53e33efd6672d21341f2177470171ad8b"}, - {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:62217ba44bf81b30abaeda1488686a04a702a261e26f87db51ff61d9d3510abd"}, - {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:9c2da296c8d3415b93e6053f5a728649a87a48ce084a9aaf51d6e46c87c7f2d2"}, - {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:9b34d8ba84454641a6bf4d6762d15847ecbd85c1316c0a7984e6e4e9f748ec2e"}, - {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:df4a817fa7138dd0c96c8c8c20f04b8aaa1fac3bbf610913dcad8ea82e1bfd3f"}, - {file = "cryptography-46.0.4-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:b1de0ebf7587f28f9190b9cb526e901bf448c9e6a99655d2b07fff60e8212a82"}, - {file = "cryptography-46.0.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:9b4d17bc7bd7cdd98e3af40b441feaea4c68225e2eb2341026c84511ad246c0c"}, - {file = "cryptography-46.0.4-cp314-cp314t-win32.whl", hash = "sha256:c411f16275b0dea722d76544a61d6421e2cc829ad76eec79280dbdc9ddf50061"}, - {file = "cryptography-46.0.4-cp314-cp314t-win_amd64.whl", hash = "sha256:728fedc529efc1439eb6107b677f7f7558adab4553ef8669f0d02d42d7b959a7"}, - {file = "cryptography-46.0.4-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a9556ba711f7c23f77b151d5798f3ac44a13455cc68db7697a1096e6d0563cab"}, - {file = "cryptography-46.0.4-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8bf75b0259e87fa70bddc0b8b4078b76e7fd512fd9afae6c1193bcf440a4dbef"}, - {file = "cryptography-46.0.4-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3c268a3490df22270955966ba236d6bc4a8f9b6e4ffddb78aac535f1a5ea471d"}, - {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:812815182f6a0c1d49a37893a303b44eaac827d7f0d582cecfc81b6427f22973"}, - {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:a90e43e3ef65e6dcf969dfe3bb40cbf5aef0d523dff95bfa24256be172a845f4"}, - {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:a05177ff6296644ef2876fce50518dffb5bcdf903c85250974fc8bc85d54c0af"}, - {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:daa392191f626d50f1b136c9b4cf08af69ca8279d110ea24f5c2700054d2e263"}, - {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e07ea39c5b048e085f15923511d8121e4a9dc45cee4e3b970ca4f0d338f23095"}, - {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:d5a45ddc256f492ce42a4e35879c5e5528c09cd9ad12420828c972951d8e016b"}, - {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:6bb5157bf6a350e5b28aee23beb2d84ae6f5be390b2f8ee7ea179cda077e1019"}, - {file = "cryptography-46.0.4-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:dd5aba870a2c40f87a3af043e0dee7d9eb02d4aff88a797b48f2b43eff8c3ab4"}, - {file = "cryptography-46.0.4-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:93d8291da8d71024379ab2cb0b5c57915300155ad42e07f76bea6ad838d7e59b"}, - {file = "cryptography-46.0.4-cp38-abi3-win32.whl", hash = "sha256:0563655cb3c6d05fb2afe693340bc050c30f9f34e15763361cf08e94749401fc"}, - {file = "cryptography-46.0.4-cp38-abi3-win_amd64.whl", hash = "sha256:fa0900b9ef9c49728887d1576fd8d9e7e3ea872fa9b25ef9b64888adc434e976"}, - {file = "cryptography-46.0.4-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:766330cce7416c92b5e90c3bb71b1b79521760cdcfc3a6a1a182d4c9fab23d2b"}, - {file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c236a44acfb610e70f6b3e1c3ca20ff24459659231ef2f8c48e879e2d32b73da"}, - {file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:8a15fb869670efa8f83cbffbc8753c1abf236883225aed74cd179b720ac9ec80"}, - {file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:fdc3daab53b212472f1524d070735b2f0c214239df131903bae1d598016fa822"}, - {file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:44cc0675b27cadb71bdbb96099cca1fa051cd11d2ade09e5cd3a2edb929ed947"}, - {file = "cryptography-46.0.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:be8c01a7d5a55f9a47d1888162b76c8f49d62b234d88f0ff91a9fbebe32ffbc3"}, - {file = "cryptography-46.0.4.tar.gz", hash = "sha256:bfd019f60f8abc2ed1b9be4ddc21cfef059c841d86d710bb69909a688cbb8f59"}, + {file = "cryptography-46.0.5-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:351695ada9ea9618b3500b490ad54c739860883df6c1f555e088eaf25b1bbaad"}, + {file = "cryptography-46.0.5-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c18ff11e86df2e28854939acde2d003f7984f721eba450b56a200ad90eeb0e6b"}, + {file = "cryptography-46.0.5-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d7e3d356b8cd4ea5aff04f129d5f66ebdc7b6f8eae802b93739ed520c47c79b"}, + {file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:50bfb6925eff619c9c023b967d5b77a54e04256c4281b0e21336a130cd7fc263"}, + {file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:803812e111e75d1aa73690d2facc295eaefd4439be1023fefc4995eaea2af90d"}, + {file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ee190460e2fbe447175cda91b88b84ae8322a104fc27766ad09428754a618ed"}, + {file = "cryptography-46.0.5-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:f145bba11b878005c496e93e257c1e88f154d278d2638e6450d17e0f31e558d2"}, + {file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e9251e3be159d1020c4030bd2e5f84d6a43fe54b6c19c12f51cde9542a2817b2"}, + {file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:47fb8a66058b80e509c47118ef8a75d14c455e81ac369050f20ba0d23e77fee0"}, + {file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:4c3341037c136030cb46e4b1e17b7418ea4cbd9dd207e4a6f3b2b24e0d4ac731"}, + {file = "cryptography-46.0.5-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:890bcb4abd5a2d3f852196437129eb3667d62630333aacc13dfd470fad3aaa82"}, + {file = "cryptography-46.0.5-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:80a8d7bfdf38f87ca30a5391c0c9ce4ed2926918e017c29ddf643d0ed2778ea1"}, + {file = "cryptography-46.0.5-cp311-abi3-win32.whl", hash = "sha256:60ee7e19e95104d4c03871d7d7dfb3d22ef8a9b9c6778c94e1c8fcc8365afd48"}, + {file = "cryptography-46.0.5-cp311-abi3-win_amd64.whl", hash = "sha256:38946c54b16c885c72c4f59846be9743d699eee2b69b6988e0a00a01f46a61a4"}, + {file = "cryptography-46.0.5-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:94a76daa32eb78d61339aff7952ea819b1734b46f73646a07decb40e5b3448e2"}, + {file = "cryptography-46.0.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5be7bf2fb40769e05739dd0046e7b26f9d4670badc7b032d6ce4db64dddc0678"}, + {file = "cryptography-46.0.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe346b143ff9685e40192a4960938545c699054ba11d4f9029f94751e3f71d87"}, + {file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:c69fd885df7d089548a42d5ec05be26050ebcd2283d89b3d30676eb32ff87dee"}, + {file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:8293f3dea7fc929ef7240796ba231413afa7b68ce38fd21da2995549f5961981"}, + {file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:1abfdb89b41c3be0365328a410baa9df3ff8a9110fb75e7b52e66803ddabc9a9"}, + {file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:d66e421495fdb797610a08f43b05269e0a5ea7f5e652a89bfd5a7d3c1dee3648"}, + {file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:4e817a8920bfbcff8940ecfd60f23d01836408242b30f1a708d93198393a80b4"}, + {file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:68f68d13f2e1cb95163fa3b4db4bf9a159a418f5f6e7242564fc75fcae667fd0"}, + {file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:a3d1fae9863299076f05cb8a778c467578262fae09f9dc0ee9b12eb4268ce663"}, + {file = "cryptography-46.0.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:c4143987a42a2397f2fc3b4d7e3a7d313fbe684f67ff443999e803dd75a76826"}, + {file = "cryptography-46.0.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:7d731d4b107030987fd61a7f8ab512b25b53cef8f233a97379ede116f30eb67d"}, + {file = "cryptography-46.0.5-cp314-cp314t-win32.whl", hash = "sha256:c3bcce8521d785d510b2aad26ae2c966092b7daa8f45dd8f44734a104dc0bc1a"}, + {file = "cryptography-46.0.5-cp314-cp314t-win_amd64.whl", hash = "sha256:4d8ae8659ab18c65ced284993c2265910f6c9e650189d4e3f68445ef82a810e4"}, + {file = "cryptography-46.0.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:4108d4c09fbbf2789d0c926eb4152ae1760d5a2d97612b92d508d96c861e4d31"}, + {file = "cryptography-46.0.5-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d1f30a86d2757199cb2d56e48cce14deddf1f9c95f1ef1b64ee91ea43fe2e18"}, + {file = "cryptography-46.0.5-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:039917b0dc418bb9f6edce8a906572d69e74bd330b0b3fea4f79dab7f8ddd235"}, + {file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ba2a27ff02f48193fc4daeadf8ad2590516fa3d0adeeb34336b96f7fa64c1e3a"}, + {file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:61aa400dce22cb001a98014f647dc21cda08f7915ceb95df0c9eaf84b4b6af76"}, + {file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ce58ba46e1bc2aac4f7d9290223cead56743fa6ab94a5d53292ffaac6a91614"}, + {file = "cryptography-46.0.5-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:420d0e909050490d04359e7fdb5ed7e667ca5c3c402b809ae2563d7e66a92229"}, + {file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:582f5fcd2afa31622f317f80426a027f30dc792e9c80ffee87b993200ea115f1"}, + {file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:bfd56bb4b37ed4f330b82402f6f435845a5f5648edf1ad497da51a8452d5d62d"}, + {file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:a3d507bb6a513ca96ba84443226af944b0f7f47dcc9a399d110cd6146481d24c"}, + {file = "cryptography-46.0.5-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9f16fbdf4da055efb21c22d81b89f155f02ba420558db21288b3d0035bafd5f4"}, + {file = "cryptography-46.0.5-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:ced80795227d70549a411a4ab66e8ce307899fad2220ce5ab2f296e687eacde9"}, + {file = "cryptography-46.0.5-cp38-abi3-win32.whl", hash = "sha256:02f547fce831f5096c9a567fd41bc12ca8f11df260959ecc7c3202555cc47a72"}, + {file = "cryptography-46.0.5-cp38-abi3-win_amd64.whl", hash = "sha256:556e106ee01aa13484ce9b0239bca667be5004efb0aabbed28d353df86445595"}, + {file = "cryptography-46.0.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:3b4995dc971c9fb83c25aa44cf45f02ba86f71ee600d81091c2f0cbae116b06c"}, + {file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bc84e875994c3b445871ea7181d424588171efec3e185dced958dad9e001950a"}, + {file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2ae6971afd6246710480e3f15824ed3029a60fc16991db250034efd0b9fb4356"}, + {file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:d861ee9e76ace6cf36a6a89b959ec08e7bc2493ee39d07ffe5acb23ef46d27da"}, + {file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:2b7a67c9cd56372f3249b39699f2ad479f6991e62ea15800973b956f4b73e257"}, + {file = "cryptography-46.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:8456928655f856c6e1533ff59d5be76578a7157224dbd9ce6872f25055ab9ab7"}, + {file = "cryptography-46.0.5.tar.gz", hash = "sha256:abace499247268e3757271b2f1e244b36b06f8515cf27c4d49468fc9eb16e93d"}, ] [package.dependencies] @@ -516,7 +516,7 @@ nox = ["nox[uv] (>=2024.4.15)"] pep8test = ["check-sdist", "click (>=8.0.1)", "mypy (>=1.14)", "ruff (>=0.11.11)"] sdist = ["build (>=1.0.0)"] ssh = ["bcrypt (>=3.1.5)"] -test = ["certifi (>=2024)", "cryptography-vectors (==46.0.4)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"] +test = ["certifi (>=2024)", "cryptography-vectors (==46.0.5)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"] test-randomorder = ["pytest-randomly"] [[package]] @@ -570,24 +570,25 @@ tests = ["coverage", "coveralls", "dill", "mock", "nose"] [[package]] name = "fastapi" -version = "0.128.0" +version = "0.128.7" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d"}, - {file = "fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a"}, + {file = "fastapi-0.128.7-py3-none-any.whl", hash = "sha256:6bd9bd31cb7047465f2d3fa3ba3f33b0870b17d4eaf7cdb36d1576ab060ad662"}, + {file = "fastapi-0.128.7.tar.gz", hash = "sha256:783c273416995486c155ad2c0e2b45905dedfaf20b9ef8d9f6a9124670639a24"}, ] [package.dependencies] annotated-doc = ">=0.0.2" pydantic = ">=2.7.0" -starlette = ">=0.40.0,<0.51.0" +starlette = ">=0.40.0,<1.0.0" typing-extensions = ">=4.8.0" +typing-inspection = ">=0.4.2" [package.extras] -all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=3.1.5)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=3.1.5)", "orjson (>=3.9.3)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "pyyaml (>=5.3.1)", "ujson (>=5.8.0)", "uvicorn[standard] (>=0.12.0)"] standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "jinja2 (>=3.1.5)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"] standard-no-fastapi-cloud-cli = ["email-validator (>=2.0.0)", "fastapi-cli[standard-no-fastapi-cloud-cli] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "jinja2 (>=3.1.5)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"] @@ -1062,14 +1063,14 @@ urllib3 = ">=1.26.0,<3" [[package]] name = "launchdarkly-server-sdk" -version = "9.14.1" +version = "9.15.0" description = "LaunchDarkly SDK for Python" optional = false -python-versions = ">=3.9" +python-versions = ">=3.10" groups = ["main"] files = [ - {file = "launchdarkly_server_sdk-9.14.1-py3-none-any.whl", hash = "sha256:a9e2bd9ecdef845cd631ae0d4334a1115e5b44257c42eb2349492be4bac7815c"}, - {file = "launchdarkly_server_sdk-9.14.1.tar.gz", hash = "sha256:1df44baf0a0efa74d8c1dad7a00592b98bce7d19edded7f770da8dbc49922213"}, + {file = "launchdarkly_server_sdk-9.15.0-py3-none-any.whl", hash = "sha256:c267e29bfa3fb5e2a06a208448ada6ed5557a2924979b8d79c970b45d227c668"}, + {file = "launchdarkly_server_sdk-9.15.0.tar.gz", hash = "sha256:f31441b74bc1a69c381db57c33116509e407a2612628ad6dff0a7dbb39d5020b"}, ] [package.dependencies] @@ -1478,14 +1479,14 @@ testing = ["coverage", "pytest", "pytest-benchmark"] [[package]] name = "postgrest" -version = "2.27.2" +version = "2.28.0" description = "PostgREST client for Python. This library provides an ORM interface to PostgREST." optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "postgrest-2.27.2-py3-none-any.whl", hash = "sha256:1666fef3de05ca097a314433dd5ae2f2d71c613cb7b233d0f468c4ffe37277da"}, - {file = "postgrest-2.27.2.tar.gz", hash = "sha256:55407d530b5af3d64e883a71fec1f345d369958f723ce4a8ab0b7d169e313242"}, + {file = "postgrest-2.28.0-py3-none-any.whl", hash = "sha256:7bca2f24dd1a1bf8a3d586c7482aba6cd41662da6733045fad585b63b7f7df75"}, + {file = "postgrest-2.28.0.tar.gz", hash = "sha256:c36b38646d25ea4255321d3d924ce70f8d20ec7799cb42c1221d6a818d4f6515"}, ] [package.dependencies] @@ -2248,14 +2249,14 @@ cli = ["click (>=5.0)"] [[package]] name = "realtime" -version = "2.27.2" +version = "2.28.0" description = "" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "realtime-2.27.2-py3-none-any.whl", hash = "sha256:34a9cbb26a274e707e8fc9e3ee0a66de944beac0fe604dc336d1e985db2c830f"}, - {file = "realtime-2.27.2.tar.gz", hash = "sha256:b960a90294d2cea1b3f1275ecb89204304728e08fff1c393cc1b3150739556b3"}, + {file = "realtime-2.28.0-py3-none-any.whl", hash = "sha256:db1bd59bab9b1fcc9f9d3b1a073bed35bf4994d720e6751f10031a58d57a3836"}, + {file = "realtime-2.28.0.tar.gz", hash = "sha256:d18cedcebd6a8f22fcd509bc767f639761eb218b7b2b6f14fc4205b6259b50fc"}, ] [package.dependencies] @@ -2436,14 +2437,14 @@ full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart [[package]] name = "storage3" -version = "2.27.2" +version = "2.28.0" description = "Supabase Storage client for Python." optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "storage3-2.27.2-py3-none-any.whl", hash = "sha256:e6f16e7a260729e7b1f46e9bf61746805a02e30f5e419ee1291007c432e3ec63"}, - {file = "storage3-2.27.2.tar.gz", hash = "sha256:cb4807b7f86b4bb1272ac6fdd2f3cfd8ba577297046fa5f88557425200275af5"}, + {file = "storage3-2.28.0-py3-none-any.whl", hash = "sha256:ecb50efd2ac71dabbdf97e99ad346eafa630c4c627a8e5a138ceb5fbbadae716"}, + {file = "storage3-2.28.0.tar.gz", hash = "sha256:bc1d008aff67de7a0f2bd867baee7aadbcdb6f78f5a310b4f7a38e8c13c19865"}, ] [package.dependencies] @@ -2487,35 +2488,35 @@ python-dateutil = ">=2.6.0" [[package]] name = "supabase" -version = "2.27.2" +version = "2.28.0" description = "Supabase client for Python." optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "supabase-2.27.2-py3-none-any.whl", hash = "sha256:d4dce00b3a418ee578017ec577c0e5be47a9a636355009c76f20ed2faa15bc54"}, - {file = "supabase-2.27.2.tar.gz", hash = "sha256:2aed40e4f3454438822442a1e94a47be6694c2c70392e7ae99b51a226d4293f7"}, + {file = "supabase-2.28.0-py3-none-any.whl", hash = "sha256:42776971c7d0ccca16034df1ab96a31c50228eb1eb19da4249ad2f756fc20272"}, + {file = "supabase-2.28.0.tar.gz", hash = "sha256:aea299aaab2a2eed3c57e0be7fc035c6807214194cce795a3575add20268ece1"}, ] [package.dependencies] httpx = ">=0.26,<0.29" -postgrest = "2.27.2" -realtime = "2.27.2" -storage3 = "2.27.2" -supabase-auth = "2.27.2" -supabase-functions = "2.27.2" +postgrest = "2.28.0" +realtime = "2.28.0" +storage3 = "2.28.0" +supabase-auth = "2.28.0" +supabase-functions = "2.28.0" yarl = ">=1.22.0" [[package]] name = "supabase-auth" -version = "2.27.2" +version = "2.28.0" description = "Python Client Library for Supabase Auth" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "supabase_auth-2.27.2-py3-none-any.whl", hash = "sha256:78ec25b11314d0a9527a7205f3b1c72560dccdc11b38392f80297ef98664ee91"}, - {file = "supabase_auth-2.27.2.tar.gz", hash = "sha256:0f5bcc79b3677cb42e9d321f3c559070cfa40d6a29a67672cc8382fb7dc2fe97"}, + {file = "supabase_auth-2.28.0-py3-none-any.whl", hash = "sha256:2ac85026cc285054c7fa6d41924f3a333e9ec298c013e5b5e1754039ba7caec9"}, + {file = "supabase_auth-2.28.0.tar.gz", hash = "sha256:2bb8f18ff39934e44b28f10918db965659f3735cd6fbfcc022fe0b82dbf8233e"}, ] [package.dependencies] @@ -2525,14 +2526,14 @@ pyjwt = {version = ">=2.10.1", extras = ["crypto"]} [[package]] name = "supabase-functions" -version = "2.27.2" +version = "2.28.0" description = "Library for Supabase Functions" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "supabase_functions-2.27.2-py3-none-any.whl", hash = "sha256:db480efc669d0bca07605b9b6f167312af43121adcc842a111f79bea416ef754"}, - {file = "supabase_functions-2.27.2.tar.gz", hash = "sha256:d0c8266207a94371cb3fd35ad3c7f025b78a97cf026861e04ccd35ac1775f80b"}, + {file = "supabase_functions-2.28.0-py3-none-any.whl", hash = "sha256:30bf2d586f8df285faf0621bb5d5bb3ec3157234fc820553ca156f009475e4ae"}, + {file = "supabase_functions-2.28.0.tar.gz", hash = "sha256:db3dddfc37aca5858819eb461130968473bd8c75bd284581013958526dac718b"}, ] [package.dependencies] @@ -2911,4 +2912,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "40eae94995dc0a388fa832ed4af9b6137f28d5b5ced3aaea70d5f91d4d9a179d" +content-hash = "9619cae908ad38fa2c48016a58bcf4241f6f5793aa0e6cc140276e91c433cbbb" diff --git a/autogpt_platform/autogpt_libs/pyproject.toml b/autogpt_platform/autogpt_libs/pyproject.toml index 8deb4d2169..2cfa742922 100644 --- a/autogpt_platform/autogpt_libs/pyproject.toml +++ b/autogpt_platform/autogpt_libs/pyproject.toml @@ -11,14 +11,14 @@ python = ">=3.10,<4.0" colorama = "^0.4.6" cryptography = "^46.0" expiringdict = "^1.2.2" -fastapi = "^0.128.0" +fastapi = "^0.128.7" google-cloud-logging = "^3.13.0" -launchdarkly-server-sdk = "^9.14.1" +launchdarkly-server-sdk = "^9.15.0" pydantic = "^2.12.5" pydantic-settings = "^2.12.0" pyjwt = { version = "^2.11.0", extras = ["crypto"] } redis = "^6.2.0" -supabase = "^2.27.2" +supabase = "^2.28.0" uvicorn = "^0.40.0" [tool.poetry.group.dev.dependencies] diff --git a/autogpt_platform/backend/poetry.lock b/autogpt_platform/backend/poetry.lock index 53b5030da6..d71cca7865 100644 --- a/autogpt_platform/backend/poetry.lock +++ b/autogpt_platform/backend/poetry.lock @@ -441,14 +441,14 @@ develop = true colorama = "^0.4.6" cryptography = "^46.0" expiringdict = "^1.2.2" -fastapi = "^0.128.0" +fastapi = "^0.128.7" google-cloud-logging = "^3.13.0" -launchdarkly-server-sdk = "^9.14.1" +launchdarkly-server-sdk = "^9.15.0" pydantic = "^2.12.5" pydantic-settings = "^2.12.0" pyjwt = {version = "^2.11.0", extras = ["crypto"]} redis = "^6.2.0" -supabase = "^2.27.2" +supabase = "^2.28.0" uvicorn = "^0.40.0" [package.source] @@ -1382,14 +1382,14 @@ tzdata = "*" [[package]] name = "fastapi" -version = "0.128.6" +version = "0.128.7" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "fastapi-0.128.6-py3-none-any.whl", hash = "sha256:bb1c1ef87d6086a7132d0ab60869d6f1ee67283b20fbf84ec0003bd335099509"}, - {file = "fastapi-0.128.6.tar.gz", hash = "sha256:0cb3946557e792d731b26a42b04912f16367e3c3135ea8290f620e234f2b604f"}, + {file = "fastapi-0.128.7-py3-none-any.whl", hash = "sha256:6bd9bd31cb7047465f2d3fa3ba3f33b0870b17d4eaf7cdb36d1576ab060ad662"}, + {file = "fastapi-0.128.7.tar.gz", hash = "sha256:783c273416995486c155ad2c0e2b45905dedfaf20b9ef8d9f6a9124670639a24"}, ] [package.dependencies] @@ -3117,14 +3117,14 @@ urllib3 = ">=1.26.0,<3" [[package]] name = "launchdarkly-server-sdk" -version = "9.14.1" +version = "9.15.0" description = "LaunchDarkly SDK for Python" optional = false -python-versions = ">=3.9" +python-versions = ">=3.10" groups = ["main"] files = [ - {file = "launchdarkly_server_sdk-9.14.1-py3-none-any.whl", hash = "sha256:a9e2bd9ecdef845cd631ae0d4334a1115e5b44257c42eb2349492be4bac7815c"}, - {file = "launchdarkly_server_sdk-9.14.1.tar.gz", hash = "sha256:1df44baf0a0efa74d8c1dad7a00592b98bce7d19edded7f770da8dbc49922213"}, + {file = "launchdarkly_server_sdk-9.15.0-py3-none-any.whl", hash = "sha256:c267e29bfa3fb5e2a06a208448ada6ed5557a2924979b8d79c970b45d227c668"}, + {file = "launchdarkly_server_sdk-9.15.0.tar.gz", hash = "sha256:f31441b74bc1a69c381db57c33116509e407a2612628ad6dff0a7dbb39d5020b"}, ] [package.dependencies] @@ -4728,14 +4728,14 @@ tests = ["coverage-conditional-plugin (>=0.9.0)", "portalocker[redis]", "pytest [[package]] name = "postgrest" -version = "2.27.3" +version = "2.28.0" description = "PostgREST client for Python. This library provides an ORM interface to PostgREST." optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "postgrest-2.27.3-py3-none-any.whl", hash = "sha256:ed79123af7127edd78d538bfe8351d277e45b1a36994a4dbf57ae27dde87a7b7"}, - {file = "postgrest-2.27.3.tar.gz", hash = "sha256:c2e2679addfc8eaab23197bad7ddaee6cbb4cbe8c483ebd2d2e5219543037cc3"}, + {file = "postgrest-2.28.0-py3-none-any.whl", hash = "sha256:7bca2f24dd1a1bf8a3d586c7482aba6cd41662da6733045fad585b63b7f7df75"}, + {file = "postgrest-2.28.0.tar.gz", hash = "sha256:c36b38646d25ea4255321d3d924ce70f8d20ec7799cb42c1221d6a818d4f6515"}, ] [package.dependencies] @@ -6260,14 +6260,14 @@ all = ["numpy"] [[package]] name = "realtime" -version = "2.27.3" +version = "2.28.0" description = "" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "realtime-2.27.3-py3-none-any.whl", hash = "sha256:f571115f86988e33c41c895cb3fba2eaa1b693aeaede3617288f44274ca90f43"}, - {file = "realtime-2.27.3.tar.gz", hash = "sha256:02b082243107656a5ef3fb63e8e2ab4c40bc199abb45adb8a42ed63f089a1041"}, + {file = "realtime-2.28.0-py3-none-any.whl", hash = "sha256:db1bd59bab9b1fcc9f9d3b1a073bed35bf4994d720e6751f10031a58d57a3836"}, + {file = "realtime-2.28.0.tar.gz", hash = "sha256:d18cedcebd6a8f22fcd509bc767f639761eb218b7b2b6f14fc4205b6259b50fc"}, ] [package.dependencies] @@ -7024,14 +7024,14 @@ full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart [[package]] name = "storage3" -version = "2.27.3" +version = "2.28.0" description = "Supabase Storage client for Python." optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "storage3-2.27.3-py3-none-any.whl", hash = "sha256:11a05b7da84bccabeeea12d940bca3760cf63fe6ca441868677335cfe4fdfbe0"}, - {file = "storage3-2.27.3.tar.gz", hash = "sha256:dc1a4a010cf36d5482c5cb6c1c28fc5f00e23284342b89e4ae43b5eae8501ddb"}, + {file = "storage3-2.28.0-py3-none-any.whl", hash = "sha256:ecb50efd2ac71dabbdf97e99ad346eafa630c4c627a8e5a138ceb5fbbadae716"}, + {file = "storage3-2.28.0.tar.gz", hash = "sha256:bc1d008aff67de7a0f2bd867baee7aadbcdb6f78f5a310b4f7a38e8c13c19865"}, ] [package.dependencies] @@ -7091,35 +7091,35 @@ typing-extensions = {version = ">=4.5.0", markers = "python_version >= \"3.7\""} [[package]] name = "supabase" -version = "2.27.3" +version = "2.28.0" description = "Supabase client for Python." optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "supabase-2.27.3-py3-none-any.whl", hash = "sha256:082a74642fcf9954693f1ce8c251baf23e4bda26ffdbc8dcd4c99c82e60d69ff"}, - {file = "supabase-2.27.3.tar.gz", hash = "sha256:5e5a348232ac4315c1032ddd687278f0b982465471f0cbb52bca7e6a66495ff3"}, + {file = "supabase-2.28.0-py3-none-any.whl", hash = "sha256:42776971c7d0ccca16034df1ab96a31c50228eb1eb19da4249ad2f756fc20272"}, + {file = "supabase-2.28.0.tar.gz", hash = "sha256:aea299aaab2a2eed3c57e0be7fc035c6807214194cce795a3575add20268ece1"}, ] [package.dependencies] httpx = ">=0.26,<0.29" -postgrest = "2.27.3" -realtime = "2.27.3" -storage3 = "2.27.3" -supabase-auth = "2.27.3" -supabase-functions = "2.27.3" +postgrest = "2.28.0" +realtime = "2.28.0" +storage3 = "2.28.0" +supabase-auth = "2.28.0" +supabase-functions = "2.28.0" yarl = ">=1.22.0" [[package]] name = "supabase-auth" -version = "2.27.3" +version = "2.28.0" description = "Python Client Library for Supabase Auth" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "supabase_auth-2.27.3-py3-none-any.whl", hash = "sha256:82a4262eaad85383319d394dab0eea11fcf3ebd774062aef8ea3874ae2f02579"}, - {file = "supabase_auth-2.27.3.tar.gz", hash = "sha256:39894d4bc60b6f23b5cff4d0d7d4c1659e5d69563cadf014d4896f780ca8ca78"}, + {file = "supabase_auth-2.28.0-py3-none-any.whl", hash = "sha256:2ac85026cc285054c7fa6d41924f3a333e9ec298c013e5b5e1754039ba7caec9"}, + {file = "supabase_auth-2.28.0.tar.gz", hash = "sha256:2bb8f18ff39934e44b28f10918db965659f3735cd6fbfcc022fe0b82dbf8233e"}, ] [package.dependencies] @@ -7129,14 +7129,14 @@ pyjwt = {version = ">=2.10.1", extras = ["crypto"]} [[package]] name = "supabase-functions" -version = "2.27.3" +version = "2.28.0" description = "Library for Supabase Functions" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "supabase_functions-2.27.3-py3-none-any.whl", hash = "sha256:9d14a931d49ede1c6cf5fbfceb11c44061535ba1c3f310f15384964d86a83d9e"}, - {file = "supabase_functions-2.27.3.tar.gz", hash = "sha256:e954f1646da8ca6e7e16accef58d0884a5f97b25956ee98e7d4927a210ed92f9"}, + {file = "supabase_functions-2.28.0-py3-none-any.whl", hash = "sha256:30bf2d586f8df285faf0621bb5d5bb3ec3157234fc820553ca156f009475e4ae"}, + {file = "supabase_functions-2.28.0.tar.gz", hash = "sha256:db3dddfc37aca5858819eb461130968473bd8c75bd284581013958526dac718b"}, ] [package.dependencies] @@ -8440,4 +8440,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "c06e96ad49388ba7a46786e9ea55ea2c1a57408e15613237b4bee40a592a12af" +content-hash = "fa9c5deadf593e815dd2190f58e22152373900603f5f244b9616cd721de84d2f" diff --git a/autogpt_platform/backend/pyproject.toml b/autogpt_platform/backend/pyproject.toml index 317663ee98..32dfc547bc 100644 --- a/autogpt_platform/backend/pyproject.toml +++ b/autogpt_platform/backend/pyproject.toml @@ -65,7 +65,7 @@ sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlal sqlalchemy = "^2.0.40" strenum = "^0.4.9" stripe = "^11.5.0" -supabase = "2.27.3" +supabase = "2.28.0" tenacity = "^9.1.4" todoist-api-python = "^2.1.7" tweepy = "^4.16.0" From ab0b537cc7d1484dd2777b0d56f397601aba3e76 Mon Sep 17 00:00:00 2001 From: Swifty Date: Fri, 13 Feb 2026 11:08:51 +0100 Subject: [PATCH 02/16] refactor(backend): optimize find_block response size by removing raw JSON schemas (#12020) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Changes 🏗️ The `find_block` AutoPilot tool was returning ~90K characters per response (10 blocks). The bloat came from including full JSON Schema objects (`input_schema`, `output_schema`) with all nested `$defs`, `anyOf`, and type definitions for every block. **What changed:** - **`BlockInfoSummary` model**: Removed `input_schema` (raw JSON Schema), `output_schema` (raw JSON Schema), and `categories`. Added `output_fields` (compact field-level summaries matching the existing `required_inputs` format). - **`BlockListResponse` model**: Removed `usage_hint` (info now in `message`). - **`FindBlockTool._execute()`**: Now extracts compact `output_fields` from output schema properties instead of including the entire raw schema. Credentials handling is unchanged. - **Test**: Added `test_response_size_average_chars_per_block` with realistic block schemas (HTTP, Email, Claude Code) to measure and assert response size stays under 2K chars/block. - **`CLAUDE.md`**: Clarified `dev` vs `master` branching strategy. **Result:** Average response size reduced from ~9,000 to ~1,300 chars per block (~85% reduction). This directly reduces LLM token consumption, latency, and API costs for AutoPilot interactions. ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Verified models import and serialize correctly - [x] Verified response size: 3,970 chars for 3 realistic blocks (avg 1,323/block) - [x] Lint (`ruff check`) and type check (`pyright`) pass on changed files - [x] Frontend compatibility preserved: `blocks[].name` and `count` fields retained for `block_list` handler --------- Co-authored-by: Claude Opus 4.6 Co-authored-by: Toran Bruce Richards --- autogpt_platform/CLAUDE.md | 5 + .../backend/api/features/chat/routes.py | 2 + .../api/features/chat/tools/find_block.py | 63 +---- .../features/chat/tools/find_block_test.py | 255 ++++++++++++++++- .../backend/api/features/chat/tools/models.py | 31 ++- .../api/features/chat/tools/run_block.py | 94 ++++++- .../api/features/chat/tools/run_block_test.py | 262 +++++++++++++++++- .../chat/tools/test_run_block_details.py | 153 ++++++++++ .../copilot/tools/RunBlock/RunBlock.tsx | 7 + .../BlockDetailsCard.stories.tsx | 188 +++++++++++++ .../BlockDetailsCard/BlockDetailsCard.tsx | 103 +++++++ .../copilot/tools/RunBlock/helpers.tsx | 58 +++- .../frontend/src/app/api/openapi.json | 114 ++++---- 13 files changed, 1194 insertions(+), 141 deletions(-) create mode 100644 autogpt_platform/backend/backend/api/features/chat/tools/test_run_block_details.py create mode 100644 autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/components/BlockDetailsCard/BlockDetailsCard.stories.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/components/BlockDetailsCard/BlockDetailsCard.tsx diff --git a/autogpt_platform/CLAUDE.md b/autogpt_platform/CLAUDE.md index 62adbdaefa..021b7c27e4 100644 --- a/autogpt_platform/CLAUDE.md +++ b/autogpt_platform/CLAUDE.md @@ -45,6 +45,11 @@ AutoGPT Platform is a monorepo containing: - Backend/Frontend services use YAML anchors for consistent configuration - Supabase services (`db/docker/docker-compose.yml`) follow the same pattern +### Branching Strategy + +- **`dev`** is the main development branch. All PRs should target `dev`. +- **`master`** is the production branch. Only used for production releases. + ### Creating Pull Requests - Create the PR against the `dev` branch of the repository. diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index c6f37569b7..0d8b12b0b7 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -24,6 +24,7 @@ from .tools.models import ( AgentPreviewResponse, AgentSavedResponse, AgentsFoundResponse, + BlockDetailsResponse, BlockListResponse, BlockOutputResponse, ClarificationNeededResponse, @@ -971,6 +972,7 @@ ToolResponseUnion = ( | AgentSavedResponse | ClarificationNeededResponse | BlockListResponse + | BlockDetailsResponse | BlockOutputResponse | DocSearchResultsResponse | DocPageResponse diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py b/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py index 6a8cfa9bbc..55b1c0d510 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py @@ -7,7 +7,6 @@ from backend.api.features.chat.model import ChatSession from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase from backend.api.features.chat.tools.models import ( BlockInfoSummary, - BlockInputFieldInfo, BlockListResponse, ErrorResponse, NoResultsResponse, @@ -55,7 +54,8 @@ class FindBlockTool(BaseTool): "Blocks are reusable components that perform specific tasks like " "sending emails, making API calls, processing text, etc. " "IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. " - "The response includes each block's id, required_inputs, and input_schema." + "The response includes each block's id, name, and description. " + "Call run_block with the block's id **with no inputs** to see detailed inputs/outputs and execute it." ) @property @@ -124,7 +124,7 @@ class FindBlockTool(BaseTool): session_id=session_id, ) - # Enrich results with full block information + # Enrich results with block information blocks: list[BlockInfoSummary] = [] for result in results: block_id = result["content_id"] @@ -141,65 +141,11 @@ class FindBlockTool(BaseTool): ): continue - # Get input/output schemas - input_schema = {} - output_schema = {} - try: - input_schema = block.input_schema.jsonschema() - except Exception as e: - logger.debug( - "Failed to generate input schema for block %s: %s", - block_id, - e, - ) - try: - output_schema = block.output_schema.jsonschema() - except Exception as e: - logger.debug( - "Failed to generate output schema for block %s: %s", - block_id, - e, - ) - - # Get categories from block instance - categories = [] - if hasattr(block, "categories") and block.categories: - categories = [cat.value for cat in block.categories] - - # Extract required inputs for easier use - required_inputs: list[BlockInputFieldInfo] = [] - if input_schema: - properties = input_schema.get("properties", {}) - required_fields = set(input_schema.get("required", [])) - # Get credential field names to exclude from required inputs - credentials_fields = set( - block.input_schema.get_credentials_fields().keys() - ) - - for field_name, field_schema in properties.items(): - # Skip credential fields - they're handled separately - if field_name in credentials_fields: - continue - - required_inputs.append( - BlockInputFieldInfo( - name=field_name, - type=field_schema.get("type", "string"), - description=field_schema.get("description", ""), - required=field_name in required_fields, - default=field_schema.get("default"), - ) - ) - blocks.append( BlockInfoSummary( id=block_id, name=block.name, description=block.description or "", - categories=categories, - input_schema=input_schema, - output_schema=output_schema, - required_inputs=required_inputs, ) ) @@ -228,8 +174,7 @@ class FindBlockTool(BaseTool): return BlockListResponse( message=( f"Found {len(blocks)} block(s) matching '{query}'. " - "To execute a block, use run_block with the block's 'id' field " - "and provide 'input_data' matching the block's input_schema." + "To see a block's inputs/outputs and execute it, use run_block with the block's 'id' - providing no inputs." ), blocks=blocks, count=len(blocks), diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/find_block_test.py b/autogpt_platform/backend/backend/api/features/chat/tools/find_block_test.py index d567a89bbe..44606f81c3 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/find_block_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/find_block_test.py @@ -18,7 +18,13 @@ _TEST_USER_ID = "test-user-find-block" def make_mock_block( - block_id: str, name: str, block_type: BlockType, disabled: bool = False + block_id: str, + name: str, + block_type: BlockType, + disabled: bool = False, + input_schema: dict | None = None, + output_schema: dict | None = None, + credentials_fields: dict | None = None, ): """Create a mock block for testing.""" mock = MagicMock() @@ -28,10 +34,13 @@ def make_mock_block( mock.block_type = block_type mock.disabled = disabled mock.input_schema = MagicMock() - mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []} - mock.input_schema.get_credentials_fields.return_value = {} + mock.input_schema.jsonschema.return_value = input_schema or { + "properties": {}, + "required": [], + } + mock.input_schema.get_credentials_fields.return_value = credentials_fields or {} mock.output_schema = MagicMock() - mock.output_schema.jsonschema.return_value = {} + mock.output_schema.jsonschema.return_value = output_schema or {} mock.categories = [] return mock @@ -137,3 +146,241 @@ class TestFindBlockFiltering: assert isinstance(response, BlockListResponse) assert len(response.blocks) == 1 assert response.blocks[0].id == "normal-block-id" + + @pytest.mark.asyncio(loop_scope="session") + async def test_response_size_average_chars_per_block(self): + """Measure average chars per block in the serialized response.""" + session = make_session(user_id=_TEST_USER_ID) + + # Realistic block definitions modeled after real blocks + block_defs = [ + { + "id": "http-block-id", + "name": "Send Web Request", + "input_schema": { + "properties": { + "url": { + "type": "string", + "description": "The URL to send the request to", + }, + "method": { + "type": "string", + "description": "The HTTP method to use", + }, + "headers": { + "type": "object", + "description": "Headers to include in the request", + }, + "json_format": { + "type": "boolean", + "description": "If true, send the body as JSON", + }, + "body": { + "type": "object", + "description": "Form/JSON body payload", + }, + "credentials": { + "type": "object", + "description": "HTTP credentials", + }, + }, + "required": ["url", "method"], + }, + "output_schema": { + "properties": { + "response": { + "type": "object", + "description": "The response from the server", + }, + "client_error": { + "type": "object", + "description": "Errors on 4xx status codes", + }, + "server_error": { + "type": "object", + "description": "Errors on 5xx status codes", + }, + "error": { + "type": "string", + "description": "Errors for all other exceptions", + }, + }, + }, + "credentials_fields": {"credentials": True}, + }, + { + "id": "email-block-id", + "name": "Send Email", + "input_schema": { + "properties": { + "to_email": { + "type": "string", + "description": "Recipient email address", + }, + "subject": { + "type": "string", + "description": "Subject of the email", + }, + "body": { + "type": "string", + "description": "Body of the email", + }, + "config": { + "type": "object", + "description": "SMTP Config", + }, + "credentials": { + "type": "object", + "description": "SMTP credentials", + }, + }, + "required": ["to_email", "subject", "body", "credentials"], + }, + "output_schema": { + "properties": { + "status": { + "type": "string", + "description": "Status of the email sending operation", + }, + "error": { + "type": "string", + "description": "Error message if sending failed", + }, + }, + }, + "credentials_fields": {"credentials": True}, + }, + { + "id": "claude-code-block-id", + "name": "Claude Code", + "input_schema": { + "properties": { + "e2b_credentials": { + "type": "object", + "description": "API key for E2B platform", + }, + "anthropic_credentials": { + "type": "object", + "description": "API key for Anthropic", + }, + "prompt": { + "type": "string", + "description": "Task or instruction for Claude Code", + }, + "timeout": { + "type": "integer", + "description": "Sandbox timeout in seconds", + }, + "setup_commands": { + "type": "array", + "description": "Shell commands to run before execution", + }, + "working_directory": { + "type": "string", + "description": "Working directory for Claude Code", + }, + "session_id": { + "type": "string", + "description": "Session ID to resume a conversation", + }, + "sandbox_id": { + "type": "string", + "description": "Sandbox ID to reconnect to", + }, + "conversation_history": { + "type": "string", + "description": "Previous conversation history", + }, + "dispose_sandbox": { + "type": "boolean", + "description": "Whether to dispose sandbox after execution", + }, + }, + "required": [ + "e2b_credentials", + "anthropic_credentials", + "prompt", + ], + }, + "output_schema": { + "properties": { + "response": { + "type": "string", + "description": "Output from Claude Code execution", + }, + "files": { + "type": "array", + "description": "Files created/modified by Claude Code", + }, + "conversation_history": { + "type": "string", + "description": "Full conversation history", + }, + "session_id": { + "type": "string", + "description": "Session ID for this conversation", + }, + "sandbox_id": { + "type": "string", + "description": "ID of the sandbox instance", + }, + "error": { + "type": "string", + "description": "Error message if execution failed", + }, + }, + }, + "credentials_fields": { + "e2b_credentials": True, + "anthropic_credentials": True, + }, + }, + ] + + search_results = [ + {"content_id": d["id"], "score": 0.9 - i * 0.1} + for i, d in enumerate(block_defs) + ] + mock_blocks = { + d["id"]: make_mock_block( + block_id=d["id"], + name=d["name"], + block_type=BlockType.STANDARD, + input_schema=d["input_schema"], + output_schema=d["output_schema"], + credentials_fields=d["credentials_fields"], + ) + for d in block_defs + } + + with patch( + "backend.api.features.chat.tools.find_block.unified_hybrid_search", + new_callable=AsyncMock, + return_value=(search_results, len(search_results)), + ), patch( + "backend.api.features.chat.tools.find_block.get_block", + side_effect=lambda bid: mock_blocks.get(bid), + ): + tool = FindBlockTool() + response = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="test" + ) + + assert isinstance(response, BlockListResponse) + assert response.count == len(block_defs) + + total_chars = len(response.model_dump_json()) + avg_chars = total_chars // response.count + + # Print for visibility in test output + print(f"\nTotal response size: {total_chars} chars") + print(f"Number of blocks: {response.count}") + print(f"Average chars per block: {avg_chars}") + + # The old response was ~90K for 10 blocks (~9K per block). + # Previous optimization reduced it to ~1.5K per block (no raw JSON schemas). + # Now with only id/name/description, we expect ~300 chars per block. + assert avg_chars < 500, ( + f"Average chars per block ({avg_chars}) exceeds 500. " + f"Total response: {total_chars} chars for {response.count} blocks." + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/models.py b/autogpt_platform/backend/backend/api/features/chat/tools/models.py index 69c8c6c684..bd19d590a6 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/models.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/models.py @@ -25,6 +25,7 @@ class ResponseType(str, Enum): AGENT_SAVED = "agent_saved" CLARIFICATION_NEEDED = "clarification_needed" BLOCK_LIST = "block_list" + BLOCK_DETAILS = "block_details" BLOCK_OUTPUT = "block_output" DOC_SEARCH_RESULTS = "doc_search_results" DOC_PAGE = "doc_page" @@ -334,13 +335,6 @@ class BlockInfoSummary(BaseModel): id: str name: str description: str - categories: list[str] - input_schema: dict[str, Any] - output_schema: dict[str, Any] - required_inputs: list[BlockInputFieldInfo] = Field( - default_factory=list, - description="List of required input fields for this block", - ) class BlockListResponse(ToolResponseBase): @@ -350,10 +344,25 @@ class BlockListResponse(ToolResponseBase): blocks: list[BlockInfoSummary] count: int query: str - usage_hint: str = Field( - default="To execute a block, call run_block with block_id set to the block's " - "'id' field and input_data containing the required fields from input_schema." - ) + + +class BlockDetails(BaseModel): + """Detailed block information.""" + + id: str + name: str + description: str + inputs: dict[str, Any] = {} + outputs: dict[str, Any] = {} + credentials: list[CredentialsMetaInput] = [] + + +class BlockDetailsResponse(ToolResponseBase): + """Response for block details (first run_block attempt).""" + + type: ResponseType = ResponseType.BLOCK_DETAILS + block: BlockDetails + user_authenticated: bool = False class BlockOutputResponse(ToolResponseBase): diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py index 8c29820f8e..a55478326a 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py @@ -23,8 +23,11 @@ from backend.util.exceptions import BlockError from .base import BaseTool from .helpers import get_inputs_from_schema from .models import ( + BlockDetails, + BlockDetailsResponse, BlockOutputResponse, ErrorResponse, + InputValidationErrorResponse, SetupInfo, SetupRequirementsResponse, ToolResponseBase, @@ -51,8 +54,8 @@ class RunBlockTool(BaseTool): "Execute a specific block with the provided input data. " "IMPORTANT: You MUST call find_block first to get the block's 'id' - " "do NOT guess or make up block IDs. " - "Use the 'id' from find_block results and provide input_data " - "matching the block's required_inputs." + "On first attempt (without input_data), returns detailed schema showing " + "required inputs and outputs. Then call again with proper input_data to execute." ) @property @@ -67,11 +70,19 @@ class RunBlockTool(BaseTool): "NEVER guess this - always get it from find_block first." ), }, + "block_name": { + "type": "string", + "description": ( + "The block's human-readable name from find_block results. " + "Used for display purposes in the UI." + ), + }, "input_data": { "type": "object", "description": ( - "Input values for the block. Use the 'required_inputs' field " - "from find_block to see what fields are needed." + "Input values for the block. " + "First call with empty {} to see the block's schema, " + "then call again with proper values to execute." ), }, }, @@ -156,6 +167,34 @@ class RunBlockTool(BaseTool): await self._resolve_block_credentials(user_id, block, input_data) ) + # Get block schemas for details/validation + try: + input_schema: dict[str, Any] = block.input_schema.jsonschema() + except Exception as e: + logger.warning( + "Failed to generate input schema for block %s: %s", + block_id, + e, + ) + return ErrorResponse( + message=f"Block '{block.name}' has an invalid input schema", + error=str(e), + session_id=session_id, + ) + try: + output_schema: dict[str, Any] = block.output_schema.jsonschema() + except Exception as e: + logger.warning( + "Failed to generate output schema for block %s: %s", + block_id, + e, + ) + return ErrorResponse( + message=f"Block '{block.name}' has an invalid output schema", + error=str(e), + session_id=session_id, + ) + if missing_credentials: # Return setup requirements response with missing credentials credentials_fields_info = block.input_schema.get_credentials_fields_info() @@ -188,6 +227,53 @@ class RunBlockTool(BaseTool): graph_version=None, ) + # Check if this is a first attempt (required inputs missing) + # Return block details so user can see what inputs are needed + credentials_fields = set(block.input_schema.get_credentials_fields().keys()) + required_keys = set(input_schema.get("required", [])) + required_non_credential_keys = required_keys - credentials_fields + provided_input_keys = set(input_data.keys()) - credentials_fields + + # Check for unknown input fields + valid_fields = ( + set(input_schema.get("properties", {}).keys()) - credentials_fields + ) + unrecognized_fields = provided_input_keys - valid_fields + if unrecognized_fields: + return InputValidationErrorResponse( + message=( + f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. " + f"Block was not executed. Please use the correct field names from the schema." + ), + session_id=session_id, + unrecognized_fields=sorted(unrecognized_fields), + inputs=input_schema, + ) + + # Show details when not all required non-credential inputs are provided + if not (required_non_credential_keys <= provided_input_keys): + # Get credentials info for the response + credentials_meta = [] + for field_name, cred_meta in matched_credentials.items(): + credentials_meta.append(cred_meta) + + return BlockDetailsResponse( + message=( + f"Block '{block.name}' details. " + "Provide input_data matching the inputs schema to execute the block." + ), + session_id=session_id, + block=BlockDetails( + id=block_id, + name=block.name, + description=block.description or "", + inputs=input_schema, + outputs=output_schema, + credentials=credentials_meta, + ), + user_authenticated=True, + ) + try: # Get or create user's workspace for CoPilot file operations workspace = await get_or_create_workspace(user_id) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_block_test.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_block_test.py index aadc161155..55efc38479 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_block_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_block_test.py @@ -1,10 +1,15 @@ -"""Tests for block execution guards in RunBlockTool.""" +"""Tests for block execution guards and input validation in RunBlockTool.""" -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest -from backend.api.features.chat.tools.models import ErrorResponse +from backend.api.features.chat.tools.models import ( + BlockDetailsResponse, + BlockOutputResponse, + ErrorResponse, + InputValidationErrorResponse, +) from backend.api.features.chat.tools.run_block import RunBlockTool from backend.blocks._base import BlockType @@ -28,6 +33,39 @@ def make_mock_block( return mock +def make_mock_block_with_schema( + block_id: str, + name: str, + input_properties: dict, + required_fields: list[str], + output_properties: dict | None = None, +): + """Create a mock block with a defined input/output schema for validation tests.""" + mock = MagicMock() + mock.id = block_id + mock.name = name + mock.block_type = BlockType.STANDARD + mock.disabled = False + mock.description = f"Test block: {name}" + + input_schema = { + "properties": input_properties, + "required": required_fields, + } + mock.input_schema = MagicMock() + mock.input_schema.jsonschema.return_value = input_schema + mock.input_schema.get_credentials_fields_info.return_value = {} + mock.input_schema.get_credentials_fields.return_value = {} + + output_schema = { + "properties": output_properties or {"result": {"type": "string"}}, + } + mock.output_schema = MagicMock() + mock.output_schema.jsonschema.return_value = output_schema + + return mock + + class TestRunBlockFiltering: """Tests for block execution guards in RunBlockTool.""" @@ -104,3 +142,221 @@ class TestRunBlockFiltering: # (may be other errors like missing credentials, but not the exclusion guard) if isinstance(response, ErrorResponse): assert "cannot be run directly in CoPilot" not in response.message + + +class TestRunBlockInputValidation: + """Tests for input field validation in RunBlockTool. + + run_block rejects unknown input field names with InputValidationErrorResponse, + preventing silent failures where incorrect keys would be ignored and the block + would execute with default values instead of the caller's intended values. + """ + + @pytest.mark.asyncio(loop_scope="session") + async def test_unknown_input_fields_are_rejected(self): + """run_block rejects unknown input fields instead of silently ignoring them. + + Scenario: The AI Text Generator block has a field called 'model' (for LLM model + selection), but the LLM calling the tool guesses wrong and sends 'LLM_Model' + instead. The block should reject the request and return the valid schema. + """ + session = make_session(user_id=_TEST_USER_ID) + + mock_block = make_mock_block_with_schema( + block_id="ai-text-gen-id", + name="AI Text Generator", + input_properties={ + "prompt": {"type": "string", "description": "The prompt to send"}, + "model": { + "type": "string", + "description": "The LLM model to use", + "default": "gpt-4o-mini", + }, + "sys_prompt": { + "type": "string", + "description": "System prompt", + "default": "", + }, + }, + required_fields=["prompt"], + output_properties={"response": {"type": "string"}}, + ) + + with patch( + "backend.api.features.chat.tools.run_block.get_block", + return_value=mock_block, + ): + tool = RunBlockTool() + + # Provide 'prompt' (correct) but 'LLM_Model' instead of 'model' (wrong key) + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + block_id="ai-text-gen-id", + input_data={ + "prompt": "Write a haiku about coding", + "LLM_Model": "claude-opus-4-6", # WRONG KEY - should be 'model' + }, + ) + + assert isinstance(response, InputValidationErrorResponse) + assert "LLM_Model" in response.unrecognized_fields + assert "Block was not executed" in response.message + assert "inputs" in response.model_dump() # valid schema included + + @pytest.mark.asyncio(loop_scope="session") + async def test_multiple_wrong_keys_are_all_reported(self): + """All unrecognized field names are reported in a single error response.""" + session = make_session(user_id=_TEST_USER_ID) + + mock_block = make_mock_block_with_schema( + block_id="ai-text-gen-id", + name="AI Text Generator", + input_properties={ + "prompt": {"type": "string"}, + "model": {"type": "string", "default": "gpt-4o-mini"}, + "sys_prompt": {"type": "string", "default": ""}, + "retry": {"type": "integer", "default": 3}, + }, + required_fields=["prompt"], + ) + + with patch( + "backend.api.features.chat.tools.run_block.get_block", + return_value=mock_block, + ): + tool = RunBlockTool() + + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + block_id="ai-text-gen-id", + input_data={ + "prompt": "Hello", # correct + "llm_model": "claude-opus-4-6", # WRONG - should be 'model' + "system_prompt": "Be helpful", # WRONG - should be 'sys_prompt' + "retries": 5, # WRONG - should be 'retry' + }, + ) + + assert isinstance(response, InputValidationErrorResponse) + assert set(response.unrecognized_fields) == { + "llm_model", + "system_prompt", + "retries", + } + assert "Block was not executed" in response.message + + @pytest.mark.asyncio(loop_scope="session") + async def test_unknown_fields_rejected_even_with_missing_required(self): + """Unknown fields are caught before the missing-required-fields check.""" + session = make_session(user_id=_TEST_USER_ID) + + mock_block = make_mock_block_with_schema( + block_id="ai-text-gen-id", + name="AI Text Generator", + input_properties={ + "prompt": {"type": "string"}, + "model": {"type": "string", "default": "gpt-4o-mini"}, + }, + required_fields=["prompt"], + ) + + with patch( + "backend.api.features.chat.tools.run_block.get_block", + return_value=mock_block, + ): + tool = RunBlockTool() + + # 'prompt' is missing AND 'LLM_Model' is an unknown field + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + block_id="ai-text-gen-id", + input_data={ + "LLM_Model": "claude-opus-4-6", # wrong key, and 'prompt' is missing + }, + ) + + # Unknown fields are caught first + assert isinstance(response, InputValidationErrorResponse) + assert "LLM_Model" in response.unrecognized_fields + + @pytest.mark.asyncio(loop_scope="session") + async def test_correct_inputs_still_execute(self): + """Correct input field names pass validation and the block executes.""" + session = make_session(user_id=_TEST_USER_ID) + + mock_block = make_mock_block_with_schema( + block_id="ai-text-gen-id", + name="AI Text Generator", + input_properties={ + "prompt": {"type": "string"}, + "model": {"type": "string", "default": "gpt-4o-mini"}, + }, + required_fields=["prompt"], + ) + + async def mock_execute(input_data, **kwargs): + yield "response", "Generated text" + + mock_block.execute = mock_execute + + with ( + patch( + "backend.api.features.chat.tools.run_block.get_block", + return_value=mock_block, + ), + patch( + "backend.api.features.chat.tools.run_block.get_or_create_workspace", + new_callable=AsyncMock, + return_value=MagicMock(id="test-workspace-id"), + ), + ): + tool = RunBlockTool() + + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + block_id="ai-text-gen-id", + input_data={ + "prompt": "Write a haiku", + "model": "gpt-4o-mini", # correct field name + }, + ) + + assert isinstance(response, BlockOutputResponse) + assert response.success is True + + @pytest.mark.asyncio(loop_scope="session") + async def test_missing_required_fields_returns_details(self): + """Missing required fields returns BlockDetailsResponse with schema.""" + session = make_session(user_id=_TEST_USER_ID) + + mock_block = make_mock_block_with_schema( + block_id="ai-text-gen-id", + name="AI Text Generator", + input_properties={ + "prompt": {"type": "string"}, + "model": {"type": "string", "default": "gpt-4o-mini"}, + }, + required_fields=["prompt"], + ) + + with patch( + "backend.api.features.chat.tools.run_block.get_block", + return_value=mock_block, + ): + tool = RunBlockTool() + + # Only provide valid optional field, missing required 'prompt' + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + block_id="ai-text-gen-id", + input_data={ + "model": "gpt-4o-mini", # valid but optional + }, + ) + + assert isinstance(response, BlockDetailsResponse) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/test_run_block_details.py b/autogpt_platform/backend/backend/api/features/chat/tools/test_run_block_details.py new file mode 100644 index 0000000000..fbab0b723d --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/test_run_block_details.py @@ -0,0 +1,153 @@ +"""Tests for BlockDetailsResponse in RunBlockTool.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from backend.api.features.chat.tools.models import BlockDetailsResponse +from backend.api.features.chat.tools.run_block import RunBlockTool +from backend.blocks._base import BlockType +from backend.data.model import CredentialsMetaInput +from backend.integrations.providers import ProviderName + +from ._test_data import make_session + +_TEST_USER_ID = "test-user-run-block-details" + + +def make_mock_block_with_inputs( + block_id: str, name: str, description: str = "Test description" +): + """Create a mock block with input/output schemas for testing.""" + mock = MagicMock() + mock.id = block_id + mock.name = name + mock.description = description + mock.block_type = BlockType.STANDARD + mock.disabled = False + + # Input schema with non-credential fields + mock.input_schema = MagicMock() + mock.input_schema.jsonschema.return_value = { + "properties": { + "url": {"type": "string", "description": "URL to fetch"}, + "method": {"type": "string", "description": "HTTP method"}, + }, + "required": ["url"], + } + mock.input_schema.get_credentials_fields.return_value = {} + mock.input_schema.get_credentials_fields_info.return_value = {} + + # Output schema + mock.output_schema = MagicMock() + mock.output_schema.jsonschema.return_value = { + "properties": { + "response": {"type": "object", "description": "HTTP response"}, + "error": {"type": "string", "description": "Error message"}, + } + } + + return mock + + +@pytest.mark.asyncio(loop_scope="session") +async def test_run_block_returns_details_when_no_input_provided(): + """When run_block is called without input_data, it should return BlockDetailsResponse.""" + session = make_session(user_id=_TEST_USER_ID) + + # Create a block with inputs + http_block = make_mock_block_with_inputs( + "http-block-id", "HTTP Request", "Send HTTP requests" + ) + + with patch( + "backend.api.features.chat.tools.run_block.get_block", + return_value=http_block, + ): + # Mock credentials check to return no missing credentials + with patch.object( + RunBlockTool, + "_resolve_block_credentials", + new_callable=AsyncMock, + return_value=({}, []), # (matched_credentials, missing_credentials) + ): + tool = RunBlockTool() + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + block_id="http-block-id", + input_data={}, # Empty input data + ) + + # Should return BlockDetailsResponse showing the schema + assert isinstance(response, BlockDetailsResponse) + assert response.block.id == "http-block-id" + assert response.block.name == "HTTP Request" + assert response.block.description == "Send HTTP requests" + assert "url" in response.block.inputs["properties"] + assert "method" in response.block.inputs["properties"] + assert "response" in response.block.outputs["properties"] + assert response.user_authenticated is True + + +@pytest.mark.asyncio(loop_scope="session") +async def test_run_block_returns_details_when_only_credentials_provided(): + """When only credentials are provided (no actual input), should return details.""" + session = make_session(user_id=_TEST_USER_ID) + + # Create a block with both credential and non-credential inputs + mock = MagicMock() + mock.id = "api-block-id" + mock.name = "API Call" + mock.description = "Make API calls" + mock.block_type = BlockType.STANDARD + mock.disabled = False + + mock.input_schema = MagicMock() + mock.input_schema.jsonschema.return_value = { + "properties": { + "credentials": {"type": "object", "description": "API credentials"}, + "endpoint": {"type": "string", "description": "API endpoint"}, + }, + "required": ["credentials", "endpoint"], + } + mock.input_schema.get_credentials_fields.return_value = {"credentials": True} + mock.input_schema.get_credentials_fields_info.return_value = {} + + mock.output_schema = MagicMock() + mock.output_schema.jsonschema.return_value = { + "properties": {"result": {"type": "object"}} + } + + with patch( + "backend.api.features.chat.tools.run_block.get_block", + return_value=mock, + ): + with patch.object( + RunBlockTool, + "_resolve_block_credentials", + new_callable=AsyncMock, + return_value=( + { + "credentials": CredentialsMetaInput( + id="cred-id", + provider=ProviderName("test_provider"), + type="api_key", + title="Test Credential", + ) + }, + [], + ), + ): + tool = RunBlockTool() + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + block_id="api-block-id", + input_data={"credentials": {"some": "cred"}}, # Only credential + ) + + # Should return details because no non-credential inputs provided + assert isinstance(response, BlockDetailsResponse) + assert response.block.id == "api-block-id" + assert response.block.name == "API Call" diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/RunBlock.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/RunBlock.tsx index e1cb030449..6e2cbe90d7 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/RunBlock.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/RunBlock.tsx @@ -3,6 +3,7 @@ import type { ToolUIPart } from "ai"; import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation"; import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion"; +import { BlockDetailsCard } from "./components/BlockDetailsCard/BlockDetailsCard"; import { BlockOutputCard } from "./components/BlockOutputCard/BlockOutputCard"; import { ErrorCard } from "./components/ErrorCard/ErrorCard"; import { SetupRequirementsCard } from "./components/SetupRequirementsCard/SetupRequirementsCard"; @@ -11,6 +12,7 @@ import { getAnimationText, getRunBlockToolOutput, isRunBlockBlockOutput, + isRunBlockDetailsOutput, isRunBlockErrorOutput, isRunBlockSetupRequirementsOutput, ToolIcon, @@ -41,6 +43,7 @@ export function RunBlockTool({ part }: Props) { part.state === "output-available" && !!output && (isRunBlockBlockOutput(output) || + isRunBlockDetailsOutput(output) || isRunBlockSetupRequirementsOutput(output) || isRunBlockErrorOutput(output)); @@ -58,6 +61,10 @@ export function RunBlockTool({ part }: Props) { {isRunBlockBlockOutput(output) && } + {isRunBlockDetailsOutput(output) && ( + + )} + {isRunBlockSetupRequirementsOutput(output) && ( )} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/components/BlockDetailsCard/BlockDetailsCard.stories.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/components/BlockDetailsCard/BlockDetailsCard.stories.tsx new file mode 100644 index 0000000000..6e133ca93b --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/components/BlockDetailsCard/BlockDetailsCard.stories.tsx @@ -0,0 +1,188 @@ +import type { Meta, StoryObj } from "@storybook/nextjs"; +import { ResponseType } from "@/app/api/__generated__/models/responseType"; +import type { BlockDetailsResponse } from "../../helpers"; +import { BlockDetailsCard } from "./BlockDetailsCard"; + +const meta: Meta = { + title: "Copilot/RunBlock/BlockDetailsCard", + component: BlockDetailsCard, + parameters: { + layout: "centered", + }, + tags: ["autodocs"], + decorators: [ + (Story) => ( +
+ +
+ ), + ], +}; + +export default meta; +type Story = StoryObj; + +const baseBlock: BlockDetailsResponse = { + type: ResponseType.block_details, + message: + "Here are the details for the GetWeather block. Provide the required inputs to run it.", + session_id: "session-123", + user_authenticated: true, + block: { + id: "block-abc-123", + name: "GetWeather", + description: "Fetches current weather data for a given location.", + inputs: { + type: "object", + properties: { + location: { + title: "Location", + type: "string", + description: + "City name or coordinates (e.g. 'London' or '51.5,-0.1')", + }, + units: { + title: "Units", + type: "string", + description: "Temperature units: 'metric' or 'imperial'", + }, + }, + required: ["location"], + }, + outputs: { + type: "object", + properties: { + temperature: { + title: "Temperature", + type: "number", + description: "Current temperature in the requested units", + }, + condition: { + title: "Condition", + type: "string", + description: "Weather condition description (e.g. 'Sunny', 'Rain')", + }, + }, + }, + credentials: [], + }, +}; + +export const Default: Story = { + args: { + output: baseBlock, + }, +}; + +export const InputsOnly: Story = { + args: { + output: { + ...baseBlock, + message: "This block requires inputs. No outputs are defined.", + block: { + ...baseBlock.block, + outputs: {}, + }, + }, + }, +}; + +export const OutputsOnly: Story = { + args: { + output: { + ...baseBlock, + message: "This block has no required inputs.", + block: { + ...baseBlock.block, + inputs: {}, + }, + }, + }, +}; + +export const ManyFields: Story = { + args: { + output: { + ...baseBlock, + message: "Block with many input and output fields.", + block: { + ...baseBlock.block, + name: "SendEmail", + description: "Sends an email via SMTP.", + inputs: { + type: "object", + properties: { + to: { + title: "To", + type: "string", + description: "Recipient email address", + }, + subject: { + title: "Subject", + type: "string", + description: "Email subject line", + }, + body: { + title: "Body", + type: "string", + description: "Email body content", + }, + cc: { + title: "CC", + type: "string", + description: "CC recipients (comma-separated)", + }, + bcc: { + title: "BCC", + type: "string", + description: "BCC recipients (comma-separated)", + }, + }, + required: ["to", "subject", "body"], + }, + outputs: { + type: "object", + properties: { + message_id: { + title: "Message ID", + type: "string", + description: "Unique ID of the sent email", + }, + status: { + title: "Status", + type: "string", + description: "Delivery status", + }, + }, + }, + }, + }, + }, +}; + +export const NoFieldDescriptions: Story = { + args: { + output: { + ...baseBlock, + message: "Fields without descriptions.", + block: { + ...baseBlock.block, + name: "SimpleBlock", + inputs: { + type: "object", + properties: { + input_a: { title: "Input A", type: "string" }, + input_b: { title: "Input B", type: "number" }, + }, + required: ["input_a"], + }, + outputs: { + type: "object", + properties: { + result: { title: "Result", type: "string" }, + }, + }, + }, + }, + }, +}; diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/components/BlockDetailsCard/BlockDetailsCard.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/components/BlockDetailsCard/BlockDetailsCard.tsx new file mode 100644 index 0000000000..fdbf115222 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/components/BlockDetailsCard/BlockDetailsCard.tsx @@ -0,0 +1,103 @@ +"use client"; + +import type { BlockDetailsResponse } from "../../helpers"; +import { + ContentBadge, + ContentCard, + ContentCardDescription, + ContentCardTitle, + ContentGrid, + ContentMessage, +} from "../../../../components/ToolAccordion/AccordionContent"; + +interface Props { + output: BlockDetailsResponse; +} + +function SchemaFieldList({ + title, + properties, + required, +}: { + title: string; + properties: Record; + required?: string[]; +}) { + const entries = Object.entries(properties); + if (entries.length === 0) return null; + + const requiredSet = new Set(required ?? []); + + return ( + + {title} +
+ {entries.map(([name, schema]) => { + const field = schema as Record | undefined; + const fieldTitle = + typeof field?.title === "string" ? field.title : name; + const fieldType = + typeof field?.type === "string" ? field.type : "unknown"; + const description = + typeof field?.description === "string" + ? field.description + : undefined; + + return ( +
+
+ + {fieldTitle} + +
+ {fieldType} + {requiredSet.has(name) && ( + Required + )} +
+
+ {description && ( + + {description} + + )} +
+ ); + })} +
+
+ ); +} + +export function BlockDetailsCard({ output }: Props) { + const inputs = output.block.inputs as { + properties?: Record; + required?: string[]; + } | null; + const outputs = output.block.outputs as { + properties?: Record; + required?: string[]; + } | null; + + return ( + + {output.message} + + {inputs?.properties && Object.keys(inputs.properties).length > 0 && ( + + )} + + {outputs?.properties && Object.keys(outputs.properties).length > 0 && ( + + )} + + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/helpers.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/helpers.tsx index b8625988cd..6e56154a5e 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/helpers.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunBlock/helpers.tsx @@ -10,18 +10,37 @@ import { import type { ToolUIPart } from "ai"; import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader"; +/** Block details returned on first run_block attempt (before input_data provided). */ +export interface BlockDetailsResponse { + type: typeof ResponseType.block_details; + message: string; + session_id?: string | null; + block: { + id: string; + name: string; + description: string; + inputs: Record; + outputs: Record; + credentials: unknown[]; + }; + user_authenticated: boolean; +} + export interface RunBlockInput { block_id?: string; + block_name?: string; input_data?: Record; } export type RunBlockToolOutput = | SetupRequirementsResponse + | BlockDetailsResponse | BlockOutputResponse | ErrorResponse; const RUN_BLOCK_OUTPUT_TYPES = new Set([ ResponseType.setup_requirements, + ResponseType.block_details, ResponseType.block_output, ResponseType.error, ]); @@ -35,6 +54,15 @@ export function isRunBlockSetupRequirementsOutput( ); } +export function isRunBlockDetailsOutput( + output: RunBlockToolOutput, +): output is BlockDetailsResponse { + return ( + output.type === ResponseType.block_details || + ("block" in output && typeof output.block === "object") + ); +} + export function isRunBlockBlockOutput( output: RunBlockToolOutput, ): output is BlockOutputResponse { @@ -64,6 +92,7 @@ function parseOutput(output: unknown): RunBlockToolOutput | null { return output as RunBlockToolOutput; } if ("block_id" in output) return output as BlockOutputResponse; + if ("block" in output) return output as BlockDetailsResponse; if ("setup_info" in output) return output as SetupRequirementsResponse; if ("error" in output || "details" in output) return output as ErrorResponse; @@ -84,17 +113,25 @@ export function getAnimationText(part: { output?: unknown; }): string { const input = part.input as RunBlockInput | undefined; + const blockName = input?.block_name?.trim(); const blockId = input?.block_id?.trim(); - const blockText = blockId ? ` "${blockId}"` : ""; + // Prefer block_name if available, otherwise fall back to block_id + const blockText = blockName + ? ` "${blockName}"` + : blockId + ? ` "${blockId}"` + : ""; switch (part.state) { case "input-streaming": case "input-available": - return `Running the block${blockText}`; + return `Running${blockText}`; case "output-available": { const output = parseOutput(part.output); - if (!output) return `Running the block${blockText}`; + if (!output) return `Running${blockText}`; if (isRunBlockBlockOutput(output)) return `Ran "${output.block_name}"`; + if (isRunBlockDetailsOutput(output)) + return `Details for "${output.block.name}"`; if (isRunBlockSetupRequirementsOutput(output)) { return `Setup needed for "${output.setup_info.agent_name}"`; } @@ -158,6 +195,21 @@ export function getAccordionMeta(output: RunBlockToolOutput): { }; } + if (isRunBlockDetailsOutput(output)) { + const inputKeys = Object.keys( + (output.block.inputs as { properties?: Record }) + ?.properties ?? {}, + ); + return { + icon, + title: output.block.name, + description: + inputKeys.length > 0 + ? `${inputKeys.length} input field${inputKeys.length === 1 ? "" : "s"} available` + : output.message, + }; + } + if (isRunBlockSetupRequirementsOutput(output)) { const missingCredsCount = Object.keys( (output.setup_info.user_readiness?.missing_credentials ?? {}) as Record< diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index 5d2cb83f7c..496a714ba5 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -1053,6 +1053,7 @@ "$ref": "#/components/schemas/ClarificationNeededResponse" }, { "$ref": "#/components/schemas/BlockListResponse" }, + { "$ref": "#/components/schemas/BlockDetailsResponse" }, { "$ref": "#/components/schemas/BlockOutputResponse" }, { "$ref": "#/components/schemas/DocSearchResultsResponse" }, { "$ref": "#/components/schemas/DocPageResponse" }, @@ -6958,6 +6959,58 @@ "enum": ["run", "byte", "second"], "title": "BlockCostType" }, + "BlockDetails": { + "properties": { + "id": { "type": "string", "title": "Id" }, + "name": { "type": "string", "title": "Name" }, + "description": { "type": "string", "title": "Description" }, + "inputs": { + "additionalProperties": true, + "type": "object", + "title": "Inputs", + "default": {} + }, + "outputs": { + "additionalProperties": true, + "type": "object", + "title": "Outputs", + "default": {} + }, + "credentials": { + "items": { "$ref": "#/components/schemas/CredentialsMetaInput" }, + "type": "array", + "title": "Credentials", + "default": [] + } + }, + "type": "object", + "required": ["id", "name", "description"], + "title": "BlockDetails", + "description": "Detailed block information." + }, + "BlockDetailsResponse": { + "properties": { + "type": { + "$ref": "#/components/schemas/ResponseType", + "default": "block_details" + }, + "message": { "type": "string", "title": "Message" }, + "session_id": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Session Id" + }, + "block": { "$ref": "#/components/schemas/BlockDetails" }, + "user_authenticated": { + "type": "boolean", + "title": "User Authenticated", + "default": false + } + }, + "type": "object", + "required": ["message", "block"], + "title": "BlockDetailsResponse", + "description": "Response for block details (first run_block attempt)." + }, "BlockInfo": { "properties": { "id": { "type": "string", "title": "Id" }, @@ -7013,62 +7066,13 @@ "properties": { "id": { "type": "string", "title": "Id" }, "name": { "type": "string", "title": "Name" }, - "description": { "type": "string", "title": "Description" }, - "categories": { - "items": { "type": "string" }, - "type": "array", - "title": "Categories" - }, - "input_schema": { - "additionalProperties": true, - "type": "object", - "title": "Input Schema" - }, - "output_schema": { - "additionalProperties": true, - "type": "object", - "title": "Output Schema" - }, - "required_inputs": { - "items": { "$ref": "#/components/schemas/BlockInputFieldInfo" }, - "type": "array", - "title": "Required Inputs", - "description": "List of required input fields for this block" - } + "description": { "type": "string", "title": "Description" } }, "type": "object", - "required": [ - "id", - "name", - "description", - "categories", - "input_schema", - "output_schema" - ], + "required": ["id", "name", "description"], "title": "BlockInfoSummary", "description": "Summary of a block for search results." }, - "BlockInputFieldInfo": { - "properties": { - "name": { "type": "string", "title": "Name" }, - "type": { "type": "string", "title": "Type" }, - "description": { - "type": "string", - "title": "Description", - "default": "" - }, - "required": { - "type": "boolean", - "title": "Required", - "default": false - }, - "default": { "anyOf": [{}, { "type": "null" }], "title": "Default" } - }, - "type": "object", - "required": ["name", "type"], - "title": "BlockInputFieldInfo", - "description": "Information about a block input field." - }, "BlockListResponse": { "properties": { "type": { @@ -7086,12 +7090,7 @@ "title": "Blocks" }, "count": { "type": "integer", "title": "Count" }, - "query": { "type": "string", "title": "Query" }, - "usage_hint": { - "type": "string", - "title": "Usage Hint", - "default": "To execute a block, call run_block with block_id set to the block's 'id' field and input_data containing the required fields from input_schema." - } + "query": { "type": "string", "title": "Query" } }, "type": "object", "required": ["message", "blocks", "count", "query"], @@ -10484,6 +10483,7 @@ "agent_saved", "clarification_needed", "block_list", + "block_details", "block_output", "doc_search_results", "doc_page", From 43b25b5e2fdec3fa0579f952d835355cddbd00f8 Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Fri, 13 Feb 2026 11:09:41 +0100 Subject: [PATCH 03/16] ci(frontend): Speed up E2E test job (#12090) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The frontend `e2e_test` doesn't have a working build cache setup, causing really slow builds = slow test jobs. These changes reduce total test runtime from ~12 minutes to ~5 minutes. ### Changes 🏗️ - Inject build cache config into docker compose config; let `buildx bake` use GHA cache directly - Add `docker-ci-fix-compose-build-cache.py` script - Optimize `backend/Dockerfile` + root `.dockerignore` - Replace broken DIY pnpm store caching with `actions/setup-node` built-in cache management - Add caching for test seed data created in DB ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - CI --- .dockerignore | 73 +++--- .github/workflows/platform-frontend-ci.yml | 241 +++++++++--------- .../docker-ci-fix-compose-build-cache.py | 195 ++++++++++++++ autogpt_platform/backend/Dockerfile | 69 +++-- autogpt_platform/docker-compose.platform.yml | 4 +- 5 files changed, 406 insertions(+), 176 deletions(-) create mode 100644 .github/workflows/scripts/docker-ci-fix-compose-build-cache.py diff --git a/.dockerignore b/.dockerignore index 9b744e7f9b..427cab29f4 100644 --- a/.dockerignore +++ b/.dockerignore @@ -5,42 +5,13 @@ !docs/ # Platform - Libs -!autogpt_platform/autogpt_libs/autogpt_libs/ -!autogpt_platform/autogpt_libs/pyproject.toml -!autogpt_platform/autogpt_libs/poetry.lock -!autogpt_platform/autogpt_libs/README.md +!autogpt_platform/autogpt_libs/ # Platform - Backend -!autogpt_platform/backend/backend/ -!autogpt_platform/backend/test/e2e_test_data.py -!autogpt_platform/backend/migrations/ -!autogpt_platform/backend/schema.prisma -!autogpt_platform/backend/pyproject.toml -!autogpt_platform/backend/poetry.lock -!autogpt_platform/backend/README.md -!autogpt_platform/backend/.env -!autogpt_platform/backend/gen_prisma_types_stub.py - -# Platform - Market -!autogpt_platform/market/market/ -!autogpt_platform/market/scripts.py -!autogpt_platform/market/schema.prisma -!autogpt_platform/market/pyproject.toml -!autogpt_platform/market/poetry.lock -!autogpt_platform/market/README.md +!autogpt_platform/backend/ # Platform - Frontend -!autogpt_platform/frontend/src/ -!autogpt_platform/frontend/public/ -!autogpt_platform/frontend/scripts/ -!autogpt_platform/frontend/package.json -!autogpt_platform/frontend/pnpm-lock.yaml -!autogpt_platform/frontend/tsconfig.json -!autogpt_platform/frontend/README.md -## config -!autogpt_platform/frontend/*.config.* -!autogpt_platform/frontend/.env.* -!autogpt_platform/frontend/.env +!autogpt_platform/frontend/ # Classic - AutoGPT !classic/original_autogpt/autogpt/ @@ -64,6 +35,38 @@ # Classic - Frontend !classic/frontend/build/web/ -# Explicitly re-ignore some folders -.* -**/__pycache__ +# Explicitly re-ignore unwanted files from whitelisted directories +# Note: These patterns MUST come after the whitelist rules to take effect + +# Hidden files and directories (but keep frontend .env files needed for build) +**/.* +!autogpt_platform/frontend/.env +!autogpt_platform/frontend/.env.default +!autogpt_platform/frontend/.env.production + +# Python artifacts +**/__pycache__/ +**/*.pyc +**/*.pyo +**/.venv/ +**/.ruff_cache/ +**/.pytest_cache/ +**/.coverage +**/htmlcov/ + +# Node artifacts +**/node_modules/ +**/.next/ +**/storybook-static/ +**/playwright-report/ +**/test-results/ + +# Build artifacts +**/dist/ +**/build/ +!autogpt_platform/frontend/src/**/build/ +**/target/ + +# Logs and temp files +**/*.log +**/*.tmp diff --git a/.github/workflows/platform-frontend-ci.yml b/.github/workflows/platform-frontend-ci.yml index 6410daae9f..4bf8a2b80c 100644 --- a/.github/workflows/platform-frontend-ci.yml +++ b/.github/workflows/platform-frontend-ci.yml @@ -26,7 +26,6 @@ jobs: setup: runs-on: ubuntu-latest outputs: - cache-key: ${{ steps.cache-key.outputs.key }} components-changed: ${{ steps.filter.outputs.components }} steps: @@ -41,28 +40,17 @@ jobs: components: - 'autogpt_platform/frontend/src/components/**' - - name: Set up Node.js - uses: actions/setup-node@v6 - with: - node-version: "22.18.0" - - name: Enable corepack run: corepack enable - - name: Generate cache key - id: cache-key - run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT - - - name: Cache dependencies - uses: actions/cache@v5 + - name: Set up Node + uses: actions/setup-node@v6 with: - path: ~/.pnpm-store - key: ${{ steps.cache-key.outputs.key }} - restore-keys: | - ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} - ${{ runner.os }}-pnpm- + node-version: "22.18.0" + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml - - name: Install dependencies + - name: Install dependencies to populate cache run: pnpm install --frozen-lockfile lint: @@ -73,22 +61,15 @@ jobs: - name: Checkout repository uses: actions/checkout@v6 - - name: Set up Node.js - uses: actions/setup-node@v6 - with: - node-version: "22.18.0" - - name: Enable corepack run: corepack enable - - name: Restore dependencies cache - uses: actions/cache@v5 + - name: Set up Node + uses: actions/setup-node@v6 with: - path: ~/.pnpm-store - key: ${{ needs.setup.outputs.cache-key }} - restore-keys: | - ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} - ${{ runner.os }}-pnpm- + node-version: "22.18.0" + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml - name: Install dependencies run: pnpm install --frozen-lockfile @@ -111,22 +92,15 @@ jobs: with: fetch-depth: 0 - - name: Set up Node.js - uses: actions/setup-node@v6 - with: - node-version: "22.18.0" - - name: Enable corepack run: corepack enable - - name: Restore dependencies cache - uses: actions/cache@v5 + - name: Set up Node + uses: actions/setup-node@v6 with: - path: ~/.pnpm-store - key: ${{ needs.setup.outputs.cache-key }} - restore-keys: | - ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} - ${{ runner.os }}-pnpm- + node-version: "22.18.0" + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml - name: Install dependencies run: pnpm install --frozen-lockfile @@ -141,10 +115,8 @@ jobs: exitOnceUploaded: true e2e_test: + name: end-to-end tests runs-on: big-boi - needs: setup - strategy: - fail-fast: false steps: - name: Checkout repository @@ -152,19 +124,11 @@ jobs: with: submodules: recursive - - name: Set up Node.js - uses: actions/setup-node@v6 - with: - node-version: "22.18.0" - - - name: Enable corepack - run: corepack enable - - - name: Copy default supabase .env + - name: Set up Platform - Copy default supabase .env run: | cp ../.env.default ../.env - - name: Copy backend .env and set OpenAI API key + - name: Set up Platform - Copy backend .env and set OpenAI API key run: | cp ../backend/.env.default ../backend/.env echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env @@ -172,77 +136,125 @@ jobs: # Used by E2E test data script to generate embeddings for approved store agents OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - - name: Set up Docker Buildx + - name: Set up Platform - Set up Docker Buildx uses: docker/setup-buildx-action@v3 + with: + driver: docker-container + driver-opts: network=host - - name: Cache Docker layers + - name: Set up Platform - Expose GHA cache to docker buildx CLI + uses: crazy-max/ghaction-github-runtime@v3 + + - name: Set up Platform - Build Docker images (with cache) + working-directory: autogpt_platform + run: | + pip install pyyaml + + # Resolve extends and generate a flat compose file that bake can understand + docker compose -f docker-compose.yml config > docker-compose.resolved.yml + + # Add cache configuration to the resolved compose file + python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \ + --source docker-compose.resolved.yml \ + --cache-from "type=gha" \ + --cache-to "type=gha,mode=max" \ + --backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \ + --frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \ + --git-ref "${{ github.ref }}" + + # Build with bake using the resolved compose file (now includes cache config) + docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load + env: + NEXT_PUBLIC_PW_TEST: true + + - name: Set up tests - Cache E2E test data + id: e2e-data-cache uses: actions/cache@v5 with: - path: /tmp/.buildx-cache - key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }} - restore-keys: | - ${{ runner.os }}-buildx-frontend-test- + path: /tmp/e2e_test_data.sql + key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }} - - name: Run docker compose + - name: Set up Platform - Start Supabase DB + Auth run: | - NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d + docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build + echo "Waiting for database to be ready..." + timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' + echo "Waiting for auth service to be ready..." + timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..." + + - name: Set up Platform - Run migrations + run: | + echo "Running migrations..." + docker compose -f ../docker-compose.resolved.yml run --rm migrate + echo "✅ Migrations completed" env: - DOCKER_BUILDKIT: 1 - BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache - BUILDX_CACHE_TO: type=local,dest=/tmp/.buildx-cache-new,mode=max + NEXT_PUBLIC_PW_TEST: true - - name: Move cache + - name: Set up tests - Load cached E2E test data + if: steps.e2e-data-cache.outputs.cache-hit == 'true' run: | - rm -rf /tmp/.buildx-cache - if [ -d "/tmp/.buildx-cache-new" ]; then - mv /tmp/.buildx-cache-new /tmp/.buildx-cache - fi + echo "✅ Found cached E2E test data, restoring..." + { + echo "SET session_replication_role = 'replica';" + cat /tmp/e2e_test_data.sql + echo "SET session_replication_role = 'origin';" + } | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b + # Refresh materialized views after restore + docker compose -f ../docker-compose.resolved.yml exec -T db \ + psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true - - name: Wait for services to be ready + echo "✅ E2E test data restored from cache" + + - name: Set up Platform - Start (all other services) run: | + docker compose -f ../docker-compose.resolved.yml up -d --no-build echo "Waiting for rest_server to be ready..." timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..." - echo "Waiting for database to be ready..." - timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..." + env: + NEXT_PUBLIC_PW_TEST: true - - name: Create E2E test data + - name: Set up tests - Create E2E test data + if: steps.e2e-data-cache.outputs.cache-hit != 'true' run: | echo "Creating E2E test data..." - # First try to run the script from inside the container - if docker compose -f ../docker-compose.yml exec -T rest_server test -f /app/autogpt_platform/backend/test/e2e_test_data.py; then - echo "✅ Found e2e_test_data.py in container, running it..." - docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python backend/test/e2e_test_data.py" || { - echo "❌ E2E test data creation failed!" - docker compose -f ../docker-compose.yml logs --tail=50 rest_server - exit 1 - } - else - echo "⚠️ e2e_test_data.py not found in container, copying and running..." - # Copy the script into the container and run it - docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.yml ps -q rest_server):/tmp/e2e_test_data.py || { - echo "❌ Failed to copy script to container" - exit 1 - } - docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || { - echo "❌ E2E test data creation failed!" - docker compose -f ../docker-compose.yml logs --tail=50 rest_server - exit 1 - } - fi + docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py + docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || { + echo "❌ E2E test data creation failed!" + docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server + exit 1 + } - - name: Restore dependencies cache - uses: actions/cache@v5 + # Dump auth.users + platform schema for cache (two separate dumps) + echo "Dumping database for cache..." + { + docker compose -f ../docker-compose.resolved.yml exec -T db \ + pg_dump -U postgres --data-only --column-inserts \ + --table='auth.users' postgres + docker compose -f ../docker-compose.resolved.yml exec -T db \ + pg_dump -U postgres --data-only --column-inserts \ + --schema=platform \ + --exclude-table='platform._prisma_migrations' \ + --exclude-table='platform.apscheduler_jobs' \ + --exclude-table='platform.apscheduler_jobs_batched_notifications' \ + postgres + } > /tmp/e2e_test_data.sql + + echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)" + + - name: Set up tests - Enable corepack + run: corepack enable + + - name: Set up tests - Set up Node + uses: actions/setup-node@v6 with: - path: ~/.pnpm-store - key: ${{ needs.setup.outputs.cache-key }} - restore-keys: | - ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} - ${{ runner.os }}-pnpm- + node-version: "22.18.0" + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml - - name: Install dependencies + - name: Set up tests - Install dependencies run: pnpm install --frozen-lockfile - - name: Install Browser 'chromium' + - name: Set up tests - Install browser 'chromium' run: pnpm playwright install --with-deps chromium - name: Run Playwright tests @@ -269,7 +281,7 @@ jobs: - name: Print Final Docker Compose logs if: always() - run: docker compose -f ../docker-compose.yml logs + run: docker compose -f ../docker-compose.resolved.yml logs integration_test: runs-on: ubuntu-latest @@ -281,22 +293,15 @@ jobs: with: submodules: recursive - - name: Set up Node.js - uses: actions/setup-node@v6 - with: - node-version: "22.18.0" - - name: Enable corepack run: corepack enable - - name: Restore dependencies cache - uses: actions/cache@v5 + - name: Set up Node + uses: actions/setup-node@v6 with: - path: ~/.pnpm-store - key: ${{ needs.setup.outputs.cache-key }} - restore-keys: | - ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} - ${{ runner.os }}-pnpm- + node-version: "22.18.0" + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml - name: Install dependencies run: pnpm install --frozen-lockfile diff --git a/.github/workflows/scripts/docker-ci-fix-compose-build-cache.py b/.github/workflows/scripts/docker-ci-fix-compose-build-cache.py new file mode 100644 index 0000000000..33693fc739 --- /dev/null +++ b/.github/workflows/scripts/docker-ci-fix-compose-build-cache.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +""" +Add cache configuration to a resolved docker-compose file for all services +that have a build key, and ensure image names match what docker compose expects. +""" + +import argparse + +import yaml + + +DEFAULT_BRANCH = "dev" +CACHE_BUILDS_FOR_COMPONENTS = ["backend", "frontend"] + + +def main(): + parser = argparse.ArgumentParser( + description="Add cache config to a resolved compose file" + ) + parser.add_argument( + "--source", + required=True, + help="Source compose file to read (should be output of `docker compose config`)", + ) + parser.add_argument( + "--cache-from", + default="type=gha", + help="Cache source configuration", + ) + parser.add_argument( + "--cache-to", + default="type=gha,mode=max", + help="Cache destination configuration", + ) + for component in CACHE_BUILDS_FOR_COMPONENTS: + parser.add_argument( + f"--{component}-hash", + default="", + help=f"Hash for {component} cache scope (e.g., from hashFiles())", + ) + parser.add_argument( + "--git-ref", + default="", + help="Git ref for branch-based cache scope (e.g., refs/heads/master)", + ) + args = parser.parse_args() + + # Normalize git ref to a safe scope name (e.g., refs/heads/master -> master) + git_ref_scope = "" + if args.git_ref: + git_ref_scope = args.git_ref.replace("refs/heads/", "").replace("/", "-") + + with open(args.source, "r") as f: + compose = yaml.safe_load(f) + + # Get project name from compose file or default + project_name = compose.get("name", "autogpt_platform") + + def get_image_name(dockerfile: str, target: str) -> str: + """Generate image name based on Dockerfile folder and build target.""" + dockerfile_parts = dockerfile.replace("\\", "/").split("/") + if len(dockerfile_parts) >= 2: + folder_name = dockerfile_parts[-2] # e.g., "backend" or "frontend" + else: + folder_name = "app" + return f"{project_name}-{folder_name}:{target}" + + def get_build_key(dockerfile: str, target: str) -> str: + """Generate a unique key for a Dockerfile+target combination.""" + return f"{dockerfile}:{target}" + + def get_component(dockerfile: str) -> str | None: + """Get component name (frontend/backend) from dockerfile path.""" + for component in CACHE_BUILDS_FOR_COMPONENTS: + if component in dockerfile: + return component + return None + + # First pass: collect all services with build configs and identify duplicates + # Track which (dockerfile, target) combinations we've seen + build_key_to_first_service: dict[str, str] = {} + services_to_build: list[str] = [] + services_to_dedupe: list[str] = [] + + for service_name, service_config in compose.get("services", {}).items(): + if "build" not in service_config: + continue + + build_config = service_config["build"] + dockerfile = build_config.get("dockerfile", "Dockerfile") + target = build_config.get("target", "default") + build_key = get_build_key(dockerfile, target) + + if build_key not in build_key_to_first_service: + # First service with this build config - it will do the actual build + build_key_to_first_service[build_key] = service_name + services_to_build.append(service_name) + else: + # Duplicate - will just use the image from the first service + services_to_dedupe.append(service_name) + + # Second pass: configure builds and deduplicate + modified_services = [] + for service_name, service_config in compose.get("services", {}).items(): + if "build" not in service_config: + continue + + build_config = service_config["build"] + dockerfile = build_config.get("dockerfile", "Dockerfile") + target = build_config.get("target", "latest") + image_name = get_image_name(dockerfile, target) + + # Set image name for all services (needed for both builders and deduped) + service_config["image"] = image_name + + if service_name in services_to_dedupe: + # Remove build config - this service will use the pre-built image + del service_config["build"] + continue + + # This service will do the actual build - add cache config + cache_from_list = [] + cache_to_list = [] + + component = get_component(dockerfile) + if not component: + # Skip services that don't clearly match frontend/backend + continue + + # Get the hash for this component + component_hash = getattr(args, f"{component}_hash") + + # Scope format: platform-{component}-{target}-{hash|ref} + # Example: platform-backend-server-abc123 + + if "type=gha" in args.cache_from: + # 1. Primary: exact hash match (most specific) + if component_hash: + hash_scope = f"platform-{component}-{target}-{component_hash}" + cache_from_list.append(f"{args.cache_from},scope={hash_scope}") + + # 2. Fallback: branch-based cache + if git_ref_scope: + ref_scope = f"platform-{component}-{target}-{git_ref_scope}" + cache_from_list.append(f"{args.cache_from},scope={ref_scope}") + + # 3. Fallback: dev branch cache (for PRs/feature branches) + if git_ref_scope and git_ref_scope != DEFAULT_BRANCH: + master_scope = f"platform-{component}-{target}-{DEFAULT_BRANCH}" + cache_from_list.append(f"{args.cache_from},scope={master_scope}") + + if "type=gha" in args.cache_to: + # Write to both hash-based and branch-based scopes + if component_hash: + hash_scope = f"platform-{component}-{target}-{component_hash}" + cache_to_list.append(f"{args.cache_to},scope={hash_scope}") + + if git_ref_scope: + ref_scope = f"platform-{component}-{target}-{git_ref_scope}" + cache_to_list.append(f"{args.cache_to},scope={ref_scope}") + + # Ensure we have at least one cache source/target + if not cache_from_list: + cache_from_list.append(args.cache_from) + if not cache_to_list: + cache_to_list.append(args.cache_to) + + build_config["cache_from"] = cache_from_list + build_config["cache_to"] = cache_to_list + modified_services.append(service_name) + + # Write back to the same file + with open(args.source, "w") as f: + yaml.dump(compose, f, default_flow_style=False, sort_keys=False) + + print(f"Added cache config to {len(modified_services)} services in {args.source}:") + for svc in modified_services: + svc_config = compose["services"][svc] + build_cfg = svc_config.get("build", {}) + cache_from_list = build_cfg.get("cache_from", ["none"]) + cache_to_list = build_cfg.get("cache_to", ["none"]) + print(f" - {svc}") + print(f" image: {svc_config.get('image', 'N/A')}") + print(f" cache_from: {cache_from_list}") + print(f" cache_to: {cache_to_list}") + if services_to_dedupe: + print( + f"Deduplicated {len(services_to_dedupe)} services (will use pre-built images):" + ) + for svc in services_to_dedupe: + print(f" - {svc} -> {compose['services'][svc].get('image', 'N/A')}") + + +if __name__ == "__main__": + main() diff --git a/autogpt_platform/backend/Dockerfile b/autogpt_platform/backend/Dockerfile index 9bd455e490..ace534b730 100644 --- a/autogpt_platform/backend/Dockerfile +++ b/autogpt_platform/backend/Dockerfile @@ -1,3 +1,5 @@ +# ============================ DEPENDENCY BUILDER ============================ # + FROM debian:13-slim AS builder # Set environment variables @@ -51,7 +53,9 @@ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/parti COPY autogpt_platform/backend/gen_prisma_types_stub.py ./ RUN poetry run prisma generate && poetry run gen-prisma-stub -FROM debian:13-slim AS server_dependencies +# ============================== BACKEND SERVER ============================== # + +FROM debian:13-slim AS server WORKDIR /app @@ -63,15 +67,14 @@ ENV POETRY_HOME=/opt/poetry \ ENV PATH=/opt/poetry/bin:$PATH # Install Python, FFmpeg, and ImageMagick (required for video processing blocks) -RUN apt-get update && apt-get install -y \ +# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc. +RUN apt-get update && apt-get install -y --no-install-recommends \ python3.13 \ python3-pip \ ffmpeg \ imagemagick \ && rm -rf /var/lib/apt/lists/* -# Copy only necessary files from builder -COPY --from=builder /app /app COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3* COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry # Copy Node.js installation for Prisma @@ -81,30 +84,54 @@ COPY --from=builder /usr/bin/npm /usr/bin/npm COPY --from=builder /usr/bin/npx /usr/bin/npx COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries -ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH" - -RUN mkdir -p /app/autogpt_platform/autogpt_libs -RUN mkdir -p /app/autogpt_platform/backend - -COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs - -COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/ - WORKDIR /app/autogpt_platform/backend -FROM server_dependencies AS migrate +# Copy only the .venv from builder (not the entire /app directory) +# The .venv includes the generated Prisma client +COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv +ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH" -# Migration stage only needs schema and migrations - much lighter than full backend -COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/ -COPY autogpt_platform/backend/backend/data/partial_types.py /app/autogpt_platform/backend/backend/data/partial_types.py -COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations +# Copy dependency files + autogpt_libs (path dependency) +COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs +COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./ -FROM server_dependencies AS server - -COPY autogpt_platform/backend /app/autogpt_platform/backend +# Copy backend code + docs (for Copilot docs search) +COPY autogpt_platform/backend ./ COPY docs /app/docs RUN poetry install --no-ansi --only-root ENV PORT=8000 CMD ["poetry", "run", "rest"] + +# =============================== DB MIGRATOR =============================== # + +# Lightweight migrate stage - only needs Prisma CLI, not full Python environment +FROM debian:13-slim AS migrate + +WORKDIR /app/autogpt_platform/backend + +ENV DEBIAN_FRONTEND=noninteractive + +# Install only what's needed for prisma migrate: Node.js and minimal Python for prisma-python +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3.13 \ + python3-pip \ + ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +# Copy Node.js from builder (needed for Prisma CLI) +COPY --from=builder /usr/bin/node /usr/bin/node +COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules +COPY --from=builder /usr/bin/npm /usr/bin/npm + +# Copy Prisma binaries +COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries + +# Install prisma-client-py directly (much smaller than copying full venv) +RUN pip3 install prisma>=0.15.0 --break-system-packages + +COPY autogpt_platform/backend/schema.prisma ./ +COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py +COPY autogpt_platform/backend/gen_prisma_types_stub.py ./ +COPY autogpt_platform/backend/migrations ./migrations diff --git a/autogpt_platform/docker-compose.platform.yml b/autogpt_platform/docker-compose.platform.yml index de6ecfd612..bab92d4693 100644 --- a/autogpt_platform/docker-compose.platform.yml +++ b/autogpt_platform/docker-compose.platform.yml @@ -37,7 +37,7 @@ services: context: ../ dockerfile: autogpt_platform/backend/Dockerfile target: migrate - command: ["sh", "-c", "poetry run prisma generate && poetry run gen-prisma-stub && poetry run prisma migrate deploy"] + command: ["sh", "-c", "prisma generate && python3 gen_prisma_types_stub.py && prisma migrate deploy"] develop: watch: - path: ./ @@ -56,7 +56,7 @@ services: test: [ "CMD-SHELL", - "poetry run prisma migrate status | grep -q 'No pending migrations' || exit 1", + "prisma migrate status | grep -q 'No pending migrations' || exit 1", ] interval: 30s timeout: 10s From dfa517300bf67fcd9cdded9875d71a1195161d07 Mon Sep 17 00:00:00 2001 From: Otto Date: Fri, 13 Feb 2026 13:15:17 +0000 Subject: [PATCH 04/16] debug(copilot): Add detailed API error logging (#11942) ## Summary Adds comprehensive error logging for OpenRouter/OpenAI API errors to help diagnose issues like provider routing failures, context length exceeded, rate limits, etc. ## Background While investigating [SECRT-1859](https://linear.app/autogpt/issue/SECRT-1859), we found that when OpenRouter returns errors, the actual error details weren't being captured or logged. Langfuse traces showed `provider_name: 'unknown'` and `completion: null` without any insight into WHY all providers rejected the request. ## Changes - Add `_extract_api_error_details()` to extract rich information from API errors including: - Status code and request ID - Response body (contains OpenRouter's actual error message) - OpenRouter-specific headers (provider, model) - Rate limit headers - Add `_log_api_error()` helper that logs errors with context: - Session ID for correlation - Message count (helps identify context length issues) - Model being used - Retry count - Update error handling in `_stream_chat_chunks()` and `_generate_llm_continuation()` to use new logging - Extract provider's error message from response body for better user feedback ## Example log output ``` API error: { 'error_type': 'APIStatusError', 'error_message': 'Provider returned error', 'status_code': 400, 'request_id': 'req_xxx', 'response_body': {'error': {'message': 'context_length_exceeded', 'type': 'invalid_request_error'}}, 'openrouter_provider': 'unknown', 'session_id': '44fbb803-...', 'message_count': 52, 'model': 'anthropic/claude-opus-4.5', 'retry_count': 0 } ``` ## Testing - [ ] Verified code passes linting (black, isort, ruff) - [ ] Error details are properly extracted from different error types ## Refs - Linear: SECRT-1859 - Thread: https://discord.com/channels/1126875755960336515/1467066151002571034 --------- Co-authored-by: Reinier van der Leer --- .../backend/api/features/chat/service.py | 140 ++++++++++++++++-- 1 file changed, 127 insertions(+), 13 deletions(-) diff --git a/autogpt_platform/backend/backend/api/features/chat/service.py b/autogpt_platform/backend/backend/api/features/chat/service.py index 193566ea01..b8ddc35960 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/service.py @@ -1245,6 +1245,7 @@ async def _stream_chat_chunks( return except Exception as e: last_error = e + if _is_retryable_error(e) and retry_count < MAX_RETRIES: retry_count += 1 # Calculate delay with exponential backoff @@ -1260,12 +1261,27 @@ async def _stream_chat_chunks( continue # Retry the stream else: # Non-retryable error or max retries exceeded - logger.error( - f"Error in stream (not retrying): {e!s}", - exc_info=True, + _log_api_error( + error=e, + context="stream (not retrying)", + session_id=session.session_id if session else None, + message_count=len(messages) if messages else None, + model=model, + retry_count=retry_count, ) error_code = None error_text = str(e) + + error_details = _extract_api_error_details(e) + if error_details.get("response_body"): + body = error_details["response_body"] + if isinstance(body, dict): + err = body.get("error") + if isinstance(err, dict) and err.get("message"): + error_text = err["message"] + elif body.get("message"): + error_text = body["message"] + if _is_region_blocked_error(e): error_code = "MODEL_NOT_AVAILABLE_REGION" error_text = ( @@ -1282,9 +1298,13 @@ async def _stream_chat_chunks( # If we exit the retry loop without returning, it means we exhausted retries if last_error: - logger.error( - f"Max retries ({MAX_RETRIES}) exceeded. Last error: {last_error!s}", - exc_info=True, + _log_api_error( + error=last_error, + context=f"stream (max retries {MAX_RETRIES} exceeded)", + session_id=session.session_id if session else None, + message_count=len(messages) if messages else None, + model=model, + retry_count=MAX_RETRIES, ) yield StreamError(errorText=f"Max retries exceeded: {last_error!s}") yield StreamFinish() @@ -1857,6 +1877,7 @@ async def _generate_llm_continuation( break # Success, exit retry loop except Exception as e: last_error = e + if _is_retryable_error(e) and retry_count < MAX_RETRIES: retry_count += 1 delay = min( @@ -1870,17 +1891,25 @@ async def _generate_llm_continuation( await asyncio.sleep(delay) continue else: - # Non-retryable error - log and exit gracefully - logger.error( - f"Non-retryable error in LLM continuation: {e!s}", - exc_info=True, + # Non-retryable error - log details and exit gracefully + _log_api_error( + error=e, + context="LLM continuation (not retrying)", + session_id=session_id, + message_count=len(messages) if messages else None, + model=config.model, + retry_count=retry_count, ) return if last_error: - logger.error( - f"Max retries ({MAX_RETRIES}) exceeded for LLM continuation. " - f"Last error: {last_error!s}" + _log_api_error( + error=last_error, + context=f"LLM continuation (max retries {MAX_RETRIES} exceeded)", + session_id=session_id, + message_count=len(messages) if messages else None, + model=config.model, + retry_count=MAX_RETRIES, ) return @@ -1920,6 +1949,91 @@ async def _generate_llm_continuation( logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True) +def _log_api_error( + error: Exception, + context: str, + session_id: str | None = None, + message_count: int | None = None, + model: str | None = None, + retry_count: int = 0, +) -> None: + """Log detailed API error information for debugging.""" + details = _extract_api_error_details(error) + details["context"] = context + details["session_id"] = session_id + details["message_count"] = message_count + details["model"] = model + details["retry_count"] = retry_count + + if isinstance(error, RateLimitError): + logger.warning(f"Rate limit error in {context}: {details}", exc_info=error) + elif isinstance(error, APIConnectionError): + logger.warning(f"API connection error in {context}: {details}", exc_info=error) + elif isinstance(error, APIStatusError) and error.status_code >= 500: + logger.error(f"API server error (5xx) in {context}: {details}", exc_info=error) + else: + logger.error(f"API error in {context}: {details}", exc_info=error) + + +def _extract_api_error_details(error: Exception) -> dict[str, Any]: + """Extract detailed information from OpenAI/OpenRouter API errors.""" + error_msg = str(error) + details: dict[str, Any] = { + "error_type": type(error).__name__, + "error_message": error_msg[:500] + "..." if len(error_msg) > 500 else error_msg, + } + + if hasattr(error, "code"): + details["code"] = getattr(error, "code", None) + if hasattr(error, "param"): + details["param"] = getattr(error, "param", None) + + if isinstance(error, APIStatusError): + details["status_code"] = error.status_code + details["request_id"] = getattr(error, "request_id", None) + + if hasattr(error, "body") and error.body: + details["response_body"] = _sanitize_error_body(error.body) + + if hasattr(error, "response") and error.response: + headers = error.response.headers + details["openrouter_provider"] = headers.get("x-openrouter-provider") + details["openrouter_model"] = headers.get("x-openrouter-model") + details["retry_after"] = headers.get("retry-after") + details["rate_limit_remaining"] = headers.get("x-ratelimit-remaining") + + return details + + +def _sanitize_error_body( + body: Any, max_length: int = 2000 +) -> dict[str, Any] | str | None: + """Extract only safe fields from error response body to avoid logging sensitive data.""" + if not isinstance(body, dict): + # Non-dict bodies (e.g., HTML error pages) - return truncated string + if body is not None: + body_str = str(body) + if len(body_str) > max_length: + return body_str[:max_length] + "...[truncated]" + return body_str + return None + + safe_fields = ("message", "type", "code", "param", "error") + sanitized: dict[str, Any] = {} + + for field in safe_fields: + if field in body: + value = body[field] + if field == "error" and isinstance(value, dict): + sanitized[field] = _sanitize_error_body(value, max_length) + elif isinstance(value, str) and len(value) > max_length: + sanitized[field] = value[:max_length] + "...[truncated]" + else: + sanitized[field] = value + + return sanitized if sanitized else None + + async def _generate_llm_continuation_with_streaming( session_id: str, user_id: str | None, From 86af8fc856d99527584bc0eb41fea06487568641 Mon Sep 17 00:00:00 2001 From: Otto Date: Fri, 13 Feb 2026 13:48:04 +0000 Subject: [PATCH 05/16] ci: apply E2E CI optimizations to Claude workflows (#12097) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Applies the CI performance optimizations from #12090 to Claude Code workflows. ## Changes ### `claude.yml` & `claude-dependabot.yml` - **pnpm caching**: Replaced manual `actions/cache` with `setup-node` built-in `cache: "pnpm"` - Removes 4 steps (set pnpm store dir, cache step, manual config) → 1 step ### `claude-ci-failure-auto-fix.yml` - **Added dev environment setup** with optimized caching - Now Claude can run lint/tests when fixing CI failures (previously could only edit files) - Uses the same optimized caching patterns ## Dependency This PR is based on #12090 and will merge after it. ## Testing - Workflow YAML syntax validated - Patterns match proven #12090 implementation - CI caching changes fail gracefully to uncached builds ## Linear Fixes [SECRT-1950](https://linear.app/autogpt/issue/SECRT-1950) ## Future Enhancements E2E test data caching could be added to Claude workflows if needed for running integration tests. Currently Claude workflows set up a dev environment but don't run E2E tests by default.

Greptile Overview

Greptile Summary

Applies proven CI performance optimizations to Claude workflows by simplifying pnpm caching and adding dev environment setup to the auto-fix workflow. **Key changes:** - Replaced manual pnpm cache configuration (4 steps) with built-in `setup-node` `cache: "pnpm"` support in `claude.yml` and `claude-dependabot.yml` - Added complete dev environment setup (Python/Poetry + Node.js/pnpm) to `claude-ci-failure-auto-fix.yml` so Claude can run linting and tests when fixing CI failures - Correctly orders `corepack enable` before `setup-node` to ensure pnpm is available for caching The changes mirror the optimizations from PR #12090 and maintain consistency across all Claude workflows.

Confidence Score: 5/5

- This PR is safe to merge with minimal risk - The changes are CI infrastructure optimizations that mirror proven patterns from PR #12090. The pnpm caching simplification reduces complexity without changing functionality (caching failures gracefully fall back to uncached builds). The dev environment setup in the auto-fix workflow is additive and enables Claude to run linting/tests. All YAML syntax is correct and the step ordering follows best practices. - No files require special attention

Sequence Diagram

```mermaid sequenceDiagram participant GHA as GitHub Actions participant Corepack as Corepack participant SetupNode as setup-node@v6 participant Cache as GHA Cache participant pnpm as pnpm Note over GHA,pnpm: Before (Manual Caching) GHA->>SetupNode: Set up Node.js 22 SetupNode-->>GHA: Node.js ready GHA->>Corepack: Enable corepack Corepack-->>GHA: pnpm available GHA->>pnpm: Configure store directory pnpm-->>GHA: Store path set GHA->>Cache: actions/cache (manual key) Cache-->>GHA: Cache restored/missed GHA->>pnpm: Install dependencies pnpm-->>GHA: Dependencies installed Note over GHA,pnpm: After (Built-in Caching) GHA->>Corepack: Enable corepack Corepack-->>GHA: pnpm available GHA->>SetupNode: Set up Node.js 22
cache: "pnpm"
cache-dependency-path: pnpm-lock.yaml SetupNode->>Cache: Auto-detect pnpm store Cache-->>SetupNode: Cache restored/missed SetupNode-->>GHA: Node.js + cache ready GHA->>pnpm: Install dependencies pnpm-->>GHA: Dependencies installed ```
Last reviewed commit: f1681a0 --------- Co-authored-by: Reinier van der Leer Co-authored-by: Ubbe --- .../workflows/claude-ci-failure-auto-fix.yml | 42 +++++ .github/workflows/claude-dependabot.yml | 22 +-- .github/workflows/claude.yml | 22 +-- plans/SECRT-1950-claude-ci-optimizations.md | 165 ++++++++++++++++++ 4 files changed, 217 insertions(+), 34 deletions(-) create mode 100644 plans/SECRT-1950-claude-ci-optimizations.md diff --git a/.github/workflows/claude-ci-failure-auto-fix.yml b/.github/workflows/claude-ci-failure-auto-fix.yml index ab07c8ae10..dbca6dc3f3 100644 --- a/.github/workflows/claude-ci-failure-auto-fix.yml +++ b/.github/workflows/claude-ci-failure-auto-fix.yml @@ -40,6 +40,48 @@ jobs: git checkout -b "$BRANCH_NAME" echo "branch_name=$BRANCH_NAME" >> $GITHUB_OUTPUT + # Backend Python/Poetry setup (so Claude can run linting/tests) + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Set up Python dependency cache + uses: actions/cache@v5 + with: + path: ~/.cache/pypoetry + key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }} + + - name: Install Poetry + run: | + cd autogpt_platform/backend + HEAD_POETRY_VERSION=$(python3 ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry) + curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 - + echo "$HOME/.local/bin" >> $GITHUB_PATH + + - name: Install Python dependencies + working-directory: autogpt_platform/backend + run: poetry install + + - name: Generate Prisma Client + working-directory: autogpt_platform/backend + run: poetry run prisma generate && poetry run gen-prisma-stub + + # Frontend Node.js/pnpm setup (so Claude can run linting/tests) + - name: Enable corepack + run: corepack enable + + - name: Set up Node.js + uses: actions/setup-node@v6 + with: + node-version: "22" + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml + + - name: Install JavaScript dependencies + working-directory: autogpt_platform/frontend + run: pnpm install --frozen-lockfile + - name: Get CI failure details id: failure_details uses: actions/github-script@v8 diff --git a/.github/workflows/claude-dependabot.yml b/.github/workflows/claude-dependabot.yml index da37df6de7..274c6d2cab 100644 --- a/.github/workflows/claude-dependabot.yml +++ b/.github/workflows/claude-dependabot.yml @@ -77,27 +77,15 @@ jobs: run: poetry run prisma generate && poetry run gen-prisma-stub # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) + - name: Enable corepack + run: corepack enable + - name: Set up Node.js uses: actions/setup-node@v6 with: node-version: "22" - - - name: Enable corepack - run: corepack enable - - - name: Set pnpm store directory - run: | - pnpm config set store-dir ~/.pnpm-store - echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV - - - name: Cache frontend dependencies - uses: actions/cache@v5 - with: - path: ~/.pnpm-store - key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }} - restore-keys: | - ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} - ${{ runner.os }}-pnpm- + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml - name: Install JavaScript dependencies working-directory: autogpt_platform/frontend diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index ee901fe5d4..8b8260af6b 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -93,27 +93,15 @@ jobs: run: poetry run prisma generate && poetry run gen-prisma-stub # Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml) + - name: Enable corepack + run: corepack enable + - name: Set up Node.js uses: actions/setup-node@v6 with: node-version: "22" - - - name: Enable corepack - run: corepack enable - - - name: Set pnpm store directory - run: | - pnpm config set store-dir ~/.pnpm-store - echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV - - - name: Cache frontend dependencies - uses: actions/cache@v5 - with: - path: ~/.pnpm-store - key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }} - restore-keys: | - ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} - ${{ runner.os }}-pnpm- + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml - name: Install JavaScript dependencies working-directory: autogpt_platform/frontend diff --git a/plans/SECRT-1950-claude-ci-optimizations.md b/plans/SECRT-1950-claude-ci-optimizations.md new file mode 100644 index 0000000000..15d1419b0e --- /dev/null +++ b/plans/SECRT-1950-claude-ci-optimizations.md @@ -0,0 +1,165 @@ +# Implementation Plan: SECRT-1950 - Apply E2E CI Optimizations to Claude Code Workflows + +## Ticket +[SECRT-1950](https://linear.app/autogpt/issue/SECRT-1950) + +## Summary +Apply Pwuts's CI performance optimizations from PR #12090 to Claude Code workflows. + +## Reference PR +https://github.com/Significant-Gravitas/AutoGPT/pull/12090 + +--- + +## Analysis + +### Current State (claude.yml) + +**pnpm caching (lines 104-118):** +```yaml +- name: Set up Node.js + uses: actions/setup-node@v6 + with: + node-version: "22" + +- name: Enable corepack + run: corepack enable + +- name: Set pnpm store directory + run: | + pnpm config set store-dir ~/.pnpm-store + echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV + +- name: Cache frontend dependencies + uses: actions/cache@v5 + with: + path: ~/.pnpm-store + key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }} + restore-keys: | + ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }} + ${{ runner.os }}-pnpm- +``` + +**Docker setup (lines 134-165):** +- Uses `docker-buildx-action@v3` +- Has manual Docker image caching via `actions/cache` +- Runs `docker compose up` without buildx bake optimization + +### Pwuts's Optimizations (PR #12090) + +1. **Simplified pnpm caching** - Use `setup-node` built-in cache: +```yaml +- name: Enable corepack + run: corepack enable + +- name: Set up Node + uses: actions/setup-node@v6 + with: + node-version: "22.18.0" + cache: "pnpm" + cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml +``` + +2. **Docker build caching via buildx bake**: +```yaml +- name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + with: + driver: docker-container + driver-opts: network=host + +- name: Expose GHA cache to docker buildx CLI + uses: crazy-max/ghaction-github-runtime@v3 + +- name: Build Docker images (with cache) + run: | + pip install pyyaml + docker compose -f docker-compose.yml config > docker-compose.resolved.yml + python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \ + --source docker-compose.resolved.yml \ + --cache-from "type=gha" \ + --cache-to "type=gha,mode=max" \ + ... + docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load +``` + +--- + +## Proposed Changes + +### 1. Update pnpm caching in `claude.yml` + +**Before:** +- Manual cache key generation +- Separate `actions/cache` step +- Manual pnpm store directory config + +**After:** +- Use `setup-node` built-in `cache: "pnpm"` option +- Remove manual cache step +- Keep `corepack enable` before `setup-node` + +### 2. Update Docker build in `claude.yml` + +**Before:** +- Manual Docker layer caching via `actions/cache` with `/tmp/.buildx-cache` +- Simple `docker compose build` + +**After:** +- Use `crazy-max/ghaction-github-runtime@v3` to expose GHA cache +- Use `docker-ci-fix-compose-build-cache.py` script +- Build with `docker buildx bake` + +### 3. Apply same changes to other Claude workflows + +- `claude-dependabot.yml` - Check if it has similar patterns +- `claude-ci-failure-auto-fix.yml` - Check if it has similar patterns +- `copilot-setup-steps.yml` - Reusable workflow, may be the source of truth + +--- + +## Files to Modify + +1. `.github/workflows/claude.yml` +2. `.github/workflows/claude-dependabot.yml` (if applicable) +3. `.github/workflows/claude-ci-failure-auto-fix.yml` (if applicable) + +## Dependencies + +- PR #12090 must be merged first (provides the `docker-ci-fix-compose-build-cache.py` script) +- Backend Dockerfile optimizations (already in PR #12090) + +--- + +## Test Plan + +1. Create PR with changes +2. Trigger Claude workflow manually or via `@claude` mention on a test issue +3. Compare CI runtime before/after +4. Verify Claude agent still works correctly (can checkout, build, run tests) + +--- + +## Risk Assessment + +**Low risk:** +- These are CI infrastructure changes, not code changes +- If caching fails, builds fall back to uncached (slower but works) +- Changes mirror proven patterns from PR #12090 + +--- + +## Questions for Reviewer + +1. Should we wait for PR #12090 to merge before creating this PR? +2. Does `copilot-setup-steps.yml` need updating, or is it a separate concern? +3. Any concerns about cache key collisions between frontend E2E and Claude workflows? + +--- + +## Verified + +- ✅ **`claude-dependabot.yml`**: Has same pnpm caching pattern as `claude.yml` (manual `actions/cache`) — NEEDS UPDATE +- ✅ **`claude-ci-failure-auto-fix.yml`**: Simple workflow with no pnpm or Docker caching — NO CHANGES NEEDED +- ✅ **Script path**: `docker-ci-fix-compose-build-cache.py` will be at `.github/workflows/scripts/` after PR #12090 merges +- ✅ **Test seed caching**: NOT APPLICABLE — Claude workflows spin up a dev environment but don't run E2E tests with pre-seeded data. The seed caching in PR #12090 is specific to the frontend E2E test suite which needs consistent test data. Claude just needs the services running. From 5035b69c792bc071da97507796d61e59cb80bd66 Mon Sep 17 00:00:00 2001 From: Swifty Date: Fri, 13 Feb 2026 15:27:00 +0100 Subject: [PATCH 06/16] feat(platform): add feature request tools for CoPilot chat (#12102) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Users can now search for existing feature requests and submit new ones directly through the CoPilot chat interface. Requests are tracked in Linear with customer need attribution. ### Changes 🏗️ **Backend:** - Added `SearchFeatureRequestsTool` and `CreateFeatureRequestTool` to the CoPilot chat tools registry - Integrated with Linear GraphQL API for searching issues in the feature requests project, creating new issues, upserting customers, and attaching customer needs - Added `linear_api_key` secret to settings for system-level Linear API access - Added response models (`FeatureRequestSearchResponse`, `FeatureRequestCreatedResponse`, `FeatureRequestInfo`) to the tools models **Frontend:** - Added `SearchFeatureRequestsTool` and `CreateFeatureRequestTool` UI components with full streaming state handling (input-streaming, input-available, output-available, output-error) - Added helper utilities for output parsing, type guards, animation text, and icon rendering - Wired tools into `ChatMessagesContainer` for rendering in the chat - Added styleguide examples covering all tool states ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Verified search returns matching feature requests from Linear - [x] Verified creating a new feature request creates an issue and customer need in Linear - [x] Verified adding a need to an existing issue works via `existing_issue_id` - [x] Verified error states render correctly in the UI - [x] Verified styleguide page renders all tool states #### For configuration changes: - [x] `.env.default` is updated or already compatible with my changes - [x] I have included a list of my configuration changes in the PR description (under **Changes**) New secret: `LINEAR_API_KEY` — required for system-level Linear API operations (defaults to empty string).

Greptile Overview

Greptile Summary

Adds feature request search and creation tools to CoPilot chat, integrating with Linear's GraphQL API to track user feedback. Users can now search existing feature requests and submit new ones (or add their need to existing issues) directly through conversation. **Key changes:** - Backend: `SearchFeatureRequestsTool` and `CreateFeatureRequestTool` with Linear API integration via system-level `LINEAR_API_KEY` - Frontend: React components with streaming state handling and accordion UI for search results and creation confirmations - Models: Added `FeatureRequestSearchResponse` and `FeatureRequestCreatedResponse` to response types - Customer need tracking: Upserts customers in Linear and attaches needs to issues for better feedback attribution **Issues found:** - Missing `LINEAR_API_KEY` entry in `.env.default` (required per PR description checklist) - Hardcoded project/team IDs reduce maintainability - Global singleton pattern could cause issues in async contexts - Using `user_id` as customer name reduces readability in Linear

Confidence Score: 4/5

- Safe to merge with minor configuration fix required - The implementation is well-structured with proper error handling, type safety, and follows existing patterns in the codebase. The missing `.env.default` entry is a straightforward configuration issue that must be fixed before deployment but doesn't affect code quality. The other findings are style improvements that don't impact functionality. - Verify that `LINEAR_API_KEY` is added to `.env.default` before merging

Sequence Diagram

```mermaid sequenceDiagram participant User participant CoPilot UI participant LLM participant FeatureRequestTool participant LinearClient participant Linear API User->>CoPilot UI: Request feature via chat CoPilot UI->>LLM: Send user message LLM->>FeatureRequestTool: search_feature_requests(query) FeatureRequestTool->>LinearClient: query(SEARCH_ISSUES_QUERY) LinearClient->>Linear API: POST /graphql (search) Linear API-->>LinearClient: searchIssues.nodes[] LinearClient-->>FeatureRequestTool: Feature request data FeatureRequestTool-->>LLM: FeatureRequestSearchResponse alt No existing requests found LLM->>FeatureRequestTool: create_feature_request(title, description) FeatureRequestTool->>LinearClient: mutate(CUSTOMER_UPSERT_MUTATION) LinearClient->>Linear API: POST /graphql (upsert customer) Linear API-->>LinearClient: customer {id, name} LinearClient-->>FeatureRequestTool: Customer data FeatureRequestTool->>LinearClient: mutate(ISSUE_CREATE_MUTATION) LinearClient->>Linear API: POST /graphql (create issue) Linear API-->>LinearClient: issue {id, identifier, url} LinearClient-->>FeatureRequestTool: Issue data FeatureRequestTool->>LinearClient: mutate(CUSTOMER_NEED_CREATE_MUTATION) LinearClient->>Linear API: POST /graphql (attach need) Linear API-->>LinearClient: need {id, issue} LinearClient-->>FeatureRequestTool: Need data FeatureRequestTool-->>LLM: FeatureRequestCreatedResponse else Existing request found LLM->>FeatureRequestTool: create_feature_request(title, description, existing_issue_id) FeatureRequestTool->>LinearClient: mutate(CUSTOMER_UPSERT_MUTATION) LinearClient->>Linear API: POST /graphql (upsert customer) Linear API-->>LinearClient: customer {id} LinearClient-->>FeatureRequestTool: Customer data FeatureRequestTool->>LinearClient: mutate(CUSTOMER_NEED_CREATE_MUTATION) LinearClient->>Linear API: POST /graphql (attach need to existing) Linear API-->>LinearClient: need {id, issue} LinearClient-->>FeatureRequestTool: Need data FeatureRequestTool-->>LLM: FeatureRequestCreatedResponse end LLM-->>CoPilot UI: Tool response + continuation CoPilot UI-->>User: Display result with accordion UI ```
Last reviewed commit: af2e093 --- autogpt_platform/backend/.env.default | 6 + .../api/features/chat/tools/__init__.py | 4 + .../features/chat/tools/feature_requests.py | 448 +++++++++++++ .../chat/tools/feature_requests_test.py | 615 ++++++++++++++++++ .../backend/api/features/chat/tools/models.py | 34 + .../backend/backend/util/settings.py | 11 + .../ChatMessagesContainer.tsx | 18 + .../(platform)/copilot/styleguide/page.tsx | 235 +++++++ .../tools/FeatureRequests/FeatureRequests.tsx | 227 +++++++ .../copilot/tools/FeatureRequests/helpers.tsx | 271 ++++++++ .../frontend/src/app/api/openapi.json | 4 +- 11 files changed, 1872 insertions(+), 1 deletion(-) create mode 100644 autogpt_platform/backend/backend/api/features/chat/tools/feature_requests.py create mode 100644 autogpt_platform/backend/backend/api/features/chat/tools/feature_requests_test.py create mode 100644 autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/FeatureRequests.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/helpers.tsx diff --git a/autogpt_platform/backend/.env.default b/autogpt_platform/backend/.env.default index fa52ba812a..2711bd2df9 100644 --- a/autogpt_platform/backend/.env.default +++ b/autogpt_platform/backend/.env.default @@ -104,6 +104,12 @@ TWITTER_CLIENT_SECRET= # Make a new workspace for your OAuth APP -- trust me # https://linear.app/settings/api/applications/new # Callback URL: http://localhost:3000/auth/integrations/oauth_callback +LINEAR_API_KEY= +# Linear project and team IDs for the feature request tracker. +# Find these in your Linear workspace URL: linear.app//project/ +# and in team settings. Used by the chat copilot to file and search feature requests. +LINEAR_FEATURE_REQUEST_PROJECT_ID= +LINEAR_FEATURE_REQUEST_TEAM_ID= LINEAR_CLIENT_ID= LINEAR_CLIENT_SECRET= diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py b/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py index dcbc35ef37..350776081a 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py @@ -12,6 +12,7 @@ from .base import BaseTool from .create_agent import CreateAgentTool from .customize_agent import CustomizeAgentTool from .edit_agent import EditAgentTool +from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsTool from .find_agent import FindAgentTool from .find_block import FindBlockTool from .find_library_agent import FindLibraryAgentTool @@ -45,6 +46,9 @@ TOOL_REGISTRY: dict[str, BaseTool] = { "view_agent_output": AgentOutputTool(), "search_docs": SearchDocsTool(), "get_doc_page": GetDocPageTool(), + # Feature request tools + "search_feature_requests": SearchFeatureRequestsTool(), + "create_feature_request": CreateFeatureRequestTool(), # Workspace tools for CoPilot file operations "list_workspace_files": ListWorkspaceFilesTool(), "read_workspace_file": ReadWorkspaceFileTool(), diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests.py b/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests.py new file mode 100644 index 0000000000..95f1eb1fbe --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests.py @@ -0,0 +1,448 @@ +"""Feature request tools - search and create feature requests via Linear.""" + +import logging +from typing import Any + +from pydantic import SecretStr + +from backend.api.features.chat.model import ChatSession +from backend.api.features.chat.tools.base import BaseTool +from backend.api.features.chat.tools.models import ( + ErrorResponse, + FeatureRequestCreatedResponse, + FeatureRequestInfo, + FeatureRequestSearchResponse, + NoResultsResponse, + ToolResponseBase, +) +from backend.blocks.linear._api import LinearClient +from backend.data.model import APIKeyCredentials +from backend.data.user import get_user_email_by_id +from backend.util.settings import Settings + +logger = logging.getLogger(__name__) + +MAX_SEARCH_RESULTS = 10 + +# GraphQL queries/mutations +SEARCH_ISSUES_QUERY = """ +query SearchFeatureRequests($term: String!, $filter: IssueFilter, $first: Int) { + searchIssues(term: $term, filter: $filter, first: $first) { + nodes { + id + identifier + title + description + } + } +} +""" + +CUSTOMER_UPSERT_MUTATION = """ +mutation CustomerUpsert($input: CustomerUpsertInput!) { + customerUpsert(input: $input) { + success + customer { + id + name + externalIds + } + } +} +""" + +ISSUE_CREATE_MUTATION = """ +mutation IssueCreate($input: IssueCreateInput!) { + issueCreate(input: $input) { + success + issue { + id + identifier + title + url + } + } +} +""" + +CUSTOMER_NEED_CREATE_MUTATION = """ +mutation CustomerNeedCreate($input: CustomerNeedCreateInput!) { + customerNeedCreate(input: $input) { + success + need { + id + body + customer { + id + name + } + issue { + id + identifier + title + url + } + } + } +} +""" + + +_settings: Settings | None = None + + +def _get_settings() -> Settings: + global _settings + if _settings is None: + _settings = Settings() + return _settings + + +def _get_linear_config() -> tuple[LinearClient, str, str]: + """Return a configured Linear client, project ID, and team ID. + + Raises RuntimeError if any required setting is missing. + """ + secrets = _get_settings().secrets + if not secrets.linear_api_key: + raise RuntimeError("LINEAR_API_KEY is not configured") + if not secrets.linear_feature_request_project_id: + raise RuntimeError("LINEAR_FEATURE_REQUEST_PROJECT_ID is not configured") + if not secrets.linear_feature_request_team_id: + raise RuntimeError("LINEAR_FEATURE_REQUEST_TEAM_ID is not configured") + + credentials = APIKeyCredentials( + id="system-linear", + provider="linear", + api_key=SecretStr(secrets.linear_api_key), + title="System Linear API Key", + ) + client = LinearClient(credentials=credentials) + return ( + client, + secrets.linear_feature_request_project_id, + secrets.linear_feature_request_team_id, + ) + + +class SearchFeatureRequestsTool(BaseTool): + """Tool for searching existing feature requests in Linear.""" + + @property + def name(self) -> str: + return "search_feature_requests" + + @property + def description(self) -> str: + return ( + "Search existing feature requests to check if a similar request " + "already exists before creating a new one. Returns matching feature " + "requests with their ID, title, and description." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search term to find matching feature requests.", + }, + }, + "required": ["query"], + } + + @property + def requires_auth(self) -> bool: + return True + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + query = kwargs.get("query", "").strip() + session_id = session.session_id if session else None + + if not query: + return ErrorResponse( + message="Please provide a search query.", + error="Missing query parameter", + session_id=session_id, + ) + + try: + client, project_id, _team_id = _get_linear_config() + data = await client.query( + SEARCH_ISSUES_QUERY, + { + "term": query, + "filter": { + "project": {"id": {"eq": project_id}}, + }, + "first": MAX_SEARCH_RESULTS, + }, + ) + + nodes = data.get("searchIssues", {}).get("nodes", []) + + if not nodes: + return NoResultsResponse( + message=f"No feature requests found matching '{query}'.", + suggestions=[ + "Try different keywords", + "Use broader search terms", + "You can create a new feature request if none exists", + ], + session_id=session_id, + ) + + results = [ + FeatureRequestInfo( + id=node["id"], + identifier=node["identifier"], + title=node["title"], + description=node.get("description"), + ) + for node in nodes + ] + + return FeatureRequestSearchResponse( + message=f"Found {len(results)} feature request(s) matching '{query}'.", + results=results, + count=len(results), + query=query, + session_id=session_id, + ) + except Exception as e: + logger.exception("Failed to search feature requests") + return ErrorResponse( + message="Failed to search feature requests.", + error=str(e), + session_id=session_id, + ) + + +class CreateFeatureRequestTool(BaseTool): + """Tool for creating feature requests (or adding needs to existing ones).""" + + @property + def name(self) -> str: + return "create_feature_request" + + @property + def description(self) -> str: + return ( + "Create a new feature request or add a customer need to an existing one. " + "Always search first with search_feature_requests to avoid duplicates. " + "If a matching request exists, pass its ID as existing_issue_id to add " + "the user's need to it instead of creating a duplicate." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "Title for the feature request.", + }, + "description": { + "type": "string", + "description": "Detailed description of what the user wants and why.", + }, + "existing_issue_id": { + "type": "string", + "description": ( + "If adding a need to an existing feature request, " + "provide its Linear issue ID (from search results). " + "Omit to create a new feature request." + ), + }, + }, + "required": ["title", "description"], + } + + @property + def requires_auth(self) -> bool: + return True + + async def _find_or_create_customer( + self, client: LinearClient, user_id: str, name: str + ) -> dict: + """Find existing customer by user_id or create a new one via upsert. + + Args: + client: Linear API client. + user_id: Stable external ID used to deduplicate customers. + name: Human-readable display name (e.g. the user's email). + """ + data = await client.mutate( + CUSTOMER_UPSERT_MUTATION, + { + "input": { + "name": name, + "externalId": user_id, + }, + }, + ) + result = data.get("customerUpsert", {}) + if not result.get("success"): + raise RuntimeError(f"Failed to upsert customer: {data}") + return result["customer"] + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + title = kwargs.get("title", "").strip() + description = kwargs.get("description", "").strip() + existing_issue_id = kwargs.get("existing_issue_id") + session_id = session.session_id if session else None + + if not title or not description: + return ErrorResponse( + message="Both title and description are required.", + error="Missing required parameters", + session_id=session_id, + ) + + if not user_id: + return ErrorResponse( + message="Authentication required to create feature requests.", + error="Missing user_id", + session_id=session_id, + ) + + try: + client, project_id, team_id = _get_linear_config() + except Exception as e: + logger.exception("Failed to initialize Linear client") + return ErrorResponse( + message="Failed to create feature request.", + error=str(e), + session_id=session_id, + ) + + # Resolve a human-readable name (email) for the Linear customer record. + # Fall back to user_id if the lookup fails or returns None. + try: + customer_display_name = await get_user_email_by_id(user_id) or user_id + except Exception: + customer_display_name = user_id + + # Step 1: Find or create customer for this user + try: + customer = await self._find_or_create_customer( + client, user_id, customer_display_name + ) + customer_id = customer["id"] + customer_name = customer["name"] + except Exception as e: + logger.exception("Failed to upsert customer in Linear") + return ErrorResponse( + message="Failed to create feature request.", + error=str(e), + session_id=session_id, + ) + + # Step 2: Create or reuse issue + issue_id: str | None = None + issue_identifier: str | None = None + if existing_issue_id: + # Add need to existing issue - we still need the issue details for response + is_new_issue = False + issue_id = existing_issue_id + else: + # Create new issue in the feature requests project + try: + data = await client.mutate( + ISSUE_CREATE_MUTATION, + { + "input": { + "title": title, + "description": description, + "teamId": team_id, + "projectId": project_id, + }, + }, + ) + result = data.get("issueCreate", {}) + if not result.get("success"): + return ErrorResponse( + message="Failed to create feature request issue.", + error=str(data), + session_id=session_id, + ) + issue = result["issue"] + issue_id = issue["id"] + issue_identifier = issue.get("identifier") + except Exception as e: + logger.exception("Failed to create feature request issue") + return ErrorResponse( + message="Failed to create feature request.", + error=str(e), + session_id=session_id, + ) + is_new_issue = True + + # Step 3: Create customer need on the issue + try: + data = await client.mutate( + CUSTOMER_NEED_CREATE_MUTATION, + { + "input": { + "customerId": customer_id, + "issueId": issue_id, + "body": description, + "priority": 0, + }, + }, + ) + need_result = data.get("customerNeedCreate", {}) + if not need_result.get("success"): + orphaned = ( + {"issue_id": issue_id, "issue_identifier": issue_identifier} + if is_new_issue + else None + ) + return ErrorResponse( + message="Failed to attach customer need to the feature request.", + error=str(data), + details=orphaned, + session_id=session_id, + ) + need = need_result["need"] + issue_info = need["issue"] + except Exception as e: + logger.exception("Failed to create customer need") + orphaned = ( + {"issue_id": issue_id, "issue_identifier": issue_identifier} + if is_new_issue + else None + ) + return ErrorResponse( + message="Failed to attach customer need to the feature request.", + error=str(e), + details=orphaned, + session_id=session_id, + ) + + return FeatureRequestCreatedResponse( + message=( + f"{'Created new feature request' if is_new_issue else 'Added your request to existing feature request'}: " + f"{issue_info['title']}." + ), + issue_id=issue_info["id"], + issue_identifier=issue_info["identifier"], + issue_title=issue_info["title"], + issue_url=issue_info.get("url", ""), + is_new_issue=is_new_issue, + customer_name=customer_name, + session_id=session_id, + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests_test.py b/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests_test.py new file mode 100644 index 0000000000..438725368f --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests_test.py @@ -0,0 +1,615 @@ +"""Tests for SearchFeatureRequestsTool and CreateFeatureRequestTool.""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from backend.api.features.chat.tools.feature_requests import ( + CreateFeatureRequestTool, + SearchFeatureRequestsTool, +) +from backend.api.features.chat.tools.models import ( + ErrorResponse, + FeatureRequestCreatedResponse, + FeatureRequestSearchResponse, + NoResultsResponse, +) + +from ._test_data import make_session + +_TEST_USER_ID = "test-user-feature-requests" +_TEST_USER_EMAIL = "testuser@example.com" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +_FAKE_PROJECT_ID = "test-project-id" +_FAKE_TEAM_ID = "test-team-id" + + +def _mock_linear_config(*, query_return=None, mutate_return=None): + """Return a patched _get_linear_config that yields a mock LinearClient.""" + client = AsyncMock() + if query_return is not None: + client.query.return_value = query_return + if mutate_return is not None: + client.mutate.return_value = mutate_return + return ( + patch( + "backend.api.features.chat.tools.feature_requests._get_linear_config", + return_value=(client, _FAKE_PROJECT_ID, _FAKE_TEAM_ID), + ), + client, + ) + + +def _search_response(nodes: list[dict]) -> dict: + return {"searchIssues": {"nodes": nodes}} + + +def _customer_upsert_response( + customer_id: str = "cust-1", name: str = _TEST_USER_EMAIL, success: bool = True +) -> dict: + return { + "customerUpsert": { + "success": success, + "customer": {"id": customer_id, "name": name, "externalIds": [name]}, + } + } + + +def _issue_create_response( + issue_id: str = "issue-1", + identifier: str = "FR-1", + title: str = "New Feature", + success: bool = True, +) -> dict: + return { + "issueCreate": { + "success": success, + "issue": { + "id": issue_id, + "identifier": identifier, + "title": title, + "url": f"https://linear.app/issue/{identifier}", + }, + } + } + + +def _need_create_response( + need_id: str = "need-1", + issue_id: str = "issue-1", + identifier: str = "FR-1", + title: str = "New Feature", + success: bool = True, +) -> dict: + return { + "customerNeedCreate": { + "success": success, + "need": { + "id": need_id, + "body": "description", + "customer": {"id": "cust-1", "name": _TEST_USER_EMAIL}, + "issue": { + "id": issue_id, + "identifier": identifier, + "title": title, + "url": f"https://linear.app/issue/{identifier}", + }, + }, + } + } + + +# =========================================================================== +# SearchFeatureRequestsTool +# =========================================================================== + + +class TestSearchFeatureRequestsTool: + """Tests for SearchFeatureRequestsTool._execute.""" + + @pytest.mark.asyncio(loop_scope="session") + async def test_successful_search(self): + session = make_session(user_id=_TEST_USER_ID) + nodes = [ + { + "id": "id-1", + "identifier": "FR-1", + "title": "Dark mode", + "description": "Add dark mode support", + }, + { + "id": "id-2", + "identifier": "FR-2", + "title": "Dark theme", + "description": None, + }, + ] + patcher, _ = _mock_linear_config(query_return=_search_response(nodes)) + with patcher: + tool = SearchFeatureRequestsTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="dark mode" + ) + + assert isinstance(resp, FeatureRequestSearchResponse) + assert resp.count == 2 + assert resp.results[0].id == "id-1" + assert resp.results[1].identifier == "FR-2" + assert resp.query == "dark mode" + + @pytest.mark.asyncio(loop_scope="session") + async def test_no_results(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, _ = _mock_linear_config(query_return=_search_response([])) + with patcher: + tool = SearchFeatureRequestsTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="nonexistent" + ) + + assert isinstance(resp, NoResultsResponse) + assert "nonexistent" in resp.message + + @pytest.mark.asyncio(loop_scope="session") + async def test_empty_query_returns_error(self): + session = make_session(user_id=_TEST_USER_ID) + tool = SearchFeatureRequestsTool() + resp = await tool._execute(user_id=_TEST_USER_ID, session=session, query=" ") + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "query" in resp.error.lower() + + @pytest.mark.asyncio(loop_scope="session") + async def test_missing_query_returns_error(self): + session = make_session(user_id=_TEST_USER_ID) + tool = SearchFeatureRequestsTool() + resp = await tool._execute(user_id=_TEST_USER_ID, session=session) + + assert isinstance(resp, ErrorResponse) + + @pytest.mark.asyncio(loop_scope="session") + async def test_api_failure(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.query.side_effect = RuntimeError("Linear API down") + with patcher: + tool = SearchFeatureRequestsTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="test" + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "Linear API down" in resp.error + + @pytest.mark.asyncio(loop_scope="session") + async def test_malformed_node_returns_error(self): + """A node missing required keys should be caught by the try/except.""" + session = make_session(user_id=_TEST_USER_ID) + # Node missing 'identifier' key + bad_nodes = [{"id": "id-1", "title": "Missing identifier"}] + patcher, _ = _mock_linear_config(query_return=_search_response(bad_nodes)) + with patcher: + tool = SearchFeatureRequestsTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="test" + ) + + assert isinstance(resp, ErrorResponse) + + @pytest.mark.asyncio(loop_scope="session") + async def test_linear_client_init_failure(self): + session = make_session(user_id=_TEST_USER_ID) + with patch( + "backend.api.features.chat.tools.feature_requests._get_linear_config", + side_effect=RuntimeError("No API key"), + ): + tool = SearchFeatureRequestsTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, session=session, query="test" + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "No API key" in resp.error + + +# =========================================================================== +# CreateFeatureRequestTool +# =========================================================================== + + +class TestCreateFeatureRequestTool: + """Tests for CreateFeatureRequestTool._execute.""" + + @pytest.fixture(autouse=True) + def _patch_email_lookup(self): + with patch( + "backend.api.features.chat.tools.feature_requests.get_user_email_by_id", + new_callable=AsyncMock, + return_value=_TEST_USER_EMAIL, + ): + yield + + # ---- Happy paths ------------------------------------------------------- + + @pytest.mark.asyncio(loop_scope="session") + async def test_create_new_issue(self): + """Full happy path: upsert customer -> create issue -> attach need.""" + session = make_session(user_id=_TEST_USER_ID) + + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _issue_create_response(), + _need_create_response(), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="New Feature", + description="Please add this", + ) + + assert isinstance(resp, FeatureRequestCreatedResponse) + assert resp.is_new_issue is True + assert resp.issue_identifier == "FR-1" + assert resp.customer_name == _TEST_USER_EMAIL + assert client.mutate.call_count == 3 + + @pytest.mark.asyncio(loop_scope="session") + async def test_add_need_to_existing_issue(self): + """When existing_issue_id is provided, skip issue creation.""" + session = make_session(user_id=_TEST_USER_ID) + + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _need_create_response(issue_id="existing-1", identifier="FR-99"), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Existing Feature", + description="Me too", + existing_issue_id="existing-1", + ) + + assert isinstance(resp, FeatureRequestCreatedResponse) + assert resp.is_new_issue is False + assert resp.issue_id == "existing-1" + # Only 2 mutations: customer upsert + need create (no issue create) + assert client.mutate.call_count == 2 + + # ---- Validation errors ------------------------------------------------- + + @pytest.mark.asyncio(loop_scope="session") + async def test_missing_title(self): + session = make_session(user_id=_TEST_USER_ID) + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="", + description="some desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "required" in resp.error.lower() + + @pytest.mark.asyncio(loop_scope="session") + async def test_missing_description(self): + session = make_session(user_id=_TEST_USER_ID) + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Some title", + description="", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "required" in resp.error.lower() + + @pytest.mark.asyncio(loop_scope="session") + async def test_missing_user_id(self): + session = make_session(user_id=_TEST_USER_ID) + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=None, + session=session, + title="Some title", + description="Some desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "user_id" in resp.error.lower() + + # ---- Linear client init failure ---------------------------------------- + + @pytest.mark.asyncio(loop_scope="session") + async def test_linear_client_init_failure(self): + session = make_session(user_id=_TEST_USER_ID) + with patch( + "backend.api.features.chat.tools.feature_requests._get_linear_config", + side_effect=RuntimeError("No API key"), + ): + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "No API key" in resp.error + + # ---- Customer upsert failures ------------------------------------------ + + @pytest.mark.asyncio(loop_scope="session") + async def test_customer_upsert_api_error(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = RuntimeError("Customer API error") + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "Customer API error" in resp.error + + @pytest.mark.asyncio(loop_scope="session") + async def test_customer_upsert_not_success(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.return_value = _customer_upsert_response(success=False) + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + + @pytest.mark.asyncio(loop_scope="session") + async def test_customer_malformed_response(self): + """Customer dict missing 'id' key should be caught.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + # success=True but customer has no 'id' + client.mutate.return_value = { + "customerUpsert": { + "success": True, + "customer": {"name": _TEST_USER_ID}, + } + } + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + + # ---- Issue creation failures ------------------------------------------- + + @pytest.mark.asyncio(loop_scope="session") + async def test_issue_create_api_error(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + RuntimeError("Issue create failed"), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "Issue create failed" in resp.error + + @pytest.mark.asyncio(loop_scope="session") + async def test_issue_create_not_success(self): + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _issue_create_response(success=False), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert "Failed to create feature request issue" in resp.message + + @pytest.mark.asyncio(loop_scope="session") + async def test_issue_create_malformed_response(self): + """issueCreate success=True but missing 'issue' key.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + {"issueCreate": {"success": True}}, # no 'issue' key + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + + # ---- Customer need attachment failures --------------------------------- + + @pytest.mark.asyncio(loop_scope="session") + async def test_need_create_api_error_new_issue(self): + """Need creation fails after new issue was created -> orphaned issue info.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _issue_create_response(issue_id="orphan-1", identifier="FR-10"), + RuntimeError("Need attach failed"), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.error is not None + assert "Need attach failed" in resp.error + assert resp.details is not None + assert resp.details["issue_id"] == "orphan-1" + assert resp.details["issue_identifier"] == "FR-10" + + @pytest.mark.asyncio(loop_scope="session") + async def test_need_create_api_error_existing_issue(self): + """Need creation fails on existing issue -> no orphaned info.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + RuntimeError("Need attach failed"), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + existing_issue_id="existing-1", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.details is None + + @pytest.mark.asyncio(loop_scope="session") + async def test_need_create_not_success_includes_orphaned_info(self): + """customerNeedCreate returns success=False -> includes orphaned issue.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _issue_create_response(issue_id="orphan-2", identifier="FR-20"), + _need_create_response(success=False), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.details is not None + assert resp.details["issue_id"] == "orphan-2" + assert resp.details["issue_identifier"] == "FR-20" + + @pytest.mark.asyncio(loop_scope="session") + async def test_need_create_not_success_existing_issue_no_details(self): + """customerNeedCreate fails on existing issue -> no orphaned info.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _need_create_response(success=False), + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + existing_issue_id="existing-1", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.details is None + + @pytest.mark.asyncio(loop_scope="session") + async def test_need_create_malformed_response(self): + """need_result missing 'need' key after success=True.""" + session = make_session(user_id=_TEST_USER_ID) + patcher, client = _mock_linear_config() + client.mutate.side_effect = [ + _customer_upsert_response(), + _issue_create_response(), + {"customerNeedCreate": {"success": True}}, # no 'need' key + ] + + with patcher: + tool = CreateFeatureRequestTool() + resp = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + title="Title", + description="Desc", + ) + + assert isinstance(resp, ErrorResponse) + assert resp.details is not None + assert resp.details["issue_id"] == "issue-1" diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/models.py b/autogpt_platform/backend/backend/api/features/chat/tools/models.py index bd19d590a6..f2d8f364e4 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/models.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/models.py @@ -41,6 +41,9 @@ class ResponseType(str, Enum): OPERATION_IN_PROGRESS = "operation_in_progress" # Input validation INPUT_VALIDATION_ERROR = "input_validation_error" + # Feature request types + FEATURE_REQUEST_SEARCH = "feature_request_search" + FEATURE_REQUEST_CREATED = "feature_request_created" # Base response model @@ -430,3 +433,34 @@ class AsyncProcessingResponse(ToolResponseBase): status: str = "accepted" # Must be "accepted" for detection operation_id: str | None = None task_id: str | None = None + + +# Feature request models +class FeatureRequestInfo(BaseModel): + """Information about a feature request issue.""" + + id: str + identifier: str + title: str + description: str | None = None + + +class FeatureRequestSearchResponse(ToolResponseBase): + """Response for search_feature_requests tool.""" + + type: ResponseType = ResponseType.FEATURE_REQUEST_SEARCH + results: list[FeatureRequestInfo] + count: int + query: str + + +class FeatureRequestCreatedResponse(ToolResponseBase): + """Response for create_feature_request tool.""" + + type: ResponseType = ResponseType.FEATURE_REQUEST_CREATED + issue_id: str + issue_identifier: str + issue_title: str + issue_url: str + is_new_issue: bool # False if added to existing + customer_name: str diff --git a/autogpt_platform/backend/backend/util/settings.py b/autogpt_platform/backend/backend/util/settings.py index 48dadb88f1..c5cca87b6e 100644 --- a/autogpt_platform/backend/backend/util/settings.py +++ b/autogpt_platform/backend/backend/util/settings.py @@ -662,6 +662,17 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings): mem0_api_key: str = Field(default="", description="Mem0 API key") elevenlabs_api_key: str = Field(default="", description="ElevenLabs API key") + linear_api_key: str = Field( + default="", description="Linear API key for system-level operations" + ) + linear_feature_request_project_id: str = Field( + default="", + description="Linear project ID where feature requests are tracked", + ) + linear_feature_request_team_id: str = Field( + default="", + description="Linear team ID used when creating feature request issues", + ) linear_client_id: str = Field(default="", description="Linear client ID") linear_client_secret: str = Field(default="", description="Linear client secret") diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx index 71ade81a9f..b62e96f58a 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx @@ -15,6 +15,10 @@ import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai"; import { useEffect, useRef, useState } from "react"; import { CreateAgentTool } from "../../tools/CreateAgent/CreateAgent"; import { EditAgentTool } from "../../tools/EditAgent/EditAgent"; +import { + CreateFeatureRequestTool, + SearchFeatureRequestsTool, +} from "../../tools/FeatureRequests/FeatureRequests"; import { FindAgentsTool } from "../../tools/FindAgents/FindAgents"; import { FindBlocksTool } from "../../tools/FindBlocks/FindBlocks"; import { RunAgentTool } from "../../tools/RunAgent/RunAgent"; @@ -254,6 +258,20 @@ export const ChatMessagesContainer = ({ part={part as ToolUIPart} /> ); + case "tool-search_feature_requests": + return ( + + ); + case "tool-create_feature_request": + return ( + + ); default: return null; } diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/styleguide/page.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/styleguide/page.tsx index 6030665f1c..8a35f939ca 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/styleguide/page.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/styleguide/page.tsx @@ -14,6 +14,10 @@ import { Text } from "@/components/atoms/Text/Text"; import { CopilotChatActionsProvider } from "../components/CopilotChatActionsProvider/CopilotChatActionsProvider"; import { CreateAgentTool } from "../tools/CreateAgent/CreateAgent"; import { EditAgentTool } from "../tools/EditAgent/EditAgent"; +import { + CreateFeatureRequestTool, + SearchFeatureRequestsTool, +} from "../tools/FeatureRequests/FeatureRequests"; import { FindAgentsTool } from "../tools/FindAgents/FindAgents"; import { FindBlocksTool } from "../tools/FindBlocks/FindBlocks"; import { RunAgentTool } from "../tools/RunAgent/RunAgent"; @@ -45,6 +49,8 @@ const SECTIONS = [ "Tool: Create Agent", "Tool: Edit Agent", "Tool: View Agent Output", + "Tool: Search Feature Requests", + "Tool: Create Feature Request", "Full Conversation Example", ] as const; @@ -1421,6 +1427,235 @@ export default function StyleguidePage() { + {/* ============================================================= */} + {/* SEARCH FEATURE REQUESTS */} + {/* ============================================================= */} + +
+ + + + + + + + + + + + + + + + + + + + + + + +
+ + {/* ============================================================= */} + {/* CREATE FEATURE REQUEST */} + {/* ============================================================= */} + +
+ + + + + + + + + + + + + + + + + + + + + + + +
+ {/* ============================================================= */} {/* FULL CONVERSATION EXAMPLE */} {/* ============================================================= */} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/FeatureRequests.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/FeatureRequests.tsx new file mode 100644 index 0000000000..fcd4624b6a --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/FeatureRequests.tsx @@ -0,0 +1,227 @@ +"use client"; + +import type { ToolUIPart } from "ai"; +import { useMemo } from "react"; + +import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation"; +import { + ContentBadge, + ContentCard, + ContentCardDescription, + ContentCardHeader, + ContentCardTitle, + ContentGrid, + ContentMessage, + ContentSuggestionsList, +} from "../../components/ToolAccordion/AccordionContent"; +import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion"; +import { + AccordionIcon, + getAccordionTitle, + getAnimationText, + getFeatureRequestOutput, + isCreatedOutput, + isErrorOutput, + isNoResultsOutput, + isSearchResultsOutput, + ToolIcon, + type FeatureRequestToolType, +} from "./helpers"; + +export interface FeatureRequestToolPart { + type: FeatureRequestToolType; + toolCallId: string; + state: ToolUIPart["state"]; + input?: unknown; + output?: unknown; +} + +interface Props { + part: FeatureRequestToolPart; +} + +function truncate(text: string, maxChars: number): string { + const trimmed = text.trim(); + if (trimmed.length <= maxChars) return trimmed; + return `${trimmed.slice(0, maxChars).trimEnd()}…`; +} + +export function SearchFeatureRequestsTool({ part }: Props) { + const output = getFeatureRequestOutput(part); + const text = getAnimationText(part); + const isStreaming = + part.state === "input-streaming" || part.state === "input-available"; + const isError = + part.state === "output-error" || (!!output && isErrorOutput(output)); + + const normalized = useMemo(() => { + if (!output) return null; + return { title: getAccordionTitle(part.type, output) }; + }, [output, part.type]); + + const isOutputAvailable = part.state === "output-available" && !!output; + + const searchOutput = + isOutputAvailable && output && isSearchResultsOutput(output) + ? output + : null; + const noResultsOutput = + isOutputAvailable && output && isNoResultsOutput(output) ? output : null; + const errorOutput = + isOutputAvailable && output && isErrorOutput(output) ? output : null; + + const hasExpandableContent = + isOutputAvailable && + ((!!searchOutput && searchOutput.count > 0) || + !!noResultsOutput || + !!errorOutput); + + const accordionDescription = + hasExpandableContent && searchOutput + ? `Found ${searchOutput.count} result${searchOutput.count === 1 ? "" : "s"} for "${searchOutput.query}"` + : hasExpandableContent && (noResultsOutput || errorOutput) + ? ((noResultsOutput ?? errorOutput)?.message ?? null) + : null; + + return ( +
+
+ + +
+ + {hasExpandableContent && normalized && ( + } + title={normalized.title} + description={accordionDescription} + > + {searchOutput && ( + + {searchOutput.results.map((r) => ( + + + {r.title} + + {r.description && ( + + {truncate(r.description, 200)} + + )} + + ))} + + )} + + {noResultsOutput && ( +
+ {noResultsOutput.message} + {noResultsOutput.suggestions && + noResultsOutput.suggestions.length > 0 && ( + + )} +
+ )} + + {errorOutput && ( +
+ {errorOutput.message} + {errorOutput.error && ( + + {errorOutput.error} + + )} +
+ )} +
+ )} +
+ ); +} + +export function CreateFeatureRequestTool({ part }: Props) { + const output = getFeatureRequestOutput(part); + const text = getAnimationText(part); + const isStreaming = + part.state === "input-streaming" || part.state === "input-available"; + const isError = + part.state === "output-error" || (!!output && isErrorOutput(output)); + + const normalized = useMemo(() => { + if (!output) return null; + return { title: getAccordionTitle(part.type, output) }; + }, [output, part.type]); + + const isOutputAvailable = part.state === "output-available" && !!output; + + const createdOutput = + isOutputAvailable && output && isCreatedOutput(output) ? output : null; + const errorOutput = + isOutputAvailable && output && isErrorOutput(output) ? output : null; + + const hasExpandableContent = + isOutputAvailable && (!!createdOutput || !!errorOutput); + + const accordionDescription = + hasExpandableContent && createdOutput + ? createdOutput.issue_title + : hasExpandableContent && errorOutput + ? errorOutput.message + : null; + + return ( +
+
+ + +
+ + {hasExpandableContent && normalized && ( + } + title={normalized.title} + description={accordionDescription} + > + {createdOutput && ( + + + {createdOutput.issue_title} + +
+ + {createdOutput.is_new_issue ? "New" : "Existing"} + +
+ {createdOutput.message} +
+ )} + + {errorOutput && ( +
+ {errorOutput.message} + {errorOutput.error && ( + + {errorOutput.error} + + )} +
+ )} +
+ )} +
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/helpers.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/helpers.tsx new file mode 100644 index 0000000000..75133905b1 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/FeatureRequests/helpers.tsx @@ -0,0 +1,271 @@ +import { + CheckCircleIcon, + LightbulbIcon, + MagnifyingGlassIcon, + PlusCircleIcon, +} from "@phosphor-icons/react"; +import type { ToolUIPart } from "ai"; + +/* ------------------------------------------------------------------ */ +/* Types (local until API client is regenerated) */ +/* ------------------------------------------------------------------ */ + +interface FeatureRequestInfo { + id: string; + identifier: string; + title: string; + description?: string | null; +} + +export interface FeatureRequestSearchResponse { + type: "feature_request_search"; + message: string; + results: FeatureRequestInfo[]; + count: number; + query: string; +} + +export interface FeatureRequestCreatedResponse { + type: "feature_request_created"; + message: string; + issue_id: string; + issue_identifier: string; + issue_title: string; + issue_url: string; + is_new_issue: boolean; + customer_name: string; +} + +interface NoResultsResponse { + type: "no_results"; + message: string; + suggestions?: string[]; +} + +interface ErrorResponse { + type: "error"; + message: string; + error?: string; +} + +export type FeatureRequestOutput = + | FeatureRequestSearchResponse + | FeatureRequestCreatedResponse + | NoResultsResponse + | ErrorResponse; + +export type FeatureRequestToolType = + | "tool-search_feature_requests" + | "tool-create_feature_request" + | string; + +/* ------------------------------------------------------------------ */ +/* Output parsing */ +/* ------------------------------------------------------------------ */ + +function parseOutput(output: unknown): FeatureRequestOutput | null { + if (!output) return null; + if (typeof output === "string") { + const trimmed = output.trim(); + if (!trimmed) return null; + try { + return parseOutput(JSON.parse(trimmed) as unknown); + } catch { + return null; + } + } + if (typeof output === "object") { + const type = (output as { type?: unknown }).type; + if ( + type === "feature_request_search" || + type === "feature_request_created" || + type === "no_results" || + type === "error" + ) { + return output as FeatureRequestOutput; + } + // Fallback structural checks + if ("results" in output && "query" in output) + return output as FeatureRequestSearchResponse; + if ("issue_identifier" in output) + return output as FeatureRequestCreatedResponse; + if ("suggestions" in output && !("error" in output)) + return output as NoResultsResponse; + if ("error" in output || "details" in output) + return output as ErrorResponse; + } + return null; +} + +export function getFeatureRequestOutput( + part: unknown, +): FeatureRequestOutput | null { + if (!part || typeof part !== "object") return null; + return parseOutput((part as { output?: unknown }).output); +} + +/* ------------------------------------------------------------------ */ +/* Type guards */ +/* ------------------------------------------------------------------ */ + +export function isSearchResultsOutput( + output: FeatureRequestOutput, +): output is FeatureRequestSearchResponse { + return ( + output.type === "feature_request_search" || + ("results" in output && "query" in output) + ); +} + +export function isCreatedOutput( + output: FeatureRequestOutput, +): output is FeatureRequestCreatedResponse { + return ( + output.type === "feature_request_created" || "issue_identifier" in output + ); +} + +export function isNoResultsOutput( + output: FeatureRequestOutput, +): output is NoResultsResponse { + return ( + output.type === "no_results" || + ("suggestions" in output && !("error" in output)) + ); +} + +export function isErrorOutput( + output: FeatureRequestOutput, +): output is ErrorResponse { + return output.type === "error" || "error" in output; +} + +/* ------------------------------------------------------------------ */ +/* Accordion metadata */ +/* ------------------------------------------------------------------ */ + +export function getAccordionTitle( + toolType: FeatureRequestToolType, + output: FeatureRequestOutput, +): string { + if (toolType === "tool-search_feature_requests") { + if (isSearchResultsOutput(output)) return "Feature requests"; + if (isNoResultsOutput(output)) return "No feature requests found"; + return "Feature request search error"; + } + if (isCreatedOutput(output)) { + return output.is_new_issue + ? "Feature request created" + : "Added to feature request"; + } + if (isErrorOutput(output)) return "Feature request error"; + return "Feature request"; +} + +/* ------------------------------------------------------------------ */ +/* Animation text */ +/* ------------------------------------------------------------------ */ + +interface AnimationPart { + type: FeatureRequestToolType; + state: ToolUIPart["state"]; + input?: unknown; + output?: unknown; +} + +export function getAnimationText(part: AnimationPart): string { + if (part.type === "tool-search_feature_requests") { + const query = (part.input as { query?: string } | undefined)?.query?.trim(); + const queryText = query ? ` for "${query}"` : ""; + + switch (part.state) { + case "input-streaming": + case "input-available": + return `Searching feature requests${queryText}`; + case "output-available": { + const output = parseOutput(part.output); + if (!output) return `Searching feature requests${queryText}`; + if (isSearchResultsOutput(output)) { + return `Found ${output.count} feature request${output.count === 1 ? "" : "s"}${queryText}`; + } + if (isNoResultsOutput(output)) + return `No feature requests found${queryText}`; + return `Error searching feature requests${queryText}`; + } + case "output-error": + return `Error searching feature requests${queryText}`; + default: + return "Searching feature requests"; + } + } + + // create_feature_request + const title = (part.input as { title?: string } | undefined)?.title?.trim(); + const titleText = title ? ` "${title}"` : ""; + + switch (part.state) { + case "input-streaming": + case "input-available": + return `Creating feature request${titleText}`; + case "output-available": { + const output = parseOutput(part.output); + if (!output) return `Creating feature request${titleText}`; + if (isCreatedOutput(output)) { + return output.is_new_issue + ? "Feature request created" + : "Added to existing feature request"; + } + if (isErrorOutput(output)) return "Error creating feature request"; + return `Created feature request${titleText}`; + } + case "output-error": + return "Error creating feature request"; + default: + return "Creating feature request"; + } +} + +/* ------------------------------------------------------------------ */ +/* Icons */ +/* ------------------------------------------------------------------ */ + +export function ToolIcon({ + toolType, + isStreaming, + isError, +}: { + toolType: FeatureRequestToolType; + isStreaming?: boolean; + isError?: boolean; +}) { + const IconComponent = + toolType === "tool-create_feature_request" + ? PlusCircleIcon + : MagnifyingGlassIcon; + + return ( + + ); +} + +export function AccordionIcon({ + toolType, +}: { + toolType: FeatureRequestToolType; +}) { + const IconComponent = + toolType === "tool-create_feature_request" + ? CheckCircleIcon + : LightbulbIcon; + return ; +} diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index 496a714ba5..1e8dca865c 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -10495,7 +10495,9 @@ "operation_started", "operation_pending", "operation_in_progress", - "input_validation_error" + "input_validation_error", + "feature_request_search", + "feature_request_created" ], "title": "ResponseType", "description": "Types of tool responses." From 9ac3f64d56611ec1780b9ff6243514091e1f6eac Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 13 Feb 2026 09:04:05 -0600 Subject: [PATCH 07/16] chore(deps): bump github/codeql-action from 3 to 4 (#12033) Bumps [github/codeql-action](https://github.com/github/codeql-action) from 3 to 4.
Release notes

Sourced from github/codeql-action's releases.

v3.32.2

  • Update default CodeQL bundle version to 2.24.1. #3460

v3.32.1

  • A warning is now shown in Default Setup workflow logs if a private package registry is configured using a GitHub Personal Access Token (PAT), but no username is configured. #3422
  • Fixed a bug which caused the CodeQL Action to fail when repository properties cannot successfully be retrieved. #3421

v3.32.0

  • Update default CodeQL bundle version to 2.24.0. #3425

v3.31.11

  • When running a Default Setup workflow with Actions debugging enabled, the CodeQL Action will now use more unique names when uploading logs from the Dependabot authentication proxy as workflow artifacts. This ensures that the artifact names do not clash between multiple jobs in a build matrix. #3409
  • Improved error handling throughout the CodeQL Action. #3415
  • Added experimental support for automatically excluding generated files from the analysis. This feature is not currently enabled for any analysis. In the future, it may be enabled by default for some GitHub-managed analyses. #3318
  • The changelog extracts that are included with releases of the CodeQL Action are now shorter to avoid duplicated information from appearing in Dependabot PRs. #3403

v3.31.10

CodeQL Action Changelog

See the releases page for the relevant changes to the CodeQL CLI and language packs.

3.31.10 - 12 Jan 2026

  • Update default CodeQL bundle version to 2.23.9. #3393

See the full CHANGELOG.md for more information.

v3.31.9

CodeQL Action Changelog

See the releases page for the relevant changes to the CodeQL CLI and language packs.

3.31.9 - 16 Dec 2025

No user facing changes.

See the full CHANGELOG.md for more information.

v3.31.8

CodeQL Action Changelog

See the releases page for the relevant changes to the CodeQL CLI and language packs.

3.31.8 - 11 Dec 2025

  • Update default CodeQL bundle version to 2.23.8. #3354

See the full CHANGELOG.md for more information.

v3.31.7

... (truncated)

Changelog

Sourced from github/codeql-action's changelog.

4.31.11 - 23 Jan 2026

  • When running a Default Setup workflow with Actions debugging enabled, the CodeQL Action will now use more unique names when uploading logs from the Dependabot authentication proxy as workflow artifacts. This ensures that the artifact names do not clash between multiple jobs in a build matrix. #3409
  • Improved error handling throughout the CodeQL Action. #3415
  • Added experimental support for automatically excluding generated files from the analysis. This feature is not currently enabled for any analysis. In the future, it may be enabled by default for some GitHub-managed analyses. #3318
  • The changelog extracts that are included with releases of the CodeQL Action are now shorter to avoid duplicated information from appearing in Dependabot PRs. #3403

4.31.10 - 12 Jan 2026

  • Update default CodeQL bundle version to 2.23.9. #3393

4.31.9 - 16 Dec 2025

No user facing changes.

4.31.8 - 11 Dec 2025

  • Update default CodeQL bundle version to 2.23.8. #3354

4.31.7 - 05 Dec 2025

  • Update default CodeQL bundle version to 2.23.7. #3343

4.31.6 - 01 Dec 2025

No user facing changes.

4.31.5 - 24 Nov 2025

  • Update default CodeQL bundle version to 2.23.6. #3321

4.31.4 - 18 Nov 2025

No user facing changes.

4.31.3 - 13 Nov 2025

  • CodeQL Action v3 will be deprecated in December 2026. The Action now logs a warning for customers who are running v3 but could be running v4. For more information, see Upcoming deprecation of CodeQL Action v3.
  • Update default CodeQL bundle version to 2.23.5. #3288

4.31.2 - 30 Oct 2025

No user facing changes.

4.31.1 - 30 Oct 2025

  • The add-snippets input has been removed from the analyze action. This input has been deprecated since CodeQL Action 3.26.4 in August 2024 when this removal was announced.

4.31.0 - 24 Oct 2025

... (truncated)

Commits
  • 8aac4e4 Merge pull request #3448 from github/mergeback/v4.32.1-to-main-6bc82e05
  • e8d7df4 Rebuild
  • c1bba77 Update changelog and version after v4.32.1
  • 6bc82e0 Merge pull request #3447 from github/update-v4.32.1-f52cbc830
  • 42f00f2 Add a couple of change notes
  • cedee6d Update changelog for v4.32.1
  • f52cbc8 Merge pull request #3445 from github/dependabot/npm_and_yarn/fast-xml-parser-...
  • See full diff in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=github/codeql-action&package-manager=github_actions&previous-version=3&new-version=4)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/codeql.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 966243323c..ff535f8496 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -62,7 +62,7 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v3 + uses: github/codeql-action/init@v4 with: languages: ${{ matrix.language }} build-mode: ${{ matrix.build-mode }} @@ -93,6 +93,6 @@ jobs: exit 1 - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v3 + uses: github/codeql-action/analyze@v4 with: category: "/language:${{matrix.language}}" From c2368f15ffff4b8389eaf03058a059c11786a7ce Mon Sep 17 00:00:00 2001 From: Bently Date: Fri, 13 Feb 2026 15:20:23 +0000 Subject: [PATCH 08/16] fix(blocks): disable PrintToConsoleBlock (#12100) ## Summary Disables the Print to Console block as requested by Nick Tindle. ## Changes - Added `disabled=True` to PrintToConsoleBlock in `basic.py` ## Testing - Block will no longer appear in the platform UI - Existing graphs using this block should be checked (block ID: `f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c`) Closes OPEN-3000

Greptile Overview

Greptile Summary

Added `disabled=True` parameter to `PrintToConsoleBlock` in `basic.py` per Nick Tindle's request (OPEN-3000). - Block follows the same disabling pattern used by other blocks in the codebase (e.g., `BlockInstallationBlock`, video blocks, Ayrshare blocks) - Block will no longer appear in the platform UI for new graph creation - Existing graphs using this block (ID: `f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c`) will need to be checked for compatibility - Comment properly documents the reason for disabling

Confidence Score: 5/5

- This PR is safe to merge with minimal risk - Single-line change that adds a well-documented flag following existing patterns used throughout the codebase. The change is non-destructive and only affects UI visibility of the block for new graphs. - No files require special attention
Last reviewed commit: 759003b --- autogpt_platform/backend/backend/blocks/basic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/autogpt_platform/backend/backend/blocks/basic.py b/autogpt_platform/backend/backend/blocks/basic.py index f129d2707b..5fdcfb6d82 100644 --- a/autogpt_platform/backend/backend/blocks/basic.py +++ b/autogpt_platform/backend/backend/blocks/basic.py @@ -126,6 +126,7 @@ class PrintToConsoleBlock(Block): output_schema=PrintToConsoleBlock.Output, test_input={"text": "Hello, World!"}, is_sensitive_action=True, + disabled=True, # Disabled per Nick Tindle's request (OPEN-3000) test_output=[ ("output", "Hello, World!"), ("status", "printed"), From 965b7d3e04024f6324ad1f851c3be7dd20136bf1 Mon Sep 17 00:00:00 2001 From: Otto Date: Fri, 13 Feb 2026 15:45:10 +0000 Subject: [PATCH 09/16] dx: Add PR overlap detection & alert (#12104) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Adds an automated workflow that detects potential merge conflicts between open PRs, helping contributors coordinate proactively. **Example output:** [See comment on PR #12057](https://github.com/Significant-Gravitas/AutoGPT/pull/12057#issuecomment-3897330632) ## How it works 1. **Triggered on PR events** — runs when a PR is opened, pushed to, or reopened 2. **Compares against all open PRs** targeting the same base branch 3. **Detects overlaps** at multiple levels: - File overlap (same files modified) - Line overlap (same line ranges modified) - Actual merge conflicts (attempts real merges) 4. **Posts a comment** on the PR with findings ## Features - Full file paths with common prefix extraction for readability - Conflict size (number of conflict regions + lines affected) - Conflict types (content, added, deleted, modified/deleted, etc.) - Last-updated timestamps for each PR - Risk categorization (conflict, medium, low) - Ignores noise files (openapi.json, lock files) - Updates existing comment on subsequent pushes (no spam) - Filters out PRs older than 14 days - Clone-once optimization for fast merge testing (~48s for 19 PRs) ## Files - `.github/scripts/detect_overlaps.py` — main detection script - `.github/workflows/pr-overlap-check.yml` — workflow definition --- .github/scripts/detect_overlaps.py | 1229 ++++++++++++++++++++++++ .github/workflows/pr-overlap-check.yml | 39 + 2 files changed, 1268 insertions(+) create mode 100644 .github/scripts/detect_overlaps.py create mode 100644 .github/workflows/pr-overlap-check.yml diff --git a/.github/scripts/detect_overlaps.py b/.github/scripts/detect_overlaps.py new file mode 100644 index 0000000000..1f9f4be7cf --- /dev/null +++ b/.github/scripts/detect_overlaps.py @@ -0,0 +1,1229 @@ +#!/usr/bin/env python3 +""" +PR Overlap Detection Tool + +Detects potential merge conflicts between a given PR and other open PRs +by checking for file overlap, line overlap, and actual merge conflicts. +""" + +import json +import os +import re +import subprocess +import sys +import tempfile +from dataclasses import dataclass +from typing import Optional + + +# ============================================================================= +# MAIN ENTRY POINT +# ============================================================================= + +def main(): + """Main entry point for PR overlap detection.""" + import argparse + + parser = argparse.ArgumentParser(description="Detect PR overlaps and potential merge conflicts") + parser.add_argument("pr_number", type=int, help="PR number to check") + parser.add_argument("--base", default=None, help="Base branch (default: auto-detect from PR)") + parser.add_argument("--skip-merge-test", action="store_true", help="Skip actual merge conflict testing") + parser.add_argument("--discord-webhook", default=os.environ.get("DISCORD_WEBHOOK_URL"), help="Discord webhook URL for notifications") + parser.add_argument("--dry-run", action="store_true", help="Don't post comments, just print") + + args = parser.parse_args() + + owner, repo = get_repo_info() + print(f"Checking PR #{args.pr_number} in {owner}/{repo}") + + # Get current PR info + current_pr = fetch_pr_details(args.pr_number) + base_branch = args.base or current_pr.base_ref + + print(f"PR #{current_pr.number}: {current_pr.title}") + print(f"Base branch: {base_branch}") + print(f"Files changed: {len(current_pr.files)}") + + # Find overlapping PRs + overlaps, all_changes = find_overlapping_prs( + owner, repo, base_branch, current_pr, args.pr_number, args.skip_merge_test + ) + + if not overlaps: + print("No overlaps detected!") + return + + # Generate and post report + comment = format_comment(overlaps, args.pr_number, current_pr.changed_ranges, all_changes) + + if args.dry_run: + print("\n" + "="*60) + print("COMMENT PREVIEW:") + print("="*60) + print(comment) + else: + if comment: + post_or_update_comment(args.pr_number, comment) + print("Posted comment to PR") + + if args.discord_webhook: + send_discord_notification(args.discord_webhook, current_pr, overlaps) + + # Report results and exit + report_results(overlaps) + + +# ============================================================================= +# HIGH-LEVEL WORKFLOW FUNCTIONS +# ============================================================================= + +def fetch_pr_details(pr_number: int) -> "PullRequest": + """Fetch details for a specific PR including its diff.""" + result = run_gh(["pr", "view", str(pr_number), "--json", "number,title,url,author,headRefName,baseRefName,files"]) + data = json.loads(result.stdout) + + pr = PullRequest( + number=data["number"], + title=data["title"], + author=data["author"]["login"] if data.get("author") else "unknown", + url=data["url"], + head_ref=data["headRefName"], + base_ref=data["baseRefName"], + files=[f["path"] for f in data["files"]], + changed_ranges={} + ) + + # Get detailed diff + diff = get_pr_diff(pr_number) + pr.changed_ranges = parse_diff_ranges(diff) + + return pr + + +def find_overlapping_prs( + owner: str, + repo: str, + base_branch: str, + current_pr: "PullRequest", + current_pr_number: int, + skip_merge_test: bool +) -> tuple[list["Overlap"], dict[int, dict[str, "ChangedFile"]]]: + """Find all PRs that overlap with the current PR.""" + # Query other open PRs + all_prs = query_open_prs(owner, repo, base_branch) + other_prs = [p for p in all_prs if p["number"] != current_pr_number] + + print(f"Found {len(other_prs)} other open PRs targeting {base_branch}") + + # Find file overlaps (excluding ignored files, filtering by age) + candidates = find_file_overlap_candidates(current_pr.files, other_prs) + + print(f"Found {len(candidates)} PRs with file overlap (excluding ignored files)") + + if not candidates: + return [], {} + + # First pass: analyze line overlaps (no merge testing yet) + overlaps = [] + all_changes = {} + prs_needing_merge_test = [] + + for pr_data, shared_files in candidates: + overlap, pr_changes = analyze_pr_overlap( + owner, repo, base_branch, current_pr, pr_data, shared_files, + skip_merge_test=True # Always skip in first pass + ) + if overlap: + overlaps.append(overlap) + all_changes[pr_data["number"]] = pr_changes + # Track PRs that need merge testing + if overlap.line_overlaps and not skip_merge_test: + prs_needing_merge_test.append(overlap) + + # Second pass: batch merge testing with shared clone + if prs_needing_merge_test: + run_batch_merge_tests(owner, repo, base_branch, current_pr, prs_needing_merge_test) + + return overlaps, all_changes + + +def run_batch_merge_tests( + owner: str, + repo: str, + base_branch: str, + current_pr: "PullRequest", + overlaps: list["Overlap"] +): + """Run merge tests for multiple PRs using a shared clone.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Clone once + if not clone_repo(owner, repo, base_branch, tmpdir): + return + + configure_git(tmpdir) + + # Fetch current PR branch once + result = run_git(["fetch", "origin", f"pull/{current_pr.number}/head:pr-{current_pr.number}"], cwd=tmpdir, check=False) + if result.returncode != 0: + print(f"Warning: Could not fetch current PR #{current_pr.number}", file=sys.stderr) + return + + for overlap in overlaps: + other_pr = overlap.pr_b if overlap.pr_a.number == current_pr.number else overlap.pr_a + print(f"Testing merge conflict with PR #{other_pr.number}...", flush=True) + + # Clean up any in-progress merge from previous iteration + run_git(["merge", "--abort"], cwd=tmpdir, check=False) + + # Reset to base branch + run_git(["checkout", base_branch], cwd=tmpdir, check=False) + run_git(["reset", "--hard", f"origin/{base_branch}"], cwd=tmpdir, check=False) + run_git(["clean", "-fdx"], cwd=tmpdir, check=False) + + # Fetch the other PR branch + result = run_git(["fetch", "origin", f"pull/{other_pr.number}/head:pr-{other_pr.number}"], cwd=tmpdir, check=False) + if result.returncode != 0: + print(f"Warning: Could not fetch PR #{other_pr.number}: {result.stderr.strip()}", file=sys.stderr) + continue + + # Try merging current PR first + result = run_git(["merge", "--no-commit", "--no-ff", f"pr-{current_pr.number}"], cwd=tmpdir, check=False) + if result.returncode != 0: + # Current PR conflicts with base + conflict_files, conflict_details = extract_conflict_info(tmpdir, result.stderr) + overlap.has_merge_conflict = True + overlap.conflict_files = conflict_files + overlap.conflict_details = conflict_details + overlap.conflict_type = 'pr_a_conflicts_base' + run_git(["merge", "--abort"], cwd=tmpdir, check=False) + continue + + # Commit and try merging other PR + run_git(["commit", "-m", f"Merge PR #{current_pr.number}"], cwd=tmpdir, check=False) + + result = run_git(["merge", "--no-commit", "--no-ff", f"pr-{other_pr.number}"], cwd=tmpdir, check=False) + if result.returncode != 0: + # Conflict between PRs + conflict_files, conflict_details = extract_conflict_info(tmpdir, result.stderr) + overlap.has_merge_conflict = True + overlap.conflict_files = conflict_files + overlap.conflict_details = conflict_details + overlap.conflict_type = 'conflict' + run_git(["merge", "--abort"], cwd=tmpdir, check=False) + + +def analyze_pr_overlap( + owner: str, + repo: str, + base_branch: str, + current_pr: "PullRequest", + other_pr_data: dict, + shared_files: list[str], + skip_merge_test: bool +) -> tuple[Optional["Overlap"], dict[str, "ChangedFile"]]: + """Analyze overlap between current PR and another PR.""" + # Filter out ignored files + non_ignored_shared = [f for f in shared_files if not should_ignore_file(f)] + if not non_ignored_shared: + return None, {} + + other_pr = PullRequest( + number=other_pr_data["number"], + title=other_pr_data["title"], + author=other_pr_data["author"], + url=other_pr_data["url"], + head_ref=other_pr_data["head_ref"], + base_ref=other_pr_data["base_ref"], + files=other_pr_data["files"], + changed_ranges={}, + updated_at=other_pr_data.get("updated_at") + ) + + # Get diff for other PR + other_diff = get_pr_diff(other_pr.number) + other_pr.changed_ranges = parse_diff_ranges(other_diff) + + # Check line overlaps + line_overlaps = find_line_overlaps( + current_pr.changed_ranges, + other_pr.changed_ranges, + shared_files + ) + + overlap = Overlap( + pr_a=current_pr, + pr_b=other_pr, + overlapping_files=non_ignored_shared, + line_overlaps=line_overlaps + ) + + # Test for actual merge conflicts if we have line overlaps + if line_overlaps and not skip_merge_test: + print(f"Testing merge conflict with PR #{other_pr.number}...", flush=True) + has_conflict, conflict_files, conflict_details, error_type = test_merge_conflict( + owner, repo, base_branch, current_pr, other_pr + ) + overlap.has_merge_conflict = has_conflict + overlap.conflict_files = conflict_files + overlap.conflict_details = conflict_details + overlap.conflict_type = error_type + + return overlap, other_pr.changed_ranges + + +def find_file_overlap_candidates( + current_files: list[str], + other_prs: list[dict], + max_age_days: int = 14 +) -> list[tuple[dict, list[str]]]: + """Find PRs that share files with the current PR.""" + from datetime import datetime, timezone, timedelta + + current_files_set = set(f for f in current_files if not should_ignore_file(f)) + candidates = [] + cutoff_date = datetime.now(timezone.utc) - timedelta(days=max_age_days) + + for pr_data in other_prs: + # Filter out PRs older than max_age_days + updated_at = pr_data.get("updated_at") + if updated_at: + try: + pr_date = datetime.fromisoformat(updated_at.replace('Z', '+00:00')) + if pr_date < cutoff_date: + continue # Skip old PRs + except Exception as e: + # If we can't parse date, include the PR (safe fallback) + print(f"Warning: Could not parse date for PR: {e}", file=sys.stderr) + + other_files = set(f for f in pr_data["files"] if not should_ignore_file(f)) + shared = current_files_set & other_files + + if shared: + candidates.append((pr_data, list(shared))) + + return candidates + + +def report_results(overlaps: list["Overlap"]): + """Report results (informational only, always exits 0).""" + conflicts = [o for o in overlaps if o.has_merge_conflict] + if conflicts: + print(f"\n⚠️ Found {len(conflicts)} merge conflict(s)") + + line_overlap_count = len([o for o in overlaps if o.line_overlaps]) + if line_overlap_count: + print(f"\n⚠️ Found {line_overlap_count} PR(s) with line overlap") + + print("\n✅ Done") + # Always exit 0 - this check is informational, not a merge blocker + + +# ============================================================================= +# COMMENT FORMATTING +# ============================================================================= + +def format_comment( + overlaps: list["Overlap"], + current_pr: int, + changes_current: dict[str, "ChangedFile"], + all_changes: dict[int, dict[str, "ChangedFile"]] +) -> str: + """Format the overlap report as a PR comment.""" + if not overlaps: + return "" + + lines = ["## 🔍 PR Overlap Detection"] + lines.append("") + lines.append("This check compares your PR against all other open PRs targeting the same branch to detect potential merge conflicts early.") + lines.append("") + + # Check if current PR conflicts with base branch + format_base_conflicts(overlaps, lines) + + # Classify and sort overlaps + classified = classify_all_overlaps(overlaps, current_pr, changes_current, all_changes) + + # Group by risk + conflicts = [(o, r) for o, r in classified if r == 'conflict'] + medium_risk = [(o, r) for o, r in classified if r == 'medium'] + low_risk = [(o, r) for o, r in classified if r == 'low'] + + # Format each section + format_conflicts_section(conflicts, current_pr, lines) + format_medium_risk_section(medium_risk, current_pr, changes_current, all_changes, lines) + format_low_risk_section(low_risk, current_pr, lines) + + # Summary + total = len(overlaps) + lines.append(f"\n**Summary:** {len(conflicts)} conflict(s), {len(medium_risk)} medium risk, {len(low_risk)} low risk (out of {total} PRs with file overlap)") + lines.append("\n---\n*Auto-generated on push. Ignores: `openapi.json`, lock files.*") + + return "\n".join(lines) + + +def format_base_conflicts(overlaps: list["Overlap"], lines: list[str]): + """Format base branch conflicts section.""" + base_conflicts = [o for o in overlaps if o.conflict_type == 'pr_a_conflicts_base'] + if base_conflicts: + lines.append("### ⚠️ This PR has conflicts with the base branch\n") + lines.append("Conflicts will need to be resolved before merging:\n") + first = base_conflicts[0] + for f in first.conflict_files[:10]: + lines.append(f"- `{f}`") + if len(first.conflict_files) > 10: + lines.append(f"- ... and {len(first.conflict_files) - 10} more files") + lines.append("\n") + + +def format_conflicts_section(conflicts: list[tuple], current_pr: int, lines: list[str]): + """Format the merge conflicts section.""" + pr_conflicts = [(o, r) for o, r in conflicts if o.conflict_type != 'pr_a_conflicts_base'] + + if not pr_conflicts: + return + + lines.append("### 🔴 Merge Conflicts Detected") + lines.append("") + lines.append("The following PRs have been tested and **will have merge conflicts** if merged after this PR. Consider coordinating with the authors.") + lines.append("") + + for o, _ in pr_conflicts: + other = o.pr_b if o.pr_a.number == current_pr else o.pr_a + format_pr_entry(other, lines) + format_conflict_details(o, lines) + lines.append("") + + +def format_medium_risk_section( + medium_risk: list[tuple], + current_pr: int, + changes_current: dict, + all_changes: dict, + lines: list[str] +): + """Format the medium risk section.""" + if not medium_risk: + return + + lines.append("### 🟡 Medium Risk — Some Line Overlap\n") + lines.append("These PRs have some overlapping changes:\n") + + for o, _ in medium_risk: + other = o.pr_b if o.pr_a.number == current_pr else o.pr_a + other_changes = all_changes.get(other.number, {}) + format_pr_entry(other, lines) + + # Note if rename is involved + for file_path in o.overlapping_files: + file_a = changes_current.get(file_path) + file_b = other_changes.get(file_path) + if (file_a and file_a.is_rename) or (file_b and file_b.is_rename): + lines.append(f" - ⚠️ `{file_path}` is being renamed/moved") + break + + if o.line_overlaps: + for file_path, ranges in o.line_overlaps.items(): + range_strs = [f"L{r[0]}-{r[1]}" if r[0] != r[1] else f"L{r[0]}" for r in ranges] + lines.append(f" - `{file_path}`: {', '.join(range_strs)}") + else: + non_ignored = [f for f in o.overlapping_files if not should_ignore_file(f)] + if non_ignored: + lines.append(f" - Shared files: `{'`, `'.join(non_ignored[:5])}`") + lines.append("") + + +def format_low_risk_section(low_risk: list[tuple], current_pr: int, lines: list[str]): + """Format the low risk section.""" + if not low_risk: + return + + lines.append("### 🟢 Low Risk — File Overlap Only\n") + lines.append("
These PRs touch the same files but different sections (click to expand)\n") + + for o, _ in low_risk: + other = o.pr_b if o.pr_a.number == current_pr else o.pr_a + non_ignored = [f for f in o.overlapping_files if not should_ignore_file(f)] + if non_ignored: + format_pr_entry(other, lines) + if o.line_overlaps: + for file_path, ranges in o.line_overlaps.items(): + range_strs = [f"L{r[0]}-{r[1]}" if r[0] != r[1] else f"L{r[0]}" for r in ranges] + lines.append(f" - `{file_path}`: {', '.join(range_strs)}") + else: + lines.append(f" - Shared files: `{'`, `'.join(non_ignored[:5])}`") + lines.append("") # Add blank line between entries + + lines.append("
\n") + + +def format_pr_entry(pr: "PullRequest", lines: list[str]): + """Format a single PR entry line.""" + updated = format_relative_time(pr.updated_at) + updated_str = f" · updated {updated}" if updated else "" + # Just use #number - GitHub auto-renders it with title + lines.append(f"- #{pr.number} ({pr.author}{updated_str})") + + +def format_conflict_details(overlap: "Overlap", lines: list[str]): + """Format conflict details for a PR.""" + if overlap.conflict_details: + all_paths = [d.path for d in overlap.conflict_details] + common_prefix = find_common_prefix(all_paths) + if common_prefix: + lines.append(f" - 📁 `{common_prefix}`") + for detail in overlap.conflict_details: + display_path = detail.path[len(common_prefix):] if common_prefix else detail.path + size_str = format_conflict_size(detail) + lines.append(f" - `{display_path}`{size_str}") + elif overlap.conflict_files: + common_prefix = find_common_prefix(overlap.conflict_files) + if common_prefix: + lines.append(f" - 📁 `{common_prefix}`") + for f in overlap.conflict_files: + display_path = f[len(common_prefix):] if common_prefix else f + lines.append(f" - `{display_path}`") + + +def format_conflict_size(detail: "ConflictInfo") -> str: + """Format conflict size string for a file.""" + if detail.conflict_count > 0: + return f" ({detail.conflict_count} conflict{'s' if detail.conflict_count > 1 else ''}, ~{detail.conflict_lines} lines)" + elif detail.conflict_type != 'content': + type_labels = { + 'both_added': 'added in both', + 'both_deleted': 'deleted in both', + 'deleted_by_us': 'deleted here, modified there', + 'deleted_by_them': 'modified here, deleted there', + 'added_by_us': 'added here', + 'added_by_them': 'added there', + } + label = type_labels.get(detail.conflict_type, detail.conflict_type) + return f" ({label})" + return "" + + +def format_line_overlaps(line_overlaps: dict[str, list[tuple]], lines: list[str]): + """Format line overlap details.""" + all_paths = list(line_overlaps.keys()) + common_prefix = find_common_prefix(all_paths) if len(all_paths) > 1 else "" + if common_prefix: + lines.append(f" - 📁 `{common_prefix}`") + for file_path, ranges in line_overlaps.items(): + display_path = file_path[len(common_prefix):] if common_prefix else file_path + range_strs = [f"L{r[0]}-{r[1]}" if r[0] != r[1] else f"L{r[0]}" for r in ranges] + indent = " " if common_prefix else " " + lines.append(f"{indent}- `{display_path}`: {', '.join(range_strs)}") + + +# ============================================================================= +# OVERLAP ANALYSIS +# ============================================================================= + +def classify_all_overlaps( + overlaps: list["Overlap"], + current_pr: int, + changes_current: dict, + all_changes: dict +) -> list[tuple["Overlap", str]]: + """Classify all overlaps by risk level and sort them.""" + classified = [] + for o in overlaps: + other_pr = o.pr_b if o.pr_a.number == current_pr else o.pr_a + other_changes = all_changes.get(other_pr.number, {}) + risk = classify_overlap_risk(o, changes_current, other_changes) + classified.append((o, risk)) + + def sort_key(item): + o, risk = item + risk_order = {'conflict': 0, 'medium': 1, 'low': 2} + # For conflicts, also sort by total conflict lines (descending) + conflict_lines = sum(d.conflict_lines for d in o.conflict_details) if o.conflict_details else 0 + return (risk_order.get(risk, 99), -conflict_lines) + + classified.sort(key=sort_key) + + return classified + + +def classify_overlap_risk( + overlap: "Overlap", + changes_a: dict[str, "ChangedFile"], + changes_b: dict[str, "ChangedFile"] +) -> str: + """Classify the risk level of an overlap.""" + if overlap.has_merge_conflict: + return 'conflict' + + has_rename = any( + (changes_a.get(f) and changes_a[f].is_rename) or + (changes_b.get(f) and changes_b[f].is_rename) + for f in overlap.overlapping_files + ) + + if overlap.line_overlaps: + total_overlap_lines = sum( + end - start + 1 + for ranges in overlap.line_overlaps.values() + for start, end in ranges + ) + + # Medium risk: >20 lines overlap or file rename + if total_overlap_lines > 20 or has_rename: + return 'medium' + else: + return 'low' + + if has_rename: + return 'medium' + + return 'low' + + +def find_line_overlaps( + changes_a: dict[str, "ChangedFile"], + changes_b: dict[str, "ChangedFile"], + shared_files: list[str] +) -> dict[str, list[tuple[int, int]]]: + """Find overlapping line ranges in shared files.""" + overlaps = {} + + for file_path in shared_files: + if should_ignore_file(file_path): + continue + + file_a = changes_a.get(file_path) + file_b = changes_b.get(file_path) + + if not file_a or not file_b: + continue + + # Skip pure renames + if file_a.is_rename and not file_a.additions and not file_a.deletions: + continue + if file_b.is_rename and not file_b.additions and not file_b.deletions: + continue + + # Note: This mixes old-file (deletions) and new-file (additions) line numbers, + # which can cause false positives when PRs insert/remove many lines. + # Acceptable for v1 since the real merge test is the authoritative check. + file_overlaps = find_range_overlaps( + file_a.additions + file_a.deletions, + file_b.additions + file_b.deletions + ) + + if file_overlaps: + overlaps[file_path] = merge_ranges(file_overlaps) + + return overlaps + + +def find_range_overlaps( + ranges_a: list[tuple[int, int]], + ranges_b: list[tuple[int, int]] +) -> list[tuple[int, int]]: + """Find overlapping regions between two sets of ranges.""" + overlaps = [] + for range_a in ranges_a: + for range_b in ranges_b: + if ranges_overlap(range_a, range_b): + overlap_start = max(range_a[0], range_b[0]) + overlap_end = min(range_a[1], range_b[1]) + overlaps.append((overlap_start, overlap_end)) + return overlaps + + +def ranges_overlap(range_a: tuple[int, int], range_b: tuple[int, int]) -> bool: + """Check if two line ranges overlap.""" + return range_a[0] <= range_b[1] and range_b[0] <= range_a[1] + + +def merge_ranges(ranges: list[tuple[int, int]]) -> list[tuple[int, int]]: + """Merge overlapping line ranges.""" + if not ranges: + return [] + + sorted_ranges = sorted(ranges, key=lambda x: x[0]) + merged = [sorted_ranges[0]] + + for current in sorted_ranges[1:]: + last = merged[-1] + if current[0] <= last[1] + 1: + merged[-1] = (last[0], max(last[1], current[1])) + else: + merged.append(current) + + return merged + + +# ============================================================================= +# MERGE CONFLICT TESTING +# ============================================================================= + +def test_merge_conflict( + owner: str, + repo: str, + base_branch: str, + pr_a: "PullRequest", + pr_b: "PullRequest" +) -> tuple[bool, list[str], list["ConflictInfo"], str]: + """Test if merging both PRs would cause a conflict.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Clone repo + if not clone_repo(owner, repo, base_branch, tmpdir): + return False, [], [], None + + configure_git(tmpdir) + if not fetch_pr_branches(tmpdir, pr_a.number, pr_b.number): + # Fetch failed for one or both PRs - can't test merge + return False, [], [], None + + # Try merging PR A first + conflict_result = try_merge_pr(tmpdir, pr_a.number) + if conflict_result: + return True, conflict_result[0], conflict_result[1], 'pr_a_conflicts_base' + + # Commit and try merging PR B + run_git(["commit", "-m", f"Merge PR #{pr_a.number}"], cwd=tmpdir, check=False) + + conflict_result = try_merge_pr(tmpdir, pr_b.number) + if conflict_result: + return True, conflict_result[0], conflict_result[1], 'conflict' + + return False, [], [], None + + +def clone_repo(owner: str, repo: str, branch: str, tmpdir: str) -> bool: + """Clone the repository.""" + clone_url = f"https://github.com/{owner}/{repo}.git" + result = run_git( + ["clone", "--depth=50", "--branch", branch, clone_url, tmpdir], + check=False + ) + if result.returncode != 0: + print(f"Failed to clone: {result.stderr}", file=sys.stderr) + return False + return True + + +def configure_git(tmpdir: str): + """Configure git for commits.""" + run_git(["config", "user.email", "github-actions[bot]@users.noreply.github.com"], cwd=tmpdir, check=False) + run_git(["config", "user.name", "github-actions[bot]"], cwd=tmpdir, check=False) + + +def fetch_pr_branches(tmpdir: str, pr_a: int, pr_b: int) -> bool: + """Fetch both PR branches. Returns False if any fetch fails.""" + success = True + for pr_num in (pr_a, pr_b): + result = run_git(["fetch", "origin", f"pull/{pr_num}/head:pr-{pr_num}"], cwd=tmpdir, check=False) + if result.returncode != 0: + print(f"Warning: Could not fetch PR #{pr_num}: {result.stderr.strip()}", file=sys.stderr) + success = False + return success + + +def try_merge_pr(tmpdir: str, pr_number: int) -> Optional[tuple[list[str], list["ConflictInfo"]]]: + """Try to merge a PR. Returns conflict info if conflicts, None if success.""" + result = run_git(["merge", "--no-commit", "--no-ff", f"pr-{pr_number}"], cwd=tmpdir, check=False) + + if result.returncode == 0: + return None + + # Conflict detected + conflict_files, conflict_details = extract_conflict_info(tmpdir, result.stderr) + run_git(["merge", "--abort"], cwd=tmpdir, check=False) + + return conflict_files, conflict_details + + +def extract_conflict_info(tmpdir: str, stderr: str) -> tuple[list[str], list["ConflictInfo"]]: + """Extract conflict information from git status.""" + status_result = run_git(["status", "--porcelain"], cwd=tmpdir, check=False) + + status_types = { + 'UU': 'content', + 'AA': 'both_added', + 'DD': 'both_deleted', + 'DU': 'deleted_by_us', + 'UD': 'deleted_by_them', + 'AU': 'added_by_us', + 'UA': 'added_by_them', + } + + conflict_files = [] + conflict_details = [] + + for line in status_result.stdout.split("\n"): + if len(line) >= 3 and line[0:2] in status_types: + status_code = line[0:2] + file_path = line[3:].strip() + conflict_files.append(file_path) + + info = analyze_conflict_markers(file_path, tmpdir) + info.conflict_type = status_types.get(status_code, 'unknown') + conflict_details.append(info) + + # Fallback to stderr parsing + if not conflict_files and stderr: + for line in stderr.split("\n"): + if "CONFLICT" in line and ":" in line: + parts = line.split(":") + if len(parts) > 1: + file_part = parts[-1].strip() + if file_part and not file_part.startswith("Merge"): + conflict_files.append(file_part) + conflict_details.append(ConflictInfo(path=file_part)) + + return conflict_files, conflict_details + + +def analyze_conflict_markers(file_path: str, cwd: str) -> "ConflictInfo": + """Analyze a conflicted file to count conflict regions and lines.""" + info = ConflictInfo(path=file_path) + + try: + full_path = os.path.join(cwd, file_path) + with open(full_path, 'r', errors='ignore') as f: + content = f.read() + + in_conflict = False + current_conflict_lines = 0 + + for line in content.split('\n'): + if line.startswith('<<<<<<<'): + in_conflict = True + info.conflict_count += 1 + current_conflict_lines = 1 + elif line.startswith('>>>>>>>'): + in_conflict = False + current_conflict_lines += 1 + info.conflict_lines += current_conflict_lines + elif in_conflict: + current_conflict_lines += 1 + except Exception as e: + print(f"Warning: Could not analyze conflict markers in {file_path}: {e}", file=sys.stderr) + + return info + + +# ============================================================================= +# DIFF PARSING +# ============================================================================= + +def parse_diff_ranges(diff: str) -> dict[str, "ChangedFile"]: + """Parse a unified diff and extract changed line ranges per file.""" + files = {} + current_file = None + pending_rename_from = None + is_rename = False + + for line in diff.split("\n"): + # Reset rename state on new file diff header + if line.startswith("diff --git "): + is_rename = False + pending_rename_from = None + elif line.startswith("rename from "): + pending_rename_from = line[12:] + is_rename = True + elif line.startswith("rename to "): + pass # rename target is captured via "+++ b/" line + elif line.startswith("similarity index"): + is_rename = True + elif line.startswith("+++ b/"): + path = line[6:] + current_file = ChangedFile( + path=path, + additions=[], + deletions=[], + is_rename=is_rename, + old_path=pending_rename_from + ) + files[path] = current_file + pending_rename_from = None + is_rename = False + elif line.startswith("--- /dev/null"): + is_rename = False + pending_rename_from = None + elif line.startswith("@@") and current_file: + parse_hunk_header(line, current_file) + + return files + + +def parse_hunk_header(line: str, current_file: "ChangedFile"): + """Parse a diff hunk header and add ranges to the file.""" + match = re.match(r"@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@", line) + if match: + old_start = int(match.group(1)) + old_count = int(match.group(2) or 1) + new_start = int(match.group(3)) + new_count = int(match.group(4) or 1) + + if old_count > 0: + current_file.deletions.append((old_start, old_start + old_count - 1)) + if new_count > 0: + current_file.additions.append((new_start, new_start + new_count - 1)) + + +# ============================================================================= +# GITHUB API +# ============================================================================= + +def get_repo_info() -> tuple[str, str]: + """Get owner and repo name from environment or git.""" + if os.environ.get("GITHUB_REPOSITORY"): + owner, repo = os.environ["GITHUB_REPOSITORY"].split("/") + return owner, repo + + result = run_gh(["repo", "view", "--json", "owner,name"]) + data = json.loads(result.stdout) + return data["owner"]["login"], data["name"] + + +def query_open_prs(owner: str, repo: str, base_branch: str) -> list[dict]: + """Query all open PRs targeting the specified base branch.""" + prs = [] + cursor = None + + while True: + after_clause = f', after: "{cursor}"' if cursor else "" + query = f''' + query {{ + repository(owner: "{owner}", name: "{repo}") {{ + pullRequests( + first: 100{after_clause}, + states: OPEN, + baseRefName: "{base_branch}", + orderBy: {{field: UPDATED_AT, direction: DESC}} + ) {{ + totalCount + edges {{ + node {{ + number + title + url + updatedAt + author {{ login }} + headRefName + baseRefName + files(first: 100) {{ + nodes {{ path }} + pageInfo {{ hasNextPage }} + }} + }} + }} + pageInfo {{ + endCursor + hasNextPage + }} + }} + }} + }} + ''' + + result = run_gh(["api", "graphql", "-f", f"query={query}"]) + data = json.loads(result.stdout) + + if "errors" in data: + print(f"GraphQL errors: {data['errors']}", file=sys.stderr) + sys.exit(1) + + pr_data = data["data"]["repository"]["pullRequests"] + for edge in pr_data["edges"]: + node = edge["node"] + files_data = node["files"] + # Warn if PR has more than 100 files (API limit, we only fetch first 100) + if files_data.get("pageInfo", {}).get("hasNextPage"): + print(f"Warning: PR #{node['number']} has >100 files, overlap detection may be incomplete", file=sys.stderr) + prs.append({ + "number": node["number"], + "title": node["title"], + "url": node["url"], + "updated_at": node.get("updatedAt"), + "author": node["author"]["login"] if node["author"] else "unknown", + "head_ref": node["headRefName"], + "base_ref": node["baseRefName"], + "files": [f["path"] for f in files_data["nodes"]] + }) + + if not pr_data["pageInfo"]["hasNextPage"]: + break + cursor = pr_data["pageInfo"]["endCursor"] + + return prs + + +def get_pr_diff(pr_number: int) -> str: + """Get the diff for a PR.""" + result = run_gh(["pr", "diff", str(pr_number)]) + return result.stdout + + +def post_or_update_comment(pr_number: int, body: str): + """Post a new comment or update existing overlap detection comment.""" + if not body: + return + + marker = "## 🔍 PR Overlap Detection" + + # Find existing comment using GraphQL + owner, repo = get_repo_info() + query = f''' + query {{ + repository(owner: "{owner}", name: "{repo}") {{ + pullRequest(number: {pr_number}) {{ + comments(first: 100) {{ + nodes {{ + id + body + author {{ login }} + }} + }} + }} + }} + }} + ''' + + result = run_gh(["api", "graphql", "-f", f"query={query}"], check=False) + + existing_comment_id = None + if result.returncode == 0: + try: + data = json.loads(result.stdout) + comments = data.get("data", {}).get("repository", {}).get("pullRequest", {}).get("comments", {}).get("nodes", []) + for comment in comments: + if marker in comment.get("body", ""): + existing_comment_id = comment["id"] + break + except Exception as e: + print(f"Warning: Could not search for existing comment: {e}", file=sys.stderr) + + if existing_comment_id: + # Update existing comment using GraphQL mutation + # Use json.dumps for proper escaping of all special characters + escaped_body = json.dumps(body)[1:-1] # Strip outer quotes added by json.dumps + mutation = f''' + mutation {{ + updateIssueComment(input: {{id: "{existing_comment_id}", body: "{escaped_body}"}}) {{ + issueComment {{ id }} + }} + }} + ''' + result = run_gh(["api", "graphql", "-f", f"query={mutation}"], check=False) + if result.returncode == 0: + print(f"Updated existing overlap comment") + else: + # Fallback to posting new comment + print(f"Failed to update comment, posting new one: {result.stderr}", file=sys.stderr) + run_gh(["pr", "comment", str(pr_number), "--body", body]) + else: + # Post new comment + run_gh(["pr", "comment", str(pr_number), "--body", body]) + + +def send_discord_notification(webhook_url: str, pr: "PullRequest", overlaps: list["Overlap"]): + """Send a Discord notification about significant overlaps.""" + conflicts = [o for o in overlaps if o.has_merge_conflict] + if not conflicts: + return + + # Discord limits: max 25 fields, max 1024 chars per field value + fields = [] + for o in conflicts[:25]: + other = o.pr_b if o.pr_a.number == pr.number else o.pr_a + # Build value string with truncation to stay under 1024 chars + file_list = o.conflict_files[:3] + files_str = f"Files: `{'`, `'.join(file_list)}`" + if len(o.conflict_files) > 3: + files_str += f" (+{len(o.conflict_files) - 3} more)" + value = f"[{other.title[:100]}]({other.url})\n{files_str}" + # Truncate if still too long + if len(value) > 1024: + value = value[:1020] + "..." + fields.append({ + "name": f"Conflicts with #{other.number}", + "value": value, + "inline": False + }) + + embed = { + "title": f"⚠️ PR #{pr.number} has merge conflicts", + "description": f"[{pr.title}]({pr.url})", + "color": 0xFF0000, + "fields": fields + } + + if len(conflicts) > 25: + embed["footer"] = {"text": f"... and {len(conflicts) - 25} more conflicts"} + + try: + subprocess.run( + ["curl", "-X", "POST", "-H", "Content-Type: application/json", + "--max-time", "10", + "-d", json.dumps({"embeds": [embed]}), webhook_url], + capture_output=True, + timeout=15 + ) + except subprocess.TimeoutExpired: + print("Warning: Discord webhook timed out", file=sys.stderr) + + +# ============================================================================= +# UTILITIES +# ============================================================================= + +def run_gh(args: list[str], check: bool = True) -> subprocess.CompletedProcess: + """Run a gh CLI command.""" + result = subprocess.run( + ["gh"] + args, + capture_output=True, + text=True, + check=False + ) + if check and result.returncode != 0: + print(f"Error running gh {' '.join(args)}: {result.stderr}", file=sys.stderr) + sys.exit(1) + return result + + +def run_git(args: list[str], cwd: str = None, check: bool = True) -> subprocess.CompletedProcess: + """Run a git command.""" + result = subprocess.run( + ["git"] + args, + capture_output=True, + text=True, + cwd=cwd, + check=False + ) + if check and result.returncode != 0: + print(f"Error running git {' '.join(args)}: {result.stderr}", file=sys.stderr) + return result + + +def should_ignore_file(path: str) -> bool: + """Check if a file should be ignored for overlap detection.""" + if path in IGNORE_FILES: + return True + basename = path.split("/")[-1] + return basename in IGNORE_FILES + + +def find_common_prefix(paths: list[str]) -> str: + """Find the common directory prefix of a list of file paths.""" + if not paths: + return "" + if len(paths) == 1: + parts = paths[0].rsplit('/', 1) + return parts[0] + '/' if len(parts) > 1 else "" + + split_paths = [p.split('/') for p in paths] + common = [] + for parts in zip(*split_paths): + if len(set(parts)) == 1: + common.append(parts[0]) + else: + break + + return '/'.join(common) + '/' if common else "" + + +def format_relative_time(iso_timestamp: str) -> str: + """Format an ISO timestamp as relative time.""" + if not iso_timestamp: + return "" + + from datetime import datetime, timezone + try: + dt = datetime.fromisoformat(iso_timestamp.replace('Z', '+00:00')) + now = datetime.now(timezone.utc) + diff = now - dt + + seconds = diff.total_seconds() + if seconds < 60: + return "just now" + elif seconds < 3600: + return f"{int(seconds / 60)}m ago" + elif seconds < 86400: + return f"{int(seconds / 3600)}h ago" + else: + return f"{int(seconds / 86400)}d ago" + except Exception as e: + print(f"Warning: Could not format relative time: {e}", file=sys.stderr) + return "" + + +# ============================================================================= +# DATA CLASSES +# ============================================================================= + +@dataclass +class ChangedFile: + """Represents a file changed in a PR.""" + path: str + additions: list[tuple[int, int]] + deletions: list[tuple[int, int]] + is_rename: bool = False + old_path: str = None + + +@dataclass +class PullRequest: + """Represents a pull request.""" + number: int + title: str + author: str + url: str + head_ref: str + base_ref: str + files: list[str] + changed_ranges: dict[str, ChangedFile] + updated_at: str = None + + +@dataclass +class ConflictInfo: + """Info about a single conflicting file.""" + path: str + conflict_count: int = 0 + conflict_lines: int = 0 + conflict_type: str = "content" + + +@dataclass +class Overlap: + """Represents an overlap between two PRs.""" + pr_a: PullRequest + pr_b: PullRequest + overlapping_files: list[str] + line_overlaps: dict[str, list[tuple[int, int]]] + has_merge_conflict: bool = False + conflict_files: list[str] = None + conflict_details: list[ConflictInfo] = None + conflict_type: str = None + + def __post_init__(self): + if self.conflict_files is None: + self.conflict_files = [] + if self.conflict_details is None: + self.conflict_details = [] + + +# ============================================================================= +# CONSTANTS +# ============================================================================= + +IGNORE_FILES = { + "autogpt_platform/frontend/src/app/api/openapi.json", + "poetry.lock", + "pnpm-lock.yaml", + "package-lock.json", + "yarn.lock", +} + + +# ============================================================================= +# ENTRY POINT +# ============================================================================= + +if __name__ == "__main__": + main() diff --git a/.github/workflows/pr-overlap-check.yml b/.github/workflows/pr-overlap-check.yml new file mode 100644 index 0000000000..c53f56321b --- /dev/null +++ b/.github/workflows/pr-overlap-check.yml @@ -0,0 +1,39 @@ +name: PR Overlap Detection + +on: + pull_request: + types: [opened, synchronize, reopened] + branches: + - dev + - master + +permissions: + contents: read + pull-requests: write + +jobs: + check-overlaps: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 # Need full history for merge testing + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Configure git + run: | + git config user.email "github-actions[bot]@users.noreply.github.com" + git config user.name "github-actions[bot]" + + - name: Run overlap detection + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + # Always succeed - this check informs contributors, it shouldn't block merging + continue-on-error: true + run: | + python .github/scripts/detect_overlaps.py ${{ github.event.pull_request.number }} From 52b3aebf7187386514f4b4de3eb52e06bd8dfd76 Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Fri, 13 Feb 2026 19:49:03 +0400 Subject: [PATCH 10/16] feat(backend/sdk): Claude Agent SDK integration for CoPilot (#12103) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Full integration of the **Claude Agent SDK** to replace the existing one-turn OpenAI-compatible CoPilot implementation with a multi-turn, tool-using AI agent. ### What changed **Core SDK Integration** (`chat/sdk/` — new module) - **`service.py`**: Main orchestrator — spawns Claude Code CLI as a subprocess per user message, streams responses back via SSE. Handles conversation history compression, session lifecycle, and error recovery. - **`response_adapter.py`**: Translates Claude Agent SDK events (text deltas, tool use, errors, result messages) into the existing CoPilot `StreamEvent` protocol so the frontend works unchanged. - **`tool_adapter.py`**: Bridges CoPilot's MCP tools (find_block, run_block, create_agent, etc.) into the SDK's tool format. Handles schema conversion and result serialization. - **`security_hooks.py`**: Pre/Post tool-use hooks that enforce a strict allowlist of tools, block path traversal, sandbox file operations to per-session workspace directories, cap sub-agent spawning, and prevent the model from accessing unauthorized system resources. - **`transcript.py`**: JSONL transcript I/O utilities for the stateless `--resume` feature (see below). **Stateless Multi-Turn Resume** (new) - Instead of compressing conversation history via LLM on every turn (lossy and expensive), we capture Claude Code's native JSONL session transcript via a **Stop hook** callback, persist it in the DB (`ChatSession.sdkTranscript`), and restore it on the next turn via `--resume `. - This preserves full tool call/result context across turns with zero token overhead for history. - Feature-flagged via `CLAUDE_AGENT_USE_RESUME` (default: off). - DB migration: `ALTER TABLE "ChatSession" ADD COLUMN "sdkTranscript" TEXT`. **Sandboxed Tool Execution** (`chat/tools/`) - **`bash_exec.py`**: Sandboxed bash execution using bubblewrap (`bwrap`) with read-only root filesystem, per-session writable workspace, resource limits (CPU, memory, file size), and network isolation. - **`sandbox.py`**: Shared bubblewrap sandbox infrastructure — generates `bwrap` command lines with configurable mounts, environment, and resource constraints. - **`web_fetch.py`**: URL fetching tool with domain allowlist, size limits, and content-type filtering. - **`check_operation_status.py`**: Polling tool for long-running operations (agent creation, block execution) so the SDK doesn't block waiting. - **`find_block.py`** / **`run_block.py`**: Enhanced with category filtering, optimized response size (removed raw JSON schemas), and better error handling. **Security** - Path traversal prevention: session IDs sanitized, all file ops confined to workspace dirs, symlink resolution. - Tool allowlist enforcement via SDK hooks — model cannot call arbitrary tools. - Built-in `Bash` tool blocked via `disallowed_tools` to prevent bypassing sandboxed `bash_exec`. - Sub-agent (`Task`) spawning capped at configurable limit (default: 10). - CodeQL-clean path sanitization patterns. **Streaming & Reconnection** - SSE stream registry backed by Redis Streams for crash-resilient reconnection. - Long-running operation tracking with TTL-based cleanup. - Atomic message append to prevent race conditions on concurrent writes. **Configuration** (`config.py`) - `use_claude_agent_sdk` — master toggle (default: on) - `claude_agent_model` — model override for SDK path - `claude_agent_max_buffer_size` — JSON parsing buffer (10MB) - `claude_agent_max_subtasks` — sub-agent cap (10) - `claude_agent_use_resume` — transcript-based resume (default: off) - `thinking_enabled` — extended thinking for Claude models **Tests** - `sdk/response_adapter_test.py` — 366 lines covering all event translation paths - `sdk/security_hooks_test.py` — 165 lines covering tool blocking, path traversal, subtask limits - `chat/model_test.py` — 214 lines covering session model serialization - `chat/service_test.py` — Integration tests including multi-turn resume keyword recall - `tools/find_block_test.py` / `run_block_test.py` — Extended with new tool behavior tests ## Test plan - [x] Unit tests pass (`sdk/response_adapter_test.py`, `security_hooks_test.py`, `model_test.py`) - [x] Integration test: multi-turn keyword recall via `--resume` (`service_test.py::test_sdk_resume_multi_turn`) - [x] Manual E2E: CoPilot chat sessions with tool calls, bash execution, and multi-turn context - [x] Pre-commit hooks pass (ruff, isort, black, pyright, flake8) - [ ] Staging deployment with `claude_agent_use_resume=false` initially - [ ] Enable resume in staging, verify transcript capture and recall

Greptile Overview

Greptile Summary

This PR replaces the existing OpenAI-compatible CoPilot with a full Claude Agent SDK integration, introducing multi-turn conversations, stateless resume via JSONL transcripts, and sandboxed tool execution. **Key changes:** - **SDK integration** (`chat/sdk/`): spawns Claude Code CLI subprocess per message, translates events to frontend protocol, bridges MCP tools - **Stateless resume**: captures JSONL transcripts via Stop hook, persists in `ChatSession.sdkTranscript`, restores with `--resume` (feature-flagged, default off) - **Sandboxed execution**: bubblewrap sandbox for bash commands with filesystem whitelist, network isolation, resource limits - **Security hooks**: tool allowlist enforcement, path traversal prevention, workspace-scoped file operations, sub-agent spawn limits - **Long-running operations**: delegates `create_agent`/`edit_agent` to existing stream_registry infrastructure for SSE reconnection - **Feature flag**: `CHAT_USE_CLAUDE_AGENT_SDK` with LaunchDarkly support, defaults to enabled **Security issues found:** - Path traversal validation has logic errors in `security_hooks.py:82` (tilde expansion order) and `service.py:266` (redundant `..` check) - Config validator always prefers env var over explicit `False` value (`config.py:162`) - Race condition in `routes.py:323` — message persisted before task registration, could duplicate on retry - Resource limits in sandbox may fail silently (`sandbox.py:109`) **Test coverage is strong** with 366 lines for response adapter, 165 for security hooks, and integration tests for multi-turn resume.

Confidence Score: 3/5

- This PR is generally safe but has critical security issues in path validation that must be fixed before merge - Score reflects strong architecture and test coverage offset by real security vulnerabilities: the tilde expansion bug in `security_hooks.py` could allow sandbox escape, the race condition could cause message duplication, and the silent ulimit failures could bypass resource limits. The bubblewrap sandbox and allowlist enforcement are well-designed, but the path validation bugs need fixing. The transcript resume feature is properly feature-flagged. Overall the implementation is solid but the security issues prevent a higher score. - Pay close attention to `backend/api/features/chat/sdk/security_hooks.py` (path traversal vulnerability), `backend/api/features/chat/routes.py` (race condition), `backend/api/features/chat/tools/sandbox.py` (silent resource limit failures), and `backend/api/features/chat/sdk/service.py` (redundant security check)

Sequence Diagram

```mermaid sequenceDiagram participant Frontend participant Routes as routes.py participant SDKService as sdk/service.py participant ClaudeSDK as Claude Agent SDK CLI participant SecurityHooks as security_hooks.py participant ToolAdapter as tool_adapter.py participant CoPilotTools as tools/* participant Sandbox as sandbox.py (bwrap) participant DB as Database participant Redis as stream_registry Frontend->>Routes: POST /chat (user message) Routes->>SDKService: stream_chat_completion_sdk() SDKService->>DB: get_chat_session() DB-->>SDKService: session + messages alt Resume enabled AND transcript exists SDKService->>SDKService: validate_transcript() SDKService->>SDKService: write_transcript_to_tempfile() Note over SDKService: Pass --resume to SDK else No resume SDKService->>SDKService: _compress_conversation_history() Note over SDKService: Inject history into user message end SDKService->>SecurityHooks: create_security_hooks() SDKService->>ToolAdapter: create_copilot_mcp_server() SDKService->>ClaudeSDK: spawn subprocess with MCP server loop Streaming Conversation ClaudeSDK->>SDKService: AssistantMessage (text/tool_use) SDKService->>Frontend: StreamTextDelta / StreamToolInputAvailable alt Tool Call ClaudeSDK->>SecurityHooks: PreToolUse hook SecurityHooks->>SecurityHooks: validate path, check allowlist alt Tool blocked SecurityHooks-->>ClaudeSDK: deny else Tool allowed SecurityHooks-->>ClaudeSDK: allow ClaudeSDK->>ToolAdapter: call MCP tool alt Long-running tool (create_agent, edit_agent) ToolAdapter->>Redis: register task ToolAdapter->>DB: save OperationPendingResponse ToolAdapter->>ToolAdapter: spawn background task ToolAdapter-->>ClaudeSDK: OperationStartedResponse else Regular tool (find_block, bash_exec) ToolAdapter->>CoPilotTools: execute() alt bash_exec CoPilotTools->>Sandbox: run_sandboxed() Sandbox->>Sandbox: build bwrap command Note over Sandbox: Network isolation,
filesystem whitelist,
resource limits Sandbox-->>CoPilotTools: stdout, stderr, exit_code end CoPilotTools-->>ToolAdapter: result ToolAdapter->>ToolAdapter: stash full output ToolAdapter-->>ClaudeSDK: MCP response end SecurityHooks->>SecurityHooks: PostToolUse hook (log) end end ClaudeSDK->>SDKService: UserMessage (ToolResultBlock) SDKService->>ToolAdapter: pop_pending_tool_output() SDKService->>Frontend: StreamToolOutputAvailable end ClaudeSDK->>SecurityHooks: Stop hook SecurityHooks->>SDKService: transcript_path callback SDKService->>SDKService: read_transcript_file() SDKService->>DB: save transcript to session.sdkTranscript ClaudeSDK->>SDKService: ResultMessage (success) SDKService->>Frontend: StreamFinish SDKService->>DB: upsert_chat_session() ```
Last reviewed commit: 28c1121 --------- Co-authored-by: Swifty --- autogpt_platform/backend/Dockerfile | 8 +- .../backend/api/features/chat/config.py | 45 +- .../backend/api/features/chat/model.py | 73 +- .../backend/api/features/chat/routes.py | 106 ++- .../backend/api/features/chat/sdk/__init__.py | 14 + .../api/features/chat/sdk/response_adapter.py | 203 +++++ .../chat/sdk/response_adapter_test.py | 366 +++++++++ .../api/features/chat/sdk/security_hooks.py | 335 ++++++++ .../features/chat/sdk/security_hooks_test.py | 165 ++++ .../backend/api/features/chat/sdk/service.py | 751 ++++++++++++++++++ .../api/features/chat/sdk/tool_adapter.py | 322 ++++++++ .../api/features/chat/sdk/transcript.py | 356 +++++++++ .../backend/api/features/chat/service.py | 22 +- .../backend/api/features/chat/service_test.py | 96 +++ .../api/features/chat/stream_registry.py | 22 + .../api/features/chat/tools/__init__.py | 9 + .../api/features/chat/tools/bash_exec.py | 131 +++ .../chat/tools/check_operation_status.py | 127 +++ .../api/features/chat/tools/find_block.py | 1 + .../backend/api/features/chat/tools/models.py | 44 + .../api/features/chat/tools/sandbox.py | 265 ++++++ .../api/features/chat/tools/web_fetch.py | 151 ++++ .../features/chat/tools/workspace_files.py | 14 +- .../backend/backend/util/feature_flag.py | 1 + autogpt_platform/backend/poetry.lock | 94 ++- autogpt_platform/backend/pyproject.toml | 1 + .../backend/test/chat/__init__.py | 0 .../backend/test/chat/test_security_hooks.py | 133 ++++ .../backend/test/chat/test_transcript.py | 255 ++++++ .../ChatMessagesContainer.tsx | 11 + .../copilot/tools/GenericTool/GenericTool.tsx | 63 ++ .../frontend/src/app/api/openapi.json | 58 +- 32 files changed, 4187 insertions(+), 55 deletions(-) create mode 100644 autogpt_platform/backend/backend/api/features/chat/sdk/__init__.py create mode 100644 autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py create mode 100644 autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter_test.py create mode 100644 autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py create mode 100644 autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks_test.py create mode 100644 autogpt_platform/backend/backend/api/features/chat/sdk/service.py create mode 100644 autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py create mode 100644 autogpt_platform/backend/backend/api/features/chat/sdk/transcript.py create mode 100644 autogpt_platform/backend/backend/api/features/chat/tools/bash_exec.py create mode 100644 autogpt_platform/backend/backend/api/features/chat/tools/check_operation_status.py create mode 100644 autogpt_platform/backend/backend/api/features/chat/tools/sandbox.py create mode 100644 autogpt_platform/backend/backend/api/features/chat/tools/web_fetch.py create mode 100644 autogpt_platform/backend/test/chat/__init__.py create mode 100644 autogpt_platform/backend/test/chat/test_security_hooks.py create mode 100644 autogpt_platform/backend/test/chat/test_transcript.py create mode 100644 autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx diff --git a/autogpt_platform/backend/Dockerfile b/autogpt_platform/backend/Dockerfile index ace534b730..05a8d4858b 100644 --- a/autogpt_platform/backend/Dockerfile +++ b/autogpt_platform/backend/Dockerfile @@ -66,13 +66,19 @@ ENV POETRY_HOME=/opt/poetry \ DEBIAN_FRONTEND=noninteractive ENV PATH=/opt/poetry/bin:$PATH -# Install Python, FFmpeg, and ImageMagick (required for video processing blocks) +# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use. +# bubblewrap provides OS-level sandbox (whitelist-only FS + no network) +# for the bash_exec MCP tool. # Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc. RUN apt-get update && apt-get install -y --no-install-recommends \ python3.13 \ python3-pip \ ffmpeg \ imagemagick \ + jq \ + ripgrep \ + tree \ + bubblewrap \ && rm -rf /var/lib/apt/lists/* COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3* diff --git a/autogpt_platform/backend/backend/api/features/chat/config.py b/autogpt_platform/backend/backend/api/features/chat/config.py index 808692f97f..04bbe8e60d 100644 --- a/autogpt_platform/backend/backend/api/features/chat/config.py +++ b/autogpt_platform/backend/backend/api/features/chat/config.py @@ -27,12 +27,11 @@ class ChatConfig(BaseSettings): session_ttl: int = Field(default=43200, description="Session TTL in seconds") # Streaming Configuration - max_context_messages: int = Field( - default=50, ge=1, le=200, description="Maximum context messages" - ) - stream_timeout: int = Field(default=300, description="Stream timeout in seconds") - max_retries: int = Field(default=3, description="Maximum number of retries") + max_retries: int = Field( + default=3, + description="Max retries for fallback path (SDK handles retries internally)", + ) max_agent_runs: int = Field(default=30, description="Maximum number of agent runs") max_agent_schedules: int = Field( default=30, description="Maximum number of agent schedules" @@ -93,6 +92,31 @@ class ChatConfig(BaseSettings): description="Name of the prompt in Langfuse to fetch", ) + # Claude Agent SDK Configuration + use_claude_agent_sdk: bool = Field( + default=True, + description="Use Claude Agent SDK for chat completions", + ) + claude_agent_model: str | None = Field( + default=None, + description="Model for the Claude Agent SDK path. If None, derives from " + "the `model` field by stripping the OpenRouter provider prefix.", + ) + claude_agent_max_buffer_size: int = Field( + default=10 * 1024 * 1024, # 10MB (default SDK is 1MB) + description="Max buffer size in bytes for Claude Agent SDK JSON message parsing. " + "Increase if tool outputs exceed the limit.", + ) + claude_agent_max_subtasks: int = Field( + default=10, + description="Max number of sub-agent Tasks the SDK can spawn per session.", + ) + claude_agent_use_resume: bool = Field( + default=True, + description="Use --resume for multi-turn conversations instead of " + "history compression. Falls back to compression when unavailable.", + ) + # Extended thinking configuration for Claude models thinking_enabled: bool = Field( default=True, @@ -138,6 +162,17 @@ class ChatConfig(BaseSettings): v = os.getenv("CHAT_INTERNAL_API_KEY") return v + @field_validator("use_claude_agent_sdk", mode="before") + @classmethod + def get_use_claude_agent_sdk(cls, v): + """Get use_claude_agent_sdk from environment if not provided.""" + # Check environment variable - default to True if not set + env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower() + if env_val: + return env_val in ("true", "1", "yes", "on") + # Default to True (SDK enabled by default) + return True if v is None else v + # Prompt paths for different contexts PROMPT_PATHS: dict[str, str] = { "default": "prompts/chat_system.md", diff --git a/autogpt_platform/backend/backend/api/features/chat/model.py b/autogpt_platform/backend/backend/api/features/chat/model.py index 35418f174f..30ac27aece 100644 --- a/autogpt_platform/backend/backend/api/features/chat/model.py +++ b/autogpt_platform/backend/backend/api/features/chat/model.py @@ -334,9 +334,8 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None: try: session = ChatSession.model_validate_json(raw_session) logger.info( - f"Loading session {session_id} from cache: " - f"message_count={len(session.messages)}, " - f"roles={[m.role for m in session.messages]}" + f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, " + f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles ) return session except Exception as e: @@ -378,11 +377,9 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None: return None messages = prisma_session.Messages - logger.info( - f"Loading session {session_id} from DB: " - f"has_messages={messages is not None}, " - f"message_count={len(messages) if messages else 0}, " - f"roles={[m.role for m in messages] if messages else []}" + logger.debug( + f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, " + f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles ) return ChatSession.from_db(prisma_session, messages) @@ -433,10 +430,9 @@ async def _save_session_to_db( "function_call": msg.function_call, } ) - logger.info( - f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: " - f"roles={[m['role'] for m in messages_data]}, " - f"start_sequence={existing_message_count}" + logger.debug( + f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, " + f"roles={[m['role'] for m in messages_data]}" ) await chat_db.add_chat_messages_batch( session_id=session.session_id, @@ -476,7 +472,7 @@ async def get_chat_session( logger.warning(f"Unexpected cache error for session {session_id}: {e}") # Fall back to database - logger.info(f"Session {session_id} not in cache, checking database") + logger.debug(f"Session {session_id} not in cache, checking database") session = await _get_session_from_db(session_id) if session is None: @@ -493,7 +489,6 @@ async def get_chat_session( # Cache the session from DB try: await _cache_session(session) - logger.info(f"Cached session {session_id} from database") except Exception as e: logger.warning(f"Failed to cache session {session_id}: {e}") @@ -558,6 +553,40 @@ async def upsert_chat_session( return session +async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession: + """Atomically append a message to a session and persist it. + + Acquires the session lock, re-fetches the latest session state, + appends the message, and saves — preventing message loss when + concurrent requests modify the same session. + """ + lock = await _get_session_lock(session_id) + + async with lock: + session = await get_chat_session(session_id) + if session is None: + raise ValueError(f"Session {session_id} not found") + + session.messages.append(message) + existing_message_count = await chat_db.get_chat_session_message_count( + session_id + ) + + try: + await _save_session_to_db(session, existing_message_count) + except Exception as e: + raise DatabaseError( + f"Failed to persist message to session {session_id}" + ) from e + + try: + await _cache_session(session) + except Exception as e: + logger.warning(f"Cache write failed for session {session_id}: {e}") + + return session + + async def create_chat_session(user_id: str) -> ChatSession: """Create a new chat session and persist it. @@ -664,13 +693,19 @@ async def update_session_title(session_id: str, title: str) -> bool: logger.warning(f"Session {session_id} not found for title update") return False - # Invalidate cache so next fetch gets updated title + # Update title in cache if it exists (instead of invalidating). + # This prevents race conditions where cache invalidation causes + # the frontend to see stale DB data while streaming is still in progress. try: - redis_key = _get_session_cache_key(session_id) - async_redis = await get_redis_async() - await async_redis.delete(redis_key) + cached = await _get_session_from_cache(session_id) + if cached: + cached.title = title + await _cache_session(cached) except Exception as e: - logger.warning(f"Failed to invalidate cache for session {session_id}: {e}") + # Not critical - title will be correct on next full cache refresh + logger.warning( + f"Failed to update title in cache for session {session_id}: {e}" + ) return True except Exception as e: diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index 0d8b12b0b7..aa565ca891 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -1,5 +1,6 @@ """Chat API routes for chat session management and streaming via SSE.""" +import asyncio import logging import uuid as uuid_module from collections.abc import AsyncGenerator @@ -11,13 +12,22 @@ from fastapi.responses import StreamingResponse from pydantic import BaseModel from backend.util.exceptions import NotFoundError +from backend.util.feature_flag import Flag, is_feature_enabled from . import service as chat_service from . import stream_registry from .completion_handler import process_operation_failure, process_operation_success from .config import ChatConfig -from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions -from .response_model import StreamFinish, StreamHeartbeat +from .model import ( + ChatMessage, + ChatSession, + append_and_save_message, + create_chat_session, + get_chat_session, + get_user_sessions, +) +from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart +from .sdk import service as sdk_service from .tools.models import ( AgentDetailsResponse, AgentOutputResponse, @@ -41,6 +51,7 @@ from .tools.models import ( SetupRequirementsResponse, UnderstandingUpdatedResponse, ) +from .tracking import track_user_message config = ChatConfig() @@ -232,6 +243,10 @@ async def get_session( active_task, last_message_id = await stream_registry.get_active_task_for_session( session_id, user_id ) + logger.info( + f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, " + f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}" + ) if active_task: # Filter out the in-progress assistant message from the session response. # The client will receive the complete assistant response through the SSE @@ -301,10 +316,9 @@ async def stream_chat_post( f"user={user_id}, message_len={len(request.message)}", extra={"json_fields": log_meta}, ) - session = await _validate_and_get_session(session_id, user_id) logger.info( - f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms", + f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms", extra={ "json_fields": { **log_meta, @@ -313,6 +327,25 @@ async def stream_chat_post( }, ) + # Atomically append user message to session BEFORE creating task to avoid + # race condition where GET_SESSION sees task as "running" but message isn't + # saved yet. append_and_save_message re-fetches inside a lock to prevent + # message loss from concurrent requests. + if request.message: + message = ChatMessage( + role="user" if request.is_user_message else "assistant", + content=request.message, + ) + if request.is_user_message: + track_user_message( + user_id=user_id, + session_id=session_id, + message_length=len(request.message), + ) + logger.info(f"[STREAM] Saving user message to session {session_id}") + session = await append_and_save_message(session_id, message) + logger.info(f"[STREAM] User message saved for session {session_id}") + # Create a task in the stream registry for reconnection support task_id = str(uuid_module.uuid4()) operation_id = str(uuid_module.uuid4()) @@ -328,7 +361,7 @@ async def stream_chat_post( operation_id=operation_id, ) logger.info( - f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms", + f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms", extra={ "json_fields": { **log_meta, @@ -349,15 +382,47 @@ async def stream_chat_post( first_chunk_time, ttfc = None, None chunk_count = 0 try: - async for chunk in chat_service.stream_chat_completion( + # Emit a start event with task_id for reconnection + start_chunk = StreamStart(messageId=task_id, taskId=task_id) + await stream_registry.publish_chunk(task_id, start_chunk) + logger.info( + f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms", + extra={ + "json_fields": { + **log_meta, + "elapsed_ms": (time_module.perf_counter() - gen_start_time) + * 1000, + } + }, + ) + + # Choose service based on LaunchDarkly flag (falls back to config default) + use_sdk = await is_feature_enabled( + Flag.COPILOT_SDK, + user_id or "anonymous", + default=config.use_claude_agent_sdk, + ) + stream_fn = ( + sdk_service.stream_chat_completion_sdk + if use_sdk + else chat_service.stream_chat_completion + ) + logger.info( + f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion", + extra={"json_fields": log_meta}, + ) + # Pass message=None since we already added it to the session above + async for chunk in stream_fn( session_id, - request.message, + None, # Message already in session is_user_message=request.is_user_message, user_id=user_id, - session=session, # Pass pre-fetched session to avoid double-fetch + session=session, # Pass session with message already added context=request.context, - _task_id=task_id, # Pass task_id so service emits start with taskId for reconnection ): + # Skip duplicate StreamStart — we already published one above + if isinstance(chunk, StreamStart): + continue chunk_count += 1 if first_chunk_time is None: first_chunk_time = time_module.perf_counter() @@ -378,7 +443,7 @@ async def stream_chat_post( gen_end_time = time_module.perf_counter() total_time = (gen_end_time - gen_start_time) * 1000 logger.info( - f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; " + f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; " f"task={task_id}, session={session_id}, " f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}", extra={ @@ -405,6 +470,17 @@ async def stream_chat_post( } }, ) + # Publish a StreamError so the frontend can display an error message + try: + await stream_registry.publish_chunk( + task_id, + StreamError( + errorText="An error occurred. Please try again.", + code="stream_error", + ), + ) + except Exception: + pass # Best-effort; mark_task_completed will publish StreamFinish await stream_registry.mark_task_completed(task_id, "failed") # Start the AI generation in a background task @@ -507,8 +583,14 @@ async def stream_chat_post( "json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)} }, ) + # Surface error to frontend so it doesn't appear stuck + yield StreamError( + errorText="An error occurred. Please try again.", + code="stream_error", + ).to_sse() + yield StreamFinish().to_sse() finally: - # Unsubscribe when client disconnects or stream ends to prevent resource leak + # Unsubscribe when client disconnects or stream ends if subscriber_queue is not None: try: await stream_registry.unsubscribe_from_task( @@ -752,8 +834,6 @@ async def stream_task( ) async def event_generator() -> AsyncGenerator[str, None]: - import asyncio - heartbeat_interval = 15.0 # Send heartbeat every 15 seconds try: while True: diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/__init__.py b/autogpt_platform/backend/backend/api/features/chat/sdk/__init__.py new file mode 100644 index 0000000000..7d9d6371e9 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/__init__.py @@ -0,0 +1,14 @@ +"""Claude Agent SDK integration for CoPilot. + +This module provides the integration layer between the Claude Agent SDK +and the existing CoPilot tool system, enabling drop-in replacement of +the current LLM orchestration with the battle-tested Claude Agent SDK. +""" + +from .service import stream_chat_completion_sdk +from .tool_adapter import create_copilot_mcp_server + +__all__ = [ + "stream_chat_completion_sdk", + "create_copilot_mcp_server", +] diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py new file mode 100644 index 0000000000..f7151f8319 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py @@ -0,0 +1,203 @@ +"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format. + +This module provides the adapter layer that converts streaming messages from +the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that +the frontend expects. +""" + +import json +import logging +import uuid + +from claude_agent_sdk import ( + AssistantMessage, + Message, + ResultMessage, + SystemMessage, + TextBlock, + ToolResultBlock, + ToolUseBlock, + UserMessage, +) + +from backend.api.features.chat.response_model import ( + StreamBaseResponse, + StreamError, + StreamFinish, + StreamFinishStep, + StreamStart, + StreamStartStep, + StreamTextDelta, + StreamTextEnd, + StreamTextStart, + StreamToolInputAvailable, + StreamToolInputStart, + StreamToolOutputAvailable, +) +from backend.api.features.chat.sdk.tool_adapter import ( + MCP_TOOL_PREFIX, + pop_pending_tool_output, +) + +logger = logging.getLogger(__name__) + + +class SDKResponseAdapter: + """Adapter for converting Claude Agent SDK messages to Vercel AI SDK format. + + This class maintains state during a streaming session to properly track + text blocks, tool calls, and message lifecycle. + """ + + def __init__(self, message_id: str | None = None): + self.message_id = message_id or str(uuid.uuid4()) + self.text_block_id = str(uuid.uuid4()) + self.has_started_text = False + self.has_ended_text = False + self.current_tool_calls: dict[str, dict[str, str]] = {} + self.task_id: str | None = None + self.step_open = False + + def set_task_id(self, task_id: str) -> None: + """Set the task ID for reconnection support.""" + self.task_id = task_id + + def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]: + """Convert a single SDK message to Vercel AI SDK format.""" + responses: list[StreamBaseResponse] = [] + + if isinstance(sdk_message, SystemMessage): + if sdk_message.subtype == "init": + responses.append( + StreamStart(messageId=self.message_id, taskId=self.task_id) + ) + # Open the first step (matches non-SDK: StreamStart then StreamStartStep) + responses.append(StreamStartStep()) + self.step_open = True + + elif isinstance(sdk_message, AssistantMessage): + # After tool results, the SDK sends a new AssistantMessage for the + # next LLM turn. Open a new step if the previous one was closed. + if not self.step_open: + responses.append(StreamStartStep()) + self.step_open = True + + for block in sdk_message.content: + if isinstance(block, TextBlock): + if block.text: + self._ensure_text_started(responses) + responses.append( + StreamTextDelta(id=self.text_block_id, delta=block.text) + ) + + elif isinstance(block, ToolUseBlock): + self._end_text_if_open(responses) + + # Strip MCP prefix so frontend sees "find_block" + # instead of "mcp__copilot__find_block". + tool_name = block.name.removeprefix(MCP_TOOL_PREFIX) + + responses.append( + StreamToolInputStart(toolCallId=block.id, toolName=tool_name) + ) + responses.append( + StreamToolInputAvailable( + toolCallId=block.id, + toolName=tool_name, + input=block.input, + ) + ) + self.current_tool_calls[block.id] = {"name": tool_name} + + elif isinstance(sdk_message, UserMessage): + # UserMessage carries tool results back from tool execution. + content = sdk_message.content + blocks = content if isinstance(content, list) else [] + for block in blocks: + if isinstance(block, ToolResultBlock) and block.tool_use_id: + tool_info = self.current_tool_calls.get(block.tool_use_id, {}) + tool_name = tool_info.get("name", "unknown") + + # Prefer the stashed full output over the SDK's + # (potentially truncated) ToolResultBlock content. + # The SDK truncates large results, writing them to disk, + # which breaks frontend widget parsing. + output = pop_pending_tool_output(tool_name) or ( + _extract_tool_output(block.content) + ) + + responses.append( + StreamToolOutputAvailable( + toolCallId=block.tool_use_id, + toolName=tool_name, + output=output, + success=not (block.is_error or False), + ) + ) + + # Close the current step after tool results — the next + # AssistantMessage will open a new step for the continuation. + if self.step_open: + responses.append(StreamFinishStep()) + self.step_open = False + + elif isinstance(sdk_message, ResultMessage): + self._end_text_if_open(responses) + # Close the step before finishing. + if self.step_open: + responses.append(StreamFinishStep()) + self.step_open = False + + if sdk_message.subtype == "success": + responses.append(StreamFinish()) + elif sdk_message.subtype in ("error", "error_during_execution"): + error_msg = getattr(sdk_message, "result", None) or "Unknown error" + responses.append( + StreamError(errorText=str(error_msg), code="sdk_error") + ) + responses.append(StreamFinish()) + else: + logger.warning( + f"Unexpected ResultMessage subtype: {sdk_message.subtype}" + ) + responses.append(StreamFinish()) + + else: + logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}") + + return responses + + def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None: + """Start (or restart) a text block if needed.""" + if not self.has_started_text or self.has_ended_text: + if self.has_ended_text: + self.text_block_id = str(uuid.uuid4()) + self.has_ended_text = False + responses.append(StreamTextStart(id=self.text_block_id)) + self.has_started_text = True + + def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None: + """End the current text block if one is open.""" + if self.has_started_text and not self.has_ended_text: + responses.append(StreamTextEnd(id=self.text_block_id)) + self.has_ended_text = True + + +def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str: + """Extract a string output from a ToolResultBlock's content field.""" + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [item.get("text", "") for item in content if item.get("type") == "text"] + if parts: + return "".join(parts) + try: + return json.dumps(content) + except (TypeError, ValueError): + return str(content) + if content is None: + return "" + try: + return json.dumps(content) + except (TypeError, ValueError): + return str(content) diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter_test.py b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter_test.py new file mode 100644 index 0000000000..a4f2502642 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter_test.py @@ -0,0 +1,366 @@ +"""Unit tests for the SDK response adapter.""" + +from claude_agent_sdk import ( + AssistantMessage, + ResultMessage, + SystemMessage, + TextBlock, + ToolResultBlock, + ToolUseBlock, + UserMessage, +) + +from backend.api.features.chat.response_model import ( + StreamBaseResponse, + StreamError, + StreamFinish, + StreamFinishStep, + StreamStart, + StreamStartStep, + StreamTextDelta, + StreamTextEnd, + StreamTextStart, + StreamToolInputAvailable, + StreamToolInputStart, + StreamToolOutputAvailable, +) + +from .response_adapter import SDKResponseAdapter +from .tool_adapter import MCP_TOOL_PREFIX + + +def _adapter() -> SDKResponseAdapter: + a = SDKResponseAdapter(message_id="msg-1") + a.set_task_id("task-1") + return a + + +# -- SystemMessage ----------------------------------------------------------- + + +def test_system_init_emits_start_and_step(): + adapter = _adapter() + results = adapter.convert_message(SystemMessage(subtype="init", data={})) + assert len(results) == 2 + assert isinstance(results[0], StreamStart) + assert results[0].messageId == "msg-1" + assert results[0].taskId == "task-1" + assert isinstance(results[1], StreamStartStep) + + +def test_system_non_init_emits_nothing(): + adapter = _adapter() + results = adapter.convert_message(SystemMessage(subtype="other", data={})) + assert results == [] + + +# -- AssistantMessage with TextBlock ----------------------------------------- + + +def test_text_block_emits_step_start_and_delta(): + adapter = _adapter() + msg = AssistantMessage(content=[TextBlock(text="hello")], model="test") + results = adapter.convert_message(msg) + assert len(results) == 3 + assert isinstance(results[0], StreamStartStep) + assert isinstance(results[1], StreamTextStart) + assert isinstance(results[2], StreamTextDelta) + assert results[2].delta == "hello" + + +def test_empty_text_block_emits_only_step(): + adapter = _adapter() + msg = AssistantMessage(content=[TextBlock(text="")], model="test") + results = adapter.convert_message(msg) + # Empty text skipped, but step still opens + assert len(results) == 1 + assert isinstance(results[0], StreamStartStep) + + +def test_multiple_text_deltas_reuse_block_id(): + adapter = _adapter() + msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test") + msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test") + r1 = adapter.convert_message(msg1) + r2 = adapter.convert_message(msg2) + # First gets step+start+delta, second only delta (block & step already started) + assert len(r1) == 3 + assert isinstance(r1[0], StreamStartStep) + assert isinstance(r1[1], StreamTextStart) + assert len(r2) == 1 + assert isinstance(r2[0], StreamTextDelta) + assert r1[1].id == r2[0].id # same block ID + + +# -- AssistantMessage with ToolUseBlock -------------------------------------- + + +def test_tool_use_emits_input_start_and_available(): + """Tool names arrive with MCP prefix and should be stripped for the frontend.""" + adapter = _adapter() + msg = AssistantMessage( + content=[ + ToolUseBlock( + id="tool-1", + name=f"{MCP_TOOL_PREFIX}find_agent", + input={"q": "x"}, + ) + ], + model="test", + ) + results = adapter.convert_message(msg) + assert len(results) == 3 + assert isinstance(results[0], StreamStartStep) + assert isinstance(results[1], StreamToolInputStart) + assert results[1].toolCallId == "tool-1" + assert results[1].toolName == "find_agent" # prefix stripped + assert isinstance(results[2], StreamToolInputAvailable) + assert results[2].toolName == "find_agent" # prefix stripped + assert results[2].input == {"q": "x"} + + +def test_text_then_tool_ends_text_block(): + adapter = _adapter() + text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test") + tool_msg = AssistantMessage( + content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})], + model="test", + ) + adapter.convert_message(text_msg) # opens step + text + results = adapter.convert_message(tool_msg) + # Step already open, so: TextEnd, ToolInputStart, ToolInputAvailable + assert len(results) == 3 + assert isinstance(results[0], StreamTextEnd) + assert isinstance(results[1], StreamToolInputStart) + + +# -- UserMessage with ToolResultBlock ---------------------------------------- + + +def test_tool_result_emits_output_and_finish_step(): + adapter = _adapter() + # First register the tool call (opens step) — SDK sends prefixed name + tool_msg = AssistantMessage( + content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_agent", input={})], + model="test", + ) + adapter.convert_message(tool_msg) + + # Now send tool result + result_msg = UserMessage( + content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")] + ) + results = adapter.convert_message(result_msg) + assert len(results) == 2 + assert isinstance(results[0], StreamToolOutputAvailable) + assert results[0].toolCallId == "t1" + assert results[0].toolName == "find_agent" # prefix stripped + assert results[0].output == "found 3 agents" + assert results[0].success is True + assert isinstance(results[1], StreamFinishStep) + + +def test_tool_result_error(): + adapter = _adapter() + adapter.convert_message( + AssistantMessage( + content=[ + ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}run_agent", input={}) + ], + model="test", + ) + ) + result_msg = UserMessage( + content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)] + ) + results = adapter.convert_message(result_msg) + assert isinstance(results[0], StreamToolOutputAvailable) + assert results[0].success is False + assert isinstance(results[1], StreamFinishStep) + + +def test_tool_result_list_content(): + adapter = _adapter() + adapter.convert_message( + AssistantMessage( + content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})], + model="test", + ) + ) + result_msg = UserMessage( + content=[ + ToolResultBlock( + tool_use_id="t1", + content=[ + {"type": "text", "text": "line1"}, + {"type": "text", "text": "line2"}, + ], + ) + ] + ) + results = adapter.convert_message(result_msg) + assert isinstance(results[0], StreamToolOutputAvailable) + assert results[0].output == "line1line2" + assert isinstance(results[1], StreamFinishStep) + + +def test_string_user_message_ignored(): + """A plain string UserMessage (not tool results) produces no output.""" + adapter = _adapter() + results = adapter.convert_message(UserMessage(content="hello")) + assert results == [] + + +# -- ResultMessage ----------------------------------------------------------- + + +def test_result_success_emits_finish_step_and_finish(): + adapter = _adapter() + # Start some text first (opens step) + adapter.convert_message( + AssistantMessage(content=[TextBlock(text="done")], model="test") + ) + msg = ResultMessage( + subtype="success", + duration_ms=100, + duration_api_ms=50, + is_error=False, + num_turns=1, + session_id="s1", + ) + results = adapter.convert_message(msg) + # TextEnd + FinishStep + StreamFinish + assert len(results) == 3 + assert isinstance(results[0], StreamTextEnd) + assert isinstance(results[1], StreamFinishStep) + assert isinstance(results[2], StreamFinish) + + +def test_result_error_emits_error_and_finish(): + adapter = _adapter() + msg = ResultMessage( + subtype="error", + duration_ms=100, + duration_api_ms=50, + is_error=True, + num_turns=0, + session_id="s1", + result="API rate limited", + ) + results = adapter.convert_message(msg) + # No step was open, so no FinishStep — just Error + Finish + assert len(results) == 2 + assert isinstance(results[0], StreamError) + assert "API rate limited" in results[0].errorText + assert isinstance(results[1], StreamFinish) + + +# -- Text after tools (new block ID) ---------------------------------------- + + +def test_text_after_tool_gets_new_block_id(): + adapter = _adapter() + # Text -> Tool -> ToolResult -> Text should get a new text block ID and step + adapter.convert_message( + AssistantMessage(content=[TextBlock(text="before")], model="test") + ) + adapter.convert_message( + AssistantMessage( + content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})], + model="test", + ) + ) + # Send tool result (closes step) + adapter.convert_message( + UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="ok")]) + ) + results = adapter.convert_message( + AssistantMessage(content=[TextBlock(text="after")], model="test") + ) + # Should get StreamStartStep (new step) + StreamTextStart (new block) + StreamTextDelta + assert len(results) == 3 + assert isinstance(results[0], StreamStartStep) + assert isinstance(results[1], StreamTextStart) + assert isinstance(results[2], StreamTextDelta) + assert results[2].delta == "after" + + +# -- Full conversation flow -------------------------------------------------- + + +def test_full_conversation_flow(): + """Simulate a complete conversation: init -> text -> tool -> result -> text -> finish.""" + adapter = _adapter() + all_responses: list[StreamBaseResponse] = [] + + # 1. Init + all_responses.extend( + adapter.convert_message(SystemMessage(subtype="init", data={})) + ) + # 2. Assistant text + all_responses.extend( + adapter.convert_message( + AssistantMessage(content=[TextBlock(text="Let me search")], model="test") + ) + ) + # 3. Tool use + all_responses.extend( + adapter.convert_message( + AssistantMessage( + content=[ + ToolUseBlock( + id="t1", + name=f"{MCP_TOOL_PREFIX}find_agent", + input={"query": "email"}, + ) + ], + model="test", + ) + ) + ) + # 4. Tool result + all_responses.extend( + adapter.convert_message( + UserMessage( + content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")] + ) + ) + ) + # 5. More text + all_responses.extend( + adapter.convert_message( + AssistantMessage(content=[TextBlock(text="I found 2")], model="test") + ) + ) + # 6. Result + all_responses.extend( + adapter.convert_message( + ResultMessage( + subtype="success", + duration_ms=500, + duration_api_ms=400, + is_error=False, + num_turns=2, + session_id="s1", + ) + ) + ) + + types = [type(r).__name__ for r in all_responses] + assert types == [ + "StreamStart", + "StreamStartStep", # step 1: text + tool call + "StreamTextStart", + "StreamTextDelta", # "Let me search" + "StreamTextEnd", # closed before tool + "StreamToolInputStart", + "StreamToolInputAvailable", + "StreamToolOutputAvailable", # tool result + "StreamFinishStep", # step 1 closed after tool result + "StreamStartStep", # step 2: continuation text + "StreamTextStart", # new block after tool + "StreamTextDelta", # "I found 2" + "StreamTextEnd", # closed by result + "StreamFinishStep", # step 2 closed + "StreamFinish", + ] diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py b/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py new file mode 100644 index 0000000000..14efc6d459 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py @@ -0,0 +1,335 @@ +"""Security hooks for Claude Agent SDK integration. + +This module provides security hooks that validate tool calls before execution, +ensuring multi-user isolation and preventing unauthorized operations. +""" + +import json +import logging +import os +import re +from collections.abc import Callable +from typing import Any, cast + +from backend.api.features.chat.sdk.tool_adapter import MCP_TOOL_PREFIX + +logger = logging.getLogger(__name__) + +# Tools that are blocked entirely (CLI/system access). +# "Bash" (capital) is the SDK built-in — it's NOT in allowed_tools but blocked +# here as defence-in-depth. The agent uses mcp__copilot__bash_exec instead, +# which has kernel-level network isolation (unshare --net). +BLOCKED_TOOLS = { + "Bash", + "bash", + "shell", + "exec", + "terminal", + "command", +} + +# Tools allowed only when their path argument stays within the SDK workspace. +# The SDK uses these to handle oversized tool results (writes to tool-results/ +# files, then reads them back) and for workspace file operations. +WORKSPACE_SCOPED_TOOLS = {"Read", "Write", "Edit", "Glob", "Grep"} + +# Dangerous patterns in tool inputs +DANGEROUS_PATTERNS = [ + r"sudo", + r"rm\s+-rf", + r"dd\s+if=", + r"/etc/passwd", + r"/etc/shadow", + r"chmod\s+777", + r"curl\s+.*\|.*sh", + r"wget\s+.*\|.*sh", + r"eval\s*\(", + r"exec\s*\(", + r"__import__", + r"os\.system", + r"subprocess", +] + + +def _deny(reason: str) -> dict[str, Any]: + """Return a hook denial response.""" + return { + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "permissionDecision": "deny", + "permissionDecisionReason": reason, + } + } + + +def _validate_workspace_path( + tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None +) -> dict[str, Any]: + """Validate that a workspace-scoped tool only accesses allowed paths. + + Allowed directories: + - The SDK working directory (``/tmp/copilot-/``) + - The SDK tool-results directory (``~/.claude/projects/…/tool-results/``) + """ + path = tool_input.get("file_path") or tool_input.get("path") or "" + if not path: + # Glob/Grep without a path default to cwd which is already sandboxed + return {} + + # Resolve relative paths against sdk_cwd (the SDK sets cwd so the LLM + # naturally uses relative paths like "test.txt" instead of absolute ones). + # Tilde paths (~/) are home-dir references, not relative — expand first. + if path.startswith("~"): + resolved = os.path.realpath(os.path.expanduser(path)) + elif not os.path.isabs(path) and sdk_cwd: + resolved = os.path.realpath(os.path.join(sdk_cwd, path)) + else: + resolved = os.path.realpath(path) + + # Allow access within the SDK working directory + if sdk_cwd: + norm_cwd = os.path.realpath(sdk_cwd) + if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd: + return {} + + # Allow access to ~/.claude/projects/*/tool-results/ (big tool results) + claude_dir = os.path.realpath(os.path.expanduser("~/.claude/projects")) + tool_results_seg = os.sep + "tool-results" + os.sep + if resolved.startswith(claude_dir + os.sep) and tool_results_seg in resolved: + return {} + + logger.warning( + f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})" + ) + workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else "" + return _deny( + f"[SECURITY] Tool '{tool_name}' can only access files within the workspace " + f"directory.{workspace_hint} " + "This is enforced by the platform and cannot be bypassed." + ) + + +def _validate_tool_access( + tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None = None +) -> dict[str, Any]: + """Validate that a tool call is allowed. + + Returns: + Empty dict to allow, or dict with hookSpecificOutput to deny + """ + # Block forbidden tools + if tool_name in BLOCKED_TOOLS: + logger.warning(f"Blocked tool access attempt: {tool_name}") + return _deny( + f"[SECURITY] Tool '{tool_name}' is blocked for security. " + "This is enforced by the platform and cannot be bypassed. " + "Use the CoPilot-specific MCP tools instead." + ) + + # Workspace-scoped tools: allowed only within the SDK workspace directory + if tool_name in WORKSPACE_SCOPED_TOOLS: + return _validate_workspace_path(tool_name, tool_input, sdk_cwd) + + # Check for dangerous patterns in tool input + # Use json.dumps for predictable format (str() produces Python repr) + input_str = json.dumps(tool_input) if tool_input else "" + + for pattern in DANGEROUS_PATTERNS: + if re.search(pattern, input_str, re.IGNORECASE): + logger.warning( + f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}" + ) + return _deny( + "[SECURITY] Input contains a blocked pattern. " + "This is enforced by the platform and cannot be bypassed." + ) + + return {} + + +def _validate_user_isolation( + tool_name: str, tool_input: dict[str, Any], user_id: str | None +) -> dict[str, Any]: + """Validate that tool calls respect user isolation.""" + # For workspace file tools, ensure path doesn't escape + if "workspace" in tool_name.lower(): + path = tool_input.get("path", "") or tool_input.get("file_path", "") + if path: + # Check for path traversal + if ".." in path or path.startswith("/"): + logger.warning( + f"Blocked path traversal attempt: {path} by user {user_id}" + ) + return { + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "permissionDecision": "deny", + "permissionDecisionReason": "Path traversal not allowed", + } + } + + return {} + + +def create_security_hooks( + user_id: str | None, + sdk_cwd: str | None = None, + max_subtasks: int = 3, + on_stop: Callable[[str, str], None] | None = None, +) -> dict[str, Any]: + """Create the security hooks configuration for Claude Agent SDK. + + Includes security validation and observability hooks: + - PreToolUse: Security validation before tool execution + - PostToolUse: Log successful tool executions + - PostToolUseFailure: Log and handle failed tool executions + - PreCompact: Log context compaction events (SDK handles compaction automatically) + - Stop: Capture transcript path for stateless resume (when *on_stop* is provided) + + Args: + user_id: Current user ID for isolation validation + sdk_cwd: SDK working directory for workspace-scoped tool validation + max_subtasks: Maximum Task (sub-agent) spawns allowed per session + on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when + the SDK finishes processing — used to read the JSONL transcript + before the CLI process exits. + + Returns: + Hooks configuration dict for ClaudeAgentOptions + """ + try: + from claude_agent_sdk import HookMatcher + from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput + + # Per-session counter for Task sub-agent spawns + task_spawn_count = 0 + + async def pre_tool_use_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Combined pre-tool-use validation hook.""" + nonlocal task_spawn_count + _ = context # unused but required by signature + tool_name = cast(str, input_data.get("tool_name", "")) + tool_input = cast(dict[str, Any], input_data.get("tool_input", {})) + + # Rate-limit Task (sub-agent) spawns per session + if tool_name == "Task": + task_spawn_count += 1 + if task_spawn_count > max_subtasks: + logger.warning( + f"[SDK] Task limit reached ({max_subtasks}), user={user_id}" + ) + return cast( + SyncHookJSONOutput, + _deny( + f"Maximum {max_subtasks} sub-tasks per session. " + "Please continue in the main conversation." + ), + ) + + # Strip MCP prefix for consistent validation + is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX) + clean_name = tool_name.removeprefix(MCP_TOOL_PREFIX) + + # Only block non-CoPilot tools; our MCP-registered tools + # (including Read for oversized results) are already sandboxed. + if not is_copilot_tool: + result = _validate_tool_access(clean_name, tool_input, sdk_cwd) + if result: + return cast(SyncHookJSONOutput, result) + + # Validate user isolation + result = _validate_user_isolation(clean_name, tool_input, user_id) + if result: + return cast(SyncHookJSONOutput, result) + + logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}") + return cast(SyncHookJSONOutput, {}) + + async def post_tool_use_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Log successful tool executions for observability.""" + _ = context + tool_name = cast(str, input_data.get("tool_name", "")) + logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}") + return cast(SyncHookJSONOutput, {}) + + async def post_tool_failure_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Log failed tool executions for debugging.""" + _ = context + tool_name = cast(str, input_data.get("tool_name", "")) + error = input_data.get("error", "Unknown error") + logger.warning( + f"[SDK] Tool failed: {tool_name}, error={error}, " + f"user={user_id}, tool_use_id={tool_use_id}" + ) + return cast(SyncHookJSONOutput, {}) + + async def pre_compact_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Log when SDK triggers context compaction. + + The SDK automatically compacts conversation history when it grows too large. + This hook provides visibility into when compaction happens. + """ + _ = context, tool_use_id + trigger = input_data.get("trigger", "auto") + logger.info( + f"[SDK] Context compaction triggered: {trigger}, user={user_id}" + ) + return cast(SyncHookJSONOutput, {}) + + # --- Stop hook: capture transcript path for stateless resume --- + async def stop_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Capture transcript path when SDK finishes processing. + + The Stop hook fires while the CLI process is still alive, giving us + a reliable window to read the JSONL transcript before SIGTERM. + """ + _ = context, tool_use_id + transcript_path = cast(str, input_data.get("transcript_path", "")) + sdk_session_id = cast(str, input_data.get("session_id", "")) + + if transcript_path and on_stop: + logger.info( + f"[SDK] Stop hook: transcript_path={transcript_path}, " + f"sdk_session_id={sdk_session_id[:12]}..." + ) + on_stop(transcript_path, sdk_session_id) + + return cast(SyncHookJSONOutput, {}) + + hooks: dict[str, Any] = { + "PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])], + "PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])], + "PostToolUseFailure": [ + HookMatcher(matcher="*", hooks=[post_tool_failure_hook]) + ], + "PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])], + } + + if on_stop is not None: + hooks["Stop"] = [HookMatcher(matcher=None, hooks=[stop_hook])] + + return hooks + except ImportError: + # Fallback for when SDK isn't available - return empty hooks + logger.warning("claude-agent-sdk not available, security hooks disabled") + return {} diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks_test.py b/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks_test.py new file mode 100644 index 0000000000..2d09afdab7 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks_test.py @@ -0,0 +1,165 @@ +"""Unit tests for SDK security hooks.""" + +import os + +from .security_hooks import _validate_tool_access, _validate_user_isolation + +SDK_CWD = "/tmp/copilot-abc123" + + +def _is_denied(result: dict) -> bool: + hook = result.get("hookSpecificOutput", {}) + return hook.get("permissionDecision") == "deny" + + +# -- Blocked tools ----------------------------------------------------------- + + +def test_blocked_tools_denied(): + for tool in ("bash", "shell", "exec", "terminal", "command"): + result = _validate_tool_access(tool, {}) + assert _is_denied(result), f"{tool} should be blocked" + + +def test_unknown_tool_allowed(): + result = _validate_tool_access("SomeCustomTool", {}) + assert result == {} + + +# -- Workspace-scoped tools -------------------------------------------------- + + +def test_read_within_workspace_allowed(): + result = _validate_tool_access( + "Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD + ) + assert result == {} + + +def test_write_within_workspace_allowed(): + result = _validate_tool_access( + "Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD + ) + assert result == {} + + +def test_edit_within_workspace_allowed(): + result = _validate_tool_access( + "Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD + ) + assert result == {} + + +def test_glob_within_workspace_allowed(): + result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD) + assert result == {} + + +def test_grep_within_workspace_allowed(): + result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD) + assert result == {} + + +def test_read_outside_workspace_denied(): + result = _validate_tool_access( + "Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD + ) + assert _is_denied(result) + + +def test_write_outside_workspace_denied(): + result = _validate_tool_access( + "Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD + ) + assert _is_denied(result) + + +def test_traversal_attack_denied(): + result = _validate_tool_access( + "Read", + {"file_path": f"{SDK_CWD}/../../etc/passwd"}, + sdk_cwd=SDK_CWD, + ) + assert _is_denied(result) + + +def test_no_path_allowed(): + """Glob/Grep without a path argument defaults to cwd — should pass.""" + result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD) + assert result == {} + + +def test_read_no_cwd_denies_absolute(): + """If no sdk_cwd is set, absolute paths are denied.""" + result = _validate_tool_access("Read", {"file_path": "/tmp/anything"}) + assert _is_denied(result) + + +# -- Tool-results directory -------------------------------------------------- + + +def test_read_tool_results_allowed(): + home = os.path.expanduser("~") + path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt" + result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD) + assert result == {} + + +def test_read_claude_projects_without_tool_results_denied(): + home = os.path.expanduser("~") + path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json" + result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD) + assert _is_denied(result) + + +# -- Built-in Bash is blocked (use bash_exec MCP tool instead) --------------- + + +def test_bash_builtin_always_blocked(): + """SDK built-in Bash is blocked — bash_exec MCP tool with bubblewrap is used instead.""" + result = _validate_tool_access("Bash", {"command": "echo hello"}, sdk_cwd=SDK_CWD) + assert _is_denied(result) + + +# -- Dangerous patterns ------------------------------------------------------ + + +def test_dangerous_pattern_blocked(): + result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"}) + assert _is_denied(result) + + +def test_subprocess_pattern_blocked(): + result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"}) + assert _is_denied(result) + + +# -- User isolation ---------------------------------------------------------- + + +def test_workspace_path_traversal_blocked(): + result = _validate_user_isolation( + "workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1" + ) + assert _is_denied(result) + + +def test_workspace_absolute_path_blocked(): + result = _validate_user_isolation( + "workspace_read", {"path": "/etc/passwd"}, user_id="user-1" + ) + assert _is_denied(result) + + +def test_workspace_normal_path_allowed(): + result = _validate_user_isolation( + "workspace_read", {"path": "src/main.py"}, user_id="user-1" + ) + assert result == {} + + +def test_non_workspace_tool_passes_isolation(): + result = _validate_user_isolation( + "find_agent", {"query": "email"}, user_id="user-1" + ) + assert result == {} diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/service.py b/autogpt_platform/backend/backend/api/features/chat/sdk/service.py new file mode 100644 index 0000000000..65195b442c --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/service.py @@ -0,0 +1,751 @@ +"""Claude Agent SDK service layer for CoPilot chat completions.""" + +import asyncio +import json +import logging +import os +import uuid +from collections.abc import AsyncGenerator +from dataclasses import dataclass +from typing import Any + +from backend.util.exceptions import NotFoundError + +from .. import stream_registry +from ..config import ChatConfig +from ..model import ( + ChatMessage, + ChatSession, + get_chat_session, + update_session_title, + upsert_chat_session, +) +from ..response_model import ( + StreamBaseResponse, + StreamError, + StreamFinish, + StreamStart, + StreamTextDelta, + StreamToolInputAvailable, + StreamToolOutputAvailable, +) +from ..service import ( + _build_system_prompt, + _execute_long_running_tool_with_streaming, + _generate_session_title, +) +from ..tools.models import OperationPendingResponse, OperationStartedResponse +from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path +from ..tracking import track_user_message +from .response_adapter import SDKResponseAdapter +from .security_hooks import create_security_hooks +from .tool_adapter import ( + COPILOT_TOOL_NAMES, + LongRunningCallback, + create_copilot_mcp_server, + set_execution_context, +) +from .transcript import ( + download_transcript, + read_transcript_file, + upload_transcript, + validate_transcript, + write_transcript_to_tempfile, +) + +logger = logging.getLogger(__name__) +config = ChatConfig() + +# Set to hold background tasks to prevent garbage collection +_background_tasks: set[asyncio.Task[Any]] = set() + + +@dataclass +class CapturedTranscript: + """Info captured by the SDK Stop hook for stateless --resume.""" + + path: str = "" + sdk_session_id: str = "" + + @property + def available(self) -> bool: + return bool(self.path) + + +_SDK_CWD_PREFIX = WORKSPACE_PREFIX + +# Appended to the system prompt to inform the agent about available tools. +# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead, +# which has kernel-level network isolation (unshare --net). +_SDK_TOOL_SUPPLEMENT = """ + +## Tool notes + +- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool + for shell commands — it runs in a network-isolated sandbox. +- **Shared workspace**: The SDK Read/Write tools and `bash_exec` share the + same working directory. Files created by one are readable by the other. + These files are **ephemeral** — they exist only for the current session. +- **Persistent storage**: Use `write_workspace_file` / `read_workspace_file` + for files that should persist across sessions (stored in cloud storage). +- Long-running tools (create_agent, edit_agent, etc.) are handled + asynchronously. You will receive an immediate response; the actual result + is delivered to the user via a background stream. +""" + + +def _build_long_running_callback(user_id: str | None) -> LongRunningCallback: + """Build a callback that delegates long-running tools to the non-SDK infrastructure. + + Long-running tools (create_agent, edit_agent, etc.) are delegated to the + existing background infrastructure: stream_registry (Redis Streams), + database persistence, and SSE reconnection. This means results survive + page refreshes / pod restarts, and the frontend shows the proper loading + widget with progress updates. + + The returned callback matches the ``LongRunningCallback`` signature: + ``(tool_name, args, session) -> MCP response dict``. + """ + + async def _callback( + tool_name: str, args: dict[str, Any], session: ChatSession + ) -> dict[str, Any]: + operation_id = str(uuid.uuid4()) + task_id = str(uuid.uuid4()) + tool_call_id = f"sdk-{uuid.uuid4().hex[:12]}" + session_id = session.session_id + + # --- Build user-friendly messages (matches non-SDK service) --- + if tool_name == "create_agent": + desc = args.get("description", "") + desc_preview = (desc[:100] + "...") if len(desc) > 100 else desc + pending_msg = ( + f"Creating your agent: {desc_preview}" + if desc_preview + else "Creating agent... This may take a few minutes." + ) + started_msg = ( + "Agent creation started. You can close this tab - " + "check your library in a few minutes." + ) + elif tool_name == "edit_agent": + changes = args.get("changes", "") + changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes + pending_msg = ( + f"Editing agent: {changes_preview}" + if changes_preview + else "Editing agent... This may take a few minutes." + ) + started_msg = ( + "Agent edit started. You can close this tab - " + "check your library in a few minutes." + ) + else: + pending_msg = f"Running {tool_name}... This may take a few minutes." + started_msg = ( + f"{tool_name} started. You can close this tab - " + "check back in a few minutes." + ) + + # --- Register task in Redis for SSE reconnection --- + await stream_registry.create_task( + task_id=task_id, + session_id=session_id, + user_id=user_id, + tool_call_id=tool_call_id, + tool_name=tool_name, + operation_id=operation_id, + ) + + # --- Save OperationPendingResponse to chat history --- + pending_message = ChatMessage( + role="tool", + content=OperationPendingResponse( + message=pending_msg, + operation_id=operation_id, + tool_name=tool_name, + ).model_dump_json(), + tool_call_id=tool_call_id, + ) + session.messages.append(pending_message) + await upsert_chat_session(session) + + # --- Spawn background task (reuses non-SDK infrastructure) --- + bg_task = asyncio.create_task( + _execute_long_running_tool_with_streaming( + tool_name=tool_name, + parameters=args, + tool_call_id=tool_call_id, + operation_id=operation_id, + task_id=task_id, + session_id=session_id, + user_id=user_id, + ) + ) + _background_tasks.add(bg_task) + bg_task.add_done_callback(_background_tasks.discard) + await stream_registry.set_task_asyncio_task(task_id, bg_task) + + logger.info( + f"[SDK] Long-running tool {tool_name} delegated to background " + f"(operation_id={operation_id}, task_id={task_id})" + ) + + # --- Return OperationStartedResponse as MCP tool result --- + # This flows through SDK → response adapter → frontend, triggering + # the loading widget with SSE reconnection support. + started_json = OperationStartedResponse( + message=started_msg, + operation_id=operation_id, + tool_name=tool_name, + task_id=task_id, + ).model_dump_json() + + return { + "content": [{"type": "text", "text": started_json}], + "isError": False, + } + + return _callback + + +def _resolve_sdk_model() -> str | None: + """Resolve the model name for the Claude Agent SDK CLI. + + Uses ``config.claude_agent_model`` if set, otherwise derives from + ``config.model`` by stripping the OpenRouter provider prefix (e.g., + ``"anthropic/claude-opus-4.6"`` → ``"claude-opus-4.6"``). + """ + if config.claude_agent_model: + return config.claude_agent_model + model = config.model + if "/" in model: + return model.split("/", 1)[1] + return model + + +def _build_sdk_env() -> dict[str, str]: + """Build env vars for the SDK CLI process. + + Routes API calls through OpenRouter (or a custom base_url) using + the same ``config.api_key`` / ``config.base_url`` as the non-SDK path. + This gives per-call token and cost tracking on the OpenRouter dashboard. + + Only overrides ``ANTHROPIC_API_KEY`` when a valid proxy URL and auth + token are both present — otherwise returns an empty dict so the SDK + falls back to its default credentials. + """ + env: dict[str, str] = {} + if config.api_key and config.base_url: + # Strip /v1 suffix — SDK expects the base URL without a version path + base = config.base_url.rstrip("/") + if base.endswith("/v1"): + base = base[:-3] + if not base or not base.startswith("http"): + # Invalid base_url — don't override SDK defaults + return env + env["ANTHROPIC_BASE_URL"] = base + env["ANTHROPIC_AUTH_TOKEN"] = config.api_key + # Must be explicitly empty so the CLI uses AUTH_TOKEN instead + env["ANTHROPIC_API_KEY"] = "" + return env + + +def _make_sdk_cwd(session_id: str) -> str: + """Create a safe, session-specific working directory path. + + Delegates to :func:`~backend.api.features.chat.tools.sandbox.make_session_path` + (single source of truth for path sanitization) and adds a defence-in-depth + assertion. + """ + cwd = make_session_path(session_id) + # Defence-in-depth: normpath + startswith is a CodeQL-recognised sanitizer + cwd = os.path.normpath(cwd) + if not cwd.startswith(_SDK_CWD_PREFIX): + raise ValueError(f"SDK cwd escaped prefix: {cwd}") + return cwd + + +def _cleanup_sdk_tool_results(cwd: str) -> None: + """Remove SDK tool-result files for a specific session working directory. + + The SDK creates tool-result files under ~/.claude/projects//tool-results/. + We clean only the specific cwd's results to avoid race conditions between + concurrent sessions. + + Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id. + """ + import shutil + + # Validate cwd is under the expected prefix + normalized = os.path.normpath(cwd) + if not normalized.startswith(_SDK_CWD_PREFIX): + logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}") + return + + # SDK encodes the cwd path by replacing '/' with '-' + encoded_cwd = normalized.replace("/", "-") + + # Construct the project directory path (known-safe home expansion) + claude_projects = os.path.expanduser("~/.claude/projects") + project_dir = os.path.join(claude_projects, encoded_cwd) + + # Security check 3: Validate project_dir is under ~/.claude/projects + project_dir = os.path.normpath(project_dir) + if not project_dir.startswith(claude_projects): + logger.warning( + f"[SDK] Rejecting cleanup for escaped project path: {project_dir}" + ) + return + + results_dir = os.path.join(project_dir, "tool-results") + if os.path.isdir(results_dir): + for filename in os.listdir(results_dir): + file_path = os.path.join(results_dir, filename) + try: + if os.path.isfile(file_path): + os.remove(file_path) + except OSError: + pass + + # Also clean up the temp cwd directory itself + try: + shutil.rmtree(normalized, ignore_errors=True) + except OSError: + pass + + +async def _compress_conversation_history( + session: ChatSession, +) -> list[ChatMessage]: + """Compress prior conversation messages if they exceed the token threshold. + + Uses the shared compress_context() from prompt.py which supports: + - LLM summarization of old messages (keeps recent ones intact) + - Progressive content truncation as fallback + - Middle-out deletion as last resort + + Returns the compressed prior messages (everything except the current message). + """ + prior = session.messages[:-1] + if len(prior) < 2: + return prior + + from backend.util.prompt import compress_context + + # Convert ChatMessages to dicts for compress_context + messages_dict = [] + for msg in prior: + msg_dict: dict[str, Any] = {"role": msg.role} + if msg.content: + msg_dict["content"] = msg.content + if msg.tool_calls: + msg_dict["tool_calls"] = msg.tool_calls + if msg.tool_call_id: + msg_dict["tool_call_id"] = msg.tool_call_id + messages_dict.append(msg_dict) + + try: + import openai + + async with openai.AsyncOpenAI( + api_key=config.api_key, base_url=config.base_url, timeout=30.0 + ) as client: + result = await compress_context( + messages=messages_dict, + model=config.model, + client=client, + ) + except Exception as e: + logger.warning(f"[SDK] Context compression with LLM failed: {e}") + # Fall back to truncation-only (no LLM summarization) + result = await compress_context( + messages=messages_dict, + model=config.model, + client=None, + ) + + if result.was_compacted: + logger.info( + f"[SDK] Context compacted: {result.original_token_count} -> " + f"{result.token_count} tokens " + f"({result.messages_summarized} summarized, " + f"{result.messages_dropped} dropped)" + ) + # Convert compressed dicts back to ChatMessages + return [ + ChatMessage( + role=m["role"], + content=m.get("content"), + tool_calls=m.get("tool_calls"), + tool_call_id=m.get("tool_call_id"), + ) + for m in result.messages + ] + + return prior + + +def _format_conversation_context(messages: list[ChatMessage]) -> str | None: + """Format conversation messages into a context prefix for the user message. + + Returns a string like: + + User: hello + You responded: Hi! How can I help? + + + Returns None if there are no messages to format. + """ + if not messages: + return None + + lines: list[str] = [] + for msg in messages: + if not msg.content: + continue + if msg.role == "user": + lines.append(f"User: {msg.content}") + elif msg.role == "assistant": + lines.append(f"You responded: {msg.content}") + # Skip tool messages — they're internal details + + if not lines: + return None + + return "\n" + "\n".join(lines) + "\n" + + +async def stream_chat_completion_sdk( + session_id: str, + message: str | None = None, + tool_call_response: str | None = None, # noqa: ARG001 + is_user_message: bool = True, + user_id: str | None = None, + retry_count: int = 0, # noqa: ARG001 + session: ChatSession | None = None, + context: dict[str, str] | None = None, # noqa: ARG001 +) -> AsyncGenerator[StreamBaseResponse, None]: + """Stream chat completion using Claude Agent SDK. + + Drop-in replacement for stream_chat_completion with improved reliability. + """ + + if session is None: + session = await get_chat_session(session_id, user_id) + + if not session: + raise NotFoundError( + f"Session {session_id} not found. Please create a new session first." + ) + + if message: + session.messages.append( + ChatMessage( + role="user" if is_user_message else "assistant", content=message + ) + ) + if is_user_message: + track_user_message( + user_id=user_id, session_id=session_id, message_length=len(message) + ) + + session = await upsert_chat_session(session) + + # Generate title for new sessions (first user message) + if is_user_message and not session.title: + user_messages = [m for m in session.messages if m.role == "user"] + if len(user_messages) == 1: + first_message = user_messages[0].content or message or "" + if first_message: + task = asyncio.create_task( + _update_title_async(session_id, first_message, user_id) + ) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) + + # Build system prompt (reuses non-SDK path with Langfuse support) + has_history = len(session.messages) > 1 + system_prompt, _ = await _build_system_prompt( + user_id, has_conversation_history=has_history + ) + system_prompt += _SDK_TOOL_SUPPLEMENT + message_id = str(uuid.uuid4()) + task_id = str(uuid.uuid4()) + + yield StreamStart(messageId=message_id, taskId=task_id) + + stream_completed = False + # Initialise sdk_cwd before the try so the finally can reference it + # even if _make_sdk_cwd raises (in that case it stays as ""). + sdk_cwd = "" + use_resume = False + + try: + # Use a session-specific temp dir to avoid cleanup race conditions + # between concurrent sessions. + sdk_cwd = _make_sdk_cwd(session_id) + os.makedirs(sdk_cwd, exist_ok=True) + + set_execution_context( + user_id, + session, + long_running_callback=_build_long_running_callback(user_id), + ) + try: + from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient + + # Fail fast when no API credentials are available at all + sdk_env = _build_sdk_env() + if not sdk_env and not os.environ.get("ANTHROPIC_API_KEY"): + raise RuntimeError( + "No API key configured. Set OPEN_ROUTER_API_KEY " + "(or CHAT_API_KEY) for OpenRouter routing, " + "or ANTHROPIC_API_KEY for direct Anthropic access." + ) + + mcp_server = create_copilot_mcp_server() + + sdk_model = _resolve_sdk_model() + + # --- Transcript capture via Stop hook --- + captured_transcript = CapturedTranscript() + + def _on_stop(transcript_path: str, sdk_session_id: str) -> None: + captured_transcript.path = transcript_path + captured_transcript.sdk_session_id = sdk_session_id + + security_hooks = create_security_hooks( + user_id, + sdk_cwd=sdk_cwd, + max_subtasks=config.claude_agent_max_subtasks, + on_stop=_on_stop if config.claude_agent_use_resume else None, + ) + + # --- Resume strategy: download transcript from bucket --- + resume_file: str | None = None + use_resume = False + + if config.claude_agent_use_resume and user_id and len(session.messages) > 1: + transcript_content = await download_transcript(user_id, session_id) + if transcript_content and validate_transcript(transcript_content): + resume_file = write_transcript_to_tempfile( + transcript_content, session_id, sdk_cwd + ) + if resume_file: + use_resume = True + logger.info( + f"[SDK] Using --resume with transcript " + f"({len(transcript_content)} bytes)" + ) + + sdk_options_kwargs: dict[str, Any] = { + "system_prompt": system_prompt, + "mcp_servers": {"copilot": mcp_server}, + "allowed_tools": COPILOT_TOOL_NAMES, + "disallowed_tools": ["Bash"], + "hooks": security_hooks, + "cwd": sdk_cwd, + "max_buffer_size": config.claude_agent_max_buffer_size, + } + if sdk_env: + sdk_options_kwargs["model"] = sdk_model + sdk_options_kwargs["env"] = sdk_env + if use_resume and resume_file: + sdk_options_kwargs["resume"] = resume_file + + options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type] + + adapter = SDKResponseAdapter(message_id=message_id) + adapter.set_task_id(task_id) + + async with ClaudeSDKClient(options=options) as client: + current_message = message or "" + if not current_message and session.messages: + last_user = [m for m in session.messages if m.role == "user"] + if last_user: + current_message = last_user[-1].content or "" + + if not current_message.strip(): + yield StreamError( + errorText="Message cannot be empty.", + code="empty_prompt", + ) + yield StreamFinish() + return + + # Build query: with --resume the CLI already has full + # context, so we only send the new message. Without + # resume, compress history into a context prefix. + query_message = current_message + if not use_resume and len(session.messages) > 1: + logger.warning( + f"[SDK] Using compression fallback for session " + f"{session_id} ({len(session.messages)} messages) — " + f"no transcript available for --resume" + ) + compressed = await _compress_conversation_history(session) + history_context = _format_conversation_context(compressed) + if history_context: + query_message = ( + f"{history_context}\n\n" + f"Now, the user says:\n{current_message}" + ) + + logger.info( + f"[SDK] Sending query ({len(session.messages)} msgs in session)" + ) + logger.debug(f"[SDK] Query preview: {current_message[:80]!r}") + await client.query(query_message, session_id=session_id) + + assistant_response = ChatMessage(role="assistant", content="") + accumulated_tool_calls: list[dict[str, Any]] = [] + has_appended_assistant = False + has_tool_results = False + + async for sdk_msg in client.receive_messages(): + logger.debug( + f"[SDK] Received: {type(sdk_msg).__name__} " + f"{getattr(sdk_msg, 'subtype', '')}" + ) + for response in adapter.convert_message(sdk_msg): + if isinstance(response, StreamStart): + continue + + yield response + + if isinstance(response, StreamTextDelta): + delta = response.delta or "" + # After tool results, start a new assistant + # message for the post-tool text. + if has_tool_results and has_appended_assistant: + assistant_response = ChatMessage( + role="assistant", content=delta + ) + accumulated_tool_calls = [] + has_appended_assistant = False + has_tool_results = False + session.messages.append(assistant_response) + has_appended_assistant = True + else: + assistant_response.content = ( + assistant_response.content or "" + ) + delta + if not has_appended_assistant: + session.messages.append(assistant_response) + has_appended_assistant = True + + elif isinstance(response, StreamToolInputAvailable): + accumulated_tool_calls.append( + { + "id": response.toolCallId, + "type": "function", + "function": { + "name": response.toolName, + "arguments": json.dumps(response.input or {}), + }, + } + ) + assistant_response.tool_calls = accumulated_tool_calls + if not has_appended_assistant: + session.messages.append(assistant_response) + has_appended_assistant = True + + elif isinstance(response, StreamToolOutputAvailable): + session.messages.append( + ChatMessage( + role="tool", + content=( + response.output + if isinstance(response.output, str) + else str(response.output) + ), + tool_call_id=response.toolCallId, + ) + ) + has_tool_results = True + + elif isinstance(response, StreamFinish): + stream_completed = True + + if stream_completed: + break + + if ( + assistant_response.content or assistant_response.tool_calls + ) and not has_appended_assistant: + session.messages.append(assistant_response) + + # --- Capture transcript while CLI is still alive --- + # Must happen INSIDE async with: close() sends SIGTERM + # which kills the CLI before it can flush the JSONL. + if ( + config.claude_agent_use_resume + and user_id + and captured_transcript.available + ): + # Give CLI time to flush JSONL writes before we read + await asyncio.sleep(0.5) + raw_transcript = read_transcript_file(captured_transcript.path) + if raw_transcript: + task = asyncio.create_task( + _upload_transcript_bg(user_id, session_id, raw_transcript) + ) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) + else: + logger.debug("[SDK] Stop hook fired but transcript not usable") + + except ImportError: + raise RuntimeError( + "claude-agent-sdk is not installed. " + "Disable SDK mode (CHAT_USE_CLAUDE_AGENT_SDK=false) " + "to use the OpenAI-compatible fallback." + ) + + await upsert_chat_session(session) + logger.debug( + f"[SDK] Session {session_id} saved with {len(session.messages)} messages" + ) + if not stream_completed: + yield StreamFinish() + + except Exception as e: + logger.error(f"[SDK] Error: {e}", exc_info=True) + try: + await upsert_chat_session(session) + except Exception as save_err: + logger.error(f"[SDK] Failed to save session on error: {save_err}") + yield StreamError( + errorText="An error occurred. Please try again.", + code="sdk_error", + ) + yield StreamFinish() + finally: + if sdk_cwd: + _cleanup_sdk_tool_results(sdk_cwd) + + +async def _upload_transcript_bg( + user_id: str, session_id: str, raw_content: str +) -> None: + """Background task to strip progress entries and upload transcript.""" + try: + await upload_transcript(user_id, session_id, raw_content) + except Exception as e: + logger.error(f"[SDK] Failed to upload transcript for {session_id}: {e}") + + +async def _update_title_async( + session_id: str, message: str, user_id: str | None = None +) -> None: + """Background task to update session title.""" + try: + title = await _generate_session_title( + message, user_id=user_id, session_id=session_id + ) + if title: + await update_session_title(session_id, title) + logger.debug(f"[SDK] Generated title for {session_id}: {title}") + except Exception as e: + logger.warning(f"[SDK] Failed to update session title: {e}") diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py b/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py new file mode 100644 index 0000000000..d983d5e785 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py @@ -0,0 +1,322 @@ +"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools. + +This module provides the adapter layer that converts existing BaseTool implementations +into in-process MCP tools that can be used with the Claude Agent SDK. + +Long-running tools (``is_long_running=True``) are delegated to the non-SDK +background infrastructure (stream_registry, Redis persistence, SSE reconnection) +via a callback provided by the service layer. This avoids wasteful SDK polling +and makes results survive page refreshes. +""" + +import itertools +import json +import logging +import os +import uuid +from collections.abc import Awaitable, Callable +from contextvars import ContextVar +from typing import Any + +from backend.api.features.chat.model import ChatSession +from backend.api.features.chat.tools import TOOL_REGISTRY +from backend.api.features.chat.tools.base import BaseTool + +logger = logging.getLogger(__name__) + +# Allowed base directory for the Read tool (SDK saves oversized tool results here). +# Restricted to ~/.claude/projects/ and further validated to require "tool-results" +# in the path — prevents reading settings, credentials, or other sensitive files. +_SDK_PROJECTS_DIR = os.path.expanduser("~/.claude/projects/") + +# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}" +MCP_SERVER_NAME = "copilot" +MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__" + +# Context variables to pass user/session info to tool execution +_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None) +_current_session: ContextVar[ChatSession | None] = ContextVar( + "current_session", default=None +) +# Stash for MCP tool outputs before the SDK potentially truncates them. +# Keyed by tool_name → full output string. Consumed (popped) by the +# response adapter when it builds StreamToolOutputAvailable. +_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar( + "pending_tool_outputs", default=None # type: ignore[arg-type] +) + +# Callback type for delegating long-running tools to the non-SDK infrastructure. +# Args: (tool_name, arguments, session) → MCP-formatted response dict. +LongRunningCallback = Callable[ + [str, dict[str, Any], ChatSession], Awaitable[dict[str, Any]] +] + +# ContextVar so the service layer can inject the callback per-request. +_long_running_callback: ContextVar[LongRunningCallback | None] = ContextVar( + "long_running_callback", default=None +) + + +def set_execution_context( + user_id: str | None, + session: ChatSession, + long_running_callback: LongRunningCallback | None = None, +) -> None: + """Set the execution context for tool calls. + + This must be called before streaming begins to ensure tools have access + to user_id and session information. + + Args: + user_id: Current user's ID. + session: Current chat session. + long_running_callback: Optional callback to delegate long-running tools + to the non-SDK background infrastructure (stream_registry + Redis). + """ + _current_user_id.set(user_id) + _current_session.set(session) + _pending_tool_outputs.set({}) + _long_running_callback.set(long_running_callback) + + +def get_execution_context() -> tuple[str | None, ChatSession | None]: + """Get the current execution context.""" + return ( + _current_user_id.get(), + _current_session.get(), + ) + + +def pop_pending_tool_output(tool_name: str) -> str | None: + """Pop and return the stashed full output for *tool_name*. + + The SDK CLI may truncate large tool results (writing them to disk and + replacing the content with a file reference). This stash keeps the + original MCP output so the response adapter can forward it to the + frontend for proper widget rendering. + + Returns ``None`` if nothing was stashed for *tool_name*. + """ + pending = _pending_tool_outputs.get(None) + if pending is None: + return None + return pending.pop(tool_name, None) + + +async def _execute_tool_sync( + base_tool: BaseTool, + user_id: str | None, + session: ChatSession, + args: dict[str, Any], +) -> dict[str, Any]: + """Execute a tool synchronously and return MCP-formatted response.""" + effective_id = f"sdk-{uuid.uuid4().hex[:12]}" + result = await base_tool.execute( + user_id=user_id, + session=session, + tool_call_id=effective_id, + **args, + ) + + text = ( + result.output if isinstance(result.output, str) else json.dumps(result.output) + ) + + # Stash the full output before the SDK potentially truncates it. + pending = _pending_tool_outputs.get(None) + if pending is not None: + pending[base_tool.name] = text + + return { + "content": [{"type": "text", "text": text}], + "isError": not result.success, + } + + +def _mcp_error(message: str) -> dict[str, Any]: + return { + "content": [ + {"type": "text", "text": json.dumps({"error": message, "type": "error"})} + ], + "isError": True, + } + + +def create_tool_handler(base_tool: BaseTool): + """Create an async handler function for a BaseTool. + + This wraps the existing BaseTool._execute method to be compatible + with the Claude Agent SDK MCP tool format. + + Long-running tools (``is_long_running=True``) are delegated to the + non-SDK background infrastructure via a callback set in the execution + context. The callback persists the operation in Redis (stream_registry) + so results survive page refreshes and pod restarts. + """ + + async def tool_handler(args: dict[str, Any]) -> dict[str, Any]: + """Execute the wrapped tool and return MCP-formatted response.""" + user_id, session = get_execution_context() + + if session is None: + return _mcp_error("No session context available") + + # --- Long-running: delegate to non-SDK background infrastructure --- + if base_tool.is_long_running: + callback = _long_running_callback.get(None) + if callback: + try: + return await callback(base_tool.name, args, session) + except Exception as e: + logger.error( + f"Long-running callback failed for {base_tool.name}: {e}", + exc_info=True, + ) + return _mcp_error(f"Failed to start {base_tool.name}: {e}") + # No callback — fall through to synchronous execution + logger.warning( + f"[SDK] No long-running callback for {base_tool.name}, " + f"executing synchronously (may block)" + ) + + # --- Normal (fast) tool: execute synchronously --- + try: + return await _execute_tool_sync(base_tool, user_id, session, args) + except Exception as e: + logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True) + return _mcp_error(f"Failed to execute {base_tool.name}: {e}") + + return tool_handler + + +def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]: + """Build a JSON Schema input schema for a tool.""" + return { + "type": "object", + "properties": base_tool.parameters.get("properties", {}), + "required": base_tool.parameters.get("required", []), + } + + +async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]: + """Read a file with optional offset/limit. Restricted to SDK working directory. + + After reading, the file is deleted to prevent accumulation in long-running pods. + """ + file_path = args.get("file_path", "") + offset = args.get("offset", 0) + limit = args.get("limit", 2000) + + # Security: only allow reads under ~/.claude/projects/**/tool-results/ + real_path = os.path.realpath(file_path) + if not real_path.startswith(_SDK_PROJECTS_DIR) or "tool-results" not in real_path: + return { + "content": [{"type": "text", "text": f"Access denied: {file_path}"}], + "isError": True, + } + + try: + with open(real_path) as f: + selected = list(itertools.islice(f, offset, offset + limit)) + content = "".join(selected) + # Cleanup happens in _cleanup_sdk_tool_results after session ends; + # don't delete here — the SDK may read in multiple chunks. + return {"content": [{"type": "text", "text": content}], "isError": False} + except FileNotFoundError: + return { + "content": [{"type": "text", "text": f"File not found: {file_path}"}], + "isError": True, + } + except Exception as e: + return { + "content": [{"type": "text", "text": f"Error reading file: {e}"}], + "isError": True, + } + + +_READ_TOOL_NAME = "Read" +_READ_TOOL_DESCRIPTION = ( + "Read a file from the local filesystem. " + "Use offset and limit to read specific line ranges for large files." +) +_READ_TOOL_SCHEMA = { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "The absolute path to the file to read", + }, + "offset": { + "type": "integer", + "description": "Line number to start reading from (0-indexed). Default: 0", + }, + "limit": { + "type": "integer", + "description": "Number of lines to read. Default: 2000", + }, + }, + "required": ["file_path"], +} + + +# Create the MCP server configuration +def create_copilot_mcp_server(): + """Create an in-process MCP server configuration for CoPilot tools. + + This can be passed to ClaudeAgentOptions.mcp_servers. + + Note: The actual SDK MCP server creation depends on the claude-agent-sdk + package being available. This function returns the configuration that + can be used with the SDK. + """ + try: + from claude_agent_sdk import create_sdk_mcp_server, tool + + # Create decorated tool functions + sdk_tools = [] + + for tool_name, base_tool in TOOL_REGISTRY.items(): + handler = create_tool_handler(base_tool) + decorated = tool( + tool_name, + base_tool.description, + _build_input_schema(base_tool), + )(handler) + sdk_tools.append(decorated) + + # Add the Read tool so the SDK can read back oversized tool results + read_tool = tool( + _READ_TOOL_NAME, + _READ_TOOL_DESCRIPTION, + _READ_TOOL_SCHEMA, + )(_read_file_handler) + sdk_tools.append(read_tool) + + server = create_sdk_mcp_server( + name=MCP_SERVER_NAME, + version="1.0.0", + tools=sdk_tools, + ) + + return server + + except ImportError: + # Let ImportError propagate so service.py handles the fallback + raise + + +# SDK built-in tools allowed within the workspace directory. +# Security hooks validate that file paths stay within sdk_cwd. +# Bash is NOT included — use the sandboxed MCP bash_exec tool instead, +# which provides kernel-level network isolation via unshare --net. +# Task allows spawning sub-agents (rate-limited by security hooks). +_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Task"] + +# List of tool names for allowed_tools configuration +# Include MCP tools, the MCP Read tool for oversized results, +# and SDK built-in file tools for workspace operations. +COPILOT_TOOL_NAMES = [ + *[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()], + f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}", + *_SDK_BUILTIN_TOOLS, +] diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/transcript.py b/autogpt_platform/backend/backend/api/features/chat/sdk/transcript.py new file mode 100644 index 0000000000..aaa5609227 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/transcript.py @@ -0,0 +1,356 @@ +"""JSONL transcript management for stateless multi-turn resume. + +The Claude Code CLI persists conversations as JSONL files (one JSON object per +line). When the SDK's ``Stop`` hook fires we read this file, strip bloat +(progress entries, metadata), and upload the result to bucket storage. On the +next turn we download the transcript, write it to a temp file, and pass +``--resume`` so the CLI can reconstruct the full conversation. + +Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local +filesystem for self-hosted) — no DB column needed. +""" + +import json +import logging +import os +import re + +logger = logging.getLogger(__name__) + +# UUIDs are hex + hyphens; strip everything else to prevent path injection. +_SAFE_ID_RE = re.compile(r"[^0-9a-fA-F-]") + +# Entry types that can be safely removed from the transcript without breaking +# the parentUuid conversation tree that ``--resume`` relies on. +# - progress: UI progress ticks, no message content (avg 97KB for agent_progress) +# - file-history-snapshot: undo tracking metadata +# - queue-operation: internal queue bookkeeping +# - summary: session summaries +# - pr-link: PR link metadata +STRIPPABLE_TYPES = frozenset( + {"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"} +) + +# Workspace storage constants — deterministic path from session_id. +TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts" + + +# --------------------------------------------------------------------------- +# Progress stripping +# --------------------------------------------------------------------------- + + +def strip_progress_entries(content: str) -> str: + """Remove progress/metadata entries from a JSONL transcript. + + Removes entries whose ``type`` is in ``STRIPPABLE_TYPES`` and reparents + any remaining child entries so the ``parentUuid`` chain stays intact. + Typically reduces transcript size by ~30%. + """ + lines = content.strip().split("\n") + + entries: list[dict] = [] + for line in lines: + try: + entries.append(json.loads(line)) + except json.JSONDecodeError: + # Keep unparseable lines as-is (safety) + entries.append({"_raw": line}) + + stripped_uuids: set[str] = set() + uuid_to_parent: dict[str, str] = {} + kept: list[dict] = [] + + for entry in entries: + if "_raw" in entry: + kept.append(entry) + continue + uid = entry.get("uuid", "") + parent = entry.get("parentUuid", "") + entry_type = entry.get("type", "") + + if uid: + uuid_to_parent[uid] = parent + + if entry_type in STRIPPABLE_TYPES: + if uid: + stripped_uuids.add(uid) + else: + kept.append(entry) + + # Reparent: walk up chain through stripped entries to find surviving ancestor + for entry in kept: + if "_raw" in entry: + continue + parent = entry.get("parentUuid", "") + original_parent = parent + while parent in stripped_uuids: + parent = uuid_to_parent.get(parent, "") + if parent != original_parent: + entry["parentUuid"] = parent + + result_lines: list[str] = [] + for entry in kept: + if "_raw" in entry: + result_lines.append(entry["_raw"]) + else: + result_lines.append(json.dumps(entry, separators=(",", ":"))) + + return "\n".join(result_lines) + "\n" + + +# --------------------------------------------------------------------------- +# Local file I/O (read from CLI's JSONL, write temp file for --resume) +# --------------------------------------------------------------------------- + + +def read_transcript_file(transcript_path: str) -> str | None: + """Read a JSONL transcript file from disk. + + Returns the raw JSONL content, or ``None`` if the file is missing, empty, + or only contains metadata (≤2 lines with no conversation messages). + """ + if not transcript_path or not os.path.isfile(transcript_path): + logger.debug(f"[Transcript] File not found: {transcript_path}") + return None + + try: + with open(transcript_path) as f: + content = f.read() + + if not content.strip(): + logger.debug(f"[Transcript] Empty file: {transcript_path}") + return None + + lines = content.strip().split("\n") + if len(lines) < 3: + # Raw files with ≤2 lines are metadata-only + # (queue-operation + file-history-snapshot, no conversation). + logger.debug( + f"[Transcript] Too few lines ({len(lines)}): {transcript_path}" + ) + return None + + # Quick structural validation — parse first and last lines. + json.loads(lines[0]) + json.loads(lines[-1]) + + logger.info( + f"[Transcript] Read {len(lines)} lines, " + f"{len(content)} bytes from {transcript_path}" + ) + return content + + except (json.JSONDecodeError, OSError) as e: + logger.warning(f"[Transcript] Failed to read {transcript_path}: {e}") + return None + + +def _sanitize_id(raw_id: str, max_len: int = 36) -> str: + """Sanitize an ID for safe use in file paths. + + Session/user IDs are expected to be UUIDs (hex + hyphens). Strip + everything else and truncate to *max_len* so the result cannot introduce + path separators or other special characters. + """ + cleaned = _SAFE_ID_RE.sub("", raw_id or "")[:max_len] + return cleaned or "unknown" + + +_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-") + + +def write_transcript_to_tempfile( + transcript_content: str, + session_id: str, + cwd: str, +) -> str | None: + """Write JSONL transcript to a temp file inside *cwd* for ``--resume``. + + The file lives in the session working directory so it is cleaned up + automatically when the session ends. + + Returns the absolute path to the file, or ``None`` on failure. + """ + # Validate cwd is under the expected sandbox prefix (CodeQL sanitizer). + real_cwd = os.path.realpath(cwd) + if not real_cwd.startswith(_SAFE_CWD_PREFIX): + logger.warning(f"[Transcript] cwd outside sandbox: {cwd}") + return None + + try: + os.makedirs(real_cwd, exist_ok=True) + safe_id = _sanitize_id(session_id, max_len=8) + jsonl_path = os.path.realpath( + os.path.join(real_cwd, f"transcript-{safe_id}.jsonl") + ) + if not jsonl_path.startswith(real_cwd): + logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}") + return None + + with open(jsonl_path, "w") as f: + f.write(transcript_content) + + logger.info(f"[Transcript] Wrote resume file: {jsonl_path}") + return jsonl_path + + except OSError as e: + logger.warning(f"[Transcript] Failed to write resume file: {e}") + return None + + +def validate_transcript(content: str | None) -> bool: + """Check that a transcript has actual conversation messages. + + A valid transcript for resume needs at least one user message and one + assistant message (not just queue-operation / file-history-snapshot + metadata). + """ + if not content or not content.strip(): + return False + + lines = content.strip().split("\n") + if len(lines) < 2: + return False + + has_user = False + has_assistant = False + + for line in lines: + try: + entry = json.loads(line) + msg_type = entry.get("type") + if msg_type == "user": + has_user = True + elif msg_type == "assistant": + has_assistant = True + except json.JSONDecodeError: + return False + + return has_user and has_assistant + + +# --------------------------------------------------------------------------- +# Bucket storage (GCS / local via WorkspaceStorageBackend) +# --------------------------------------------------------------------------- + + +def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]: + """Return (workspace_id, file_id, filename) for a session's transcript. + + Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl`` + IDs are sanitized to hex+hyphen to prevent path traversal. + """ + return ( + TRANSCRIPT_STORAGE_PREFIX, + _sanitize_id(user_id), + f"{_sanitize_id(session_id)}.jsonl", + ) + + +def _build_storage_path(user_id: str, session_id: str, backend: object) -> str: + """Build the full storage path string that ``retrieve()`` expects. + + ``store()`` returns a path like ``gcs://bucket/workspaces/...`` or + ``local://workspace_id/file_id/filename``. Since we use deterministic + arguments we can reconstruct the same path for download/delete without + having stored the return value. + """ + from backend.util.workspace_storage import GCSWorkspaceStorage + + wid, fid, fname = _storage_path_parts(user_id, session_id) + + if isinstance(backend, GCSWorkspaceStorage): + blob = f"workspaces/{wid}/{fid}/{fname}" + return f"gcs://{backend.bucket_name}/{blob}" + else: + # LocalWorkspaceStorage returns local://{relative_path} + return f"local://{wid}/{fid}/{fname}" + + +async def upload_transcript(user_id: str, session_id: str, content: str) -> None: + """Strip progress entries and upload transcript to bucket storage. + + Safety: only overwrites when the new (stripped) transcript is larger than + what is already stored. Since JSONL is append-only, the latest transcript + is always the longest. This prevents a slow/stale background task from + clobbering a newer upload from a concurrent turn. + """ + from backend.util.workspace_storage import get_workspace_storage + + stripped = strip_progress_entries(content) + if not validate_transcript(stripped): + logger.warning( + f"[Transcript] Skipping upload — stripped content is not a valid " + f"transcript for session {session_id}" + ) + return + + storage = await get_workspace_storage() + wid, fid, fname = _storage_path_parts(user_id, session_id) + encoded = stripped.encode("utf-8") + new_size = len(encoded) + + # Check existing transcript size to avoid overwriting newer with older + path = _build_storage_path(user_id, session_id, storage) + try: + existing = await storage.retrieve(path) + if len(existing) >= new_size: + logger.info( + f"[Transcript] Skipping upload — existing transcript " + f"({len(existing)}B) >= new ({new_size}B) for session " + f"{session_id}" + ) + return + except (FileNotFoundError, Exception): + pass # No existing transcript or retrieval error — proceed with upload + + await storage.store( + workspace_id=wid, + file_id=fid, + filename=fname, + content=encoded, + ) + logger.info( + f"[Transcript] Uploaded {new_size} bytes " + f"(stripped from {len(content)}) for session {session_id}" + ) + + +async def download_transcript(user_id: str, session_id: str) -> str | None: + """Download transcript from bucket storage. + + Returns the JSONL content string, or ``None`` if not found. + """ + from backend.util.workspace_storage import get_workspace_storage + + storage = await get_workspace_storage() + path = _build_storage_path(user_id, session_id, storage) + + try: + data = await storage.retrieve(path) + content = data.decode("utf-8") + logger.info( + f"[Transcript] Downloaded {len(content)} bytes for session {session_id}" + ) + return content + except FileNotFoundError: + logger.debug(f"[Transcript] No transcript in storage for {session_id}") + return None + except Exception as e: + logger.warning(f"[Transcript] Failed to download transcript: {e}") + return None + + +async def delete_transcript(user_id: str, session_id: str) -> None: + """Delete transcript from bucket storage (e.g. after resume failure).""" + from backend.util.workspace_storage import get_workspace_storage + + storage = await get_workspace_storage() + path = _build_storage_path(user_id, session_id, storage) + + try: + await storage.delete(path) + logger.info(f"[Transcript] Deleted transcript for session {session_id}") + except Exception as e: + logger.warning(f"[Transcript] Failed to delete transcript: {e}") diff --git a/autogpt_platform/backend/backend/api/features/chat/service.py b/autogpt_platform/backend/backend/api/features/chat/service.py index b8ddc35960..cb5591e6d0 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/service.py @@ -245,12 +245,16 @@ async def _get_system_prompt_template(context: str) -> str: return DEFAULT_SYSTEM_PROMPT.format(users_information=context) -async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]: +async def _build_system_prompt( + user_id: str | None, has_conversation_history: bool = False +) -> tuple[str, Any]: """Build the full system prompt including business understanding if available. Args: - user_id: The user ID for fetching business understanding - If "default" and this is the user's first session, will use "onboarding" instead. + user_id: The user ID for fetching business understanding. + has_conversation_history: Whether there's existing conversation history. + If True, we don't tell the model to greet/introduce (since they're + already in a conversation). Returns: Tuple of (compiled prompt string, business understanding object) @@ -266,6 +270,8 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]: if understanding: context = format_understanding_for_prompt(understanding) + elif has_conversation_history: + context = "No prior understanding saved yet. Continue the existing conversation naturally." else: context = "This is the first time you are meeting the user. Greet them and introduce them to the platform" @@ -374,7 +380,6 @@ async def stream_chat_completion( Raises: NotFoundError: If session_id is invalid - ValueError: If max_context_messages is exceeded """ completion_start = time.monotonic() @@ -459,8 +464,9 @@ async def stream_chat_completion( # Generate title for new sessions on first user message (non-blocking) # Check: is_user_message, no title yet, and this is the first user message - if is_user_message and message and not session.title: - user_messages = [m for m in session.messages if m.role == "user"] + user_messages = [m for m in session.messages if m.role == "user"] + first_user_msg = message or (user_messages[0].content if user_messages else None) + if is_user_message and first_user_msg and not session.title: if len(user_messages) == 1: # First user message - generate title in background import asyncio @@ -468,7 +474,7 @@ async def stream_chat_completion( # Capture only the values we need (not the session object) to avoid # stale data issues when the main flow modifies the session captured_session_id = session_id - captured_message = message + captured_message = first_user_msg captured_user_id = user_id async def _update_title(): @@ -1237,7 +1243,7 @@ async def _stream_chat_chunks( total_time = (time_module.perf_counter() - stream_chunks_start) * 1000 logger.info( - f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; " + f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; " f"session={session.session_id}, user={session.user_id}", extra={"json_fields": {**log_meta, "total_time_ms": total_time}}, ) diff --git a/autogpt_platform/backend/backend/api/features/chat/service_test.py b/autogpt_platform/backend/backend/api/features/chat/service_test.py index 70f27af14f..b2fc82b790 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/service_test.py @@ -1,3 +1,4 @@ +import asyncio import logging from os import getenv @@ -11,6 +12,8 @@ from .response_model import ( StreamTextDelta, StreamToolOutputAvailable, ) +from .sdk import service as sdk_service +from .sdk.transcript import download_transcript logger = logging.getLogger(__name__) @@ -80,3 +83,96 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user session = await get_chat_session(session.session_id) assert session, "Session not found" assert session.usage, "Usage is empty" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_sdk_resume_multi_turn(setup_test_user, test_user_id): + """Test that the SDK --resume path captures and uses transcripts across turns. + + Turn 1: Send a message containing a unique keyword. + Turn 2: Ask the model to recall that keyword — proving the transcript was + persisted and restored via --resume. + """ + api_key: str | None = getenv("OPEN_ROUTER_API_KEY") + if not api_key: + return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test") + + from .config import ChatConfig + + cfg = ChatConfig() + if not cfg.claude_agent_use_resume: + return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test") + + session = await create_chat_session(test_user_id) + session = await upsert_chat_session(session) + + # --- Turn 1: send a message with a unique keyword --- + keyword = "ZEPHYR42" + turn1_msg = ( + f"Please remember this special keyword: {keyword}. " + "Just confirm you've noted it, keep your response brief." + ) + turn1_text = "" + turn1_errors: list[str] = [] + turn1_ended = False + + async for chunk in sdk_service.stream_chat_completion_sdk( + session.session_id, + turn1_msg, + user_id=test_user_id, + ): + if isinstance(chunk, StreamTextDelta): + turn1_text += chunk.delta + elif isinstance(chunk, StreamError): + turn1_errors.append(chunk.errorText) + elif isinstance(chunk, StreamFinish): + turn1_ended = True + + assert turn1_ended, "Turn 1 did not finish" + assert not turn1_errors, f"Turn 1 errors: {turn1_errors}" + assert turn1_text, "Turn 1 produced no text" + + # Wait for background upload task to complete (retry up to 5s) + transcript = None + for _ in range(10): + await asyncio.sleep(0.5) + transcript = await download_transcript(test_user_id, session.session_id) + if transcript: + break + assert transcript, ( + "Transcript was not uploaded to bucket after turn 1 — " + "Stop hook may not have fired or transcript was too small" + ) + logger.info(f"Turn 1 transcript uploaded: {len(transcript)} bytes") + + # Reload session for turn 2 + session = await get_chat_session(session.session_id, test_user_id) + assert session, "Session not found after turn 1" + + # --- Turn 2: ask model to recall the keyword --- + turn2_msg = "What was the special keyword I asked you to remember?" + turn2_text = "" + turn2_errors: list[str] = [] + turn2_ended = False + + async for chunk in sdk_service.stream_chat_completion_sdk( + session.session_id, + turn2_msg, + user_id=test_user_id, + session=session, + ): + if isinstance(chunk, StreamTextDelta): + turn2_text += chunk.delta + elif isinstance(chunk, StreamError): + turn2_errors.append(chunk.errorText) + elif isinstance(chunk, StreamFinish): + turn2_ended = True + + assert turn2_ended, "Turn 2 did not finish" + assert not turn2_errors, f"Turn 2 errors: {turn2_errors}" + assert turn2_text, "Turn 2 produced no text" + assert keyword in turn2_text, ( + f"Model did not recall keyword '{keyword}' in turn 2. " + f"Response: {turn2_text[:200]}" + ) + logger.info(f"Turn 2 recalled keyword successfully: {turn2_text[:100]}") diff --git a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py index abc34b1fc9..671aefc7ba 100644 --- a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py +++ b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py @@ -814,6 +814,28 @@ async def get_active_task_for_session( if task_user_id and user_id != task_user_id: continue + # Auto-expire stale tasks that exceeded stream_timeout + created_at_str = meta.get("created_at", "") + if created_at_str: + try: + created_at = datetime.fromisoformat(created_at_str) + age_seconds = ( + datetime.now(timezone.utc) - created_at + ).total_seconds() + if age_seconds > config.stream_timeout: + logger.warning( + f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... " + f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)" + ) + await mark_task_completed(task_id, "failed") + continue + except (ValueError, TypeError): + pass + + logger.info( + f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..." + ) + # Get the last message ID from Redis Stream stream_key = _get_task_stream_key(task_id) last_id = "0-0" diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py b/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py index 350776081a..1ab4f720bb 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py @@ -9,6 +9,8 @@ from backend.api.features.chat.tracking import track_tool_called from .add_understanding import AddUnderstandingTool from .agent_output import AgentOutputTool from .base import BaseTool +from .bash_exec import BashExecTool +from .check_operation_status import CheckOperationStatusTool from .create_agent import CreateAgentTool from .customize_agent import CustomizeAgentTool from .edit_agent import EditAgentTool @@ -20,6 +22,7 @@ from .get_doc_page import GetDocPageTool from .run_agent import RunAgentTool from .run_block import RunBlockTool from .search_docs import SearchDocsTool +from .web_fetch import WebFetchTool from .workspace_files import ( DeleteWorkspaceFileTool, ListWorkspaceFilesTool, @@ -44,8 +47,14 @@ TOOL_REGISTRY: dict[str, BaseTool] = { "run_agent": RunAgentTool(), "run_block": RunBlockTool(), "view_agent_output": AgentOutputTool(), + "check_operation_status": CheckOperationStatusTool(), "search_docs": SearchDocsTool(), "get_doc_page": GetDocPageTool(), + # Web fetch for safe URL retrieval + "web_fetch": WebFetchTool(), + # Sandboxed code execution (bubblewrap) + "bash_exec": BashExecTool(), + # Persistent workspace tools (cloud storage, survives across sessions) # Feature request tools "search_feature_requests": SearchFeatureRequestsTool(), "create_feature_request": CreateFeatureRequestTool(), diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/bash_exec.py b/autogpt_platform/backend/backend/api/features/chat/tools/bash_exec.py new file mode 100644 index 0000000000..da9d8bf3fa --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/bash_exec.py @@ -0,0 +1,131 @@ +"""Bash execution tool — run shell commands in a bubblewrap sandbox. + +Full Bash scripting is allowed (loops, conditionals, pipes, functions, etc.). +Safety comes from OS-level isolation (bubblewrap): only system dirs visible +read-only, writable workspace only, clean env, no network. + +Requires bubblewrap (``bwrap``) — the tool is disabled when bwrap is not +available (e.g. macOS development). +""" + +import logging +from typing import Any + +from backend.api.features.chat.model import ChatSession +from backend.api.features.chat.tools.base import BaseTool +from backend.api.features.chat.tools.models import ( + BashExecResponse, + ErrorResponse, + ToolResponseBase, +) +from backend.api.features.chat.tools.sandbox import ( + get_workspace_dir, + has_full_sandbox, + run_sandboxed, +) + +logger = logging.getLogger(__name__) + + +class BashExecTool(BaseTool): + """Execute Bash commands in a bubblewrap sandbox.""" + + @property + def name(self) -> str: + return "bash_exec" + + @property + def description(self) -> str: + if not has_full_sandbox(): + return ( + "Bash execution is DISABLED — bubblewrap sandbox is not " + "available on this platform. Do not call this tool." + ) + return ( + "Execute a Bash command or script in a bubblewrap sandbox. " + "Full Bash scripting is supported (loops, conditionals, pipes, " + "functions, etc.). " + "The sandbox shares the same working directory as the SDK Read/Write " + "tools — files created by either are accessible to both. " + "SECURITY: Only system directories (/usr, /bin, /lib, /etc) are " + "visible read-only, the per-session workspace is the only writable " + "path, environment variables are wiped (no secrets), all network " + "access is blocked at the kernel level, and resource limits are " + "enforced (max 64 processes, 512MB memory, 50MB file size). " + "Application code, configs, and other directories are NOT accessible. " + "To fetch web content, use the web_fetch tool instead. " + "Execution is killed after the timeout (default 30s, max 120s). " + "Returns stdout and stderr. " + "Useful for file manipulation, data processing with Unix tools " + "(grep, awk, sed, jq, etc.), and running shell scripts." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "Bash command or script to execute.", + }, + "timeout": { + "type": "integer", + "description": ( + "Max execution time in seconds (default 30, max 120)." + ), + "default": 30, + }, + }, + "required": ["command"], + } + + @property + def requires_auth(self) -> bool: + return False + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs: Any, + ) -> ToolResponseBase: + session_id = session.session_id if session else None + + if not has_full_sandbox(): + return ErrorResponse( + message="bash_exec requires bubblewrap sandbox (Linux only).", + error="sandbox_unavailable", + session_id=session_id, + ) + + command: str = (kwargs.get("command") or "").strip() + timeout: int = kwargs.get("timeout", 30) + + if not command: + return ErrorResponse( + message="No command provided.", + error="empty_command", + session_id=session_id, + ) + + workspace = get_workspace_dir(session_id or "default") + + stdout, stderr, exit_code, timed_out = await run_sandboxed( + command=["bash", "-c", command], + cwd=workspace, + timeout=timeout, + ) + + return BashExecResponse( + message=( + "Execution timed out" + if timed_out + else f"Command executed (exit {exit_code})" + ), + stdout=stdout, + stderr=stderr, + exit_code=exit_code, + timed_out=timed_out, + session_id=session_id, + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/check_operation_status.py b/autogpt_platform/backend/backend/api/features/chat/tools/check_operation_status.py new file mode 100644 index 0000000000..b8ec770fd0 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/check_operation_status.py @@ -0,0 +1,127 @@ +"""CheckOperationStatusTool — query the status of a long-running operation.""" + +import logging +from typing import Any + +from backend.api.features.chat.model import ChatSession +from backend.api.features.chat.tools.base import BaseTool +from backend.api.features.chat.tools.models import ( + ErrorResponse, + ResponseType, + ToolResponseBase, +) + +logger = logging.getLogger(__name__) + + +class OperationStatusResponse(ToolResponseBase): + """Response for check_operation_status tool.""" + + type: ResponseType = ResponseType.OPERATION_STATUS + task_id: str + operation_id: str + status: str # "running", "completed", "failed" + tool_name: str | None = None + message: str = "" + + +class CheckOperationStatusTool(BaseTool): + """Check the status of a long-running operation (create_agent, edit_agent, etc.). + + The CoPilot uses this tool to report back to the user whether an + operation that was started earlier has completed, failed, or is still + running. + """ + + @property + def name(self) -> str: + return "check_operation_status" + + @property + def description(self) -> str: + return ( + "Check the current status of a long-running operation such as " + "create_agent or edit_agent. Accepts either an operation_id or " + "task_id from a previous operation_started response. " + "Returns the current status: running, completed, or failed." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "operation_id": { + "type": "string", + "description": ( + "The operation_id from an operation_started response." + ), + }, + "task_id": { + "type": "string", + "description": ( + "The task_id from an operation_started response. " + "Used as fallback if operation_id is not provided." + ), + }, + }, + "required": [], + } + + @property + def requires_auth(self) -> bool: + return False + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + from backend.api.features.chat import stream_registry + + operation_id = (kwargs.get("operation_id") or "").strip() + task_id = (kwargs.get("task_id") or "").strip() + + if not operation_id and not task_id: + return ErrorResponse( + message="Please provide an operation_id or task_id.", + error="missing_parameter", + ) + + task = None + if operation_id: + task = await stream_registry.find_task_by_operation_id(operation_id) + if task is None and task_id: + task = await stream_registry.get_task(task_id) + + if task is None: + # Task not in Redis — it may have already expired (TTL). + # Check conversation history for the result instead. + return ErrorResponse( + message=( + "Operation not found — it may have already completed and " + "expired from the status tracker. Check the conversation " + "history for the result." + ), + error="not_found", + ) + + status_messages = { + "running": ( + f"The {task.tool_name or 'operation'} is still running. " + "Please wait for it to complete." + ), + "completed": ( + f"The {task.tool_name or 'operation'} has completed successfully." + ), + "failed": f"The {task.tool_name or 'operation'} has failed.", + } + + return OperationStatusResponse( + task_id=task.task_id, + operation_id=task.operation_id, + status=task.status, + tool_name=task.tool_name, + message=status_messages.get(task.status, f"Status: {task.status}"), + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py b/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py index 55b1c0d510..c51317cb62 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py @@ -146,6 +146,7 @@ class FindBlockTool(BaseTool): id=block_id, name=block.name, description=block.description or "", + categories=[c.value for c in block.categories], ) ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/models.py b/autogpt_platform/backend/backend/api/features/chat/tools/models.py index f2d8f364e4..b32f6ca2ce 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/models.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/models.py @@ -41,6 +41,12 @@ class ResponseType(str, Enum): OPERATION_IN_PROGRESS = "operation_in_progress" # Input validation INPUT_VALIDATION_ERROR = "input_validation_error" + # Web fetch + WEB_FETCH = "web_fetch" + # Code execution + BASH_EXEC = "bash_exec" + # Operation status check + OPERATION_STATUS = "operation_status" # Feature request types FEATURE_REQUEST_SEARCH = "feature_request_search" FEATURE_REQUEST_CREATED = "feature_request_created" @@ -338,6 +344,19 @@ class BlockInfoSummary(BaseModel): id: str name: str description: str + categories: list[str] + input_schema: dict[str, Any] = Field( + default_factory=dict, + description="Full JSON schema for block inputs", + ) + output_schema: dict[str, Any] = Field( + default_factory=dict, + description="Full JSON schema for block outputs", + ) + required_inputs: list[BlockInputFieldInfo] = Field( + default_factory=list, + description="List of input fields for this block", + ) class BlockListResponse(ToolResponseBase): @@ -347,6 +366,10 @@ class BlockListResponse(ToolResponseBase): blocks: list[BlockInfoSummary] count: int query: str + usage_hint: str = Field( + default="To execute a block, call run_block with block_id set to the block's " + "'id' field and input_data containing the fields listed in required_inputs." + ) class BlockDetails(BaseModel): @@ -435,6 +458,27 @@ class AsyncProcessingResponse(ToolResponseBase): task_id: str | None = None +class WebFetchResponse(ToolResponseBase): + """Response for web_fetch tool.""" + + type: ResponseType = ResponseType.WEB_FETCH + url: str + status_code: int + content_type: str + content: str + truncated: bool = False + + +class BashExecResponse(ToolResponseBase): + """Response for bash_exec tool.""" + + type: ResponseType = ResponseType.BASH_EXEC + stdout: str + stderr: str + exit_code: int + timed_out: bool = False + + # Feature request models class FeatureRequestInfo(BaseModel): """Information about a feature request issue.""" diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/sandbox.py b/autogpt_platform/backend/backend/api/features/chat/tools/sandbox.py new file mode 100644 index 0000000000..beb326f909 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/sandbox.py @@ -0,0 +1,265 @@ +"""Sandbox execution utilities for code execution tools. + +Provides filesystem + network isolated command execution using **bubblewrap** +(``bwrap``): whitelist-only filesystem (only system dirs visible read-only), +writable workspace only, clean environment, network blocked. + +Tools that call :func:`run_sandboxed` must first check :func:`has_full_sandbox` +and refuse to run if bubblewrap is not available. +""" + +import asyncio +import logging +import os +import platform +import shutil + +logger = logging.getLogger(__name__) + +_DEFAULT_TIMEOUT = 30 +_MAX_TIMEOUT = 120 + + +# --------------------------------------------------------------------------- +# Sandbox capability detection (cached at first call) +# --------------------------------------------------------------------------- + +_BWRAP_AVAILABLE: bool | None = None + + +def has_full_sandbox() -> bool: + """Return True if bubblewrap is available (filesystem + network isolation). + + On non-Linux platforms (macOS), always returns False. + """ + global _BWRAP_AVAILABLE + if _BWRAP_AVAILABLE is None: + _BWRAP_AVAILABLE = ( + platform.system() == "Linux" and shutil.which("bwrap") is not None + ) + return _BWRAP_AVAILABLE + + +WORKSPACE_PREFIX = "/tmp/copilot-" + + +def make_session_path(session_id: str) -> str: + """Build a sanitized, session-specific path under :data:`WORKSPACE_PREFIX`. + + Shared by both the SDK working-directory setup and the sandbox tools so + they always resolve to the same directory for a given session. + + Steps: + 1. Strip all characters except ``[A-Za-z0-9-]``. + 2. Construct ``/tmp/copilot-``. + 3. Validate via ``os.path.normpath`` + ``startswith`` (CodeQL-recognised + sanitizer) to prevent path traversal. + + Raises: + ValueError: If the resulting path escapes the prefix. + """ + import re + + safe_id = re.sub(r"[^A-Za-z0-9-]", "", session_id) + if not safe_id: + safe_id = "default" + path = os.path.normpath(f"{WORKSPACE_PREFIX}{safe_id}") + if not path.startswith(WORKSPACE_PREFIX): + raise ValueError(f"Session path escaped prefix: {path}") + return path + + +def get_workspace_dir(session_id: str) -> str: + """Get or create the workspace directory for a session. + + Uses :func:`make_session_path` — the same path the SDK uses — so that + bash_exec shares the workspace with the SDK file tools. + """ + workspace = make_session_path(session_id) + os.makedirs(workspace, exist_ok=True) + return workspace + + +# --------------------------------------------------------------------------- +# Bubblewrap command builder +# --------------------------------------------------------------------------- + +# System directories mounted read-only inside the sandbox. +# ONLY these are visible — /app, /root, /home, /opt, /var etc. are NOT accessible. +_SYSTEM_RO_BINDS = [ + "/usr", # binaries, libraries, Python interpreter + "/etc", # system config: ld.so, locale, passwd, alternatives +] + +# Compat paths: symlinks to /usr/* on modern Debian, real dirs on older systems. +# On Debian 13 these are symlinks (e.g. /bin -> usr/bin). bwrap --ro-bind +# can't create a symlink target, so we detect and use --symlink instead. +# /lib64 is critical: the ELF dynamic linker lives at /lib64/ld-linux-x86-64.so.2. +_COMPAT_PATHS = [ + ("/bin", "usr/bin"), # -> /usr/bin on Debian 13 + ("/sbin", "usr/sbin"), # -> /usr/sbin on Debian 13 + ("/lib", "usr/lib"), # -> /usr/lib on Debian 13 + ("/lib64", "usr/lib64"), # 64-bit libraries / ELF interpreter +] + +# Resource limits to prevent fork bombs, memory exhaustion, and disk abuse. +# Applied via ulimit inside the sandbox before exec'ing the user command. +_RESOURCE_LIMITS = ( + "ulimit -u 64" # max 64 processes (prevents fork bombs) + " -v 524288" # 512 MB virtual memory + " -f 51200" # 50 MB max file size (1024-byte blocks) + " -n 256" # 256 open file descriptors + " 2>/dev/null" +) + + +def _build_bwrap_command( + command: list[str], cwd: str, env: dict[str, str] +) -> list[str]: + """Build a bubblewrap command with strict filesystem + network isolation. + + Security model: + - **Whitelist-only filesystem**: only system directories (``/usr``, ``/etc``, + ``/bin``, ``/lib``) are mounted read-only. Application code (``/app``), + home directories, ``/var``, ``/opt``, etc. are NOT accessible at all. + - **Writable workspace only**: the per-session workspace is the sole + writable path. + - **Clean environment**: ``--clearenv`` wipes all inherited env vars. + Only the explicitly-passed safe env vars are set inside the sandbox. + - **Network isolation**: ``--unshare-net`` blocks all network access. + - **Resource limits**: ulimit caps on processes (64), memory (512MB), + file size (50MB), and open FDs (256) to prevent fork bombs and abuse. + - **New session**: prevents terminal control escape. + - **Die with parent**: prevents orphaned sandbox processes. + """ + cmd = [ + "bwrap", + # Create a new user namespace so bwrap can set up sandboxing + # inside unprivileged Docker containers (no CAP_SYS_ADMIN needed). + "--unshare-user", + # Wipe all inherited environment variables (API keys, secrets, etc.) + "--clearenv", + ] + + # Set only the safe env vars inside the sandbox + for key, value in env.items(): + cmd.extend(["--setenv", key, value]) + + # System directories: read-only + for path in _SYSTEM_RO_BINDS: + cmd.extend(["--ro-bind", path, path]) + + # Compat paths: use --symlink when host path is a symlink (Debian 13), + # --ro-bind when it's a real directory (older distros). + for path, symlink_target in _COMPAT_PATHS: + if os.path.islink(path): + cmd.extend(["--symlink", symlink_target, path]) + elif os.path.exists(path): + cmd.extend(["--ro-bind", path, path]) + + # Wrap the user command with resource limits: + # sh -c 'ulimit ...; exec "$@"' -- + # `exec "$@"` replaces the shell so there's no extra process overhead, + # and properly handles arguments with spaces. + limited_command = [ + "sh", + "-c", + f'{_RESOURCE_LIMITS}; exec "$@"', + "--", + *command, + ] + + cmd.extend( + [ + # Fresh virtual filesystems + "--dev", + "/dev", + "--proc", + "/proc", + "--tmpfs", + "/tmp", + # Workspace bind AFTER --tmpfs /tmp so it's visible through the tmpfs. + # (workspace lives under /tmp/copilot-) + "--bind", + cwd, + cwd, + # Isolation + "--unshare-net", + "--die-with-parent", + "--new-session", + "--chdir", + cwd, + "--", + *limited_command, + ] + ) + + return cmd + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +async def run_sandboxed( + command: list[str], + cwd: str, + timeout: int = _DEFAULT_TIMEOUT, + env: dict[str, str] | None = None, +) -> tuple[str, str, int, bool]: + """Run a command inside a bubblewrap sandbox. + + Callers **must** check :func:`has_full_sandbox` before calling this + function. If bubblewrap is not available, this function raises + :class:`RuntimeError` rather than running unsandboxed. + + Returns: + (stdout, stderr, exit_code, timed_out) + """ + if not has_full_sandbox(): + raise RuntimeError( + "run_sandboxed() requires bubblewrap but bwrap is not available. " + "Callers must check has_full_sandbox() before calling this function." + ) + + timeout = min(max(timeout, 1), _MAX_TIMEOUT) + + safe_env = { + "PATH": "/usr/local/bin:/usr/bin:/bin", + "HOME": cwd, + "TMPDIR": cwd, + "LANG": "en_US.UTF-8", + "PYTHONDONTWRITEBYTECODE": "1", + "PYTHONIOENCODING": "utf-8", + } + if env: + safe_env.update(env) + + full_command = _build_bwrap_command(command, cwd, safe_env) + + try: + proc = await asyncio.create_subprocess_exec( + *full_command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + env=safe_env, + ) + + try: + stdout_bytes, stderr_bytes = await asyncio.wait_for( + proc.communicate(), timeout=timeout + ) + stdout = stdout_bytes.decode("utf-8", errors="replace") + stderr = stderr_bytes.decode("utf-8", errors="replace") + return stdout, stderr, proc.returncode or 0, False + except asyncio.TimeoutError: + proc.kill() + await proc.communicate() + return "", f"Execution timed out after {timeout}s", -1, True + + except RuntimeError: + raise + except Exception as e: + return "", f"Sandbox error: {e}", -1, False diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/web_fetch.py b/autogpt_platform/backend/backend/api/features/chat/tools/web_fetch.py new file mode 100644 index 0000000000..fed7cc11fa --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/web_fetch.py @@ -0,0 +1,151 @@ +"""Web fetch tool — safely retrieve public web page content.""" + +import logging +from typing import Any + +import aiohttp +import html2text + +from backend.api.features.chat.model import ChatSession +from backend.api.features.chat.tools.base import BaseTool +from backend.api.features.chat.tools.models import ( + ErrorResponse, + ToolResponseBase, + WebFetchResponse, +) +from backend.util.request import Requests + +logger = logging.getLogger(__name__) + +# Limits +_MAX_CONTENT_BYTES = 102_400 # 100 KB download cap +_REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=15) + +# Content types we'll read as text +_TEXT_CONTENT_TYPES = { + "text/html", + "text/plain", + "text/xml", + "text/csv", + "text/markdown", + "application/json", + "application/xml", + "application/xhtml+xml", + "application/rss+xml", + "application/atom+xml", +} + + +def _is_text_content(content_type: str) -> bool: + base = content_type.split(";")[0].strip().lower() + return base in _TEXT_CONTENT_TYPES or base.startswith("text/") + + +def _html_to_text(html: str) -> str: + h = html2text.HTML2Text() + h.ignore_links = False + h.ignore_images = True + h.body_width = 0 + return h.handle(html) + + +class WebFetchTool(BaseTool): + """Safely fetch content from a public URL using SSRF-protected HTTP.""" + + @property + def name(self) -> str: + return "web_fetch" + + @property + def description(self) -> str: + return ( + "Fetch the content of a public web page by URL. " + "Returns readable text extracted from HTML by default. " + "Useful for reading documentation, articles, and API responses. " + "Only supports HTTP/HTTPS GET requests to public URLs " + "(private/internal network addresses are blocked)." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "The public HTTP/HTTPS URL to fetch.", + }, + "extract_text": { + "type": "boolean", + "description": ( + "If true (default), extract readable text from HTML. " + "If false, return raw content." + ), + "default": True, + }, + }, + "required": ["url"], + } + + @property + def requires_auth(self) -> bool: + return False + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs: Any, + ) -> ToolResponseBase: + url: str = (kwargs.get("url") or "").strip() + extract_text: bool = kwargs.get("extract_text", True) + session_id = session.session_id if session else None + + if not url: + return ErrorResponse( + message="Please provide a URL to fetch.", + error="missing_url", + session_id=session_id, + ) + + try: + client = Requests(raise_for_status=False, retry_max_attempts=1) + response = await client.get(url, timeout=_REQUEST_TIMEOUT) + except ValueError as e: + # validate_url raises ValueError for SSRF / blocked IPs + return ErrorResponse( + message=f"URL blocked: {e}", + error="url_blocked", + session_id=session_id, + ) + except Exception as e: + logger.warning(f"[web_fetch] Request failed for {url}: {e}") + return ErrorResponse( + message=f"Failed to fetch URL: {e}", + error="fetch_failed", + session_id=session_id, + ) + + content_type = response.headers.get("content-type", "") + if not _is_text_content(content_type): + return ErrorResponse( + message=f"Non-text content type: {content_type.split(';')[0]}", + error="unsupported_content_type", + session_id=session_id, + ) + + raw = response.content[:_MAX_CONTENT_BYTES] + text = raw.decode("utf-8", errors="replace") + + if extract_text and "html" in content_type.lower(): + text = _html_to_text(text) + + return WebFetchResponse( + message=f"Fetched {url}", + url=response.url, + status_code=response.status, + content_type=content_type.split(";")[0].strip(), + content=text, + truncated=False, + session_id=session_id, + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py b/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py index 03532c8fee..f37d2c80e0 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py @@ -88,7 +88,9 @@ class ListWorkspaceFilesTool(BaseTool): @property def description(self) -> str: return ( - "List files in the user's workspace. " + "List files in the user's persistent workspace (cloud storage). " + "These files survive across sessions. " + "For ephemeral session files, use the SDK Read/Glob tools instead. " "Returns file names, paths, sizes, and metadata. " "Optionally filter by path prefix." ) @@ -204,7 +206,9 @@ class ReadWorkspaceFileTool(BaseTool): @property def description(self) -> str: return ( - "Read a file from the user's workspace. " + "Read a file from the user's persistent workspace (cloud storage). " + "These files survive across sessions. " + "For ephemeral session files, use the SDK Read tool instead. " "Specify either file_id or path to identify the file. " "For small text files, returns content directly. " "For large or binary files, returns metadata and a download URL. " @@ -378,7 +382,9 @@ class WriteWorkspaceFileTool(BaseTool): @property def description(self) -> str: return ( - "Write or create a file in the user's workspace. " + "Write or create a file in the user's persistent workspace (cloud storage). " + "These files survive across sessions. " + "For ephemeral session files, use the SDK Write tool instead. " "Provide the content as a base64-encoded string. " f"Maximum file size is {Config().max_file_size_mb}MB. " "Files are saved to the current session's folder by default. " @@ -523,7 +529,7 @@ class DeleteWorkspaceFileTool(BaseTool): @property def description(self) -> str: return ( - "Delete a file from the user's workspace. " + "Delete a file from the user's persistent workspace (cloud storage). " "Specify either file_id or path to identify the file. " "Paths are scoped to the current session by default. " "Use /sessions//... for cross-session access." diff --git a/autogpt_platform/backend/backend/util/feature_flag.py b/autogpt_platform/backend/backend/util/feature_flag.py index fbd3573112..4eadc41333 100644 --- a/autogpt_platform/backend/backend/util/feature_flag.py +++ b/autogpt_platform/backend/backend/util/feature_flag.py @@ -38,6 +38,7 @@ class Flag(str, Enum): AGENT_ACTIVITY = "agent-activity" ENABLE_PLATFORM_PAYMENT = "enable-platform-payment" CHAT = "chat" + COPILOT_SDK = "copilot-sdk" def is_configured() -> bool: diff --git a/autogpt_platform/backend/poetry.lock b/autogpt_platform/backend/poetry.lock index d71cca7865..8062457a70 100644 --- a/autogpt_platform/backend/poetry.lock +++ b/autogpt_platform/backend/poetry.lock @@ -897,6 +897,29 @@ files = [ {file = "charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a"}, ] +[[package]] +name = "claude-agent-sdk" +version = "0.1.35" +description = "Python SDK for Claude Code" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "claude_agent_sdk-0.1.35-py3-none-macosx_11_0_arm64.whl", hash = "sha256:df67f4deade77b16a9678b3a626c176498e40417f33b04beda9628287f375591"}, + {file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:14963944f55ded7c8ed518feebfa5b4284aa6dd8d81aeff2e5b21a962ce65097"}, + {file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:84344dcc535d179c1fc8a11c6f34c37c3b583447bdf09d869effb26514fd7a65"}, + {file = "claude_agent_sdk-0.1.35-py3-none-win_amd64.whl", hash = "sha256:1b3d54b47448c93f6f372acd4d1757f047c3c1e8ef5804be7a1e3e53e2c79a5f"}, + {file = "claude_agent_sdk-0.1.35.tar.gz", hash = "sha256:0f98e2b3c71ca85abfc042e7a35c648df88e87fda41c52e6779ef7b038dcbb52"}, +] + +[package.dependencies] +anyio = ">=4.0.0" +mcp = ">=0.1.0" +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +dev = ["anyio[trio] (>=4.0.0)", "mypy (>=1.0.0)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.20.0)", "pytest-cov (>=4.0.0)", "ruff (>=0.1.0)"] + [[package]] name = "cleo" version = "2.1.0" @@ -2593,6 +2616,18 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "httpx-sse" +version = "0.4.3" +description = "Consume Server-Sent Event (SSE) messages with HTTPX." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc"}, + {file = "httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d"}, +] + [[package]] name = "huggingface-hub" version = "1.4.1" @@ -3310,6 +3345,39 @@ files = [ {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, ] +[[package]] +name = "mcp" +version = "1.26.0" +description = "Model Context Protocol SDK" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca"}, + {file = "mcp-1.26.0.tar.gz", hash = "sha256:db6e2ef491eecc1a0d93711a76f28dec2e05999f93afd48795da1c1137142c66"}, +] + +[package.dependencies] +anyio = ">=4.5" +httpx = ">=0.27.1" +httpx-sse = ">=0.4" +jsonschema = ">=4.20.0" +pydantic = ">=2.11.0,<3.0.0" +pydantic-settings = ">=2.5.2" +pyjwt = {version = ">=2.10.1", extras = ["crypto"]} +python-multipart = ">=0.0.9" +pywin32 = {version = ">=310", markers = "sys_platform == \"win32\""} +sse-starlette = ">=1.6.1" +starlette = ">=0.27" +typing-extensions = ">=4.9.0" +typing-inspection = ">=0.4.1" +uvicorn = {version = ">=0.31.1", markers = "sys_platform != \"emscripten\""} + +[package.extras] +cli = ["python-dotenv (>=1.0.0)", "typer (>=0.16.0)"] +rich = ["rich (>=13.9.4)"] +ws = ["websockets (>=15.0.1)"] + [[package]] name = "mdurl" version = "0.1.2" @@ -5994,7 +6062,7 @@ description = "Python for Window Extensions" optional = false python-versions = "*" groups = ["main"] -markers = "platform_system == \"Windows\"" +markers = "sys_platform == \"win32\" or platform_system == \"Windows\"" files = [ {file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"}, {file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"}, @@ -6974,6 +7042,28 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] sqlcipher = ["sqlcipher3_binary"] +[[package]] +name = "sse-starlette" +version = "3.2.0" +description = "SSE plugin for Starlette" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "sse_starlette-3.2.0-py3-none-any.whl", hash = "sha256:5876954bd51920fc2cd51baee47a080eb88a37b5b784e615abb0b283f801cdbf"}, + {file = "sse_starlette-3.2.0.tar.gz", hash = "sha256:8127594edfb51abe44eac9c49e59b0b01f1039d0c7461c6fd91d4e03b70da422"}, +] + +[package.dependencies] +anyio = ">=4.7.0" +starlette = ">=0.49.1" + +[package.extras] +daphne = ["daphne (>=4.2.0)"] +examples = ["aiosqlite (>=0.21.0)", "fastapi (>=0.115.12)", "sqlalchemy[asyncio] (>=2.0.41)", "uvicorn (>=0.34.0)"] +granian = ["granian (>=2.3.1)"] +uvicorn = ["uvicorn (>=0.34.0)"] + [[package]] name = "stagehand" version = "0.5.9" @@ -8440,4 +8530,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "fa9c5deadf593e815dd2190f58e22152373900603f5f244b9616cd721de84d2f" +content-hash = "55e095de555482f0fe47de7695f390fe93e7bcf739b31c391b2e5e3c3d938ae3" diff --git a/autogpt_platform/backend/pyproject.toml b/autogpt_platform/backend/pyproject.toml index 32dfc547bc..7a112e75ca 100644 --- a/autogpt_platform/backend/pyproject.toml +++ b/autogpt_platform/backend/pyproject.toml @@ -16,6 +16,7 @@ anthropic = "^0.79.0" apscheduler = "^3.11.1" autogpt-libs = { path = "../autogpt_libs", develop = true } bleach = { extras = ["css"], version = "^6.2.0" } +claude-agent-sdk = "^0.1.0" click = "^8.2.0" cryptography = "^46.0" discord-py = "^2.5.2" diff --git a/autogpt_platform/backend/test/chat/__init__.py b/autogpt_platform/backend/test/chat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/autogpt_platform/backend/test/chat/test_security_hooks.py b/autogpt_platform/backend/test/chat/test_security_hooks.py new file mode 100644 index 0000000000..f10a90871b --- /dev/null +++ b/autogpt_platform/backend/test/chat/test_security_hooks.py @@ -0,0 +1,133 @@ +"""Tests for SDK security hooks — workspace paths, tool access, and deny messages. + +These are pure unit tests with no external dependencies (no SDK, no DB, no server). +They validate that the security hooks correctly block unauthorized paths, +tool access, and dangerous input patterns. + +Note: Bash command validation was removed — the SDK built-in Bash tool is not in +allowed_tools, and the bash_exec MCP tool has kernel-level network isolation +(unshare --net) making command-level parsing unnecessary. +""" + +from backend.api.features.chat.sdk.security_hooks import ( + _validate_tool_access, + _validate_workspace_path, +) + +SDK_CWD = "/tmp/copilot-test-session" + + +def _is_denied(result: dict) -> bool: + hook = result.get("hookSpecificOutput", {}) + return hook.get("permissionDecision") == "deny" + + +def _reason(result: dict) -> str: + return result.get("hookSpecificOutput", {}).get("permissionDecisionReason", "") + + +# ============================================================ +# Workspace path validation (Read, Write, Edit, etc.) +# ============================================================ + + +class TestWorkspacePathValidation: + def test_path_in_workspace(self): + result = _validate_workspace_path( + "Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD + ) + assert not _is_denied(result) + + def test_path_outside_workspace(self): + result = _validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD) + assert _is_denied(result) + + def test_tool_results_allowed(self): + result = _validate_workspace_path( + "Read", + {"file_path": "~/.claude/projects/abc/tool-results/out.txt"}, + SDK_CWD, + ) + assert not _is_denied(result) + + def test_claude_settings_blocked(self): + result = _validate_workspace_path( + "Read", {"file_path": "~/.claude/settings.json"}, SDK_CWD + ) + assert _is_denied(result) + + def test_claude_projects_without_tool_results(self): + result = _validate_workspace_path( + "Read", {"file_path": "~/.claude/projects/abc/credentials.json"}, SDK_CWD + ) + assert _is_denied(result) + + def test_no_path_allowed(self): + """Glob/Grep without path defaults to cwd — should be allowed.""" + result = _validate_workspace_path("Grep", {"pattern": "foo"}, SDK_CWD) + assert not _is_denied(result) + + def test_path_traversal_with_dotdot(self): + result = _validate_workspace_path( + "Read", {"file_path": f"{SDK_CWD}/../../../etc/passwd"}, SDK_CWD + ) + assert _is_denied(result) + + +# ============================================================ +# Tool access validation +# ============================================================ + + +class TestToolAccessValidation: + def test_blocked_tools(self): + for tool in ("bash", "shell", "exec", "terminal", "command"): + result = _validate_tool_access(tool, {}) + assert _is_denied(result), f"Tool '{tool}' should be blocked" + + def test_bash_builtin_blocked(self): + """SDK built-in Bash (capital) is blocked as defence-in-depth.""" + result = _validate_tool_access("Bash", {"command": "echo hello"}, SDK_CWD) + assert _is_denied(result) + assert "Bash" in _reason(result) + + def test_workspace_tools_delegate(self): + result = _validate_tool_access( + "Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD + ) + assert not _is_denied(result) + + def test_dangerous_pattern_blocked(self): + result = _validate_tool_access("SomeUnknownTool", {"data": "sudo rm -rf /"}) + assert _is_denied(result) + + def test_safe_unknown_tool_allowed(self): + result = _validate_tool_access("SomeSafeTool", {"data": "hello world"}) + assert not _is_denied(result) + + +# ============================================================ +# Deny message quality (ntindle feedback) +# ============================================================ + + +class TestDenyMessageClarity: + """Deny messages must include [SECURITY] and 'cannot be bypassed' + so the model knows the restriction is enforced, not a suggestion.""" + + def test_blocked_tool_message(self): + reason = _reason(_validate_tool_access("bash", {})) + assert "[SECURITY]" in reason + assert "cannot be bypassed" in reason + + def test_bash_builtin_blocked_message(self): + reason = _reason(_validate_tool_access("Bash", {"command": "echo hello"})) + assert "[SECURITY]" in reason + assert "cannot be bypassed" in reason + + def test_workspace_path_message(self): + reason = _reason( + _validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD) + ) + assert "[SECURITY]" in reason + assert "cannot be bypassed" in reason diff --git a/autogpt_platform/backend/test/chat/test_transcript.py b/autogpt_platform/backend/test/chat/test_transcript.py new file mode 100644 index 0000000000..71b1fad81f --- /dev/null +++ b/autogpt_platform/backend/test/chat/test_transcript.py @@ -0,0 +1,255 @@ +"""Unit tests for JSONL transcript management utilities.""" + +import json +import os + +from backend.api.features.chat.sdk.transcript import ( + STRIPPABLE_TYPES, + read_transcript_file, + strip_progress_entries, + validate_transcript, + write_transcript_to_tempfile, +) + + +def _make_jsonl(*entries: dict) -> str: + return "\n".join(json.dumps(e) for e in entries) + "\n" + + +# --- Fixtures --- + + +METADATA_LINE = {"type": "queue-operation", "subtype": "create"} +FILE_HISTORY = {"type": "file-history-snapshot", "files": []} +USER_MSG = {"type": "user", "uuid": "u1", "message": {"role": "user", "content": "hi"}} +ASST_MSG = { + "type": "assistant", + "uuid": "a1", + "parentUuid": "u1", + "message": {"role": "assistant", "content": "hello"}, +} +PROGRESS_ENTRY = { + "type": "progress", + "uuid": "p1", + "parentUuid": "u1", + "data": {"type": "bash_progress", "stdout": "running..."}, +} + +VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG) + + +# --- read_transcript_file --- + + +class TestReadTranscriptFile: + def test_returns_content_for_valid_file(self, tmp_path): + path = tmp_path / "session.jsonl" + path.write_text(VALID_TRANSCRIPT) + result = read_transcript_file(str(path)) + assert result is not None + assert "user" in result + + def test_returns_none_for_missing_file(self): + assert read_transcript_file("/nonexistent/path.jsonl") is None + + def test_returns_none_for_empty_path(self): + assert read_transcript_file("") is None + + def test_returns_none_for_empty_file(self, tmp_path): + path = tmp_path / "empty.jsonl" + path.write_text("") + assert read_transcript_file(str(path)) is None + + def test_returns_none_for_metadata_only(self, tmp_path): + content = _make_jsonl(METADATA_LINE, FILE_HISTORY) + path = tmp_path / "meta.jsonl" + path.write_text(content) + assert read_transcript_file(str(path)) is None + + def test_returns_none_for_invalid_json(self, tmp_path): + path = tmp_path / "bad.jsonl" + path.write_text("not json\n{}\n{}\n") + assert read_transcript_file(str(path)) is None + + def test_no_size_limit(self, tmp_path): + """Large files are accepted — bucket storage has no size limit.""" + big_content = {"type": "user", "uuid": "u9", "data": "x" * 1_000_000} + content = _make_jsonl(METADATA_LINE, FILE_HISTORY, big_content, ASST_MSG) + path = tmp_path / "big.jsonl" + path.write_text(content) + result = read_transcript_file(str(path)) + assert result is not None + + +# --- write_transcript_to_tempfile --- + + +class TestWriteTranscriptToTempfile: + """Tests use /tmp/copilot-* paths to satisfy the sandbox prefix check.""" + + def test_writes_file_and_returns_path(self): + cwd = "/tmp/copilot-test-write" + try: + result = write_transcript_to_tempfile( + VALID_TRANSCRIPT, "sess-1234-abcd", cwd + ) + assert result is not None + assert os.path.isfile(result) + assert result.endswith(".jsonl") + with open(result) as f: + assert f.read() == VALID_TRANSCRIPT + finally: + import shutil + + shutil.rmtree(cwd, ignore_errors=True) + + def test_creates_parent_directory(self): + cwd = "/tmp/copilot-test-mkdir" + try: + result = write_transcript_to_tempfile(VALID_TRANSCRIPT, "sess-1234", cwd) + assert result is not None + assert os.path.isdir(cwd) + finally: + import shutil + + shutil.rmtree(cwd, ignore_errors=True) + + def test_uses_session_id_prefix(self): + cwd = "/tmp/copilot-test-prefix" + try: + result = write_transcript_to_tempfile( + VALID_TRANSCRIPT, "abcdef12-rest", cwd + ) + assert result is not None + assert "abcdef12" in os.path.basename(result) + finally: + import shutil + + shutil.rmtree(cwd, ignore_errors=True) + + def test_rejects_cwd_outside_sandbox(self, tmp_path): + cwd = str(tmp_path / "not-copilot") + result = write_transcript_to_tempfile(VALID_TRANSCRIPT, "sess-1234", cwd) + assert result is None + + +# --- validate_transcript --- + + +class TestValidateTranscript: + def test_valid_transcript(self): + assert validate_transcript(VALID_TRANSCRIPT) is True + + def test_none_content(self): + assert validate_transcript(None) is False + + def test_empty_content(self): + assert validate_transcript("") is False + + def test_metadata_only(self): + content = _make_jsonl(METADATA_LINE, FILE_HISTORY) + assert validate_transcript(content) is False + + def test_user_only_no_assistant(self): + content = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG) + assert validate_transcript(content) is False + + def test_assistant_only_no_user(self): + content = _make_jsonl(METADATA_LINE, FILE_HISTORY, ASST_MSG) + assert validate_transcript(content) is False + + def test_invalid_json_returns_false(self): + assert validate_transcript("not json\n{}\n{}\n") is False + + +# --- strip_progress_entries --- + + +class TestStripProgressEntries: + def test_strips_all_strippable_types(self): + """All STRIPPABLE_TYPES are removed from the output.""" + entries = [ + USER_MSG, + {"type": "progress", "uuid": "p1", "parentUuid": "u1"}, + {"type": "file-history-snapshot", "files": []}, + {"type": "queue-operation", "subtype": "create"}, + {"type": "summary", "text": "..."}, + {"type": "pr-link", "url": "..."}, + ASST_MSG, + ] + result = strip_progress_entries(_make_jsonl(*entries)) + result_types = {json.loads(line)["type"] for line in result.strip().split("\n")} + assert result_types == {"user", "assistant"} + for stype in STRIPPABLE_TYPES: + assert stype not in result_types + + def test_reparents_children_of_stripped_entries(self): + """An assistant message whose parent is a progress entry gets reparented.""" + progress = { + "type": "progress", + "uuid": "p1", + "parentUuid": "u1", + "data": {"type": "bash_progress"}, + } + asst = { + "type": "assistant", + "uuid": "a1", + "parentUuid": "p1", # Points to progress + "message": {"role": "assistant", "content": "done"}, + } + content = _make_jsonl(USER_MSG, progress, asst) + result = strip_progress_entries(content) + lines = [json.loads(line) for line in result.strip().split("\n")] + + asst_entry = next(e for e in lines if e["type"] == "assistant") + # Should be reparented to u1 (the user message) + assert asst_entry["parentUuid"] == "u1" + + def test_reparents_through_chain(self): + """Reparenting walks through multiple stripped entries.""" + p1 = {"type": "progress", "uuid": "p1", "parentUuid": "u1"} + p2 = {"type": "progress", "uuid": "p2", "parentUuid": "p1"} + p3 = {"type": "progress", "uuid": "p3", "parentUuid": "p2"} + asst = { + "type": "assistant", + "uuid": "a1", + "parentUuid": "p3", # 3 levels deep + "message": {"role": "assistant", "content": "done"}, + } + content = _make_jsonl(USER_MSG, p1, p2, p3, asst) + result = strip_progress_entries(content) + lines = [json.loads(line) for line in result.strip().split("\n")] + + asst_entry = next(e for e in lines if e["type"] == "assistant") + assert asst_entry["parentUuid"] == "u1" + + def test_preserves_non_strippable_entries(self): + """User, assistant, and system entries are preserved.""" + system = {"type": "system", "uuid": "s1", "message": "prompt"} + content = _make_jsonl(system, USER_MSG, ASST_MSG) + result = strip_progress_entries(content) + result_types = [json.loads(line)["type"] for line in result.strip().split("\n")] + assert result_types == ["system", "user", "assistant"] + + def test_empty_input(self): + result = strip_progress_entries("") + # Should return just a newline (empty content stripped) + assert result.strip() == "" + + def test_no_strippable_entries(self): + """When there's nothing to strip, output matches input structure.""" + content = _make_jsonl(USER_MSG, ASST_MSG) + result = strip_progress_entries(content) + result_lines = result.strip().split("\n") + assert len(result_lines) == 2 + + def test_handles_entries_without_uuid(self): + """Entries without uuid field are handled gracefully.""" + no_uuid = {"type": "queue-operation", "subtype": "create"} + content = _make_jsonl(no_uuid, USER_MSG, ASST_MSG) + result = strip_progress_entries(content) + result_types = [json.loads(line)["type"] for line in result.strip().split("\n")] + # queue-operation is strippable + assert "queue-operation" not in result_types + assert "user" in result_types + assert "assistant" in result_types diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx index b62e96f58a..c118057963 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx @@ -24,6 +24,7 @@ import { FindBlocksTool } from "../../tools/FindBlocks/FindBlocks"; import { RunAgentTool } from "../../tools/RunAgent/RunAgent"; import { RunBlockTool } from "../../tools/RunBlock/RunBlock"; import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs"; +import { GenericTool } from "../../tools/GenericTool/GenericTool"; import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput"; // --------------------------------------------------------------------------- @@ -273,6 +274,16 @@ export const ChatMessagesContainer = ({ /> ); default: + // Render a generic tool indicator for SDK built-in + // tools (Read, Glob, Grep, etc.) or any unrecognized tool + if (part.type.startsWith("tool-")) { + return ( + + ); + } return null; } })} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx new file mode 100644 index 0000000000..677f1d01d1 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx @@ -0,0 +1,63 @@ +"use client"; + +import { ToolUIPart } from "ai"; +import { GearIcon } from "@phosphor-icons/react"; +import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation"; + +interface Props { + part: ToolUIPart; +} + +function extractToolName(part: ToolUIPart): string { + // ToolUIPart.type is "tool-{name}", extract the name portion. + return part.type.replace(/^tool-/, ""); +} + +function formatToolName(name: string): string { + // "search_docs" → "Search docs", "Read" → "Read" + return name.replace(/_/g, " ").replace(/^\w/, (c) => c.toUpperCase()); +} + +function getAnimationText(part: ToolUIPart): string { + const label = formatToolName(extractToolName(part)); + + switch (part.state) { + case "input-streaming": + case "input-available": + return `Running ${label}…`; + case "output-available": + return `${label} completed`; + case "output-error": + return `${label} failed`; + default: + return `Running ${label}…`; + } +} + +export function GenericTool({ part }: Props) { + const isStreaming = + part.state === "input-streaming" || part.state === "input-available"; + const isError = part.state === "output-error"; + + return ( +
+
+ + +
+
+ ); +} diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index 1e8dca865c..8e48931540 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -7066,13 +7066,57 @@ "properties": { "id": { "type": "string", "title": "Id" }, "name": { "type": "string", "title": "Name" }, - "description": { "type": "string", "title": "Description" } + "description": { "type": "string", "title": "Description" }, + "categories": { + "items": { "type": "string" }, + "type": "array", + "title": "Categories" + }, + "input_schema": { + "additionalProperties": true, + "type": "object", + "title": "Input Schema", + "description": "Full JSON schema for block inputs" + }, + "output_schema": { + "additionalProperties": true, + "type": "object", + "title": "Output Schema", + "description": "Full JSON schema for block outputs" + }, + "required_inputs": { + "items": { "$ref": "#/components/schemas/BlockInputFieldInfo" }, + "type": "array", + "title": "Required Inputs", + "description": "List of input fields for this block" + } }, "type": "object", - "required": ["id", "name", "description"], + "required": ["id", "name", "description", "categories"], "title": "BlockInfoSummary", "description": "Summary of a block for search results." }, + "BlockInputFieldInfo": { + "properties": { + "name": { "type": "string", "title": "Name" }, + "type": { "type": "string", "title": "Type" }, + "description": { + "type": "string", + "title": "Description", + "default": "" + }, + "required": { + "type": "boolean", + "title": "Required", + "default": false + }, + "default": { "anyOf": [{}, { "type": "null" }], "title": "Default" } + }, + "type": "object", + "required": ["name", "type"], + "title": "BlockInputFieldInfo", + "description": "Information about a block input field." + }, "BlockListResponse": { "properties": { "type": { @@ -7090,7 +7134,12 @@ "title": "Blocks" }, "count": { "type": "integer", "title": "Count" }, - "query": { "type": "string", "title": "Query" } + "query": { "type": "string", "title": "Query" }, + "usage_hint": { + "type": "string", + "title": "Usage Hint", + "default": "To execute a block, call run_block with block_id set to the block's 'id' field and input_data containing the fields listed in required_inputs." + } }, "type": "object", "required": ["message", "blocks", "count", "query"], @@ -10496,6 +10545,9 @@ "operation_pending", "operation_in_progress", "input_validation_error", + "web_fetch", + "bash_exec", + "operation_status", "feature_request_search", "feature_request_created" ], From f9f358c5263b32902810769a5bfcfe0254aa3c2c Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Fri, 13 Feb 2026 20:17:03 +0400 Subject: [PATCH 11/16] feat(mcp): Add MCP tool block with OAuth, tool discovery, and standard credential integration (#12011) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary image image image image Full-stack MCP (Model Context Protocol) tool block integration that allows users to connect to any MCP server, discover available tools, authenticate via OAuth, and execute tools — all through the standard AutoGPT credential system. ### Backend - **MCPToolBlock** (`blocks/mcp/block.py`): New block using `CredentialsMetaInput` pattern with optional credentials (`default={}`), supporting both authenticated (OAuth) and public MCP servers. Includes auto-lookup fallback for backward compatibility. - **MCP Client** (`blocks/mcp/client.py`): HTTP transport with JSON-RPC 2.0, tool discovery, tool execution with robust error handling (type-checked error fields, non-JSON response handling) - **MCP OAuth Handler** (`blocks/mcp/oauth.py`): RFC 8414 discovery, dynamic per-server OAuth with PKCE, token storage and refresh via `raise_for_status=True` - **MCP API Routes** (`api/features/mcp/routes.py`): `discover-tools`, `oauth/login`, `oauth/callback` endpoints with credential cleanup, defensive OAuth metadata validation - **Credential system integration**: - `CredentialsMetaInput` model_validator normalizes legacy `"ProviderName.MCP"` format from Python 3.13's `str(StrEnum)` change - `CredentialsFieldInfo.combine()` supports URL-based credential discrimination (each MCP server gets its own credential entry) - `aggregate_credentials_inputs` checks block schema defaults for credential optionality - Executor normalizes credential data for both Pydantic and JSON schema validation paths - Chat credential matching handles MCP server URL filtering - `provider_matches()` helper used consistently for Python 3.13 StrEnum compatibility - **Pre-run validation**: `_validate_graph_get_errors` now calls `get_missing_input()` for custom block-level validation (MCP tool arguments) - **Security**: HTML tag stripping loop to prevent XSS bypass, SSRF protection (removed trusted_origins) ### Frontend - **MCPToolDialog** (`MCPToolDialog.tsx`): Full tool discovery UI — enter server URL, authenticate if needed, browse tools, select tool and configure - **OAuth popup** (`oauth-popup.ts`): Shared utility supporting cross-origin MCP OAuth flows with BroadcastChannel + localStorage fallback - **Credential integration**: MCP-specific OAuth flow in `useCredentialsInput`, server URL filtering in `useCredentials`, MCP callback page - **CredentialsSelect**: Auto-selects first available credential instead of defaulting to "None", credentials listed before "None" in dropdown - **Node rendering**: Dynamic tool input schema rendering on MCP nodes, proper handling in both legacy and new flow editors - **Block title persistence**: `customized_name` set at block creation for both MCP and Agent blocks — no fallback logic needed, titles survive save/load reliably - **Stable credential ordering**: Removed `sortByUnsetFirst` that caused credential inputs to jump when selected ### Tests (~2060 lines) - Unit tests: block, client, tool execution - Integration tests: mock MCP server with auth - OAuth flow tests - API endpoint tests - Credential combining/optionality tests - E2e tests (skipped in CI, run manually) ## Key Design Decisions 1. **Optional credentials via `default={}`**: MCP servers can be public (no auth) or private (OAuth). The `credentials` field has `default={}` making it optional at the schema level, so public servers work without prompting for credentials. 2. **URL-based credential discrimination**: Each MCP server URL gets its own credential entry in the "Run agent" form (via `discriminator="server_url"`), so agents using multiple MCP servers prompt for each independently. 3. **Model-level normalization**: Python 3.13 changed `str(StrEnum)` to return `"ClassName.MEMBER"`. Rather than scattering fixes across the codebase, a Pydantic `model_validator(mode="before")` on `CredentialsMetaInput` handles normalization centrally, and `provider_matches()` handles lookups. 4. **Credential auto-select**: `CredentialsSelect` component defaults to the first available credential and notifies the parent state, ensuring credentials are pre-filled in the "Run agent" dialog without requiring manual selection. 5. **customized_name for block titles**: Both MCP and Agent blocks set `customized_name` in metadata at creation time. This eliminates convoluted runtime fallback logic (`agent_name`, hostname extraction) — the title is persisted once and read directly. ## Test plan - [x] Unit/integration tests pass (68 MCP + 11 graph = 79 tests) - [x] Manual: MCP block with public server (DeepWiki) — no credentials needed, tools discovered and executable - [x] Manual: MCP block with OAuth server (Linear, Sentry) — OAuth flow prompts correctly - [x] Manual: "Run agent" form shows correct credential requirements per MCP server - [x] Manual: Credential auto-selects when exactly one matches, pre-selects first when multiple exist - [x] Manual: Credential ordering stays stable when selecting/deselecting - [x] Manual: MCP block title persists after save and refresh - [x] Manual: Agent block title persists after save and refresh (via customized_name) - [ ] Manual: Shared agent with MCP block prompts new user for credentials --------- Co-authored-by: Otto Co-authored-by: Ubbe --- .../backend/api/features/chat/tools/utils.py | 23 +- .../api/features/integrations/router.py | 54 +- .../backend/api/features/mcp/__init__.py | 0 .../backend/api/features/mcp/routes.py | 404 ++++++++++++ .../backend/api/features/mcp/test_routes.py | 436 ++++++++++++ .../backend/backend/api/rest_api.py | 6 + .../backend/backend/blocks/_base.py | 1 + .../backend/backend/blocks/mcp/__init__.py | 0 .../backend/backend/blocks/mcp/block.py | 300 +++++++++ .../backend/backend/blocks/mcp/client.py | 323 +++++++++ .../backend/backend/blocks/mcp/oauth.py | 204 ++++++ .../backend/backend/blocks/mcp/test_e2e.py | 109 +++ .../backend/blocks/mcp/test_integration.py | 389 +++++++++++ .../backend/backend/blocks/mcp/test_mcp.py | 619 ++++++++++++++++++ .../backend/backend/blocks/mcp/test_oauth.py | 242 +++++++ .../backend/backend/blocks/mcp/test_server.py | 162 +++++ .../backend/backend/data/graph.py | 42 +- .../backend/backend/data/graph_test.py | 117 ++++ .../backend/backend/data/model.py | 35 +- .../backend/backend/executor/manager.py | 43 +- .../backend/backend/executor/utils.py | 22 +- .../backend/backend/executor/utils_test.py | 6 +- .../backend/integrations/credentials_store.py | 36 +- .../backend/integrations/creds_manager.py | 38 +- .../backend/backend/integrations/providers.py | 1 + .../webhooks/graph_lifecycle_hooks.py | 15 + .../backend/backend/util/request.py | 4 +- .../auth/integrations/mcp_callback/route.ts | 96 +++ .../nodes/CustomNode/CustomNode.tsx | 6 +- .../CustomNode/components/NodeHeader.tsx | 6 +- .../nodes/CustomNode/useCustomNode.tsx | 39 +- .../FlowEditor/nodes/FormCreator.tsx | 57 +- .../build/components/MCPToolDialog.tsx | 558 ++++++++++++++++ .../NewControlPanel/NewBlockMenu/Block.tsx | 176 +++-- .../app/(platform)/build/components/types.ts | 1 + .../frontend/src/app/api/openapi.json | 219 ++++++- .../CredentialsGroupedView.tsx | 11 +- .../CredentialsGroupedView/helpers.ts | 56 +- .../CredentialsInput/useCredentialsInput.ts | 165 +++-- .../frontend/src/hooks/useCredentials.ts | 5 + .../src/lib/autogpt-server-api/types.ts | 2 + .../frontend/src/lib/oauth-popup.ts | 177 +++++ autogpt_platform/frontend/src/middleware.ts | 2 +- .../credentials-provider.tsx | 37 ++ .../frontend/src/tests/pages/build.page.ts | 3 + docs/integrations/README.md | 1 + docs/integrations/SUMMARY.md | 1 + .../block-integrations/mcp/block.md | 40 ++ 48 files changed, 5074 insertions(+), 215 deletions(-) create mode 100644 autogpt_platform/backend/backend/api/features/mcp/__init__.py create mode 100644 autogpt_platform/backend/backend/api/features/mcp/routes.py create mode 100644 autogpt_platform/backend/backend/api/features/mcp/test_routes.py create mode 100644 autogpt_platform/backend/backend/blocks/mcp/__init__.py create mode 100644 autogpt_platform/backend/backend/blocks/mcp/block.py create mode 100644 autogpt_platform/backend/backend/blocks/mcp/client.py create mode 100644 autogpt_platform/backend/backend/blocks/mcp/oauth.py create mode 100644 autogpt_platform/backend/backend/blocks/mcp/test_e2e.py create mode 100644 autogpt_platform/backend/backend/blocks/mcp/test_integration.py create mode 100644 autogpt_platform/backend/backend/blocks/mcp/test_mcp.py create mode 100644 autogpt_platform/backend/backend/blocks/mcp/test_oauth.py create mode 100644 autogpt_platform/backend/backend/blocks/mcp/test_server.py create mode 100644 autogpt_platform/frontend/src/app/(platform)/auth/integrations/mcp_callback/route.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/MCPToolDialog.tsx create mode 100644 autogpt_platform/frontend/src/lib/oauth-popup.ts create mode 100644 docs/integrations/block-integrations/mcp/block.md diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py index 80a842bf36..3b2168d09e 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py @@ -15,6 +15,7 @@ from backend.data.model import ( OAuth2Credentials, ) from backend.integrations.creds_manager import IntegrationCredentialsManager +from backend.integrations.providers import ProviderName from backend.util.exceptions import NotFoundError logger = logging.getLogger(__name__) @@ -359,7 +360,7 @@ async def match_user_credentials_to_graph( _, _, ) in aggregated_creds.items(): - # Find first matching credential by provider, type, and scopes + # Find first matching credential by provider, type, scopes, and host/URL matching_cred = next( ( cred @@ -374,6 +375,10 @@ async def match_user_credentials_to_graph( cred.type != "host_scoped" or _credential_is_for_host(cred, credential_requirements) ) + and ( + cred.provider != ProviderName.MCP + or _credential_is_for_mcp_server(cred, credential_requirements) + ) ), None, ) @@ -444,6 +449,22 @@ def _credential_is_for_host( return credential.matches_url(list(requirements.discriminator_values)[0]) +def _credential_is_for_mcp_server( + credential: Credentials, + requirements: CredentialsFieldInfo, +) -> bool: + """Check if an MCP OAuth credential matches the required server URL.""" + if not requirements.discriminator_values: + return True + + server_url = ( + credential.metadata.get("mcp_server_url") + if isinstance(credential, OAuth2Credentials) + else None + ) + return server_url in requirements.discriminator_values if server_url else False + + async def check_user_has_required_credentials( user_id: str, required_credentials: list[CredentialsMetaInput], diff --git a/autogpt_platform/backend/backend/api/features/integrations/router.py b/autogpt_platform/backend/backend/api/features/integrations/router.py index 00500dc8a8..4eacf83e71 100644 --- a/autogpt_platform/backend/backend/api/features/integrations/router.py +++ b/autogpt_platform/backend/backend/api/features/integrations/router.py @@ -1,7 +1,7 @@ import asyncio import logging from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Annotated, List, Literal +from typing import TYPE_CHECKING, Annotated, Any, List, Literal from autogpt_libs.auth import get_user_id from fastapi import ( @@ -14,7 +14,7 @@ from fastapi import ( Security, status, ) -from pydantic import BaseModel, Field, SecretStr +from pydantic import BaseModel, Field, SecretStr, model_validator from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY from backend.api.features.library.db import set_preset_webhook, update_preset @@ -39,7 +39,11 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step from backend.data.user import get_user_integrations from backend.executor.utils import add_graph_execution from backend.integrations.ayrshare import AyrshareClient, SocialPlatform -from backend.integrations.creds_manager import IntegrationCredentialsManager +from backend.integrations.credentials_store import provider_matches +from backend.integrations.creds_manager import ( + IntegrationCredentialsManager, + create_mcp_oauth_handler, +) from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME from backend.integrations.providers import ProviderName from backend.integrations.webhooks import get_webhook_manager @@ -102,9 +106,37 @@ class CredentialsMetaResponse(BaseModel): scopes: list[str] | None username: str | None host: str | None = Field( - default=None, description="Host pattern for host-scoped credentials" + default=None, + description="Host pattern for host-scoped or MCP server URL for MCP credentials", ) + @model_validator(mode="before") + @classmethod + def _normalize_provider(cls, data: Any) -> Any: + """Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug.""" + if isinstance(data, dict): + prov = data.get("provider", "") + if isinstance(prov, str) and prov.startswith("ProviderName."): + member = prov.removeprefix("ProviderName.") + try: + data = {**data, "provider": ProviderName[member].value} + except KeyError: + pass + return data + + @staticmethod + def get_host(cred: Credentials) -> str | None: + """Extract host from credential: HostScoped host or MCP server URL.""" + if isinstance(cred, HostScopedCredentials): + return cred.host + if isinstance(cred, OAuth2Credentials) and cred.provider in ( + ProviderName.MCP, + ProviderName.MCP.value, + "ProviderName.MCP", + ): + return (cred.metadata or {}).get("mcp_server_url") + return None + @router.post("/{provider}/callback", summary="Exchange OAuth code for tokens") async def callback( @@ -179,9 +211,7 @@ async def callback( title=credentials.title, scopes=credentials.scopes, username=credentials.username, - host=( - credentials.host if isinstance(credentials, HostScopedCredentials) else None - ), + host=(CredentialsMetaResponse.get_host(credentials)), ) @@ -199,7 +229,7 @@ async def list_credentials( title=cred.title, scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None, username=cred.username if isinstance(cred, OAuth2Credentials) else None, - host=cred.host if isinstance(cred, HostScopedCredentials) else None, + host=CredentialsMetaResponse.get_host(cred), ) for cred in credentials ] @@ -222,7 +252,7 @@ async def list_credentials_by_provider( title=cred.title, scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None, username=cred.username if isinstance(cred, OAuth2Credentials) else None, - host=cred.host if isinstance(cred, HostScopedCredentials) else None, + host=CredentialsMetaResponse.get_host(cred), ) for cred in credentials ] @@ -322,7 +352,11 @@ async def delete_credentials( tokens_revoked = None if isinstance(creds, OAuth2Credentials): - handler = _get_provider_oauth_handler(request, provider) + if provider_matches(provider.value, ProviderName.MCP.value): + # MCP uses dynamic per-server OAuth — create handler from metadata + handler = create_mcp_oauth_handler(creds) + else: + handler = _get_provider_oauth_handler(request, provider) tokens_revoked = await handler.revoke_tokens(creds) return CredentialsDeletionResponse(revoked=tokens_revoked) diff --git a/autogpt_platform/backend/backend/api/features/mcp/__init__.py b/autogpt_platform/backend/backend/api/features/mcp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/autogpt_platform/backend/backend/api/features/mcp/routes.py b/autogpt_platform/backend/backend/api/features/mcp/routes.py new file mode 100644 index 0000000000..f8d311f372 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/mcp/routes.py @@ -0,0 +1,404 @@ +""" +MCP (Model Context Protocol) API routes. + +Provides endpoints for MCP tool discovery and OAuth authentication so the +frontend can list available tools on an MCP server before placing a block. +""" + +import logging +from typing import Annotated, Any +from urllib.parse import urlparse + +import fastapi +from autogpt_libs.auth import get_user_id +from fastapi import Security +from pydantic import BaseModel, Field + +from backend.api.features.integrations.router import CredentialsMetaResponse +from backend.blocks.mcp.client import MCPClient, MCPClientError +from backend.blocks.mcp.oauth import MCPOAuthHandler +from backend.data.model import OAuth2Credentials +from backend.integrations.creds_manager import IntegrationCredentialsManager +from backend.integrations.providers import ProviderName +from backend.util.request import HTTPClientError, Requests +from backend.util.settings import Settings + +logger = logging.getLogger(__name__) + +settings = Settings() +router = fastapi.APIRouter(tags=["mcp"]) +creds_manager = IntegrationCredentialsManager() + + +# ====================== Tool Discovery ====================== # + + +class DiscoverToolsRequest(BaseModel): + """Request to discover tools on an MCP server.""" + + server_url: str = Field(description="URL of the MCP server") + auth_token: str | None = Field( + default=None, + description="Optional Bearer token for authenticated MCP servers", + ) + + +class MCPToolResponse(BaseModel): + """A single MCP tool returned by discovery.""" + + name: str + description: str + input_schema: dict[str, Any] + + +class DiscoverToolsResponse(BaseModel): + """Response containing the list of tools available on an MCP server.""" + + tools: list[MCPToolResponse] + server_name: str | None = None + protocol_version: str | None = None + + +@router.post( + "/discover-tools", + summary="Discover available tools on an MCP server", + response_model=DiscoverToolsResponse, +) +async def discover_tools( + request: DiscoverToolsRequest, + user_id: Annotated[str, Security(get_user_id)], +) -> DiscoverToolsResponse: + """ + Connect to an MCP server and return its available tools. + + If the user has a stored MCP credential for this server URL, it will be + used automatically — no need to pass an explicit auth token. + """ + auth_token = request.auth_token + + # Auto-use stored MCP credential when no explicit token is provided. + if not auth_token: + mcp_creds = await creds_manager.store.get_creds_by_provider( + user_id, ProviderName.MCP.value + ) + # Find the freshest credential for this server URL + best_cred: OAuth2Credentials | None = None + for cred in mcp_creds: + if ( + isinstance(cred, OAuth2Credentials) + and (cred.metadata or {}).get("mcp_server_url") == request.server_url + ): + if best_cred is None or ( + (cred.access_token_expires_at or 0) + > (best_cred.access_token_expires_at or 0) + ): + best_cred = cred + if best_cred: + # Refresh the token if expired before using it + best_cred = await creds_manager.refresh_if_needed(user_id, best_cred) + logger.info( + f"Using MCP credential {best_cred.id} for {request.server_url}, " + f"expires_at={best_cred.access_token_expires_at}" + ) + auth_token = best_cred.access_token.get_secret_value() + + client = MCPClient(request.server_url, auth_token=auth_token) + + try: + init_result = await client.initialize() + tools = await client.list_tools() + except HTTPClientError as e: + if e.status_code in (401, 403): + raise fastapi.HTTPException( + status_code=401, + detail="This MCP server requires authentication. " + "Please provide a valid auth token.", + ) + raise fastapi.HTTPException(status_code=502, detail=str(e)) + except MCPClientError as e: + raise fastapi.HTTPException(status_code=502, detail=str(e)) + except Exception as e: + raise fastapi.HTTPException( + status_code=502, + detail=f"Failed to connect to MCP server: {e}", + ) + + return DiscoverToolsResponse( + tools=[ + MCPToolResponse( + name=t.name, + description=t.description, + input_schema=t.input_schema, + ) + for t in tools + ], + server_name=( + init_result.get("serverInfo", {}).get("name") + or urlparse(request.server_url).hostname + or "MCP" + ), + protocol_version=init_result.get("protocolVersion"), + ) + + +# ======================== OAuth Flow ======================== # + + +class MCPOAuthLoginRequest(BaseModel): + """Request to start an OAuth flow for an MCP server.""" + + server_url: str = Field(description="URL of the MCP server that requires OAuth") + + +class MCPOAuthLoginResponse(BaseModel): + """Response with the OAuth login URL for the user to authenticate.""" + + login_url: str + state_token: str + + +@router.post( + "/oauth/login", + summary="Initiate OAuth login for an MCP server", +) +async def mcp_oauth_login( + request: MCPOAuthLoginRequest, + user_id: Annotated[str, Security(get_user_id)], +) -> MCPOAuthLoginResponse: + """ + Discover OAuth metadata from the MCP server and return a login URL. + + 1. Discovers the protected-resource metadata (RFC 9728) + 2. Fetches the authorization server metadata (RFC 8414) + 3. Performs Dynamic Client Registration (RFC 7591) if available + 4. Returns the authorization URL for the frontend to open in a popup + """ + client = MCPClient(request.server_url) + + # Step 1: Discover protected-resource metadata (RFC 9728) + protected_resource = await client.discover_auth() + + metadata: dict[str, Any] | None = None + + if protected_resource and protected_resource.get("authorization_servers"): + auth_server_url = protected_resource["authorization_servers"][0] + resource_url = protected_resource.get("resource", request.server_url) + + # Step 2a: Discover auth-server metadata (RFC 8414) + metadata = await client.discover_auth_server_metadata(auth_server_url) + else: + # Fallback: Some MCP servers (e.g. Linear) are their own auth server + # and serve OAuth metadata directly without protected-resource metadata. + # Don't assume a resource_url — omitting it lets the auth server choose + # the correct audience for the token (RFC 8707 resource is optional). + resource_url = None + metadata = await client.discover_auth_server_metadata(request.server_url) + + if ( + not metadata + or "authorization_endpoint" not in metadata + or "token_endpoint" not in metadata + ): + raise fastapi.HTTPException( + status_code=400, + detail="This MCP server does not advertise OAuth support. " + "You may need to provide an auth token manually.", + ) + + authorize_url = metadata["authorization_endpoint"] + token_url = metadata["token_endpoint"] + registration_endpoint = metadata.get("registration_endpoint") + revoke_url = metadata.get("revocation_endpoint") + + # Step 3: Dynamic Client Registration (RFC 7591) if available + frontend_base_url = settings.config.frontend_base_url + if not frontend_base_url: + raise fastapi.HTTPException( + status_code=500, + detail="Frontend base URL is not configured.", + ) + redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback" + + client_id = "" + client_secret = "" + if registration_endpoint: + reg_result = await _register_mcp_client( + registration_endpoint, redirect_uri, request.server_url + ) + if reg_result: + client_id = reg_result.get("client_id", "") + client_secret = reg_result.get("client_secret", "") + + if not client_id: + client_id = "autogpt-platform" + + # Step 4: Store state token with OAuth metadata for the callback + scopes = (protected_resource or {}).get("scopes_supported") or metadata.get( + "scopes_supported", [] + ) + state_token, code_challenge = await creds_manager.store.store_state_token( + user_id, + ProviderName.MCP.value, + scopes, + state_metadata={ + "authorize_url": authorize_url, + "token_url": token_url, + "revoke_url": revoke_url, + "resource_url": resource_url, + "server_url": request.server_url, + "client_id": client_id, + "client_secret": client_secret, + }, + ) + + # Step 5: Build and return the login URL + handler = MCPOAuthHandler( + client_id=client_id, + client_secret=client_secret, + redirect_uri=redirect_uri, + authorize_url=authorize_url, + token_url=token_url, + resource_url=resource_url, + ) + login_url = handler.get_login_url( + scopes, state_token, code_challenge=code_challenge + ) + + return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token) + + +class MCPOAuthCallbackRequest(BaseModel): + """Request to exchange an OAuth code for tokens.""" + + code: str = Field(description="Authorization code from OAuth callback") + state_token: str = Field(description="State token for CSRF verification") + + +class MCPOAuthCallbackResponse(BaseModel): + """Response after successfully storing OAuth credentials.""" + + credential_id: str + + +@router.post( + "/oauth/callback", + summary="Exchange OAuth code for MCP tokens", +) +async def mcp_oauth_callback( + request: MCPOAuthCallbackRequest, + user_id: Annotated[str, Security(get_user_id)], +) -> CredentialsMetaResponse: + """ + Exchange the authorization code for tokens and store the credential. + + The frontend calls this after receiving the OAuth code from the popup. + On success, subsequent ``/discover-tools`` calls for the same server URL + will automatically use the stored credential. + """ + valid_state = await creds_manager.store.verify_state_token( + user_id, request.state_token, ProviderName.MCP.value + ) + if not valid_state: + raise fastapi.HTTPException( + status_code=400, + detail="Invalid or expired state token.", + ) + + meta = valid_state.state_metadata + frontend_base_url = settings.config.frontend_base_url + if not frontend_base_url: + raise fastapi.HTTPException( + status_code=500, + detail="Frontend base URL is not configured.", + ) + redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback" + + handler = MCPOAuthHandler( + client_id=meta["client_id"], + client_secret=meta.get("client_secret", ""), + redirect_uri=redirect_uri, + authorize_url=meta["authorize_url"], + token_url=meta["token_url"], + revoke_url=meta.get("revoke_url"), + resource_url=meta.get("resource_url"), + ) + + try: + credentials = await handler.exchange_code_for_tokens( + request.code, valid_state.scopes, valid_state.code_verifier + ) + except Exception as e: + raise fastapi.HTTPException( + status_code=400, + detail=f"OAuth token exchange failed: {e}", + ) + + # Enrich credential metadata for future lookup and token refresh + if credentials.metadata is None: + credentials.metadata = {} + credentials.metadata["mcp_server_url"] = meta["server_url"] + credentials.metadata["mcp_client_id"] = meta["client_id"] + credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "") + credentials.metadata["mcp_token_url"] = meta["token_url"] + credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "") + + hostname = urlparse(meta["server_url"]).hostname or meta["server_url"] + credentials.title = f"MCP: {hostname}" + + # Remove old MCP credentials for the same server to prevent stale token buildup. + try: + old_creds = await creds_manager.store.get_creds_by_provider( + user_id, ProviderName.MCP.value + ) + for old in old_creds: + if ( + isinstance(old, OAuth2Credentials) + and (old.metadata or {}).get("mcp_server_url") == meta["server_url"] + ): + await creds_manager.store.delete_creds_by_id(user_id, old.id) + logger.info( + f"Removed old MCP credential {old.id} for {meta['server_url']}" + ) + except Exception: + logger.debug("Could not clean up old MCP credentials", exc_info=True) + + await creds_manager.create(user_id, credentials) + + return CredentialsMetaResponse( + id=credentials.id, + provider=credentials.provider, + type=credentials.type, + title=credentials.title, + scopes=credentials.scopes, + username=credentials.username, + host=credentials.metadata.get("mcp_server_url"), + ) + + +# ======================== Helpers ======================== # + + +async def _register_mcp_client( + registration_endpoint: str, + redirect_uri: str, + server_url: str, +) -> dict[str, Any] | None: + """Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server.""" + try: + response = await Requests(raise_for_status=True).post( + registration_endpoint, + json={ + "client_name": "AutoGPT Platform", + "redirect_uris": [redirect_uri], + "grant_types": ["authorization_code"], + "response_types": ["code"], + "token_endpoint_auth_method": "client_secret_post", + }, + ) + data = response.json() + if isinstance(data, dict) and "client_id" in data: + return data + return None + except Exception as e: + logger.warning(f"Dynamic client registration failed for {server_url}: {e}") + return None diff --git a/autogpt_platform/backend/backend/api/features/mcp/test_routes.py b/autogpt_platform/backend/backend/api/features/mcp/test_routes.py new file mode 100644 index 0000000000..e86b9f4865 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/mcp/test_routes.py @@ -0,0 +1,436 @@ +"""Tests for MCP API routes. + +Uses httpx.AsyncClient with ASGITransport instead of fastapi.testclient.TestClient +to avoid creating blocking portals that can corrupt pytest-asyncio's session event loop. +""" + +from unittest.mock import AsyncMock, patch + +import fastapi +import httpx +import pytest +import pytest_asyncio +from autogpt_libs.auth import get_user_id + +from backend.api.features.mcp.routes import router +from backend.blocks.mcp.client import MCPClientError, MCPTool +from backend.util.request import HTTPClientError + +app = fastapi.FastAPI() +app.include_router(router) +app.dependency_overrides[get_user_id] = lambda: "test-user-id" + + +@pytest_asyncio.fixture(scope="module") +async def client(): + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as c: + yield c + + +class TestDiscoverTools: + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_success(self, client): + mock_tools = [ + MCPTool( + name="get_weather", + description="Get weather for a city", + input_schema={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + ), + MCPTool( + name="add_numbers", + description="Add two numbers", + input_schema={ + "type": "object", + "properties": { + "a": {"type": "number"}, + "b": {"type": "number"}, + }, + }, + ), + ] + + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + instance = MockClient.return_value + instance.initialize = AsyncMock( + return_value={ + "protocolVersion": "2025-03-26", + "serverInfo": {"name": "test-server"}, + } + ) + instance.list_tools = AsyncMock(return_value=mock_tools) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://mcp.example.com/mcp"}, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data["tools"]) == 2 + assert data["tools"][0]["name"] == "get_weather" + assert data["tools"][1]["name"] == "add_numbers" + assert data["server_name"] == "test-server" + assert data["protocol_version"] == "2025-03-26" + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_with_auth_token(self, client): + with patch("backend.api.features.mcp.routes.MCPClient") as MockClient: + instance = MockClient.return_value + instance.initialize = AsyncMock( + return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"} + ) + instance.list_tools = AsyncMock(return_value=[]) + + response = await client.post( + "/discover-tools", + json={ + "server_url": "https://mcp.example.com/mcp", + "auth_token": "my-secret-token", + }, + ) + + assert response.status_code == 200 + MockClient.assert_called_once_with( + "https://mcp.example.com/mcp", + auth_token="my-secret-token", + ) + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_auto_uses_stored_credential(self, client): + """When no explicit token is given, stored MCP credentials are used.""" + from pydantic import SecretStr + + from backend.data.model import OAuth2Credentials + + stored_cred = OAuth2Credentials( + provider="mcp", + title="MCP: example.com", + access_token=SecretStr("stored-token-123"), + refresh_token=None, + access_token_expires_at=None, + refresh_token_expires_at=None, + scopes=[], + metadata={"mcp_server_url": "https://mcp.example.com/mcp"}, + ) + + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred]) + mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred) + instance = MockClient.return_value + instance.initialize = AsyncMock( + return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"} + ) + instance.list_tools = AsyncMock(return_value=[]) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://mcp.example.com/mcp"}, + ) + + assert response.status_code == 200 + MockClient.assert_called_once_with( + "https://mcp.example.com/mcp", + auth_token="stored-token-123", + ) + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_mcp_error(self, client): + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + instance = MockClient.return_value + instance.initialize = AsyncMock( + side_effect=MCPClientError("Connection refused") + ) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://bad-server.example.com/mcp"}, + ) + + assert response.status_code == 502 + assert "Connection refused" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_generic_error(self, client): + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + instance = MockClient.return_value + instance.initialize = AsyncMock(side_effect=Exception("Network timeout")) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://timeout.example.com/mcp"}, + ) + + assert response.status_code == 502 + assert "Failed to connect" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_auth_required(self, client): + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + instance = MockClient.return_value + instance.initialize = AsyncMock( + side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401) + ) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://auth-server.example.com/mcp"}, + ) + + assert response.status_code == 401 + assert "requires authentication" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_forbidden(self, client): + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + instance = MockClient.return_value + instance.initialize = AsyncMock( + side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403) + ) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://auth-server.example.com/mcp"}, + ) + + assert response.status_code == 401 + assert "requires authentication" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_missing_url(self, client): + response = await client.post("/discover-tools", json={}) + assert response.status_code == 422 + + +class TestOAuthLogin: + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_login_success(self, client): + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + patch("backend.api.features.mcp.routes.settings") as mock_settings, + patch( + "backend.api.features.mcp.routes._register_mcp_client" + ) as mock_register, + ): + instance = MockClient.return_value + instance.discover_auth = AsyncMock( + return_value={ + "authorization_servers": ["https://auth.sentry.io"], + "resource": "https://mcp.sentry.dev/mcp", + "scopes_supported": ["openid"], + } + ) + instance.discover_auth_server_metadata = AsyncMock( + return_value={ + "authorization_endpoint": "https://auth.sentry.io/authorize", + "token_endpoint": "https://auth.sentry.io/token", + "registration_endpoint": "https://auth.sentry.io/register", + } + ) + mock_register.return_value = { + "client_id": "registered-client-id", + "client_secret": "registered-secret", + } + mock_cm.store.store_state_token = AsyncMock( + return_value=("state-token-123", "code-challenge-abc") + ) + mock_settings.config.frontend_base_url = "http://localhost:3000" + + response = await client.post( + "/oauth/login", + json={"server_url": "https://mcp.sentry.dev/mcp"}, + ) + + assert response.status_code == 200 + data = response.json() + assert "login_url" in data + assert data["state_token"] == "state-token-123" + assert "auth.sentry.io/authorize" in data["login_url"] + assert "registered-client-id" in data["login_url"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_login_no_oauth_support(self, client): + with patch("backend.api.features.mcp.routes.MCPClient") as MockClient: + instance = MockClient.return_value + instance.discover_auth = AsyncMock(return_value=None) + instance.discover_auth_server_metadata = AsyncMock(return_value=None) + + response = await client.post( + "/oauth/login", + json={"server_url": "https://simple-server.example.com/mcp"}, + ) + + assert response.status_code == 400 + assert "does not advertise OAuth" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_login_fallback_to_public_client(self, client): + """When DCR is unavailable, falls back to default public client ID.""" + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + patch("backend.api.features.mcp.routes.settings") as mock_settings, + ): + instance = MockClient.return_value + instance.discover_auth = AsyncMock( + return_value={ + "authorization_servers": ["https://auth.example.com"], + "resource": "https://mcp.example.com/mcp", + } + ) + instance.discover_auth_server_metadata = AsyncMock( + return_value={ + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + # No registration_endpoint + } + ) + mock_cm.store.store_state_token = AsyncMock( + return_value=("state-abc", "challenge-xyz") + ) + mock_settings.config.frontend_base_url = "http://localhost:3000" + + response = await client.post( + "/oauth/login", + json={"server_url": "https://mcp.example.com/mcp"}, + ) + + assert response.status_code == 200 + data = response.json() + assert "autogpt-platform" in data["login_url"] + + +class TestOAuthCallback: + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_callback_success(self, client): + from pydantic import SecretStr + + from backend.data.model import OAuth2Credentials + + mock_creds = OAuth2Credentials( + provider="mcp", + title=None, + access_token=SecretStr("access-token-xyz"), + refresh_token=None, + access_token_expires_at=None, + refresh_token_expires_at=None, + scopes=[], + metadata={ + "mcp_token_url": "https://auth.sentry.io/token", + "mcp_resource_url": "https://mcp.sentry.dev/mcp", + }, + ) + + with ( + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + patch("backend.api.features.mcp.routes.settings") as mock_settings, + patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler, + ): + mock_settings.config.frontend_base_url = "http://localhost:3000" + + # Mock state verification + mock_state = AsyncMock() + mock_state.state_metadata = { + "authorize_url": "https://auth.sentry.io/authorize", + "token_url": "https://auth.sentry.io/token", + "client_id": "test-client-id", + "client_secret": "test-secret", + "server_url": "https://mcp.sentry.dev/mcp", + } + mock_state.scopes = ["openid"] + mock_state.code_verifier = "verifier-123" + mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state) + mock_cm.create = AsyncMock() + + handler_instance = MockHandler.return_value + handler_instance.exchange_code_for_tokens = AsyncMock( + return_value=mock_creds + ) + + # Mock old credential cleanup + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + + response = await client.post( + "/oauth/callback", + json={"code": "auth-code-abc", "state_token": "state-token-123"}, + ) + + assert response.status_code == 200 + data = response.json() + assert "id" in data + assert data["provider"] == "mcp" + assert data["type"] == "oauth2" + mock_cm.create.assert_called_once() + + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_callback_invalid_state(self, client): + with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm: + mock_cm.store.verify_state_token = AsyncMock(return_value=None) + + response = await client.post( + "/oauth/callback", + json={"code": "auth-code", "state_token": "bad-state"}, + ) + + assert response.status_code == 400 + assert "Invalid or expired" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_callback_token_exchange_fails(self, client): + with ( + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + patch("backend.api.features.mcp.routes.settings") as mock_settings, + patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler, + ): + mock_settings.config.frontend_base_url = "http://localhost:3000" + mock_state = AsyncMock() + mock_state.state_metadata = { + "authorize_url": "https://auth.example.com/authorize", + "token_url": "https://auth.example.com/token", + "client_id": "cid", + "server_url": "https://mcp.example.com/mcp", + } + mock_state.scopes = [] + mock_state.code_verifier = "v" + mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state) + + handler_instance = MockHandler.return_value + handler_instance.exchange_code_for_tokens = AsyncMock( + side_effect=RuntimeError("Token exchange failed") + ) + + response = await client.post( + "/oauth/callback", + json={"code": "bad-code", "state_token": "state"}, + ) + + assert response.status_code == 400 + assert "token exchange failed" in response.json()["detail"].lower() diff --git a/autogpt_platform/backend/backend/api/rest_api.py b/autogpt_platform/backend/backend/api/rest_api.py index 0eef76193e..aed348755b 100644 --- a/autogpt_platform/backend/backend/api/rest_api.py +++ b/autogpt_platform/backend/backend/api/rest_api.py @@ -26,6 +26,7 @@ import backend.api.features.executions.review.routes import backend.api.features.library.db import backend.api.features.library.model import backend.api.features.library.routes +import backend.api.features.mcp.routes as mcp_routes import backend.api.features.oauth import backend.api.features.otto.routes import backend.api.features.postmark.postmark @@ -343,6 +344,11 @@ app.include_router( tags=["workspace"], prefix="/api/workspace", ) +app.include_router( + mcp_routes.router, + tags=["v2", "mcp"], + prefix="/api/mcp", +) app.include_router( backend.api.features.oauth.router, tags=["oauth"], diff --git a/autogpt_platform/backend/backend/blocks/_base.py b/autogpt_platform/backend/backend/blocks/_base.py index 0ba4daec40..632c5e43b9 100644 --- a/autogpt_platform/backend/backend/blocks/_base.py +++ b/autogpt_platform/backend/backend/blocks/_base.py @@ -64,6 +64,7 @@ class BlockType(Enum): AI = "AI" AYRSHARE = "Ayrshare" HUMAN_IN_THE_LOOP = "Human In The Loop" + MCP_TOOL = "MCP Tool" class BlockCategory(Enum): diff --git a/autogpt_platform/backend/backend/blocks/mcp/__init__.py b/autogpt_platform/backend/backend/blocks/mcp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/autogpt_platform/backend/backend/blocks/mcp/block.py b/autogpt_platform/backend/backend/blocks/mcp/block.py new file mode 100644 index 0000000000..9e3056d928 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/block.py @@ -0,0 +1,300 @@ +""" +MCP (Model Context Protocol) Tool Block. + +A single dynamic block that can connect to any MCP server, discover available tools, +and execute them. Works like AgentExecutorBlock — the user selects a tool from a +dropdown and the input/output schema adapts dynamically. +""" + +import json +import logging +from typing import Any, Literal + +from pydantic import SecretStr + +from backend.blocks._base import ( + Block, + BlockCategory, + BlockSchemaInput, + BlockSchemaOutput, + BlockType, +) +from backend.blocks.mcp.client import MCPClient, MCPClientError +from backend.data.block import BlockInput, BlockOutput +from backend.data.model import ( + CredentialsField, + CredentialsMetaInput, + OAuth2Credentials, + SchemaField, +) +from backend.integrations.providers import ProviderName +from backend.util.json import validate_with_jsonschema + +logger = logging.getLogger(__name__) + +TEST_CREDENTIALS = OAuth2Credentials( + id="test-mcp-cred", + provider="mcp", + access_token=SecretStr("mock-mcp-token"), + refresh_token=SecretStr("mock-refresh"), + scopes=[], + title="Mock MCP credential", +) +TEST_CREDENTIALS_INPUT = { + "provider": TEST_CREDENTIALS.provider, + "id": TEST_CREDENTIALS.id, + "type": TEST_CREDENTIALS.type, + "title": TEST_CREDENTIALS.title, +} + + +MCPCredentials = CredentialsMetaInput[Literal[ProviderName.MCP], Literal["oauth2"]] + + +class MCPToolBlock(Block): + """ + A block that connects to an MCP server, lets the user pick a tool, + and executes it with dynamic input/output schema. + + The flow: + 1. User provides an MCP server URL (and optional credentials) + 2. Frontend calls the backend to get tool list from that URL + 3. User selects a tool from a dropdown (available_tools) + 4. The block's input schema updates to reflect the selected tool's parameters + 5. On execution, the block calls the MCP server to run the tool + """ + + class Input(BlockSchemaInput): + server_url: str = SchemaField( + description="URL of the MCP server (Streamable HTTP endpoint)", + placeholder="https://mcp.example.com/mcp", + ) + credentials: MCPCredentials = CredentialsField( + discriminator="server_url", + description="MCP server OAuth credentials", + default={}, + ) + selected_tool: str = SchemaField( + description="The MCP tool to execute", + placeholder="Select a tool", + default="", + ) + tool_input_schema: dict[str, Any] = SchemaField( + description="JSON Schema for the selected tool's input parameters. " + "Populated automatically when a tool is selected.", + default={}, + hidden=True, + ) + + tool_arguments: dict[str, Any] = SchemaField( + description="Arguments to pass to the selected MCP tool. " + "The fields here are defined by the tool's input schema.", + default={}, + ) + + @classmethod + def get_input_schema(cls, data: BlockInput) -> dict[str, Any]: + """Return the tool's input schema so the builder UI renders dynamic fields.""" + return data.get("tool_input_schema", {}) + + @classmethod + def get_input_defaults(cls, data: BlockInput) -> BlockInput: + """Return the current tool_arguments as defaults for the dynamic fields.""" + return data.get("tool_arguments", {}) + + @classmethod + def get_missing_input(cls, data: BlockInput) -> set[str]: + """Check which required tool arguments are missing.""" + required_fields = cls.get_input_schema(data).get("required", []) + tool_arguments = data.get("tool_arguments", {}) + return set(required_fields) - set(tool_arguments) + + @classmethod + def get_mismatch_error(cls, data: BlockInput) -> str | None: + """Validate tool_arguments against the tool's input schema.""" + tool_schema = cls.get_input_schema(data) + if not tool_schema: + return None + tool_arguments = data.get("tool_arguments", {}) + return validate_with_jsonschema(tool_schema, tool_arguments) + + class Output(BlockSchemaOutput): + result: Any = SchemaField(description="The result returned by the MCP tool") + error: str = SchemaField(description="Error message if the tool call failed") + + def __init__(self): + super().__init__( + id="a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4", + description="Connect to any MCP server and execute its tools. " + "Provide a server URL, select a tool, and pass arguments dynamically.", + categories={BlockCategory.DEVELOPER_TOOLS}, + input_schema=MCPToolBlock.Input, + output_schema=MCPToolBlock.Output, + block_type=BlockType.MCP_TOOL, + test_credentials=TEST_CREDENTIALS, + test_input={ + "server_url": "https://mcp.example.com/mcp", + "credentials": TEST_CREDENTIALS_INPUT, + "selected_tool": "get_weather", + "tool_input_schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + "tool_arguments": {"city": "London"}, + }, + test_output=[ + ( + "result", + {"weather": "sunny", "temperature": 20}, + ), + ], + test_mock={ + "_call_mcp_tool": lambda *a, **kw: { + "weather": "sunny", + "temperature": 20, + }, + }, + ) + + async def _call_mcp_tool( + self, + server_url: str, + tool_name: str, + arguments: dict[str, Any], + auth_token: str | None = None, + ) -> Any: + """Call a tool on the MCP server. Extracted for easy mocking in tests.""" + client = MCPClient(server_url, auth_token=auth_token) + await client.initialize() + result = await client.call_tool(tool_name, arguments) + + if result.is_error: + error_text = "" + for item in result.content: + if item.get("type") == "text": + error_text += item.get("text", "") + raise MCPClientError( + f"MCP tool '{tool_name}' returned an error: " + f"{error_text or 'Unknown error'}" + ) + + # Extract text content from the result + output_parts = [] + for item in result.content: + if item.get("type") == "text": + text = item.get("text", "") + # Try to parse as JSON for structured output + try: + output_parts.append(json.loads(text)) + except (json.JSONDecodeError, ValueError): + output_parts.append(text) + elif item.get("type") == "image": + output_parts.append( + { + "type": "image", + "data": item.get("data"), + "mimeType": item.get("mimeType"), + } + ) + elif item.get("type") == "resource": + output_parts.append(item.get("resource", {})) + + # If single result, unwrap + if len(output_parts) == 1: + return output_parts[0] + return output_parts if output_parts else None + + @staticmethod + async def _auto_lookup_credential( + user_id: str, server_url: str + ) -> "OAuth2Credentials | None": + """Auto-lookup stored MCP credential for a server URL. + + This is a fallback for nodes that don't have ``credentials`` explicitly + set (e.g. nodes created before the credential field was wired up). + """ + from backend.integrations.creds_manager import IntegrationCredentialsManager + from backend.integrations.providers import ProviderName + + try: + mgr = IntegrationCredentialsManager() + mcp_creds = await mgr.store.get_creds_by_provider( + user_id, ProviderName.MCP.value + ) + best: OAuth2Credentials | None = None + for cred in mcp_creds: + if ( + isinstance(cred, OAuth2Credentials) + and (cred.metadata or {}).get("mcp_server_url") == server_url + ): + if best is None or ( + (cred.access_token_expires_at or 0) + > (best.access_token_expires_at or 0) + ): + best = cred + if best: + best = await mgr.refresh_if_needed(user_id, best) + logger.info( + "Auto-resolved MCP credential %s for %s", best.id, server_url + ) + return best + except Exception: + logger.warning("Auto-lookup MCP credential failed", exc_info=True) + return None + + async def run( + self, + input_data: Input, + *, + user_id: str, + credentials: OAuth2Credentials | None = None, + **kwargs, + ) -> BlockOutput: + if not input_data.server_url: + yield "error", "MCP server URL is required" + return + + if not input_data.selected_tool: + yield "error", "No tool selected. Please select a tool from the dropdown." + return + + # Validate required tool arguments before calling the server. + # The executor-level validation is bypassed for MCP blocks because + # get_input_defaults() flattens tool_arguments, stripping tool_input_schema + # from the validation context. + required = set(input_data.tool_input_schema.get("required", [])) + if required: + missing = required - set(input_data.tool_arguments.keys()) + if missing: + yield "error", ( + f"Missing required argument(s): {', '.join(sorted(missing))}. " + f"Please fill in all required fields marked with * in the block form." + ) + return + + # If no credentials were injected by the executor (e.g. legacy nodes + # that don't have the credentials field set), try to auto-lookup + # the stored MCP credential for this server URL. + if credentials is None: + credentials = await self._auto_lookup_credential( + user_id, input_data.server_url + ) + + auth_token = ( + credentials.access_token.get_secret_value() if credentials else None + ) + + try: + result = await self._call_mcp_tool( + server_url=input_data.server_url, + tool_name=input_data.selected_tool, + arguments=input_data.tool_arguments, + auth_token=auth_token, + ) + yield "result", result + except MCPClientError as e: + yield "error", str(e) + except Exception as e: + logger.exception(f"MCP tool call failed: {e}") + yield "error", f"MCP tool call failed: {str(e)}" diff --git a/autogpt_platform/backend/backend/blocks/mcp/client.py b/autogpt_platform/backend/backend/blocks/mcp/client.py new file mode 100644 index 0000000000..050349dbcc --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/client.py @@ -0,0 +1,323 @@ +""" +MCP (Model Context Protocol) HTTP client. + +Implements the MCP Streamable HTTP transport for listing tools and calling tools +on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST. + +Handles both JSON and SSE (text/event-stream) response formats per the MCP spec. + +Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports +""" + +import json +import logging +from dataclasses import dataclass, field +from typing import Any + +from backend.util.request import Requests + +logger = logging.getLogger(__name__) + + +@dataclass +class MCPTool: + """Represents an MCP tool discovered from a server.""" + + name: str + description: str + input_schema: dict[str, Any] + + +@dataclass +class MCPCallResult: + """Result from calling an MCP tool.""" + + content: list[dict[str, Any]] = field(default_factory=list) + is_error: bool = False + + +class MCPClientError(Exception): + """Raised when an MCP protocol error occurs.""" + + pass + + +class MCPClient: + """ + Async HTTP client for the MCP Streamable HTTP transport. + + Communicates with MCP servers using JSON-RPC 2.0 over HTTP POST. + Supports optional Bearer token authentication. + """ + + def __init__( + self, + server_url: str, + auth_token: str | None = None, + ): + self.server_url = server_url.rstrip("/") + self.auth_token = auth_token + self._request_id = 0 + self._session_id: str | None = None + + def _next_id(self) -> int: + self._request_id += 1 + return self._request_id + + def _build_headers(self) -> dict[str, str]: + headers = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + } + if self.auth_token: + headers["Authorization"] = f"Bearer {self.auth_token}" + if self._session_id: + headers["Mcp-Session-Id"] = self._session_id + return headers + + def _build_jsonrpc_request( + self, method: str, params: dict[str, Any] | None = None + ) -> dict[str, Any]: + req: dict[str, Any] = { + "jsonrpc": "2.0", + "method": method, + "id": self._next_id(), + } + if params is not None: + req["params"] = params + return req + + @staticmethod + def _parse_sse_response(text: str) -> dict[str, Any]: + """Parse an SSE (text/event-stream) response body into JSON-RPC data. + + MCP servers may return responses as SSE with format: + event: message + data: {"jsonrpc":"2.0","result":{...},"id":1} + + We extract the last `data:` line that contains a JSON-RPC response + (i.e. has an "id" field), which is the reply to our request. + """ + last_data: dict[str, Any] | None = None + for line in text.splitlines(): + stripped = line.strip() + if stripped.startswith("data:"): + payload = stripped[len("data:") :].strip() + if not payload: + continue + try: + parsed = json.loads(payload) + # Only keep JSON-RPC responses (have "id"), skip notifications + if isinstance(parsed, dict) and "id" in parsed: + last_data = parsed + except (json.JSONDecodeError, ValueError): + continue + if last_data is None: + raise MCPClientError("No JSON-RPC response found in SSE stream") + return last_data + + async def _send_request( + self, method: str, params: dict[str, Any] | None = None + ) -> Any: + """Send a JSON-RPC request to the MCP server and return the result. + + Handles both ``application/json`` and ``text/event-stream`` responses + as required by the MCP Streamable HTTP transport specification. + """ + payload = self._build_jsonrpc_request(method, params) + headers = self._build_headers() + + requests = Requests( + raise_for_status=True, + extra_headers=headers, + ) + response = await requests.post(self.server_url, json=payload) + + # Capture session ID from response (MCP Streamable HTTP transport) + session_id = response.headers.get("Mcp-Session-Id") + if session_id: + self._session_id = session_id + + content_type = response.headers.get("content-type", "") + if "text/event-stream" in content_type: + body = self._parse_sse_response(response.text()) + else: + try: + body = response.json() + except Exception as e: + raise MCPClientError( + f"MCP server returned non-JSON response: {e}" + ) from e + + if not isinstance(body, dict): + raise MCPClientError( + f"MCP server returned unexpected JSON type: {type(body).__name__}" + ) + + # Handle JSON-RPC error + if "error" in body: + error = body["error"] + if isinstance(error, dict): + raise MCPClientError( + f"MCP server error [{error.get('code', '?')}]: " + f"{error.get('message', 'Unknown error')}" + ) + raise MCPClientError(f"MCP server error: {error}") + + return body.get("result") + + async def _send_notification(self, method: str) -> None: + """Send a JSON-RPC notification (no id, no response expected).""" + headers = self._build_headers() + notification = {"jsonrpc": "2.0", "method": method} + requests = Requests( + raise_for_status=False, + extra_headers=headers, + ) + await requests.post(self.server_url, json=notification) + + async def discover_auth(self) -> dict[str, Any] | None: + """Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec). + + Returns ``None`` if the server doesn't require auth, otherwise returns + a dict with: + - ``authorization_servers``: list of authorization server URLs + - ``resource``: the resource indicator URL (usually the MCP endpoint) + - ``scopes_supported``: optional list of supported scopes + + The caller can then fetch the authorization server metadata to get + ``authorization_endpoint``, ``token_endpoint``, etc. + """ + from urllib.parse import urlparse + + parsed = urlparse(self.server_url) + base = f"{parsed.scheme}://{parsed.netloc}" + + # Build candidates for protected-resource metadata (per RFC 9728) + path = parsed.path.rstrip("/") + candidates = [] + if path and path != "/": + candidates.append(f"{base}/.well-known/oauth-protected-resource{path}") + candidates.append(f"{base}/.well-known/oauth-protected-resource") + + requests = Requests( + raise_for_status=False, + ) + for url in candidates: + try: + resp = await requests.get(url) + if resp.status == 200: + data = resp.json() + if isinstance(data, dict) and "authorization_servers" in data: + return data + except Exception: + continue + + return None + + async def discover_auth_server_metadata( + self, auth_server_url: str + ) -> dict[str, Any] | None: + """Fetch the OAuth Authorization Server Metadata (RFC 8414). + + Given an authorization server URL, returns a dict with: + - ``authorization_endpoint`` + - ``token_endpoint`` + - ``registration_endpoint`` (for dynamic client registration) + - ``scopes_supported`` + - ``code_challenge_methods_supported`` + - etc. + """ + from urllib.parse import urlparse + + parsed = urlparse(auth_server_url) + base = f"{parsed.scheme}://{parsed.netloc}" + path = parsed.path.rstrip("/") + + # Try standard metadata endpoints (RFC 8414 and OpenID Connect) + candidates = [] + if path and path != "/": + candidates.append(f"{base}/.well-known/oauth-authorization-server{path}") + candidates.append(f"{base}/.well-known/oauth-authorization-server") + candidates.append(f"{base}/.well-known/openid-configuration") + + requests = Requests( + raise_for_status=False, + ) + for url in candidates: + try: + resp = await requests.get(url) + if resp.status == 200: + data = resp.json() + if isinstance(data, dict) and "authorization_endpoint" in data: + return data + except Exception: + continue + + return None + + async def initialize(self) -> dict[str, Any]: + """ + Send the MCP initialize request. + + This is required by the MCP protocol before any other requests. + Returns the server's capabilities. + """ + result = await self._send_request( + "initialize", + { + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": {"name": "AutoGPT-Platform", "version": "1.0.0"}, + }, + ) + # Send initialized notification (no response expected) + await self._send_notification("notifications/initialized") + + return result or {} + + async def list_tools(self) -> list[MCPTool]: + """ + Discover available tools from the MCP server. + + Returns a list of MCPTool objects with name, description, and input schema. + """ + result = await self._send_request("tools/list") + if not result or "tools" not in result: + return [] + + tools = [] + for tool_data in result["tools"]: + tools.append( + MCPTool( + name=tool_data.get("name", ""), + description=tool_data.get("description", ""), + input_schema=tool_data.get("inputSchema", {}), + ) + ) + return tools + + async def call_tool( + self, tool_name: str, arguments: dict[str, Any] + ) -> MCPCallResult: + """ + Call a tool on the MCP server. + + Args: + tool_name: The name of the tool to call. + arguments: The arguments to pass to the tool. + + Returns: + MCPCallResult with the tool's response content. + """ + result = await self._send_request( + "tools/call", + {"name": tool_name, "arguments": arguments}, + ) + if not result: + return MCPCallResult(is_error=True) + + return MCPCallResult( + content=result.get("content", []), + is_error=result.get("isError", False), + ) diff --git a/autogpt_platform/backend/backend/blocks/mcp/oauth.py b/autogpt_platform/backend/backend/blocks/mcp/oauth.py new file mode 100644 index 0000000000..2228336cd3 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/oauth.py @@ -0,0 +1,204 @@ +""" +MCP OAuth handler for MCP servers that use OAuth 2.1 authorization. + +Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed, +MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata. +This handler accepts those endpoints at construction time. +""" + +import logging +import time +import urllib.parse +from typing import ClassVar, Optional + +from pydantic import SecretStr + +from backend.data.model import OAuth2Credentials +from backend.integrations.oauth.base import BaseOAuthHandler +from backend.integrations.providers import ProviderName +from backend.util.request import Requests + +logger = logging.getLogger(__name__) + + +class MCPOAuthHandler(BaseOAuthHandler): + """ + OAuth handler for MCP servers with dynamically-discovered endpoints. + + Construction requires the authorization and token endpoint URLs, + which are obtained via MCP OAuth metadata discovery + (``MCPClient.discover_auth`` + ``discover_auth_server_metadata``). + """ + + PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP + DEFAULT_SCOPES: ClassVar[list[str]] = [] + + def __init__( + self, + client_id: str, + client_secret: str, + redirect_uri: str, + *, + authorize_url: str, + token_url: str, + revoke_url: str | None = None, + resource_url: str | None = None, + ): + self.client_id = client_id + self.client_secret = client_secret + self.redirect_uri = redirect_uri + self.authorize_url = authorize_url + self.token_url = token_url + self.revoke_url = revoke_url + self.resource_url = resource_url + + def get_login_url( + self, + scopes: list[str], + state: str, + code_challenge: Optional[str], + ) -> str: + scopes = self.handle_default_scopes(scopes) + + params: dict[str, str] = { + "response_type": "code", + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "state": state, + } + if scopes: + params["scope"] = " ".join(scopes) + # PKCE (S256) — included when the caller provides a code_challenge + if code_challenge: + params["code_challenge"] = code_challenge + params["code_challenge_method"] = "S256" + # MCP spec requires resource indicator (RFC 8707) + if self.resource_url: + params["resource"] = self.resource_url + + return f"{self.authorize_url}?{urllib.parse.urlencode(params)}" + + async def exchange_code_for_tokens( + self, + code: str, + scopes: list[str], + code_verifier: Optional[str], + ) -> OAuth2Credentials: + data: dict[str, str] = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": self.redirect_uri, + "client_id": self.client_id, + } + if self.client_secret: + data["client_secret"] = self.client_secret + if code_verifier: + data["code_verifier"] = code_verifier + if self.resource_url: + data["resource"] = self.resource_url + + response = await Requests(raise_for_status=True).post( + self.token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + tokens = response.json() + + if "error" in tokens: + raise RuntimeError( + f"Token exchange failed: {tokens.get('error_description', tokens['error'])}" + ) + + if "access_token" not in tokens: + raise RuntimeError("OAuth token response missing 'access_token' field") + + now = int(time.time()) + expires_in = tokens.get("expires_in") + + return OAuth2Credentials( + provider=self.PROVIDER_NAME, + title=None, + access_token=SecretStr(tokens["access_token"]), + refresh_token=( + SecretStr(tokens["refresh_token"]) + if tokens.get("refresh_token") + else None + ), + access_token_expires_at=now + expires_in if expires_in else None, + refresh_token_expires_at=None, + scopes=scopes, + metadata={ + "mcp_token_url": self.token_url, + "mcp_resource_url": self.resource_url, + }, + ) + + async def _refresh_tokens( + self, credentials: OAuth2Credentials + ) -> OAuth2Credentials: + if not credentials.refresh_token: + raise ValueError("No refresh token available for MCP OAuth credentials") + + data: dict[str, str] = { + "grant_type": "refresh_token", + "refresh_token": credentials.refresh_token.get_secret_value(), + "client_id": self.client_id, + } + if self.client_secret: + data["client_secret"] = self.client_secret + if self.resource_url: + data["resource"] = self.resource_url + + response = await Requests(raise_for_status=True).post( + self.token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + tokens = response.json() + + if "error" in tokens: + raise RuntimeError( + f"Token refresh failed: {tokens.get('error_description', tokens['error'])}" + ) + + if "access_token" not in tokens: + raise RuntimeError("OAuth refresh response missing 'access_token' field") + + now = int(time.time()) + expires_in = tokens.get("expires_in") + + return OAuth2Credentials( + id=credentials.id, + provider=self.PROVIDER_NAME, + title=credentials.title, + access_token=SecretStr(tokens["access_token"]), + refresh_token=( + SecretStr(tokens["refresh_token"]) + if tokens.get("refresh_token") + else credentials.refresh_token + ), + access_token_expires_at=now + expires_in if expires_in else None, + refresh_token_expires_at=credentials.refresh_token_expires_at, + scopes=credentials.scopes, + metadata=credentials.metadata, + ) + + async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool: + if not self.revoke_url: + return False + + try: + data = { + "token": credentials.access_token.get_secret_value(), + "token_type_hint": "access_token", + "client_id": self.client_id, + } + await Requests().post( + self.revoke_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + return True + except Exception: + logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True) + return False diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_e2e.py b/autogpt_platform/backend/backend/blocks/mcp/test_e2e.py new file mode 100644 index 0000000000..7818fac9ce --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_e2e.py @@ -0,0 +1,109 @@ +""" +End-to-end tests against a real public MCP server. + +These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp) +which is publicly accessible without authentication and returns SSE responses. + +Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped +independently of the rest of the test suite (they require network access). +""" + +import json +import os + +import pytest + +from backend.blocks.mcp.client import MCPClient + +# Public MCP server that requires no authentication +OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp" + +# Skip all tests in this module unless RUN_E2E env var is set +pytestmark = pytest.mark.skipif( + not os.environ.get("RUN_E2E"), reason="set RUN_E2E=1 to run e2e tests" +) + + +class TestRealMCPServer: + """Tests against the live OpenAI docs MCP server.""" + + @pytest.mark.asyncio(loop_scope="session") + async def test_initialize(self): + """Verify we can complete the MCP handshake with a real server.""" + client = MCPClient(OPENAI_DOCS_MCP_URL) + result = await client.initialize() + + assert result["protocolVersion"] == "2025-03-26" + assert "serverInfo" in result + assert result["serverInfo"]["name"] == "openai-docs-mcp" + assert "tools" in result.get("capabilities", {}) + + @pytest.mark.asyncio(loop_scope="session") + async def test_list_tools(self): + """Verify we can discover tools from a real MCP server.""" + client = MCPClient(OPENAI_DOCS_MCP_URL) + await client.initialize() + tools = await client.list_tools() + + assert len(tools) >= 3 # server has at least 5 tools as of writing + + tool_names = {t.name for t in tools} + # These tools are documented and should be stable + assert "search_openai_docs" in tool_names + assert "list_openai_docs" in tool_names + assert "fetch_openai_doc" in tool_names + + # Verify schema structure + search_tool = next(t for t in tools if t.name == "search_openai_docs") + assert "query" in search_tool.input_schema.get("properties", {}) + assert "query" in search_tool.input_schema.get("required", []) + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_list_api_endpoints(self): + """Call the list_api_endpoints tool and verify we get real data.""" + client = MCPClient(OPENAI_DOCS_MCP_URL) + await client.initialize() + result = await client.call_tool("list_api_endpoints", {}) + + assert not result.is_error + assert len(result.content) >= 1 + assert result.content[0]["type"] == "text" + + data = json.loads(result.content[0]["text"]) + assert "paths" in data or "urls" in data + # The OpenAI API should have many endpoints + total = data.get("total", len(data.get("paths", []))) + assert total > 50 + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_search(self): + """Search for docs and verify we get results.""" + client = MCPClient(OPENAI_DOCS_MCP_URL) + await client.initialize() + result = await client.call_tool( + "search_openai_docs", {"query": "chat completions", "limit": 3} + ) + + assert not result.is_error + assert len(result.content) >= 1 + + @pytest.mark.asyncio(loop_scope="session") + async def test_sse_response_handling(self): + """Verify the client correctly handles SSE responses from a real server. + + This is the key test — our local test server returns JSON, + but real MCP servers typically return SSE. This proves the + SSE parsing works end-to-end. + """ + client = MCPClient(OPENAI_DOCS_MCP_URL) + # initialize() internally calls _send_request which must parse SSE + result = await client.initialize() + + # If we got here without error, SSE parsing works + assert isinstance(result, dict) + assert "protocolVersion" in result + + # Also verify list_tools works (another SSE response) + tools = await client.list_tools() + assert len(tools) > 0 + assert all(hasattr(t, "name") for t in tools) diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_integration.py b/autogpt_platform/backend/backend/blocks/mcp/test_integration.py new file mode 100644 index 0000000000..70658dbaaf --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_integration.py @@ -0,0 +1,389 @@ +""" +Integration tests for MCP client and MCPToolBlock against a real HTTP server. + +These tests spin up a local MCP test server and run the full client/block flow +against it — no mocking, real HTTP requests. +""" + +import asyncio +import json +import threading +from unittest.mock import patch + +import pytest +from aiohttp import web +from pydantic import SecretStr + +from backend.blocks.mcp.block import MCPToolBlock +from backend.blocks.mcp.client import MCPClient +from backend.blocks.mcp.test_server import create_test_mcp_app +from backend.data.model import OAuth2Credentials + +MOCK_USER_ID = "test-user-integration" + + +class _MCPTestServer: + """ + Run an MCP test server in a background thread with its own event loop. + This avoids event loop conflicts with pytest-asyncio. + """ + + def __init__(self, auth_token: str | None = None): + self.auth_token = auth_token + self.url: str = "" + self._runner: web.AppRunner | None = None + self._loop: asyncio.AbstractEventLoop | None = None + self._thread: threading.Thread | None = None + self._started = threading.Event() + + def _run(self): + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + self._loop.run_until_complete(self._start()) + self._started.set() + self._loop.run_forever() + + async def _start(self): + app = create_test_mcp_app(auth_token=self.auth_token) + self._runner = web.AppRunner(app) + await self._runner.setup() + site = web.TCPSite(self._runner, "127.0.0.1", 0) + await site.start() + port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr] + self.url = f"http://127.0.0.1:{port}/mcp" + + def start(self): + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + if not self._started.wait(timeout=5): + raise RuntimeError("MCP test server failed to start within 5 seconds") + return self + + def stop(self): + if self._loop and self._runner: + asyncio.run_coroutine_threadsafe(self._runner.cleanup(), self._loop).result( + timeout=5 + ) + self._loop.call_soon_threadsafe(self._loop.stop) + if self._thread: + self._thread.join(timeout=5) + + +@pytest.fixture(scope="module") +def mcp_server(): + """Start a local MCP test server in a background thread.""" + server = _MCPTestServer() + server.start() + yield server.url + server.stop() + + +@pytest.fixture(scope="module") +def mcp_server_with_auth(): + """Start a local MCP test server with auth in a background thread.""" + server = _MCPTestServer(auth_token="test-secret-token") + server.start() + yield server.url, "test-secret-token" + server.stop() + + +@pytest.fixture(autouse=True) +def _allow_localhost(): + """ + Allow 127.0.0.1 through SSRF protection for integration tests. + + The Requests class blocks private IPs by default. We patch the Requests + constructor to always include 127.0.0.1 as a trusted origin so the local + test server is reachable. + """ + from backend.util.request import Requests + + original_init = Requests.__init__ + + def patched_init(self, *args, **kwargs): + trusted = list(kwargs.get("trusted_origins") or []) + trusted.append("http://127.0.0.1") + kwargs["trusted_origins"] = trusted + original_init(self, *args, **kwargs) + + with patch.object(Requests, "__init__", patched_init): + yield + + +def _make_client(url: str, auth_token: str | None = None) -> MCPClient: + """Create an MCPClient for integration tests.""" + return MCPClient(url, auth_token=auth_token) + + +# ── MCPClient integration tests ────────────────────────────────────── + + +class TestMCPClientIntegration: + """Test MCPClient against a real local MCP server.""" + + @pytest.mark.asyncio(loop_scope="session") + async def test_initialize(self, mcp_server): + client = _make_client(mcp_server) + result = await client.initialize() + + assert result["protocolVersion"] == "2025-03-26" + assert result["serverInfo"]["name"] == "test-mcp-server" + assert "tools" in result["capabilities"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_list_tools(self, mcp_server): + client = _make_client(mcp_server) + await client.initialize() + tools = await client.list_tools() + + assert len(tools) == 3 + + tool_names = {t.name for t in tools} + assert tool_names == {"get_weather", "add_numbers", "echo"} + + # Check get_weather schema + weather = next(t for t in tools if t.name == "get_weather") + assert weather.description == "Get current weather for a city" + assert "city" in weather.input_schema["properties"] + assert weather.input_schema["required"] == ["city"] + + # Check add_numbers schema + add = next(t for t in tools if t.name == "add_numbers") + assert "a" in add.input_schema["properties"] + assert "b" in add.input_schema["properties"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_get_weather(self, mcp_server): + client = _make_client(mcp_server) + await client.initialize() + result = await client.call_tool("get_weather", {"city": "London"}) + + assert not result.is_error + assert len(result.content) == 1 + assert result.content[0]["type"] == "text" + + data = json.loads(result.content[0]["text"]) + assert data["city"] == "London" + assert data["temperature"] == 22 + assert data["condition"] == "sunny" + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_add_numbers(self, mcp_server): + client = _make_client(mcp_server) + await client.initialize() + result = await client.call_tool("add_numbers", {"a": 3, "b": 7}) + + assert not result.is_error + data = json.loads(result.content[0]["text"]) + assert data["result"] == 10 + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_echo(self, mcp_server): + client = _make_client(mcp_server) + await client.initialize() + result = await client.call_tool("echo", {"message": "Hello MCP!"}) + + assert not result.is_error + assert result.content[0]["text"] == "Hello MCP!" + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_unknown_tool(self, mcp_server): + client = _make_client(mcp_server) + await client.initialize() + result = await client.call_tool("nonexistent_tool", {}) + + assert result.is_error + assert "Unknown tool" in result.content[0]["text"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_auth_success(self, mcp_server_with_auth): + url, token = mcp_server_with_auth + client = _make_client(url, auth_token=token) + result = await client.initialize() + + assert result["protocolVersion"] == "2025-03-26" + + tools = await client.list_tools() + assert len(tools) == 3 + + @pytest.mark.asyncio(loop_scope="session") + async def test_auth_failure(self, mcp_server_with_auth): + url, _ = mcp_server_with_auth + client = _make_client(url, auth_token="wrong-token") + + with pytest.raises(Exception): + await client.initialize() + + @pytest.mark.asyncio(loop_scope="session") + async def test_auth_missing(self, mcp_server_with_auth): + url, _ = mcp_server_with_auth + client = _make_client(url) + + with pytest.raises(Exception): + await client.initialize() + + +# ── MCPToolBlock integration tests ─────────────────────────────────── + + +class TestMCPToolBlockIntegration: + """Test MCPToolBlock end-to-end against a real local MCP server.""" + + @pytest.mark.asyncio(loop_scope="session") + async def test_full_flow_get_weather(self, mcp_server): + """Full flow: discover tools, select one, execute it.""" + # Step 1: Discover tools (simulating what the frontend/API would do) + client = _make_client(mcp_server) + await client.initialize() + tools = await client.list_tools() + assert len(tools) == 3 + + # Step 2: User selects "get_weather" and we get its schema + weather_tool = next(t for t in tools if t.name == "get_weather") + + # Step 3: Execute the block — no credentials (public server) + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=mcp_server, + selected_tool="get_weather", + tool_input_schema=weather_tool.input_schema, + tool_arguments={"city": "Paris"}, + ) + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + result = outputs[0][1] + assert result["city"] == "Paris" + assert result["temperature"] == 22 + assert result["condition"] == "sunny" + + @pytest.mark.asyncio(loop_scope="session") + async def test_full_flow_add_numbers(self, mcp_server): + """Full flow for add_numbers tool.""" + client = _make_client(mcp_server) + await client.initialize() + tools = await client.list_tools() + add_tool = next(t for t in tools if t.name == "add_numbers") + + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=mcp_server, + selected_tool="add_numbers", + tool_input_schema=add_tool.input_schema, + tool_arguments={"a": 42, "b": 58}, + ) + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + assert outputs[0][1]["result"] == 100 + + @pytest.mark.asyncio(loop_scope="session") + async def test_full_flow_echo_plain_text(self, mcp_server): + """Verify plain text (non-JSON) responses work.""" + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=mcp_server, + selected_tool="echo", + tool_input_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + tool_arguments={"message": "Hello from AutoGPT!"}, + ) + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + assert outputs[0][1] == "Hello from AutoGPT!" + + @pytest.mark.asyncio(loop_scope="session") + async def test_full_flow_unknown_tool_yields_error(self, mcp_server): + """Calling an unknown tool should yield an error output.""" + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=mcp_server, + selected_tool="nonexistent_tool", + tool_arguments={}, + ) + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "error" + assert "returned an error" in outputs[0][1] + + @pytest.mark.asyncio(loop_scope="session") + async def test_full_flow_with_auth(self, mcp_server_with_auth): + """Full flow with authentication via credentials kwarg.""" + url, token = mcp_server_with_auth + + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=url, + selected_tool="echo", + tool_input_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + tool_arguments={"message": "Authenticated!"}, + ) + + # Pass credentials via the standard kwarg (as the executor would) + test_creds = OAuth2Credentials( + id="test-cred", + provider="mcp", + access_token=SecretStr(token), + refresh_token=SecretStr(""), + scopes=[], + title="Test MCP credential", + ) + + outputs = [] + async for name, data in block.run( + input_data, user_id=MOCK_USER_ID, credentials=test_creds + ): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + assert outputs[0][1] == "Authenticated!" + + @pytest.mark.asyncio(loop_scope="session") + async def test_no_credentials_runs_without_auth(self, mcp_server): + """Block runs without auth when no credentials are provided.""" + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=mcp_server, + selected_tool="echo", + tool_input_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + tool_arguments={"message": "No auth needed"}, + ) + + outputs = [] + async for name, data in block.run( + input_data, user_id=MOCK_USER_ID, credentials=None + ): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + assert outputs[0][1] == "No auth needed" diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_mcp.py b/autogpt_platform/backend/backend/blocks/mcp/test_mcp.py new file mode 100644 index 0000000000..8cb49b0fee --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_mcp.py @@ -0,0 +1,619 @@ +""" +Tests for MCP client and MCPToolBlock. +""" + +import json +from unittest.mock import AsyncMock, patch + +import pytest + +from backend.blocks.mcp.block import MCPToolBlock +from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError +from backend.util.test import execute_block_test + +# ── SSE parsing unit tests ─────────────────────────────────────────── + + +class TestSSEParsing: + """Tests for SSE (text/event-stream) response parsing.""" + + def test_parse_sse_simple(self): + sse = ( + "event: message\n" + 'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n' + "\n" + ) + body = MCPClient._parse_sse_response(sse) + assert body["result"] == {"tools": []} + assert body["id"] == 1 + + def test_parse_sse_with_notifications(self): + """SSE streams can contain notifications (no id) before the response.""" + sse = ( + "event: message\n" + 'data: {"jsonrpc":"2.0","method":"some/notification"}\n' + "\n" + "event: message\n" + 'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n' + "\n" + ) + body = MCPClient._parse_sse_response(sse) + assert body["result"] == {"ok": True} + assert body["id"] == 2 + + def test_parse_sse_error_response(self): + sse = ( + "event: message\n" + 'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n' + ) + body = MCPClient._parse_sse_response(sse) + assert "error" in body + assert body["error"]["code"] == -32600 + + def test_parse_sse_no_data_raises(self): + with pytest.raises(MCPClientError, match="No JSON-RPC response found"): + MCPClient._parse_sse_response("event: message\n\n") + + def test_parse_sse_empty_raises(self): + with pytest.raises(MCPClientError, match="No JSON-RPC response found"): + MCPClient._parse_sse_response("") + + def test_parse_sse_ignores_non_data_lines(self): + sse = ( + ": comment line\n" + "event: message\n" + "id: 123\n" + 'data: {"jsonrpc":"2.0","result":"ok","id":1}\n' + "\n" + ) + body = MCPClient._parse_sse_response(sse) + assert body["result"] == "ok" + + def test_parse_sse_uses_last_response(self): + """If multiple responses exist, use the last one.""" + sse = ( + 'data: {"jsonrpc":"2.0","result":"first","id":1}\n' + "\n" + 'data: {"jsonrpc":"2.0","result":"second","id":2}\n' + "\n" + ) + body = MCPClient._parse_sse_response(sse) + assert body["result"] == "second" + + +# ── MCPClient unit tests ───────────────────────────────────────────── + + +class TestMCPClient: + """Tests for the MCP HTTP client.""" + + def test_build_headers_without_auth(self): + client = MCPClient("https://mcp.example.com") + headers = client._build_headers() + assert "Authorization" not in headers + assert headers["Content-Type"] == "application/json" + + def test_build_headers_with_auth(self): + client = MCPClient("https://mcp.example.com", auth_token="my-token") + headers = client._build_headers() + assert headers["Authorization"] == "Bearer my-token" + + def test_build_jsonrpc_request(self): + client = MCPClient("https://mcp.example.com") + req = client._build_jsonrpc_request("tools/list") + assert req["jsonrpc"] == "2.0" + assert req["method"] == "tools/list" + assert "id" in req + assert "params" not in req + + def test_build_jsonrpc_request_with_params(self): + client = MCPClient("https://mcp.example.com") + req = client._build_jsonrpc_request( + "tools/call", {"name": "test", "arguments": {"x": 1}} + ) + assert req["params"] == {"name": "test", "arguments": {"x": 1}} + + def test_request_id_increments(self): + client = MCPClient("https://mcp.example.com") + req1 = client._build_jsonrpc_request("tools/list") + req2 = client._build_jsonrpc_request("tools/list") + assert req2["id"] > req1["id"] + + def test_server_url_trailing_slash_stripped(self): + client = MCPClient("https://mcp.example.com/mcp/") + assert client.server_url == "https://mcp.example.com/mcp" + + @pytest.mark.asyncio(loop_scope="session") + async def test_send_request_success(self): + client = MCPClient("https://mcp.example.com") + + mock_response = AsyncMock() + mock_response.json.return_value = { + "jsonrpc": "2.0", + "result": {"tools": []}, + "id": 1, + } + + with patch.object(client, "_send_request", return_value={"tools": []}): + result = await client._send_request("tools/list") + assert result == {"tools": []} + + @pytest.mark.asyncio(loop_scope="session") + async def test_send_request_error(self): + client = MCPClient("https://mcp.example.com") + + async def mock_send(*args, **kwargs): + raise MCPClientError("MCP server error [-32600]: Invalid Request") + + with patch.object(client, "_send_request", side_effect=mock_send): + with pytest.raises(MCPClientError, match="Invalid Request"): + await client._send_request("tools/list") + + @pytest.mark.asyncio(loop_scope="session") + async def test_list_tools(self): + client = MCPClient("https://mcp.example.com") + + mock_result = { + "tools": [ + { + "name": "get_weather", + "description": "Get current weather for a city", + "inputSchema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + { + "name": "search", + "description": "Search the web", + "inputSchema": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }, + ] + } + + with patch.object(client, "_send_request", return_value=mock_result): + tools = await client.list_tools() + + assert len(tools) == 2 + assert tools[0].name == "get_weather" + assert tools[0].description == "Get current weather for a city" + assert tools[0].input_schema["properties"]["city"]["type"] == "string" + assert tools[1].name == "search" + + @pytest.mark.asyncio(loop_scope="session") + async def test_list_tools_empty(self): + client = MCPClient("https://mcp.example.com") + + with patch.object(client, "_send_request", return_value={"tools": []}): + tools = await client.list_tools() + + assert tools == [] + + @pytest.mark.asyncio(loop_scope="session") + async def test_list_tools_none_result(self): + client = MCPClient("https://mcp.example.com") + + with patch.object(client, "_send_request", return_value=None): + tools = await client.list_tools() + + assert tools == [] + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_success(self): + client = MCPClient("https://mcp.example.com") + + mock_result = { + "content": [ + {"type": "text", "text": json.dumps({"temp": 20, "city": "London"})} + ], + "isError": False, + } + + with patch.object(client, "_send_request", return_value=mock_result): + result = await client.call_tool("get_weather", {"city": "London"}) + + assert not result.is_error + assert len(result.content) == 1 + assert result.content[0]["type"] == "text" + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_error(self): + client = MCPClient("https://mcp.example.com") + + mock_result = { + "content": [{"type": "text", "text": "City not found"}], + "isError": True, + } + + with patch.object(client, "_send_request", return_value=mock_result): + result = await client.call_tool("get_weather", {"city": "???"}) + + assert result.is_error + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_none_result(self): + client = MCPClient("https://mcp.example.com") + + with patch.object(client, "_send_request", return_value=None): + result = await client.call_tool("get_weather", {"city": "London"}) + + assert result.is_error + + @pytest.mark.asyncio(loop_scope="session") + async def test_initialize(self): + client = MCPClient("https://mcp.example.com") + + mock_result = { + "protocolVersion": "2025-03-26", + "capabilities": {"tools": {}}, + "serverInfo": {"name": "test-server", "version": "1.0.0"}, + } + + with ( + patch.object(client, "_send_request", return_value=mock_result) as mock_req, + patch.object(client, "_send_notification") as mock_notif, + ): + result = await client.initialize() + + mock_req.assert_called_once() + mock_notif.assert_called_once_with("notifications/initialized") + assert result["protocolVersion"] == "2025-03-26" + + +# ── MCPToolBlock unit tests ────────────────────────────────────────── + +MOCK_USER_ID = "test-user-123" + + +class TestMCPToolBlock: + """Tests for the MCPToolBlock.""" + + def test_block_instantiation(self): + block = MCPToolBlock() + assert block.id == "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4" + assert block.name == "MCPToolBlock" + + def test_input_schema_has_required_fields(self): + block = MCPToolBlock() + schema = block.input_schema.jsonschema() + props = schema.get("properties", {}) + assert "server_url" in props + assert "selected_tool" in props + assert "tool_arguments" in props + assert "credentials" in props + + def test_output_schema(self): + block = MCPToolBlock() + schema = block.output_schema.jsonschema() + props = schema.get("properties", {}) + assert "result" in props + assert "error" in props + + def test_get_input_schema_with_tool_schema(self): + tool_schema = { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + } + data = {"tool_input_schema": tool_schema} + result = MCPToolBlock.Input.get_input_schema(data) + assert result == tool_schema + + def test_get_input_schema_without_tool_schema(self): + result = MCPToolBlock.Input.get_input_schema({}) + assert result == {} + + def test_get_input_defaults(self): + data = {"tool_arguments": {"city": "London"}} + result = MCPToolBlock.Input.get_input_defaults(data) + assert result == {"city": "London"} + + def test_get_missing_input(self): + data = { + "tool_input_schema": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "units": {"type": "string"}, + }, + "required": ["city", "units"], + }, + "tool_arguments": {"city": "London"}, + } + missing = MCPToolBlock.Input.get_missing_input(data) + assert missing == {"units"} + + def test_get_missing_input_all_present(self): + data = { + "tool_input_schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + "tool_arguments": {"city": "London"}, + } + missing = MCPToolBlock.Input.get_missing_input(data) + assert missing == set() + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_with_mock(self): + """Test the block using the built-in test infrastructure.""" + block = MCPToolBlock() + await execute_block_test(block) + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_missing_server_url(self): + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="", + selected_tool="test", + ) + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + assert outputs == [("error", "MCP server URL is required")] + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_missing_tool(self): + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="", + ) + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + assert outputs == [ + ("error", "No tool selected. Please select a tool from the dropdown.") + ] + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_success(self): + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="get_weather", + tool_input_schema={ + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + tool_arguments={"city": "London"}, + ) + + async def mock_call(*args, **kwargs): + return {"temp": 20, "city": "London"} + + block._call_mcp_tool = mock_call # type: ignore + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + assert outputs[0][1] == {"temp": 20, "city": "London"} + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_mcp_error(self): + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="bad_tool", + ) + + async def mock_call(*args, **kwargs): + raise MCPClientError("Tool not found") + + block._call_mcp_tool = mock_call # type: ignore + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert outputs[0][0] == "error" + assert "Tool not found" in outputs[0][1] + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_mcp_tool_parses_json_text(self): + block = MCPToolBlock() + + mock_result = MCPCallResult( + content=[ + {"type": "text", "text": '{"temp": 20}'}, + ], + is_error=False, + ) + + async def mock_init(self): + return {} + + async def mock_call(self, name, args): + return mock_result + + with ( + patch.object(MCPClient, "initialize", mock_init), + patch.object(MCPClient, "call_tool", mock_call), + ): + result = await block._call_mcp_tool( + "https://mcp.example.com", "test_tool", {} + ) + + assert result == {"temp": 20} + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_mcp_tool_plain_text(self): + block = MCPToolBlock() + + mock_result = MCPCallResult( + content=[ + {"type": "text", "text": "Hello, world!"}, + ], + is_error=False, + ) + + async def mock_init(self): + return {} + + async def mock_call(self, name, args): + return mock_result + + with ( + patch.object(MCPClient, "initialize", mock_init), + patch.object(MCPClient, "call_tool", mock_call), + ): + result = await block._call_mcp_tool( + "https://mcp.example.com", "test_tool", {} + ) + + assert result == "Hello, world!" + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_mcp_tool_multiple_content(self): + block = MCPToolBlock() + + mock_result = MCPCallResult( + content=[ + {"type": "text", "text": "Part 1"}, + {"type": "text", "text": '{"part": 2}'}, + ], + is_error=False, + ) + + async def mock_init(self): + return {} + + async def mock_call(self, name, args): + return mock_result + + with ( + patch.object(MCPClient, "initialize", mock_init), + patch.object(MCPClient, "call_tool", mock_call), + ): + result = await block._call_mcp_tool( + "https://mcp.example.com", "test_tool", {} + ) + + assert result == ["Part 1", {"part": 2}] + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_mcp_tool_error_result(self): + block = MCPToolBlock() + + mock_result = MCPCallResult( + content=[{"type": "text", "text": "Something went wrong"}], + is_error=True, + ) + + async def mock_init(self): + return {} + + async def mock_call(self, name, args): + return mock_result + + with ( + patch.object(MCPClient, "initialize", mock_init), + patch.object(MCPClient, "call_tool", mock_call), + ): + with pytest.raises(MCPClientError, match="returned an error"): + await block._call_mcp_tool("https://mcp.example.com", "test_tool", {}) + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_mcp_tool_image_content(self): + block = MCPToolBlock() + + mock_result = MCPCallResult( + content=[ + { + "type": "image", + "data": "base64data==", + "mimeType": "image/png", + } + ], + is_error=False, + ) + + async def mock_init(self): + return {} + + async def mock_call(self, name, args): + return mock_result + + with ( + patch.object(MCPClient, "initialize", mock_init), + patch.object(MCPClient, "call_tool", mock_call), + ): + result = await block._call_mcp_tool( + "https://mcp.example.com", "test_tool", {} + ) + + assert result == { + "type": "image", + "data": "base64data==", + "mimeType": "image/png", + } + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_with_credentials(self): + """Verify the block uses OAuth2Credentials and passes auth token.""" + from pydantic import SecretStr + + from backend.data.model import OAuth2Credentials + + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="test_tool", + ) + + captured_tokens: list[str | None] = [] + + async def mock_call(server_url, tool_name, arguments, auth_token=None): + captured_tokens.append(auth_token) + return "ok" + + block._call_mcp_tool = mock_call # type: ignore + + test_creds = OAuth2Credentials( + id="cred-123", + provider="mcp", + access_token=SecretStr("resolved-token"), + refresh_token=SecretStr(""), + scopes=[], + title="Test MCP credential", + ) + + async for _ in block.run( + input_data, user_id=MOCK_USER_ID, credentials=test_creds + ): + pass + + assert captured_tokens == ["resolved-token"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_without_credentials(self): + """Verify the block works without credentials (public server).""" + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="test_tool", + ) + + captured_tokens: list[str | None] = [] + + async def mock_call(server_url, tool_name, arguments, auth_token=None): + captured_tokens.append(auth_token) + return "ok" + + block._call_mcp_tool = mock_call # type: ignore + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert captured_tokens == [None] + assert outputs == [("result", "ok")] diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_oauth.py b/autogpt_platform/backend/backend/blocks/mcp/test_oauth.py new file mode 100644 index 0000000000..e9a42f68ea --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_oauth.py @@ -0,0 +1,242 @@ +""" +Tests for MCP OAuth handler. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic import SecretStr + +from backend.blocks.mcp.client import MCPClient +from backend.blocks.mcp.oauth import MCPOAuthHandler +from backend.data.model import OAuth2Credentials + + +def _mock_response(json_data: dict, status: int = 200) -> MagicMock: + """Create a mock Response with synchronous json() (matching Requests.Response).""" + resp = MagicMock() + resp.status = status + resp.ok = 200 <= status < 300 + resp.json.return_value = json_data + return resp + + +class TestMCPOAuthHandler: + """Tests for the MCPOAuthHandler.""" + + def _make_handler(self, **overrides) -> MCPOAuthHandler: + defaults = { + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "redirect_uri": "https://app.example.com/callback", + "authorize_url": "https://auth.example.com/authorize", + "token_url": "https://auth.example.com/token", + } + defaults.update(overrides) + return MCPOAuthHandler(**defaults) + + def test_get_login_url_basic(self): + handler = self._make_handler() + url = handler.get_login_url( + scopes=["read", "write"], + state="random-state-token", + code_challenge="S256-challenge-value", + ) + + assert "https://auth.example.com/authorize?" in url + assert "response_type=code" in url + assert "client_id=test-client-id" in url + assert "state=random-state-token" in url + assert "code_challenge=S256-challenge-value" in url + assert "code_challenge_method=S256" in url + assert "scope=read+write" in url + + def test_get_login_url_with_resource(self): + handler = self._make_handler(resource_url="https://mcp.example.com/mcp") + url = handler.get_login_url( + scopes=[], state="state", code_challenge="challenge" + ) + + assert "resource=https" in url + + def test_get_login_url_without_pkce(self): + handler = self._make_handler() + url = handler.get_login_url(scopes=["read"], state="state", code_challenge=None) + + assert "code_challenge" not in url + assert "code_challenge_method" not in url + + @pytest.mark.asyncio(loop_scope="session") + async def test_exchange_code_for_tokens(self): + handler = self._make_handler() + + resp = _mock_response( + { + "access_token": "new-access-token", + "refresh_token": "new-refresh-token", + "expires_in": 3600, + "token_type": "Bearer", + } + ) + + with patch("backend.blocks.mcp.oauth.Requests") as MockRequests: + instance = MockRequests.return_value + instance.post = AsyncMock(return_value=resp) + + creds = await handler.exchange_code_for_tokens( + code="auth-code", + scopes=["read"], + code_verifier="pkce-verifier", + ) + + assert isinstance(creds, OAuth2Credentials) + assert creds.access_token.get_secret_value() == "new-access-token" + assert creds.refresh_token is not None + assert creds.refresh_token.get_secret_value() == "new-refresh-token" + assert creds.scopes == ["read"] + assert creds.access_token_expires_at is not None + + @pytest.mark.asyncio(loop_scope="session") + async def test_refresh_tokens(self): + handler = self._make_handler() + + existing_creds = OAuth2Credentials( + id="existing-id", + provider="mcp", + access_token=SecretStr("old-token"), + refresh_token=SecretStr("old-refresh"), + scopes=["read"], + title="test", + ) + + resp = _mock_response( + { + "access_token": "refreshed-token", + "refresh_token": "new-refresh", + "expires_in": 3600, + } + ) + + with patch("backend.blocks.mcp.oauth.Requests") as MockRequests: + instance = MockRequests.return_value + instance.post = AsyncMock(return_value=resp) + + refreshed = await handler._refresh_tokens(existing_creds) + + assert refreshed.id == "existing-id" + assert refreshed.access_token.get_secret_value() == "refreshed-token" + assert refreshed.refresh_token is not None + assert refreshed.refresh_token.get_secret_value() == "new-refresh" + + @pytest.mark.asyncio(loop_scope="session") + async def test_refresh_tokens_no_refresh_token(self): + handler = self._make_handler() + + creds = OAuth2Credentials( + provider="mcp", + access_token=SecretStr("token"), + scopes=["read"], + title="test", + ) + + with pytest.raises(ValueError, match="No refresh token"): + await handler._refresh_tokens(creds) + + @pytest.mark.asyncio(loop_scope="session") + async def test_revoke_tokens_no_url(self): + handler = self._make_handler(revoke_url=None) + + creds = OAuth2Credentials( + provider="mcp", + access_token=SecretStr("token"), + scopes=[], + title="test", + ) + + result = await handler.revoke_tokens(creds) + assert result is False + + @pytest.mark.asyncio(loop_scope="session") + async def test_revoke_tokens_with_url(self): + handler = self._make_handler(revoke_url="https://auth.example.com/revoke") + + creds = OAuth2Credentials( + provider="mcp", + access_token=SecretStr("token"), + scopes=[], + title="test", + ) + + resp = _mock_response({}, status=200) + + with patch("backend.blocks.mcp.oauth.Requests") as MockRequests: + instance = MockRequests.return_value + instance.post = AsyncMock(return_value=resp) + + result = await handler.revoke_tokens(creds) + + assert result is True + + +class TestMCPClientDiscovery: + """Tests for MCPClient OAuth metadata discovery.""" + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_auth_found(self): + client = MCPClient("https://mcp.example.com/mcp") + + metadata = { + "authorization_servers": ["https://auth.example.com"], + "resource": "https://mcp.example.com/mcp", + } + + resp = _mock_response(metadata, status=200) + + with patch("backend.blocks.mcp.client.Requests") as MockRequests: + instance = MockRequests.return_value + instance.get = AsyncMock(return_value=resp) + + result = await client.discover_auth() + + assert result is not None + assert result["authorization_servers"] == ["https://auth.example.com"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_auth_not_found(self): + client = MCPClient("https://mcp.example.com/mcp") + + resp = _mock_response({}, status=404) + + with patch("backend.blocks.mcp.client.Requests") as MockRequests: + instance = MockRequests.return_value + instance.get = AsyncMock(return_value=resp) + + result = await client.discover_auth() + + assert result is None + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_auth_server_metadata(self): + client = MCPClient("https://mcp.example.com/mcp") + + server_metadata = { + "issuer": "https://auth.example.com", + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + "registration_endpoint": "https://auth.example.com/register", + "code_challenge_methods_supported": ["S256"], + } + + resp = _mock_response(server_metadata, status=200) + + with patch("backend.blocks.mcp.client.Requests") as MockRequests: + instance = MockRequests.return_value + instance.get = AsyncMock(return_value=resp) + + result = await client.discover_auth_server_metadata( + "https://auth.example.com" + ) + + assert result is not None + assert result["authorization_endpoint"] == "https://auth.example.com/authorize" + assert result["token_endpoint"] == "https://auth.example.com/token" diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_server.py b/autogpt_platform/backend/backend/blocks/mcp/test_server.py new file mode 100644 index 0000000000..a6732932bc --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_server.py @@ -0,0 +1,162 @@ +""" +Minimal MCP server for integration testing. + +Implements the MCP Streamable HTTP transport (JSON-RPC 2.0 over HTTP POST) +with a few sample tools. Runs on localhost with a random available port. +""" + +import json +import logging + +from aiohttp import web + +logger = logging.getLogger(__name__) + +# Sample tools this test server exposes +TEST_TOOLS = [ + { + "name": "get_weather", + "description": "Get current weather for a city", + "inputSchema": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + }, + }, + "required": ["city"], + }, + }, + { + "name": "add_numbers", + "description": "Add two numbers together", + "inputSchema": { + "type": "object", + "properties": { + "a": {"type": "number", "description": "First number"}, + "b": {"type": "number", "description": "Second number"}, + }, + "required": ["a", "b"], + }, + }, + { + "name": "echo", + "description": "Echo back the input message", + "inputSchema": { + "type": "object", + "properties": { + "message": {"type": "string", "description": "Message to echo"}, + }, + "required": ["message"], + }, + }, +] + + +def _handle_initialize(params: dict) -> dict: + return { + "protocolVersion": "2025-03-26", + "capabilities": {"tools": {"listChanged": False}}, + "serverInfo": {"name": "test-mcp-server", "version": "1.0.0"}, + } + + +def _handle_tools_list(params: dict) -> dict: + return {"tools": TEST_TOOLS} + + +def _handle_tools_call(params: dict) -> dict: + tool_name = params.get("name", "") + arguments = params.get("arguments", {}) + + if tool_name == "get_weather": + city = arguments.get("city", "Unknown") + return { + "content": [ + { + "type": "text", + "text": json.dumps( + {"city": city, "temperature": 22, "condition": "sunny"} + ), + } + ], + } + + elif tool_name == "add_numbers": + a = arguments.get("a", 0) + b = arguments.get("b", 0) + return { + "content": [{"type": "text", "text": json.dumps({"result": a + b})}], + } + + elif tool_name == "echo": + message = arguments.get("message", "") + return { + "content": [{"type": "text", "text": message}], + } + + else: + return { + "content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}], + "isError": True, + } + + +HANDLERS = { + "initialize": _handle_initialize, + "tools/list": _handle_tools_list, + "tools/call": _handle_tools_call, +} + + +async def handle_mcp_request(request: web.Request) -> web.Response: + """Handle incoming MCP JSON-RPC 2.0 requests.""" + # Check auth if configured + expected_token = request.app.get("auth_token") + if expected_token: + auth_header = request.headers.get("Authorization", "") + if auth_header != f"Bearer {expected_token}": + return web.json_response( + { + "jsonrpc": "2.0", + "error": {"code": -32001, "message": "Unauthorized"}, + "id": None, + }, + status=401, + ) + + body = await request.json() + + # Handle notifications (no id field) — just acknowledge + if "id" not in body: + return web.Response(status=202) + + method = body.get("method", "") + params = body.get("params", {}) + request_id = body.get("id") + + handler = HANDLERS.get(method) + if not handler: + return web.json_response( + { + "jsonrpc": "2.0", + "error": { + "code": -32601, + "message": f"Method not found: {method}", + }, + "id": request_id, + } + ) + + result = handler(params) + return web.json_response({"jsonrpc": "2.0", "result": result, "id": request_id}) + + +def create_test_mcp_app(auth_token: str | None = None) -> web.Application: + """Create an aiohttp app that acts as an MCP server.""" + app = web.Application() + app.router.add_post("/mcp", handle_mcp_request) + if auth_token: + app["auth_token"] = auth_token + return app diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index f39a0144e7..94f99852e8 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -33,6 +33,7 @@ from backend.util import type as type_utils from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError from backend.util.json import SafeJson from backend.util.models import Pagination +from backend.util.request import parse_url from .block import BlockInput from .db import BaseDbModel @@ -449,6 +450,9 @@ class GraphModel(Graph, GraphMeta): continue if ProviderName.HTTP in field.provider: continue + # MCP credentials are intentionally split by server URL + if ProviderName.MCP in field.provider: + continue # If this happens, that means a block implementation probably needs # to be updated. @@ -505,6 +509,18 @@ class GraphModel(Graph, GraphMeta): "required": ["id", "provider", "type"], } + # Add a descriptive display title when URL-based discriminator values + # are present (e.g. "mcp.sentry.dev" instead of just "Mcp") + if ( + field_info.discriminator + and not field_info.discriminator_mapping + and field_info.discriminator_values + ): + hostnames = sorted( + parse_url(str(v)).netloc for v in field_info.discriminator_values + ) + field_schema["display_name"] = ", ".join(hostnames) + # Add other (optional) field info items field_schema.update( field_info.model_dump( @@ -549,8 +565,17 @@ class GraphModel(Graph, GraphMeta): for graph in [self] + self.sub_graphs: for node in graph.nodes: - # Track if this node requires credentials (credentials_optional=False means required) - node_required_map[node.id] = not node.credentials_optional + # A node's credentials are optional if either: + # 1. The node metadata says so (credentials_optional=True), or + # 2. All credential fields on the block have defaults (not required by schema) + block_required = node.block.input_schema.get_required_fields() + creds_required_by_schema = any( + fname in block_required + for fname in node.block.input_schema.get_credentials_fields() + ) + node_required_map[node.id] = ( + not node.credentials_optional and creds_required_by_schema + ) for ( field_name, @@ -776,6 +801,19 @@ class GraphModel(Graph, GraphMeta): "'credentials' and `*_credentials` are reserved" ) + # Check custom block-level validation (e.g., MCP dynamic tool arguments). + # Blocks can override get_missing_input to report additional missing fields + # beyond the standard top-level required fields. + if for_run: + credential_fields = InputSchema.get_credentials_fields() + custom_missing = InputSchema.get_missing_input(node.input_default) + for field_name in custom_missing: + if ( + field_name not in provided_inputs + and field_name not in credential_fields + ): + node_errors[node.id][field_name] = "This field is required" + # Get input schema properties and check dependencies input_fields = InputSchema.model_fields diff --git a/autogpt_platform/backend/backend/data/graph_test.py b/autogpt_platform/backend/backend/data/graph_test.py index 442c8ed4be..3cb6f24b87 100644 --- a/autogpt_platform/backend/backend/data/graph_test.py +++ b/autogpt_platform/backend/backend/data/graph_test.py @@ -462,3 +462,120 @@ def test_node_credentials_optional_with_other_metadata(): assert node.credentials_optional is True assert node.metadata["position"] == {"x": 100, "y": 200} assert node.metadata["customized_name"] == "My Custom Node" + + +# ============================================================================ +# Tests for MCP Credential Deduplication +# ============================================================================ + + +def test_mcp_credential_combine_different_servers(): + """Two MCP credential fields with different server URLs should produce + separate entries when combined (not merged into one).""" + from backend.data.model import CredentialsFieldInfo, CredentialsType + from backend.integrations.providers import ProviderName + + oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"]) + + field_sentry = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + discriminator_values={"https://mcp.sentry.dev/mcp"}, + ) + field_linear = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + discriminator_values={"https://mcp.linear.app/mcp"}, + ) + + combined = CredentialsFieldInfo.combine( + (field_sentry, ("node-sentry", "credentials")), + (field_linear, ("node-linear", "credentials")), + ) + + # Should produce 2 separate credential entries + assert len(combined) == 2, ( + f"Expected 2 credential entries for 2 MCP blocks with different servers, " + f"got {len(combined)}: {list(combined.keys())}" + ) + + # Each entry should contain the server hostname in its key + keys = list(combined.keys()) + assert any( + "mcp.sentry.dev" in k for k in keys + ), f"Expected 'mcp.sentry.dev' in one key, got {keys}" + assert any( + "mcp.linear.app" in k for k in keys + ), f"Expected 'mcp.linear.app' in one key, got {keys}" + + +def test_mcp_credential_combine_same_server(): + """Two MCP credential fields with the same server URL should be combined + into one credential entry.""" + from backend.data.model import CredentialsFieldInfo, CredentialsType + from backend.integrations.providers import ProviderName + + oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"]) + + field_a = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + discriminator_values={"https://mcp.sentry.dev/mcp"}, + ) + field_b = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + discriminator_values={"https://mcp.sentry.dev/mcp"}, + ) + + combined = CredentialsFieldInfo.combine( + (field_a, ("node-a", "credentials")), + (field_b, ("node-b", "credentials")), + ) + + # Should produce 1 credential entry (same server URL) + assert len(combined) == 1, ( + f"Expected 1 credential entry for 2 MCP blocks with same server, " + f"got {len(combined)}: {list(combined.keys())}" + ) + + +def test_mcp_credential_combine_no_discriminator_values(): + """MCP credential fields without discriminator_values should be merged + into a single entry (backwards compat for blocks without server_url set).""" + from backend.data.model import CredentialsFieldInfo, CredentialsType + from backend.integrations.providers import ProviderName + + oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"]) + + field_a = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + ) + field_b = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + ) + + combined = CredentialsFieldInfo.combine( + (field_a, ("node-a", "credentials")), + (field_b, ("node-b", "credentials")), + ) + + # Should produce 1 entry (no URL differentiation) + assert len(combined) == 1, ( + f"Expected 1 credential entry for MCP blocks without discriminator_values, " + f"got {len(combined)}: {list(combined.keys())}" + ) diff --git a/autogpt_platform/backend/backend/data/model.py b/autogpt_platform/backend/backend/data/model.py index e61f7efbd0..c9d8c5879f 100644 --- a/autogpt_platform/backend/backend/data/model.py +++ b/autogpt_platform/backend/backend/data/model.py @@ -29,6 +29,7 @@ from pydantic import ( GetCoreSchemaHandler, SecretStr, field_serializer, + model_validator, ) from pydantic_core import ( CoreSchema, @@ -502,6 +503,25 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]): provider: CP type: CT + @model_validator(mode="before") + @classmethod + def _normalize_legacy_provider(cls, data: Any) -> Any: + """Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug. + + Python 3.13 changed ``str(StrEnum)`` to return ``"ClassName.MEMBER"`` + instead of the plain value. Old stored credential references may have + ``provider: "ProviderName.MCP"`` instead of ``"mcp"``. + """ + if isinstance(data, dict): + prov = data.get("provider", "") + if isinstance(prov, str) and prov.startswith("ProviderName."): + member = prov.removeprefix("ProviderName.") + try: + data = {**data, "provider": ProviderName[member].value} + except KeyError: + pass + return data + @classmethod def allowed_providers(cls) -> tuple[ProviderName, ...] | None: return get_args(cls.model_fields["provider"].annotation) @@ -606,11 +626,18 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]): ] = defaultdict(list) for field, key in fields: - if field.provider == frozenset([ProviderName.HTTP]): - # HTTP host-scoped credentials can have different hosts that reqires different credential sets. - # Group by host extracted from the URL + if ( + field.discriminator + and not field.discriminator_mapping + and field.discriminator_values + ): + # URL-based discrimination (e.g. HTTP host-scoped, MCP server URL): + # Each unique host gets its own credential entry. + provider_prefix = next(iter(field.provider)) + # Use .value for enum types to get the plain string (e.g. "mcp" not "ProviderName.MCP") + prefix_str = getattr(provider_prefix, "value", str(provider_prefix)) providers = frozenset( - [cast(CP, "http")] + [cast(CP, prefix_str)] + [ cast(CP, parse_url(str(value)).netloc) for value in field.discriminator_values diff --git a/autogpt_platform/backend/backend/executor/manager.py b/autogpt_platform/backend/backend/executor/manager.py index 1f76458947..caa98784c2 100644 --- a/autogpt_platform/backend/backend/executor/manager.py +++ b/autogpt_platform/backend/backend/executor/manager.py @@ -20,6 +20,7 @@ from backend.blocks import get_block from backend.blocks._base import BlockSchema from backend.blocks.agent import AgentExecutorBlock from backend.blocks.io import AgentOutputBlock +from backend.blocks.mcp.block import MCPToolBlock from backend.data import redis_client as redis from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry from backend.data.credit import UsageTransactionMetadata @@ -228,6 +229,18 @@ async def execute_node( _input_data.nodes_input_masks = nodes_input_masks _input_data.user_id = user_id input_data = _input_data.model_dump() + elif isinstance(node_block, MCPToolBlock): + _mcp_data = MCPToolBlock.Input(**node.input_default) + # Dynamic tool fields are flattened to top-level by validate_exec + # (via get_input_defaults). Collect them back into tool_arguments. + tool_schema = _mcp_data.tool_input_schema + tool_props = set(tool_schema.get("properties", {}).keys()) + merged_args = {**_mcp_data.tool_arguments} + for key in tool_props: + if key in input_data: + merged_args[key] = input_data[key] + _mcp_data.tool_arguments = merged_args + input_data = _mcp_data.model_dump() data.inputs = input_data # Execute the node @@ -264,8 +277,34 @@ async def execute_node( # Handle regular credentials fields for field_name, input_type in input_model.get_credentials_fields().items(): - credentials_meta = input_type(**input_data[field_name]) - credentials, lock = await creds_manager.acquire(user_id, credentials_meta.id) + field_value = input_data.get(field_name) + if not field_value or ( + isinstance(field_value, dict) and not field_value.get("id") + ): + # No credentials configured — nullify so JSON schema validation + # doesn't choke on the empty default `{}`. + input_data[field_name] = None + continue # Block runs without credentials + + credentials_meta = input_type(**field_value) + # Write normalized values back so JSON schema validation also passes + # (model_validator may have fixed legacy formats like "ProviderName.MCP") + input_data[field_name] = credentials_meta.model_dump(mode="json") + try: + credentials, lock = await creds_manager.acquire( + user_id, credentials_meta.id + ) + except ValueError: + # Credential was deleted or doesn't exist. + # If the field has a default, run without credentials. + if input_model.model_fields[field_name].default is not None: + log_metadata.warning( + f"Credentials #{credentials_meta.id} not found, " + "running without (field has default)" + ) + input_data[field_name] = None + continue + raise creds_locks.append(lock) extra_exec_kwargs[field_name] = credentials diff --git a/autogpt_platform/backend/backend/executor/utils.py b/autogpt_platform/backend/backend/executor/utils.py index bb5da1e527..2b9a454061 100644 --- a/autogpt_platform/backend/backend/executor/utils.py +++ b/autogpt_platform/backend/backend/executor/utils.py @@ -260,7 +260,13 @@ async def _validate_node_input_credentials( # Track if any credential field is missing for this node has_missing_credentials = False + # A credential field is optional if the node metadata says so, or if + # the block schema declares a default for the field. + required_fields = block.input_schema.get_required_fields() + is_creds_optional = node.credentials_optional + for field_name, credentials_meta_type in credentials_fields.items(): + field_is_optional = is_creds_optional or field_name not in required_fields try: # Check nodes_input_masks first, then input_default field_value = None @@ -273,7 +279,7 @@ async def _validate_node_input_credentials( elif field_name in node.input_default: # For optional credentials, don't use input_default - treat as missing # This prevents stale credential IDs from failing validation - if node.credentials_optional: + if field_is_optional: field_value = None else: field_value = node.input_default[field_name] @@ -283,8 +289,8 @@ async def _validate_node_input_credentials( isinstance(field_value, dict) and not field_value.get("id") ): has_missing_credentials = True - # If node has credentials_optional flag, mark for skipping instead of error - if node.credentials_optional: + # If credential field is optional, skip instead of error + if field_is_optional: continue # Don't add error, will be marked for skip after loop else: credential_errors[node.id][ @@ -334,16 +340,16 @@ async def _validate_node_input_credentials( ] = "Invalid credentials: type/provider mismatch" continue - # If node has optional credentials and any are missing, mark for skipping - # But only if there are no other errors for this node + # If node has optional credentials and any are missing, allow running without. + # The executor will pass credentials=None to the block's run(). if ( has_missing_credentials - and node.credentials_optional + and is_creds_optional and node.id not in credential_errors ): - nodes_to_skip.add(node.id) logger.info( - f"Node #{node.id} will be skipped: optional credentials not configured" + f"Node #{node.id}: optional credentials not configured, " + "running without" ) return credential_errors, nodes_to_skip diff --git a/autogpt_platform/backend/backend/executor/utils_test.py b/autogpt_platform/backend/backend/executor/utils_test.py index db33249583..069086a6fd 100644 --- a/autogpt_platform/backend/backend/executor/utils_test.py +++ b/autogpt_platform/backend/backend/executor/utils_test.py @@ -495,6 +495,7 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip( mock_block.input_schema.get_credentials_fields.return_value = { "credentials": mock_credentials_field_type } + mock_block.input_schema.get_required_fields.return_value = {"credentials"} mock_node.block = mock_block # Create mock graph @@ -508,8 +509,8 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip( nodes_input_masks=None, ) - # Node should be in nodes_to_skip, not in errors - assert mock_node.id in nodes_to_skip + # Node should NOT be in nodes_to_skip (runs without credentials) and not in errors + assert mock_node.id not in nodes_to_skip assert mock_node.id not in errors @@ -535,6 +536,7 @@ async def test_validate_node_input_credentials_required_missing_creds_error( mock_block.input_schema.get_credentials_fields.return_value = { "credentials": mock_credentials_field_type } + mock_block.input_schema.get_required_fields.return_value = {"credentials"} mock_node.block = mock_block # Create mock graph diff --git a/autogpt_platform/backend/backend/integrations/credentials_store.py b/autogpt_platform/backend/backend/integrations/credentials_store.py index 384405b0c7..3e79a6c047 100644 --- a/autogpt_platform/backend/backend/integrations/credentials_store.py +++ b/autogpt_platform/backend/backend/integrations/credentials_store.py @@ -22,6 +22,27 @@ from backend.util.settings import Settings settings = Settings() + +def provider_matches(stored: str, expected: str) -> bool: + """Compare provider strings, handling Python 3.13 ``str(StrEnum)`` bug. + + On Python 3.13, ``str(ProviderName.MCP)`` returns ``"ProviderName.MCP"`` + instead of ``"mcp"``. OAuth states persisted with the buggy format need + to match when ``expected`` is the canonical value (e.g. ``"mcp"``). + """ + if stored == expected: + return True + if stored.startswith("ProviderName."): + member = stored.removeprefix("ProviderName.") + from backend.integrations.providers import ProviderName + + try: + return ProviderName[member].value == expected + except KeyError: + pass + return False + + # This is an overrride since ollama doesn't actually require an API key, but the creddential system enforces one be attached ollama_credentials = APIKeyCredentials( id="744fdc56-071a-4761-b5a5-0af0ce10a2b5", @@ -389,7 +410,7 @@ class IntegrationCredentialsStore: self, user_id: str, provider: str ) -> list[Credentials]: credentials = await self.get_all_creds(user_id) - return [c for c in credentials if c.provider == provider] + return [c for c in credentials if provider_matches(c.provider, provider)] async def get_authorized_providers(self, user_id: str) -> list[str]: credentials = await self.get_all_creds(user_id) @@ -485,17 +506,6 @@ class IntegrationCredentialsStore: async with self.edit_user_integrations(user_id) as user_integrations: user_integrations.oauth_states.append(state) - async with await self.locked_user_integrations(user_id): - - user_integrations = await self._get_user_integrations(user_id) - oauth_states = user_integrations.oauth_states - oauth_states.append(state) - user_integrations.oauth_states = oauth_states - - await self.db_manager.update_user_integrations( - user_id=user_id, data=user_integrations - ) - return token, code_challenge def _generate_code_challenge(self) -> tuple[str, str]: @@ -521,7 +531,7 @@ class IntegrationCredentialsStore: state for state in oauth_states if secrets.compare_digest(state.token, token) - and state.provider == provider + and provider_matches(state.provider, provider) and state.expires_at > now.timestamp() ), None, diff --git a/autogpt_platform/backend/backend/integrations/creds_manager.py b/autogpt_platform/backend/backend/integrations/creds_manager.py index f2b6a9da4f..5634dd73b6 100644 --- a/autogpt_platform/backend/backend/integrations/creds_manager.py +++ b/autogpt_platform/backend/backend/integrations/creds_manager.py @@ -9,7 +9,10 @@ from redis.asyncio.lock import Lock as AsyncRedisLock from backend.data.model import Credentials, OAuth2Credentials from backend.data.redis_client import get_redis_async -from backend.integrations.credentials_store import IntegrationCredentialsStore +from backend.integrations.credentials_store import ( + IntegrationCredentialsStore, + provider_matches, +) from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME from backend.integrations.providers import ProviderName from backend.util.exceptions import MissingConfigError @@ -137,7 +140,10 @@ class IntegrationCredentialsManager: self, user_id: str, credentials: OAuth2Credentials, lock: bool = True ) -> OAuth2Credentials: async with self._locked(user_id, credentials.id, "refresh"): - oauth_handler = await _get_provider_oauth_handler(credentials.provider) + if provider_matches(credentials.provider, ProviderName.MCP.value): + oauth_handler = create_mcp_oauth_handler(credentials) + else: + oauth_handler = await _get_provider_oauth_handler(credentials.provider) if oauth_handler.needs_refresh(credentials): logger.debug( f"Refreshing '{credentials.provider}' " @@ -236,3 +242,31 @@ async def _get_provider_oauth_handler(provider_name_str: str) -> "BaseOAuthHandl client_secret=client_secret, redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback", ) + + +def create_mcp_oauth_handler( + credentials: OAuth2Credentials, +) -> "BaseOAuthHandler": + """Create an MCPOAuthHandler from credential metadata for token refresh. + + MCP OAuth handlers have dynamic endpoints discovered per-server, so they + can't be registered as singletons in HANDLERS_BY_NAME. Instead, the handler + is reconstructed from metadata stored on the credential during initial auth. + """ + from backend.blocks.mcp.oauth import MCPOAuthHandler + + meta = credentials.metadata or {} + token_url = meta.get("mcp_token_url", "") + if not token_url: + raise ValueError( + f"MCP credential {credentials.id} is missing 'mcp_token_url' metadata; " + "cannot refresh tokens" + ) + return MCPOAuthHandler( + client_id=meta.get("mcp_client_id", ""), + client_secret=meta.get("mcp_client_secret", ""), + redirect_uri="", # Not needed for token refresh + authorize_url="", # Not needed for token refresh + token_url=token_url, + resource_url=meta.get("mcp_resource_url"), + ) diff --git a/autogpt_platform/backend/backend/integrations/providers.py b/autogpt_platform/backend/backend/integrations/providers.py index 8a0d6fd183..a462cd787f 100644 --- a/autogpt_platform/backend/backend/integrations/providers.py +++ b/autogpt_platform/backend/backend/integrations/providers.py @@ -30,6 +30,7 @@ class ProviderName(str, Enum): IDEOGRAM = "ideogram" JINA = "jina" LLAMA_API = "llama_api" + MCP = "mcp" MEDIUM = "medium" MEM0 = "mem0" NOTION = "notion" diff --git a/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py b/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py index 99eee404b9..8fdbe10383 100644 --- a/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py +++ b/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py @@ -51,6 +51,21 @@ async def _on_graph_activate(graph: "BaseGraph | GraphModel", user_id: str): if ( creds_meta := new_node.input_default.get(creds_field_name) ) and not await get_credentials(creds_meta["id"]): + # If the credential field is optional (has a default in the + # schema, or node metadata marks it optional), clear the stale + # reference instead of blocking the save. + creds_field_optional = ( + new_node.credentials_optional + or creds_field_name not in block_input_schema.get_required_fields() + ) + if creds_field_optional: + new_node.input_default[creds_field_name] = {} + logger.warning( + f"Node #{new_node.id}: cleared stale optional " + f"credentials #{creds_meta['id']} for " + f"'{creds_field_name}'" + ) + continue raise ValueError( f"Node #{new_node.id} input '{creds_field_name}' updated with " f"non-existent credentials #{creds_meta['id']}" diff --git a/autogpt_platform/backend/backend/util/request.py b/autogpt_platform/backend/backend/util/request.py index 95e5ee32f7..9470909dfc 100644 --- a/autogpt_platform/backend/backend/util/request.py +++ b/autogpt_platform/backend/backend/util/request.py @@ -101,7 +101,7 @@ class HostResolver(abc.AbstractResolver): def __init__(self, ssl_hostname: str, ip_addresses: list[str]): self.ssl_hostname = ssl_hostname self.ip_addresses = ip_addresses - self._default = aiohttp.AsyncResolver() + self._default = aiohttp.ThreadedResolver() async def resolve(self, host, port=0, family=socket.AF_INET): if host == self.ssl_hostname: @@ -467,7 +467,7 @@ class Requests: resolver = HostResolver(ssl_hostname=hostname, ip_addresses=ip_addresses) ssl_context = ssl.create_default_context() connector = aiohttp.TCPConnector(resolver=resolver, ssl=ssl_context) - session_kwargs = {} + session_kwargs: dict = {} if connector: session_kwargs["connector"] = connector diff --git a/autogpt_platform/frontend/src/app/(platform)/auth/integrations/mcp_callback/route.ts b/autogpt_platform/frontend/src/app/(platform)/auth/integrations/mcp_callback/route.ts new file mode 100644 index 0000000000..326f42e049 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/auth/integrations/mcp_callback/route.ts @@ -0,0 +1,96 @@ +import { NextResponse } from "next/server"; + +/** + * Safely encode a value as JSON for embedding in a script tag. + * Escapes characters that could break out of the script context to prevent XSS. + */ +function safeJsonStringify(value: unknown): string { + return JSON.stringify(value) + .replace(//g, "\\u003e") + .replace(/&/g, "\\u0026"); +} + +// MCP-specific OAuth callback route. +// +// Unlike the generic oauth_callback which relies on window.opener.postMessage, +// this route uses BroadcastChannel as the PRIMARY communication method. +// This is critical because cross-origin OAuth flows (e.g. Sentry → localhost) +// often lose window.opener due to COOP (Cross-Origin-Opener-Policy) headers. +// +// BroadcastChannel works across all same-origin tabs/popups regardless of opener. +export async function GET(request: Request) { + const { searchParams } = new URL(request.url); + const code = searchParams.get("code"); + const state = searchParams.get("state"); + + const success = Boolean(code && state); + const message = success + ? { success: true, code, state } + : { + success: false, + message: `Missing parameters: ${searchParams.toString()}`, + }; + + return new NextResponse( + ` + + MCP Sign-in + +
+
+

Completing sign-in...

+
+ + + +`, + { headers: { "Content-Type": "text/html" } }, + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/CustomNode.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/CustomNode.tsx index d4aa26480d..62e796b748 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/CustomNode.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/CustomNode.tsx @@ -47,7 +47,10 @@ export type CustomNode = XYNode; export const CustomNode: React.FC> = React.memo( ({ data, id: nodeId, selected }) => { - const { inputSchema, outputSchema } = useCustomNode({ data, nodeId }); + const { inputSchema, outputSchema, isMCPWithTool } = useCustomNode({ + data, + nodeId, + }); const isAgent = data.uiType === BlockUIType.AGENT; @@ -98,6 +101,7 @@ export const CustomNode: React.FC> = React.memo( jsonSchema={preprocessInputSchema(inputSchema)} nodeId={nodeId} uiType={data.uiType} + isMCPWithTool={isMCPWithTool} className={cn( "bg-white px-4", isWebhook && "pointer-events-none opacity-50", diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx index c4659b8dcf..9a3add62b6 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx @@ -20,10 +20,8 @@ type Props = { export const NodeHeader = ({ data, nodeId }: Props) => { const updateNodeData = useNodeStore((state) => state.updateNodeData); - const title = - (data.metadata?.customized_name as string) || - data.hardcodedValues?.agent_name || - data.title; + + const title = (data.metadata?.customized_name as string) || data.title; const [isEditingTitle, setIsEditingTitle] = useState(false); const [editedTitle, setEditedTitle] = useState(title); diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/useCustomNode.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/useCustomNode.tsx index e58d0ab12b..050515a02f 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/useCustomNode.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/useCustomNode.tsx @@ -3,6 +3,34 @@ import { CustomNodeData } from "./CustomNode"; import { BlockUIType } from "../../../types"; import { useMemo } from "react"; import { mergeSchemaForResolution } from "./helpers"; +/** + * Build a dynamic input schema for MCP blocks. + * + * When a tool has been selected (tool_input_schema is populated), the block + * renders the selected tool's input parameters *plus* the credentials field + * so users can select/change the OAuth credential used for execution. + * + * Static fields like server_url, selected_tool, available_tools, and + * tool_arguments are hidden because they're pre-configured from the dialog. + */ +function buildMCPInputSchema( + toolInputSchema: Record, + blockInputSchema: Record, +): Record { + // Extract the credentials field from the block's original input schema + const credentialsSchema = + blockInputSchema?.properties?.credentials ?? undefined; + + return { + type: "object", + properties: { + // Credentials field first so the dropdown appears at the top + ...(credentialsSchema ? { credentials: credentialsSchema } : {}), + ...(toolInputSchema.properties ?? {}), + }, + required: [...(toolInputSchema.required ?? [])], + }; +} export const useCustomNode = ({ data, @@ -19,10 +47,18 @@ export const useCustomNode = ({ ); const isAgent = data.uiType === BlockUIType.AGENT; + const isMCPWithTool = + data.uiType === BlockUIType.MCP_TOOL && + !!data.hardcodedValues?.tool_input_schema?.properties; const currentInputSchema = isAgent ? (data.hardcodedValues.input_schema ?? {}) - : data.inputSchema; + : isMCPWithTool + ? buildMCPInputSchema( + data.hardcodedValues.tool_input_schema, + data.inputSchema, + ) + : data.inputSchema; const currentOutputSchema = isAgent ? (data.hardcodedValues.output_schema ?? {}) : data.outputSchema; @@ -54,5 +90,6 @@ export const useCustomNode = ({ return { inputSchema, outputSchema, + isMCPWithTool, }; }; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/FormCreator.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/FormCreator.tsx index d6a3fabffa..77b21dda92 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/FormCreator.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/FormCreator.tsx @@ -9,39 +9,72 @@ interface FormCreatorProps { jsonSchema: RJSFSchema; nodeId: string; uiType: BlockUIType; + /** When true the block is an MCP Tool with a selected tool. */ + isMCPWithTool?: boolean; showHandles?: boolean; className?: string; } export const FormCreator: React.FC = React.memo( - ({ jsonSchema, nodeId, uiType, showHandles = true, className }) => { + ({ + jsonSchema, + nodeId, + uiType, + isMCPWithTool = false, + showHandles = true, + className, + }) => { const updateNodeData = useNodeStore((state) => state.updateNodeData); const getHardCodedValues = useNodeStore( (state) => state.getHardCodedValues, ); + const isAgent = uiType === BlockUIType.AGENT; + const handleChange = ({ formData }: any) => { if ("credentials" in formData && !formData.credentials?.id) { delete formData.credentials; } - const updatedValues = - uiType === BlockUIType.AGENT - ? { - ...getHardCodedValues(nodeId), - inputs: formData, - } - : formData; + let updatedValues; + if (isAgent) { + updatedValues = { + ...getHardCodedValues(nodeId), + inputs: formData, + }; + } else if (isMCPWithTool) { + // Separate credentials from tool arguments — credentials are stored + // at the top level of hardcodedValues, not inside tool_arguments. + const { credentials, ...toolArgs } = formData; + updatedValues = { + ...getHardCodedValues(nodeId), + tool_arguments: toolArgs, + ...(credentials?.id ? { credentials } : {}), + }; + } else { + updatedValues = formData; + } updateNodeData(nodeId, { hardcodedValues: updatedValues }); }; const hardcodedValues = getHardCodedValues(nodeId); - const initialValues = - uiType === BlockUIType.AGENT - ? (hardcodedValues.inputs ?? {}) - : hardcodedValues; + + let initialValues; + if (isAgent) { + initialValues = hardcodedValues.inputs ?? {}; + } else if (isMCPWithTool) { + // Merge tool arguments with credentials for the form + initialValues = { + ...(hardcodedValues.tool_arguments ?? {}), + ...(hardcodedValues.credentials?.id + ? { credentials: hardcodedValues.credentials } + : {}), + }; + } else { + initialValues = hardcodedValues; + } return (
; + availableTools: Record; + /** Credentials meta from OAuth flow, null for public servers. */ + credentials: CredentialsMetaInput | null; +}; + +interface MCPToolDialogProps { + open: boolean; + onClose: () => void; + onConfirm: (result: MCPToolDialogResult) => void; +} + +type DialogStep = "url" | "tool"; + +export function MCPToolDialog({ + open, + onClose, + onConfirm, +}: MCPToolDialogProps) { + const allProviders = useContext(CredentialsProvidersContext); + + const [step, setStep] = useState("url"); + const [serverUrl, setServerUrl] = useState(""); + const [tools, setTools] = useState([]); + const [serverName, setServerName] = useState(null); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); + const [authRequired, setAuthRequired] = useState(false); + const [oauthLoading, setOauthLoading] = useState(false); + const [showManualToken, setShowManualToken] = useState(false); + const [manualToken, setManualToken] = useState(""); + const [selectedTool, setSelectedTool] = useState( + null, + ); + const [credentials, setCredentials] = useState( + null, + ); + + const startOAuthRef = useRef(false); + const oauthAbortRef = useRef<((reason?: string) => void) | null>(null); + + // Clean up on unmount + useEffect(() => { + return () => { + oauthAbortRef.current?.(); + }; + }, []); + + const reset = useCallback(() => { + oauthAbortRef.current?.(); + oauthAbortRef.current = null; + setStep("url"); + setServerUrl(""); + setManualToken(""); + setTools([]); + setServerName(null); + setLoading(false); + setError(null); + setAuthRequired(false); + setOauthLoading(false); + setShowManualToken(false); + setSelectedTool(null); + setCredentials(null); + }, []); + + const handleClose = useCallback(() => { + reset(); + onClose(); + }, [reset, onClose]); + + const discoverTools = useCallback(async (url: string, authToken?: string) => { + setLoading(true); + setError(null); + try { + const response = await postV2DiscoverAvailableToolsOnAnMcpServer({ + server_url: url, + auth_token: authToken || null, + }); + if (response.status !== 200) throw response.data; + setTools(response.data.tools); + setServerName(response.data.server_name ?? null); + setAuthRequired(false); + setShowManualToken(false); + setStep("tool"); + } catch (e: any) { + if (e?.status === 401 || e?.status === 403) { + setAuthRequired(true); + setError(null); + // Automatically start OAuth sign-in instead of requiring a second click + setLoading(false); + startOAuthRef.current = true; + return; + } else { + const message = + e?.message || e?.detail || "Failed to connect to MCP server"; + setError( + typeof message === "string" ? message : JSON.stringify(message), + ); + } + } finally { + setLoading(false); + } + }, []); + + const handleDiscoverTools = useCallback(() => { + if (!serverUrl.trim()) return; + discoverTools(serverUrl.trim(), manualToken.trim() || undefined); + }, [serverUrl, manualToken, discoverTools]); + + const handleOAuthSignIn = useCallback(async () => { + if (!serverUrl.trim()) return; + setError(null); + + // Abort any previous OAuth flow + oauthAbortRef.current?.(); + + setOauthLoading(true); + + try { + const loginResponse = await postV2InitiateOauthLoginForAnMcpServer({ + server_url: serverUrl.trim(), + }); + if (loginResponse.status !== 200) throw loginResponse.data; + const { login_url, state_token } = loginResponse.data; + + const { promise, cleanup } = openOAuthPopup(login_url, { + stateToken: state_token, + useCrossOriginListeners: true, + }); + oauthAbortRef.current = cleanup.abort; + + const result = await promise; + + // Exchange code for tokens via the credentials provider (updates cache) + setLoading(true); + setOauthLoading(false); + + const mcpProvider = allProviders?.["mcp"]; + let callbackResult; + if (mcpProvider) { + callbackResult = await mcpProvider.mcpOAuthCallback( + result.code, + state_token, + ); + } else { + const cbResponse = await postV2ExchangeOauthCodeForMcpTokens({ + code: result.code, + state_token, + }); + if (cbResponse.status !== 200) throw cbResponse.data; + callbackResult = cbResponse.data; + } + + setCredentials({ + id: callbackResult.id, + provider: callbackResult.provider, + type: callbackResult.type, + title: callbackResult.title, + }); + setAuthRequired(false); + + // Discover tools now that we're authenticated + const toolsResponse = await postV2DiscoverAvailableToolsOnAnMcpServer({ + server_url: serverUrl.trim(), + }); + if (toolsResponse.status !== 200) throw toolsResponse.data; + setTools(toolsResponse.data.tools); + setServerName(toolsResponse.data.server_name ?? null); + setStep("tool"); + } catch (e: any) { + // If server doesn't support OAuth → show manual token entry + if (e?.status === 400) { + setShowManualToken(true); + setError( + "This server does not support OAuth sign-in. Please enter a token manually.", + ); + } else if (e?.message === "OAuth flow timed out") { + setError("OAuth sign-in timed out. Please try again."); + } else { + const status = e?.status; + let message: string; + if (status === 401 || status === 403) { + message = + "Authentication succeeded but the server still rejected the request. " + + "The token audience may not match. Please try again."; + } else { + message = e?.message || e?.detail || "Failed to complete sign-in"; + } + setError( + typeof message === "string" ? message : JSON.stringify(message), + ); + } + } finally { + setOauthLoading(false); + setLoading(false); + oauthAbortRef.current = null; + } + }, [serverUrl, allProviders]); + + // Auto-start OAuth sign-in when server returns 401/403 + useEffect(() => { + if (authRequired && startOAuthRef.current) { + startOAuthRef.current = false; + handleOAuthSignIn(); + } + }, [authRequired, handleOAuthSignIn]); + + const handleConfirm = useCallback(() => { + if (!selectedTool) return; + + const availableTools: Record = {}; + for (const t of tools) { + availableTools[t.name] = { + description: t.description, + input_schema: t.input_schema, + }; + } + + onConfirm({ + serverUrl: serverUrl.trim(), + serverName, + selectedTool: selectedTool.name, + toolInputSchema: selectedTool.input_schema, + availableTools, + credentials, + }); + reset(); + }, [ + selectedTool, + tools, + serverUrl, + serverName, + credentials, + onConfirm, + reset, + ]); + + return ( + !isOpen && handleClose()}> + + + + {step === "url" + ? "Connect to MCP Server" + : `Select a Tool${serverName ? ` — ${serverName}` : ""}`} + + + {step === "url" + ? "Enter the URL of an MCP server to discover its available tools." + : `Found ${tools.length} tool${tools.length !== 1 ? "s" : ""}. Select one to add to your agent.`} + + + + {step === "url" && ( +
+
+ + setServerUrl(e.target.value)} + onKeyDown={(e) => e.key === "Enter" && handleDiscoverTools()} + autoFocus + /> +
+ + {/* Auth required: show manual token option */} + {authRequired && !showManualToken && ( + + )} + + {/* Manual token entry — only visible when expanded */} + {showManualToken && ( +
+ + setManualToken(e.target.value)} + onKeyDown={(e) => e.key === "Enter" && handleDiscoverTools()} + autoFocus + /> +
+ )} + + {error &&

{error}

} +
+ )} + + {step === "tool" && ( + +
+ {tools.map((tool) => ( + setSelectedTool(tool)} + /> + ))} +
+
+ )} + + + {step === "tool" && ( + + )} + + {step === "url" && ( + + )} + {step === "tool" && ( + + )} + +
+
+ ); +} + +// --------------- Tool Card Component --------------- // + +/** Truncate a description to a reasonable length for the collapsed view. */ +function truncateDescription(text: string, maxLen = 120): string { + if (text.length <= maxLen) return text; + return text.slice(0, maxLen).trimEnd() + "…"; +} + +/** Pretty-print a JSON Schema type for a parameter. */ +function schemaTypeLabel(schema: Record): string { + if (schema.type) return schema.type; + if (schema.anyOf) + return schema.anyOf.map((s: any) => s.type ?? "any").join(" | "); + if (schema.oneOf) + return schema.oneOf.map((s: any) => s.type ?? "any").join(" | "); + return "any"; +} + +function MCPToolCard({ + tool, + selected, + onSelect, +}: { + tool: MCPToolResponse; + selected: boolean; + onSelect: () => void; +}) { + const [expanded, setExpanded] = useState(false); + const schema = tool.input_schema as Record; + const properties = schema?.properties ?? {}; + const required = new Set(schema?.required ?? []); + const paramNames = Object.keys(properties); + + // Strip XML-like tags from description for cleaner display. + // Loop to handle nested tags like ipt> (CodeQL fix). + let cleanDescription = tool.description ?? ""; + let prev = ""; + while (prev !== cleanDescription) { + prev = cleanDescription; + cleanDescription = cleanDescription.replace(/<[^>]*>/g, ""); + } + cleanDescription = cleanDescription.trim(); + + return ( + + )} + + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/Block.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/Block.tsx index 10f4fc8a44..07c6795808 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/Block.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/Block.tsx @@ -1,7 +1,7 @@ import { Button } from "@/components/__legacy__/ui/button"; import { Skeleton } from "@/components/__legacy__/ui/skeleton"; import { beautifyString, cn } from "@/lib/utils"; -import React, { ButtonHTMLAttributes } from "react"; +import React, { ButtonHTMLAttributes, useCallback, useState } from "react"; import { highlightText } from "./helpers"; import { PlusIcon } from "@phosphor-icons/react"; import { BlockInfo } from "@/app/api/__generated__/models/blockInfo"; @@ -9,6 +9,12 @@ import { useControlPanelStore } from "../../../stores/controlPanelStore"; import { blockDragPreviewStyle } from "./style"; import { useReactFlow } from "@xyflow/react"; import { useNodeStore } from "../../../stores/nodeStore"; +import { BlockUIType, SpecialBlockID } from "@/lib/autogpt-server-api"; +import { + MCPToolDialog, + type MCPToolDialogResult, +} from "@/app/(platform)/build/components/MCPToolDialog"; + interface Props extends ButtonHTMLAttributes { title?: string; description?: string; @@ -33,22 +39,86 @@ export const Block: BlockComponent = ({ ); const { setViewport } = useReactFlow(); const { addBlock } = useNodeStore(); + const [mcpDialogOpen, setMcpDialogOpen] = useState(false); + + const isMCPBlock = blockData.uiType === BlockUIType.MCP_TOOL; + + const addBlockAndCenter = useCallback( + (block: BlockInfo, hardcodedValues?: Record) => { + const customNode = addBlock(block, hardcodedValues); + setTimeout(() => { + setViewport( + { + x: -customNode.position.x * 0.8 + window.innerWidth / 2, + y: -customNode.position.y * 0.8 + (window.innerHeight - 400) / 2, + zoom: 0.8, + }, + { duration: 500 }, + ); + }, 50); + return customNode; + }, + [addBlock, setViewport], + ); + + const updateNodeData = useNodeStore((state) => state.updateNodeData); + + const handleMCPToolConfirm = useCallback( + (result: MCPToolDialogResult) => { + // Derive a display label: prefer server name, fall back to URL hostname. + let serverLabel = result.serverName; + if (!serverLabel) { + try { + serverLabel = new URL(result.serverUrl).hostname; + } catch { + serverLabel = "MCP"; + } + } + + const customNode = addBlockAndCenter(blockData, { + server_url: result.serverUrl, + server_name: serverLabel, + selected_tool: result.selectedTool, + tool_input_schema: result.toolInputSchema, + available_tools: result.availableTools, + credentials: result.credentials ?? undefined, + }); + if (customNode) { + const title = result.selectedTool + ? `${serverLabel}: ${beautifyString(result.selectedTool)}` + : undefined; + updateNodeData(customNode.id, { + metadata: { + ...customNode.data.metadata, + credentials_optional: true, + ...(title && { customized_name: title }), + }, + }); + } + setMcpDialogOpen(false); + }, + [addBlockAndCenter, blockData, updateNodeData], + ); const handleClick = () => { - const customNode = addBlock(blockData); - setTimeout(() => { - setViewport( - { - x: -customNode.position.x * 0.8 + window.innerWidth / 2, - y: -customNode.position.y * 0.8 + (window.innerHeight - 400) / 2, - zoom: 0.8, + if (isMCPBlock) { + setMcpDialogOpen(true); + return; + } + const customNode = addBlockAndCenter(blockData); + // Set customized_name for agent blocks so the agent's name persists + if (customNode && blockData.id === SpecialBlockID.AGENT) { + updateNodeData(customNode.id, { + metadata: { + ...customNode.data.metadata, + customized_name: blockData.name, }, - { duration: 500 }, - ); - }, 50); + }); + } }; const handleDragStart = (e: React.DragEvent) => { + if (isMCPBlock) return; e.dataTransfer.effectAllowed = "copy"; e.dataTransfer.setData("application/reactflow", JSON.stringify(blockData)); @@ -71,46 +141,56 @@ export const Block: BlockComponent = ({ : undefined; return ( -
- +
+ {title && ( + + {highlightText(beautifyString(title), highlightedText)} + + )} + {description && ( + + {highlightText(description, highlightedText)} + + )} +
+
+ +
+ + {isMCPBlock && ( + setMcpDialogOpen(false)} + onConfirm={handleMCPToolConfirm} + /> + )} + ); }; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/types.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/types.ts index 2fde427330..0f5021351d 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/types.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/types.ts @@ -9,4 +9,5 @@ export enum BlockUIType { AGENT = "Agent", AI = "AI", AYRSHARE = "Ayrshare", + MCP_TOOL = "MCP Tool", } diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index 8e48931540..63a8a856b9 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -4269,6 +4269,128 @@ } } }, + "/api/mcp/discover-tools": { + "post": { + "tags": ["v2", "mcp", "mcp"], + "summary": "Discover available tools on an MCP server", + "description": "Connect to an MCP server and return its available tools.\n\nIf the user has a stored MCP credential for this server URL, it will be\nused automatically — no need to pass an explicit auth token.", + "operationId": "postV2Discover available tools on an mcp server", + "requestBody": { + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/DiscoverToolsRequest" } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DiscoverToolsResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, + "/api/mcp/oauth/callback": { + "post": { + "tags": ["v2", "mcp", "mcp"], + "summary": "Exchange OAuth code for MCP tokens", + "description": "Exchange the authorization code for tokens and store the credential.\n\nThe frontend calls this after receiving the OAuth code from the popup.\nOn success, subsequent ``/discover-tools`` calls for the same server URL\nwill automatically use the stored credential.", + "operationId": "postV2Exchange oauth code for mcp tokens", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/MCPOAuthCallbackRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CredentialsMetaResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, + "/api/mcp/oauth/login": { + "post": { + "tags": ["v2", "mcp", "mcp"], + "summary": "Initiate OAuth login for an MCP server", + "description": "Discover OAuth metadata from the MCP server and return a login URL.\n\n1. Discovers the protected-resource metadata (RFC 9728)\n2. Fetches the authorization server metadata (RFC 8414)\n3. Performs Dynamic Client Registration (RFC 7591) if available\n4. Returns the authorization URL for the frontend to open in a popup", + "operationId": "postV2Initiate oauth login for an mcp server", + "requestBody": { + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/MCPOAuthLoginRequest" } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/MCPOAuthLoginResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, "/api/oauth/app/{client_id}": { "get": { "tags": ["oauth"], @@ -7691,7 +7813,7 @@ "host": { "anyOf": [{ "type": "string" }, { "type": "null" }], "title": "Host", - "description": "Host pattern for host-scoped credentials" + "description": "Host pattern for host-scoped or MCP server URL for MCP credentials" } }, "type": "object", @@ -7711,6 +7833,45 @@ "required": ["version_counts"], "title": "DeleteGraphResponse" }, + "DiscoverToolsRequest": { + "properties": { + "server_url": { + "type": "string", + "title": "Server Url", + "description": "URL of the MCP server" + }, + "auth_token": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Auth Token", + "description": "Optional Bearer token for authenticated MCP servers" + } + }, + "type": "object", + "required": ["server_url"], + "title": "DiscoverToolsRequest", + "description": "Request to discover tools on an MCP server." + }, + "DiscoverToolsResponse": { + "properties": { + "tools": { + "items": { "$ref": "#/components/schemas/MCPToolResponse" }, + "type": "array", + "title": "Tools" + }, + "server_name": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Server Name" + }, + "protocol_version": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Protocol Version" + } + }, + "type": "object", + "required": ["tools"], + "title": "DiscoverToolsResponse", + "description": "Response containing the list of tools available on an MCP server." + }, "DocPageResponse": { "properties": { "type": { @@ -9287,6 +9448,62 @@ "required": ["login_url", "state_token"], "title": "LoginResponse" }, + "MCPOAuthCallbackRequest": { + "properties": { + "code": { + "type": "string", + "title": "Code", + "description": "Authorization code from OAuth callback" + }, + "state_token": { + "type": "string", + "title": "State Token", + "description": "State token for CSRF verification" + } + }, + "type": "object", + "required": ["code", "state_token"], + "title": "MCPOAuthCallbackRequest", + "description": "Request to exchange an OAuth code for tokens." + }, + "MCPOAuthLoginRequest": { + "properties": { + "server_url": { + "type": "string", + "title": "Server Url", + "description": "URL of the MCP server that requires OAuth" + } + }, + "type": "object", + "required": ["server_url"], + "title": "MCPOAuthLoginRequest", + "description": "Request to start an OAuth flow for an MCP server." + }, + "MCPOAuthLoginResponse": { + "properties": { + "login_url": { "type": "string", "title": "Login Url" }, + "state_token": { "type": "string", "title": "State Token" } + }, + "type": "object", + "required": ["login_url", "state_token"], + "title": "MCPOAuthLoginResponse", + "description": "Response with the OAuth login URL for the user to authenticate." + }, + "MCPToolResponse": { + "properties": { + "name": { "type": "string", "title": "Name" }, + "description": { "type": "string", "title": "Description" }, + "input_schema": { + "additionalProperties": true, + "type": "object", + "title": "Input Schema" + } + }, + "type": "object", + "required": ["name", "description", "input_schema"], + "title": "MCPToolResponse", + "description": "A single MCP tool returned by discovery." + }, "MarketplaceListing": { "properties": { "id": { "type": "string", "title": "Id" }, diff --git a/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView.tsx b/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView.tsx index 135a960431..22d0a318a9 100644 --- a/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView.tsx +++ b/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView.tsx @@ -38,13 +38,8 @@ export function CredentialsGroupedView({ const allProviders = useContext(CredentialsProvidersContext); const { userCredentialFields, systemCredentialFields } = useMemo( - () => - splitCredentialFieldsBySystem( - credentialFields, - allProviders, - inputCredentials, - ), - [credentialFields, allProviders, inputCredentials], + () => splitCredentialFieldsBySystem(credentialFields, allProviders), + [credentialFields, allProviders], ); const hasSystemCredentials = systemCredentialFields.length > 0; @@ -86,11 +81,13 @@ export function CredentialsGroupedView({ const providerNames = schema.credentials_provider || []; const credentialTypes = schema.credentials_types || []; const requiredScopes = schema.credentials_scopes; + const discriminatorValues = schema.discriminator_values; const savedCredential = findSavedCredentialByProviderAndType( providerNames, credentialTypes, requiredScopes, allProviders, + discriminatorValues, ); if (savedCredential) { diff --git a/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/helpers.ts b/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/helpers.ts index 5f439d3a32..2d8d001a72 100644 --- a/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/helpers.ts +++ b/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/helpers.ts @@ -23,10 +23,35 @@ function hasRequiredScopes( return true; } +/** Check if a credential matches the discriminator values (e.g. MCP server URL). */ +function matchesDiscriminatorValues( + credential: { host?: string | null; provider: string; type: string }, + discriminatorValues?: string[], +) { + // MCP OAuth2 credentials must match by server URL + if (credential.type === "oauth2" && credential.provider === "mcp") { + if (!discriminatorValues || discriminatorValues.length === 0) return false; + return ( + credential.host != null && discriminatorValues.includes(credential.host) + ); + } + // Host-scoped credentials match by host + if (credential.type === "host_scoped" && credential.host) { + if (!discriminatorValues || discriminatorValues.length === 0) return true; + return discriminatorValues.some((v) => { + try { + return new URL(v).hostname === credential.host; + } catch { + return false; + } + }); + } + return true; +} + export function splitCredentialFieldsBySystem( credentialFields: CredentialField[], allProviders: CredentialsProvidersContextType | null, - inputCredentials?: Record, ) { if (!allProviders || credentialFields.length === 0) { return { @@ -52,17 +77,9 @@ export function splitCredentialFieldsBySystem( } } - const sortByUnsetFirst = (a: CredentialField, b: CredentialField) => { - const aIsSet = Boolean(inputCredentials?.[a[0]]); - const bIsSet = Boolean(inputCredentials?.[b[0]]); - - if (aIsSet === bIsSet) return 0; - return aIsSet ? 1 : -1; - }; - return { - userCredentialFields: userFields.sort(sortByUnsetFirst), - systemCredentialFields: systemFields.sort(sortByUnsetFirst), + userCredentialFields: userFields, + systemCredentialFields: systemFields, }; } @@ -160,6 +177,7 @@ export function findSavedCredentialByProviderAndType( credentialTypes: string[], requiredScopes: string[] | undefined, allProviders: CredentialsProvidersContextType | null, + discriminatorValues?: string[], ): SavedCredential | undefined { for (const providerName of providerNames) { const providerData = allProviders?.[providerName]; @@ -176,9 +194,14 @@ export function findSavedCredentialByProviderAndType( credentialTypes.length === 0 || credentialTypes.includes(credential.type); const scopesMatch = hasRequiredScopes(credential, requiredScopes); + const hostMatches = matchesDiscriminatorValues( + credential, + discriminatorValues, + ); if (!typeMatches) continue; if (!scopesMatch) continue; + if (!hostMatches) continue; matchingCredentials.push(credential as SavedCredential); } @@ -190,9 +213,14 @@ export function findSavedCredentialByProviderAndType( credentialTypes.length === 0 || credentialTypes.includes(credential.type); const scopesMatch = hasRequiredScopes(credential, requiredScopes); + const hostMatches = matchesDiscriminatorValues( + credential, + discriminatorValues, + ); if (!typeMatches) continue; if (!scopesMatch) continue; + if (!hostMatches) continue; matchingCredentials.push(credential as SavedCredential); } @@ -214,6 +242,7 @@ export function findSavedUserCredentialByProviderAndType( credentialTypes: string[], requiredScopes: string[] | undefined, allProviders: CredentialsProvidersContextType | null, + discriminatorValues?: string[], ): SavedCredential | undefined { for (const providerName of providerNames) { const providerData = allProviders?.[providerName]; @@ -230,9 +259,14 @@ export function findSavedUserCredentialByProviderAndType( credentialTypes.length === 0 || credentialTypes.includes(credential.type); const scopesMatch = hasRequiredScopes(credential, requiredScopes); + const hostMatches = matchesDiscriminatorValues( + credential, + discriminatorValues, + ); if (!typeMatches) continue; if (!scopesMatch) continue; + if (!hostMatches) continue; matchingCredentials.push(credential as SavedCredential); } diff --git a/autogpt_platform/frontend/src/components/contextual/CredentialsInput/useCredentialsInput.ts b/autogpt_platform/frontend/src/components/contextual/CredentialsInput/useCredentialsInput.ts index 509713ff1e..9ab2e08141 100644 --- a/autogpt_platform/frontend/src/components/contextual/CredentialsInput/useCredentialsInput.ts +++ b/autogpt_platform/frontend/src/components/contextual/CredentialsInput/useCredentialsInput.ts @@ -5,14 +5,14 @@ import { BlockIOCredentialsSubSchema, CredentialsMetaInput, } from "@/lib/autogpt-server-api/types"; +import { postV2InitiateOauthLoginForAnMcpServer } from "@/app/api/__generated__/endpoints/mcp/mcp"; +import { openOAuthPopup } from "@/lib/oauth-popup"; import { useQueryClient } from "@tanstack/react-query"; import { useEffect, useRef, useState } from "react"; import { filterSystemCredentials, getActionButtonText, getSystemCredentials, - OAUTH_TIMEOUT_MS, - OAuthPopupResultMessage, } from "./helpers"; export type CredentialsInputState = ReturnType; @@ -57,6 +57,14 @@ export function useCredentialsInput({ const queryClient = useQueryClient(); const credentials = useCredentials(schema, siblingInputs); const hasAttemptedAutoSelect = useRef(false); + const oauthAbortRef = useRef<((reason?: string) => void) | null>(null); + + // Clean up on unmount + useEffect(() => { + return () => { + oauthAbortRef.current?.(); + }; + }, []); const deleteCredentialsMutation = useDeleteV1DeleteCredentials({ mutation: { @@ -81,11 +89,14 @@ export function useCredentialsInput({ } }, [credentials, onLoaded]); - // Unselect credential if not available + // Unselect credential if not available in the loaded credential list. + // Skip when no credentials have been loaded yet (empty list could mean + // the provider data hasn't finished loading, not that the credential is invalid). useEffect(() => { if (readOnly) return; if (!credentials || !("savedCredentials" in credentials)) return; const availableCreds = credentials.savedCredentials; + if (availableCreds.length === 0) return; if ( selectedCredential && !availableCreds.some((c) => c.id === selectedCredential.id) @@ -110,7 +121,9 @@ export function useCredentialsInput({ if (hasAttemptedAutoSelect.current) return; hasAttemptedAutoSelect.current = true; - if (isOptional) return; + // Auto-select if exactly one credential matches. + // For optional fields with multiple options, let the user choose. + if (isOptional && savedCreds.length > 1) return; const cred = savedCreds[0]; onSelectCredential({ @@ -148,7 +161,9 @@ export function useCredentialsInput({ supportsHostScoped, savedCredentials, oAuthCallback, + mcpOAuthCallback, isSystemProvider, + discriminatorValue, } = credentials; // Split credentials into user and system @@ -157,72 +172,66 @@ export function useCredentialsInput({ async function handleOAuthLogin() { setOAuthError(null); - const { login_url, state_token } = await api.oAuthLogin( - provider, - schema.credentials_scopes, - ); - setOAuth2FlowInProgress(true); - const popup = window.open(login_url, "_blank", "popup=true"); - if (!popup) { - throw new Error( - "Failed to open popup window. Please allow popups for this site.", + // Abort any previous OAuth flow + oauthAbortRef.current?.(); + + // MCP uses dynamic OAuth discovery per server URL + const isMCP = provider === "mcp" && !!discriminatorValue; + + try { + let login_url: string; + let state_token: string; + + if (isMCP) { + const mcpLoginResponse = await postV2InitiateOauthLoginForAnMcpServer({ + server_url: discriminatorValue!, + }); + if (mcpLoginResponse.status !== 200) throw mcpLoginResponse.data; + ({ login_url, state_token } = mcpLoginResponse.data); + } else { + ({ login_url, state_token } = await api.oAuthLogin( + provider, + schema.credentials_scopes, + )); + } + + setOAuth2FlowInProgress(true); + + const { promise, cleanup } = openOAuthPopup(login_url, { + stateToken: state_token, + useCrossOriginListeners: isMCP, + // Standard OAuth uses "oauth_popup_result", MCP uses "mcp_oauth_result" + acceptMessageTypes: isMCP + ? ["mcp_oauth_result"] + : ["oauth_popup_result"], + }); + + oauthAbortRef.current = cleanup.abort; + // Expose abort signal for the waiting modal's cancel button + const controller = new AbortController(); + cleanup.signal.addEventListener("abort", () => + controller.abort("completed"), ); - } + setOAuthPopupController(controller); - const controller = new AbortController(); - setOAuthPopupController(controller); - controller.signal.onabort = () => { - console.debug("OAuth flow aborted"); - setOAuth2FlowInProgress(false); - popup.close(); - }; + const result = await promise; - const handleMessage = async (e: MessageEvent) => { - console.debug("Message received:", e.data); - if ( - typeof e.data != "object" || - !("message_type" in e.data) || - e.data.message_type !== "oauth_popup_result" - ) { - console.debug("Ignoring irrelevant message"); - return; - } + // Exchange code for tokens via the provider (updates credential cache) + const credentialResult = isMCP + ? await mcpOAuthCallback(result.code, state_token) + : await oAuthCallback(result.code, result.state); - if (!e.data.success) { - console.error("OAuth flow failed:", e.data.message); - setOAuthError(`OAuth flow failed: ${e.data.message}`); - setOAuth2FlowInProgress(false); - return; - } - - if (e.data.state !== state_token) { - console.error("Invalid state token received"); - setOAuthError("Invalid state token received"); - setOAuth2FlowInProgress(false); - return; - } - - try { - console.debug("Processing OAuth callback"); - const credentials = await oAuthCallback(e.data.code, e.data.state); - console.debug("OAuth callback processed successfully"); - - // Check if the credential's scopes match the required scopes + // Check if the credential's scopes match the required scopes (skip for MCP) + if (!isMCP) { const requiredScopes = schema.credentials_scopes; if (requiredScopes && requiredScopes.length > 0) { - const grantedScopes = new Set(credentials.scopes || []); + const grantedScopes = new Set(credentialResult.scopes || []); const hasAllRequiredScopes = new Set(requiredScopes).isSubsetOf( grantedScopes, ); if (!hasAllRequiredScopes) { - console.error( - `Newly created OAuth credential for ${providerName} has insufficient scopes. Required:`, - requiredScopes, - "Granted:", - credentials.scopes, - ); setOAuthError( "Connection failed: the granted permissions don't match what's required. " + "Please contact the application administrator.", @@ -230,38 +239,28 @@ export function useCredentialsInput({ return; } } + } - onSelectCredential({ - id: credentials.id, - type: "oauth2", - title: credentials.title, - provider, - }); - } catch (error) { - console.error("Error in OAuth callback:", error); + onSelectCredential({ + id: credentialResult.id, + type: "oauth2", + title: credentialResult.title, + provider, + }); + } catch (error) { + if (error instanceof Error && error.message === "OAuth flow timed out") { + setOAuthError("OAuth flow timed out"); + } else { setOAuthError( - `Error in OAuth callback: ${ + `OAuth error: ${ error instanceof Error ? error.message : String(error) }`, ); - } finally { - console.debug("Finalizing OAuth flow"); - setOAuth2FlowInProgress(false); - controller.abort("success"); } - }; - - console.debug("Adding message event listener"); - window.addEventListener("message", handleMessage, { - signal: controller.signal, - }); - - setTimeout(() => { - console.debug("OAuth flow timed out"); - controller.abort("timeout"); + } finally { setOAuth2FlowInProgress(false); - setOAuthError("OAuth flow timed out"); - }, OAUTH_TIMEOUT_MS); + oauthAbortRef.current = null; + } } function handleActionButtonClick() { diff --git a/autogpt_platform/frontend/src/hooks/useCredentials.ts b/autogpt_platform/frontend/src/hooks/useCredentials.ts index eda6ab0278..9a78e5b8f4 100644 --- a/autogpt_platform/frontend/src/hooks/useCredentials.ts +++ b/autogpt_platform/frontend/src/hooks/useCredentials.ts @@ -100,6 +100,11 @@ export default function useCredentials( return false; } + // Filter MCP OAuth2 credentials by server URL matching + if (c.type === "oauth2" && c.provider === "mcp") { + return discriminatorValue != null && c.host === discriminatorValue; + } + // Filter by OAuth credentials that have sufficient scopes for this block if (c.type === "oauth2") { const requiredScopes = credsInputSchema.credentials_scopes; diff --git a/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts b/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts index 65625f1cfb..ffc21269e6 100644 --- a/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts +++ b/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts @@ -749,10 +749,12 @@ export enum BlockUIType { AGENT = "Agent", AI = "AI", AYRSHARE = "Ayrshare", + MCP_TOOL = "MCP Tool", } export enum SpecialBlockID { AGENT = "e189baac-8c20-45a1-94a7-55177ea42565", + MCP_TOOL = "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4", SMART_DECISION = "3b191d9f-356f-482d-8238-ba04b6d18381", OUTPUT = "363ae599-353e-4804-937e-b2ee3cef3da4", } diff --git a/autogpt_platform/frontend/src/lib/oauth-popup.ts b/autogpt_platform/frontend/src/lib/oauth-popup.ts new file mode 100644 index 0000000000..2927887751 --- /dev/null +++ b/autogpt_platform/frontend/src/lib/oauth-popup.ts @@ -0,0 +1,177 @@ +/** + * Shared utility for OAuth popup flows with cross-origin support. + * + * Handles BroadcastChannel, postMessage, and localStorage polling + * to reliably receive OAuth callback results even when COOP headers + * sever the window.opener relationship. + */ + +const DEFAULT_TIMEOUT_MS = 5 * 60 * 1000; // 5 minutes + +export type OAuthPopupResult = { + code: string; + state: string; +}; + +export type OAuthPopupOptions = { + /** State token to validate against incoming messages */ + stateToken: string; + /** + * Use BroadcastChannel + localStorage polling for cross-origin OAuth (MCP). + * Standard OAuth only uses postMessage via window.opener. + */ + useCrossOriginListeners?: boolean; + /** BroadcastChannel name (default: "mcp_oauth") */ + broadcastChannelName?: string; + /** localStorage key for cross-origin fallback (default: "mcp_oauth_result") */ + localStorageKey?: string; + /** Message types to accept (default: ["oauth_popup_result", "mcp_oauth_result"]) */ + acceptMessageTypes?: string[]; + /** Timeout in ms (default: 5 minutes) */ + timeout?: number; +}; + +type Cleanup = { + /** Abort the OAuth flow and close the popup */ + abort: (reason?: string) => void; + /** The AbortController signal */ + signal: AbortSignal; +}; + +/** + * Opens an OAuth popup and sets up listeners for the callback result. + * + * Opens a blank popup synchronously (to avoid popup blockers), then navigates + * it to the login URL. Returns a promise that resolves with the OAuth code/state. + * + * @param loginUrl - The OAuth authorization URL to navigate to + * @param options - Configuration for message handling + * @returns Object with `promise` (resolves with OAuth result) and `abort` (cancels flow) + */ +export function openOAuthPopup( + loginUrl: string, + options: OAuthPopupOptions, +): { promise: Promise; cleanup: Cleanup } { + const { + stateToken, + useCrossOriginListeners = false, + broadcastChannelName = "mcp_oauth", + localStorageKey = "mcp_oauth_result", + acceptMessageTypes = ["oauth_popup_result", "mcp_oauth_result"], + timeout = DEFAULT_TIMEOUT_MS, + } = options; + + const controller = new AbortController(); + + // Open popup synchronously (before any async work) to avoid browser popup blockers + const width = 500; + const height = 700; + const left = window.screenX + (window.outerWidth - width) / 2; + const top = window.screenY + (window.outerHeight - height) / 2; + const popup = window.open( + "about:blank", + "_blank", + `width=${width},height=${height},left=${left},top=${top},popup=true,scrollbars=yes`, + ); + + if (popup && !popup.closed) { + popup.location.href = loginUrl; + } else { + // Popup was blocked — open in new tab as fallback + window.open(loginUrl, "_blank"); + } + + // Close popup on abort + controller.signal.addEventListener("abort", () => { + if (popup && !popup.closed) popup.close(); + }); + + // Clear any stale localStorage entry + if (useCrossOriginListeners) { + try { + localStorage.removeItem(localStorageKey); + } catch {} + } + + const promise = new Promise((resolve, reject) => { + let handled = false; + + const handleResult = (data: any) => { + if (handled) return; // Prevent double-handling + + // Validate message type + const messageType = data?.message_type ?? data?.type; + if (!messageType || !acceptMessageTypes.includes(messageType)) return; + + // Validate state token + if (data.state !== stateToken) { + // State mismatch — this message is for a different listener. Ignore silently. + return; + } + + handled = true; + + if (!data.success) { + reject(new Error(data.message || "OAuth authentication failed")); + } else { + resolve({ code: data.code, state: data.state }); + } + + controller.abort("completed"); + }; + + // Listener: postMessage (works for same-origin popups) + window.addEventListener( + "message", + (event: MessageEvent) => { + if (typeof event.data === "object") { + handleResult(event.data); + } + }, + { signal: controller.signal }, + ); + + // Cross-origin listeners for MCP OAuth + if (useCrossOriginListeners) { + // Listener: BroadcastChannel (works across tabs/popups without opener) + try { + const bc = new BroadcastChannel(broadcastChannelName); + bc.onmessage = (event) => handleResult(event.data); + controller.signal.addEventListener("abort", () => bc.close()); + } catch {} + + // Listener: localStorage polling (most reliable cross-tab fallback) + const pollInterval = setInterval(() => { + try { + const stored = localStorage.getItem(localStorageKey); + if (stored) { + const data = JSON.parse(stored); + localStorage.removeItem(localStorageKey); + handleResult(data); + } + } catch {} + }, 500); + controller.signal.addEventListener("abort", () => + clearInterval(pollInterval), + ); + } + + // Timeout + const timeoutId = setTimeout(() => { + if (!handled) { + handled = true; + reject(new Error("OAuth flow timed out")); + controller.abort("timeout"); + } + }, timeout); + controller.signal.addEventListener("abort", () => clearTimeout(timeoutId)); + }); + + return { + promise, + cleanup: { + abort: (reason?: string) => controller.abort(reason || "canceled"), + signal: controller.signal, + }, + }; +} diff --git a/autogpt_platform/frontend/src/middleware.ts b/autogpt_platform/frontend/src/middleware.ts index af1c823295..8cec8a2645 100644 --- a/autogpt_platform/frontend/src/middleware.ts +++ b/autogpt_platform/frontend/src/middleware.ts @@ -18,6 +18,6 @@ export const config = { * Note: /auth/authorize and /auth/integrations/* ARE protected and need * middleware to run for authentication checks. */ - "/((?!_next/static|_next/image|favicon.ico|auth/callback|.*\\.(?:svg|png|jpg|jpeg|gif|webp)$).*)", + "/((?!_next/static|_next/image|favicon.ico|auth/callback|auth/integrations/mcp_callback|.*\\.(?:svg|png|jpg|jpeg|gif|webp)$).*)", ], }; diff --git a/autogpt_platform/frontend/src/providers/agent-credentials/credentials-provider.tsx b/autogpt_platform/frontend/src/providers/agent-credentials/credentials-provider.tsx index e47cc65e13..a426d8f667 100644 --- a/autogpt_platform/frontend/src/providers/agent-credentials/credentials-provider.tsx +++ b/autogpt_platform/frontend/src/providers/agent-credentials/credentials-provider.tsx @@ -8,6 +8,7 @@ import { HostScopedCredentials, UserPasswordCredentials, } from "@/lib/autogpt-server-api"; +import { postV2ExchangeOauthCodeForMcpTokens } from "@/app/api/__generated__/endpoints/mcp/mcp"; import { useBackendAPI } from "@/lib/autogpt-server-api/context"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { toDisplayName } from "@/providers/agent-credentials/helper"; @@ -38,6 +39,11 @@ export type CredentialsProviderData = { code: string, state_token: string, ) => Promise; + /** MCP-specific OAuth callback that uses dynamic per-server OAuth discovery. */ + mcpOAuthCallback: ( + code: string, + state_token: string, + ) => Promise; createAPIKeyCredentials: ( credentials: APIKeyCredentialsCreatable, ) => Promise; @@ -120,6 +126,35 @@ export default function CredentialsProvider({ [api, addCredentials, onFailToast], ); + /** Exchanges an MCP OAuth code for tokens and adds the result to the internal credentials store. */ + const mcpOAuthCallback = useCallback( + async ( + code: string, + state_token: string, + ): Promise => { + try { + const response = await postV2ExchangeOauthCodeForMcpTokens({ + code, + state_token, + }); + if (response.status !== 200) throw response.data; + const credsMeta: CredentialsMetaResponse = { + ...response.data, + title: response.data.title ?? undefined, + scopes: response.data.scopes ?? undefined, + username: response.data.username ?? undefined, + host: response.data.host ?? undefined, + }; + addCredentials("mcp", credsMeta); + return credsMeta; + } catch (error) { + onFailToast("complete MCP OAuth authentication")(error); + throw error; + } + }, + [addCredentials, onFailToast], + ); + /** Wraps `BackendAPI.createAPIKeyCredentials`, and adds the result to the internal credentials store. */ const createAPIKeyCredentials = useCallback( async ( @@ -258,6 +293,7 @@ export default function CredentialsProvider({ isSystemProvider: systemProviders.has(provider), oAuthCallback: (code: string, state_token: string) => oAuthCallback(provider, code, state_token), + mcpOAuthCallback, createAPIKeyCredentials: ( credentials: APIKeyCredentialsCreatable, ) => createAPIKeyCredentials(provider, credentials), @@ -286,6 +322,7 @@ export default function CredentialsProvider({ createHostScopedCredentials, deleteCredentials, oAuthCallback, + mcpOAuthCallback, onFailToast, ]); diff --git a/autogpt_platform/frontend/src/tests/pages/build.page.ts b/autogpt_platform/frontend/src/tests/pages/build.page.ts index 9370288f8e..3bb9552b82 100644 --- a/autogpt_platform/frontend/src/tests/pages/build.page.ts +++ b/autogpt_platform/frontend/src/tests/pages/build.page.ts @@ -528,6 +528,9 @@ export class BuildPage extends BasePage { async getBlocksToSkip(): Promise { return [ (await this.getGithubTriggerBlockDetails()).map((b) => b.id), + // MCP Tool block requires an interactive dialog (server URL + OAuth) before + // it can be placed, so it can't be tested via the standard "add block" flow. + "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4", ].flat(); } diff --git a/docs/integrations/README.md b/docs/integrations/README.md index a471ef3533..c216aa4836 100644 --- a/docs/integrations/README.md +++ b/docs/integrations/README.md @@ -467,6 +467,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim | [Github Update Comment](block-integrations/github/issues.md#github-update-comment) | A block that updates an existing comment on a GitHub issue or pull request | | [Github Update File](block-integrations/github/repo.md#github-update-file) | This block updates an existing file in a GitHub repository | | [Instantiate Code Sandbox](block-integrations/misc.md#instantiate-code-sandbox) | Instantiate a sandbox environment with internet access in which you can execute code with the Execute Code Step block | +| [MCP Tool](block-integrations/mcp/block.md#mcp-tool) | Connect to any MCP server and execute its tools | | [Slant3D Order Webhook](block-integrations/slant3d/webhook.md#slant3d-order-webhook) | This block triggers on Slant3D order status updates and outputs the event details, including tracking information when orders are shipped | ## Media Generation diff --git a/docs/integrations/SUMMARY.md b/docs/integrations/SUMMARY.md index f481ae2e0a..3ad4bf2c6d 100644 --- a/docs/integrations/SUMMARY.md +++ b/docs/integrations/SUMMARY.md @@ -84,6 +84,7 @@ * [Linear Projects](block-integrations/linear/projects.md) * [LLM](block-integrations/llm.md) * [Logic](block-integrations/logic.md) +* [Mcp Block](block-integrations/mcp/block.md) * [Misc](block-integrations/misc.md) * [Notion Create Page](block-integrations/notion/create_page.md) * [Notion Read Database](block-integrations/notion/read_database.md) diff --git a/docs/integrations/block-integrations/mcp/block.md b/docs/integrations/block-integrations/mcp/block.md new file mode 100644 index 0000000000..6858e42e94 --- /dev/null +++ b/docs/integrations/block-integrations/mcp/block.md @@ -0,0 +1,40 @@ +# Mcp Block + +Blocks for connecting to and executing tools on MCP (Model Context Protocol) servers. + + +## MCP Tool + +### What it is +Connect to any MCP server and execute its tools. Provide a server URL, select a tool, and pass arguments dynamically. + +### How it works + +The block uses JSON-RPC 2.0 over HTTP to communicate with MCP servers. When configuring, it sends an `initialize` request followed by `tools/list` to discover available tools and their input schemas. On execution, it calls `tools/call` with the selected tool name and arguments, then extracts text, image, or resource content from the response. + +Authentication is handled via OAuth 2.0 when the server requires it. The block supports optional credentials — public servers work without authentication, while protected servers trigger a standard OAuth flow with PKCE. Tokens are automatically refreshed when they expire. + + +### Inputs + +| Input | Description | Type | Required | +|-------|-------------|------|----------| +| server_url | URL of the MCP server (Streamable HTTP endpoint) | str | Yes | +| selected_tool | The MCP tool to execute | str | No | +| tool_arguments | Arguments to pass to the selected MCP tool. The fields here are defined by the tool's input schema. | Dict[str, Any] | No | + +### Outputs + +| Output | Description | Type | +|--------|-------------|------| +| error | Error message if the tool call failed | str | +| result | The result returned by the MCP tool | Result | + +### Possible use case + +- **Connecting to third-party APIs**: Use an MCP server like Sentry or Linear to query issues, create tickets, or manage projects without building custom integrations. +- **AI-powered tool execution**: Chain MCP tool calls with AI blocks to let agents dynamically discover and use external tools based on task requirements. +- **Data retrieval from knowledge bases**: Connect to MCP servers like DeepWiki to search documentation, retrieve code context, or query structured knowledge bases. + + +--- From ca216dfd7f91ef1e79c3129879b31a8a1d2b79a9 Mon Sep 17 00:00:00 2001 From: Otto Date: Fri, 13 Feb 2026 16:46:23 +0000 Subject: [PATCH 12/16] ci(docs-claude-review): Update comments instead of creating new ones (#12106) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Changes 🏗️ This PR updates the Claude Block Docs Review CI workflow to update existing comments instead of creating new ones on each push. ### What's Changed: 1. **Concurrency group** - Prevents race conditions if the workflow runs twice simultaneously 2. **Comment cleanup step** - Deletes any previous Claude review comment before posting a new one 3. **Marker instruction** - Instructs Claude to include a `` marker in its comment for identification ### Why: Previously, every PR push would create a new review comment, cluttering the PR with multiple comments. Now only the most recent review is shown. ### Testing: 1. Create a PR that triggers this workflow (modify a file in `docs/integrations/` or `autogpt_platform/backend/backend/blocks/`) 2. Verify first run creates comment with marker 3. Push another commit 4. Verify old comment is deleted and new comment is created (not accumulated) Requested by @Bentlybro --- ## Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [ ] I have made a test plan - [ ] I have tested my changes according to the test plan (will be tested on merge) #### For configuration changes: - [x] `.env.default` is updated or already compatible with my changes - [x] `docker-compose.yml` is updated or already compatible with my changes - [x] I have included a list of my configuration changes in the PR description (under **Changes**)

Greptile Overview

Greptile Summary

Added concurrency control and comment deduplication to prevent multiple Claude review comments from accumulating on PRs. The workflow now deletes previous review comments (identified by `` marker) before posting new ones, and uses concurrency groups to prevent race conditions.

Confidence Score: 5/5

- This PR is safe to merge with minimal risk - The changes are well-contained, follow GitHub Actions best practices, and use built-in GitHub APIs safely. The concurrency control prevents race conditions, and the comment cleanup logic uses proper filtering with `head -1` to handle edge cases. The HTML comment marker approach is standard and reliable. - No files require special attention

Sequence Diagram

```mermaid sequenceDiagram participant GH as GitHub PR Event participant WF as Workflow participant API as GitHub API participant Claude as Claude Action GH->>WF: PR opened/synchronized WF->>WF: Check concurrency group Note over WF: Cancel any in-progress runs
for same PR number WF->>API: Query PR comments API-->>WF: Return all comments WF->>WF: Filter for CLAUDE_DOCS_REVIEW marker alt Previous comment exists WF->>API: DELETE comment by ID API-->>WF: Comment deleted else No previous comment WF->>WF: Skip deletion end WF->>Claude: Run code review Claude->>API: POST new comment with marker API-->>Claude: Comment created ```
Last reviewed commit: fb1b436 --- .github/workflows/docs-claude-review.yml | 34 ++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/.github/workflows/docs-claude-review.yml b/.github/workflows/docs-claude-review.yml index ca2788b387..19d5dd667b 100644 --- a/.github/workflows/docs-claude-review.yml +++ b/.github/workflows/docs-claude-review.yml @@ -7,6 +7,10 @@ on: - "docs/integrations/**" - "autogpt_platform/backend/backend/blocks/**" +concurrency: + group: claude-docs-review-${{ github.event.pull_request.number }} + cancel-in-progress: true + jobs: claude-review: # Only run for PRs from members/collaborators @@ -91,5 +95,35 @@ jobs: 3. Read corresponding documentation files to verify accuracy 4. Provide your feedback as a PR comment + ## IMPORTANT: Comment Marker + Start your PR comment with exactly this HTML comment marker on its own line: + + + This marker is used to identify and replace your comment on subsequent runs. + Be constructive and specific. If everything looks good, say so! If there are issues, explain what's wrong and suggest how to fix it. + + - name: Delete old Claude review comments + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + # Get all comment IDs with our marker, sorted by creation date (oldest first) + COMMENT_IDS=$(gh api \ + repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments \ + --jq '[.[] | select(.body | contains(""))] | sort_by(.created_at) | .[].id') + + # Count comments + COMMENT_COUNT=$(echo "$COMMENT_IDS" | grep -c . || true) + + if [ "$COMMENT_COUNT" -gt 1 ]; then + # Delete all but the last (newest) comment + echo "$COMMENT_IDS" | head -n -1 | while read -r COMMENT_ID; do + if [ -n "$COMMENT_ID" ]; then + echo "Deleting old review comment: $COMMENT_ID" + gh api -X DELETE repos/${{ github.repository }}/issues/comments/$COMMENT_ID + fi + done + else + echo "No old review comments to clean up" + fi From b8f5c208d08e313306ad3ee87020d8746d9afbb4 Mon Sep 17 00:00:00 2001 From: DEEVEN SERU <144827577+DEVELOPER-DEEVEN@users.noreply.github.com> Date: Sat, 14 Feb 2026 00:45:09 +0530 Subject: [PATCH 13/16] Handle errors in Jina ExtractWebsiteContentBlock (#12048) ## Summary - catch Jina reader client/server errors in ExtractWebsiteContentBlock and surface a clear error output keyed to the user URL - guard empty responses to return an explicit error instead of yielding blank content - add regression tests covering the happy path and HTTP client failures via a monkeypatched fetch ## Testing - not run (pytest unavailable in this environment) --------- Co-authored-by: Nicholas Tindle Co-authored-by: Nicholas Tindle --- .../backend/backend/blocks/jina/search.py | 25 ++++++- .../test/blocks/test_jina_extract_website.py | 66 +++++++++++++++++++ 2 files changed, 89 insertions(+), 2 deletions(-) create mode 100644 autogpt_platform/backend/test/blocks/test_jina_extract_website.py diff --git a/autogpt_platform/backend/backend/blocks/jina/search.py b/autogpt_platform/backend/backend/blocks/jina/search.py index 22a883fa03..5e58ddcab4 100644 --- a/autogpt_platform/backend/backend/blocks/jina/search.py +++ b/autogpt_platform/backend/backend/blocks/jina/search.py @@ -17,6 +17,7 @@ from backend.blocks.jina._auth import ( from backend.blocks.search import GetRequest from backend.data.model import SchemaField from backend.util.exceptions import BlockExecutionError +from backend.util.request import HTTPClientError, HTTPServerError, validate_url class SearchTheWebBlock(Block, GetRequest): @@ -110,7 +111,12 @@ class ExtractWebsiteContentBlock(Block, GetRequest): self, input_data: Input, *, credentials: JinaCredentials, **kwargs ) -> BlockOutput: if input_data.raw_content: - url = input_data.url + try: + parsed_url, _, _ = await validate_url(input_data.url, []) + url = parsed_url.geturl() + except ValueError as e: + yield "error", f"Invalid URL: {e}" + return headers = {} else: url = f"https://r.jina.ai/{input_data.url}" @@ -119,5 +125,20 @@ class ExtractWebsiteContentBlock(Block, GetRequest): "Authorization": f"Bearer {credentials.api_key.get_secret_value()}", } - content = await self.get_request(url, json=False, headers=headers) + try: + content = await self.get_request(url, json=False, headers=headers) + except HTTPClientError as e: + yield "error", f"Client error ({e.status_code}) fetching {input_data.url}: {e}" + return + except HTTPServerError as e: + yield "error", f"Server error ({e.status_code}) fetching {input_data.url}: {e}" + return + except Exception as e: + yield "error", f"Failed to fetch {input_data.url}: {e}" + return + + if not content: + yield "error", f"No content returned for {input_data.url}" + return + yield "content", content diff --git a/autogpt_platform/backend/test/blocks/test_jina_extract_website.py b/autogpt_platform/backend/test/blocks/test_jina_extract_website.py new file mode 100644 index 0000000000..335c43f966 --- /dev/null +++ b/autogpt_platform/backend/test/blocks/test_jina_extract_website.py @@ -0,0 +1,66 @@ +from typing import cast + +import pytest + +from backend.blocks.jina._auth import ( + TEST_CREDENTIALS, + TEST_CREDENTIALS_INPUT, + JinaCredentialsInput, +) +from backend.blocks.jina.search import ExtractWebsiteContentBlock +from backend.util.request import HTTPClientError + + +@pytest.mark.asyncio +async def test_extract_website_content_returns_content(monkeypatch): + block = ExtractWebsiteContentBlock() + input_data = block.Input( + url="https://example.com", + credentials=cast(JinaCredentialsInput, TEST_CREDENTIALS_INPUT), + raw_content=True, + ) + + async def fake_get_request(url, json=False, headers=None): + assert url == "https://example.com" + assert headers == {} + return "page content" + + monkeypatch.setattr(block, "get_request", fake_get_request) + + results = [ + output + async for output in block.run( + input_data=input_data, credentials=TEST_CREDENTIALS + ) + ] + + assert ("content", "page content") in results + assert all(key != "error" for key, _ in results) + + +@pytest.mark.asyncio +async def test_extract_website_content_handles_http_error(monkeypatch): + block = ExtractWebsiteContentBlock() + input_data = block.Input( + url="https://example.com", + credentials=cast(JinaCredentialsInput, TEST_CREDENTIALS_INPUT), + raw_content=False, + ) + + async def fake_get_request(url, json=False, headers=None): + raise HTTPClientError("HTTP 400 Error: Bad Request", 400) + + monkeypatch.setattr(block, "get_request", fake_get_request) + + results = [ + output + async for output in block.run( + input_data=input_data, credentials=TEST_CREDENTIALS + ) + ] + + assert ("content", "page content") not in results + error_messages = [value for key, value in results if key == "error"] + assert error_messages + assert "Client error (400)" in error_messages[0] + assert "https://example.com" in error_messages[0] From 27d94e395cc8d191a1b284dda2b42572ed9d4796 Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Sun, 15 Feb 2026 10:51:25 +0400 Subject: [PATCH 14/16] feat(backend/sdk): enable WebSearch, block WebFetch, consolidate tool constants (#12108) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Enable Claude Agent SDK built-in **WebSearch** tool (Brave Search via Anthropic API) for the CoPilot SDK agent - Explicitly **block WebFetch** via `SDK_DISALLOWED_TOOLS`. The agent uses the SSRF-protected `mcp__copilot__web_fetch` MCP tool instead - **Consolidate** all tool security constants (`BLOCKED_TOOLS`, `WORKSPACE_SCOPED_TOOLS`, `DANGEROUS_PATTERNS`, `SDK_DISALLOWED_TOOLS`) into `tool_adapter.py` as a single source of truth — previously scattered across `tool_adapter.py`, `security_hooks.py`, and inline in `service.py` ## Changes - `tool_adapter.py`: Add `WebSearch` to `_SDK_BUILTIN_TOOLS`, add `SDK_DISALLOWED_TOOLS`, move security constants here - `security_hooks.py`: Import constants from `tool_adapter.py` instead of defining locally - `service.py`: Use `SDK_DISALLOWED_TOOLS` instead of inline `["Bash"]` ## Test plan - [x] All 21 security hooks tests pass - [x] Ruff lint clean - [x] All pre-commit hooks pass - [ ] Verify WebSearch works in CoPilot chat (manual test)

Greptile Overview

Greptile Summary

Consolidates tool security constants into `tool_adapter.py` as single source of truth, enables WebSearch (Brave via Anthropic API), and explicitly blocks WebFetch to prevent SSRF attacks. The change improves security by ensuring the agent uses the SSRF-protected `mcp__copilot__web_fetch` tool instead of the built-in WebFetch which can access internal networks like `localhost:8006`.

Confidence Score: 5/5

- This PR is safe to merge with minimal risk - The changes improve security by blocking WebFetch (SSRF risk) while enabling safe WebSearch. The consolidation of constants into a single source of truth improves maintainability. All existing tests pass (21 security hooks tests), and the refactoring is straightforward with no behavioral changes to existing security logic. The only suggestions are minor improvements: adding a test for WebFetch blocking and considering a lowercase alias for consistency. - No files require special attention

Sequence Diagram

```mermaid sequenceDiagram participant Agent as SDK Agent participant Hooks as Security Hooks participant TA as tool_adapter.py participant MCP as MCP Tools Note over TA: SDK_DISALLOWED_TOOLS = ["Bash", "WebFetch"] Note over TA: _SDK_BUILTIN_TOOLS includes WebSearch Agent->>Hooks: Request WebSearch (Brave API) Hooks->>TA: Check BLOCKED_TOOLS TA-->>Hooks: Not blocked Hooks-->>Agent: Allowed ✓ Agent->>Agent: Execute via Anthropic API Agent->>Hooks: Request WebFetch (SSRF risk) Hooks->>TA: Check BLOCKED_TOOLS Note over TA: WebFetch in SDK_DISALLOWED_TOOLS TA-->>Hooks: Blocked Hooks-->>Agent: Denied ✗ Note over Agent: Use mcp__copilot__web_fetch instead Agent->>Hooks: Request mcp__copilot__web_fetch Hooks->>MCP: Validate (MCP tool, not SDK builtin) MCP-->>Hooks: Has SSRF protection Hooks-->>Agent: Allowed ✓ Agent->>MCP: Execute with SSRF checks ```
Last reviewed commit: 2d9975f --- .../api/features/chat/sdk/security_hooks.py | 42 +++--------------- .../backend/api/features/chat/sdk/service.py | 3 +- .../api/features/chat/sdk/tool_adapter.py | 43 ++++++++++++++++++- 3 files changed, 50 insertions(+), 38 deletions(-) diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py b/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py index 14efc6d459..89853402a3 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py @@ -11,45 +11,15 @@ import re from collections.abc import Callable from typing import Any, cast -from backend.api.features.chat.sdk.tool_adapter import MCP_TOOL_PREFIX +from backend.api.features.chat.sdk.tool_adapter import ( + BLOCKED_TOOLS, + DANGEROUS_PATTERNS, + MCP_TOOL_PREFIX, + WORKSPACE_SCOPED_TOOLS, +) logger = logging.getLogger(__name__) -# Tools that are blocked entirely (CLI/system access). -# "Bash" (capital) is the SDK built-in — it's NOT in allowed_tools but blocked -# here as defence-in-depth. The agent uses mcp__copilot__bash_exec instead, -# which has kernel-level network isolation (unshare --net). -BLOCKED_TOOLS = { - "Bash", - "bash", - "shell", - "exec", - "terminal", - "command", -} - -# Tools allowed only when their path argument stays within the SDK workspace. -# The SDK uses these to handle oversized tool results (writes to tool-results/ -# files, then reads them back) and for workspace file operations. -WORKSPACE_SCOPED_TOOLS = {"Read", "Write", "Edit", "Glob", "Grep"} - -# Dangerous patterns in tool inputs -DANGEROUS_PATTERNS = [ - r"sudo", - r"rm\s+-rf", - r"dd\s+if=", - r"/etc/passwd", - r"/etc/shadow", - r"chmod\s+777", - r"curl\s+.*\|.*sh", - r"wget\s+.*\|.*sh", - r"eval\s*\(", - r"exec\s*\(", - r"__import__", - r"os\.system", - r"subprocess", -] - def _deny(reason: str) -> dict[str, Any]: """Return a hook denial response.""" diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/service.py b/autogpt_platform/backend/backend/api/features/chat/sdk/service.py index 65195b442c..65c4cebb06 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/service.py @@ -41,6 +41,7 @@ from .response_adapter import SDKResponseAdapter from .security_hooks import create_security_hooks from .tool_adapter import ( COPILOT_TOOL_NAMES, + SDK_DISALLOWED_TOOLS, LongRunningCallback, create_copilot_mcp_server, set_execution_context, @@ -543,7 +544,7 @@ async def stream_chat_completion_sdk( "system_prompt": system_prompt, "mcp_servers": {"copilot": mcp_server}, "allowed_tools": COPILOT_TOOL_NAMES, - "disallowed_tools": ["Bash"], + "disallowed_tools": SDK_DISALLOWED_TOOLS, "hooks": security_hooks, "cwd": sdk_cwd, "max_buffer_size": config.claude_agent_max_buffer_size, diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py b/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py index d983d5e785..2d259730bf 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py @@ -310,7 +310,48 @@ def create_copilot_mcp_server(): # Bash is NOT included — use the sandboxed MCP bash_exec tool instead, # which provides kernel-level network isolation via unshare --net. # Task allows spawning sub-agents (rate-limited by security hooks). -_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Task"] +# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk. +_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Task", "WebSearch"] + +# SDK built-in tools that must be explicitly blocked. +# Bash: dangerous — agent uses mcp__copilot__bash_exec with kernel-level +# network isolation (unshare --net) instead. +# WebFetch: SSRF risk — can reach internal network (localhost, 10.x, etc.). +# Agent uses the SSRF-protected mcp__copilot__web_fetch tool instead. +SDK_DISALLOWED_TOOLS = ["Bash", "WebFetch"] + +# Tools that are blocked entirely in security hooks (defence-in-depth). +# Includes SDK_DISALLOWED_TOOLS plus common aliases/synonyms. +BLOCKED_TOOLS = { + *SDK_DISALLOWED_TOOLS, + "bash", + "shell", + "exec", + "terminal", + "command", +} + +# Tools allowed only when their path argument stays within the SDK workspace. +# The SDK uses these to handle oversized tool results (writes to tool-results/ +# files, then reads them back) and for workspace file operations. +WORKSPACE_SCOPED_TOOLS = {"Read", "Write", "Edit", "Glob", "Grep"} + +# Dangerous patterns in tool inputs +DANGEROUS_PATTERNS = [ + r"sudo", + r"rm\s+-rf", + r"dd\s+if=", + r"/etc/passwd", + r"/etc/shadow", + r"chmod\s+777", + r"curl\s+.*\|.*sh", + r"wget\s+.*\|.*sh", + r"eval\s*\(", + r"exec\s*\(", + r"__import__", + r"os\.system", + r"subprocess", +] # List of tool names for allowed_tools configuration # Include MCP tools, the MCP Read tool for oversized results, From 647c8ed8d46b133cff44572cb99f8e683f9083b2 Mon Sep 17 00:00:00 2001 From: Eve <162624394+aviu16@users.noreply.github.com> Date: Mon, 16 Feb 2026 00:39:53 -0500 Subject: [PATCH 15/16] feat(backend/blocks): enhance list concatenation with advanced operations (#12105) ## Summary Enhances the existing `ConcatenateListsBlock` and adds five new companion blocks for comprehensive list manipulation, addressing issue #11139 ("Implement block to concatenate lists"). ### Changes - **Enhanced `ConcatenateListsBlock`** with optional deduplication (`deduplicate`) and None-value filtering (`remove_none`), plus an output `length` field - **New `FlattenListBlock`**: Recursively flattens nested list structures with configurable `max_depth` - **New `InterleaveListsBlock`**: Round-robin interleaving of elements from multiple lists - **New `ZipListsBlock`**: Zips corresponding elements from multiple lists with support for padding to longest or truncating to shortest - **New `ListDifferenceBlock`**: Computes set difference between two lists (regular or symmetric) - **New `ListIntersectionBlock`**: Finds common elements between two lists, preserving order ### Helper Utilities Extracted reusable helper functions for validation, flattening, deduplication, interleaving, chunking, and statistics computation to support the blocks and enable future reuse. ### Test Coverage Comprehensive test suite with 188 test functions across 29 test classes covering: - Built-in block test harness validation for all 6 blocks - Manual edge-case tests for each block (empty inputs, large lists, mixed types, nested structures) - Internal method tests for all block classes - Unit tests for all helper utility functions Closes #11139 ## Test plan - [x] All files pass Python syntax validation (`ast.parse`) - [x] Built-in `test_input`/`test_output` tests defined for all blocks - [x] Manual tests cover edge cases: empty lists, large lists, mixed types, nested structures, deduplication, None removal - [x] Helper function tests validate all utility functions independently - [x] All block IDs are valid UUID4 - [x] Block categories set to `BlockCategory.BASIC` for consistency with existing list blocks

Greptile Overview

Greptile Summary

Enhanced `ConcatenateListsBlock` with deduplication and None-filtering options, and added five new list manipulation blocks (`FlattenListBlock`, `InterleaveListsBlock`, `ZipListsBlock`, `ListDifferenceBlock`, `ListIntersectionBlock`) with comprehensive helper functions and test coverage. **Key Changes:** - Enhanced `ConcatenateListsBlock` with `deduplicate` and `remove_none` options, plus `length` output field - Added `FlattenListBlock` for recursively flattening nested lists with configurable `max_depth` - Added `InterleaveListsBlock` for round-robin element interleaving - Added `ZipListsBlock` with support for padding/truncation - Added `ListDifferenceBlock` and `ListIntersectionBlock` for set operations - Extracted 12 reusable helper functions for validation, flattening, deduplication, etc. - Comprehensive test suite with 188 test functions covering edge cases **Minor Issues:** - Helper function `_deduplicate_list` has redundant logic in the `else` branch that duplicates the `if` branch - Three helper functions (`_filter_empty_collections`, `_compute_list_statistics`, `_chunk_list`) are defined but unused - consider removing unless planned for future use - The `_make_hashable` function uses `hash(repr(item))` for unhashable types, which correctly treats structurally identical dicts/lists as duplicates

Confidence Score: 4/5

- Safe to merge with minor style improvements recommended - The implementation is well-structured with comprehensive test coverage (188 tests), proper error handling, and follows existing block patterns. All blocks use valid UUID4 IDs and correct categories. The helper functions provide good code reuse. The minor issues are purely stylistic (redundant code, unused helpers) and don't affect functionality or safety. - No files require special attention - both files are well-tested and follow project conventions

Sequence Diagram

```mermaid sequenceDiagram participant User participant Block as List Block participant Helper as Helper Functions participant Output User->>Block: Input (lists/parameters) Block->>Helper: _validate_all_lists() Helper-->>Block: validation result alt validation fails Block->>Output: error message else validation succeeds Block->>Helper: _concatenate_lists_simple() / _flatten_nested_list() / etc. Helper-->>Block: processed result opt deduplicate enabled Block->>Helper: _deduplicate_list() Helper-->>Block: deduplicated result end opt remove_none enabled Block->>Helper: _filter_none_values() Helper-->>Block: filtered result end Block->>Output: result + length end Output-->>User: Block outputs ```
Last reviewed commit: a6d5445 (2/5) Greptile learns from your feedback when you react with thumbs up/down! --------- Co-authored-by: Otto --- .../backend/blocks/data_manipulation.py | 704 ++++++++- .../test/blocks/test_list_concatenation.py | 1276 +++++++++++++++++ docs/integrations/README.md | 5 + docs/integrations/block-integrations/basic.md | 197 ++- 4 files changed, 2164 insertions(+), 18 deletions(-) create mode 100644 autogpt_platform/backend/test/blocks/test_list_concatenation.py diff --git a/autogpt_platform/backend/backend/blocks/data_manipulation.py b/autogpt_platform/backend/backend/blocks/data_manipulation.py index a8f25ecb18..fe878acfa9 100644 --- a/autogpt_platform/backend/backend/blocks/data_manipulation.py +++ b/autogpt_platform/backend/backend/blocks/data_manipulation.py @@ -682,17 +682,219 @@ class ListIsEmptyBlock(Block): yield "is_empty", len(input_data.list) == 0 +# ============================================================================= +# List Concatenation Helpers +# ============================================================================= + + +def _validate_list_input(item: Any, index: int) -> str | None: + """Validate that an item is a list. Returns error message or None.""" + if item is None: + return None # None is acceptable, will be skipped + if not isinstance(item, list): + return ( + f"Invalid input at index {index}: expected a list, " + f"got {type(item).__name__}. " + f"All items in 'lists' must be lists (e.g., [[1, 2], [3, 4]])." + ) + return None + + +def _validate_all_lists(lists: List[Any]) -> str | None: + """Validate that all items in a sequence are lists. Returns first error or None.""" + for idx, item in enumerate(lists): + error = _validate_list_input(item, idx) + if error is not None and item is not None: + return error + return None + + +def _concatenate_lists_simple(lists: List[List[Any]]) -> List[Any]: + """Concatenate a sequence of lists into a single list, skipping None values.""" + result: List[Any] = [] + for lst in lists: + if lst is None: + continue + result.extend(lst) + return result + + +def _flatten_nested_list(nested: List[Any], max_depth: int = -1) -> List[Any]: + """ + Recursively flatten a nested list structure. + + Args: + nested: The list to flatten. + max_depth: Maximum recursion depth. -1 means unlimited. + + Returns: + A flat list with all nested elements extracted. + """ + result: List[Any] = [] + _flatten_recursive(nested, result, current_depth=0, max_depth=max_depth) + return result + + +_MAX_FLATTEN_DEPTH = 1000 + + +def _flatten_recursive( + items: List[Any], + result: List[Any], + current_depth: int, + max_depth: int, +) -> None: + """Internal recursive helper for flattening nested lists.""" + if current_depth > _MAX_FLATTEN_DEPTH: + raise RecursionError( + f"Flattening exceeded maximum depth of {_MAX_FLATTEN_DEPTH} levels. " + "Input may be too deeply nested." + ) + for item in items: + if isinstance(item, list) and (max_depth == -1 or current_depth < max_depth): + _flatten_recursive(item, result, current_depth + 1, max_depth) + else: + result.append(item) + + +def _deduplicate_list(items: List[Any]) -> List[Any]: + """ + Remove duplicate elements from a list, preserving order of first occurrences. + + Args: + items: The list to deduplicate. + + Returns: + A list with duplicates removed, maintaining original order. + """ + seen: set = set() + result: List[Any] = [] + for item in items: + item_id = _make_hashable(item) + if item_id not in seen: + seen.add(item_id) + result.append(item) + return result + + +def _make_hashable(item: Any): + """ + Create a hashable representation of any item for deduplication. + Converts unhashable types (dicts, lists) into deterministic tuple structures. + """ + if isinstance(item, dict): + return tuple( + sorted( + ((_make_hashable(k), _make_hashable(v)) for k, v in item.items()), + key=lambda x: (str(type(x[0])), str(x[0])), + ) + ) + if isinstance(item, (list, tuple)): + return tuple(_make_hashable(i) for i in item) + if isinstance(item, set): + return frozenset(_make_hashable(i) for i in item) + return item + + +def _filter_none_values(items: List[Any]) -> List[Any]: + """Remove None values from a list.""" + return [item for item in items if item is not None] + + +def _compute_nesting_depth( + items: Any, current: int = 0, max_depth: int = _MAX_FLATTEN_DEPTH +) -> int: + """ + Compute the maximum nesting depth of a list structure using iteration to avoid RecursionError. + + Uses a stack-based approach to handle deeply nested structures without hitting Python's + recursion limit (~1000 levels). + """ + if not isinstance(items, list): + return current + + # Stack contains tuples of (item, depth) + stack = [(items, current)] + max_observed_depth = current + + while stack: + item, depth = stack.pop() + + if depth > max_depth: + return depth + + if not isinstance(item, list): + max_observed_depth = max(max_observed_depth, depth) + continue + + if len(item) == 0: + max_observed_depth = max(max_observed_depth, depth + 1) + continue + + # Add all children to stack with incremented depth + for child in item: + stack.append((child, depth + 1)) + + return max_observed_depth + + +def _interleave_lists(lists: List[List[Any]]) -> List[Any]: + """ + Interleave elements from multiple lists in round-robin fashion. + Example: [[1,2,3], [a,b], [x,y,z]] -> [1, a, x, 2, b, y, 3, z] + """ + if not lists: + return [] + filtered = [lst for lst in lists if lst is not None] + if not filtered: + return [] + result: List[Any] = [] + max_len = max(len(lst) for lst in filtered) + for i in range(max_len): + for lst in filtered: + if i < len(lst): + result.append(lst[i]) + return result + + +# ============================================================================= +# List Concatenation Blocks +# ============================================================================= + + class ConcatenateListsBlock(Block): + """ + Concatenates two or more lists into a single list. + + This block accepts a list of lists and combines all their elements + in order into one flat output list. It supports options for + deduplication and None-filtering to provide flexible list merging + capabilities for workflow pipelines. + """ + class Input(BlockSchemaInput): lists: List[List[Any]] = SchemaField( description="A list of lists to concatenate together. All lists will be combined in order into a single list.", placeholder="e.g., [[1, 2], [3, 4], [5, 6]]", ) + deduplicate: bool = SchemaField( + description="If True, remove duplicate elements from the concatenated result while preserving order.", + default=False, + advanced=True, + ) + remove_none: bool = SchemaField( + description="If True, remove None values from the concatenated result.", + default=False, + advanced=True, + ) class Output(BlockSchemaOutput): concatenated_list: List[Any] = SchemaField( description="The concatenated list containing all elements from all input lists in order." ) + length: int = SchemaField( + description="The total number of elements in the concatenated list." + ) error: str = SchemaField( description="Error message if concatenation failed due to invalid input types." ) @@ -700,7 +902,7 @@ class ConcatenateListsBlock(Block): def __init__(self): super().__init__( id="3cf9298b-5817-4141-9d80-7c2cc5199c8e", - description="Concatenates multiple lists into a single list. All elements from all input lists are combined in order.", + description="Concatenates multiple lists into a single list. All elements from all input lists are combined in order. Supports optional deduplication and None removal.", categories={BlockCategory.BASIC}, input_schema=ConcatenateListsBlock.Input, output_schema=ConcatenateListsBlock.Output, @@ -709,29 +911,497 @@ class ConcatenateListsBlock(Block): {"lists": [["a", "b"], ["c"], ["d", "e", "f"]]}, {"lists": [[1, 2], []]}, {"lists": []}, + {"lists": [[1, 2, 2, 3], [3, 4]], "deduplicate": True}, + {"lists": [[1, None, 2], [None, 3]], "remove_none": True}, ], test_output=[ ("concatenated_list", [1, 2, 3, 4, 5, 6]), + ("length", 6), ("concatenated_list", ["a", "b", "c", "d", "e", "f"]), + ("length", 6), ("concatenated_list", [1, 2]), + ("length", 2), ("concatenated_list", []), + ("length", 0), + ("concatenated_list", [1, 2, 3, 4]), + ("length", 4), + ("concatenated_list", [1, 2, 3]), + ("length", 3), ], ) + def _validate_inputs(self, lists: List[Any]) -> str | None: + return _validate_all_lists(lists) + + def _perform_concatenation(self, lists: List[List[Any]]) -> List[Any]: + return _concatenate_lists_simple(lists) + + def _apply_deduplication(self, items: List[Any]) -> List[Any]: + return _deduplicate_list(items) + + def _apply_none_removal(self, items: List[Any]) -> List[Any]: + return _filter_none_values(items) + + def _post_process( + self, items: List[Any], deduplicate: bool, remove_none: bool + ) -> List[Any]: + """Apply all post-processing steps to the concatenated result.""" + result = items + if remove_none: + result = self._apply_none_removal(result) + if deduplicate: + result = self._apply_deduplication(result) + return result + async def run(self, input_data: Input, **kwargs) -> BlockOutput: - concatenated = [] - for idx, lst in enumerate(input_data.lists): - if lst is None: - # Skip None values to avoid errors - continue - if not isinstance(lst, list): - # Type validation: each item must be a list - # Strings are iterable and would cause extend() to iterate character-by-character - # Non-iterable types would raise TypeError - yield "error", ( - f"Invalid input at index {idx}: expected a list, got {type(lst).__name__}. " - f"All items in 'lists' must be lists (e.g., [[1, 2], [3, 4]])." - ) - return - concatenated.extend(lst) - yield "concatenated_list", concatenated + # Validate all inputs are lists + validation_error = self._validate_inputs(input_data.lists) + if validation_error is not None: + yield "error", validation_error + return + + # Perform concatenation + concatenated = self._perform_concatenation(input_data.lists) + + # Apply post-processing + result = self._post_process( + concatenated, input_data.deduplicate, input_data.remove_none + ) + + yield "concatenated_list", result + yield "length", len(result) + + +class FlattenListBlock(Block): + """ + Flattens a nested list structure into a single flat list. + + This block takes a list that may contain nested lists at any depth + and produces a single-level list with all leaf elements. Useful + for normalizing data structures from multiple sources that may + have varying levels of nesting. + """ + + class Input(BlockSchemaInput): + nested_list: List[Any] = SchemaField( + description="A potentially nested list to flatten into a single-level list.", + placeholder="e.g., [[1, [2, 3]], [4, [5, [6]]]]", + ) + max_depth: int = SchemaField( + description="Maximum depth to flatten. -1 means flatten completely. 1 means flatten only one level.", + default=-1, + advanced=True, + ) + + class Output(BlockSchemaOutput): + flattened_list: List[Any] = SchemaField( + description="The flattened list with all nested elements extracted." + ) + length: int = SchemaField( + description="The number of elements in the flattened list." + ) + original_depth: int = SchemaField( + description="The maximum nesting depth of the original input list." + ) + error: str = SchemaField(description="Error message if flattening failed.") + + def __init__(self): + super().__init__( + id="cc45bb0f-d035-4756-96a7-fe3e36254b4d", + description="Flattens a nested list structure into a single flat list. Supports configurable maximum flattening depth.", + categories={BlockCategory.BASIC}, + input_schema=FlattenListBlock.Input, + output_schema=FlattenListBlock.Output, + test_input=[ + {"nested_list": [[1, 2], [3, [4, 5]]]}, + {"nested_list": [1, [2, [3, [4]]]]}, + {"nested_list": [1, [2, [3, [4]]], 5], "max_depth": 1}, + {"nested_list": []}, + {"nested_list": [1, 2, 3]}, + ], + test_output=[ + ("flattened_list", [1, 2, 3, 4, 5]), + ("length", 5), + ("original_depth", 3), + ("flattened_list", [1, 2, 3, 4]), + ("length", 4), + ("original_depth", 4), + ("flattened_list", [1, 2, [3, [4]], 5]), + ("length", 4), + ("original_depth", 4), + ("flattened_list", []), + ("length", 0), + ("original_depth", 1), + ("flattened_list", [1, 2, 3]), + ("length", 3), + ("original_depth", 1), + ], + ) + + def _compute_depth(self, items: List[Any]) -> int: + """Compute the nesting depth of the input list.""" + return _compute_nesting_depth(items) + + def _flatten(self, items: List[Any], max_depth: int) -> List[Any]: + """Flatten the list to the specified depth.""" + return _flatten_nested_list(items, max_depth=max_depth) + + def _validate_max_depth(self, max_depth: int) -> str | None: + """Validate the max_depth parameter.""" + if max_depth < -1: + return f"max_depth must be -1 (unlimited) or a non-negative integer, got {max_depth}" + return None + + async def run(self, input_data: Input, **kwargs) -> BlockOutput: + # Validate max_depth + depth_error = self._validate_max_depth(input_data.max_depth) + if depth_error is not None: + yield "error", depth_error + return + + original_depth = self._compute_depth(input_data.nested_list) + flattened = self._flatten(input_data.nested_list, input_data.max_depth) + + yield "flattened_list", flattened + yield "length", len(flattened) + yield "original_depth", original_depth + + +class InterleaveListsBlock(Block): + """ + Interleaves elements from multiple lists in round-robin fashion. + + Given multiple input lists, this block takes one element from each + list in turn, producing an output where elements alternate between + sources. Lists of different lengths are handled gracefully - shorter + lists simply stop contributing once exhausted. + """ + + class Input(BlockSchemaInput): + lists: List[List[Any]] = SchemaField( + description="A list of lists to interleave. Elements will be taken in round-robin order.", + placeholder="e.g., [[1, 2, 3], ['a', 'b', 'c']]", + ) + + class Output(BlockSchemaOutput): + interleaved_list: List[Any] = SchemaField( + description="The interleaved list with elements alternating from each input list." + ) + length: int = SchemaField( + description="The total number of elements in the interleaved list." + ) + error: str = SchemaField(description="Error message if interleaving failed.") + + def __init__(self): + super().__init__( + id="9f616084-1d9f-4f8e-bc00-5b9d2a75cd75", + description="Interleaves elements from multiple lists in round-robin fashion, alternating between sources.", + categories={BlockCategory.BASIC}, + input_schema=InterleaveListsBlock.Input, + output_schema=InterleaveListsBlock.Output, + test_input=[ + {"lists": [[1, 2, 3], ["a", "b", "c"]]}, + {"lists": [[1, 2, 3], ["a", "b"], ["x", "y", "z"]]}, + {"lists": [[1], [2], [3]]}, + {"lists": []}, + ], + test_output=[ + ("interleaved_list", [1, "a", 2, "b", 3, "c"]), + ("length", 6), + ("interleaved_list", [1, "a", "x", 2, "b", "y", 3, "z"]), + ("length", 8), + ("interleaved_list", [1, 2, 3]), + ("length", 3), + ("interleaved_list", []), + ("length", 0), + ], + ) + + def _validate_inputs(self, lists: List[Any]) -> str | None: + return _validate_all_lists(lists) + + def _interleave(self, lists: List[List[Any]]) -> List[Any]: + return _interleave_lists(lists) + + async def run(self, input_data: Input, **kwargs) -> BlockOutput: + validation_error = self._validate_inputs(input_data.lists) + if validation_error is not None: + yield "error", validation_error + return + + result = self._interleave(input_data.lists) + yield "interleaved_list", result + yield "length", len(result) + + +class ZipListsBlock(Block): + """ + Zips multiple lists together into a list of grouped tuples/lists. + + Takes two or more input lists and combines corresponding elements + into sub-lists. For example, zipping [1,2,3] and ['a','b','c'] + produces [[1,'a'], [2,'b'], [3,'c']]. Supports both truncating + to shortest list and padding to longest list with a fill value. + """ + + class Input(BlockSchemaInput): + lists: List[List[Any]] = SchemaField( + description="A list of lists to zip together. Corresponding elements will be grouped.", + placeholder="e.g., [[1, 2, 3], ['a', 'b', 'c']]", + ) + pad_to_longest: bool = SchemaField( + description="If True, pad shorter lists with fill_value to match the longest list. If False, truncate to shortest.", + default=False, + advanced=True, + ) + fill_value: Any = SchemaField( + description="Value to use for padding when pad_to_longest is True.", + default=None, + advanced=True, + ) + + class Output(BlockSchemaOutput): + zipped_list: List[List[Any]] = SchemaField( + description="The zipped list of grouped elements." + ) + length: int = SchemaField( + description="The number of groups in the zipped result." + ) + error: str = SchemaField(description="Error message if zipping failed.") + + def __init__(self): + super().__init__( + id="0d0e684f-5cb9-4c4b-b8d1-47a0860e0c07", + description="Zips multiple lists together into a list of grouped elements. Supports padding to longest or truncating to shortest.", + categories={BlockCategory.BASIC}, + input_schema=ZipListsBlock.Input, + output_schema=ZipListsBlock.Output, + test_input=[ + {"lists": [[1, 2, 3], ["a", "b", "c"]]}, + {"lists": [[1, 2, 3], ["a", "b"]]}, + { + "lists": [[1, 2], ["a", "b", "c"]], + "pad_to_longest": True, + "fill_value": 0, + }, + {"lists": []}, + ], + test_output=[ + ("zipped_list", [[1, "a"], [2, "b"], [3, "c"]]), + ("length", 3), + ("zipped_list", [[1, "a"], [2, "b"]]), + ("length", 2), + ("zipped_list", [[1, "a"], [2, "b"], [0, "c"]]), + ("length", 3), + ("zipped_list", []), + ("length", 0), + ], + ) + + def _validate_inputs(self, lists: List[Any]) -> str | None: + return _validate_all_lists(lists) + + def _zip_truncate(self, lists: List[List[Any]]) -> List[List[Any]]: + """Zip lists, truncating to shortest.""" + filtered = [lst for lst in lists if lst is not None] + if not filtered: + return [] + return [list(group) for group in zip(*filtered)] + + def _zip_pad(self, lists: List[List[Any]], fill_value: Any) -> List[List[Any]]: + """Zip lists, padding shorter ones with fill_value.""" + if not lists: + return [] + lists = [lst for lst in lists if lst is not None] + if not lists: + return [] + max_len = max(len(lst) for lst in lists) + result: List[List[Any]] = [] + for i in range(max_len): + group: List[Any] = [] + for lst in lists: + if i < len(lst): + group.append(lst[i]) + else: + group.append(fill_value) + result.append(group) + return result + + async def run(self, input_data: Input, **kwargs) -> BlockOutput: + validation_error = self._validate_inputs(input_data.lists) + if validation_error is not None: + yield "error", validation_error + return + + if not input_data.lists: + yield "zipped_list", [] + yield "length", 0 + return + + if input_data.pad_to_longest: + result = self._zip_pad(input_data.lists, input_data.fill_value) + else: + result = self._zip_truncate(input_data.lists) + + yield "zipped_list", result + yield "length", len(result) + + +class ListDifferenceBlock(Block): + """ + Computes the difference between two lists (elements in the first + list that are not in the second list). + + This is useful for finding items that exist in one dataset but + not in another, such as finding new items, missing items, or + items that need to be processed. + """ + + class Input(BlockSchemaInput): + list_a: List[Any] = SchemaField( + description="The primary list to check elements from.", + placeholder="e.g., [1, 2, 3, 4, 5]", + ) + list_b: List[Any] = SchemaField( + description="The list to subtract. Elements found here will be removed from list_a.", + placeholder="e.g., [3, 4, 5, 6]", + ) + symmetric: bool = SchemaField( + description="If True, compute symmetric difference (elements in either list but not both).", + default=False, + advanced=True, + ) + + class Output(BlockSchemaOutput): + difference: List[Any] = SchemaField( + description="Elements from list_a not found in list_b (or symmetric difference if enabled)." + ) + length: int = SchemaField( + description="The number of elements in the difference result." + ) + error: str = SchemaField(description="Error message if the operation failed.") + + def __init__(self): + super().__init__( + id="05309873-9d61-447e-96b5-b804e2511829", + description="Computes the difference between two lists. Returns elements in the first list not found in the second, or symmetric difference.", + categories={BlockCategory.BASIC}, + input_schema=ListDifferenceBlock.Input, + output_schema=ListDifferenceBlock.Output, + test_input=[ + {"list_a": [1, 2, 3, 4, 5], "list_b": [3, 4, 5, 6, 7]}, + { + "list_a": [1, 2, 3, 4, 5], + "list_b": [3, 4, 5, 6, 7], + "symmetric": True, + }, + {"list_a": ["a", "b", "c"], "list_b": ["b"]}, + {"list_a": [], "list_b": [1, 2, 3]}, + ], + test_output=[ + ("difference", [1, 2]), + ("length", 2), + ("difference", [1, 2, 6, 7]), + ("length", 4), + ("difference", ["a", "c"]), + ("length", 2), + ("difference", []), + ("length", 0), + ], + ) + + def _compute_difference(self, list_a: List[Any], list_b: List[Any]) -> List[Any]: + """Compute elements in list_a not in list_b.""" + b_hashes = {_make_hashable(item) for item in list_b} + return [item for item in list_a if _make_hashable(item) not in b_hashes] + + def _compute_symmetric_difference( + self, list_a: List[Any], list_b: List[Any] + ) -> List[Any]: + """Compute elements in either list but not both.""" + a_hashes = {_make_hashable(item) for item in list_a} + b_hashes = {_make_hashable(item) for item in list_b} + only_in_a = [item for item in list_a if _make_hashable(item) not in b_hashes] + only_in_b = [item for item in list_b if _make_hashable(item) not in a_hashes] + return only_in_a + only_in_b + + async def run(self, input_data: Input, **kwargs) -> BlockOutput: + if input_data.symmetric: + result = self._compute_symmetric_difference( + input_data.list_a, input_data.list_b + ) + else: + result = self._compute_difference(input_data.list_a, input_data.list_b) + + yield "difference", result + yield "length", len(result) + + +class ListIntersectionBlock(Block): + """ + Computes the intersection of two lists (elements present in both lists). + + This is useful for finding common items between two datasets, + such as shared tags, mutual connections, or overlapping categories. + """ + + class Input(BlockSchemaInput): + list_a: List[Any] = SchemaField( + description="The first list to intersect.", + placeholder="e.g., [1, 2, 3, 4, 5]", + ) + list_b: List[Any] = SchemaField( + description="The second list to intersect.", + placeholder="e.g., [3, 4, 5, 6, 7]", + ) + + class Output(BlockSchemaOutput): + intersection: List[Any] = SchemaField( + description="Elements present in both list_a and list_b." + ) + length: int = SchemaField( + description="The number of elements in the intersection." + ) + error: str = SchemaField(description="Error message if the operation failed.") + + def __init__(self): + super().__init__( + id="b6eb08b6-dbe3-411b-b9b4-2508cb311a1f", + description="Computes the intersection of two lists, returning only elements present in both.", + categories={BlockCategory.BASIC}, + input_schema=ListIntersectionBlock.Input, + output_schema=ListIntersectionBlock.Output, + test_input=[ + {"list_a": [1, 2, 3, 4, 5], "list_b": [3, 4, 5, 6, 7]}, + {"list_a": ["a", "b", "c"], "list_b": ["c", "d", "e"]}, + {"list_a": [1, 2], "list_b": [3, 4]}, + {"list_a": [], "list_b": [1, 2, 3]}, + ], + test_output=[ + ("intersection", [3, 4, 5]), + ("length", 3), + ("intersection", ["c"]), + ("length", 1), + ("intersection", []), + ("length", 0), + ("intersection", []), + ("length", 0), + ], + ) + + def _compute_intersection(self, list_a: List[Any], list_b: List[Any]) -> List[Any]: + """Compute elements present in both lists, preserving order from list_a.""" + b_hashes = {_make_hashable(item) for item in list_b} + seen: set = set() + result: List[Any] = [] + for item in list_a: + h = _make_hashable(item) + if h in b_hashes and h not in seen: + result.append(item) + seen.add(h) + return result + + async def run(self, input_data: Input, **kwargs) -> BlockOutput: + result = self._compute_intersection(input_data.list_a, input_data.list_b) + yield "intersection", result + yield "length", len(result) diff --git a/autogpt_platform/backend/test/blocks/test_list_concatenation.py b/autogpt_platform/backend/test/blocks/test_list_concatenation.py new file mode 100644 index 0000000000..8cea3b60f7 --- /dev/null +++ b/autogpt_platform/backend/test/blocks/test_list_concatenation.py @@ -0,0 +1,1276 @@ +""" +Comprehensive test suite for list concatenation and manipulation blocks. + +Tests cover: +- ConcatenateListsBlock: basic concatenation, deduplication, None removal +- FlattenListBlock: nested list flattening with depth control +- InterleaveListsBlock: round-robin interleaving of multiple lists +- ZipListsBlock: zipping lists with truncation and padding +- ListDifferenceBlock: computing list differences (regular and symmetric) +- ListIntersectionBlock: finding common elements between lists +- Helper utility functions: validation, flattening, deduplication, etc. +""" + +import pytest + +from backend.blocks.data_manipulation import ( + _MAX_FLATTEN_DEPTH, + ConcatenateListsBlock, + FlattenListBlock, + InterleaveListsBlock, + ListDifferenceBlock, + ListIntersectionBlock, + ZipListsBlock, + _compute_nesting_depth, + _concatenate_lists_simple, + _deduplicate_list, + _filter_none_values, + _flatten_nested_list, + _interleave_lists, + _make_hashable, + _validate_all_lists, + _validate_list_input, +) +from backend.util.test import execute_block_test + +# ============================================================================= +# Helper Function Tests +# ============================================================================= + + +class TestValidateListInput: + """Tests for the _validate_list_input helper.""" + + def test_valid_list_returns_none(self): + assert _validate_list_input([1, 2, 3], 0) is None + + def test_empty_list_returns_none(self): + assert _validate_list_input([], 0) is None + + def test_none_returns_none(self): + assert _validate_list_input(None, 0) is None + + def test_string_returns_error(self): + result = _validate_list_input("hello", 0) + assert result is not None + assert "str" in result + assert "index 0" in result + + def test_integer_returns_error(self): + result = _validate_list_input(42, 1) + assert result is not None + assert "int" in result + assert "index 1" in result + + def test_dict_returns_error(self): + result = _validate_list_input({"a": 1}, 2) + assert result is not None + assert "dict" in result + assert "index 2" in result + + def test_tuple_returns_error(self): + result = _validate_list_input((1, 2), 3) + assert result is not None + assert "tuple" in result + + def test_boolean_returns_error(self): + result = _validate_list_input(True, 0) + assert result is not None + assert "bool" in result + + def test_float_returns_error(self): + result = _validate_list_input(3.14, 0) + assert result is not None + assert "float" in result + + +class TestValidateAllLists: + """Tests for the _validate_all_lists helper.""" + + def test_all_valid_lists(self): + assert _validate_all_lists([[1], [2], [3]]) is None + + def test_empty_outer_list(self): + assert _validate_all_lists([]) is None + + def test_mixed_valid_and_none(self): + # None is skipped, so this should pass + assert _validate_all_lists([[1], None, [3]]) is None + + def test_invalid_item_returns_error(self): + result = _validate_all_lists([[1], "bad", [3]]) + assert result is not None + assert "index 1" in result + + def test_first_invalid_is_returned(self): + result = _validate_all_lists(["first_bad", "second_bad"]) + assert result is not None + assert "index 0" in result + + def test_all_none_passes(self): + assert _validate_all_lists([None, None, None]) is None + + +class TestConcatenateListsSimple: + """Tests for the _concatenate_lists_simple helper.""" + + def test_basic_concatenation(self): + assert _concatenate_lists_simple([[1, 2], [3, 4]]) == [1, 2, 3, 4] + + def test_empty_lists(self): + assert _concatenate_lists_simple([[], []]) == [] + + def test_single_list(self): + assert _concatenate_lists_simple([[1, 2, 3]]) == [1, 2, 3] + + def test_no_lists(self): + assert _concatenate_lists_simple([]) == [] + + def test_skip_none_values(self): + assert _concatenate_lists_simple([[1, 2], None, [3, 4]]) == [1, 2, 3, 4] # type: ignore[arg-type] + + def test_mixed_types(self): + result = _concatenate_lists_simple([[1, "a"], [True, 3.14]]) + assert result == [1, "a", True, 3.14] + + def test_nested_lists_preserved(self): + result = _concatenate_lists_simple([[[1, 2]], [[3, 4]]]) + assert result == [[1, 2], [3, 4]] + + def test_large_number_of_lists(self): + lists = [[i] for i in range(100)] + result = _concatenate_lists_simple(lists) + assert result == list(range(100)) + + +class TestFlattenNestedList: + """Tests for the _flatten_nested_list helper.""" + + def test_already_flat(self): + assert _flatten_nested_list([1, 2, 3]) == [1, 2, 3] + + def test_one_level_nesting(self): + assert _flatten_nested_list([[1, 2], [3, 4]]) == [1, 2, 3, 4] + + def test_deep_nesting(self): + assert _flatten_nested_list([1, [2, [3, [4, [5]]]]]) == [1, 2, 3, 4, 5] + + def test_empty_list(self): + assert _flatten_nested_list([]) == [] + + def test_mixed_nesting(self): + assert _flatten_nested_list([1, [2, 3], 4, [5, [6]]]) == [1, 2, 3, 4, 5, 6] + + def test_max_depth_zero(self): + # max_depth=0 means no flattening at all + result = _flatten_nested_list([[1, 2], [3, 4]], max_depth=0) + assert result == [[1, 2], [3, 4]] + + def test_max_depth_one(self): + result = _flatten_nested_list([[1, [2, 3]], [4, [5]]], max_depth=1) + assert result == [1, [2, 3], 4, [5]] + + def test_max_depth_two(self): + result = _flatten_nested_list([[[1, 2], [3]], [[4, [5]]]], max_depth=2) + assert result == [1, 2, 3, 4, [5]] + + def test_unlimited_depth(self): + deeply_nested = [[[[[[[1]]]]]]] + assert _flatten_nested_list(deeply_nested, max_depth=-1) == [1] + + def test_preserves_non_list_iterables(self): + result = _flatten_nested_list(["hello", [1, 2]]) + assert result == ["hello", 1, 2] + + def test_preserves_dicts(self): + result = _flatten_nested_list([{"a": 1}, [{"b": 2}]]) + assert result == [{"a": 1}, {"b": 2}] + + def test_excessive_depth_raises_recursion_error(self): + """Deeply nested lists beyond 1000 levels should raise RecursionError.""" + # Build a list nested 1100 levels deep + nested = [42] + for _ in range(1100): + nested = [nested] + with pytest.raises(RecursionError, match="maximum.*depth"): + _flatten_nested_list(nested, max_depth=-1) + + +class TestDeduplicateList: + """Tests for the _deduplicate_list helper.""" + + def test_no_duplicates(self): + assert _deduplicate_list([1, 2, 3]) == [1, 2, 3] + + def test_with_duplicates(self): + assert _deduplicate_list([1, 2, 2, 3, 3, 3]) == [1, 2, 3] + + def test_all_duplicates(self): + assert _deduplicate_list([1, 1, 1]) == [1] + + def test_empty_list(self): + assert _deduplicate_list([]) == [] + + def test_preserves_order(self): + result = _deduplicate_list([3, 1, 2, 1, 3]) + assert result == [3, 1, 2] + + def test_string_duplicates(self): + assert _deduplicate_list(["a", "b", "a", "c"]) == ["a", "b", "c"] + + def test_mixed_types(self): + result = _deduplicate_list([1, "1", 1, "1"]) + assert result == [1, "1"] + + def test_dict_duplicates(self): + result = _deduplicate_list([{"a": 1}, {"a": 1}, {"b": 2}]) + assert result == [{"a": 1}, {"b": 2}] + + def test_list_duplicates(self): + result = _deduplicate_list([[1, 2], [1, 2], [3, 4]]) + assert result == [[1, 2], [3, 4]] + + def test_none_duplicates(self): + result = _deduplicate_list([None, 1, None, 2]) + assert result == [None, 1, 2] + + def test_single_element(self): + assert _deduplicate_list([42]) == [42] + + +class TestMakeHashable: + """Tests for the _make_hashable helper.""" + + def test_integer(self): + assert _make_hashable(42) == 42 + + def test_string(self): + assert _make_hashable("hello") == "hello" + + def test_none(self): + assert _make_hashable(None) is None + + def test_dict_returns_tuple(self): + result = _make_hashable({"a": 1}) + assert isinstance(result, tuple) + # Should be hashable + hash(result) + + def test_list_returns_tuple(self): + result = _make_hashable([1, 2, 3]) + assert result == (1, 2, 3) + + def test_same_dict_same_hash(self): + assert _make_hashable({"a": 1, "b": 2}) == _make_hashable({"a": 1, "b": 2}) + + def test_different_dict_different_hash(self): + assert _make_hashable({"a": 1}) != _make_hashable({"a": 2}) + + def test_dict_key_order_independent(self): + """Dicts with same keys in different insertion order produce same result.""" + d1 = {"b": 2, "a": 1} + d2 = {"a": 1, "b": 2} + assert _make_hashable(d1) == _make_hashable(d2) + + def test_tuple_hashable(self): + result = _make_hashable((1, 2, 3)) + assert result == (1, 2, 3) + hash(result) + + def test_boolean(self): + result = _make_hashable(True) + assert result is True + + def test_float(self): + result = _make_hashable(3.14) + assert result == 3.14 + + +class TestFilterNoneValues: + """Tests for the _filter_none_values helper.""" + + def test_removes_none(self): + assert _filter_none_values([1, None, 2, None, 3]) == [1, 2, 3] + + def test_no_none(self): + assert _filter_none_values([1, 2, 3]) == [1, 2, 3] + + def test_all_none(self): + assert _filter_none_values([None, None, None]) == [] + + def test_empty_list(self): + assert _filter_none_values([]) == [] + + def test_preserves_falsy_values(self): + assert _filter_none_values([0, False, "", None, []]) == [0, False, "", []] + + +class TestComputeNestingDepth: + """Tests for the _compute_nesting_depth helper.""" + + def test_flat_list(self): + assert _compute_nesting_depth([1, 2, 3]) == 1 + + def test_one_level(self): + assert _compute_nesting_depth([[1, 2], [3, 4]]) == 2 + + def test_deep_nesting(self): + assert _compute_nesting_depth([[[[]]]]) == 4 + + def test_mixed_depth(self): + depth = _compute_nesting_depth([1, [2, [3]]]) + assert depth == 3 + + def test_empty_list(self): + assert _compute_nesting_depth([]) == 1 + + def test_non_list(self): + assert _compute_nesting_depth(42) == 0 + + def test_string_not_recursed(self): + # Strings should not be treated as nested lists + assert _compute_nesting_depth(["hello"]) == 1 + + +class TestInterleaveListsHelper: + """Tests for the _interleave_lists helper.""" + + def test_equal_length_lists(self): + result = _interleave_lists([[1, 2, 3], ["a", "b", "c"]]) + assert result == [1, "a", 2, "b", 3, "c"] + + def test_unequal_length_lists(self): + result = _interleave_lists([[1, 2, 3], ["a"]]) + assert result == [1, "a", 2, 3] + + def test_empty_input(self): + assert _interleave_lists([]) == [] + + def test_single_list(self): + assert _interleave_lists([[1, 2, 3]]) == [1, 2, 3] + + def test_three_lists(self): + result = _interleave_lists([[1], [2], [3]]) + assert result == [1, 2, 3] + + def test_with_none_list(self): + result = _interleave_lists([[1, 2], None, [3, 4]]) # type: ignore[arg-type] + assert result == [1, 3, 2, 4] + + def test_all_empty_lists(self): + assert _interleave_lists([[], [], []]) == [] + + def test_all_none_lists(self): + """All-None inputs should return empty list, not crash.""" + assert _interleave_lists([None, None, None]) == [] # type: ignore[arg-type] + + +class TestComputeNestingDepthEdgeCases: + """Tests for _compute_nesting_depth with deeply nested input.""" + + def test_deeply_nested_does_not_crash(self): + """Deeply nested lists beyond 1000 levels should not raise RecursionError.""" + nested = [42] + for _ in range(1100): + nested = [nested] + # Should return a depth value without crashing + depth = _compute_nesting_depth(nested) + assert depth >= _MAX_FLATTEN_DEPTH + + +class TestMakeHashableMixedKeys: + """Tests for _make_hashable with mixed-type dict keys.""" + + def test_mixed_type_dict_keys(self): + """Dicts with mixed-type keys (int and str) should not crash sorted().""" + d = {1: "one", "two": 2} + result = _make_hashable(d) + assert isinstance(result, tuple) + hash(result) # Should be hashable without error + + def test_mixed_type_keys_deterministic(self): + """Same dict with mixed keys produces same result.""" + d1 = {1: "a", "b": 2} + d2 = {1: "a", "b": 2} + assert _make_hashable(d1) == _make_hashable(d2) + + +class TestZipListsNoneHandling: + """Tests for ZipListsBlock with None values in input.""" + + def setup_method(self): + self.block = ZipListsBlock() + + def test_zip_truncate_with_none(self): + """_zip_truncate should handle None values in input lists.""" + result = self.block._zip_truncate([[1, 2], None, [3, 4]]) # type: ignore[arg-type] + assert result == [[1, 3], [2, 4]] + + def test_zip_pad_with_none(self): + """_zip_pad should handle None values in input lists.""" + result = self.block._zip_pad([[1, 2, 3], None, ["a"]], fill_value="X") # type: ignore[arg-type] + assert result == [[1, "a"], [2, "X"], [3, "X"]] + + def test_zip_truncate_all_none(self): + """All-None inputs should return empty list.""" + result = self.block._zip_truncate([None, None]) # type: ignore[arg-type] + assert result == [] + + def test_zip_pad_all_none(self): + """All-None inputs should return empty list.""" + result = self.block._zip_pad([None, None], fill_value=0) # type: ignore[arg-type] + assert result == [] + + +# ============================================================================= +# Block Built-in Tests (using test_input/test_output) +# ============================================================================= + + +class TestConcatenateListsBlockBuiltin: + """Run the built-in test_input/test_output tests for ConcatenateListsBlock.""" + + @pytest.mark.asyncio + async def test_builtin_tests(self): + block = ConcatenateListsBlock() + await execute_block_test(block) + + +class TestFlattenListBlockBuiltin: + """Run the built-in test_input/test_output tests for FlattenListBlock.""" + + @pytest.mark.asyncio + async def test_builtin_tests(self): + block = FlattenListBlock() + await execute_block_test(block) + + +class TestInterleaveListsBlockBuiltin: + """Run the built-in test_input/test_output tests for InterleaveListsBlock.""" + + @pytest.mark.asyncio + async def test_builtin_tests(self): + block = InterleaveListsBlock() + await execute_block_test(block) + + +class TestZipListsBlockBuiltin: + """Run the built-in test_input/test_output tests for ZipListsBlock.""" + + @pytest.mark.asyncio + async def test_builtin_tests(self): + block = ZipListsBlock() + await execute_block_test(block) + + +class TestListDifferenceBlockBuiltin: + """Run the built-in test_input/test_output tests for ListDifferenceBlock.""" + + @pytest.mark.asyncio + async def test_builtin_tests(self): + block = ListDifferenceBlock() + await execute_block_test(block) + + +class TestListIntersectionBlockBuiltin: + """Run the built-in test_input/test_output tests for ListIntersectionBlock.""" + + @pytest.mark.asyncio + async def test_builtin_tests(self): + block = ListIntersectionBlock() + await execute_block_test(block) + + +# ============================================================================= +# ConcatenateListsBlock Manual Tests +# ============================================================================= + + +class TestConcatenateListsBlockManual: + """Manual test cases for ConcatenateListsBlock edge cases.""" + + def setup_method(self): + self.block = ConcatenateListsBlock() + + @pytest.mark.asyncio + async def test_two_lists(self): + """Test basic two-list concatenation.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[[1, 2], [3, 4]]) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3, 4] + assert results["length"] == 4 + + @pytest.mark.asyncio + async def test_three_lists(self): + """Test three-list concatenation.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[[1], [2], [3]]) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_five_lists(self): + """Test concatenation of five lists.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[[1], [2], [3], [4], [5]]) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3, 4, 5] + assert results["length"] == 5 + + @pytest.mark.asyncio + async def test_empty_lists_only(self): + """Test concatenation of only empty lists.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[[], [], []]) + ): + results[name] = value + assert results["concatenated_list"] == [] + assert results["length"] == 0 + + @pytest.mark.asyncio + async def test_mixed_types_in_lists(self): + """Test concatenation with mixed types.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input( + lists=[[1, "a"], [True, 3.14], [None, {"key": "val"}]] + ) + ): + results[name] = value + assert results["concatenated_list"] == [ + 1, + "a", + True, + 3.14, + None, + {"key": "val"}, + ] + + @pytest.mark.asyncio + async def test_deduplication_enabled(self): + """Test deduplication removes duplicates.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input( + lists=[[1, 2, 3], [2, 3, 4], [3, 4, 5]], + deduplicate=True, + ) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3, 4, 5] + + @pytest.mark.asyncio + async def test_deduplication_preserves_order(self): + """Test that deduplication preserves first-occurrence order.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input( + lists=[[3, 1, 2], [2, 4, 1]], + deduplicate=True, + ) + ): + results[name] = value + assert results["concatenated_list"] == [3, 1, 2, 4] + + @pytest.mark.asyncio + async def test_remove_none_enabled(self): + """Test None removal from concatenated results.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input( + lists=[[1, None], [None, 2], [3, None]], + remove_none=True, + ) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_dedup_and_remove_none_combined(self): + """Test both deduplication and None removal together.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input( + lists=[[1, None, 2], [2, None, 3]], + deduplicate=True, + remove_none=True, + ) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_nested_lists_preserved(self): + """Test that nested lists are not flattened during concatenation.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[[[1, 2]], [[3, 4]]]) + ): + results[name] = value + assert results["concatenated_list"] == [[1, 2], [3, 4]] + + @pytest.mark.asyncio + async def test_large_lists(self): + """Test concatenation of large lists.""" + list_a = list(range(1000)) + list_b = list(range(1000, 2000)) + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[list_a, list_b]) + ): + results[name] = value + assert results["concatenated_list"] == list(range(2000)) + assert results["length"] == 2000 + + @pytest.mark.asyncio + async def test_single_list_input(self): + """Test concatenation with a single list.""" + results = {} + async for name, value in self.block.run( + ConcatenateListsBlock.Input(lists=[[1, 2, 3]]) + ): + results[name] = value + assert results["concatenated_list"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_block_id_is_valid_uuid(self): + """Test that the block has a valid UUID4 ID.""" + import uuid + + parsed = uuid.UUID(self.block.id) + assert parsed.version == 4 + + @pytest.mark.asyncio + async def test_block_category(self): + """Test that the block has the correct category.""" + from backend.blocks._base import BlockCategory + + assert BlockCategory.BASIC in self.block.categories + + +# ============================================================================= +# FlattenListBlock Manual Tests +# ============================================================================= + + +class TestFlattenListBlockManual: + """Manual test cases for FlattenListBlock.""" + + def setup_method(self): + self.block = FlattenListBlock() + + @pytest.mark.asyncio + async def test_simple_flatten(self): + """Test flattening a simple nested list.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input(nested_list=[[1, 2], [3, 4]]) + ): + results[name] = value + assert results["flattened_list"] == [1, 2, 3, 4] + assert results["length"] == 4 + + @pytest.mark.asyncio + async def test_deeply_nested(self): + """Test flattening a deeply nested structure.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input(nested_list=[1, [2, [3, [4, [5]]]]]) + ): + results[name] = value + assert results["flattened_list"] == [1, 2, 3, 4, 5] + + @pytest.mark.asyncio + async def test_partial_flatten(self): + """Test flattening with max_depth=1.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input( + nested_list=[[1, [2, 3]], [4, [5]]], + max_depth=1, + ) + ): + results[name] = value + assert results["flattened_list"] == [1, [2, 3], 4, [5]] + + @pytest.mark.asyncio + async def test_already_flat_list(self): + """Test flattening an already flat list.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input(nested_list=[1, 2, 3, 4]) + ): + results[name] = value + assert results["flattened_list"] == [1, 2, 3, 4] + + @pytest.mark.asyncio + async def test_empty_nested_lists(self): + """Test flattening with empty nested lists.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input(nested_list=[[], [1], [], [2], []]) + ): + results[name] = value + assert results["flattened_list"] == [1, 2] + + @pytest.mark.asyncio + async def test_mixed_types_preserved(self): + """Test that non-list types are preserved during flattening.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input(nested_list=["hello", [1, {"a": 1}], [True]]) + ): + results[name] = value + assert results["flattened_list"] == ["hello", 1, {"a": 1}, True] + + @pytest.mark.asyncio + async def test_original_depth_reported(self): + """Test that original nesting depth is correctly reported.""" + results = {} + async for name, value in self.block.run( + FlattenListBlock.Input(nested_list=[1, [2, [3]]]) + ): + results[name] = value + assert results["original_depth"] == 3 + + @pytest.mark.asyncio + async def test_block_id_is_valid_uuid(self): + """Test that the block has a valid UUID4 ID.""" + import uuid + + parsed = uuid.UUID(self.block.id) + assert parsed.version == 4 + + +# ============================================================================= +# InterleaveListsBlock Manual Tests +# ============================================================================= + + +class TestInterleaveListsBlockManual: + """Manual test cases for InterleaveListsBlock.""" + + def setup_method(self): + self.block = InterleaveListsBlock() + + @pytest.mark.asyncio + async def test_equal_length_interleave(self): + """Test interleaving two equal-length lists.""" + results = {} + async for name, value in self.block.run( + InterleaveListsBlock.Input(lists=[[1, 2, 3], ["a", "b", "c"]]) + ): + results[name] = value + assert results["interleaved_list"] == [1, "a", 2, "b", 3, "c"] + + @pytest.mark.asyncio + async def test_unequal_length_interleave(self): + """Test interleaving lists of different lengths.""" + results = {} + async for name, value in self.block.run( + InterleaveListsBlock.Input(lists=[[1, 2, 3, 4], ["a", "b"]]) + ): + results[name] = value + assert results["interleaved_list"] == [1, "a", 2, "b", 3, 4] + + @pytest.mark.asyncio + async def test_three_lists_interleave(self): + """Test interleaving three lists.""" + results = {} + async for name, value in self.block.run( + InterleaveListsBlock.Input(lists=[[1, 2], ["a", "b"], ["x", "y"]]) + ): + results[name] = value + assert results["interleaved_list"] == [1, "a", "x", 2, "b", "y"] + + @pytest.mark.asyncio + async def test_single_element_lists(self): + """Test interleaving single-element lists.""" + results = {} + async for name, value in self.block.run( + InterleaveListsBlock.Input(lists=[[1], [2], [3], [4]]) + ): + results[name] = value + assert results["interleaved_list"] == [1, 2, 3, 4] + + @pytest.mark.asyncio + async def test_block_id_is_valid_uuid(self): + """Test that the block has a valid UUID4 ID.""" + import uuid + + parsed = uuid.UUID(self.block.id) + assert parsed.version == 4 + + +# ============================================================================= +# ZipListsBlock Manual Tests +# ============================================================================= + + +class TestZipListsBlockManual: + """Manual test cases for ZipListsBlock.""" + + def setup_method(self): + self.block = ZipListsBlock() + + @pytest.mark.asyncio + async def test_basic_zip(self): + """Test basic zipping of two lists.""" + results = {} + async for name, value in self.block.run( + ZipListsBlock.Input(lists=[[1, 2, 3], ["a", "b", "c"]]) + ): + results[name] = value + assert results["zipped_list"] == [[1, "a"], [2, "b"], [3, "c"]] + + @pytest.mark.asyncio + async def test_truncate_to_shortest(self): + """Test that default behavior truncates to shortest list.""" + results = {} + async for name, value in self.block.run( + ZipListsBlock.Input(lists=[[1, 2, 3], ["a", "b"]]) + ): + results[name] = value + assert results["zipped_list"] == [[1, "a"], [2, "b"]] + assert results["length"] == 2 + + @pytest.mark.asyncio + async def test_pad_to_longest(self): + """Test padding shorter lists with fill value.""" + results = {} + async for name, value in self.block.run( + ZipListsBlock.Input( + lists=[[1, 2, 3], ["a"]], + pad_to_longest=True, + fill_value="X", + ) + ): + results[name] = value + assert results["zipped_list"] == [[1, "a"], [2, "X"], [3, "X"]] + + @pytest.mark.asyncio + async def test_pad_with_none(self): + """Test padding with None (default fill value).""" + results = {} + async for name, value in self.block.run( + ZipListsBlock.Input( + lists=[[1, 2], ["a"]], + pad_to_longest=True, + ) + ): + results[name] = value + assert results["zipped_list"] == [[1, "a"], [2, None]] + + @pytest.mark.asyncio + async def test_three_lists_zip(self): + """Test zipping three lists.""" + results = {} + async for name, value in self.block.run( + ZipListsBlock.Input(lists=[[1, 2], ["a", "b"], [True, False]]) + ): + results[name] = value + assert results["zipped_list"] == [[1, "a", True], [2, "b", False]] + + @pytest.mark.asyncio + async def test_empty_lists_zip(self): + """Test zipping empty input.""" + results = {} + async for name, value in self.block.run(ZipListsBlock.Input(lists=[])): + results[name] = value + assert results["zipped_list"] == [] + assert results["length"] == 0 + + @pytest.mark.asyncio + async def test_block_id_is_valid_uuid(self): + """Test that the block has a valid UUID4 ID.""" + import uuid + + parsed = uuid.UUID(self.block.id) + assert parsed.version == 4 + + +# ============================================================================= +# ListDifferenceBlock Manual Tests +# ============================================================================= + + +class TestListDifferenceBlockManual: + """Manual test cases for ListDifferenceBlock.""" + + def setup_method(self): + self.block = ListDifferenceBlock() + + @pytest.mark.asyncio + async def test_basic_difference(self): + """Test basic set difference.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input( + list_a=[1, 2, 3, 4, 5], + list_b=[3, 4, 5, 6, 7], + ) + ): + results[name] = value + assert results["difference"] == [1, 2] + + @pytest.mark.asyncio + async def test_symmetric_difference(self): + """Test symmetric difference.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input( + list_a=[1, 2, 3], + list_b=[2, 3, 4], + symmetric=True, + ) + ): + results[name] = value + assert results["difference"] == [1, 4] + + @pytest.mark.asyncio + async def test_no_difference(self): + """Test when lists are identical.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input( + list_a=[1, 2, 3], + list_b=[1, 2, 3], + ) + ): + results[name] = value + assert results["difference"] == [] + assert results["length"] == 0 + + @pytest.mark.asyncio + async def test_complete_difference(self): + """Test when lists share no elements.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input( + list_a=[1, 2, 3], + list_b=[4, 5, 6], + ) + ): + results[name] = value + assert results["difference"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_empty_list_a(self): + """Test with empty list_a.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input(list_a=[], list_b=[1, 2, 3]) + ): + results[name] = value + assert results["difference"] == [] + + @pytest.mark.asyncio + async def test_empty_list_b(self): + """Test with empty list_b.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input(list_a=[1, 2, 3], list_b=[]) + ): + results[name] = value + assert results["difference"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_string_difference(self): + """Test difference with string elements.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input( + list_a=["apple", "banana", "cherry"], + list_b=["banana", "date"], + ) + ): + results[name] = value + assert results["difference"] == ["apple", "cherry"] + + @pytest.mark.asyncio + async def test_dict_difference(self): + """Test difference with dictionary elements.""" + results = {} + async for name, value in self.block.run( + ListDifferenceBlock.Input( + list_a=[{"a": 1}, {"b": 2}, {"c": 3}], + list_b=[{"b": 2}], + ) + ): + results[name] = value + assert results["difference"] == [{"a": 1}, {"c": 3}] + + @pytest.mark.asyncio + async def test_block_id_is_valid_uuid(self): + """Test that the block has a valid UUID4 ID.""" + import uuid + + parsed = uuid.UUID(self.block.id) + assert parsed.version == 4 + + +# ============================================================================= +# ListIntersectionBlock Manual Tests +# ============================================================================= + + +class TestListIntersectionBlockManual: + """Manual test cases for ListIntersectionBlock.""" + + def setup_method(self): + self.block = ListIntersectionBlock() + + @pytest.mark.asyncio + async def test_basic_intersection(self): + """Test basic intersection.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input( + list_a=[1, 2, 3, 4, 5], + list_b=[3, 4, 5, 6, 7], + ) + ): + results[name] = value + assert results["intersection"] == [3, 4, 5] + assert results["length"] == 3 + + @pytest.mark.asyncio + async def test_no_intersection(self): + """Test when lists share no elements.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input( + list_a=[1, 2, 3], + list_b=[4, 5, 6], + ) + ): + results[name] = value + assert results["intersection"] == [] + assert results["length"] == 0 + + @pytest.mark.asyncio + async def test_identical_lists(self): + """Test intersection of identical lists.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input( + list_a=[1, 2, 3], + list_b=[1, 2, 3], + ) + ): + results[name] = value + assert results["intersection"] == [1, 2, 3] + + @pytest.mark.asyncio + async def test_preserves_order_from_list_a(self): + """Test that intersection preserves order from list_a.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input( + list_a=[5, 3, 1], + list_b=[1, 3, 5], + ) + ): + results[name] = value + assert results["intersection"] == [5, 3, 1] + + @pytest.mark.asyncio + async def test_empty_list_a(self): + """Test with empty list_a.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input(list_a=[], list_b=[1, 2, 3]) + ): + results[name] = value + assert results["intersection"] == [] + + @pytest.mark.asyncio + async def test_empty_list_b(self): + """Test with empty list_b.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input(list_a=[1, 2, 3], list_b=[]) + ): + results[name] = value + assert results["intersection"] == [] + + @pytest.mark.asyncio + async def test_string_intersection(self): + """Test intersection with string elements.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input( + list_a=["apple", "banana", "cherry"], + list_b=["banana", "cherry", "date"], + ) + ): + results[name] = value + assert results["intersection"] == ["banana", "cherry"] + + @pytest.mark.asyncio + async def test_deduplication_in_intersection(self): + """Test that duplicates in input don't cause duplicate results.""" + results = {} + async for name, value in self.block.run( + ListIntersectionBlock.Input( + list_a=[1, 1, 2, 2, 3], + list_b=[1, 2], + ) + ): + results[name] = value + assert results["intersection"] == [1, 2] + + @pytest.mark.asyncio + async def test_block_id_is_valid_uuid(self): + """Test that the block has a valid UUID4 ID.""" + import uuid + + parsed = uuid.UUID(self.block.id) + assert parsed.version == 4 + + +# ============================================================================= +# Block Method Tests +# ============================================================================= + + +class TestConcatenateListsBlockMethods: + """Tests for internal methods of ConcatenateListsBlock.""" + + def setup_method(self): + self.block = ConcatenateListsBlock() + + def test_validate_inputs_valid(self): + assert self.block._validate_inputs([[1], [2]]) is None + + def test_validate_inputs_invalid(self): + result = self.block._validate_inputs([[1], "bad"]) + assert result is not None + + def test_perform_concatenation(self): + result = self.block._perform_concatenation([[1, 2], [3, 4]]) + assert result == [1, 2, 3, 4] + + def test_apply_deduplication(self): + result = self.block._apply_deduplication([1, 2, 2, 3]) + assert result == [1, 2, 3] + + def test_apply_none_removal(self): + result = self.block._apply_none_removal([1, None, 2]) + assert result == [1, 2] + + def test_post_process_all_options(self): + result = self.block._post_process( + [1, None, 2, None, 2], deduplicate=True, remove_none=True + ) + assert result == [1, 2] + + def test_post_process_no_options(self): + result = self.block._post_process( + [1, None, 2, None, 2], deduplicate=False, remove_none=False + ) + assert result == [1, None, 2, None, 2] + + +class TestFlattenListBlockMethods: + """Tests for internal methods of FlattenListBlock.""" + + def setup_method(self): + self.block = FlattenListBlock() + + def test_compute_depth_flat(self): + assert self.block._compute_depth([1, 2, 3]) == 1 + + def test_compute_depth_nested(self): + assert self.block._compute_depth([[1, [2]]]) == 3 + + def test_flatten_unlimited(self): + result = self.block._flatten([1, [2, [3]]], max_depth=-1) + assert result == [1, 2, 3] + + def test_flatten_limited(self): + result = self.block._flatten([1, [2, [3]]], max_depth=1) + assert result == [1, 2, [3]] + + def test_validate_max_depth_valid(self): + assert self.block._validate_max_depth(-1) is None + assert self.block._validate_max_depth(0) is None + assert self.block._validate_max_depth(5) is None + + def test_validate_max_depth_invalid(self): + result = self.block._validate_max_depth(-2) + assert result is not None + + +class TestZipListsBlockMethods: + """Tests for internal methods of ZipListsBlock.""" + + def setup_method(self): + self.block = ZipListsBlock() + + def test_zip_truncate(self): + result = self.block._zip_truncate([[1, 2, 3], ["a", "b"]]) + assert result == [[1, "a"], [2, "b"]] + + def test_zip_pad(self): + result = self.block._zip_pad([[1, 2, 3], ["a"]], fill_value="X") + assert result == [[1, "a"], [2, "X"], [3, "X"]] + + def test_zip_pad_empty(self): + result = self.block._zip_pad([], fill_value=None) + assert result == [] + + def test_validate_inputs(self): + assert self.block._validate_inputs([[1], [2]]) is None + result = self.block._validate_inputs([[1], "bad"]) + assert result is not None + + +class TestListDifferenceBlockMethods: + """Tests for internal methods of ListDifferenceBlock.""" + + def setup_method(self): + self.block = ListDifferenceBlock() + + def test_compute_difference(self): + result = self.block._compute_difference([1, 2, 3], [2, 3, 4]) + assert result == [1] + + def test_compute_symmetric_difference(self): + result = self.block._compute_symmetric_difference([1, 2, 3], [2, 3, 4]) + assert result == [1, 4] + + def test_compute_difference_empty(self): + result = self.block._compute_difference([], [1, 2]) + assert result == [] + + def test_compute_symmetric_difference_identical(self): + result = self.block._compute_symmetric_difference([1, 2], [1, 2]) + assert result == [] + + +class TestListIntersectionBlockMethods: + """Tests for internal methods of ListIntersectionBlock.""" + + def setup_method(self): + self.block = ListIntersectionBlock() + + def test_compute_intersection(self): + result = self.block._compute_intersection([1, 2, 3], [2, 3, 4]) + assert result == [2, 3] + + def test_compute_intersection_empty(self): + result = self.block._compute_intersection([], [1, 2]) + assert result == [] + + def test_compute_intersection_no_overlap(self): + result = self.block._compute_intersection([1, 2], [3, 4]) + assert result == [] diff --git a/docs/integrations/README.md b/docs/integrations/README.md index c216aa4836..00d4b0c73a 100644 --- a/docs/integrations/README.md +++ b/docs/integrations/README.md @@ -56,12 +56,16 @@ Below is a comprehensive list of all available blocks, categorized by their prim | [File Store](block-integrations/basic.md#file-store) | Downloads and stores a file from a URL, data URI, or local path | | [Find In Dictionary](block-integrations/basic.md#find-in-dictionary) | A block that looks up a value in a dictionary, list, or object by key or index and returns the corresponding value | | [Find In List](block-integrations/basic.md#find-in-list) | Finds the index of the value in the list | +| [Flatten List](block-integrations/basic.md#flatten-list) | Flattens a nested list structure into a single flat list | | [Get All Memories](block-integrations/basic.md#get-all-memories) | Retrieve all memories from Mem0 with optional conversation filtering | | [Get Latest Memory](block-integrations/basic.md#get-latest-memory) | Retrieve the latest memory from Mem0 with optional key filtering | | [Get List Item](block-integrations/basic.md#get-list-item) | Returns the element at the given index | | [Get Store Agent Details](block-integrations/system/store_operations.md#get-store-agent-details) | Get detailed information about an agent from the store | | [Get Weather Information](block-integrations/basic.md#get-weather-information) | Retrieves weather information for a specified location using OpenWeatherMap API | | [Human In The Loop](block-integrations/basic.md#human-in-the-loop) | Pause execution for human review | +| [Interleave Lists](block-integrations/basic.md#interleave-lists) | Interleaves elements from multiple lists in round-robin fashion, alternating between sources | +| [List Difference](block-integrations/basic.md#list-difference) | Computes the difference between two lists | +| [List Intersection](block-integrations/basic.md#list-intersection) | Computes the intersection of two lists, returning only elements present in both | | [List Is Empty](block-integrations/basic.md#list-is-empty) | Checks if a list is empty | | [List Library Agents](block-integrations/system/library_operations.md#list-library-agents) | List all agents in your personal library | | [Note](block-integrations/basic.md#note) | A visual annotation block that displays a sticky note in the workflow editor for documentation and organization purposes | @@ -84,6 +88,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim | [Store Value](block-integrations/basic.md#store-value) | A basic block that stores and forwards a value throughout workflows, allowing it to be reused without changes across multiple blocks | | [Universal Type Converter](block-integrations/basic.md#universal-type-converter) | This block is used to convert a value to a universal type | | [XML Parser](block-integrations/basic.md#xml-parser) | Parses XML using gravitasml to tokenize and coverts it to dict | +| [Zip Lists](block-integrations/basic.md#zip-lists) | Zips multiple lists together into a list of grouped elements | ## Data Processing diff --git a/docs/integrations/block-integrations/basic.md b/docs/integrations/block-integrations/basic.md index 08def38ede..e032690edc 100644 --- a/docs/integrations/block-integrations/basic.md +++ b/docs/integrations/block-integrations/basic.md @@ -637,7 +637,7 @@ This enables extensibility by allowing custom blocks to be added without modifyi ## Concatenate Lists ### What it is -Concatenates multiple lists into a single list. All elements from all input lists are combined in order. +Concatenates multiple lists into a single list. All elements from all input lists are combined in order. Supports optional deduplication and None removal. ### How it works @@ -651,6 +651,8 @@ The block includes validation to ensure each item is actually a list. If a non-l | Input | Description | Type | Required | |-------|-------------|------|----------| | lists | A list of lists to concatenate together. All lists will be combined in order into a single list. | List[List[Any]] | Yes | +| deduplicate | If True, remove duplicate elements from the concatenated result while preserving order. | bool | No | +| remove_none | If True, remove None values from the concatenated result. | bool | No | ### Outputs @@ -658,6 +660,7 @@ The block includes validation to ensure each item is actually a list. If a non-l |--------|-------------|------| | error | Error message if concatenation failed due to invalid input types. | str | | concatenated_list | The concatenated list containing all elements from all input lists in order. | List[Any] | +| length | The total number of elements in the concatenated list. | int | ### Possible use case @@ -820,6 +823,45 @@ This enables conditional logic based on list membership and helps locate items f --- +## Flatten List + +### What it is +Flattens a nested list structure into a single flat list. Supports configurable maximum flattening depth. + +### How it works + +This block recursively traverses a nested list and extracts all leaf elements into a single flat list. You can control how deep the flattening goes with the max_depth parameter: set it to -1 to flatten completely, or to a positive integer to flatten only that many levels. + +The block also reports the original nesting depth of the input, which is useful for understanding the structure of data coming from sources with varying levels of nesting. + + +### Inputs + +| Input | Description | Type | Required | +|-------|-------------|------|----------| +| nested_list | A potentially nested list to flatten into a single-level list. | List[Any] | Yes | +| max_depth | Maximum depth to flatten. -1 means flatten completely. 1 means flatten only one level. | int | No | + +### Outputs + +| Output | Description | Type | +|--------|-------------|------| +| error | Error message if flattening failed. | str | +| flattened_list | The flattened list with all nested elements extracted. | List[Any] | +| length | The number of elements in the flattened list. | int | +| original_depth | The maximum nesting depth of the original input list. | int | + +### Possible use case + +**Normalizing API Responses**: Flatten nested JSON arrays from different API endpoints into a uniform single-level list for consistent processing. + +**Aggregating Nested Results**: Combine results from recursive file searches or nested category trees into a flat list of items for display or export. + +**Data Pipeline Cleanup**: Simplify deeply nested data structures from multiple transformation steps into a clean flat list before final output. + + +--- + ## Get All Memories ### What it is @@ -1012,6 +1054,120 @@ This enables human oversight at critical points in automated workflows, ensuring --- +## Interleave Lists + +### What it is +Interleaves elements from multiple lists in round-robin fashion, alternating between sources. + +### How it works + +This block takes elements from each input list in round-robin order, picking one element from each list in turn. For example, given `[[1, 2, 3], ['a', 'b', 'c']]`, it produces `[1, 'a', 2, 'b', 3, 'c']`. + +When lists have different lengths, shorter lists stop contributing once exhausted, and remaining elements from longer lists continue to be added in order. + + +### Inputs + +| Input | Description | Type | Required | +|-------|-------------|------|----------| +| lists | A list of lists to interleave. Elements will be taken in round-robin order. | List[List[Any]] | Yes | + +### Outputs + +| Output | Description | Type | +|--------|-------------|------| +| error | Error message if interleaving failed. | str | +| interleaved_list | The interleaved list with elements alternating from each input list. | List[Any] | +| length | The total number of elements in the interleaved list. | int | + +### Possible use case + +**Balanced Content Mixing**: Alternate between content from different sources (e.g., mixing promotional and organic posts) for a balanced feed. + +**Round-Robin Scheduling**: Distribute tasks evenly across workers or queues by interleaving items from separate task lists. + +**Multi-Language Output**: Weave together translated text segments with their original counterparts for side-by-side comparison. + + +--- + +## List Difference + +### What it is +Computes the difference between two lists. Returns elements in the first list not found in the second, or symmetric difference. + +### How it works + +This block compares two lists and returns elements from list_a that do not appear in list_b. It uses hash-based lookup for efficient comparison. When symmetric mode is enabled, it returns elements that are in either list but not in both. + +The order of elements from list_a is preserved in the output, and elements from list_b are appended when using symmetric difference. + + +### Inputs + +| Input | Description | Type | Required | +|-------|-------------|------|----------| +| list_a | The primary list to check elements from. | List[Any] | Yes | +| list_b | The list to subtract. Elements found here will be removed from list_a. | List[Any] | Yes | +| symmetric | If True, compute symmetric difference (elements in either list but not both). | bool | No | + +### Outputs + +| Output | Description | Type | +|--------|-------------|------| +| error | Error message if the operation failed. | str | +| difference | Elements from list_a not found in list_b (or symmetric difference if enabled). | List[Any] | +| length | The number of elements in the difference result. | int | + +### Possible use case + +**Change Detection**: Compare a current list of records against a previous snapshot to find newly added or removed items. + +**Exclusion Filtering**: Remove items from a list that appear in a blocklist or already-processed list to avoid duplicates. + +**Data Sync**: Identify which items exist in one system but not another to determine what needs to be synced. + + +--- + +## List Intersection + +### What it is +Computes the intersection of two lists, returning only elements present in both. + +### How it works + +This block finds elements that appear in both input lists by hashing elements from list_b for efficient lookup, then checking each element of list_a against that set. The output preserves the order from list_a and removes duplicates. + +This is useful for finding common items between two datasets without needing to manually iterate or compare. + + +### Inputs + +| Input | Description | Type | Required | +|-------|-------------|------|----------| +| list_a | The first list to intersect. | List[Any] | Yes | +| list_b | The second list to intersect. | List[Any] | Yes | + +### Outputs + +| Output | Description | Type | +|--------|-------------|------| +| error | Error message if the operation failed. | str | +| intersection | Elements present in both list_a and list_b. | List[Any] | +| length | The number of elements in the intersection. | int | + +### Possible use case + +**Finding Common Tags**: Identify shared tags or categories between two items for recommendation or grouping purposes. + +**Mutual Connections**: Find users or contacts that appear in both of two different lists, such as shared friends or overlapping team members. + +**Feature Comparison**: Determine which features or capabilities are supported by both of two systems or products. + + +--- + ## List Is Empty ### What it is @@ -1452,3 +1608,42 @@ This makes XML data accessible using standard dictionary operations, allowing yo --- + +## Zip Lists + +### What it is +Zips multiple lists together into a list of grouped elements. Supports padding to longest or truncating to shortest. + +### How it works + +This block pairs up corresponding elements from multiple input lists into sub-lists. For example, zipping `[[1, 2, 3], ['a', 'b', 'c']]` produces `[[1, 'a'], [2, 'b'], [3, 'c']]`. + +By default, the result is truncated to the length of the shortest input list. Enable pad_to_longest to instead pad shorter lists with a fill_value so no elements from longer lists are lost. + + +### Inputs + +| Input | Description | Type | Required | +|-------|-------------|------|----------| +| lists | A list of lists to zip together. Corresponding elements will be grouped. | List[List[Any]] | Yes | +| pad_to_longest | If True, pad shorter lists with fill_value to match the longest list. If False, truncate to shortest. | bool | No | +| fill_value | Value to use for padding when pad_to_longest is True. | Fill Value | No | + +### Outputs + +| Output | Description | Type | +|--------|-------------|------| +| error | Error message if zipping failed. | str | +| zipped_list | The zipped list of grouped elements. | List[List[Any]] | +| length | The number of groups in the zipped result. | int | + +### Possible use case + +**Creating Key-Value Pairs**: Combine a list of field names with a list of values to build structured records or dictionaries. + +**Parallel Data Alignment**: Pair up corresponding items from separate data sources (e.g., names and email addresses) for processing together. + +**Table Row Construction**: Group column data into rows by zipping each column's values together for CSV export or display. + + +--- From e2d3c8a21761cbbdee88f608833ffcab35eba79e Mon Sep 17 00:00:00 2001 From: Abhimanyu Yadav <122007096+Abhi1992002@users.noreply.github.com> Date: Mon, 16 Feb 2026 12:29:33 +0530 Subject: [PATCH 16/16] fix(frontend): Prevent node drag when selecting text in object editor key input (#11955) ## Summary - Add `nodrag` class to the key name input wrapper in `WrapIfAdditionalTemplate.tsx` - This prevents the node from being dragged when users try to select text in the key name input field - Follows the same pattern used by other input components like `TextWidget.tsx` ## Test plan - [x] Open the new builder - [x] Add a custom node with an Object input field - [x] Try to select text in the key name input by clicking and dragging - [x] Verify that text selection works without moving the block Co-authored-by: Claude --- .../InputRenderer/base/object/WrapIfAdditionalTemplate.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autogpt_platform/frontend/src/components/renderers/InputRenderer/base/object/WrapIfAdditionalTemplate.tsx b/autogpt_platform/frontend/src/components/renderers/InputRenderer/base/object/WrapIfAdditionalTemplate.tsx index 97478e9eaf..a8b3514d41 100644 --- a/autogpt_platform/frontend/src/components/renderers/InputRenderer/base/object/WrapIfAdditionalTemplate.tsx +++ b/autogpt_platform/frontend/src/components/renderers/InputRenderer/base/object/WrapIfAdditionalTemplate.tsx @@ -80,7 +80,7 @@ export default function WrapIfAdditionalTemplate( uiSchema={uiSchema} /> {!isHandleConnected && ( -
+