mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-10 14:55:16 -05:00
Compare commits
68 Commits
fix/copilo
...
feat/mcp-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ac025da09 | ||
|
|
fbe4c740cb | ||
|
|
ff48f4335b | ||
|
|
11fbb51a70 | ||
|
|
bb8b56c7de | ||
|
|
6ebd97f874 | ||
|
|
e934f0d0c2 | ||
|
|
3e38b141dd | ||
|
|
ec72e7eb7b | ||
|
|
db038cd0e0 | ||
|
|
6805a4f3c5 | ||
|
|
7d4c020a9b | ||
|
|
cb7a0cbdd7 | ||
|
|
79d6e8e2d7 | ||
|
|
472117a872 | ||
|
|
75a7ccf36e | ||
|
|
84809f4b94 | ||
|
|
4364a771d4 | ||
|
|
4d4ed562f0 | ||
|
|
e596ea87cb | ||
|
|
8bea7cf875 | ||
|
|
c1c269c4a9 | ||
|
|
65987ff15e | ||
|
|
ed50f7f87d | ||
|
|
c03fb170e0 | ||
|
|
8a2f98b23c | ||
|
|
5e2ae3cec5 | ||
|
|
f8771484fe | ||
|
|
81e4f0a4b0 | ||
|
|
66aada30f0 | ||
|
|
74e04f71f4 | ||
|
|
81f8290f01 | ||
|
|
4db27ca112 | ||
|
|
27ba4e8e93 | ||
|
|
1a1985186a | ||
|
|
8fd13ade74 | ||
|
|
88ee4b3a11 | ||
|
|
8eed4ad653 | ||
|
|
7744b89e96 | ||
|
|
4c02cd8f2f | ||
|
|
909f313e1e | ||
|
|
edd9a90903 | ||
|
|
ba031329e9 | ||
|
|
6ab1a6867e | ||
|
|
d9269310cc | ||
|
|
fe70b6929f | ||
|
|
340520ba85 | ||
|
|
6467f6734f | ||
|
|
6c2791b00b | ||
|
|
5a30d11416 | ||
|
|
7decc20a32 | ||
|
|
54375065d5 | ||
|
|
d62fde9445 | ||
|
|
03487f7b4d | ||
|
|
1f4105e8f9 | ||
|
|
caf9ff34e6 | ||
|
|
df41d02fce | ||
|
|
7c9e47ba76 | ||
|
|
e8fc8ee623 | ||
|
|
1a16e203b8 | ||
|
|
e59e8dd9a9 | ||
|
|
7aab2eb1d5 | ||
|
|
5ab28ccda2 | ||
|
|
4fe0f05980 | ||
|
|
19b3373052 | ||
|
|
7db3f12876 | ||
|
|
e9b996abb0 | ||
|
|
9b972389a0 |
2
.github/workflows/claude-dependabot.yml
vendored
2
.github/workflows/claude-dependabot.yml
vendored
@@ -78,7 +78,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/claude.yml
vendored
2
.github/workflows/claude.yml
vendored
@@ -94,7 +94,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/copilot-setup-steps.yml
vendored
2
.github/workflows/copilot-setup-steps.yml
vendored
@@ -76,7 +76,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
|
|||||||
10
.github/workflows/platform-frontend-ci.yml
vendored
10
.github/workflows/platform-frontend-ci.yml
vendored
@@ -42,7 +42,7 @@ jobs:
|
|||||||
- 'autogpt_platform/frontend/src/components/**'
|
- 'autogpt_platform/frontend/src/components/**'
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -74,7 +74,7 @@ jobs:
|
|||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -112,7 +112,7 @@ jobs:
|
|||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -153,7 +153,7 @@ jobs:
|
|||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -282,7 +282,7 @@ jobs:
|
|||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
|
|||||||
8
.github/workflows/platform-fullstack-ci.yml
vendored
8
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
|||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -56,7 +56,7 @@ jobs:
|
|||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
types:
|
types:
|
||||||
runs-on: ubuntu-latest
|
runs-on: big-boi
|
||||||
needs: setup
|
needs: setup
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
@@ -68,7 +68,7 @@ jobs:
|
|||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -85,7 +85,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Run docker compose
|
- name: Run docker compose
|
||||||
run: |
|
run: |
|
||||||
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v5
|
uses: actions/cache@v5
|
||||||
|
|||||||
220
autogpt_platform/autogpt_libs/poetry.lock
generated
220
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1,4 +1,4 @@
|
|||||||
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "annotated-doc"
|
name = "annotated-doc"
|
||||||
@@ -67,7 +67,7 @@ description = "Backport of asyncio.Runner, a context manager that controls event
|
|||||||
optional = false
|
optional = false
|
||||||
python-versions = "<3.11,>=3.8"
|
python-versions = "<3.11,>=3.8"
|
||||||
groups = ["dev"]
|
groups = ["dev"]
|
||||||
markers = "python_version == \"3.10\""
|
markers = "python_version < \"3.11\""
|
||||||
files = [
|
files = [
|
||||||
{file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"},
|
{file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"},
|
||||||
{file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"},
|
{file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"},
|
||||||
@@ -326,100 +326,118 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "coverage"
|
name = "coverage"
|
||||||
version = "7.10.5"
|
version = "7.13.4"
|
||||||
description = "Code coverage measurement for Python"
|
description = "Code coverage measurement for Python"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.10"
|
||||||
groups = ["dev"]
|
groups = ["dev"]
|
||||||
files = [
|
files = [
|
||||||
{file = "coverage-7.10.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c6a5c3414bfc7451b879141ce772c546985163cf553f08e0f135f0699a911801"},
|
{file = "coverage-7.13.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0fc31c787a84f8cd6027eba44010517020e0d18487064cd3d8968941856d1415"},
|
||||||
{file = "coverage-7.10.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bc8e4d99ce82f1710cc3c125adc30fd1487d3cf6c2cd4994d78d68a47b16989a"},
|
{file = "coverage-7.13.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a32ebc02a1805adf637fc8dec324b5cdacd2e493515424f70ee33799573d661b"},
|
||||||
{file = "coverage-7.10.5-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:02252dc1216e512a9311f596b3169fad54abcb13827a8d76d5630c798a50a754"},
|
{file = "coverage-7.13.4-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:e24f9156097ff9dc286f2f913df3a7f63c0e333dcafa3c196f2c18b4175ca09a"},
|
||||||
{file = "coverage-7.10.5-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:73269df37883e02d460bee0cc16be90509faea1e3bd105d77360b512d5bb9c33"},
|
{file = "coverage-7.13.4-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8041b6c5bfdc03257666e9881d33b1abc88daccaf73f7b6340fb7946655cd10f"},
|
||||||
{file = "coverage-7.10.5-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1f8a81b0614642f91c9effd53eec284f965577591f51f547a1cbeb32035b4c2f"},
|
{file = "coverage-7.13.4-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2a09cfa6a5862bc2fc6ca7c3def5b2926194a56b8ab78ffcf617d28911123012"},
|
||||||
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6a29f8e0adb7f8c2b95fa2d4566a1d6e6722e0a637634c6563cb1ab844427dd9"},
|
{file = "coverage-7.13.4-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:296f8b0af861d3970c2a4d8c91d48eb4dd4771bcef9baedec6a9b515d7de3def"},
|
||||||
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:fcf6ab569436b4a647d4e91accba12509ad9f2554bc93d3aee23cc596e7f99c3"},
|
{file = "coverage-7.13.4-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e101609bcbbfb04605ea1027b10dc3735c094d12d40826a60f897b98b1c30256"},
|
||||||
{file = "coverage-7.10.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:90dc3d6fb222b194a5de60af8d190bedeeddcbc7add317e4a3cd333ee6b7c879"},
|
{file = "coverage-7.13.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:aa3feb8db2e87ff5e6d00d7e1480ae241876286691265657b500886c98f38bda"},
|
||||||
{file = "coverage-7.10.5-cp310-cp310-win32.whl", hash = "sha256:414a568cd545f9dc75f0686a0049393de8098414b58ea071e03395505b73d7a8"},
|
{file = "coverage-7.13.4-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:4fc7fa81bbaf5a02801b65346c8b3e657f1d93763e58c0abdf7c992addd81a92"},
|
||||||
{file = "coverage-7.10.5-cp310-cp310-win_amd64.whl", hash = "sha256:e551f9d03347196271935fd3c0c165f0e8c049220280c1120de0084d65e9c7ff"},
|
{file = "coverage-7.13.4-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:33901f604424145c6e9c2398684b92e176c0b12df77d52db81c20abd48c3794c"},
|
||||||
{file = "coverage-7.10.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c177e6ffe2ebc7c410785307758ee21258aa8e8092b44d09a2da767834f075f2"},
|
{file = "coverage-7.13.4-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:bb28c0f2cf2782508a40cec377935829d5fcc3ad9a3681375af4e84eb34b6b58"},
|
||||||
{file = "coverage-7.10.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:14d6071c51ad0f703d6440827eaa46386169b5fdced42631d5a5ac419616046f"},
|
{file = "coverage-7.13.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9d107aff57a83222ddbd8d9ee705ede2af2cc926608b57abed8ef96b50b7e8f9"},
|
||||||
{file = "coverage-7.10.5-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:61f78c7c3bc272a410c5ae3fde7792b4ffb4acc03d35a7df73ca8978826bb7ab"},
|
{file = "coverage-7.13.4-cp310-cp310-win32.whl", hash = "sha256:a6f94a7d00eb18f1b6d403c91a88fd58cfc92d4b16080dfdb774afc8294469bf"},
|
||||||
{file = "coverage-7.10.5-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f39071caa126f69d63f99b324fb08c7b1da2ec28cbb1fe7b5b1799926492f65c"},
|
{file = "coverage-7.13.4-cp310-cp310-win_amd64.whl", hash = "sha256:2cb0f1e000ebc419632bbe04366a8990b6e32c4e0b51543a6484ffe15eaeda95"},
|
||||||
{file = "coverage-7.10.5-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:343a023193f04d46edc46b2616cdbee68c94dd10208ecd3adc56fcc54ef2baa1"},
|
{file = "coverage-7.13.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d490ba50c3f35dd7c17953c68f3270e7ccd1c6642e2d2afe2d8e720b98f5a053"},
|
||||||
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:585ffe93ae5894d1ebdee69fc0b0d4b7c75d8007983692fb300ac98eed146f78"},
|
{file = "coverage-7.13.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:19bc3c88078789f8ef36acb014d7241961dbf883fd2533d18cb1e7a5b4e28b11"},
|
||||||
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b0ef4e66f006ed181df29b59921bd8fc7ed7cd6a9289295cd8b2824b49b570df"},
|
{file = "coverage-7.13.4-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3998e5a32e62fdf410c0dbd3115df86297995d6e3429af80b8798aad894ca7aa"},
|
||||||
{file = "coverage-7.10.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:eb7b0bbf7cc1d0453b843eca7b5fa017874735bef9bfdfa4121373d2cc885ed6"},
|
{file = "coverage-7.13.4-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8e264226ec98e01a8e1054314af91ee6cde0eacac4f465cc93b03dbe0bce2fd7"},
|
||||||
{file = "coverage-7.10.5-cp311-cp311-win32.whl", hash = "sha256:1d043a8a06987cc0c98516e57c4d3fc2c1591364831e9deb59c9e1b4937e8caf"},
|
{file = "coverage-7.13.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a3aa4e7b9e416774b21797365b358a6e827ffadaaca81b69ee02946852449f00"},
|
||||||
{file = "coverage-7.10.5-cp311-cp311-win_amd64.whl", hash = "sha256:fefafcca09c3ac56372ef64a40f5fe17c5592fab906e0fdffd09543f3012ba50"},
|
{file = "coverage-7.13.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:71ca20079dd8f27fcf808817e281e90220475cd75115162218d0e27549f95fef"},
|
||||||
{file = "coverage-7.10.5-cp311-cp311-win_arm64.whl", hash = "sha256:7e78b767da8b5fc5b2faa69bb001edafcd6f3995b42a331c53ef9572c55ceb82"},
|
{file = "coverage-7.13.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e2f25215f1a359ab17320b47bcdaca3e6e6356652e8256f2441e4ef972052903"},
|
||||||
{file = "coverage-7.10.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c2d05c7e73c60a4cecc7d9b60dbfd603b4ebc0adafaef371445b47d0f805c8a9"},
|
{file = "coverage-7.13.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d65b2d373032411e86960604dc4edac91fdfb5dca539461cf2cbe78327d1e64f"},
|
||||||
{file = "coverage-7.10.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:32ddaa3b2c509778ed5373b177eb2bf5662405493baeff52278a0b4f9415188b"},
|
{file = "coverage-7.13.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:94eb63f9b363180aff17de3e7c8760c3ba94664ea2695c52f10111244d16a299"},
|
||||||
{file = "coverage-7.10.5-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:dd382410039fe062097aa0292ab6335a3f1e7af7bba2ef8d27dcda484918f20c"},
|
{file = "coverage-7.13.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e856bf6616714c3a9fbc270ab54103f4e685ba236fa98c054e8f87f266c93505"},
|
||||||
{file = "coverage-7.10.5-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7fa22800f3908df31cea6fb230f20ac49e343515d968cc3a42b30d5c3ebf9b5a"},
|
{file = "coverage-7.13.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:65dfcbe305c3dfe658492df2d85259e0d79ead4177f9ae724b6fb245198f55d6"},
|
||||||
{file = "coverage-7.10.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f366a57ac81f5e12797136552f5b7502fa053c861a009b91b80ed51f2ce651c6"},
|
{file = "coverage-7.13.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b507778ae8a4c915436ed5c2e05b4a6cecfa70f734e19c22a005152a11c7b6a9"},
|
||||||
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5f1dc8f1980a272ad4a6c84cba7981792344dad33bf5869361576b7aef42733a"},
|
{file = "coverage-7.13.4-cp311-cp311-win32.whl", hash = "sha256:784fc3cf8be001197b652d51d3fd259b1e2262888693a4636e18879f613a62a9"},
|
||||||
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:2285c04ee8676f7938b02b4936d9b9b672064daab3187c20f73a55f3d70e6b4a"},
|
{file = "coverage-7.13.4-cp311-cp311-win_amd64.whl", hash = "sha256:2421d591f8ca05b308cf0092807308b2facbefe54af7c02ac22548b88b95c98f"},
|
||||||
{file = "coverage-7.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c2492e4dd9daab63f5f56286f8a04c51323d237631eb98505d87e4c4ff19ec34"},
|
{file = "coverage-7.13.4-cp311-cp311-win_arm64.whl", hash = "sha256:79e73a76b854d9c6088fe5d8b2ebe745f8681c55f7397c3c0a016192d681045f"},
|
||||||
{file = "coverage-7.10.5-cp312-cp312-win32.whl", hash = "sha256:38a9109c4ee8135d5df5505384fc2f20287a47ccbe0b3f04c53c9a1989c2bbaf"},
|
{file = "coverage-7.13.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:02231499b08dabbe2b96612993e5fc34217cdae907a51b906ac7fca8027a4459"},
|
||||||
{file = "coverage-7.10.5-cp312-cp312-win_amd64.whl", hash = "sha256:6b87f1ad60b30bc3c43c66afa7db6b22a3109902e28c5094957626a0143a001f"},
|
{file = "coverage-7.13.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40aa8808140e55dc022b15d8aa7f651b6b3d68b365ea0398f1441e0b04d859c3"},
|
||||||
{file = "coverage-7.10.5-cp312-cp312-win_arm64.whl", hash = "sha256:672a6c1da5aea6c629819a0e1461e89d244f78d7b60c424ecf4f1f2556c041d8"},
|
{file = "coverage-7.13.4-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:5b856a8ccf749480024ff3bd7310adaef57bf31fd17e1bfc404b7940b6986634"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ef3b83594d933020f54cf65ea1f4405d1f4e41a009c46df629dd964fcb6e907c"},
|
{file = "coverage-7.13.4-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:2c048ea43875fbf8b45d476ad79f179809c590ec7b79e2035c662e7afa3192e3"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2b96bfdf7c0ea9faebce088a3ecb2382819da4fbc05c7b80040dbc428df6af44"},
|
{file = "coverage-7.13.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b7b38448866e83176e28086674fe7368ab8590e4610fb662b44e345b86d63ffa"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:63df1fdaffa42d914d5c4d293e838937638bf75c794cf20bee12978fc8c4e3bc"},
|
{file = "coverage-7.13.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:de6defc1c9badbf8b9e67ae90fd00519186d6ab64e5cc5f3d21359c2a9b2c1d3"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8002dc6a049aac0e81ecec97abfb08c01ef0c1fbf962d0c98da3950ace89b869"},
|
{file = "coverage-7.13.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:7eda778067ad7ffccd23ecffce537dface96212576a07924cbf0d8799d2ded5a"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:63d4bb2966d6f5f705a6b0c6784c8969c468dbc4bcf9d9ded8bff1c7e092451f"},
|
{file = "coverage-7.13.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e87f6c587c3f34356c3759f0420693e35e7eb0e2e41e4c011cb6ec6ecbbf1db7"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:1f672efc0731a6846b157389b6e6d5d5e9e59d1d1a23a5c66a99fd58339914d5"},
|
{file = "coverage-7.13.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:8248977c2e33aecb2ced42fef99f2d319e9904a36e55a8a68b69207fb7e43edc"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:3f39cef43d08049e8afc1fde4a5da8510fc6be843f8dea350ee46e2a26b2f54c"},
|
{file = "coverage-7.13.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:25381386e80ae727608e662474db537d4df1ecd42379b5ba33c84633a2b36d47"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2968647e3ed5a6c019a419264386b013979ff1fb67dd11f5c9886c43d6a31fc2"},
|
{file = "coverage-7.13.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:ee756f00726693e5ba94d6df2bdfd64d4852d23b09bb0bc700e3b30e6f333985"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313-win32.whl", hash = "sha256:0d511dda38595b2b6934c2b730a1fd57a3635c6aa2a04cb74714cdfdd53846f4"},
|
{file = "coverage-7.13.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fdfc1e28e7c7cdce44985b3043bc13bbd9c747520f94a4d7164af8260b3d91f0"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313-win_amd64.whl", hash = "sha256:9a86281794a393513cf117177fd39c796b3f8e3759bb2764259a2abba5cce54b"},
|
{file = "coverage-7.13.4-cp312-cp312-win32.whl", hash = "sha256:01d4cbc3c283a17fc1e42d614a119f7f438eabb593391283adca8dc86eff1246"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313-win_arm64.whl", hash = "sha256:cebd8e906eb98bb09c10d1feed16096700b1198d482267f8bf0474e63a7b8d84"},
|
{file = "coverage-7.13.4-cp312-cp312-win_amd64.whl", hash = "sha256:9401ebc7ef522f01d01d45532c68c5ac40fb27113019b6b7d8b208f6e9baa126"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0520dff502da5e09d0d20781df74d8189ab334a1e40d5bafe2efaa4158e2d9e7"},
|
{file = "coverage-7.13.4-cp312-cp312-win_arm64.whl", hash = "sha256:b1ec7b6b6e93255f952e27ab58fbc68dcc468844b16ecbee881aeb29b6ab4d8d"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:d9cd64aca68f503ed3f1f18c7c9174cbb797baba02ca8ab5112f9d1c0328cd4b"},
|
{file = "coverage-7.13.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b66a2da594b6068b48b2692f043f35d4d3693fb639d5ea8b39533c2ad9ac3ab9"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:0913dd1613a33b13c4f84aa6e3f4198c1a21ee28ccb4f674985c1f22109f0aae"},
|
{file = "coverage-7.13.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3599eb3992d814d23b35c536c28df1a882caa950f8f507cef23d1cbf334995ac"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:1b7181c0feeb06ed8a02da02792f42f829a7b29990fef52eff257fef0885d760"},
|
{file = "coverage-7.13.4-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:93550784d9281e374fb5a12bf1324cc8a963fd63b2d2f223503ef0fd4aa339ea"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:36d42b7396b605f774d4372dd9c49bed71cbabce4ae1ccd074d155709dd8f235"},
|
{file = "coverage-7.13.4-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b720ce6a88a2755f7c697c23268ddc47a571b88052e6b155224347389fdf6a3b"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b4fdc777e05c4940b297bf47bf7eedd56a39a61dc23ba798e4b830d585486ca5"},
|
{file = "coverage-7.13.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7b322db1284a2ed3aa28ffd8ebe3db91c929b7a333c0820abec3d838ef5b3525"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:42144e8e346de44a6f1dbd0a56575dd8ab8dfa7e9007da02ea5b1c30ab33a7db"},
|
{file = "coverage-7.13.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f4594c67d8a7c89cf922d9df0438c7c7bb022ad506eddb0fdb2863359ff78242"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:66c644cbd7aed8fe266d5917e2c9f65458a51cfe5eeff9c05f15b335f697066e"},
|
{file = "coverage-7.13.4-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:53d133df809c743eb8bce33b24bcababb371f4441340578cd406e084d94a6148"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313t-win32.whl", hash = "sha256:2d1b73023854068c44b0c554578a4e1ef1b050ed07cf8b431549e624a29a66ee"},
|
{file = "coverage-7.13.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:76451d1978b95ba6507a039090ba076105c87cc76fc3efd5d35d72093964d49a"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313t-win_amd64.whl", hash = "sha256:54a1532c8a642d8cc0bd5a9a51f5a9dcc440294fd06e9dda55e743c5ec1a8f14"},
|
{file = "coverage-7.13.4-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7f57b33491e281e962021de110b451ab8a24182589be17e12a22c79047935e23"},
|
||||||
{file = "coverage-7.10.5-cp313-cp313t-win_arm64.whl", hash = "sha256:74d5b63fe3f5f5d372253a4ef92492c11a4305f3550631beaa432fc9df16fcff"},
|
{file = "coverage-7.13.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:1731dc33dc276dafc410a885cbf5992f1ff171393e48a21453b78727d090de80"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:68c5e0bc5f44f68053369fa0d94459c84548a77660a5f2561c5e5f1e3bed7031"},
|
{file = "coverage-7.13.4-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:bd60d4fe2f6fa7dff9223ca1bbc9f05d2b6697bc5961072e5d3b952d46e1b1ea"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:cf33134ffae93865e32e1e37df043bef15a5e857d8caebc0099d225c579b0fa3"},
|
{file = "coverage-7.13.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9181a3ccead280b828fae232df12b16652702b49d41e99d657f46cc7b1f6ec7a"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:ad8fa9d5193bafcf668231294241302b5e683a0518bf1e33a9a0dfb142ec3031"},
|
{file = "coverage-7.13.4-cp313-cp313-win32.whl", hash = "sha256:f53d492307962561ac7de4cd1de3e363589b000ab69617c6156a16ba7237998d"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:146fa1531973d38ab4b689bc764592fe6c2f913e7e80a39e7eeafd11f0ef6db2"},
|
{file = "coverage-7.13.4-cp313-cp313-win_amd64.whl", hash = "sha256:e6f70dec1cc557e52df5306d051ef56003f74d56e9c4dd7ddb07e07ef32a84dd"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6013a37b8a4854c478d3219ee8bc2392dea51602dd0803a12d6f6182a0061762"},
|
{file = "coverage-7.13.4-cp313-cp313-win_arm64.whl", hash = "sha256:fb07dc5da7e849e2ad31a5d74e9bece81f30ecf5a42909d0a695f8bd1874d6af"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:eb90fe20db9c3d930fa2ad7a308207ab5b86bf6a76f54ab6a40be4012d88fcae"},
|
{file = "coverage-7.13.4-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:40d74da8e6c4b9ac18b15331c4b5ebc35a17069410cad462ad4f40dcd2d50c0d"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:384b34482272e960c438703cafe63316dfbea124ac62006a455c8410bf2a2262"},
|
{file = "coverage-7.13.4-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:4223b4230a376138939a9173f1bdd6521994f2aff8047fae100d6d94d50c5a12"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:467dc74bd0a1a7de2bedf8deaf6811f43602cb532bd34d81ffd6038d6d8abe99"},
|
{file = "coverage-7.13.4-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:1d4be36a5114c499f9f1f9195e95ebf979460dbe2d88e6816ea202010ba1c34b"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314-win32.whl", hash = "sha256:556d23d4e6393ca898b2e63a5bca91e9ac2d5fb13299ec286cd69a09a7187fde"},
|
{file = "coverage-7.13.4-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:200dea7d1e8095cc6e98cdabe3fd1d21ab17d3cee6dab00cadbb2fe35d9c15b9"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314-win_amd64.whl", hash = "sha256:f4446a9547681533c8fa3e3c6cf62121eeee616e6a92bd9201c6edd91beffe13"},
|
{file = "coverage-7.13.4-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b8eb931ee8e6d8243e253e5ed7336deea6904369d2fd8ae6e43f68abbf167092"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314-win_arm64.whl", hash = "sha256:5e78bd9cf65da4c303bf663de0d73bf69f81e878bf72a94e9af67137c69b9fe9"},
|
{file = "coverage-7.13.4-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:75eab1ebe4f2f64d9509b984f9314d4aa788540368218b858dad56dc8f3e5eb9"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:5661bf987d91ec756a47c7e5df4fbcb949f39e32f9334ccd3f43233bbb65e508"},
|
{file = "coverage-7.13.4-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c35eb28c1d085eb7d8c9b3296567a1bebe03ce72962e932431b9a61f28facf26"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:a46473129244db42a720439a26984f8c6f834762fc4573616c1f37f13994b357"},
|
{file = "coverage-7.13.4-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:eb88b316ec33760714a4720feb2816a3a59180fd58c1985012054fa7aebee4c2"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:1f64b8d3415d60f24b058b58d859e9512624bdfa57a2d1f8aff93c1ec45c429b"},
|
{file = "coverage-7.13.4-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:7d41eead3cc673cbd38a4417deb7fd0b4ca26954ff7dc6078e33f6ff97bed940"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:44d43de99a9d90b20e0163f9770542357f58860a26e24dc1d924643bd6aa7cb4"},
|
{file = "coverage-7.13.4-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:fb26a934946a6afe0e326aebe0730cdff393a8bc0bbb65a2f41e30feddca399c"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a931a87e5ddb6b6404e65443b742cb1c14959622777f2a4efd81fba84f5d91ba"},
|
{file = "coverage-7.13.4-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:dae88bc0fc77edaa65c14be099bd57ee140cf507e6bfdeea7938457ab387efb0"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:f9559b906a100029274448f4c8b8b0a127daa4dade5661dfd821b8c188058842"},
|
{file = "coverage-7.13.4-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:845f352911777a8e722bfce168958214951e07e47e5d5d9744109fa5fe77f79b"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:b08801e25e3b4526ef9ced1aa29344131a8f5213c60c03c18fe4c6170ffa2874"},
|
{file = "coverage-7.13.4-cp313-cp313t-win32.whl", hash = "sha256:2fa8d5f8de70688a28240de9e139fa16b153cc3cbb01c5f16d88d6505ebdadf9"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:ed9749bb8eda35f8b636fb7632f1c62f735a236a5d4edadd8bbcc5ea0542e732"},
|
{file = "coverage-7.13.4-cp313-cp313t-win_amd64.whl", hash = "sha256:9351229c8c8407645840edcc277f4a2d44814d1bc34a2128c11c2a031d45a5dd"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314t-win32.whl", hash = "sha256:609b60d123fc2cc63ccee6d17e4676699075db72d14ac3c107cc4976d516f2df"},
|
{file = "coverage-7.13.4-cp313-cp313t-win_arm64.whl", hash = "sha256:30b8d0512f2dc8c8747557e8fb459d6176a2c9e5731e2b74d311c03b78451997"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314t-win_amd64.whl", hash = "sha256:0666cf3d2c1626b5a3463fd5b05f5e21f99e6aec40a3192eee4d07a15970b07f"},
|
{file = "coverage-7.13.4-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:300deaee342f90696ed186e3a00c71b5b3d27bffe9e827677954f4ee56969601"},
|
||||||
{file = "coverage-7.10.5-cp314-cp314t-win_arm64.whl", hash = "sha256:bc85eb2d35e760120540afddd3044a5bf69118a91a296a8b3940dfc4fdcfe1e2"},
|
{file = "coverage-7.13.4-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:29e3220258d682b6226a9b0925bc563ed9a1ebcff3cad30f043eceea7eaf2689"},
|
||||||
{file = "coverage-7.10.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:62835c1b00c4a4ace24c1a88561a5a59b612fbb83a525d1c70ff5720c97c0610"},
|
{file = "coverage-7.13.4-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:391ee8f19bef69210978363ca930f7328081c6a0152f1166c91f0b5fdd2a773c"},
|
||||||
{file = "coverage-7.10.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5255b3bbcc1d32a4069d6403820ac8e6dbcc1d68cb28a60a1ebf17e47028e898"},
|
{file = "coverage-7.13.4-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:0dd7ab8278f0d58a0128ba2fca25824321f05d059c1441800e934ff2efa52129"},
|
||||||
{file = "coverage-7.10.5-cp39-cp39-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3876385722e335d6e991c430302c24251ef9c2a9701b2b390f5473199b1b8ebf"},
|
{file = "coverage-7.13.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:78cdf0d578b15148b009ccf18c686aa4f719d887e76e6b40c38ffb61d264a552"},
|
||||||
{file = "coverage-7.10.5-cp39-cp39-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8048ce4b149c93447a55d279078c8ae98b08a6951a3c4d2d7e87f4efc7bfe100"},
|
{file = "coverage-7.13.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:48685fee12c2eb3b27c62f2658e7ea21e9c3239cba5a8a242801a0a3f6a8c62a"},
|
||||||
{file = "coverage-7.10.5-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4028e7558e268dd8bcf4d9484aad393cafa654c24b4885f6f9474bf53183a82a"},
|
{file = "coverage-7.13.4-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:4e83efc079eb39480e6346a15a1bcb3e9b04759c5202d157e1dd4303cd619356"},
|
||||||
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:03f47dc870eec0367fcdd603ca6a01517d2504e83dc18dbfafae37faec66129a"},
|
{file = "coverage-7.13.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:ecae9737b72408d6a950f7e525f30aca12d4bd8dd95e37342e5beb3a2a8c4f71"},
|
||||||
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:2d488d7d42b6ded7ea0704884f89dcabd2619505457de8fc9a6011c62106f6e5"},
|
{file = "coverage-7.13.4-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:ae4578f8528569d3cf303fef2ea569c7f4c4059a38c8667ccef15c6e1f118aa5"},
|
||||||
{file = "coverage-7.10.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:b3dcf2ead47fa8be14224ee817dfc1df98043af568fe120a22f81c0eb3c34ad2"},
|
{file = "coverage-7.13.4-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:6fdef321fdfbb30a197efa02d48fcd9981f0d8ad2ae8903ac318adc653f5df98"},
|
||||||
{file = "coverage-7.10.5-cp39-cp39-win32.whl", hash = "sha256:02650a11324b80057b8c9c29487020073d5e98a498f1857f37e3f9b6ea1b2426"},
|
{file = "coverage-7.13.4-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:2b0f6ccf3dbe577170bebfce1318707d0e8c3650003cb4b3a9dd744575daa8b5"},
|
||||||
{file = "coverage-7.10.5-cp39-cp39-win_amd64.whl", hash = "sha256:b45264dd450a10f9e03237b41a9a24e85cbb1e278e5a32adb1a303f58f0017f3"},
|
{file = "coverage-7.13.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:75fcd519f2a5765db3f0e391eb3b7d150cce1a771bf4c9f861aeab86c767a3c0"},
|
||||||
{file = "coverage-7.10.5-py3-none-any.whl", hash = "sha256:0be24d35e4db1d23d0db5c0f6a74a962e2ec83c426b5cac09f4234aadef38e4a"},
|
{file = "coverage-7.13.4-cp314-cp314-win32.whl", hash = "sha256:8e798c266c378da2bd819b0677df41ab46d78065fb2a399558f3f6cae78b2fbb"},
|
||||||
{file = "coverage-7.10.5.tar.gz", hash = "sha256:f2e57716a78bc3ae80b2207be0709a3b2b63b9f2dcf9740ee6ac03588a2015b6"},
|
{file = "coverage-7.13.4-cp314-cp314-win_amd64.whl", hash = "sha256:245e37f664d89861cf2329c9afa2c1fe9e6d4e1a09d872c947e70718aeeac505"},
|
||||||
|
{file = "coverage-7.13.4-cp314-cp314-win_arm64.whl", hash = "sha256:ad27098a189e5838900ce4c2a99f2fe42a0bf0c2093c17c69b45a71579e8d4a2"},
|
||||||
|
{file = "coverage-7.13.4-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:85480adfb35ffc32d40918aad81b89c69c9cc5661a9b8a81476d3e645321a056"},
|
||||||
|
{file = "coverage-7.13.4-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:79be69cf7f3bf9b0deeeb062eab7ac7f36cd4cc4c4dd694bd28921ba4d8596cc"},
|
||||||
|
{file = "coverage-7.13.4-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:caa421e2684e382c5d8973ac55e4f36bed6821a9bad5c953494de960c74595c9"},
|
||||||
|
{file = "coverage-7.13.4-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:14375934243ee05f56c45393fe2ce81fe5cc503c07cee2bdf1725fb8bef3ffaf"},
|
||||||
|
{file = "coverage-7.13.4-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:25a41c3104d08edb094d9db0d905ca54d0cd41c928bb6be3c4c799a54753af55"},
|
||||||
|
{file = "coverage-7.13.4-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:6f01afcff62bf9a08fb32b2c1d6e924236c0383c02c790732b6537269e466a72"},
|
||||||
|
{file = "coverage-7.13.4-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:eb9078108fbf0bcdde37c3f4779303673c2fa1fe8f7956e68d447d0dd426d38a"},
|
||||||
|
{file = "coverage-7.13.4-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:0e086334e8537ddd17e5f16a344777c1ab8194986ec533711cbe6c41cde841b6"},
|
||||||
|
{file = "coverage-7.13.4-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:725d985c5ab621268b2edb8e50dfe57633dc69bda071abc470fed55a14935fd3"},
|
||||||
|
{file = "coverage-7.13.4-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:3c06f0f1337c667b971ca2f975523347e63ec5e500b9aa5882d91931cd3ef750"},
|
||||||
|
{file = "coverage-7.13.4-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:590c0ed4bf8e85f745e6b805b2e1c457b2e33d5255dd9729743165253bc9ad39"},
|
||||||
|
{file = "coverage-7.13.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:eb30bf180de3f632cd043322dad5751390e5385108b2807368997d1a92a509d0"},
|
||||||
|
{file = "coverage-7.13.4-cp314-cp314t-win32.whl", hash = "sha256:c4240e7eded42d131a2d2c4dec70374b781b043ddc79a9de4d55ca71f8e98aea"},
|
||||||
|
{file = "coverage-7.13.4-cp314-cp314t-win_amd64.whl", hash = "sha256:4c7d3cc01e7350f2f0f6f7036caaf5673fb56b6998889ccfe9e1c1fe75a9c932"},
|
||||||
|
{file = "coverage-7.13.4-cp314-cp314t-win_arm64.whl", hash = "sha256:23e3f687cf945070d1c90f85db66d11e3025665d8dafa831301a0e0038f3db9b"},
|
||||||
|
{file = "coverage-7.13.4-py3-none-any.whl", hash = "sha256:1af1641e57cf7ba1bd67d677c9abdbcd6cc2ab7da3bca7fa1e2b7e50e65f2ad0"},
|
||||||
|
{file = "coverage-7.13.4.tar.gz", hash = "sha256:e5c8f6ed1e61a8b2dcdf31eb0b9bbf0130750ca79c1c49eb898e2ad86f5ccc91"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -523,7 +541,7 @@ description = "Backport of PEP 654 (exception groups)"
|
|||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
groups = ["main", "dev"]
|
groups = ["main", "dev"]
|
||||||
markers = "python_version == \"3.10\""
|
markers = "python_version < \"3.11\""
|
||||||
files = [
|
files = [
|
||||||
{file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"},
|
{file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"},
|
||||||
{file = "exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88"},
|
{file = "exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88"},
|
||||||
@@ -2162,23 +2180,23 @@ testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytest-cov"
|
name = "pytest-cov"
|
||||||
version = "6.2.1"
|
version = "7.0.0"
|
||||||
description = "Pytest plugin for measuring coverage."
|
description = "Pytest plugin for measuring coverage."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["dev"]
|
groups = ["dev"]
|
||||||
files = [
|
files = [
|
||||||
{file = "pytest_cov-6.2.1-py3-none-any.whl", hash = "sha256:f5bc4c23f42f1cdd23c70b1dab1bbaef4fc505ba950d53e0081d0730dd7e86d5"},
|
{file = "pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861"},
|
||||||
{file = "pytest_cov-6.2.1.tar.gz", hash = "sha256:25cc6cc0a5358204b8108ecedc51a9b57b34cc6b8c967cc2c01a4e00d8a67da2"},
|
{file = "pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
coverage = {version = ">=7.5", extras = ["toml"]}
|
coverage = {version = ">=7.10.6", extras = ["toml"]}
|
||||||
pluggy = ">=1.2"
|
pluggy = ">=1.2"
|
||||||
pytest = ">=6.2.5"
|
pytest = ">=7"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
|
testing = ["process-tests", "pytest-xdist", "virtualenv"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytest-mock"
|
name = "pytest-mock"
|
||||||
@@ -2545,7 +2563,7 @@ description = "A lil' TOML parser"
|
|||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
groups = ["dev"]
|
groups = ["dev"]
|
||||||
markers = "python_version == \"3.10\""
|
markers = "python_version < \"3.11\""
|
||||||
files = [
|
files = [
|
||||||
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
|
{file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"},
|
||||||
{file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
|
{file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"},
|
||||||
@@ -2893,4 +2911,4 @@ type = ["pytest-mypy"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = ">=3.10,<4.0"
|
python-versions = ">=3.10,<4.0"
|
||||||
content-hash = "b7ac335a86aa44c3d7d2802298818b389a6f1286e3e9b7b0edb2ff06377cecaf"
|
content-hash = "40eae94995dc0a388fa832ed4af9b6137f28d5b5ced3aaea70d5f91d4d9a179d"
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ pyright = "^1.1.408"
|
|||||||
pytest = "^8.4.1"
|
pytest = "^8.4.1"
|
||||||
pytest-asyncio = "^1.3.0"
|
pytest-asyncio = "^1.3.0"
|
||||||
pytest-mock = "^3.15.1"
|
pytest-mock = "^3.15.1"
|
||||||
pytest-cov = "^6.2.1"
|
pytest-cov = "^7.0.0"
|
||||||
ruff = "^0.15.0"
|
ruff = "^0.15.0"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|||||||
@@ -45,10 +45,7 @@ async def create_chat_session(
|
|||||||
successfulAgentRuns=SafeJson({}),
|
successfulAgentRuns=SafeJson({}),
|
||||||
successfulAgentSchedules=SafeJson({}),
|
successfulAgentSchedules=SafeJson({}),
|
||||||
)
|
)
|
||||||
return await PrismaChatSession.prisma().create(
|
return await PrismaChatSession.prisma().create(data=data)
|
||||||
data=data,
|
|
||||||
include={"Messages": True},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def update_chat_session(
|
async def update_chat_session(
|
||||||
|
|||||||
@@ -18,6 +18,10 @@ class ResponseType(str, Enum):
|
|||||||
START = "start"
|
START = "start"
|
||||||
FINISH = "finish"
|
FINISH = "finish"
|
||||||
|
|
||||||
|
# Step lifecycle (one LLM API call within a message)
|
||||||
|
START_STEP = "start-step"
|
||||||
|
FINISH_STEP = "finish-step"
|
||||||
|
|
||||||
# Text streaming
|
# Text streaming
|
||||||
TEXT_START = "text-start"
|
TEXT_START = "text-start"
|
||||||
TEXT_DELTA = "text-delta"
|
TEXT_DELTA = "text-delta"
|
||||||
@@ -57,6 +61,16 @@ class StreamStart(StreamBaseResponse):
|
|||||||
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_sse(self) -> str:
|
||||||
|
"""Convert to SSE format, excluding non-protocol fields like taskId."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
data: dict[str, Any] = {
|
||||||
|
"type": self.type.value,
|
||||||
|
"messageId": self.messageId,
|
||||||
|
}
|
||||||
|
return f"data: {json.dumps(data)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
class StreamFinish(StreamBaseResponse):
|
class StreamFinish(StreamBaseResponse):
|
||||||
"""End of message/stream."""
|
"""End of message/stream."""
|
||||||
@@ -64,6 +78,26 @@ class StreamFinish(StreamBaseResponse):
|
|||||||
type: ResponseType = ResponseType.FINISH
|
type: ResponseType = ResponseType.FINISH
|
||||||
|
|
||||||
|
|
||||||
|
class StreamStartStep(StreamBaseResponse):
|
||||||
|
"""Start of a step (one LLM API call within a message).
|
||||||
|
|
||||||
|
The AI SDK uses this to add a step-start boundary to message.parts,
|
||||||
|
enabling visual separation between multiple LLM calls in a single message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.START_STEP
|
||||||
|
|
||||||
|
|
||||||
|
class StreamFinishStep(StreamBaseResponse):
|
||||||
|
"""End of a step (one LLM API call within a message).
|
||||||
|
|
||||||
|
The AI SDK uses this to reset activeTextParts and activeReasoningParts,
|
||||||
|
so the next LLM call in a tool-call continuation starts with clean state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.FINISH_STEP
|
||||||
|
|
||||||
|
|
||||||
# ========== Text Streaming ==========
|
# ========== Text Streaming ==========
|
||||||
|
|
||||||
|
|
||||||
@@ -117,7 +151,7 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
|||||||
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
||||||
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
||||||
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
||||||
# Additional fields for internal use (not part of AI SDK spec but useful)
|
# Keep these for internal backend use
|
||||||
toolName: str | None = Field(
|
toolName: str | None = Field(
|
||||||
default=None, description="Name of the tool that was executed"
|
default=None, description="Name of the tool that was executed"
|
||||||
)
|
)
|
||||||
@@ -125,6 +159,17 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
|||||||
default=True, description="Whether the tool execution succeeded"
|
default=True, description="Whether the tool execution succeeded"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_sse(self) -> str:
|
||||||
|
"""Convert to SSE format, excluding non-spec fields."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"type": self.type.value,
|
||||||
|
"toolCallId": self.toolCallId,
|
||||||
|
"output": self.output,
|
||||||
|
}
|
||||||
|
return f"data: {json.dumps(data)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
# ========== Other ==========
|
# ========== Other ==========
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from autogpt_libs import auth
|
from autogpt_libs import auth
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
|
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -17,7 +17,29 @@ from . import stream_registry
|
|||||||
from .completion_handler import process_operation_failure, process_operation_success
|
from .completion_handler import process_operation_failure, process_operation_success
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||||
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
from .response_model import StreamFinish, StreamHeartbeat
|
||||||
|
from .tools.models import (
|
||||||
|
AgentDetailsResponse,
|
||||||
|
AgentOutputResponse,
|
||||||
|
AgentPreviewResponse,
|
||||||
|
AgentSavedResponse,
|
||||||
|
AgentsFoundResponse,
|
||||||
|
BlockListResponse,
|
||||||
|
BlockOutputResponse,
|
||||||
|
ClarificationNeededResponse,
|
||||||
|
DocPageResponse,
|
||||||
|
DocSearchResultsResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
ExecutionStartedResponse,
|
||||||
|
InputValidationErrorResponse,
|
||||||
|
NeedLoginResponse,
|
||||||
|
NoResultsResponse,
|
||||||
|
OperationInProgressResponse,
|
||||||
|
OperationPendingResponse,
|
||||||
|
OperationStartedResponse,
|
||||||
|
SetupRequirementsResponse,
|
||||||
|
UnderstandingUpdatedResponse,
|
||||||
|
)
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
@@ -266,12 +288,36 @@ async def stream_chat_post(
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
stream_start_time = time.perf_counter()
|
||||||
|
log_meta = {"component": "ChatStream", "session_id": session_id}
|
||||||
|
if user_id:
|
||||||
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
|
||||||
|
f"user={user_id}, message_len={len(request.message)}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": (time.perf_counter() - stream_start_time) * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
# Create a task in the stream registry for reconnection support
|
||||||
task_id = str(uuid_module.uuid4())
|
task_id = str(uuid_module.uuid4())
|
||||||
operation_id = str(uuid_module.uuid4())
|
operation_id = str(uuid_module.uuid4())
|
||||||
|
log_meta["task_id"] = task_id
|
||||||
|
|
||||||
|
task_create_start = time.perf_counter()
|
||||||
await stream_registry.create_task(
|
await stream_registry.create_task(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -280,14 +326,28 @@ async def stream_chat_post(
|
|||||||
tool_name="chat",
|
tool_name="chat",
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Background task that runs the AI generation independently of SSE connection
|
# Background task that runs the AI generation independently of SSE connection
|
||||||
async def run_ai_generation():
|
async def run_ai_generation():
|
||||||
try:
|
import time as time_module
|
||||||
# Emit a start event with task_id for reconnection
|
|
||||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
|
||||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
|
||||||
|
|
||||||
|
gen_start_time = time_module.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
first_chunk_time, ttfc = None, None
|
||||||
|
chunk_count = 0
|
||||||
|
try:
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
session_id,
|
session_id,
|
||||||
request.message,
|
request.message,
|
||||||
@@ -295,25 +355,79 @@ async def stream_chat_post(
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
context=request.context,
|
context=request.context,
|
||||||
|
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
|
||||||
):
|
):
|
||||||
|
chunk_count += 1
|
||||||
|
if first_chunk_time is None:
|
||||||
|
first_chunk_time = time_module.perf_counter()
|
||||||
|
ttfc = first_chunk_time - gen_start_time
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunk_type": type(chunk).__name__,
|
||||||
|
"time_to_first_chunk_ms": ttfc * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
# Write to Redis (subscribers will receive via XREAD)
|
# Write to Redis (subscribers will receive via XREAD)
|
||||||
await stream_registry.publish_chunk(task_id, chunk)
|
await stream_registry.publish_chunk(task_id, chunk)
|
||||||
|
|
||||||
# Mark task as completed
|
gen_end_time = time_module.perf_counter()
|
||||||
|
total_time = (gen_end_time - gen_start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
|
||||||
|
f"task={task_id}, session={session_id}, "
|
||||||
|
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"time_to_first_chunk_ms": (
|
||||||
|
ttfc * 1000 if ttfc is not None else None
|
||||||
|
),
|
||||||
|
"n_chunks": chunk_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
await stream_registry.mark_task_completed(task_id, "completed")
|
await stream_registry.mark_task_completed(task_id, "completed")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
elapsed = time_module.perf_counter() - gen_start_time
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error in background AI generation for session {session_id}: {e}"
|
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed * 1000,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
await stream_registry.mark_task_completed(task_id, "failed")
|
await stream_registry.mark_task_completed(task_id, "failed")
|
||||||
|
|
||||||
# Start the AI generation in a background task
|
# Start the AI generation in a background task
|
||||||
bg_task = asyncio.create_task(run_ai_generation())
|
bg_task = asyncio.create_task(run_ai_generation())
|
||||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||||
|
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||||
|
)
|
||||||
|
|
||||||
# SSE endpoint that subscribes to the task's stream
|
# SSE endpoint that subscribes to the task's stream
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
|
import time as time_module
|
||||||
|
|
||||||
|
event_gen_start = time_module.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
|
||||||
|
f"user={user_id}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
subscriber_queue = None
|
subscriber_queue = None
|
||||||
|
first_chunk_yielded = False
|
||||||
|
chunks_yielded = 0
|
||||||
try:
|
try:
|
||||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
@@ -328,22 +442,70 @@ async def stream_chat_post(
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Read from the subscriber queue and yield to SSE
|
# Read from the subscriber queue and yield to SSE
|
||||||
|
logger.info(
|
||||||
|
"[TIMING] Starting to read from subscriber_queue",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||||
|
chunks_yielded += 1
|
||||||
|
|
||||||
|
if not first_chunk_yielded:
|
||||||
|
first_chunk_yielded = True
|
||||||
|
elapsed = time_module.perf_counter() - event_gen_start
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
|
||||||
|
f"type={type(chunk).__name__}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunk_type": type(chunk).__name__,
|
||||||
|
"elapsed_ms": elapsed * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
yield chunk.to_sse()
|
yield chunk.to_sse()
|
||||||
|
|
||||||
# Check for finish signal
|
# Check for finish signal
|
||||||
if isinstance(chunk, StreamFinish):
|
if isinstance(chunk, StreamFinish):
|
||||||
|
total_time = time_module.perf_counter() - event_gen_start
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
|
||||||
|
f"n_chunks={chunks_yielded}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunks_yielded": chunks_yielded,
|
||||||
|
"total_time_ms": total_time * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
break
|
break
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
# Send heartbeat to keep connection alive
|
|
||||||
yield StreamHeartbeat().to_sse()
|
yield StreamHeartbeat().to_sse()
|
||||||
|
|
||||||
except GeneratorExit:
|
except GeneratorExit:
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunks_yielded": chunks_yielded,
|
||||||
|
"reason": "client_disconnect",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
pass # Client disconnected - background task continues
|
pass # Client disconnected - background task continues
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
||||||
|
logger.error(
|
||||||
|
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
||||||
|
},
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||||
if subscriber_queue is not None:
|
if subscriber_queue is not None:
|
||||||
@@ -357,6 +519,18 @@ async def stream_chat_post(
|
|||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||||
|
total_time = time_module.perf_counter() - event_gen_start
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
||||||
|
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time * 1000,
|
||||||
|
"chunks_yielded": chunks_yielded,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
@@ -374,44 +548,53 @@ async def stream_chat_post(
|
|||||||
@router.get(
|
@router.get(
|
||||||
"/sessions/{session_id}/stream",
|
"/sessions/{session_id}/stream",
|
||||||
)
|
)
|
||||||
async def stream_chat_get(
|
async def resume_session_stream(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
message: Annotated[str, Query(min_length=1, max_length=10000)],
|
|
||||||
user_id: str | None = Depends(auth.get_user_id),
|
user_id: str | None = Depends(auth.get_user_id),
|
||||||
is_user_message: bool = Query(default=True),
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Stream chat responses for a session (GET - legacy endpoint).
|
Resume an active stream for a session.
|
||||||
|
|
||||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
Called by the AI SDK's ``useChat(resume: true)`` on page load.
|
||||||
- Text fragments as they are generated
|
Checks for an active (in-progress) task on the session and either replays
|
||||||
- Tool call UI elements (if invoked)
|
the full SSE stream or returns 204 No Content if nothing is running.
|
||||||
- Tool execution results
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The chat session identifier to associate with the streamed messages.
|
session_id: The chat session identifier.
|
||||||
message: The user's new message to process.
|
|
||||||
user_id: Optional authenticated user ID.
|
user_id: Optional authenticated user ID.
|
||||||
is_user_message: Whether the message is a user message.
|
|
||||||
Returns:
|
|
||||||
StreamingResponse: SSE-formatted response chunks.
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StreamingResponse (SSE) when an active stream exists,
|
||||||
|
or 204 No Content when there is nothing to resume.
|
||||||
"""
|
"""
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
import asyncio
|
||||||
|
|
||||||
|
active_task, _last_id = await stream_registry.get_active_task_for_session(
|
||||||
|
session_id, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not active_task:
|
||||||
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
|
task_id=active_task.task_id,
|
||||||
|
user_id=user_id,
|
||||||
|
last_message_id="0-0", # Full replay so useChat rebuilds the message
|
||||||
|
)
|
||||||
|
|
||||||
|
if subscriber_queue is None:
|
||||||
|
return Response(status_code=204)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
first_chunk_type: str | None = None
|
first_chunk_type: str | None = None
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
try:
|
||||||
session_id,
|
while True:
|
||||||
message,
|
try:
|
||||||
is_user_message=is_user_message,
|
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||||
user_id=user_id,
|
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
|
||||||
):
|
|
||||||
if chunk_count < 3:
|
if chunk_count < 3:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Chat stream chunk",
|
"Resume stream chunk",
|
||||||
extra={
|
extra={
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"chunk_type": str(chunk.type),
|
"chunk_type": str(chunk.type),
|
||||||
@@ -421,15 +604,33 @@ async def stream_chat_get(
|
|||||||
first_chunk_type = str(chunk.type)
|
first_chunk_type = str(chunk.type)
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
yield chunk.to_sse()
|
yield chunk.to_sse()
|
||||||
|
|
||||||
|
if isinstance(chunk, StreamFinish):
|
||||||
|
break
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
yield StreamHeartbeat().to_sse()
|
||||||
|
except GeneratorExit:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in resume stream for session {session_id}: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await stream_registry.unsubscribe_from_task(
|
||||||
|
active_task.task_id, subscriber_queue
|
||||||
|
)
|
||||||
|
except Exception as unsub_err:
|
||||||
|
logger.error(
|
||||||
|
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Chat stream completed",
|
"Resume stream completed",
|
||||||
extra={
|
extra={
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"chunk_count": chunk_count,
|
"n_chunks": chunk_count,
|
||||||
"first_chunk_type": first_chunk_type,
|
"first_chunk_type": first_chunk_type,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# AI SDK protocol termination
|
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
@@ -438,8 +639,8 @@ async def stream_chat_get(
|
|||||||
headers={
|
headers={
|
||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
"Connection": "keep-alive",
|
"Connection": "keep-alive",
|
||||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
"X-Accel-Buffering": "no",
|
||||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
"x-vercel-ai-ui-message-stream": "v1",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -751,3 +952,42 @@ async def health_check() -> dict:
|
|||||||
"service": "chat",
|
"service": "chat",
|
||||||
"version": "0.1.0",
|
"version": "0.1.0",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Schema Export (for OpenAPI / Orval codegen) ==========
|
||||||
|
|
||||||
|
ToolResponseUnion = (
|
||||||
|
AgentsFoundResponse
|
||||||
|
| NoResultsResponse
|
||||||
|
| AgentDetailsResponse
|
||||||
|
| SetupRequirementsResponse
|
||||||
|
| ExecutionStartedResponse
|
||||||
|
| NeedLoginResponse
|
||||||
|
| ErrorResponse
|
||||||
|
| InputValidationErrorResponse
|
||||||
|
| AgentOutputResponse
|
||||||
|
| UnderstandingUpdatedResponse
|
||||||
|
| AgentPreviewResponse
|
||||||
|
| AgentSavedResponse
|
||||||
|
| ClarificationNeededResponse
|
||||||
|
| BlockListResponse
|
||||||
|
| BlockOutputResponse
|
||||||
|
| DocSearchResultsResponse
|
||||||
|
| DocPageResponse
|
||||||
|
| OperationStartedResponse
|
||||||
|
| OperationPendingResponse
|
||||||
|
| OperationInProgressResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/schema/tool-responses",
|
||||||
|
response_model=ToolResponseUnion,
|
||||||
|
include_in_schema=True,
|
||||||
|
summary="[Dummy] Tool response type export for codegen",
|
||||||
|
description="This endpoint is not meant to be called. It exists solely to "
|
||||||
|
"expose tool response models in the OpenAPI schema for frontend codegen.",
|
||||||
|
)
|
||||||
|
async def _tool_response_schema() -> ToolResponseUnion: # type: ignore[return]
|
||||||
|
"""Never called at runtime. Exists only so Orval generates TS types."""
|
||||||
|
raise HTTPException(status_code=501, detail="Schema-only endpoint")
|
||||||
|
|||||||
@@ -52,8 +52,10 @@ from .response_model import (
|
|||||||
StreamBaseResponse,
|
StreamBaseResponse,
|
||||||
StreamError,
|
StreamError,
|
||||||
StreamFinish,
|
StreamFinish,
|
||||||
|
StreamFinishStep,
|
||||||
StreamHeartbeat,
|
StreamHeartbeat,
|
||||||
StreamStart,
|
StreamStart,
|
||||||
|
StreamStartStep,
|
||||||
StreamTextDelta,
|
StreamTextDelta,
|
||||||
StreamTextEnd,
|
StreamTextEnd,
|
||||||
StreamTextStart,
|
StreamTextStart,
|
||||||
@@ -351,6 +353,10 @@ async def stream_chat_completion(
|
|||||||
retry_count: int = 0,
|
retry_count: int = 0,
|
||||||
session: ChatSession | None = None,
|
session: ChatSession | None = None,
|
||||||
context: dict[str, str] | None = None, # {url: str, content: str}
|
context: dict[str, str] | None = None, # {url: str, content: str}
|
||||||
|
_continuation_message_id: (
|
||||||
|
str | None
|
||||||
|
) = None, # Internal: reuse message ID for tool call continuations
|
||||||
|
_task_id: str | None = None, # Internal: task ID for SSE reconnection support
|
||||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||||
"""Main entry point for streaming chat completions with database handling.
|
"""Main entry point for streaming chat completions with database handling.
|
||||||
|
|
||||||
@@ -371,21 +377,45 @@ async def stream_chat_completion(
|
|||||||
ValueError: If max_context_messages is exceeded
|
ValueError: If max_context_messages is exceeded
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
completion_start = time.monotonic()
|
||||||
|
|
||||||
|
# Build log metadata for structured logging
|
||||||
|
log_meta = {"component": "ChatService", "session_id": session_id}
|
||||||
|
if user_id:
|
||||||
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
|
f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, "
|
||||||
|
f"message_len={len(message) if message else 0}, is_user={is_user_message}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"message_len": len(message) if message else 0,
|
||||||
|
"is_user_message": is_user_message,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only fetch from Redis if session not provided (initial call)
|
# Only fetch from Redis if session not provided (initial call)
|
||||||
if session is None:
|
if session is None:
|
||||||
|
fetch_start = time.monotonic()
|
||||||
session = await get_chat_session(session_id, user_id)
|
session = await get_chat_session(session_id, user_id)
|
||||||
|
fetch_time = (time.monotonic() - fetch_start) * 1000
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Fetched session from Redis: {session.session_id if session else 'None'}, "
|
f"[TIMING] get_chat_session took {fetch_time:.1f}ms, "
|
||||||
f"message_count={len(session.messages) if session else 0}"
|
f"n_messages={len(session.messages) if session else 0}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": fetch_time,
|
||||||
|
"n_messages": len(session.messages) if session else 0,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Using provided session object: {session.session_id}, "
|
f"[TIMING] Using provided session, messages={len(session.messages)}",
|
||||||
f"message_count={len(session.messages)}"
|
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
||||||
)
|
)
|
||||||
|
|
||||||
if not session:
|
if not session:
|
||||||
@@ -406,17 +436,25 @@ async def stream_chat_completion(
|
|||||||
|
|
||||||
# Track user message in PostHog
|
# Track user message in PostHog
|
||||||
if is_user_message:
|
if is_user_message:
|
||||||
|
posthog_start = time.monotonic()
|
||||||
track_user_message(
|
track_user_message(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
message_length=len(message),
|
message_length=len(message),
|
||||||
)
|
)
|
||||||
|
posthog_time = (time.monotonic() - posthog_start) * 1000
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Upserting session: {session.session_id} with user id {session.user_id}, "
|
f"[TIMING] track_user_message took {posthog_time:.1f}ms",
|
||||||
f"message_count={len(session.messages)}"
|
extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
upsert_start = time.monotonic()
|
||||||
session = await upsert_chat_session(session)
|
session = await upsert_chat_session(session)
|
||||||
|
upsert_time = (time.monotonic() - upsert_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}},
|
||||||
|
)
|
||||||
assert session, "Session not found"
|
assert session, "Session not found"
|
||||||
|
|
||||||
# Generate title for new sessions on first user message (non-blocking)
|
# Generate title for new sessions on first user message (non-blocking)
|
||||||
@@ -454,7 +492,13 @@ async def stream_chat_completion(
|
|||||||
asyncio.create_task(_update_title())
|
asyncio.create_task(_update_title())
|
||||||
|
|
||||||
# Build system prompt with business understanding
|
# Build system prompt with business understanding
|
||||||
|
prompt_start = time.monotonic()
|
||||||
system_prompt, understanding = await _build_system_prompt(user_id)
|
system_prompt, understanding = await _build_system_prompt(user_id)
|
||||||
|
prompt_time = (time.monotonic() - prompt_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize variables for streaming
|
# Initialize variables for streaming
|
||||||
assistant_response = ChatMessage(
|
assistant_response = ChatMessage(
|
||||||
@@ -479,13 +523,27 @@ async def stream_chat_completion(
|
|||||||
# Generate unique IDs for AI SDK protocol
|
# Generate unique IDs for AI SDK protocol
|
||||||
import uuid as uuid_module
|
import uuid as uuid_module
|
||||||
|
|
||||||
message_id = str(uuid_module.uuid4())
|
is_continuation = _continuation_message_id is not None
|
||||||
|
message_id = _continuation_message_id or str(uuid_module.uuid4())
|
||||||
text_block_id = str(uuid_module.uuid4())
|
text_block_id = str(uuid_module.uuid4())
|
||||||
|
|
||||||
# Yield message start
|
# Only yield message start for the initial call, not for continuations.
|
||||||
yield StreamStart(messageId=message_id)
|
setup_time = (time.monotonic() - completion_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||||
|
)
|
||||||
|
if not is_continuation:
|
||||||
|
yield StreamStart(messageId=message_id, taskId=_task_id)
|
||||||
|
|
||||||
|
# Emit start-step before each LLM call (AI SDK uses this to add step boundaries)
|
||||||
|
yield StreamStartStep()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info(
|
||||||
|
"[TIMING] Calling _stream_chat_chunks",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
async for chunk in _stream_chat_chunks(
|
async for chunk in _stream_chat_chunks(
|
||||||
session=session,
|
session=session,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
@@ -585,6 +643,10 @@ async def stream_chat_completion(
|
|||||||
)
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
elif isinstance(chunk, StreamFinish):
|
elif isinstance(chunk, StreamFinish):
|
||||||
|
if has_done_tool_call:
|
||||||
|
# Tool calls happened — close the step but don't send message-level finish.
|
||||||
|
# The continuation will open a new step, and finish will come at the end.
|
||||||
|
yield StreamFinishStep()
|
||||||
if not has_done_tool_call:
|
if not has_done_tool_call:
|
||||||
# Emit text-end before finish if we received text but haven't closed it
|
# Emit text-end before finish if we received text but haven't closed it
|
||||||
if has_received_text and not text_streaming_ended:
|
if has_received_text and not text_streaming_ended:
|
||||||
@@ -616,6 +678,8 @@ async def stream_chat_completion(
|
|||||||
has_saved_assistant_message = True
|
has_saved_assistant_message = True
|
||||||
|
|
||||||
has_yielded_end = True
|
has_yielded_end = True
|
||||||
|
# Emit finish-step before finish (resets AI SDK text/reasoning state)
|
||||||
|
yield StreamFinishStep()
|
||||||
yield chunk
|
yield chunk
|
||||||
elif isinstance(chunk, StreamError):
|
elif isinstance(chunk, StreamError):
|
||||||
has_yielded_error = True
|
has_yielded_error = True
|
||||||
@@ -665,6 +729,10 @@ async def stream_chat_completion(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
|
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
|
||||||
)
|
)
|
||||||
|
# Close the current step before retrying so the recursive call's
|
||||||
|
# StreamStartStep doesn't produce unbalanced step events.
|
||||||
|
if not has_yielded_end:
|
||||||
|
yield StreamFinishStep()
|
||||||
should_retry = True
|
should_retry = True
|
||||||
else:
|
else:
|
||||||
# Non-retryable error or max retries exceeded
|
# Non-retryable error or max retries exceeded
|
||||||
@@ -700,6 +768,7 @@ async def stream_chat_completion(
|
|||||||
error_response = StreamError(errorText=error_message)
|
error_response = StreamError(errorText=error_message)
|
||||||
yield error_response
|
yield error_response
|
||||||
if not has_yielded_end:
|
if not has_yielded_end:
|
||||||
|
yield StreamFinishStep()
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -714,6 +783,8 @@ async def stream_chat_completion(
|
|||||||
retry_count=retry_count + 1,
|
retry_count=retry_count + 1,
|
||||||
session=session,
|
session=session,
|
||||||
context=context,
|
context=context,
|
||||||
|
_continuation_message_id=message_id, # Reuse message ID since start was already sent
|
||||||
|
_task_id=_task_id,
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
return # Exit after retry to avoid double-saving in finally block
|
return # Exit after retry to avoid double-saving in finally block
|
||||||
@@ -783,6 +854,8 @@ async def stream_chat_completion(
|
|||||||
session=session, # Pass session object to avoid Redis refetch
|
session=session, # Pass session object to avoid Redis refetch
|
||||||
context=context,
|
context=context,
|
||||||
tool_call_response=str(tool_response_messages),
|
tool_call_response=str(tool_response_messages),
|
||||||
|
_continuation_message_id=message_id, # Reuse message ID to avoid duplicates
|
||||||
|
_task_id=_task_id,
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
@@ -893,9 +966,21 @@ async def _stream_chat_chunks(
|
|||||||
SSE formatted JSON response objects
|
SSE formatted JSON response objects
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
import time as time_module
|
||||||
|
|
||||||
|
stream_chunks_start = time_module.perf_counter()
|
||||||
model = config.model
|
model = config.model
|
||||||
|
|
||||||
logger.info("Starting pure chat stream")
|
# Build log metadata for structured logging
|
||||||
|
log_meta = {"component": "ChatService", "session_id": session.session_id}
|
||||||
|
if session.user_id:
|
||||||
|
log_meta["user_id"] = session.user_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
|
||||||
|
f"user={session.user_id}, n_messages={len(session.messages)}",
|
||||||
|
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
||||||
|
)
|
||||||
|
|
||||||
messages = session.to_openai_messages()
|
messages = session.to_openai_messages()
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
@@ -906,12 +991,18 @@ async def _stream_chat_chunks(
|
|||||||
messages = [system_message] + messages
|
messages = [system_message] + messages
|
||||||
|
|
||||||
# Apply context window management
|
# Apply context window management
|
||||||
|
context_start = time_module.perf_counter()
|
||||||
context_result = await _manage_context_window(
|
context_result = await _manage_context_window(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=model,
|
model=model,
|
||||||
api_key=config.api_key,
|
api_key=config.api_key,
|
||||||
base_url=config.base_url,
|
base_url=config.base_url,
|
||||||
)
|
)
|
||||||
|
context_time = (time_module.perf_counter() - context_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _manage_context_window took {context_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": context_time}},
|
||||||
|
)
|
||||||
|
|
||||||
if context_result.error:
|
if context_result.error:
|
||||||
if "System prompt dropped" in context_result.error:
|
if "System prompt dropped" in context_result.error:
|
||||||
@@ -946,9 +1037,19 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
while retry_count <= MAX_RETRIES:
|
while retry_count <= MAX_RETRIES:
|
||||||
try:
|
try:
|
||||||
|
elapsed = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||||
|
retry_info = (
|
||||||
|
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Creating OpenAI chat completion stream..."
|
f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}",
|
||||||
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed,
|
||||||
|
"retry_count": retry_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build extra_body for OpenRouter tracing and PostHog analytics
|
# Build extra_body for OpenRouter tracing and PostHog analytics
|
||||||
@@ -965,6 +1066,7 @@ async def _stream_chat_chunks(
|
|||||||
:128
|
:128
|
||||||
] # OpenRouter limit
|
] # OpenRouter limit
|
||||||
|
|
||||||
|
api_call_start = time_module.perf_counter()
|
||||||
stream = await client.chat.completions.create(
|
stream = await client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||||
@@ -974,6 +1076,11 @@ async def _stream_chat_chunks(
|
|||||||
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
|
api_init_time = (time_module.perf_counter() - api_call_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
|
||||||
|
)
|
||||||
|
|
||||||
# Variables to accumulate tool calls
|
# Variables to accumulate tool calls
|
||||||
tool_calls: list[dict[str, Any]] = []
|
tool_calls: list[dict[str, Any]] = []
|
||||||
@@ -984,10 +1091,13 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
# Track if we've started the text block
|
# Track if we've started the text block
|
||||||
text_started = False
|
text_started = False
|
||||||
|
first_content_chunk = True
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
# Process the stream
|
# Process the stream
|
||||||
chunk: ChatCompletionChunk
|
chunk: ChatCompletionChunk
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
|
chunk_count += 1
|
||||||
if chunk.usage:
|
if chunk.usage:
|
||||||
yield StreamUsage(
|
yield StreamUsage(
|
||||||
promptTokens=chunk.usage.prompt_tokens,
|
promptTokens=chunk.usage.prompt_tokens,
|
||||||
@@ -1010,6 +1120,23 @@ async def _stream_chat_chunks(
|
|||||||
if not text_started and text_block_id:
|
if not text_started and text_block_id:
|
||||||
yield StreamTextStart(id=text_block_id)
|
yield StreamTextStart(id=text_block_id)
|
||||||
text_started = True
|
text_started = True
|
||||||
|
# Log timing for first content chunk
|
||||||
|
if first_content_chunk:
|
||||||
|
first_content_chunk = False
|
||||||
|
ttfc = (
|
||||||
|
time_module.perf_counter() - api_call_start
|
||||||
|
) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
|
||||||
|
f"(since API call), n_chunks={chunk_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"time_to_first_chunk_ms": ttfc,
|
||||||
|
"n_chunks": chunk_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
# Stream the text delta
|
# Stream the text delta
|
||||||
text_response = StreamTextDelta(
|
text_response = StreamTextDelta(
|
||||||
id=text_block_id or "",
|
id=text_block_id or "",
|
||||||
@@ -1066,7 +1193,21 @@ async def _stream_chat_chunks(
|
|||||||
toolName=tool_calls[idx]["function"]["name"],
|
toolName=tool_calls[idx]["function"]["name"],
|
||||||
)
|
)
|
||||||
emitted_start_for_idx.add(idx)
|
emitted_start_for_idx.add(idx)
|
||||||
logger.info(f"Stream complete. Finish reason: {finish_reason}")
|
stream_duration = time_module.perf_counter() - api_call_start
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
|
||||||
|
f"duration={stream_duration:.2f}s, "
|
||||||
|
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"stream_duration_ms": stream_duration * 1000,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
"n_chunks": chunk_count,
|
||||||
|
"n_tool_calls": len(tool_calls),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Yield all accumulated tool calls after the stream is complete
|
# Yield all accumulated tool calls after the stream is complete
|
||||||
# This ensures all tool call arguments have been fully received
|
# This ensures all tool call arguments have been fully received
|
||||||
@@ -1086,11 +1227,16 @@ async def _stream_chat_chunks(
|
|||||||
# Re-raise to trigger retry logic in the parent function
|
# Re-raise to trigger retry logic in the parent function
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
|
||||||
|
f"session={session.session_id}, user={session.user_id}",
|
||||||
|
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||||
|
)
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
last_error = e
|
last_error = e
|
||||||
|
|
||||||
if _is_retryable_error(e) and retry_count < MAX_RETRIES:
|
if _is_retryable_error(e) and retry_count < MAX_RETRIES:
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
# Calculate delay with exponential backoff
|
# Calculate delay with exponential backoff
|
||||||
@@ -1106,26 +1252,12 @@ async def _stream_chat_chunks(
|
|||||||
continue # Retry the stream
|
continue # Retry the stream
|
||||||
else:
|
else:
|
||||||
# Non-retryable error or max retries exceeded
|
# Non-retryable error or max retries exceeded
|
||||||
_log_api_error(
|
logger.error(
|
||||||
error=e,
|
f"Error in stream (not retrying): {e!s}",
|
||||||
session_id=session.session_id if session else None,
|
exc_info=True,
|
||||||
message_count=len(messages) if messages else None,
|
|
||||||
model=model,
|
|
||||||
retry_count=retry_count,
|
|
||||||
)
|
)
|
||||||
error_code = None
|
error_code = None
|
||||||
error_text = str(e)
|
error_text = str(e)
|
||||||
|
|
||||||
error_details = _extract_api_error_details(e)
|
|
||||||
if error_details.get("response_body"):
|
|
||||||
body = error_details["response_body"]
|
|
||||||
if isinstance(body, dict):
|
|
||||||
err = body.get("error")
|
|
||||||
if isinstance(err, dict) and err.get("message"):
|
|
||||||
error_text = err["message"]
|
|
||||||
elif body.get("message"):
|
|
||||||
error_text = body["message"]
|
|
||||||
|
|
||||||
if _is_region_blocked_error(e):
|
if _is_region_blocked_error(e):
|
||||||
error_code = "MODEL_NOT_AVAILABLE_REGION"
|
error_code = "MODEL_NOT_AVAILABLE_REGION"
|
||||||
error_text = (
|
error_text = (
|
||||||
@@ -1142,12 +1274,9 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
# If we exit the retry loop without returning, it means we exhausted retries
|
# If we exit the retry loop without returning, it means we exhausted retries
|
||||||
if last_error:
|
if last_error:
|
||||||
_log_api_error(
|
logger.error(
|
||||||
error=last_error,
|
f"Max retries ({MAX_RETRIES}) exceeded. Last error: {last_error!s}",
|
||||||
session_id=session.session_id if session else None,
|
exc_info=True,
|
||||||
message_count=len(messages) if messages else None,
|
|
||||||
model=model,
|
|
||||||
retry_count=MAX_RETRIES,
|
|
||||||
)
|
)
|
||||||
yield StreamError(errorText=f"Max retries exceeded: {last_error!s}")
|
yield StreamError(errorText=f"Max retries exceeded: {last_error!s}")
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
@@ -1583,6 +1712,7 @@ async def _execute_long_running_tool_with_streaming(
|
|||||||
task_id,
|
task_id,
|
||||||
StreamError(errorText=str(e)),
|
StreamError(errorText=str(e)),
|
||||||
)
|
)
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||||
|
|
||||||
await _update_pending_operation(
|
await _update_pending_operation(
|
||||||
@@ -1719,7 +1849,6 @@ async def _generate_llm_continuation(
|
|||||||
break # Success, exit retry loop
|
break # Success, exit retry loop
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
last_error = e
|
last_error = e
|
||||||
|
|
||||||
if _is_retryable_error(e) and retry_count < MAX_RETRIES:
|
if _is_retryable_error(e) and retry_count < MAX_RETRIES:
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
delay = min(
|
delay = min(
|
||||||
@@ -1733,23 +1862,17 @@ async def _generate_llm_continuation(
|
|||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
# Non-retryable error - log details and exit gracefully
|
# Non-retryable error - log and exit gracefully
|
||||||
_log_api_error(
|
logger.error(
|
||||||
error=e,
|
f"Non-retryable error in LLM continuation: {e!s}",
|
||||||
session_id=session_id,
|
exc_info=True,
|
||||||
message_count=len(messages) if messages else None,
|
|
||||||
model=config.model,
|
|
||||||
retry_count=retry_count,
|
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if last_error:
|
if last_error:
|
||||||
_log_api_error(
|
logger.error(
|
||||||
error=last_error,
|
f"Max retries ({MAX_RETRIES}) exceeded for LLM continuation. "
|
||||||
session_id=session_id,
|
f"Last error: {last_error!s}"
|
||||||
message_count=len(messages) if messages else None,
|
|
||||||
model=config.model,
|
|
||||||
retry_count=MAX_RETRIES,
|
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1789,89 +1912,6 @@ async def _generate_llm_continuation(
|
|||||||
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
def _log_api_error(
|
|
||||||
error: Exception,
|
|
||||||
session_id: str | None = None,
|
|
||||||
message_count: int | None = None,
|
|
||||||
model: str | None = None,
|
|
||||||
retry_count: int = 0,
|
|
||||||
) -> None:
|
|
||||||
"""Log detailed API error information for debugging."""
|
|
||||||
details = _extract_api_error_details(error)
|
|
||||||
details["session_id"] = session_id
|
|
||||||
details["message_count"] = message_count
|
|
||||||
details["model"] = model
|
|
||||||
details["retry_count"] = retry_count
|
|
||||||
|
|
||||||
if isinstance(error, RateLimitError):
|
|
||||||
logger.warning(f"Rate limit error: {details}")
|
|
||||||
elif isinstance(error, APIConnectionError):
|
|
||||||
logger.warning(f"API connection error: {details}")
|
|
||||||
elif isinstance(error, APIStatusError) and error.status_code >= 500:
|
|
||||||
logger.error(f"API server error (5xx): {details}")
|
|
||||||
else:
|
|
||||||
logger.error(f"API error: {details}")
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_api_error_details(error: Exception) -> dict[str, Any]:
|
|
||||||
"""Extract detailed information from OpenAI/OpenRouter API errors."""
|
|
||||||
error_msg = str(error)
|
|
||||||
details: dict[str, Any] = {
|
|
||||||
"error_type": type(error).__name__,
|
|
||||||
"error_message": error_msg[:500] + "..." if len(error_msg) > 500 else error_msg,
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasattr(error, "code"):
|
|
||||||
details["code"] = getattr(error, "code", None)
|
|
||||||
if hasattr(error, "param"):
|
|
||||||
details["param"] = getattr(error, "param", None)
|
|
||||||
|
|
||||||
if isinstance(error, APIStatusError):
|
|
||||||
details["status_code"] = error.status_code
|
|
||||||
details["request_id"] = getattr(error, "request_id", None)
|
|
||||||
|
|
||||||
if hasattr(error, "body") and error.body:
|
|
||||||
details["response_body"] = _sanitize_error_body(error.body)
|
|
||||||
|
|
||||||
if hasattr(error, "response") and error.response:
|
|
||||||
headers = error.response.headers
|
|
||||||
details["openrouter_provider"] = headers.get("x-openrouter-provider")
|
|
||||||
details["openrouter_model"] = headers.get("x-openrouter-model")
|
|
||||||
details["retry_after"] = headers.get("retry-after")
|
|
||||||
details["rate_limit_remaining"] = headers.get("x-ratelimit-remaining")
|
|
||||||
|
|
||||||
return details
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_error_body(
|
|
||||||
body: Any, max_length: int = 2000
|
|
||||||
) -> dict[str, Any] | str | None:
|
|
||||||
"""Extract only safe fields from error response body to avoid logging sensitive data."""
|
|
||||||
if not isinstance(body, dict):
|
|
||||||
# Non-dict bodies (e.g., HTML error pages) - return truncated string
|
|
||||||
if body is not None:
|
|
||||||
body_str = str(body)
|
|
||||||
if len(body_str) > max_length:
|
|
||||||
return body_str[:max_length] + "...[truncated]"
|
|
||||||
return body_str
|
|
||||||
return None
|
|
||||||
|
|
||||||
safe_fields = ("message", "type", "code", "param", "error")
|
|
||||||
sanitized: dict[str, Any] = {}
|
|
||||||
|
|
||||||
for field in safe_fields:
|
|
||||||
if field in body:
|
|
||||||
value = body[field]
|
|
||||||
if field == "error" and isinstance(value, dict):
|
|
||||||
sanitized[field] = _sanitize_error_body(value, max_length)
|
|
||||||
elif isinstance(value, str) and len(value) > max_length:
|
|
||||||
sanitized[field] = value[:max_length] + "...[truncated]"
|
|
||||||
else:
|
|
||||||
sanitized[field] = value
|
|
||||||
|
|
||||||
return sanitized if sanitized else None
|
|
||||||
|
|
||||||
|
|
||||||
async def _generate_llm_continuation_with_streaming(
|
async def _generate_llm_continuation_with_streaming(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
@@ -1930,6 +1970,7 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
|
|
||||||
# Publish start event
|
# Publish start event
|
||||||
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
|
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamStartStep())
|
||||||
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
|
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
|
||||||
|
|
||||||
# Stream the response
|
# Stream the response
|
||||||
@@ -1953,6 +1994,7 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
|
|
||||||
# Publish end events
|
# Publish end events
|
||||||
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
|
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||||
|
|
||||||
if assistant_content:
|
if assistant_content:
|
||||||
# Reload session from DB to avoid race condition with user messages
|
# Reload session from DB to avoid race condition with user messages
|
||||||
@@ -1994,4 +2036,5 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
task_id,
|
task_id,
|
||||||
StreamError(errorText=f"Failed to generate response: {e}"),
|
StreamError(errorText=f"Failed to generate response: {e}"),
|
||||||
)
|
)
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||||
|
|||||||
@@ -104,6 +104,24 @@ async def create_task(
|
|||||||
Returns:
|
Returns:
|
||||||
The created ActiveTask instance (metadata only)
|
The created ActiveTask instance (metadata only)
|
||||||
"""
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Build log metadata for structured logging
|
||||||
|
log_meta = {
|
||||||
|
"component": "StreamRegistry",
|
||||||
|
"task_id": task_id,
|
||||||
|
"session_id": session_id,
|
||||||
|
}
|
||||||
|
if user_id:
|
||||||
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
|
||||||
task = ActiveTask(
|
task = ActiveTask(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -114,10 +132,18 @@ async def create_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Store metadata in Redis
|
# Store metadata in Redis
|
||||||
|
redis_start = time.perf_counter()
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
|
redis_time = (time.perf_counter() - redis_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
|
||||||
|
)
|
||||||
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
meta_key = _get_task_meta_key(task_id)
|
||||||
op_key = _get_operation_mapping_key(operation_id)
|
op_key = _get_operation_mapping_key(operation_id)
|
||||||
|
|
||||||
|
hset_start = time.perf_counter()
|
||||||
await redis.hset( # type: ignore[misc]
|
await redis.hset( # type: ignore[misc]
|
||||||
meta_key,
|
meta_key,
|
||||||
mapping={
|
mapping={
|
||||||
@@ -131,12 +157,22 @@ async def create_task(
|
|||||||
"created_at": task.created_at.isoformat(),
|
"created_at": task.created_at.isoformat(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
hset_time = (time.perf_counter() - hset_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] redis.hset took {hset_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
|
||||||
|
)
|
||||||
|
|
||||||
await redis.expire(meta_key, config.stream_ttl)
|
await redis.expire(meta_key, config.stream_ttl)
|
||||||
|
|
||||||
# Create operation_id -> task_id mapping for webhook lookups
|
# Create operation_id -> task_id mapping for webhook lookups
|
||||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||||
|
|
||||||
logger.debug(f"Created task {task_id} for session {session_id}")
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
|
||||||
|
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||||
|
)
|
||||||
|
|
||||||
return task
|
return task
|
||||||
|
|
||||||
@@ -156,26 +192,60 @@ async def publish_chunk(
|
|||||||
Returns:
|
Returns:
|
||||||
The Redis Stream message ID
|
The Redis Stream message ID
|
||||||
"""
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
chunk_type = type(chunk).__name__
|
||||||
chunk_json = chunk.model_dump_json()
|
chunk_json = chunk.model_dump_json()
|
||||||
message_id = "0-0"
|
message_id = "0-0"
|
||||||
|
|
||||||
|
# Build log metadata
|
||||||
|
log_meta = {
|
||||||
|
"component": "StreamRegistry",
|
||||||
|
"task_id": task_id,
|
||||||
|
"chunk_type": chunk_type,
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
# Write to Redis Stream for persistence and real-time delivery
|
# Write to Redis Stream for persistence and real-time delivery
|
||||||
|
xadd_start = time.perf_counter()
|
||||||
raw_id = await redis.xadd(
|
raw_id = await redis.xadd(
|
||||||
stream_key,
|
stream_key,
|
||||||
{"data": chunk_json},
|
{"data": chunk_json},
|
||||||
maxlen=config.stream_max_length,
|
maxlen=config.stream_max_length,
|
||||||
)
|
)
|
||||||
|
xadd_time = (time.perf_counter() - xadd_start) * 1000
|
||||||
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
||||||
|
|
||||||
# Set TTL on stream to match task metadata TTL
|
# Set TTL on stream to match task metadata TTL
|
||||||
await redis.expire(stream_key, config.stream_ttl)
|
await redis.expire(stream_key, config.stream_ttl)
|
||||||
|
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
# Only log timing for significant chunks or slow operations
|
||||||
|
if (
|
||||||
|
chunk_type
|
||||||
|
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
|
||||||
|
or total_time > 50
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"xadd_time_ms": xadd_time,
|
||||||
|
"message_id": message_id,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to publish chunk for task {task_id}: {e}",
|
f"[TIMING] Failed to publish chunk {chunk_type} after {elapsed:.1f}ms: {e}",
|
||||||
|
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -200,24 +270,61 @@ async def subscribe_to_task(
|
|||||||
An asyncio Queue that will receive stream chunks, or None if task not found
|
An asyncio Queue that will receive stream chunks, or None if task not found
|
||||||
or user doesn't have access
|
or user doesn't have access
|
||||||
"""
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Build log metadata
|
||||||
|
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
||||||
|
if user_id:
|
||||||
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
|
||||||
|
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_start = time.perf_counter()
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
meta_key = _get_task_meta_key(task_id)
|
meta_key = _get_task_meta_key(task_id)
|
||||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||||
|
hgetall_time = (time.perf_counter() - redis_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
|
||||||
|
)
|
||||||
|
|
||||||
if not meta:
|
if not meta:
|
||||||
logger.debug(f"Task {task_id} not found in Redis")
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed,
|
||||||
|
"reason": "task_not_found",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||||
task_status = meta.get("status", "")
|
task_status = meta.get("status", "")
|
||||||
task_user_id = meta.get("user_id", "") or None
|
task_user_id = meta.get("user_id", "") or None
|
||||||
|
log_meta["session_id"] = meta.get("session_id", "")
|
||||||
|
|
||||||
# Validate ownership - if task has an owner, requester must match
|
# Validate ownership - if task has an owner, requester must match
|
||||||
if task_user_id:
|
if task_user_id:
|
||||||
if user_id != task_user_id:
|
if user_id != task_user_id:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"User {user_id} denied access to task {task_id} "
|
f"[TIMING] Access denied: user {user_id} tried to access task owned by {task_user_id}",
|
||||||
f"owned by {task_user_id}"
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"task_owner": task_user_id,
|
||||||
|
"reason": "access_denied",
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -225,7 +332,19 @@ async def subscribe_to_task(
|
|||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
# Step 1: Replay messages from Redis Stream
|
# Step 1: Replay messages from Redis Stream
|
||||||
|
xread_start = time.perf_counter()
|
||||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||||
|
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": xread_time,
|
||||||
|
"task_status": task_status,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
replayed_count = 0
|
replayed_count = 0
|
||||||
replay_last_id = last_message_id
|
replay_last_id = last_message_id
|
||||||
@@ -244,19 +363,48 @@ async def subscribe_to_task(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to replay message: {e}")
|
logger.warning(f"Failed to replay message: {e}")
|
||||||
|
|
||||||
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
|
logger.info(
|
||||||
|
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"n_messages_replayed": replayed_count,
|
||||||
|
"replay_last_id": replay_last_id,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Step 2: If task is still running, start stream listener for live updates
|
# Step 2: If task is still running, start stream listener for live updates
|
||||||
if task_status == "running":
|
if task_status == "running":
|
||||||
|
logger.info(
|
||||||
|
"[TIMING] Task still running, starting _stream_listener",
|
||||||
|
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
||||||
|
)
|
||||||
listener_task = asyncio.create_task(
|
listener_task = asyncio.create_task(
|
||||||
_stream_listener(task_id, subscriber_queue, replay_last_id)
|
_stream_listener(task_id, subscriber_queue, replay_last_id, log_meta)
|
||||||
)
|
)
|
||||||
# Track listener task for cleanup on unsubscribe
|
# Track listener task for cleanup on unsubscribe
|
||||||
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
||||||
else:
|
else:
|
||||||
# Task is completed/failed - add finish marker
|
# Task is completed/failed - add finish marker
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Task already {task_status}, adding StreamFinish",
|
||||||
|
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
||||||
|
)
|
||||||
await subscriber_queue.put(StreamFinish())
|
await subscriber_queue.put(StreamFinish())
|
||||||
|
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
|
||||||
|
f"n_messages_replayed={replayed_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"n_messages_replayed": replayed_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
return subscriber_queue
|
return subscriber_queue
|
||||||
|
|
||||||
|
|
||||||
@@ -264,6 +412,7 @@ async def _stream_listener(
|
|||||||
task_id: str,
|
task_id: str,
|
||||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||||
last_replayed_id: str,
|
last_replayed_id: str,
|
||||||
|
log_meta: dict | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Listen to Redis Stream for new messages using blocking XREAD.
|
"""Listen to Redis Stream for new messages using blocking XREAD.
|
||||||
|
|
||||||
@@ -274,10 +423,27 @@ async def _stream_listener(
|
|||||||
task_id: Task ID to listen for
|
task_id: Task ID to listen for
|
||||||
subscriber_queue: Queue to deliver messages to
|
subscriber_queue: Queue to deliver messages to
|
||||||
last_replayed_id: Last message ID from replay (continue from here)
|
last_replayed_id: Last message ID from replay (continue from here)
|
||||||
|
log_meta: Structured logging metadata
|
||||||
"""
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Use provided log_meta or build minimal one
|
||||||
|
if log_meta is None:
|
||||||
|
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
|
||||||
|
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
|
||||||
|
)
|
||||||
|
|
||||||
queue_id = id(subscriber_queue)
|
queue_id = id(subscriber_queue)
|
||||||
# Track the last successfully delivered message ID for recovery hints
|
# Track the last successfully delivered message ID for recovery hints
|
||||||
last_delivered_id = last_replayed_id
|
last_delivered_id = last_replayed_id
|
||||||
|
messages_delivered = 0
|
||||||
|
first_message_time = None
|
||||||
|
xread_count = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
@@ -287,9 +453,39 @@ async def _stream_listener(
|
|||||||
while True:
|
while True:
|
||||||
# Block for up to 30 seconds waiting for new messages
|
# Block for up to 30 seconds waiting for new messages
|
||||||
# This allows periodic checking if task is still running
|
# This allows periodic checking if task is still running
|
||||||
|
xread_start = time.perf_counter()
|
||||||
|
xread_count += 1
|
||||||
messages = await redis.xread(
|
messages = await redis.xread(
|
||||||
{stream_key: current_id}, block=30000, count=100
|
{stream_key: current_id}, block=30000, count=100
|
||||||
)
|
)
|
||||||
|
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||||
|
|
||||||
|
if messages:
|
||||||
|
msg_count = sum(len(msgs) for _, msgs in messages)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"xread_count": xread_count,
|
||||||
|
"n_messages": msg_count,
|
||||||
|
"duration_ms": xread_time,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif xread_time > 1000:
|
||||||
|
# Only log timeouts (30s blocking)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"xread_count": xread_count,
|
||||||
|
"duration_ms": xread_time,
|
||||||
|
"reason": "timeout",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if not messages:
|
if not messages:
|
||||||
# Timeout - check if task is still running
|
# Timeout - check if task is still running
|
||||||
@@ -326,10 +522,30 @@ async def _stream_listener(
|
|||||||
)
|
)
|
||||||
# Update last delivered ID on successful delivery
|
# Update last delivered ID on successful delivery
|
||||||
last_delivered_id = current_id
|
last_delivered_id = current_id
|
||||||
|
messages_delivered += 1
|
||||||
|
if first_message_time is None:
|
||||||
|
first_message_time = time.perf_counter()
|
||||||
|
elapsed = (first_message_time - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed,
|
||||||
|
"chunk_type": type(chunk).__name__,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Subscriber queue full for task {task_id}, "
|
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
|
||||||
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"timeout_s": QUEUE_PUT_TIMEOUT,
|
||||||
|
"reason": "queue_full",
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
# Send overflow error with recovery info
|
# Send overflow error with recovery info
|
||||||
try:
|
try:
|
||||||
@@ -351,15 +567,44 @@ async def _stream_listener(
|
|||||||
|
|
||||||
# Stop listening on finish
|
# Stop listening on finish
|
||||||
if isinstance(chunk, StreamFinish):
|
if isinstance(chunk, StreamFinish):
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"messages_delivered": messages_delivered,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error processing stream message: {e}")
|
logger.warning(
|
||||||
|
f"Error processing stream message: {e}",
|
||||||
|
extra={"json_fields": {**log_meta, "error": str(e)}},
|
||||||
|
)
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.debug(f"Stream listener cancelled for task {task_id}")
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed,
|
||||||
|
"messages_delivered": messages_delivered,
|
||||||
|
"reason": "cancelled",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
raise # Re-raise to propagate cancellation
|
raise # Re-raise to propagate cancellation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Stream listener error for task {task_id}: {e}")
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.error(
|
||||||
|
f"[TIMING] _stream_listener ERROR after {elapsed:.1f}ms: {e}",
|
||||||
|
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
|
||||||
|
)
|
||||||
# On error, send finish to unblock subscriber
|
# On error, send finish to unblock subscriber
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
@@ -368,10 +613,24 @@ async def _stream_listener(
|
|||||||
)
|
)
|
||||||
except (asyncio.TimeoutError, asyncio.QueueFull):
|
except (asyncio.TimeoutError, asyncio.QueueFull):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Could not deliver finish event for task {task_id} after error"
|
"Could not deliver finish event after error",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
# Clean up listener task mapping on exit
|
# Clean up listener task mapping on exit
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
|
||||||
|
f"delivered={messages_delivered}, xread_count={xread_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"messages_delivered": messages_delivered,
|
||||||
|
"xread_count": xread_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
_listener_tasks.pop(queue_id, None)
|
_listener_tasks.pop(queue_id, None)
|
||||||
|
|
||||||
|
|
||||||
@@ -598,8 +857,10 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
|||||||
ResponseType,
|
ResponseType,
|
||||||
StreamError,
|
StreamError,
|
||||||
StreamFinish,
|
StreamFinish,
|
||||||
|
StreamFinishStep,
|
||||||
StreamHeartbeat,
|
StreamHeartbeat,
|
||||||
StreamStart,
|
StreamStart,
|
||||||
|
StreamStartStep,
|
||||||
StreamTextDelta,
|
StreamTextDelta,
|
||||||
StreamTextEnd,
|
StreamTextEnd,
|
||||||
StreamTextStart,
|
StreamTextStart,
|
||||||
@@ -613,6 +874,8 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
|||||||
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
||||||
ResponseType.START.value: StreamStart,
|
ResponseType.START.value: StreamStart,
|
||||||
ResponseType.FINISH.value: StreamFinish,
|
ResponseType.FINISH.value: StreamFinish,
|
||||||
|
ResponseType.START_STEP.value: StreamStartStep,
|
||||||
|
ResponseType.FINISH_STEP.value: StreamFinishStep,
|
||||||
ResponseType.TEXT_START.value: StreamTextStart,
|
ResponseType.TEXT_START.value: StreamTextStart,
|
||||||
ResponseType.TEXT_DELTA.value: StreamTextDelta,
|
ResponseType.TEXT_DELTA.value: StreamTextDelta,
|
||||||
ResponseType.TEXT_END.value: StreamTextEnd,
|
ResponseType.TEXT_END.value: StreamTextEnd,
|
||||||
|
|||||||
@@ -13,10 +13,32 @@ from backend.api.features.chat.tools.models import (
|
|||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
)
|
)
|
||||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||||
from backend.data.block import get_block
|
from backend.data.block import BlockType, get_block
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_TARGET_RESULTS = 10
|
||||||
|
# Over-fetch to compensate for post-hoc filtering of graph-only blocks.
|
||||||
|
# 40 is 2x current removed; speed of query 10 vs 40 is minimial
|
||||||
|
_OVERFETCH_PAGE_SIZE = 40
|
||||||
|
|
||||||
|
# Block types that only work within graphs and cannot run standalone in CoPilot.
|
||||||
|
COPILOT_EXCLUDED_BLOCK_TYPES = {
|
||||||
|
BlockType.INPUT, # Graph interface definition - data enters via chat, not graph inputs
|
||||||
|
BlockType.OUTPUT, # Graph interface definition - data exits via chat, not graph outputs
|
||||||
|
BlockType.WEBHOOK, # Wait for external events - would hang forever in CoPilot
|
||||||
|
BlockType.WEBHOOK_MANUAL, # Same as WEBHOOK
|
||||||
|
BlockType.NOTE, # Visual annotation only - no runtime behavior
|
||||||
|
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
|
||||||
|
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
|
||||||
|
}
|
||||||
|
|
||||||
|
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
|
||||||
|
COPILOT_EXCLUDED_BLOCK_IDS = {
|
||||||
|
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology
|
||||||
|
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class FindBlockTool(BaseTool):
|
class FindBlockTool(BaseTool):
|
||||||
"""Tool for searching available blocks."""
|
"""Tool for searching available blocks."""
|
||||||
@@ -88,7 +110,7 @@ class FindBlockTool(BaseTool):
|
|||||||
query=query,
|
query=query,
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=1,
|
page=1,
|
||||||
page_size=10,
|
page_size=_OVERFETCH_PAGE_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
@@ -108,18 +130,35 @@ class FindBlockTool(BaseTool):
|
|||||||
block = get_block(block_id)
|
block = get_block(block_id)
|
||||||
|
|
||||||
# Skip disabled blocks
|
# Skip disabled blocks
|
||||||
if block and not block.disabled:
|
if not block or block.disabled:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip blocks excluded from CoPilot (graph-only blocks)
|
||||||
|
if (
|
||||||
|
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
# Get input/output schemas
|
# Get input/output schemas
|
||||||
input_schema = {}
|
input_schema = {}
|
||||||
output_schema = {}
|
output_schema = {}
|
||||||
try:
|
try:
|
||||||
input_schema = block.input_schema.jsonschema()
|
input_schema = block.input_schema.jsonschema()
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.debug(
|
||||||
|
"Failed to generate input schema for block %s: %s",
|
||||||
|
block_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
output_schema = block.output_schema.jsonschema()
|
output_schema = block.output_schema.jsonschema()
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.debug(
|
||||||
|
"Failed to generate output schema for block %s: %s",
|
||||||
|
block_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
# Get categories from block instance
|
# Get categories from block instance
|
||||||
categories = []
|
categories = []
|
||||||
@@ -163,6 +202,19 @@ class FindBlockTool(BaseTool):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if len(blocks) >= _TARGET_RESULTS:
|
||||||
|
break
|
||||||
|
|
||||||
|
if blocks and len(blocks) < _TARGET_RESULTS:
|
||||||
|
logger.debug(
|
||||||
|
"find_block returned %d/%d results for query '%s' "
|
||||||
|
"(filtered %d excluded/disabled blocks)",
|
||||||
|
len(blocks),
|
||||||
|
_TARGET_RESULTS,
|
||||||
|
query,
|
||||||
|
len(results) - len(blocks),
|
||||||
|
)
|
||||||
|
|
||||||
if not blocks:
|
if not blocks:
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
message=f"No blocks found for '{query}'",
|
message=f"No blocks found for '{query}'",
|
||||||
|
|||||||
@@ -0,0 +1,139 @@
|
|||||||
|
"""Tests for block filtering in FindBlockTool."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.chat.tools.find_block import (
|
||||||
|
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||||
|
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||||
|
FindBlockTool,
|
||||||
|
)
|
||||||
|
from backend.api.features.chat.tools.models import BlockListResponse
|
||||||
|
from backend.data.block import BlockType
|
||||||
|
|
||||||
|
from ._test_data import make_session
|
||||||
|
|
||||||
|
_TEST_USER_ID = "test-user-find-block"
|
||||||
|
|
||||||
|
|
||||||
|
def make_mock_block(
|
||||||
|
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
||||||
|
):
|
||||||
|
"""Create a mock block for testing."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.id = block_id
|
||||||
|
mock.name = name
|
||||||
|
mock.description = f"{name} description"
|
||||||
|
mock.block_type = block_type
|
||||||
|
mock.disabled = disabled
|
||||||
|
mock.input_schema = MagicMock()
|
||||||
|
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||||
|
mock.input_schema.get_credentials_fields.return_value = {}
|
||||||
|
mock.output_schema = MagicMock()
|
||||||
|
mock.output_schema.jsonschema.return_value = {}
|
||||||
|
mock.categories = []
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
class TestFindBlockFiltering:
|
||||||
|
"""Tests for block filtering in FindBlockTool."""
|
||||||
|
|
||||||
|
def test_excluded_block_types_contains_expected_types(self):
|
||||||
|
"""Verify COPILOT_EXCLUDED_BLOCK_TYPES contains all graph-only types."""
|
||||||
|
assert BlockType.INPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.OUTPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.WEBHOOK in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.WEBHOOK_MANUAL in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.NOTE in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
|
||||||
|
def test_excluded_block_ids_contains_smart_decision_maker(self):
|
||||||
|
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
|
||||||
|
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_excluded_block_type_filtered_from_results(self):
|
||||||
|
"""Verify blocks with excluded BlockTypes are filtered from search results."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
# Mock search returns an INPUT block (excluded) and a STANDARD block (included)
|
||||||
|
search_results = [
|
||||||
|
{"content_id": "input-block-id", "score": 0.9},
|
||||||
|
{"content_id": "standard-block-id", "score": 0.8},
|
||||||
|
]
|
||||||
|
|
||||||
|
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||||
|
standard_block = make_mock_block(
|
||||||
|
"standard-block-id", "HTTP Request", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_get_block(block_id):
|
||||||
|
return {
|
||||||
|
"input-block-id": input_block,
|
||||||
|
"standard-block-id": standard_block,
|
||||||
|
}.get(block_id)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=(search_results, 2),
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.get_block",
|
||||||
|
side_effect=mock_get_block,
|
||||||
|
):
|
||||||
|
tool = FindBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID, session=session, query="test"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only return the standard block, not the INPUT block
|
||||||
|
assert isinstance(response, BlockListResponse)
|
||||||
|
assert len(response.blocks) == 1
|
||||||
|
assert response.blocks[0].id == "standard-block-id"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_excluded_block_id_filtered_from_results(self):
|
||||||
|
"""Verify SmartDecisionMakerBlock is filtered from search results."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||||
|
search_results = [
|
||||||
|
{"content_id": smart_decision_id, "score": 0.9},
|
||||||
|
{"content_id": "normal-block-id", "score": 0.8},
|
||||||
|
]
|
||||||
|
|
||||||
|
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
|
||||||
|
smart_block = make_mock_block(
|
||||||
|
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
normal_block = make_mock_block(
|
||||||
|
"normal-block-id", "Normal Block", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_get_block(block_id):
|
||||||
|
return {
|
||||||
|
smart_decision_id: smart_block,
|
||||||
|
"normal-block-id": normal_block,
|
||||||
|
}.get(block_id)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=(search_results, 2),
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.get_block",
|
||||||
|
side_effect=mock_get_block,
|
||||||
|
):
|
||||||
|
tool = FindBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID, session=session, query="decision"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only return normal block, not SmartDecisionMakerBlock
|
||||||
|
assert isinstance(response, BlockListResponse)
|
||||||
|
assert len(response.blocks) == 1
|
||||||
|
assert response.blocks[0].id == "normal-block-id"
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
"""Shared helpers for chat tools."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def get_inputs_from_schema(
|
||||||
|
input_schema: dict[str, Any],
|
||||||
|
exclude_fields: set[str] | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Extract input field info from JSON schema."""
|
||||||
|
if not isinstance(input_schema, dict):
|
||||||
|
return []
|
||||||
|
|
||||||
|
exclude = exclude_fields or set()
|
||||||
|
properties = input_schema.get("properties", {})
|
||||||
|
required = set(input_schema.get("required", []))
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
"title": schema.get("title", name),
|
||||||
|
"type": schema.get("type", "string"),
|
||||||
|
"description": schema.get("description", ""),
|
||||||
|
"required": name in required,
|
||||||
|
"default": schema.get("default"),
|
||||||
|
}
|
||||||
|
for name, schema in properties.items()
|
||||||
|
if name not in exclude
|
||||||
|
]
|
||||||
@@ -24,6 +24,7 @@ from backend.util.timezone_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
AgentDetails,
|
AgentDetails,
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
@@ -261,7 +262,7 @@ class RunAgentTool(BaseTool):
|
|||||||
),
|
),
|
||||||
requirements={
|
requirements={
|
||||||
"credentials": requirements_creds_list,
|
"credentials": requirements_creds_list,
|
||||||
"inputs": self._get_inputs_list(graph.input_schema),
|
"inputs": get_inputs_from_schema(graph.input_schema),
|
||||||
"execution_modes": self._get_execution_modes(graph),
|
"execution_modes": self._get_execution_modes(graph),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
@@ -369,22 +370,6 @@ class RunAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
|
||||||
"""Extract inputs list from schema."""
|
|
||||||
inputs_list = []
|
|
||||||
if isinstance(input_schema, dict) and "properties" in input_schema:
|
|
||||||
for field_name, field_schema in input_schema["properties"].items():
|
|
||||||
inputs_list.append(
|
|
||||||
{
|
|
||||||
"name": field_name,
|
|
||||||
"title": field_schema.get("title", field_name),
|
|
||||||
"type": field_schema.get("type", "string"),
|
|
||||||
"description": field_schema.get("description", ""),
|
|
||||||
"required": field_name in input_schema.get("required", []),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return inputs_list
|
|
||||||
|
|
||||||
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
||||||
"""Get available execution modes for the graph."""
|
"""Get available execution modes for the graph."""
|
||||||
trigger_info = graph.trigger_setup_info
|
trigger_info = graph.trigger_setup_info
|
||||||
@@ -398,7 +383,7 @@ class RunAgentTool(BaseTool):
|
|||||||
suffix: str,
|
suffix: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build a message describing available inputs for an agent."""
|
"""Build a message describing available inputs for an agent."""
|
||||||
inputs_list = self._get_inputs_list(graph.input_schema)
|
inputs_list = get_inputs_from_schema(graph.input_schema)
|
||||||
required_names = [i["name"] for i in inputs_list if i["required"]]
|
required_names = [i["name"] for i in inputs_list if i["required"]]
|
||||||
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
||||||
|
|
||||||
|
|||||||
@@ -8,14 +8,19 @@ from typing import Any
|
|||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.block import get_block
|
from backend.api.features.chat.tools.find_block import (
|
||||||
|
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||||
|
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||||
|
)
|
||||||
|
from backend.data.block import AnyBlockSchema, get_block
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||||
from backend.data.workspace import get_or_create_workspace
|
from backend.data.workspace import get_or_create_workspace
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import BlockError
|
from backend.util.exceptions import BlockError
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
BlockOutputResponse,
|
BlockOutputResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -24,7 +29,10 @@ from .models import (
|
|||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
UserReadiness,
|
UserReadiness,
|
||||||
)
|
)
|
||||||
from .utils import build_missing_credentials_from_field_info
|
from .utils import (
|
||||||
|
build_missing_credentials_from_field_info,
|
||||||
|
match_credentials_to_requirements,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -73,91 +81,6 @@ class RunBlockTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def _check_block_credentials(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
block: Any,
|
|
||||||
input_data: dict[str, Any] | None = None,
|
|
||||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
|
||||||
"""
|
|
||||||
Check if user has required credentials for a block.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
block: Block to check credentials for
|
|
||||||
input_data: Input data for the block (used to determine provider via discriminator)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[matched_credentials, missing_credentials]
|
|
||||||
"""
|
|
||||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
|
||||||
missing_credentials: list[CredentialsMetaInput] = []
|
|
||||||
input_data = input_data or {}
|
|
||||||
|
|
||||||
# Get credential field info from block's input schema
|
|
||||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
|
||||||
|
|
||||||
if not credentials_fields_info:
|
|
||||||
return matched_credentials, missing_credentials
|
|
||||||
|
|
||||||
# Get user's available credentials
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
|
||||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
|
||||||
|
|
||||||
for field_name, field_info in credentials_fields_info.items():
|
|
||||||
effective_field_info = field_info
|
|
||||||
if field_info.discriminator and field_info.discriminator_mapping:
|
|
||||||
# Get discriminator from input, falling back to schema default
|
|
||||||
discriminator_value = input_data.get(field_info.discriminator)
|
|
||||||
if discriminator_value is None:
|
|
||||||
field = block.input_schema.model_fields.get(
|
|
||||||
field_info.discriminator
|
|
||||||
)
|
|
||||||
if field and field.default is not PydanticUndefined:
|
|
||||||
discriminator_value = field.default
|
|
||||||
|
|
||||||
if (
|
|
||||||
discriminator_value
|
|
||||||
and discriminator_value in field_info.discriminator_mapping
|
|
||||||
):
|
|
||||||
effective_field_info = field_info.discriminate(discriminator_value)
|
|
||||||
logger.debug(
|
|
||||||
f"Discriminated provider for {field_name}: "
|
|
||||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
|
||||||
)
|
|
||||||
|
|
||||||
matching_cred = next(
|
|
||||||
(
|
|
||||||
cred
|
|
||||||
for cred in available_creds
|
|
||||||
if cred.provider in effective_field_info.provider
|
|
||||||
and cred.type in effective_field_info.supported_types
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if matching_cred:
|
|
||||||
matched_credentials[field_name] = CredentialsMetaInput(
|
|
||||||
id=matching_cred.id,
|
|
||||||
provider=matching_cred.provider, # type: ignore
|
|
||||||
type=matching_cred.type,
|
|
||||||
title=matching_cred.title,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Create a placeholder for the missing credential
|
|
||||||
provider = next(iter(effective_field_info.provider), "unknown")
|
|
||||||
cred_type = next(iter(effective_field_info.supported_types), "api_key")
|
|
||||||
missing_credentials.append(
|
|
||||||
CredentialsMetaInput(
|
|
||||||
id=field_name,
|
|
||||||
provider=provider, # type: ignore
|
|
||||||
type=cred_type, # type: ignore
|
|
||||||
title=field_name.replace("_", " ").title(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return matched_credentials, missing_credentials
|
|
||||||
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
@@ -212,11 +135,24 @@ class RunBlockTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if block is excluded from CoPilot (graph-only blocks)
|
||||||
|
if (
|
||||||
|
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||||
|
):
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"Block '{block.name}' cannot be run directly in CoPilot. "
|
||||||
|
"This block is designed for use within graphs only."
|
||||||
|
),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||||
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
creds_manager = IntegrationCredentialsManager()
|
||||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
matched_credentials, missing_credentials = (
|
||||||
user_id, block, input_data
|
await self._resolve_block_credentials(user_id, block, input_data)
|
||||||
)
|
)
|
||||||
|
|
||||||
if missing_credentials:
|
if missing_credentials:
|
||||||
@@ -345,29 +281,75 @@ class RunBlockTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
async def _resolve_block_credentials(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
block: AnyBlockSchema,
|
||||||
|
input_data: dict[str, Any] | None = None,
|
||||||
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
|
"""
|
||||||
|
Resolve credentials for a block by matching user's available credentials.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
block: Block to resolve credentials for
|
||||||
|
input_data: Input data for the block (used to determine provider via discriminator)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple of (matched_credentials, missing_credentials) - matched credentials
|
||||||
|
are used for block execution, missing ones indicate setup requirements.
|
||||||
|
"""
|
||||||
|
input_data = input_data or {}
|
||||||
|
requirements = self._resolve_discriminated_credentials(block, input_data)
|
||||||
|
|
||||||
|
if not requirements:
|
||||||
|
return {}, []
|
||||||
|
|
||||||
|
return await match_credentials_to_requirements(user_id, requirements)
|
||||||
|
|
||||||
|
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
|
||||||
"""Extract non-credential inputs from block schema."""
|
"""Extract non-credential inputs from block schema."""
|
||||||
inputs_list = []
|
|
||||||
schema = block.input_schema.jsonschema()
|
schema = block.input_schema.jsonschema()
|
||||||
properties = schema.get("properties", {})
|
|
||||||
required_fields = set(schema.get("required", []))
|
|
||||||
|
|
||||||
# Get credential field names to exclude
|
|
||||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||||
|
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
||||||
|
|
||||||
for field_name, field_schema in properties.items():
|
def _resolve_discriminated_credentials(
|
||||||
# Skip credential fields
|
self,
|
||||||
if field_name in credentials_fields:
|
block: AnyBlockSchema,
|
||||||
continue
|
input_data: dict[str, Any],
|
||||||
|
) -> dict[str, CredentialsFieldInfo]:
|
||||||
|
"""Resolve credential requirements, applying discriminator logic where needed."""
|
||||||
|
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||||
|
if not credentials_fields_info:
|
||||||
|
return {}
|
||||||
|
|
||||||
inputs_list.append(
|
resolved: dict[str, CredentialsFieldInfo] = {}
|
||||||
{
|
|
||||||
"name": field_name,
|
for field_name, field_info in credentials_fields_info.items():
|
||||||
"title": field_schema.get("title", field_name),
|
effective_field_info = field_info
|
||||||
"type": field_schema.get("type", "string"),
|
|
||||||
"description": field_schema.get("description", ""),
|
if field_info.discriminator and field_info.discriminator_mapping:
|
||||||
"required": field_name in required_fields,
|
discriminator_value = input_data.get(field_info.discriminator)
|
||||||
}
|
if discriminator_value is None:
|
||||||
|
field = block.input_schema.model_fields.get(
|
||||||
|
field_info.discriminator
|
||||||
|
)
|
||||||
|
if field and field.default is not PydanticUndefined:
|
||||||
|
discriminator_value = field.default
|
||||||
|
|
||||||
|
if (
|
||||||
|
discriminator_value
|
||||||
|
and discriminator_value in field_info.discriminator_mapping
|
||||||
|
):
|
||||||
|
effective_field_info = field_info.discriminate(discriminator_value)
|
||||||
|
# For host-scoped credentials, add the discriminator value
|
||||||
|
# (e.g., URL) so _credential_is_for_host can match it
|
||||||
|
effective_field_info.discriminator_values.add(discriminator_value)
|
||||||
|
logger.debug(
|
||||||
|
f"Discriminated provider for {field_name}: "
|
||||||
|
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return inputs_list
|
resolved[field_name] = effective_field_info
|
||||||
|
|
||||||
|
return resolved
|
||||||
|
|||||||
@@ -0,0 +1,106 @@
|
|||||||
|
"""Tests for block execution guards in RunBlockTool."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.chat.tools.models import ErrorResponse
|
||||||
|
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||||
|
from backend.data.block import BlockType
|
||||||
|
|
||||||
|
from ._test_data import make_session
|
||||||
|
|
||||||
|
_TEST_USER_ID = "test-user-run-block"
|
||||||
|
|
||||||
|
|
||||||
|
def make_mock_block(
|
||||||
|
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
||||||
|
):
|
||||||
|
"""Create a mock block for testing."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.id = block_id
|
||||||
|
mock.name = name
|
||||||
|
mock.block_type = block_type
|
||||||
|
mock.disabled = disabled
|
||||||
|
mock.input_schema = MagicMock()
|
||||||
|
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||||
|
mock.input_schema.get_credentials_fields_info.return_value = []
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunBlockFiltering:
|
||||||
|
"""Tests for block execution guards in RunBlockTool."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_excluded_block_type_returns_error(self):
|
||||||
|
"""Attempting to execute a block with excluded BlockType returns error."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=input_block,
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id="input-block-id",
|
||||||
|
input_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, ErrorResponse)
|
||||||
|
assert "cannot be run directly in CoPilot" in response.message
|
||||||
|
assert "designed for use within graphs only" in response.message
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_excluded_block_id_returns_error(self):
|
||||||
|
"""Attempting to execute SmartDecisionMakerBlock returns error."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||||
|
smart_block = make_mock_block(
|
||||||
|
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=smart_block,
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id=smart_decision_id,
|
||||||
|
input_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, ErrorResponse)
|
||||||
|
assert "cannot be run directly in CoPilot" in response.message
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_non_excluded_block_passes_guard(self):
|
||||||
|
"""Non-excluded blocks pass the filtering guard (may fail later for other reasons)."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
standard_block = make_mock_block(
|
||||||
|
"standard-id", "HTTP Request", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=standard_block,
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id="standard-id",
|
||||||
|
input_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should NOT be an ErrorResponse about CoPilot exclusion
|
||||||
|
# (may be other errors like missing credentials, but not the exclusion guard)
|
||||||
|
if isinstance(response, ErrorResponse):
|
||||||
|
assert "cannot be run directly in CoPilot" not in response.message
|
||||||
@@ -8,12 +8,14 @@ from backend.api.features.library import model as library_model
|
|||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
|
Credentials,
|
||||||
CredentialsFieldInfo,
|
CredentialsFieldInfo,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
HostScopedCredentials,
|
HostScopedCredentials,
|
||||||
OAuth2Credentials,
|
OAuth2Credentials,
|
||||||
)
|
)
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -223,6 +225,99 @@ async def get_or_create_library_agent(
|
|||||||
return library_agents[0]
|
return library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def match_credentials_to_requirements(
|
||||||
|
user_id: str,
|
||||||
|
requirements: dict[str, CredentialsFieldInfo],
|
||||||
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
|
"""
|
||||||
|
Match user's credentials against a dictionary of credential requirements.
|
||||||
|
|
||||||
|
This is the core matching logic shared by both graph and block credential matching.
|
||||||
|
"""
|
||||||
|
matched: dict[str, CredentialsMetaInput] = {}
|
||||||
|
missing: list[CredentialsMetaInput] = []
|
||||||
|
|
||||||
|
if not requirements:
|
||||||
|
return matched, missing
|
||||||
|
|
||||||
|
available_creds = await get_user_credentials(user_id)
|
||||||
|
|
||||||
|
for field_name, field_info in requirements.items():
|
||||||
|
matching_cred = find_matching_credential(available_creds, field_info)
|
||||||
|
|
||||||
|
if matching_cred:
|
||||||
|
try:
|
||||||
|
matched[field_name] = create_credential_meta_from_match(matching_cred)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create CredentialsMetaInput for field '{field_name}': "
|
||||||
|
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
||||||
|
f"credential_id={matching_cred.id}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
|
missing.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=f"{field_name} (validation failed: {e})",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
|
missing.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=field_name.replace("_", " ").title(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matched, missing
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_credentials(user_id: str) -> list[Credentials]:
|
||||||
|
"""Get all available credentials for a user."""
|
||||||
|
creds_manager = IntegrationCredentialsManager()
|
||||||
|
return await creds_manager.store.get_all_creds(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def find_matching_credential(
|
||||||
|
available_creds: list[Credentials],
|
||||||
|
field_info: CredentialsFieldInfo,
|
||||||
|
) -> Credentials | None:
|
||||||
|
"""Find a credential that matches the required provider, type, scopes, and host."""
|
||||||
|
for cred in available_creds:
|
||||||
|
if cred.provider not in field_info.provider:
|
||||||
|
continue
|
||||||
|
if cred.type not in field_info.supported_types:
|
||||||
|
continue
|
||||||
|
if cred.type == "oauth2" and not _credential_has_required_scopes(
|
||||||
|
cred, field_info
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
|
||||||
|
continue
|
||||||
|
return cred
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_credential_meta_from_match(
|
||||||
|
matching_cred: Credentials,
|
||||||
|
) -> CredentialsMetaInput:
|
||||||
|
"""Create a CredentialsMetaInput from a matched credential."""
|
||||||
|
return CredentialsMetaInput(
|
||||||
|
id=matching_cred.id,
|
||||||
|
provider=matching_cred.provider, # type: ignore
|
||||||
|
type=matching_cred.type,
|
||||||
|
title=matching_cred.title,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def match_user_credentials_to_graph(
|
async def match_user_credentials_to_graph(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
graph: GraphModel,
|
graph: GraphModel,
|
||||||
@@ -265,7 +360,7 @@ async def match_user_credentials_to_graph(
|
|||||||
_,
|
_,
|
||||||
_,
|
_,
|
||||||
) in aggregated_creds.items():
|
) in aggregated_creds.items():
|
||||||
# Find first matching credential by provider, type, and scopes
|
# Find first matching credential by provider, type, scopes, and host/URL
|
||||||
matching_cred = next(
|
matching_cred = next(
|
||||||
(
|
(
|
||||||
cred
|
cred
|
||||||
@@ -280,6 +375,10 @@ async def match_user_credentials_to_graph(
|
|||||||
cred.type != "host_scoped"
|
cred.type != "host_scoped"
|
||||||
or _credential_is_for_host(cred, credential_requirements)
|
or _credential_is_for_host(cred, credential_requirements)
|
||||||
)
|
)
|
||||||
|
and (
|
||||||
|
cred.provider != ProviderName.MCP
|
||||||
|
or _credential_is_for_mcp_server(cred, credential_requirements)
|
||||||
|
)
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -331,8 +430,6 @@ def _credential_has_required_scopes(
|
|||||||
# If no scopes are required, any credential matches
|
# If no scopes are required, any credential matches
|
||||||
if not requirements.required_scopes:
|
if not requirements.required_scopes:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check that credential scopes are a superset of required scopes
|
|
||||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||||
|
|
||||||
|
|
||||||
@@ -352,6 +449,22 @@ def _credential_is_for_host(
|
|||||||
return credential.matches_url(list(requirements.discriminator_values)[0])
|
return credential.matches_url(list(requirements.discriminator_values)[0])
|
||||||
|
|
||||||
|
|
||||||
|
def _credential_is_for_mcp_server(
|
||||||
|
credential: Credentials,
|
||||||
|
requirements: CredentialsFieldInfo,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if an MCP OAuth credential matches the required server URL."""
|
||||||
|
if not requirements.discriminator_values:
|
||||||
|
return True
|
||||||
|
|
||||||
|
server_url = (
|
||||||
|
credential.metadata.get("mcp_server_url")
|
||||||
|
if isinstance(credential, OAuth2Credentials)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return server_url in requirements.discriminator_values if server_url else False
|
||||||
|
|
||||||
|
|
||||||
async def check_user_has_required_credentials(
|
async def check_user_has_required_credentials(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
required_credentials: list[CredentialsMetaInput],
|
required_credentials: list[CredentialsMetaInput],
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import TYPE_CHECKING, Annotated, List, Literal
|
from typing import TYPE_CHECKING, Annotated, Any, List, Literal
|
||||||
|
|
||||||
from autogpt_libs.auth import get_user_id
|
from autogpt_libs.auth import get_user_id
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
@@ -14,7 +14,7 @@ from fastapi import (
|
|||||||
Security,
|
Security,
|
||||||
status,
|
status,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field, SecretStr, model_validator
|
||||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
||||||
|
|
||||||
from backend.api.features.library.db import set_preset_webhook, update_preset
|
from backend.api.features.library.db import set_preset_webhook, update_preset
|
||||||
@@ -39,7 +39,11 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
|||||||
from backend.data.user import get_user_integrations
|
from backend.data.user import get_user_integrations
|
||||||
from backend.executor.utils import add_graph_execution
|
from backend.executor.utils import add_graph_execution
|
||||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.credentials_store import provider_matches
|
||||||
|
from backend.integrations.creds_manager import (
|
||||||
|
IntegrationCredentialsManager,
|
||||||
|
create_mcp_oauth_handler,
|
||||||
|
)
|
||||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.integrations.webhooks import get_webhook_manager
|
from backend.integrations.webhooks import get_webhook_manager
|
||||||
@@ -102,9 +106,37 @@ class CredentialsMetaResponse(BaseModel):
|
|||||||
scopes: list[str] | None
|
scopes: list[str] | None
|
||||||
username: str | None
|
username: str | None
|
||||||
host: str | None = Field(
|
host: str | None = Field(
|
||||||
default=None, description="Host pattern for host-scoped credentials"
|
default=None,
|
||||||
|
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_provider(cls, data: Any) -> Any:
|
||||||
|
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug."""
|
||||||
|
if isinstance(data, dict):
|
||||||
|
prov = data.get("provider", "")
|
||||||
|
if isinstance(prov, str) and prov.startswith("ProviderName."):
|
||||||
|
member = prov.removeprefix("ProviderName.")
|
||||||
|
try:
|
||||||
|
data = {**data, "provider": ProviderName[member].value}
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
return data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_host(cred: Credentials) -> str | None:
|
||||||
|
"""Extract host from credential: HostScoped host or MCP server URL."""
|
||||||
|
if isinstance(cred, HostScopedCredentials):
|
||||||
|
return cred.host
|
||||||
|
if isinstance(cred, OAuth2Credentials) and cred.provider in (
|
||||||
|
ProviderName.MCP,
|
||||||
|
ProviderName.MCP.value,
|
||||||
|
"ProviderName.MCP",
|
||||||
|
):
|
||||||
|
return (cred.metadata or {}).get("mcp_server_url")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
||||||
async def callback(
|
async def callback(
|
||||||
@@ -179,9 +211,7 @@ async def callback(
|
|||||||
title=credentials.title,
|
title=credentials.title,
|
||||||
scopes=credentials.scopes,
|
scopes=credentials.scopes,
|
||||||
username=credentials.username,
|
username=credentials.username,
|
||||||
host=(
|
host=(CredentialsMetaResponse.get_host(credentials)),
|
||||||
credentials.host if isinstance(credentials, HostScopedCredentials) else None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -199,7 +229,7 @@ async def list_credentials(
|
|||||||
title=cred.title,
|
title=cred.title,
|
||||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
host=CredentialsMetaResponse.get_host(cred),
|
||||||
)
|
)
|
||||||
for cred in credentials
|
for cred in credentials
|
||||||
]
|
]
|
||||||
@@ -222,7 +252,7 @@ async def list_credentials_by_provider(
|
|||||||
title=cred.title,
|
title=cred.title,
|
||||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||||
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
host=CredentialsMetaResponse.get_host(cred),
|
||||||
)
|
)
|
||||||
for cred in credentials
|
for cred in credentials
|
||||||
]
|
]
|
||||||
@@ -322,6 +352,10 @@ async def delete_credentials(
|
|||||||
|
|
||||||
tokens_revoked = None
|
tokens_revoked = None
|
||||||
if isinstance(creds, OAuth2Credentials):
|
if isinstance(creds, OAuth2Credentials):
|
||||||
|
if provider_matches(provider.value, ProviderName.MCP.value):
|
||||||
|
# MCP uses dynamic per-server OAuth — create handler from metadata
|
||||||
|
handler = create_mcp_oauth_handler(creds)
|
||||||
|
else:
|
||||||
handler = _get_provider_oauth_handler(request, provider)
|
handler = _get_provider_oauth_handler(request, provider)
|
||||||
tokens_revoked = await handler.revoke_tokens(creds)
|
tokens_revoked = await handler.revoke_tokens(creds)
|
||||||
|
|
||||||
|
|||||||
402
autogpt_platform/backend/backend/api/features/mcp/routes.py
Normal file
402
autogpt_platform/backend/backend/api/features/mcp/routes.py
Normal file
@@ -0,0 +1,402 @@
|
|||||||
|
"""
|
||||||
|
MCP (Model Context Protocol) API routes.
|
||||||
|
|
||||||
|
Provides endpoints for MCP tool discovery and OAuth authentication so the
|
||||||
|
frontend can list available tools on an MCP server before placing a block.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Annotated, Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
from autogpt_libs.auth import get_user_id
|
||||||
|
from fastapi import Security
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from backend.api.features.integrations.router import CredentialsMetaResponse
|
||||||
|
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||||
|
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||||
|
from backend.data.model import OAuth2Credentials
|
||||||
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.request import HTTPClientError, Requests
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
router = fastapi.APIRouter(tags=["mcp"])
|
||||||
|
creds_manager = IntegrationCredentialsManager()
|
||||||
|
|
||||||
|
|
||||||
|
# ====================== Tool Discovery ====================== #
|
||||||
|
|
||||||
|
|
||||||
|
class DiscoverToolsRequest(BaseModel):
|
||||||
|
"""Request to discover tools on an MCP server."""
|
||||||
|
|
||||||
|
server_url: str = Field(description="URL of the MCP server")
|
||||||
|
auth_token: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional Bearer token for authenticated MCP servers",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPToolResponse(BaseModel):
|
||||||
|
"""A single MCP tool returned by discovery."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
input_schema: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class DiscoverToolsResponse(BaseModel):
|
||||||
|
"""Response containing the list of tools available on an MCP server."""
|
||||||
|
|
||||||
|
tools: list[MCPToolResponse]
|
||||||
|
server_name: str | None = None
|
||||||
|
protocol_version: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/discover-tools",
|
||||||
|
summary="Discover available tools on an MCP server",
|
||||||
|
response_model=DiscoverToolsResponse,
|
||||||
|
)
|
||||||
|
async def discover_tools(
|
||||||
|
request: DiscoverToolsRequest,
|
||||||
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
|
) -> DiscoverToolsResponse:
|
||||||
|
"""
|
||||||
|
Connect to an MCP server and return its available tools.
|
||||||
|
|
||||||
|
If the user has a stored MCP credential for this server URL, it will be
|
||||||
|
used automatically — no need to pass an explicit auth token.
|
||||||
|
"""
|
||||||
|
auth_token = request.auth_token
|
||||||
|
|
||||||
|
# Auto-use stored MCP credential when no explicit token is provided.
|
||||||
|
if not auth_token:
|
||||||
|
try:
|
||||||
|
mcp_creds = await creds_manager.store.get_creds_by_provider(
|
||||||
|
user_id, ProviderName.MCP.value
|
||||||
|
)
|
||||||
|
# Find the freshest credential for this server URL
|
||||||
|
best_cred: OAuth2Credentials | None = None
|
||||||
|
for cred in mcp_creds:
|
||||||
|
if (
|
||||||
|
isinstance(cred, OAuth2Credentials)
|
||||||
|
and cred.metadata.get("mcp_server_url") == request.server_url
|
||||||
|
):
|
||||||
|
if best_cred is None or (
|
||||||
|
(cred.access_token_expires_at or 0)
|
||||||
|
> (best_cred.access_token_expires_at or 0)
|
||||||
|
):
|
||||||
|
best_cred = cred
|
||||||
|
if best_cred:
|
||||||
|
# Refresh the token if expired before using it
|
||||||
|
best_cred = await creds_manager.refresh_if_needed(user_id, best_cred)
|
||||||
|
logger.info(
|
||||||
|
f"Using MCP credential {best_cred.id} for {request.server_url}, "
|
||||||
|
f"expires_at={best_cred.access_token_expires_at}"
|
||||||
|
)
|
||||||
|
auth_token = best_cred.access_token.get_secret_value()
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not look up stored MCP credentials", exc_info=True)
|
||||||
|
|
||||||
|
client = MCPClient(request.server_url, auth_token=auth_token)
|
||||||
|
|
||||||
|
try:
|
||||||
|
init_result = await client.initialize()
|
||||||
|
tools = await client.list_tools()
|
||||||
|
except HTTPClientError as e:
|
||||||
|
if e.status_code in (401, 403):
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="This MCP server requires authentication. "
|
||||||
|
"Please provide a valid auth token.",
|
||||||
|
)
|
||||||
|
raise fastapi.HTTPException(status_code=502, detail=str(e))
|
||||||
|
except MCPClientError as e:
|
||||||
|
raise fastapi.HTTPException(status_code=502, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=f"Failed to connect to MCP server: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return DiscoverToolsResponse(
|
||||||
|
tools=[
|
||||||
|
MCPToolResponse(
|
||||||
|
name=t.name,
|
||||||
|
description=t.description,
|
||||||
|
input_schema=t.input_schema,
|
||||||
|
)
|
||||||
|
for t in tools
|
||||||
|
],
|
||||||
|
server_name=(
|
||||||
|
init_result.get("serverInfo", {}).get("name")
|
||||||
|
or urlparse(request.server_url).hostname
|
||||||
|
or "MCP"
|
||||||
|
),
|
||||||
|
protocol_version=init_result.get("protocolVersion"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ======================== OAuth Flow ======================== #
|
||||||
|
|
||||||
|
|
||||||
|
class MCPOAuthLoginRequest(BaseModel):
|
||||||
|
"""Request to start an OAuth flow for an MCP server."""
|
||||||
|
|
||||||
|
server_url: str = Field(description="URL of the MCP server that requires OAuth")
|
||||||
|
|
||||||
|
|
||||||
|
class MCPOAuthLoginResponse(BaseModel):
|
||||||
|
"""Response with the OAuth login URL for the user to authenticate."""
|
||||||
|
|
||||||
|
login_url: str
|
||||||
|
state_token: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/oauth/login",
|
||||||
|
summary="Initiate OAuth login for an MCP server",
|
||||||
|
)
|
||||||
|
async def mcp_oauth_login(
|
||||||
|
request: MCPOAuthLoginRequest,
|
||||||
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
|
) -> MCPOAuthLoginResponse:
|
||||||
|
"""
|
||||||
|
Discover OAuth metadata from the MCP server and return a login URL.
|
||||||
|
|
||||||
|
1. Discovers the protected-resource metadata (RFC 9728)
|
||||||
|
2. Fetches the authorization server metadata (RFC 8414)
|
||||||
|
3. Performs Dynamic Client Registration (RFC 7591) if available
|
||||||
|
4. Returns the authorization URL for the frontend to open in a popup
|
||||||
|
"""
|
||||||
|
client = MCPClient(request.server_url)
|
||||||
|
|
||||||
|
# Step 1: Discover protected-resource metadata (RFC 9728)
|
||||||
|
protected_resource = await client.discover_auth()
|
||||||
|
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
if protected_resource and protected_resource.get("authorization_servers"):
|
||||||
|
auth_server_url = protected_resource["authorization_servers"][0]
|
||||||
|
resource_url = protected_resource.get("resource", request.server_url)
|
||||||
|
|
||||||
|
# Step 2a: Discover auth-server metadata (RFC 8414)
|
||||||
|
metadata = await client.discover_auth_server_metadata(auth_server_url)
|
||||||
|
else:
|
||||||
|
# Fallback: Some MCP servers (e.g. Linear) are their own auth server
|
||||||
|
# and serve OAuth metadata directly without protected-resource metadata.
|
||||||
|
# Don't assume a resource_url — omitting it lets the auth server choose
|
||||||
|
# the correct audience for the token (RFC 8707 resource is optional).
|
||||||
|
resource_url = None
|
||||||
|
metadata = await client.discover_auth_server_metadata(request.server_url)
|
||||||
|
|
||||||
|
if (
|
||||||
|
not metadata
|
||||||
|
or "authorization_endpoint" not in metadata
|
||||||
|
or "token_endpoint" not in metadata
|
||||||
|
):
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="This MCP server does not advertise OAuth support. "
|
||||||
|
"You may need to provide an auth token manually.",
|
||||||
|
)
|
||||||
|
|
||||||
|
authorize_url = metadata["authorization_endpoint"]
|
||||||
|
token_url = metadata["token_endpoint"]
|
||||||
|
registration_endpoint = metadata.get("registration_endpoint")
|
||||||
|
revoke_url = metadata.get("revocation_endpoint")
|
||||||
|
|
||||||
|
# Step 3: Dynamic Client Registration (RFC 7591) if available
|
||||||
|
frontend_base_url = settings.config.frontend_base_url
|
||||||
|
if not frontend_base_url:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Frontend base URL is not configured.",
|
||||||
|
)
|
||||||
|
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
|
||||||
|
|
||||||
|
client_id = ""
|
||||||
|
client_secret = ""
|
||||||
|
if registration_endpoint:
|
||||||
|
reg_result = await _register_mcp_client(
|
||||||
|
registration_endpoint, redirect_uri, request.server_url
|
||||||
|
)
|
||||||
|
if reg_result:
|
||||||
|
client_id = reg_result.get("client_id", "")
|
||||||
|
client_secret = reg_result.get("client_secret", "")
|
||||||
|
|
||||||
|
if not client_id:
|
||||||
|
client_id = "autogpt-platform"
|
||||||
|
|
||||||
|
# Step 4: Store state token with OAuth metadata for the callback
|
||||||
|
scopes = (protected_resource or {}).get("scopes_supported") or metadata.get(
|
||||||
|
"scopes_supported", []
|
||||||
|
)
|
||||||
|
state_token, code_challenge = await creds_manager.store.store_state_token(
|
||||||
|
user_id,
|
||||||
|
ProviderName.MCP.value,
|
||||||
|
scopes,
|
||||||
|
state_metadata={
|
||||||
|
"authorize_url": authorize_url,
|
||||||
|
"token_url": token_url,
|
||||||
|
"revoke_url": revoke_url,
|
||||||
|
"resource_url": resource_url,
|
||||||
|
"server_url": request.server_url,
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_secret": client_secret,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 5: Build and return the login URL
|
||||||
|
handler = MCPOAuthHandler(
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret=client_secret,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
authorize_url=authorize_url,
|
||||||
|
token_url=token_url,
|
||||||
|
resource_url=resource_url,
|
||||||
|
)
|
||||||
|
login_url = handler.get_login_url(
|
||||||
|
scopes, state_token, code_challenge=code_challenge
|
||||||
|
)
|
||||||
|
|
||||||
|
return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPOAuthCallbackRequest(BaseModel):
|
||||||
|
"""Request to exchange an OAuth code for tokens."""
|
||||||
|
|
||||||
|
code: str = Field(description="Authorization code from OAuth callback")
|
||||||
|
state_token: str = Field(description="State token for CSRF verification")
|
||||||
|
|
||||||
|
|
||||||
|
class MCPOAuthCallbackResponse(BaseModel):
|
||||||
|
"""Response after successfully storing OAuth credentials."""
|
||||||
|
|
||||||
|
credential_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/oauth/callback",
|
||||||
|
summary="Exchange OAuth code for MCP tokens",
|
||||||
|
)
|
||||||
|
async def mcp_oauth_callback(
|
||||||
|
request: MCPOAuthCallbackRequest,
|
||||||
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
|
) -> CredentialsMetaResponse:
|
||||||
|
"""
|
||||||
|
Exchange the authorization code for tokens and store the credential.
|
||||||
|
|
||||||
|
The frontend calls this after receiving the OAuth code from the popup.
|
||||||
|
On success, subsequent ``/discover-tools`` calls for the same server URL
|
||||||
|
will automatically use the stored credential.
|
||||||
|
"""
|
||||||
|
valid_state = await creds_manager.store.verify_state_token(
|
||||||
|
user_id, request.state_token, ProviderName.MCP.value
|
||||||
|
)
|
||||||
|
if not valid_state:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Invalid or expired state token.",
|
||||||
|
)
|
||||||
|
|
||||||
|
meta = valid_state.state_metadata
|
||||||
|
frontend_base_url = settings.config.frontend_base_url
|
||||||
|
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
|
||||||
|
|
||||||
|
handler = MCPOAuthHandler(
|
||||||
|
client_id=meta["client_id"],
|
||||||
|
client_secret=meta.get("client_secret", ""),
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
authorize_url=meta["authorize_url"],
|
||||||
|
token_url=meta["token_url"],
|
||||||
|
revoke_url=meta.get("revoke_url"),
|
||||||
|
resource_url=meta.get("resource_url"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
credentials = await handler.exchange_code_for_tokens(
|
||||||
|
request.code, valid_state.scopes, valid_state.code_verifier
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"OAuth token exchange failed: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enrich credential metadata for future lookup and token refresh
|
||||||
|
if credentials.metadata is None:
|
||||||
|
credentials.metadata = {}
|
||||||
|
credentials.metadata["mcp_server_url"] = meta["server_url"]
|
||||||
|
credentials.metadata["mcp_client_id"] = meta["client_id"]
|
||||||
|
credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "")
|
||||||
|
credentials.metadata["mcp_token_url"] = meta["token_url"]
|
||||||
|
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
|
||||||
|
|
||||||
|
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
|
||||||
|
credentials.title = f"MCP: {hostname}"
|
||||||
|
|
||||||
|
# Remove old MCP credentials for the same server to prevent stale token buildup.
|
||||||
|
try:
|
||||||
|
old_creds = await creds_manager.store.get_creds_by_provider(
|
||||||
|
user_id, ProviderName.MCP.value
|
||||||
|
)
|
||||||
|
for old in old_creds:
|
||||||
|
if (
|
||||||
|
isinstance(old, OAuth2Credentials)
|
||||||
|
and old.metadata.get("mcp_server_url") == meta["server_url"]
|
||||||
|
):
|
||||||
|
await creds_manager.store.delete_creds_by_id(user_id, old.id)
|
||||||
|
logger.info(
|
||||||
|
f"Removed old MCP credential {old.id} for {meta['server_url']}"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not clean up old MCP credentials", exc_info=True)
|
||||||
|
|
||||||
|
await creds_manager.create(user_id, credentials)
|
||||||
|
|
||||||
|
return CredentialsMetaResponse(
|
||||||
|
id=credentials.id,
|
||||||
|
provider=credentials.provider,
|
||||||
|
type=credentials.type,
|
||||||
|
title=credentials.title,
|
||||||
|
scopes=credentials.scopes,
|
||||||
|
username=credentials.username,
|
||||||
|
host=credentials.metadata.get("mcp_server_url"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ======================== Helpers ======================== #
|
||||||
|
|
||||||
|
|
||||||
|
async def _register_mcp_client(
|
||||||
|
registration_endpoint: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
server_url: str,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server."""
|
||||||
|
try:
|
||||||
|
response = await Requests(raise_for_status=True).post(
|
||||||
|
registration_endpoint,
|
||||||
|
json={
|
||||||
|
"client_name": "AutoGPT Platform",
|
||||||
|
"redirect_uris": [redirect_uri],
|
||||||
|
"grant_types": ["authorization_code"],
|
||||||
|
"response_types": ["code"],
|
||||||
|
"token_endpoint_auth_method": "client_secret_post",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
data = response.json()
|
||||||
|
if isinstance(data, dict) and "client_id" in data:
|
||||||
|
return data
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
|
||||||
|
return None
|
||||||
390
autogpt_platform/backend/backend/api/features/mcp/test_routes.py
Normal file
390
autogpt_platform/backend/backend/api/features/mcp/test_routes.py
Normal file
@@ -0,0 +1,390 @@
|
|||||||
|
"""Tests for MCP API routes."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
import fastapi.testclient
|
||||||
|
from autogpt_libs.auth import get_user_id
|
||||||
|
|
||||||
|
from backend.api.features.mcp.routes import router
|
||||||
|
from backend.blocks.mcp.client import MCPClientError, MCPTool
|
||||||
|
from backend.util.request import HTTPClientError
|
||||||
|
|
||||||
|
app = fastapi.FastAPI()
|
||||||
|
app.include_router(router)
|
||||||
|
app.dependency_overrides[get_user_id] = lambda: "test-user-id"
|
||||||
|
client = fastapi.testclient.TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiscoverTools:
|
||||||
|
def test_discover_tools_success(self):
|
||||||
|
mock_tools = [
|
||||||
|
MCPTool(
|
||||||
|
name="get_weather",
|
||||||
|
description="Get weather for a city",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
MCPTool(
|
||||||
|
name="add_numbers",
|
||||||
|
description="Add two numbers",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": {"type": "number"},
|
||||||
|
"b": {"type": "number"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
with (patch("backend.api.features.mcp.routes.MCPClient") as MockClient,):
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"protocolVersion": "2025-03-26",
|
||||||
|
"serverInfo": {"name": "test-server"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
instance.list_tools = AsyncMock(return_value=mock_tools)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={"server_url": "https://mcp.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["tools"]) == 2
|
||||||
|
assert data["tools"][0]["name"] == "get_weather"
|
||||||
|
assert data["tools"][1]["name"] == "add_numbers"
|
||||||
|
assert data["server_name"] == "test-server"
|
||||||
|
assert data["protocol_version"] == "2025-03-26"
|
||||||
|
|
||||||
|
def test_discover_tools_with_auth_token(self):
|
||||||
|
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(
|
||||||
|
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
||||||
|
)
|
||||||
|
instance.list_tools = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={
|
||||||
|
"server_url": "https://mcp.example.com/mcp",
|
||||||
|
"auth_token": "my-secret-token",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
MockClient.assert_called_once_with(
|
||||||
|
"https://mcp.example.com/mcp",
|
||||||
|
auth_token="my-secret-token",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_discover_tools_auto_uses_stored_credential(self):
|
||||||
|
"""When no explicit token is given, stored MCP credentials are used."""
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.data.model import OAuth2Credentials
|
||||||
|
|
||||||
|
stored_cred = OAuth2Credentials(
|
||||||
|
provider="mcp",
|
||||||
|
title="MCP: example.com",
|
||||||
|
access_token=SecretStr("stored-token-123"),
|
||||||
|
refresh_token=None,
|
||||||
|
access_token_expires_at=None,
|
||||||
|
refresh_token_expires_at=None,
|
||||||
|
scopes=[],
|
||||||
|
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||||
|
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||||
|
):
|
||||||
|
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
|
||||||
|
mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred)
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(
|
||||||
|
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
||||||
|
)
|
||||||
|
instance.list_tools = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={"server_url": "https://mcp.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
MockClient.assert_called_once_with(
|
||||||
|
"https://mcp.example.com/mcp",
|
||||||
|
auth_token="stored-token-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_discover_tools_mcp_error(self):
|
||||||
|
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(
|
||||||
|
side_effect=MCPClientError("Connection refused")
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={"server_url": "https://bad-server.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 502
|
||||||
|
assert "Connection refused" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_discover_tools_generic_error(self):
|
||||||
|
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={"server_url": "https://timeout.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 502
|
||||||
|
assert "Failed to connect" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_discover_tools_auth_required(self):
|
||||||
|
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(
|
||||||
|
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={"server_url": "https://auth-server.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "requires authentication" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_discover_tools_forbidden(self):
|
||||||
|
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.initialize = AsyncMock(
|
||||||
|
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/discover-tools",
|
||||||
|
json={"server_url": "https://auth-server.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "requires authentication" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_discover_tools_missing_url(self):
|
||||||
|
response = client.post("/discover-tools", json={})
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuthLogin:
|
||||||
|
def test_oauth_login_success(self):
|
||||||
|
with (
|
||||||
|
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||||
|
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||||
|
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||||
|
patch(
|
||||||
|
"backend.api.features.mcp.routes._register_mcp_client"
|
||||||
|
) as mock_register,
|
||||||
|
):
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.discover_auth = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"authorization_servers": ["https://auth.sentry.io"],
|
||||||
|
"resource": "https://mcp.sentry.dev/mcp",
|
||||||
|
"scopes_supported": ["openid"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
instance.discover_auth_server_metadata = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"authorization_endpoint": "https://auth.sentry.io/authorize",
|
||||||
|
"token_endpoint": "https://auth.sentry.io/token",
|
||||||
|
"registration_endpoint": "https://auth.sentry.io/register",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_register.return_value = {
|
||||||
|
"client_id": "registered-client-id",
|
||||||
|
"client_secret": "registered-secret",
|
||||||
|
}
|
||||||
|
mock_cm.store.store_state_token = AsyncMock(
|
||||||
|
return_value=("state-token-123", "code-challenge-abc")
|
||||||
|
)
|
||||||
|
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/oauth/login",
|
||||||
|
json={"server_url": "https://mcp.sentry.dev/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "login_url" in data
|
||||||
|
assert data["state_token"] == "state-token-123"
|
||||||
|
assert "auth.sentry.io/authorize" in data["login_url"]
|
||||||
|
assert "registered-client-id" in data["login_url"]
|
||||||
|
|
||||||
|
def test_oauth_login_no_oauth_support(self):
|
||||||
|
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.discover_auth = AsyncMock(return_value=None)
|
||||||
|
instance.discover_auth_server_metadata = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/oauth/login",
|
||||||
|
json={"server_url": "https://simple-server.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "does not advertise OAuth" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_oauth_login_fallback_to_public_client(self):
|
||||||
|
"""When DCR is unavailable, falls back to default public client ID."""
|
||||||
|
with (
|
||||||
|
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||||
|
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||||
|
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||||
|
):
|
||||||
|
instance = MockClient.return_value
|
||||||
|
instance.discover_auth = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"authorization_servers": ["https://auth.example.com"],
|
||||||
|
"resource": "https://mcp.example.com/mcp",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
instance.discover_auth_server_metadata = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"authorization_endpoint": "https://auth.example.com/authorize",
|
||||||
|
"token_endpoint": "https://auth.example.com/token",
|
||||||
|
# No registration_endpoint
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_cm.store.store_state_token = AsyncMock(
|
||||||
|
return_value=("state-abc", "challenge-xyz")
|
||||||
|
)
|
||||||
|
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/oauth/login",
|
||||||
|
json={"server_url": "https://mcp.example.com/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "autogpt-platform" in data["login_url"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuthCallback:
|
||||||
|
def test_oauth_callback_success(self):
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.data.model import OAuth2Credentials
|
||||||
|
|
||||||
|
mock_creds = OAuth2Credentials(
|
||||||
|
provider="mcp",
|
||||||
|
title=None,
|
||||||
|
access_token=SecretStr("access-token-xyz"),
|
||||||
|
refresh_token=None,
|
||||||
|
access_token_expires_at=None,
|
||||||
|
refresh_token_expires_at=None,
|
||||||
|
scopes=[],
|
||||||
|
metadata={
|
||||||
|
"mcp_token_url": "https://auth.sentry.io/token",
|
||||||
|
"mcp_resource_url": "https://mcp.sentry.dev/mcp",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||||
|
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||||
|
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
|
||||||
|
):
|
||||||
|
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||||
|
|
||||||
|
# Mock state verification
|
||||||
|
mock_state = AsyncMock()
|
||||||
|
mock_state.state_metadata = {
|
||||||
|
"authorize_url": "https://auth.sentry.io/authorize",
|
||||||
|
"token_url": "https://auth.sentry.io/token",
|
||||||
|
"client_id": "test-client-id",
|
||||||
|
"client_secret": "test-secret",
|
||||||
|
"server_url": "https://mcp.sentry.dev/mcp",
|
||||||
|
}
|
||||||
|
mock_state.scopes = ["openid"]
|
||||||
|
mock_state.code_verifier = "verifier-123"
|
||||||
|
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
|
||||||
|
mock_cm.create = AsyncMock()
|
||||||
|
|
||||||
|
handler_instance = MockHandler.return_value
|
||||||
|
handler_instance.exchange_code_for_tokens = AsyncMock(
|
||||||
|
return_value=mock_creds
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock old credential cleanup
|
||||||
|
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/oauth/callback",
|
||||||
|
json={"code": "auth-code-abc", "state_token": "state-token-123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "id" in data
|
||||||
|
assert data["provider"] == "mcp"
|
||||||
|
assert data["type"] == "oauth2"
|
||||||
|
mock_cm.create.assert_called_once()
|
||||||
|
|
||||||
|
def test_oauth_callback_invalid_state(self):
|
||||||
|
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
|
||||||
|
mock_cm.store.verify_state_token = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/oauth/callback",
|
||||||
|
json={"code": "auth-code", "state_token": "bad-state"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "Invalid or expired" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_oauth_callback_token_exchange_fails(self):
|
||||||
|
with (
|
||||||
|
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||||
|
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||||
|
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
|
||||||
|
):
|
||||||
|
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||||
|
mock_state = AsyncMock()
|
||||||
|
mock_state.state_metadata = {
|
||||||
|
"authorize_url": "https://auth.example.com/authorize",
|
||||||
|
"token_url": "https://auth.example.com/token",
|
||||||
|
"client_id": "cid",
|
||||||
|
"server_url": "https://mcp.example.com/mcp",
|
||||||
|
}
|
||||||
|
mock_state.scopes = []
|
||||||
|
mock_state.code_verifier = "v"
|
||||||
|
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
|
||||||
|
|
||||||
|
handler_instance = MockHandler.return_value
|
||||||
|
handler_instance.exchange_code_for_tokens = AsyncMock(
|
||||||
|
side_effect=RuntimeError("Token exchange failed")
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/oauth/callback",
|
||||||
|
json={"code": "bad-code", "state_token": "state"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "token exchange failed" in response.json()["detail"].lower()
|
||||||
@@ -8,6 +8,7 @@ Includes BM25 reranking for improved lexical relevance.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
@@ -362,7 +363,11 @@ async def unified_hybrid_search(
|
|||||||
LIMIT {limit_param} OFFSET {offset_param}
|
LIMIT {limit_param} OFFSET {offset_param}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
results = await query_raw_with_schema(sql_query, *params)
|
results = await query_raw_with_schema(sql_query, *params)
|
||||||
|
except Exception as e:
|
||||||
|
await _log_vector_error_diagnostics(e)
|
||||||
|
raise
|
||||||
|
|
||||||
total = results[0]["total_count"] if results else 0
|
total = results[0]["total_count"] if results else 0
|
||||||
# Apply BM25 reranking
|
# Apply BM25 reranking
|
||||||
@@ -686,7 +691,11 @@ async def hybrid_search(
|
|||||||
LIMIT {limit_param} OFFSET {offset_param}
|
LIMIT {limit_param} OFFSET {offset_param}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
results = await query_raw_with_schema(sql_query, *params)
|
results = await query_raw_with_schema(sql_query, *params)
|
||||||
|
except Exception as e:
|
||||||
|
await _log_vector_error_diagnostics(e)
|
||||||
|
raise
|
||||||
|
|
||||||
total = results[0]["total_count"] if results else 0
|
total = results[0]["total_count"] if results else 0
|
||||||
|
|
||||||
@@ -718,6 +727,87 @@ async def hybrid_search_simple(
|
|||||||
return await hybrid_search(query=query, page=page, page_size=page_size)
|
return await hybrid_search(query=query, page=page, page_size=page_size)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Diagnostics
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# Rate limit: only log vector error diagnostics once per this interval
|
||||||
|
_VECTOR_DIAG_INTERVAL_SECONDS = 60
|
||||||
|
_last_vector_diag_time: float = 0
|
||||||
|
|
||||||
|
|
||||||
|
async def _log_vector_error_diagnostics(error: Exception) -> None:
|
||||||
|
"""Log diagnostic info when 'type vector does not exist' error occurs.
|
||||||
|
|
||||||
|
Note: Diagnostic queries use query_raw_with_schema which may run on a different
|
||||||
|
pooled connection than the one that failed. Session-level search_path can differ,
|
||||||
|
so these diagnostics show cluster-wide state, not necessarily the failed session.
|
||||||
|
|
||||||
|
Includes rate limiting to avoid log spam - only logs once per minute.
|
||||||
|
Caller should re-raise the error after calling this function.
|
||||||
|
"""
|
||||||
|
global _last_vector_diag_time
|
||||||
|
|
||||||
|
# Check if this is the vector type error
|
||||||
|
error_str = str(error).lower()
|
||||||
|
if not (
|
||||||
|
"type" in error_str and "vector" in error_str and "does not exist" in error_str
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Rate limit: only log once per interval
|
||||||
|
now = time.time()
|
||||||
|
if now - _last_vector_diag_time < _VECTOR_DIAG_INTERVAL_SECONDS:
|
||||||
|
return
|
||||||
|
_last_vector_diag_time = now
|
||||||
|
|
||||||
|
try:
|
||||||
|
diagnostics: dict[str, object] = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
search_path_result = await query_raw_with_schema("SHOW search_path")
|
||||||
|
diagnostics["search_path"] = search_path_result
|
||||||
|
except Exception as e:
|
||||||
|
diagnostics["search_path"] = f"Error: {e}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
schema_result = await query_raw_with_schema("SELECT current_schema()")
|
||||||
|
diagnostics["current_schema"] = schema_result
|
||||||
|
except Exception as e:
|
||||||
|
diagnostics["current_schema"] = f"Error: {e}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_result = await query_raw_with_schema(
|
||||||
|
"SELECT current_user, session_user, current_database()"
|
||||||
|
)
|
||||||
|
diagnostics["user_info"] = user_result
|
||||||
|
except Exception as e:
|
||||||
|
diagnostics["user_info"] = f"Error: {e}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check pgvector extension installation (cluster-wide, stable info)
|
||||||
|
ext_result = await query_raw_with_schema(
|
||||||
|
"SELECT extname, extversion, nspname as schema "
|
||||||
|
"FROM pg_extension e "
|
||||||
|
"JOIN pg_namespace n ON e.extnamespace = n.oid "
|
||||||
|
"WHERE extname = 'vector'"
|
||||||
|
)
|
||||||
|
diagnostics["pgvector_extension"] = ext_result
|
||||||
|
except Exception as e:
|
||||||
|
diagnostics["pgvector_extension"] = f"Error: {e}"
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
f"Vector type error diagnostics:\n"
|
||||||
|
f" Error: {error}\n"
|
||||||
|
f" search_path: {diagnostics.get('search_path')}\n"
|
||||||
|
f" current_schema: {diagnostics.get('current_schema')}\n"
|
||||||
|
f" user_info: {diagnostics.get('user_info')}\n"
|
||||||
|
f" pgvector_extension: {diagnostics.get('pgvector_extension')}"
|
||||||
|
)
|
||||||
|
except Exception as diag_error:
|
||||||
|
logger.error(f"Failed to collect vector error diagnostics: {diag_error}")
|
||||||
|
|
||||||
|
|
||||||
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
|
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
|
||||||
# for existing code that expects the popularity parameter
|
# for existing code that expects the popularity parameter
|
||||||
HybridSearchWeights = StoreAgentSearchWeights
|
HybridSearchWeights = StoreAgentSearchWeights
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import backend.api.features.executions.review.routes
|
|||||||
import backend.api.features.library.db
|
import backend.api.features.library.db
|
||||||
import backend.api.features.library.model
|
import backend.api.features.library.model
|
||||||
import backend.api.features.library.routes
|
import backend.api.features.library.routes
|
||||||
|
import backend.api.features.mcp.routes as mcp_routes
|
||||||
import backend.api.features.oauth
|
import backend.api.features.oauth
|
||||||
import backend.api.features.otto.routes
|
import backend.api.features.otto.routes
|
||||||
import backend.api.features.postmark.postmark
|
import backend.api.features.postmark.postmark
|
||||||
@@ -343,6 +344,11 @@ app.include_router(
|
|||||||
tags=["workspace"],
|
tags=["workspace"],
|
||||||
prefix="/api/workspace",
|
prefix="/api/workspace",
|
||||||
)
|
)
|
||||||
|
app.include_router(
|
||||||
|
mcp_routes.router,
|
||||||
|
tags=["v2", "mcp"],
|
||||||
|
prefix="/api/mcp",
|
||||||
|
)
|
||||||
app.include_router(
|
app.include_router(
|
||||||
backend.api.features.oauth.router,
|
backend.api.features.oauth.router,
|
||||||
tags=["oauth"],
|
tags=["oauth"],
|
||||||
|
|||||||
301
autogpt_platform/backend/backend/blocks/mcp/block.py
Normal file
301
autogpt_platform/backend/backend/blocks/mcp/block.py
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
"""
|
||||||
|
MCP (Model Context Protocol) Tool Block.
|
||||||
|
|
||||||
|
A single dynamic block that can connect to any MCP server, discover available tools,
|
||||||
|
and execute them. Works like AgentExecutorBlock — the user selects a tool from a
|
||||||
|
dropdown and the input/output schema adapts dynamically.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockInput,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
BlockType,
|
||||||
|
)
|
||||||
|
from backend.data.model import (
|
||||||
|
CredentialsField,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
OAuth2Credentials,
|
||||||
|
SchemaField,
|
||||||
|
)
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.json import validate_with_jsonschema
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
TEST_CREDENTIALS = OAuth2Credentials(
|
||||||
|
id="test-mcp-cred",
|
||||||
|
provider="mcp",
|
||||||
|
access_token=SecretStr("mock-mcp-token"),
|
||||||
|
refresh_token=SecretStr("mock-refresh"),
|
||||||
|
scopes=[],
|
||||||
|
title="Mock MCP credential",
|
||||||
|
)
|
||||||
|
TEST_CREDENTIALS_INPUT = {
|
||||||
|
"provider": TEST_CREDENTIALS.provider,
|
||||||
|
"id": TEST_CREDENTIALS.id,
|
||||||
|
"type": TEST_CREDENTIALS.type,
|
||||||
|
"title": TEST_CREDENTIALS.title,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
MCPCredentials = CredentialsMetaInput[Literal[ProviderName.MCP], Literal["oauth2"]]
|
||||||
|
|
||||||
|
|
||||||
|
class MCPToolBlock(Block):
|
||||||
|
"""
|
||||||
|
A block that connects to an MCP server, lets the user pick a tool,
|
||||||
|
and executes it with dynamic input/output schema.
|
||||||
|
|
||||||
|
The flow:
|
||||||
|
1. User provides an MCP server URL (and optional credentials)
|
||||||
|
2. Frontend calls the backend to get tool list from that URL
|
||||||
|
3. User selects a tool from a dropdown (available_tools)
|
||||||
|
4. The block's input schema updates to reflect the selected tool's parameters
|
||||||
|
5. On execution, the block calls the MCP server to run the tool
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
server_url: str = SchemaField(
|
||||||
|
description="URL of the MCP server (Streamable HTTP endpoint)",
|
||||||
|
placeholder="https://mcp.example.com/mcp",
|
||||||
|
)
|
||||||
|
credentials: MCPCredentials = CredentialsField(
|
||||||
|
discriminator="server_url",
|
||||||
|
description="MCP server OAuth credentials",
|
||||||
|
default={},
|
||||||
|
)
|
||||||
|
selected_tool: str = SchemaField(
|
||||||
|
description="The MCP tool to execute",
|
||||||
|
placeholder="Select a tool",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
tool_input_schema: dict[str, Any] = SchemaField(
|
||||||
|
description="JSON Schema for the selected tool's input parameters. "
|
||||||
|
"Populated automatically when a tool is selected.",
|
||||||
|
default={},
|
||||||
|
hidden=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_arguments: dict[str, Any] = SchemaField(
|
||||||
|
description="Arguments to pass to the selected MCP tool. "
|
||||||
|
"The fields here are defined by the tool's input schema.",
|
||||||
|
default={},
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
|
||||||
|
"""Return the tool's input schema so the builder UI renders dynamic fields."""
|
||||||
|
return data.get("tool_input_schema", {})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||||
|
"""Return the current tool_arguments as defaults for the dynamic fields."""
|
||||||
|
return data.get("tool_arguments", {})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||||
|
"""Check which required tool arguments are missing."""
|
||||||
|
required_fields = cls.get_input_schema(data).get("required", [])
|
||||||
|
tool_arguments = data.get("tool_arguments", {})
|
||||||
|
return set(required_fields) - set(tool_arguments)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||||
|
"""Validate tool_arguments against the tool's input schema."""
|
||||||
|
tool_schema = cls.get_input_schema(data)
|
||||||
|
if not tool_schema:
|
||||||
|
return None
|
||||||
|
tool_arguments = data.get("tool_arguments", {})
|
||||||
|
return validate_with_jsonschema(tool_schema, tool_arguments)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
result: Any = SchemaField(description="The result returned by the MCP tool")
|
||||||
|
error: str = SchemaField(description="Error message if the tool call failed")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4",
|
||||||
|
description="Connect to any MCP server and execute its tools. "
|
||||||
|
"Provide a server URL, select a tool, and pass arguments dynamically.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=MCPToolBlock.Input,
|
||||||
|
output_schema=MCPToolBlock.Output,
|
||||||
|
block_type=BlockType.STANDARD,
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_input={
|
||||||
|
"server_url": "https://mcp.example.com/mcp",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
"selected_tool": "get_weather",
|
||||||
|
"tool_input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
"tool_arguments": {"city": "London"},
|
||||||
|
},
|
||||||
|
test_output=[
|
||||||
|
(
|
||||||
|
"result",
|
||||||
|
{"weather": "sunny", "temperature": 20},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"_call_mcp_tool": lambda *a, **kw: {
|
||||||
|
"weather": "sunny",
|
||||||
|
"temperature": 20,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _call_mcp_tool(
|
||||||
|
self,
|
||||||
|
server_url: str,
|
||||||
|
tool_name: str,
|
||||||
|
arguments: dict[str, Any],
|
||||||
|
auth_token: str | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Call a tool on the MCP server. Extracted for easy mocking in tests."""
|
||||||
|
client = MCPClient(server_url, auth_token=auth_token)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool(tool_name, arguments)
|
||||||
|
|
||||||
|
if result.is_error:
|
||||||
|
error_text = ""
|
||||||
|
for item in result.content:
|
||||||
|
if item.get("type") == "text":
|
||||||
|
error_text += item.get("text", "")
|
||||||
|
raise MCPClientError(
|
||||||
|
f"MCP tool '{tool_name}' returned an error: "
|
||||||
|
f"{error_text or 'Unknown error'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract text content from the result
|
||||||
|
output_parts = []
|
||||||
|
for item in result.content:
|
||||||
|
if item.get("type") == "text":
|
||||||
|
text = item.get("text", "")
|
||||||
|
# Try to parse as JSON for structured output
|
||||||
|
try:
|
||||||
|
output_parts.append(json.loads(text))
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
output_parts.append(text)
|
||||||
|
elif item.get("type") == "image":
|
||||||
|
output_parts.append(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"data": item.get("data"),
|
||||||
|
"mimeType": item.get("mimeType"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif item.get("type") == "resource":
|
||||||
|
output_parts.append(item.get("resource", {}))
|
||||||
|
|
||||||
|
# If single result, unwrap
|
||||||
|
if len(output_parts) == 1:
|
||||||
|
return output_parts[0]
|
||||||
|
return output_parts if output_parts else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _auto_lookup_credential(
|
||||||
|
user_id: str, server_url: str
|
||||||
|
) -> "OAuth2Credentials | None":
|
||||||
|
"""Auto-lookup stored MCP credential for a server URL.
|
||||||
|
|
||||||
|
This is a fallback for nodes that don't have ``credentials`` explicitly
|
||||||
|
set (e.g. nodes created before the credential field was wired up).
|
||||||
|
"""
|
||||||
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
try:
|
||||||
|
mgr = IntegrationCredentialsManager()
|
||||||
|
mcp_creds = await mgr.store.get_creds_by_provider(
|
||||||
|
user_id, ProviderName.MCP.value
|
||||||
|
)
|
||||||
|
best: OAuth2Credentials | None = None
|
||||||
|
for cred in mcp_creds:
|
||||||
|
if (
|
||||||
|
isinstance(cred, OAuth2Credentials)
|
||||||
|
and cred.metadata.get("mcp_server_url") == server_url
|
||||||
|
):
|
||||||
|
if best is None or (
|
||||||
|
(cred.access_token_expires_at or 0)
|
||||||
|
> (best.access_token_expires_at or 0)
|
||||||
|
):
|
||||||
|
best = cred
|
||||||
|
if best:
|
||||||
|
best = await mgr.refresh_if_needed(user_id, best)
|
||||||
|
logger.info(
|
||||||
|
"Auto-resolved MCP credential %s for %s", best.id, server_url
|
||||||
|
)
|
||||||
|
return best
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Auto-lookup MCP credential failed", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
credentials: OAuth2Credentials | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
if not input_data.server_url:
|
||||||
|
yield "error", "MCP server URL is required"
|
||||||
|
return
|
||||||
|
|
||||||
|
if not input_data.selected_tool:
|
||||||
|
yield "error", "No tool selected. Please select a tool from the dropdown."
|
||||||
|
return
|
||||||
|
|
||||||
|
# Validate required tool arguments before calling the server.
|
||||||
|
# The executor-level validation is bypassed for MCP blocks because
|
||||||
|
# get_input_defaults() flattens tool_arguments, stripping tool_input_schema
|
||||||
|
# from the validation context.
|
||||||
|
required = set(input_data.tool_input_schema.get("required", []))
|
||||||
|
if required:
|
||||||
|
missing = required - set(input_data.tool_arguments.keys())
|
||||||
|
if missing:
|
||||||
|
yield "error", (
|
||||||
|
f"Missing required argument(s): {', '.join(sorted(missing))}. "
|
||||||
|
f"Please fill in all required fields marked with * in the block form."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# If no credentials were injected by the executor (e.g. legacy nodes
|
||||||
|
# that don't have the credentials field set), try to auto-lookup
|
||||||
|
# the stored MCP credential for this server URL.
|
||||||
|
if credentials is None:
|
||||||
|
credentials = await self._auto_lookup_credential(
|
||||||
|
user_id, input_data.server_url
|
||||||
|
)
|
||||||
|
|
||||||
|
auth_token = (
|
||||||
|
credentials.access_token.get_secret_value() if credentials else None
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self._call_mcp_tool(
|
||||||
|
server_url=input_data.server_url,
|
||||||
|
tool_name=input_data.selected_tool,
|
||||||
|
arguments=input_data.tool_arguments,
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
yield "result", result
|
||||||
|
except MCPClientError as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"MCP tool call failed: {e}")
|
||||||
|
yield "error", f"MCP tool call failed: {str(e)}"
|
||||||
323
autogpt_platform/backend/backend/blocks/mcp/client.py
Normal file
323
autogpt_platform/backend/backend/blocks/mcp/client.py
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
"""
|
||||||
|
MCP (Model Context Protocol) HTTP client.
|
||||||
|
|
||||||
|
Implements the MCP Streamable HTTP transport for listing tools and calling tools
|
||||||
|
on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST.
|
||||||
|
|
||||||
|
Handles both JSON and SSE (text/event-stream) response formats per the MCP spec.
|
||||||
|
|
||||||
|
Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MCPTool:
|
||||||
|
"""Represents an MCP tool discovered from a server."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
input_schema: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MCPCallResult:
|
||||||
|
"""Result from calling an MCP tool."""
|
||||||
|
|
||||||
|
content: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
is_error: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class MCPClientError(Exception):
|
||||||
|
"""Raised when an MCP protocol error occurs."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MCPClient:
|
||||||
|
"""
|
||||||
|
Async HTTP client for the MCP Streamable HTTP transport.
|
||||||
|
|
||||||
|
Communicates with MCP servers using JSON-RPC 2.0 over HTTP POST.
|
||||||
|
Supports optional Bearer token authentication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_url: str,
|
||||||
|
auth_token: str | None = None,
|
||||||
|
):
|
||||||
|
self.server_url = server_url.rstrip("/")
|
||||||
|
self.auth_token = auth_token
|
||||||
|
self._request_id = 0
|
||||||
|
self._session_id: str | None = None
|
||||||
|
|
||||||
|
def _next_id(self) -> int:
|
||||||
|
self._request_id += 1
|
||||||
|
return self._request_id
|
||||||
|
|
||||||
|
def _build_headers(self) -> dict[str, str]:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
}
|
||||||
|
if self.auth_token:
|
||||||
|
headers["Authorization"] = f"Bearer {self.auth_token}"
|
||||||
|
if self._session_id:
|
||||||
|
headers["Mcp-Session-Id"] = self._session_id
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def _build_jsonrpc_request(
|
||||||
|
self, method: str, params: dict[str, Any] | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
req: dict[str, Any] = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": method,
|
||||||
|
"id": self._next_id(),
|
||||||
|
}
|
||||||
|
if params is not None:
|
||||||
|
req["params"] = params
|
||||||
|
return req
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_sse_response(text: str) -> dict[str, Any]:
|
||||||
|
"""Parse an SSE (text/event-stream) response body into JSON-RPC data.
|
||||||
|
|
||||||
|
MCP servers may return responses as SSE with format:
|
||||||
|
event: message
|
||||||
|
data: {"jsonrpc":"2.0","result":{...},"id":1}
|
||||||
|
|
||||||
|
We extract the last `data:` line that contains a JSON-RPC response
|
||||||
|
(i.e. has an "id" field), which is the reply to our request.
|
||||||
|
"""
|
||||||
|
last_data: dict[str, Any] | None = None
|
||||||
|
for line in text.splitlines():
|
||||||
|
stripped = line.strip()
|
||||||
|
if stripped.startswith("data:"):
|
||||||
|
payload = stripped[len("data:") :].strip()
|
||||||
|
if not payload:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
parsed = json.loads(payload)
|
||||||
|
# Only keep JSON-RPC responses (have "id"), skip notifications
|
||||||
|
if isinstance(parsed, dict) and "id" in parsed:
|
||||||
|
last_data = parsed
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
continue
|
||||||
|
if last_data is None:
|
||||||
|
raise MCPClientError("No JSON-RPC response found in SSE stream")
|
||||||
|
return last_data
|
||||||
|
|
||||||
|
async def _send_request(
|
||||||
|
self, method: str, params: dict[str, Any] | None = None
|
||||||
|
) -> Any:
|
||||||
|
"""Send a JSON-RPC request to the MCP server and return the result.
|
||||||
|
|
||||||
|
Handles both ``application/json`` and ``text/event-stream`` responses
|
||||||
|
as required by the MCP Streamable HTTP transport specification.
|
||||||
|
"""
|
||||||
|
payload = self._build_jsonrpc_request(method, params)
|
||||||
|
headers = self._build_headers()
|
||||||
|
|
||||||
|
requests = Requests(
|
||||||
|
raise_for_status=True,
|
||||||
|
extra_headers=headers,
|
||||||
|
)
|
||||||
|
response = await requests.post(self.server_url, json=payload)
|
||||||
|
|
||||||
|
# Capture session ID from response (MCP Streamable HTTP transport)
|
||||||
|
session_id = response.headers.get("Mcp-Session-Id")
|
||||||
|
if session_id:
|
||||||
|
self._session_id = session_id
|
||||||
|
|
||||||
|
content_type = response.headers.get("content-type", "")
|
||||||
|
if "text/event-stream" in content_type:
|
||||||
|
body = self._parse_sse_response(response.text())
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
body = response.json()
|
||||||
|
except Exception as e:
|
||||||
|
raise MCPClientError(
|
||||||
|
f"MCP server returned non-JSON response: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if not isinstance(body, dict):
|
||||||
|
raise MCPClientError(
|
||||||
|
f"MCP server returned unexpected JSON type: {type(body).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle JSON-RPC error
|
||||||
|
if "error" in body:
|
||||||
|
error = body["error"]
|
||||||
|
if isinstance(error, dict):
|
||||||
|
raise MCPClientError(
|
||||||
|
f"MCP server error [{error.get('code', '?')}]: "
|
||||||
|
f"{error.get('message', 'Unknown error')}"
|
||||||
|
)
|
||||||
|
raise MCPClientError(f"MCP server error: {error}")
|
||||||
|
|
||||||
|
return body.get("result")
|
||||||
|
|
||||||
|
async def _send_notification(self, method: str) -> None:
|
||||||
|
"""Send a JSON-RPC notification (no id, no response expected)."""
|
||||||
|
headers = self._build_headers()
|
||||||
|
notification = {"jsonrpc": "2.0", "method": method}
|
||||||
|
requests = Requests(
|
||||||
|
raise_for_status=False,
|
||||||
|
extra_headers=headers,
|
||||||
|
)
|
||||||
|
await requests.post(self.server_url, json=notification)
|
||||||
|
|
||||||
|
async def discover_auth(self) -> dict[str, Any] | None:
|
||||||
|
"""Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec).
|
||||||
|
|
||||||
|
Returns ``None`` if the server doesn't require auth, otherwise returns
|
||||||
|
a dict with:
|
||||||
|
- ``authorization_servers``: list of authorization server URLs
|
||||||
|
- ``resource``: the resource indicator URL (usually the MCP endpoint)
|
||||||
|
- ``scopes_supported``: optional list of supported scopes
|
||||||
|
|
||||||
|
The caller can then fetch the authorization server metadata to get
|
||||||
|
``authorization_endpoint``, ``token_endpoint``, etc.
|
||||||
|
"""
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
parsed = urlparse(self.server_url)
|
||||||
|
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||||
|
|
||||||
|
# Build candidates for protected-resource metadata (per RFC 9728)
|
||||||
|
path = parsed.path.rstrip("/")
|
||||||
|
candidates = []
|
||||||
|
if path and path != "/":
|
||||||
|
candidates.append(f"{base}/.well-known/oauth-protected-resource{path}")
|
||||||
|
candidates.append(f"{base}/.well-known/oauth-protected-resource")
|
||||||
|
|
||||||
|
requests = Requests(
|
||||||
|
raise_for_status=False,
|
||||||
|
)
|
||||||
|
for url in candidates:
|
||||||
|
try:
|
||||||
|
resp = await requests.get(url)
|
||||||
|
if resp.status == 200:
|
||||||
|
data = resp.json()
|
||||||
|
if isinstance(data, dict) and "authorization_servers" in data:
|
||||||
|
return data
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def discover_auth_server_metadata(
|
||||||
|
self, auth_server_url: str
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Fetch the OAuth Authorization Server Metadata (RFC 8414).
|
||||||
|
|
||||||
|
Given an authorization server URL, returns a dict with:
|
||||||
|
- ``authorization_endpoint``
|
||||||
|
- ``token_endpoint``
|
||||||
|
- ``registration_endpoint`` (for dynamic client registration)
|
||||||
|
- ``scopes_supported``
|
||||||
|
- ``code_challenge_methods_supported``
|
||||||
|
- etc.
|
||||||
|
"""
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
parsed = urlparse(auth_server_url)
|
||||||
|
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||||
|
path = parsed.path.rstrip("/")
|
||||||
|
|
||||||
|
# Try standard metadata endpoints (RFC 8414 and OpenID Connect)
|
||||||
|
candidates = []
|
||||||
|
if path and path != "/":
|
||||||
|
candidates.append(f"{base}/.well-known/oauth-authorization-server{path}")
|
||||||
|
candidates.append(f"{base}/.well-known/oauth-authorization-server")
|
||||||
|
candidates.append(f"{base}/.well-known/openid-configuration")
|
||||||
|
|
||||||
|
requests = Requests(
|
||||||
|
raise_for_status=False,
|
||||||
|
)
|
||||||
|
for url in candidates:
|
||||||
|
try:
|
||||||
|
resp = await requests.get(url)
|
||||||
|
if resp.status == 200:
|
||||||
|
data = resp.json()
|
||||||
|
if isinstance(data, dict) and "authorization_endpoint" in data:
|
||||||
|
return data
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def initialize(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Send the MCP initialize request.
|
||||||
|
|
||||||
|
This is required by the MCP protocol before any other requests.
|
||||||
|
Returns the server's capabilities.
|
||||||
|
"""
|
||||||
|
result = await self._send_request(
|
||||||
|
"initialize",
|
||||||
|
{
|
||||||
|
"protocolVersion": "2025-03-26",
|
||||||
|
"capabilities": {},
|
||||||
|
"clientInfo": {"name": "AutoGPT-Platform", "version": "1.0.0"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Send initialized notification (no response expected)
|
||||||
|
await self._send_notification("notifications/initialized")
|
||||||
|
|
||||||
|
return result or {}
|
||||||
|
|
||||||
|
async def list_tools(self) -> list[MCPTool]:
|
||||||
|
"""
|
||||||
|
Discover available tools from the MCP server.
|
||||||
|
|
||||||
|
Returns a list of MCPTool objects with name, description, and input schema.
|
||||||
|
"""
|
||||||
|
result = await self._send_request("tools/list")
|
||||||
|
if not result or "tools" not in result:
|
||||||
|
return []
|
||||||
|
|
||||||
|
tools = []
|
||||||
|
for tool_data in result["tools"]:
|
||||||
|
tools.append(
|
||||||
|
MCPTool(
|
||||||
|
name=tool_data.get("name", ""),
|
||||||
|
description=tool_data.get("description", ""),
|
||||||
|
input_schema=tool_data.get("inputSchema", {}),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
async def call_tool(
|
||||||
|
self, tool_name: str, arguments: dict[str, Any]
|
||||||
|
) -> MCPCallResult:
|
||||||
|
"""
|
||||||
|
Call a tool on the MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: The name of the tool to call.
|
||||||
|
arguments: The arguments to pass to the tool.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MCPCallResult with the tool's response content.
|
||||||
|
"""
|
||||||
|
result = await self._send_request(
|
||||||
|
"tools/call",
|
||||||
|
{"name": tool_name, "arguments": arguments},
|
||||||
|
)
|
||||||
|
if not result:
|
||||||
|
return MCPCallResult(is_error=True)
|
||||||
|
|
||||||
|
return MCPCallResult(
|
||||||
|
content=result.get("content", []),
|
||||||
|
is_error=result.get("isError", False),
|
||||||
|
)
|
||||||
204
autogpt_platform/backend/backend/blocks/mcp/oauth.py
Normal file
204
autogpt_platform/backend/backend/blocks/mcp/oauth.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
"""
|
||||||
|
MCP OAuth handler for MCP servers that use OAuth 2.1 authorization.
|
||||||
|
|
||||||
|
Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed,
|
||||||
|
MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata.
|
||||||
|
This handler accepts those endpoints at construction time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import urllib.parse
|
||||||
|
from typing import ClassVar, Optional
|
||||||
|
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.data.model import OAuth2Credentials
|
||||||
|
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPOAuthHandler(BaseOAuthHandler):
|
||||||
|
"""
|
||||||
|
OAuth handler for MCP servers with dynamically-discovered endpoints.
|
||||||
|
|
||||||
|
Construction requires the authorization and token endpoint URLs,
|
||||||
|
which are obtained via MCP OAuth metadata discovery
|
||||||
|
(``MCPClient.discover_auth`` + ``discover_auth_server_metadata``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP
|
||||||
|
DEFAULT_SCOPES: ClassVar[list[str]] = []
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
client_id: str,
|
||||||
|
client_secret: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
*,
|
||||||
|
authorize_url: str,
|
||||||
|
token_url: str,
|
||||||
|
revoke_url: str | None = None,
|
||||||
|
resource_url: str | None = None,
|
||||||
|
):
|
||||||
|
self.client_id = client_id
|
||||||
|
self.client_secret = client_secret
|
||||||
|
self.redirect_uri = redirect_uri
|
||||||
|
self.authorize_url = authorize_url
|
||||||
|
self.token_url = token_url
|
||||||
|
self.revoke_url = revoke_url
|
||||||
|
self.resource_url = resource_url
|
||||||
|
|
||||||
|
def get_login_url(
|
||||||
|
self,
|
||||||
|
scopes: list[str],
|
||||||
|
state: str,
|
||||||
|
code_challenge: Optional[str],
|
||||||
|
) -> str:
|
||||||
|
scopes = self.handle_default_scopes(scopes)
|
||||||
|
|
||||||
|
params: dict[str, str] = {
|
||||||
|
"response_type": "code",
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"redirect_uri": self.redirect_uri,
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
if scopes:
|
||||||
|
params["scope"] = " ".join(scopes)
|
||||||
|
# PKCE (S256) — included when the caller provides a code_challenge
|
||||||
|
if code_challenge:
|
||||||
|
params["code_challenge"] = code_challenge
|
||||||
|
params["code_challenge_method"] = "S256"
|
||||||
|
# MCP spec requires resource indicator (RFC 8707)
|
||||||
|
if self.resource_url:
|
||||||
|
params["resource"] = self.resource_url
|
||||||
|
|
||||||
|
return f"{self.authorize_url}?{urllib.parse.urlencode(params)}"
|
||||||
|
|
||||||
|
async def exchange_code_for_tokens(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
scopes: list[str],
|
||||||
|
code_verifier: Optional[str],
|
||||||
|
) -> OAuth2Credentials:
|
||||||
|
data: dict[str, str] = {
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": self.redirect_uri,
|
||||||
|
"client_id": self.client_id,
|
||||||
|
}
|
||||||
|
if self.client_secret:
|
||||||
|
data["client_secret"] = self.client_secret
|
||||||
|
if code_verifier:
|
||||||
|
data["code_verifier"] = code_verifier
|
||||||
|
if self.resource_url:
|
||||||
|
data["resource"] = self.resource_url
|
||||||
|
|
||||||
|
response = await Requests(raise_for_status=True).post(
|
||||||
|
self.token_url,
|
||||||
|
data=data,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
)
|
||||||
|
tokens = response.json()
|
||||||
|
|
||||||
|
if "error" in tokens:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Token exchange failed: {tokens.get('error_description', tokens['error'])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "access_token" not in tokens:
|
||||||
|
raise RuntimeError("OAuth token response missing 'access_token' field")
|
||||||
|
|
||||||
|
now = int(time.time())
|
||||||
|
expires_in = tokens.get("expires_in")
|
||||||
|
|
||||||
|
return OAuth2Credentials(
|
||||||
|
provider=self.PROVIDER_NAME,
|
||||||
|
title=None,
|
||||||
|
access_token=SecretStr(tokens["access_token"]),
|
||||||
|
refresh_token=(
|
||||||
|
SecretStr(tokens["refresh_token"])
|
||||||
|
if tokens.get("refresh_token")
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
access_token_expires_at=now + expires_in if expires_in else None,
|
||||||
|
refresh_token_expires_at=None,
|
||||||
|
scopes=scopes,
|
||||||
|
metadata={
|
||||||
|
"mcp_token_url": self.token_url,
|
||||||
|
"mcp_resource_url": self.resource_url,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _refresh_tokens(
|
||||||
|
self, credentials: OAuth2Credentials
|
||||||
|
) -> OAuth2Credentials:
|
||||||
|
if not credentials.refresh_token:
|
||||||
|
raise ValueError("No refresh token available for MCP OAuth credentials")
|
||||||
|
|
||||||
|
data: dict[str, str] = {
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||||
|
"client_id": self.client_id,
|
||||||
|
}
|
||||||
|
if self.client_secret:
|
||||||
|
data["client_secret"] = self.client_secret
|
||||||
|
if self.resource_url:
|
||||||
|
data["resource"] = self.resource_url
|
||||||
|
|
||||||
|
response = await Requests(raise_for_status=True).post(
|
||||||
|
self.token_url,
|
||||||
|
data=data,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
)
|
||||||
|
tokens = response.json()
|
||||||
|
|
||||||
|
if "error" in tokens:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Token refresh failed: {tokens.get('error_description', tokens['error'])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "access_token" not in tokens:
|
||||||
|
raise RuntimeError("OAuth refresh response missing 'access_token' field")
|
||||||
|
|
||||||
|
now = int(time.time())
|
||||||
|
expires_in = tokens.get("expires_in")
|
||||||
|
|
||||||
|
return OAuth2Credentials(
|
||||||
|
id=credentials.id,
|
||||||
|
provider=self.PROVIDER_NAME,
|
||||||
|
title=credentials.title,
|
||||||
|
access_token=SecretStr(tokens["access_token"]),
|
||||||
|
refresh_token=(
|
||||||
|
SecretStr(tokens["refresh_token"])
|
||||||
|
if tokens.get("refresh_token")
|
||||||
|
else credentials.refresh_token
|
||||||
|
),
|
||||||
|
access_token_expires_at=now + expires_in if expires_in else None,
|
||||||
|
refresh_token_expires_at=credentials.refresh_token_expires_at,
|
||||||
|
scopes=credentials.scopes,
|
||||||
|
metadata=credentials.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||||
|
if not self.revoke_url:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = {
|
||||||
|
"token": credentials.access_token.get_secret_value(),
|
||||||
|
"token_type_hint": "access_token",
|
||||||
|
"client_id": self.client_id,
|
||||||
|
}
|
||||||
|
await Requests().post(
|
||||||
|
self.revoke_url,
|
||||||
|
data=data,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True)
|
||||||
|
return False
|
||||||
109
autogpt_platform/backend/backend/blocks/mcp/test_e2e.py
Normal file
109
autogpt_platform/backend/backend/blocks/mcp/test_e2e.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
"""
|
||||||
|
End-to-end tests against a real public MCP server.
|
||||||
|
|
||||||
|
These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp)
|
||||||
|
which is publicly accessible without authentication and returns SSE responses.
|
||||||
|
|
||||||
|
Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped
|
||||||
|
independently of the rest of the test suite (they require network access).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.blocks.mcp.client import MCPClient
|
||||||
|
|
||||||
|
# Public MCP server that requires no authentication
|
||||||
|
OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp"
|
||||||
|
|
||||||
|
# Skip all tests in this module unless RUN_E2E env var is set
|
||||||
|
pytestmark = pytest.mark.skipif(
|
||||||
|
not os.environ.get("RUN_E2E"), reason="set RUN_E2E=1 to run e2e tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRealMCPServer:
|
||||||
|
"""Tests against the live OpenAI docs MCP server."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_initialize(self):
|
||||||
|
"""Verify we can complete the MCP handshake with a real server."""
|
||||||
|
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||||
|
result = await client.initialize()
|
||||||
|
|
||||||
|
assert result["protocolVersion"] == "2025-03-26"
|
||||||
|
assert "serverInfo" in result
|
||||||
|
assert result["serverInfo"]["name"] == "openai-docs-mcp"
|
||||||
|
assert "tools" in result.get("capabilities", {})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_list_tools(self):
|
||||||
|
"""Verify we can discover tools from a real MCP server."""
|
||||||
|
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||||
|
await client.initialize()
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
|
assert len(tools) >= 3 # server has at least 5 tools as of writing
|
||||||
|
|
||||||
|
tool_names = {t.name for t in tools}
|
||||||
|
# These tools are documented and should be stable
|
||||||
|
assert "search_openai_docs" in tool_names
|
||||||
|
assert "list_openai_docs" in tool_names
|
||||||
|
assert "fetch_openai_doc" in tool_names
|
||||||
|
|
||||||
|
# Verify schema structure
|
||||||
|
search_tool = next(t for t in tools if t.name == "search_openai_docs")
|
||||||
|
assert "query" in search_tool.input_schema.get("properties", {})
|
||||||
|
assert "query" in search_tool.input_schema.get("required", [])
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_call_tool_list_api_endpoints(self):
|
||||||
|
"""Call the list_api_endpoints tool and verify we get real data."""
|
||||||
|
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool("list_api_endpoints", {})
|
||||||
|
|
||||||
|
assert not result.is_error
|
||||||
|
assert len(result.content) >= 1
|
||||||
|
assert result.content[0]["type"] == "text"
|
||||||
|
|
||||||
|
data = json.loads(result.content[0]["text"])
|
||||||
|
assert "paths" in data or "urls" in data
|
||||||
|
# The OpenAI API should have many endpoints
|
||||||
|
total = data.get("total", len(data.get("paths", [])))
|
||||||
|
assert total > 50
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_call_tool_search(self):
|
||||||
|
"""Search for docs and verify we get results."""
|
||||||
|
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool(
|
||||||
|
"search_openai_docs", {"query": "chat completions", "limit": 3}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not result.is_error
|
||||||
|
assert len(result.content) >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_sse_response_handling(self):
|
||||||
|
"""Verify the client correctly handles SSE responses from a real server.
|
||||||
|
|
||||||
|
This is the key test — our local test server returns JSON,
|
||||||
|
but real MCP servers typically return SSE. This proves the
|
||||||
|
SSE parsing works end-to-end.
|
||||||
|
"""
|
||||||
|
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||||
|
# initialize() internally calls _send_request which must parse SSE
|
||||||
|
result = await client.initialize()
|
||||||
|
|
||||||
|
# If we got here without error, SSE parsing works
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "protocolVersion" in result
|
||||||
|
|
||||||
|
# Also verify list_tools works (another SSE response)
|
||||||
|
tools = await client.list_tools()
|
||||||
|
assert len(tools) > 0
|
||||||
|
assert all(hasattr(t, "name") for t in tools)
|
||||||
389
autogpt_platform/backend/backend/blocks/mcp/test_integration.py
Normal file
389
autogpt_platform/backend/backend/blocks/mcp/test_integration.py
Normal file
@@ -0,0 +1,389 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for MCP client and MCPToolBlock against a real HTTP server.
|
||||||
|
|
||||||
|
These tests spin up a local MCP test server and run the full client/block flow
|
||||||
|
against it — no mocking, real HTTP requests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from aiohttp import web
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.blocks.mcp.block import MCPToolBlock
|
||||||
|
from backend.blocks.mcp.client import MCPClient
|
||||||
|
from backend.blocks.mcp.test_server import create_test_mcp_app
|
||||||
|
from backend.data.model import OAuth2Credentials
|
||||||
|
|
||||||
|
MOCK_USER_ID = "test-user-integration"
|
||||||
|
|
||||||
|
|
||||||
|
class _MCPTestServer:
|
||||||
|
"""
|
||||||
|
Run an MCP test server in a background thread with its own event loop.
|
||||||
|
This avoids event loop conflicts with pytest-asyncio.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, auth_token: str | None = None):
|
||||||
|
self.auth_token = auth_token
|
||||||
|
self.url: str = ""
|
||||||
|
self._runner: web.AppRunner | None = None
|
||||||
|
self._loop: asyncio.AbstractEventLoop | None = None
|
||||||
|
self._thread: threading.Thread | None = None
|
||||||
|
self._started = threading.Event()
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
self._loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self._loop)
|
||||||
|
self._loop.run_until_complete(self._start())
|
||||||
|
self._started.set()
|
||||||
|
self._loop.run_forever()
|
||||||
|
|
||||||
|
async def _start(self):
|
||||||
|
app = create_test_mcp_app(auth_token=self.auth_token)
|
||||||
|
self._runner = web.AppRunner(app)
|
||||||
|
await self._runner.setup()
|
||||||
|
site = web.TCPSite(self._runner, "127.0.0.1", 0)
|
||||||
|
await site.start()
|
||||||
|
port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr]
|
||||||
|
self.url = f"http://127.0.0.1:{port}/mcp"
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
if not self._started.wait(timeout=5):
|
||||||
|
raise RuntimeError("MCP test server failed to start within 5 seconds")
|
||||||
|
return self
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
if self._loop and self._runner:
|
||||||
|
asyncio.run_coroutine_threadsafe(self._runner.cleanup(), self._loop).result(
|
||||||
|
timeout=5
|
||||||
|
)
|
||||||
|
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||||
|
if self._thread:
|
||||||
|
self._thread.join(timeout=5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mcp_server():
|
||||||
|
"""Start a local MCP test server in a background thread."""
|
||||||
|
server = _MCPTestServer()
|
||||||
|
server.start()
|
||||||
|
yield server.url
|
||||||
|
server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mcp_server_with_auth():
|
||||||
|
"""Start a local MCP test server with auth in a background thread."""
|
||||||
|
server = _MCPTestServer(auth_token="test-secret-token")
|
||||||
|
server.start()
|
||||||
|
yield server.url, "test-secret-token"
|
||||||
|
server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _allow_localhost():
|
||||||
|
"""
|
||||||
|
Allow 127.0.0.1 through SSRF protection for integration tests.
|
||||||
|
|
||||||
|
The Requests class blocks private IPs by default. We patch the Requests
|
||||||
|
constructor to always include 127.0.0.1 as a trusted origin so the local
|
||||||
|
test server is reachable.
|
||||||
|
"""
|
||||||
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
original_init = Requests.__init__
|
||||||
|
|
||||||
|
def patched_init(self, *args, **kwargs):
|
||||||
|
trusted = list(kwargs.get("trusted_origins") or [])
|
||||||
|
trusted.append("http://127.0.0.1")
|
||||||
|
kwargs["trusted_origins"] = trusted
|
||||||
|
original_init(self, *args, **kwargs)
|
||||||
|
|
||||||
|
with patch.object(Requests, "__init__", patched_init):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def _make_client(url: str, auth_token: str | None = None) -> MCPClient:
|
||||||
|
"""Create an MCPClient for integration tests."""
|
||||||
|
return MCPClient(url, auth_token=auth_token)
|
||||||
|
|
||||||
|
|
||||||
|
# ── MCPClient integration tests ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPClientIntegration:
|
||||||
|
"""Test MCPClient against a real local MCP server."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_initialize(self, mcp_server):
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
result = await client.initialize()
|
||||||
|
|
||||||
|
assert result["protocolVersion"] == "2025-03-26"
|
||||||
|
assert result["serverInfo"]["name"] == "test-mcp-server"
|
||||||
|
assert "tools" in result["capabilities"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_list_tools(self, mcp_server):
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
|
assert len(tools) == 3
|
||||||
|
|
||||||
|
tool_names = {t.name for t in tools}
|
||||||
|
assert tool_names == {"get_weather", "add_numbers", "echo"}
|
||||||
|
|
||||||
|
# Check get_weather schema
|
||||||
|
weather = next(t for t in tools if t.name == "get_weather")
|
||||||
|
assert weather.description == "Get current weather for a city"
|
||||||
|
assert "city" in weather.input_schema["properties"]
|
||||||
|
assert weather.input_schema["required"] == ["city"]
|
||||||
|
|
||||||
|
# Check add_numbers schema
|
||||||
|
add = next(t for t in tools if t.name == "add_numbers")
|
||||||
|
assert "a" in add.input_schema["properties"]
|
||||||
|
assert "b" in add.input_schema["properties"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_call_tool_get_weather(self, mcp_server):
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool("get_weather", {"city": "London"})
|
||||||
|
|
||||||
|
assert not result.is_error
|
||||||
|
assert len(result.content) == 1
|
||||||
|
assert result.content[0]["type"] == "text"
|
||||||
|
|
||||||
|
data = json.loads(result.content[0]["text"])
|
||||||
|
assert data["city"] == "London"
|
||||||
|
assert data["temperature"] == 22
|
||||||
|
assert data["condition"] == "sunny"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_call_tool_add_numbers(self, mcp_server):
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool("add_numbers", {"a": 3, "b": 7})
|
||||||
|
|
||||||
|
assert not result.is_error
|
||||||
|
data = json.loads(result.content[0]["text"])
|
||||||
|
assert data["result"] == 10
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_call_tool_echo(self, mcp_server):
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool("echo", {"message": "Hello MCP!"})
|
||||||
|
|
||||||
|
assert not result.is_error
|
||||||
|
assert result.content[0]["text"] == "Hello MCP!"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_call_unknown_tool(self, mcp_server):
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool("nonexistent_tool", {})
|
||||||
|
|
||||||
|
assert result.is_error
|
||||||
|
assert "Unknown tool" in result.content[0]["text"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_auth_success(self, mcp_server_with_auth):
|
||||||
|
url, token = mcp_server_with_auth
|
||||||
|
client = _make_client(url, auth_token=token)
|
||||||
|
result = await client.initialize()
|
||||||
|
|
||||||
|
assert result["protocolVersion"] == "2025-03-26"
|
||||||
|
|
||||||
|
tools = await client.list_tools()
|
||||||
|
assert len(tools) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_auth_failure(self, mcp_server_with_auth):
|
||||||
|
url, _ = mcp_server_with_auth
|
||||||
|
client = _make_client(url, auth_token="wrong-token")
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
await client.initialize()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_auth_missing(self, mcp_server_with_auth):
|
||||||
|
url, _ = mcp_server_with_auth
|
||||||
|
client = _make_client(url)
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
await client.initialize()
|
||||||
|
|
||||||
|
|
||||||
|
# ── MCPToolBlock integration tests ───────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPToolBlockIntegration:
|
||||||
|
"""Test MCPToolBlock end-to-end against a real local MCP server."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_full_flow_get_weather(self, mcp_server):
|
||||||
|
"""Full flow: discover tools, select one, execute it."""
|
||||||
|
# Step 1: Discover tools (simulating what the frontend/API would do)
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
tools = await client.list_tools()
|
||||||
|
assert len(tools) == 3
|
||||||
|
|
||||||
|
# Step 2: User selects "get_weather" and we get its schema
|
||||||
|
weather_tool = next(t for t in tools if t.name == "get_weather")
|
||||||
|
|
||||||
|
# Step 3: Execute the block — no credentials (public server)
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url=mcp_server,
|
||||||
|
selected_tool="get_weather",
|
||||||
|
tool_input_schema=weather_tool.input_schema,
|
||||||
|
tool_arguments={"city": "Paris"},
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert outputs[0][0] == "result"
|
||||||
|
result = outputs[0][1]
|
||||||
|
assert result["city"] == "Paris"
|
||||||
|
assert result["temperature"] == 22
|
||||||
|
assert result["condition"] == "sunny"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_full_flow_add_numbers(self, mcp_server):
|
||||||
|
"""Full flow for add_numbers tool."""
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
tools = await client.list_tools()
|
||||||
|
add_tool = next(t for t in tools if t.name == "add_numbers")
|
||||||
|
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url=mcp_server,
|
||||||
|
selected_tool="add_numbers",
|
||||||
|
tool_input_schema=add_tool.input_schema,
|
||||||
|
tool_arguments={"a": 42, "b": 58},
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert outputs[0][0] == "result"
|
||||||
|
assert outputs[0][1]["result"] == 100
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_full_flow_echo_plain_text(self, mcp_server):
|
||||||
|
"""Verify plain text (non-JSON) responses work."""
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url=mcp_server,
|
||||||
|
selected_tool="echo",
|
||||||
|
tool_input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"message": {"type": "string"}},
|
||||||
|
"required": ["message"],
|
||||||
|
},
|
||||||
|
tool_arguments={"message": "Hello from AutoGPT!"},
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert outputs[0][0] == "result"
|
||||||
|
assert outputs[0][1] == "Hello from AutoGPT!"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_full_flow_unknown_tool_yields_error(self, mcp_server):
|
||||||
|
"""Calling an unknown tool should yield an error output."""
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url=mcp_server,
|
||||||
|
selected_tool="nonexistent_tool",
|
||||||
|
tool_arguments={},
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert outputs[0][0] == "error"
|
||||||
|
assert "returned an error" in outputs[0][1]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_full_flow_with_auth(self, mcp_server_with_auth):
|
||||||
|
"""Full flow with authentication via credentials kwarg."""
|
||||||
|
url, token = mcp_server_with_auth
|
||||||
|
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url=url,
|
||||||
|
selected_tool="echo",
|
||||||
|
tool_input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"message": {"type": "string"}},
|
||||||
|
"required": ["message"],
|
||||||
|
},
|
||||||
|
tool_arguments={"message": "Authenticated!"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pass credentials via the standard kwarg (as the executor would)
|
||||||
|
test_creds = OAuth2Credentials(
|
||||||
|
id="test-cred",
|
||||||
|
provider="mcp",
|
||||||
|
access_token=SecretStr(token),
|
||||||
|
refresh_token=SecretStr(""),
|
||||||
|
scopes=[],
|
||||||
|
title="Test MCP credential",
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(
|
||||||
|
input_data, user_id=MOCK_USER_ID, credentials=test_creds
|
||||||
|
):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert outputs[0][0] == "result"
|
||||||
|
assert outputs[0][1] == "Authenticated!"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_no_credentials_runs_without_auth(self, mcp_server):
|
||||||
|
"""Block runs without auth when no credentials are provided."""
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url=mcp_server,
|
||||||
|
selected_tool="echo",
|
||||||
|
tool_input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"message": {"type": "string"}},
|
||||||
|
"required": ["message"],
|
||||||
|
},
|
||||||
|
tool_arguments={"message": "No auth needed"},
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(
|
||||||
|
input_data, user_id=MOCK_USER_ID, credentials=None
|
||||||
|
):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert outputs[0][0] == "result"
|
||||||
|
assert outputs[0][1] == "No auth needed"
|
||||||
619
autogpt_platform/backend/backend/blocks/mcp/test_mcp.py
Normal file
619
autogpt_platform/backend/backend/blocks/mcp/test_mcp.py
Normal file
@@ -0,0 +1,619 @@
|
|||||||
|
"""
|
||||||
|
Tests for MCP client and MCPToolBlock.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.blocks.mcp.block import MCPToolBlock
|
||||||
|
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError
|
||||||
|
from backend.util.test import execute_block_test
|
||||||
|
|
||||||
|
# ── SSE parsing unit tests ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSEParsing:
|
||||||
|
"""Tests for SSE (text/event-stream) response parsing."""
|
||||||
|
|
||||||
|
def test_parse_sse_simple(self):
|
||||||
|
sse = (
|
||||||
|
"event: message\n"
|
||||||
|
'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n'
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
body = MCPClient._parse_sse_response(sse)
|
||||||
|
assert body["result"] == {"tools": []}
|
||||||
|
assert body["id"] == 1
|
||||||
|
|
||||||
|
def test_parse_sse_with_notifications(self):
|
||||||
|
"""SSE streams can contain notifications (no id) before the response."""
|
||||||
|
sse = (
|
||||||
|
"event: message\n"
|
||||||
|
'data: {"jsonrpc":"2.0","method":"some/notification"}\n'
|
||||||
|
"\n"
|
||||||
|
"event: message\n"
|
||||||
|
'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n'
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
body = MCPClient._parse_sse_response(sse)
|
||||||
|
assert body["result"] == {"ok": True}
|
||||||
|
assert body["id"] == 2
|
||||||
|
|
||||||
|
def test_parse_sse_error_response(self):
|
||||||
|
sse = (
|
||||||
|
"event: message\n"
|
||||||
|
'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n'
|
||||||
|
)
|
||||||
|
body = MCPClient._parse_sse_response(sse)
|
||||||
|
assert "error" in body
|
||||||
|
assert body["error"]["code"] == -32600
|
||||||
|
|
||||||
|
def test_parse_sse_no_data_raises(self):
|
||||||
|
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
|
||||||
|
MCPClient._parse_sse_response("event: message\n\n")
|
||||||
|
|
||||||
|
def test_parse_sse_empty_raises(self):
|
||||||
|
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
|
||||||
|
MCPClient._parse_sse_response("")
|
||||||
|
|
||||||
|
def test_parse_sse_ignores_non_data_lines(self):
|
||||||
|
sse = (
|
||||||
|
": comment line\n"
|
||||||
|
"event: message\n"
|
||||||
|
"id: 123\n"
|
||||||
|
'data: {"jsonrpc":"2.0","result":"ok","id":1}\n'
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
body = MCPClient._parse_sse_response(sse)
|
||||||
|
assert body["result"] == "ok"
|
||||||
|
|
||||||
|
def test_parse_sse_uses_last_response(self):
|
||||||
|
"""If multiple responses exist, use the last one."""
|
||||||
|
sse = (
|
||||||
|
'data: {"jsonrpc":"2.0","result":"first","id":1}\n'
|
||||||
|
"\n"
|
||||||
|
'data: {"jsonrpc":"2.0","result":"second","id":2}\n'
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
body = MCPClient._parse_sse_response(sse)
|
||||||
|
assert body["result"] == "second"
|
||||||
|
|
||||||
|
|
||||||
|
# ── MCPClient unit tests ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPClient:
|
||||||
|
"""Tests for the MCP HTTP client."""
|
||||||
|
|
||||||
|
def test_build_headers_without_auth(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
headers = client._build_headers()
|
||||||
|
assert "Authorization" not in headers
|
||||||
|
assert headers["Content-Type"] == "application/json"
|
||||||
|
|
||||||
|
def test_build_headers_with_auth(self):
|
||||||
|
client = MCPClient("https://mcp.example.com", auth_token="my-token")
|
||||||
|
headers = client._build_headers()
|
||||||
|
assert headers["Authorization"] == "Bearer my-token"
|
||||||
|
|
||||||
|
def test_build_jsonrpc_request(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
req = client._build_jsonrpc_request("tools/list")
|
||||||
|
assert req["jsonrpc"] == "2.0"
|
||||||
|
assert req["method"] == "tools/list"
|
||||||
|
assert "id" in req
|
||||||
|
assert "params" not in req
|
||||||
|
|
||||||
|
def test_build_jsonrpc_request_with_params(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
req = client._build_jsonrpc_request(
|
||||||
|
"tools/call", {"name": "test", "arguments": {"x": 1}}
|
||||||
|
)
|
||||||
|
assert req["params"] == {"name": "test", "arguments": {"x": 1}}
|
||||||
|
|
||||||
|
def test_request_id_increments(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
req1 = client._build_jsonrpc_request("tools/list")
|
||||||
|
req2 = client._build_jsonrpc_request("tools/list")
|
||||||
|
assert req2["id"] > req1["id"]
|
||||||
|
|
||||||
|
def test_server_url_trailing_slash_stripped(self):
|
||||||
|
client = MCPClient("https://mcp.example.com/mcp/")
|
||||||
|
assert client.server_url == "https://mcp.example.com/mcp"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_send_request_success(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"result": {"tools": []},
|
||||||
|
"id": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value={"tools": []}):
|
||||||
|
result = await client._send_request("tools/list")
|
||||||
|
assert result == {"tools": []}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_send_request_error(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
async def mock_send(*args, **kwargs):
|
||||||
|
raise MCPClientError("MCP server error [-32600]: Invalid Request")
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", side_effect=mock_send):
|
||||||
|
with pytest.raises(MCPClientError, match="Invalid Request"):
|
||||||
|
await client._send_request("tools/list")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_list_tools(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
mock_result = {
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get current weather for a city",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "search",
|
||||||
|
"description": "Search the web",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"query": {"type": "string"}},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value=mock_result):
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
|
assert len(tools) == 2
|
||||||
|
assert tools[0].name == "get_weather"
|
||||||
|
assert tools[0].description == "Get current weather for a city"
|
||||||
|
assert tools[0].input_schema["properties"]["city"]["type"] == "string"
|
||||||
|
assert tools[1].name == "search"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_list_tools_empty(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value={"tools": []}):
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
|
assert tools == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_list_tools_none_result(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value=None):
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
|
assert tools == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_call_tool_success(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
mock_result = {
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": json.dumps({"temp": 20, "city": "London"})}
|
||||||
|
],
|
||||||
|
"isError": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value=mock_result):
|
||||||
|
result = await client.call_tool("get_weather", {"city": "London"})
|
||||||
|
|
||||||
|
assert not result.is_error
|
||||||
|
assert len(result.content) == 1
|
||||||
|
assert result.content[0]["type"] == "text"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_call_tool_error(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
mock_result = {
|
||||||
|
"content": [{"type": "text", "text": "City not found"}],
|
||||||
|
"isError": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value=mock_result):
|
||||||
|
result = await client.call_tool("get_weather", {"city": "???"})
|
||||||
|
|
||||||
|
assert result.is_error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_call_tool_none_result(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value=None):
|
||||||
|
result = await client.call_tool("get_weather", {"city": "London"})
|
||||||
|
|
||||||
|
assert result.is_error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_initialize(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
mock_result = {
|
||||||
|
"protocolVersion": "2025-03-26",
|
||||||
|
"capabilities": {"tools": {}},
|
||||||
|
"serverInfo": {"name": "test-server", "version": "1.0.0"},
|
||||||
|
}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(client, "_send_request", return_value=mock_result) as mock_req,
|
||||||
|
patch.object(client, "_send_notification") as mock_notif,
|
||||||
|
):
|
||||||
|
result = await client.initialize()
|
||||||
|
|
||||||
|
mock_req.assert_called_once()
|
||||||
|
mock_notif.assert_called_once_with("notifications/initialized")
|
||||||
|
assert result["protocolVersion"] == "2025-03-26"
|
||||||
|
|
||||||
|
|
||||||
|
# ── MCPToolBlock unit tests ──────────────────────────────────────────
|
||||||
|
|
||||||
|
MOCK_USER_ID = "test-user-123"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPToolBlock:
|
||||||
|
"""Tests for the MCPToolBlock."""
|
||||||
|
|
||||||
|
def test_block_instantiation(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
assert block.id == "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
|
||||||
|
assert block.name == "MCPToolBlock"
|
||||||
|
|
||||||
|
def test_input_schema_has_required_fields(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
schema = block.input_schema.jsonschema()
|
||||||
|
props = schema.get("properties", {})
|
||||||
|
assert "server_url" in props
|
||||||
|
assert "selected_tool" in props
|
||||||
|
assert "tool_arguments" in props
|
||||||
|
assert "credentials" in props
|
||||||
|
|
||||||
|
def test_output_schema(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
schema = block.output_schema.jsonschema()
|
||||||
|
props = schema.get("properties", {})
|
||||||
|
assert "result" in props
|
||||||
|
assert "error" in props
|
||||||
|
|
||||||
|
def test_get_input_schema_with_tool_schema(self):
|
||||||
|
tool_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"query": {"type": "string"}},
|
||||||
|
"required": ["query"],
|
||||||
|
}
|
||||||
|
data = {"tool_input_schema": tool_schema}
|
||||||
|
result = MCPToolBlock.Input.get_input_schema(data)
|
||||||
|
assert result == tool_schema
|
||||||
|
|
||||||
|
def test_get_input_schema_without_tool_schema(self):
|
||||||
|
result = MCPToolBlock.Input.get_input_schema({})
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_get_input_defaults(self):
|
||||||
|
data = {"tool_arguments": {"city": "London"}}
|
||||||
|
result = MCPToolBlock.Input.get_input_defaults(data)
|
||||||
|
assert result == {"city": "London"}
|
||||||
|
|
||||||
|
def test_get_missing_input(self):
|
||||||
|
data = {
|
||||||
|
"tool_input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {"type": "string"},
|
||||||
|
"units": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["city", "units"],
|
||||||
|
},
|
||||||
|
"tool_arguments": {"city": "London"},
|
||||||
|
}
|
||||||
|
missing = MCPToolBlock.Input.get_missing_input(data)
|
||||||
|
assert missing == {"units"}
|
||||||
|
|
||||||
|
def test_get_missing_input_all_present(self):
|
||||||
|
data = {
|
||||||
|
"tool_input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
"tool_arguments": {"city": "London"},
|
||||||
|
}
|
||||||
|
missing = MCPToolBlock.Input.get_missing_input(data)
|
||||||
|
assert missing == set()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_run_with_mock(self):
|
||||||
|
"""Test the block using the built-in test infrastructure."""
|
||||||
|
block = MCPToolBlock()
|
||||||
|
await execute_block_test(block)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_run_missing_server_url(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url="",
|
||||||
|
selected_tool="test",
|
||||||
|
)
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||||
|
outputs.append((name, data))
|
||||||
|
assert outputs == [("error", "MCP server URL is required")]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_run_missing_tool(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url="https://mcp.example.com/mcp",
|
||||||
|
selected_tool="",
|
||||||
|
)
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||||
|
outputs.append((name, data))
|
||||||
|
assert outputs == [
|
||||||
|
("error", "No tool selected. Please select a tool from the dropdown.")
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_run_success(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url="https://mcp.example.com/mcp",
|
||||||
|
selected_tool="get_weather",
|
||||||
|
tool_input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
},
|
||||||
|
tool_arguments={"city": "London"},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_call(*args, **kwargs):
|
||||||
|
return {"temp": 20, "city": "London"}
|
||||||
|
|
||||||
|
block._call_mcp_tool = mock_call # type: ignore
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert outputs[0][0] == "result"
|
||||||
|
assert outputs[0][1] == {"temp": 20, "city": "London"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_run_mcp_error(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url="https://mcp.example.com/mcp",
|
||||||
|
selected_tool="bad_tool",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_call(*args, **kwargs):
|
||||||
|
raise MCPClientError("Tool not found")
|
||||||
|
|
||||||
|
block._call_mcp_tool = mock_call # type: ignore
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert outputs[0][0] == "error"
|
||||||
|
assert "Tool not found" in outputs[0][1]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_call_mcp_tool_parses_json_text(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
|
||||||
|
mock_result = MCPCallResult(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": '{"temp": 20}'},
|
||||||
|
],
|
||||||
|
is_error=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_init(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def mock_call(self, name, args):
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MCPClient, "initialize", mock_init),
|
||||||
|
patch.object(MCPClient, "call_tool", mock_call),
|
||||||
|
):
|
||||||
|
result = await block._call_mcp_tool(
|
||||||
|
"https://mcp.example.com", "test_tool", {}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"temp": 20}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_call_mcp_tool_plain_text(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
|
||||||
|
mock_result = MCPCallResult(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "Hello, world!"},
|
||||||
|
],
|
||||||
|
is_error=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_init(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def mock_call(self, name, args):
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MCPClient, "initialize", mock_init),
|
||||||
|
patch.object(MCPClient, "call_tool", mock_call),
|
||||||
|
):
|
||||||
|
result = await block._call_mcp_tool(
|
||||||
|
"https://mcp.example.com", "test_tool", {}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "Hello, world!"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_call_mcp_tool_multiple_content(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
|
||||||
|
mock_result = MCPCallResult(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "Part 1"},
|
||||||
|
{"type": "text", "text": '{"part": 2}'},
|
||||||
|
],
|
||||||
|
is_error=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_init(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def mock_call(self, name, args):
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MCPClient, "initialize", mock_init),
|
||||||
|
patch.object(MCPClient, "call_tool", mock_call),
|
||||||
|
):
|
||||||
|
result = await block._call_mcp_tool(
|
||||||
|
"https://mcp.example.com", "test_tool", {}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == ["Part 1", {"part": 2}]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_call_mcp_tool_error_result(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
|
||||||
|
mock_result = MCPCallResult(
|
||||||
|
content=[{"type": "text", "text": "Something went wrong"}],
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_init(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def mock_call(self, name, args):
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MCPClient, "initialize", mock_init),
|
||||||
|
patch.object(MCPClient, "call_tool", mock_call),
|
||||||
|
):
|
||||||
|
with pytest.raises(MCPClientError, match="returned an error"):
|
||||||
|
await block._call_mcp_tool("https://mcp.example.com", "test_tool", {})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_call_mcp_tool_image_content(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
|
||||||
|
mock_result = MCPCallResult(
|
||||||
|
content=[
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"data": "base64data==",
|
||||||
|
"mimeType": "image/png",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
is_error=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_init(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def mock_call(self, name, args):
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MCPClient, "initialize", mock_init),
|
||||||
|
patch.object(MCPClient, "call_tool", mock_call),
|
||||||
|
):
|
||||||
|
result = await block._call_mcp_tool(
|
||||||
|
"https://mcp.example.com", "test_tool", {}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"type": "image",
|
||||||
|
"data": "base64data==",
|
||||||
|
"mimeType": "image/png",
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_run_with_credentials(self):
|
||||||
|
"""Verify the block uses OAuth2Credentials and passes auth token."""
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.data.model import OAuth2Credentials
|
||||||
|
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url="https://mcp.example.com/mcp",
|
||||||
|
selected_tool="test_tool",
|
||||||
|
)
|
||||||
|
|
||||||
|
captured_tokens: list[str | None] = []
|
||||||
|
|
||||||
|
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
||||||
|
captured_tokens.append(auth_token)
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
block._call_mcp_tool = mock_call # type: ignore
|
||||||
|
|
||||||
|
test_creds = OAuth2Credentials(
|
||||||
|
id="cred-123",
|
||||||
|
provider="mcp",
|
||||||
|
access_token=SecretStr("resolved-token"),
|
||||||
|
refresh_token=SecretStr(""),
|
||||||
|
scopes=[],
|
||||||
|
title="Test MCP credential",
|
||||||
|
)
|
||||||
|
|
||||||
|
async for _ in block.run(
|
||||||
|
input_data, user_id=MOCK_USER_ID, credentials=test_creds
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert captured_tokens == ["resolved-token"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_run_without_credentials(self):
|
||||||
|
"""Verify the block works without credentials (public server)."""
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url="https://mcp.example.com/mcp",
|
||||||
|
selected_tool="test_tool",
|
||||||
|
)
|
||||||
|
|
||||||
|
captured_tokens: list[str | None] = []
|
||||||
|
|
||||||
|
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
||||||
|
captured_tokens.append(auth_token)
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
block._call_mcp_tool = mock_call # type: ignore
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert captured_tokens == [None]
|
||||||
|
assert outputs == [("result", "ok")]
|
||||||
242
autogpt_platform/backend/backend/blocks/mcp/test_oauth.py
Normal file
242
autogpt_platform/backend/backend/blocks/mcp/test_oauth.py
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
"""
|
||||||
|
Tests for MCP OAuth handler.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.blocks.mcp.client import MCPClient
|
||||||
|
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||||
|
from backend.data.model import OAuth2Credentials
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_response(json_data: dict, status: int = 200) -> MagicMock:
|
||||||
|
"""Create a mock Response with synchronous json() (matching Requests.Response)."""
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.status = status
|
||||||
|
resp.ok = 200 <= status < 300
|
||||||
|
resp.json.return_value = json_data
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPOAuthHandler:
|
||||||
|
"""Tests for the MCPOAuthHandler."""
|
||||||
|
|
||||||
|
def _make_handler(self, **overrides) -> MCPOAuthHandler:
|
||||||
|
defaults = {
|
||||||
|
"client_id": "test-client-id",
|
||||||
|
"client_secret": "test-client-secret",
|
||||||
|
"redirect_uri": "https://app.example.com/callback",
|
||||||
|
"authorize_url": "https://auth.example.com/authorize",
|
||||||
|
"token_url": "https://auth.example.com/token",
|
||||||
|
}
|
||||||
|
defaults.update(overrides)
|
||||||
|
return MCPOAuthHandler(**defaults)
|
||||||
|
|
||||||
|
def test_get_login_url_basic(self):
|
||||||
|
handler = self._make_handler()
|
||||||
|
url = handler.get_login_url(
|
||||||
|
scopes=["read", "write"],
|
||||||
|
state="random-state-token",
|
||||||
|
code_challenge="S256-challenge-value",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "https://auth.example.com/authorize?" in url
|
||||||
|
assert "response_type=code" in url
|
||||||
|
assert "client_id=test-client-id" in url
|
||||||
|
assert "state=random-state-token" in url
|
||||||
|
assert "code_challenge=S256-challenge-value" in url
|
||||||
|
assert "code_challenge_method=S256" in url
|
||||||
|
assert "scope=read+write" in url
|
||||||
|
|
||||||
|
def test_get_login_url_with_resource(self):
|
||||||
|
handler = self._make_handler(resource_url="https://mcp.example.com/mcp")
|
||||||
|
url = handler.get_login_url(
|
||||||
|
scopes=[], state="state", code_challenge="challenge"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "resource=https" in url
|
||||||
|
|
||||||
|
def test_get_login_url_without_pkce(self):
|
||||||
|
handler = self._make_handler()
|
||||||
|
url = handler.get_login_url(scopes=["read"], state="state", code_challenge=None)
|
||||||
|
|
||||||
|
assert "code_challenge" not in url
|
||||||
|
assert "code_challenge_method" not in url
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_exchange_code_for_tokens(self):
|
||||||
|
handler = self._make_handler()
|
||||||
|
|
||||||
|
resp = _mock_response(
|
||||||
|
{
|
||||||
|
"access_token": "new-access-token",
|
||||||
|
"refresh_token": "new-refresh-token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
||||||
|
instance = MockRequests.return_value
|
||||||
|
instance.post = AsyncMock(return_value=resp)
|
||||||
|
|
||||||
|
creds = await handler.exchange_code_for_tokens(
|
||||||
|
code="auth-code",
|
||||||
|
scopes=["read"],
|
||||||
|
code_verifier="pkce-verifier",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(creds, OAuth2Credentials)
|
||||||
|
assert creds.access_token.get_secret_value() == "new-access-token"
|
||||||
|
assert creds.refresh_token is not None
|
||||||
|
assert creds.refresh_token.get_secret_value() == "new-refresh-token"
|
||||||
|
assert creds.scopes == ["read"]
|
||||||
|
assert creds.access_token_expires_at is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_refresh_tokens(self):
|
||||||
|
handler = self._make_handler()
|
||||||
|
|
||||||
|
existing_creds = OAuth2Credentials(
|
||||||
|
id="existing-id",
|
||||||
|
provider="mcp",
|
||||||
|
access_token=SecretStr("old-token"),
|
||||||
|
refresh_token=SecretStr("old-refresh"),
|
||||||
|
scopes=["read"],
|
||||||
|
title="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = _mock_response(
|
||||||
|
{
|
||||||
|
"access_token": "refreshed-token",
|
||||||
|
"refresh_token": "new-refresh",
|
||||||
|
"expires_in": 3600,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
||||||
|
instance = MockRequests.return_value
|
||||||
|
instance.post = AsyncMock(return_value=resp)
|
||||||
|
|
||||||
|
refreshed = await handler._refresh_tokens(existing_creds)
|
||||||
|
|
||||||
|
assert refreshed.id == "existing-id"
|
||||||
|
assert refreshed.access_token.get_secret_value() == "refreshed-token"
|
||||||
|
assert refreshed.refresh_token is not None
|
||||||
|
assert refreshed.refresh_token.get_secret_value() == "new-refresh"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_refresh_tokens_no_refresh_token(self):
|
||||||
|
handler = self._make_handler()
|
||||||
|
|
||||||
|
creds = OAuth2Credentials(
|
||||||
|
provider="mcp",
|
||||||
|
access_token=SecretStr("token"),
|
||||||
|
scopes=["read"],
|
||||||
|
title="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No refresh token"):
|
||||||
|
await handler._refresh_tokens(creds)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_revoke_tokens_no_url(self):
|
||||||
|
handler = self._make_handler(revoke_url=None)
|
||||||
|
|
||||||
|
creds = OAuth2Credentials(
|
||||||
|
provider="mcp",
|
||||||
|
access_token=SecretStr("token"),
|
||||||
|
scopes=[],
|
||||||
|
title="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await handler.revoke_tokens(creds)
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_revoke_tokens_with_url(self):
|
||||||
|
handler = self._make_handler(revoke_url="https://auth.example.com/revoke")
|
||||||
|
|
||||||
|
creds = OAuth2Credentials(
|
||||||
|
provider="mcp",
|
||||||
|
access_token=SecretStr("token"),
|
||||||
|
scopes=[],
|
||||||
|
title="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = _mock_response({}, status=200)
|
||||||
|
|
||||||
|
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
||||||
|
instance = MockRequests.return_value
|
||||||
|
instance.post = AsyncMock(return_value=resp)
|
||||||
|
|
||||||
|
result = await handler.revoke_tokens(creds)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPClientDiscovery:
|
||||||
|
"""Tests for MCPClient OAuth metadata discovery."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_discover_auth_found(self):
|
||||||
|
client = MCPClient("https://mcp.example.com/mcp")
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"authorization_servers": ["https://auth.example.com"],
|
||||||
|
"resource": "https://mcp.example.com/mcp",
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = _mock_response(metadata, status=200)
|
||||||
|
|
||||||
|
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
||||||
|
instance = MockRequests.return_value
|
||||||
|
instance.get = AsyncMock(return_value=resp)
|
||||||
|
|
||||||
|
result = await client.discover_auth()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["authorization_servers"] == ["https://auth.example.com"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_discover_auth_not_found(self):
|
||||||
|
client = MCPClient("https://mcp.example.com/mcp")
|
||||||
|
|
||||||
|
resp = _mock_response({}, status=404)
|
||||||
|
|
||||||
|
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
||||||
|
instance = MockRequests.return_value
|
||||||
|
instance.get = AsyncMock(return_value=resp)
|
||||||
|
|
||||||
|
result = await client.discover_auth()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_discover_auth_server_metadata(self):
|
||||||
|
client = MCPClient("https://mcp.example.com/mcp")
|
||||||
|
|
||||||
|
server_metadata = {
|
||||||
|
"issuer": "https://auth.example.com",
|
||||||
|
"authorization_endpoint": "https://auth.example.com/authorize",
|
||||||
|
"token_endpoint": "https://auth.example.com/token",
|
||||||
|
"registration_endpoint": "https://auth.example.com/register",
|
||||||
|
"code_challenge_methods_supported": ["S256"],
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = _mock_response(server_metadata, status=200)
|
||||||
|
|
||||||
|
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
||||||
|
instance = MockRequests.return_value
|
||||||
|
instance.get = AsyncMock(return_value=resp)
|
||||||
|
|
||||||
|
result = await client.discover_auth_server_metadata(
|
||||||
|
"https://auth.example.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["authorization_endpoint"] == "https://auth.example.com/authorize"
|
||||||
|
assert result["token_endpoint"] == "https://auth.example.com/token"
|
||||||
162
autogpt_platform/backend/backend/blocks/mcp/test_server.py
Normal file
162
autogpt_platform/backend/backend/blocks/mcp/test_server.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
"""
|
||||||
|
Minimal MCP server for integration testing.
|
||||||
|
|
||||||
|
Implements the MCP Streamable HTTP transport (JSON-RPC 2.0 over HTTP POST)
|
||||||
|
with a few sample tools. Runs on localhost with a random available port.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Sample tools this test server exposes
|
||||||
|
TEST_TOOLS = [
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get current weather for a city",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "City name",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "add_numbers",
|
||||||
|
"description": "Add two numbers together",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": {"type": "number", "description": "First number"},
|
||||||
|
"b": {"type": "number", "description": "Second number"},
|
||||||
|
},
|
||||||
|
"required": ["a", "b"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "echo",
|
||||||
|
"description": "Echo back the input message",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"message": {"type": "string", "description": "Message to echo"},
|
||||||
|
},
|
||||||
|
"required": ["message"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_initialize(params: dict) -> dict:
|
||||||
|
return {
|
||||||
|
"protocolVersion": "2025-03-26",
|
||||||
|
"capabilities": {"tools": {"listChanged": False}},
|
||||||
|
"serverInfo": {"name": "test-mcp-server", "version": "1.0.0"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_tools_list(params: dict) -> dict:
|
||||||
|
return {"tools": TEST_TOOLS}
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_tools_call(params: dict) -> dict:
|
||||||
|
tool_name = params.get("name", "")
|
||||||
|
arguments = params.get("arguments", {})
|
||||||
|
|
||||||
|
if tool_name == "get_weather":
|
||||||
|
city = arguments.get("city", "Unknown")
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": json.dumps(
|
||||||
|
{"city": city, "temperature": 22, "condition": "sunny"}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
elif tool_name == "add_numbers":
|
||||||
|
a = arguments.get("a", 0)
|
||||||
|
b = arguments.get("b", 0)
|
||||||
|
return {
|
||||||
|
"content": [{"type": "text", "text": json.dumps({"result": a + b})}],
|
||||||
|
}
|
||||||
|
|
||||||
|
elif tool_name == "echo":
|
||||||
|
message = arguments.get("message", "")
|
||||||
|
return {
|
||||||
|
"content": [{"type": "text", "text": message}],
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}],
|
||||||
|
"isError": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
HANDLERS = {
|
||||||
|
"initialize": _handle_initialize,
|
||||||
|
"tools/list": _handle_tools_list,
|
||||||
|
"tools/call": _handle_tools_call,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_mcp_request(request: web.Request) -> web.Response:
|
||||||
|
"""Handle incoming MCP JSON-RPC 2.0 requests."""
|
||||||
|
# Check auth if configured
|
||||||
|
expected_token = request.app.get("auth_token")
|
||||||
|
if expected_token:
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if auth_header != f"Bearer {expected_token}":
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"error": {"code": -32001, "message": "Unauthorized"},
|
||||||
|
"id": None,
|
||||||
|
},
|
||||||
|
status=401,
|
||||||
|
)
|
||||||
|
|
||||||
|
body = await request.json()
|
||||||
|
|
||||||
|
# Handle notifications (no id field) — just acknowledge
|
||||||
|
if "id" not in body:
|
||||||
|
return web.Response(status=202)
|
||||||
|
|
||||||
|
method = body.get("method", "")
|
||||||
|
params = body.get("params", {})
|
||||||
|
request_id = body.get("id")
|
||||||
|
|
||||||
|
handler = HANDLERS.get(method)
|
||||||
|
if not handler:
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"error": {
|
||||||
|
"code": -32601,
|
||||||
|
"message": f"Method not found: {method}",
|
||||||
|
},
|
||||||
|
"id": request_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = handler(params)
|
||||||
|
return web.json_response({"jsonrpc": "2.0", "result": result, "id": request_id})
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_mcp_app(auth_token: str | None = None) -> web.Application:
|
||||||
|
"""Create an aiohttp app that acts as an MCP server."""
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_post("/mcp", handle_mcp_request)
|
||||||
|
if auth_token:
|
||||||
|
app["auth_token"] = auth_token
|
||||||
|
return app
|
||||||
@@ -39,6 +39,7 @@ from backend.util import type as type_utils
|
|||||||
from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError
|
from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
from backend.util.request import parse_url
|
||||||
|
|
||||||
from .block import (
|
from .block import (
|
||||||
AnyBlockSchema,
|
AnyBlockSchema,
|
||||||
@@ -462,6 +463,9 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
continue
|
continue
|
||||||
if ProviderName.HTTP in field.provider:
|
if ProviderName.HTTP in field.provider:
|
||||||
continue
|
continue
|
||||||
|
# MCP credentials are intentionally split by server URL
|
||||||
|
if ProviderName.MCP in field.provider:
|
||||||
|
continue
|
||||||
|
|
||||||
# If this happens, that means a block implementation probably needs
|
# If this happens, that means a block implementation probably needs
|
||||||
# to be updated.
|
# to be updated.
|
||||||
@@ -518,6 +522,18 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
"required": ["id", "provider", "type"],
|
"required": ["id", "provider", "type"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Add a descriptive display title when URL-based discriminator values
|
||||||
|
# are present (e.g. "mcp.sentry.dev" instead of just "Mcp")
|
||||||
|
if (
|
||||||
|
field_info.discriminator
|
||||||
|
and not field_info.discriminator_mapping
|
||||||
|
and field_info.discriminator_values
|
||||||
|
):
|
||||||
|
hostnames = sorted(
|
||||||
|
parse_url(str(v)).netloc for v in field_info.discriminator_values
|
||||||
|
)
|
||||||
|
field_schema["display_name"] = ", ".join(hostnames)
|
||||||
|
|
||||||
# Add other (optional) field info items
|
# Add other (optional) field info items
|
||||||
field_schema.update(
|
field_schema.update(
|
||||||
field_info.model_dump(
|
field_info.model_dump(
|
||||||
@@ -562,8 +578,17 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
|
|
||||||
for graph in [self] + self.sub_graphs:
|
for graph in [self] + self.sub_graphs:
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
# Track if this node requires credentials (credentials_optional=False means required)
|
# A node's credentials are optional if either:
|
||||||
node_required_map[node.id] = not node.credentials_optional
|
# 1. The node metadata says so (credentials_optional=True), or
|
||||||
|
# 2. All credential fields on the block have defaults (not required by schema)
|
||||||
|
block_required = node.block.input_schema.get_required_fields()
|
||||||
|
creds_required_by_schema = any(
|
||||||
|
fname in block_required
|
||||||
|
for fname in node.block.input_schema.get_credentials_fields()
|
||||||
|
)
|
||||||
|
node_required_map[node.id] = (
|
||||||
|
not node.credentials_optional and creds_required_by_schema
|
||||||
|
)
|
||||||
|
|
||||||
for (
|
for (
|
||||||
field_name,
|
field_name,
|
||||||
@@ -784,6 +809,19 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
"'credentials' and `*_credentials` are reserved"
|
"'credentials' and `*_credentials` are reserved"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check custom block-level validation (e.g., MCP dynamic tool arguments).
|
||||||
|
# Blocks can override get_missing_input to report additional missing fields
|
||||||
|
# beyond the standard top-level required fields.
|
||||||
|
if for_run:
|
||||||
|
credential_fields = InputSchema.get_credentials_fields()
|
||||||
|
custom_missing = InputSchema.get_missing_input(node.input_default)
|
||||||
|
for field_name in custom_missing:
|
||||||
|
if (
|
||||||
|
field_name not in provided_inputs
|
||||||
|
and field_name not in credential_fields
|
||||||
|
):
|
||||||
|
node_errors[node.id][field_name] = "This field is required"
|
||||||
|
|
||||||
# Get input schema properties and check dependencies
|
# Get input schema properties and check dependencies
|
||||||
input_fields = InputSchema.model_fields
|
input_fields = InputSchema.model_fields
|
||||||
|
|
||||||
|
|||||||
@@ -463,3 +463,120 @@ def test_node_credentials_optional_with_other_metadata():
|
|||||||
assert node.credentials_optional is True
|
assert node.credentials_optional is True
|
||||||
assert node.metadata["position"] == {"x": 100, "y": 200}
|
assert node.metadata["position"] == {"x": 100, "y": 200}
|
||||||
assert node.metadata["customized_name"] == "My Custom Node"
|
assert node.metadata["customized_name"] == "My Custom Node"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for MCP Credential Deduplication
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_credential_combine_different_servers():
|
||||||
|
"""Two MCP credential fields with different server URLs should produce
|
||||||
|
separate entries when combined (not merged into one)."""
|
||||||
|
from backend.data.model import CredentialsFieldInfo, CredentialsType
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
|
||||||
|
|
||||||
|
field_sentry = CredentialsFieldInfo(
|
||||||
|
credentials_provider=frozenset([ProviderName.MCP]),
|
||||||
|
credentials_types=oauth2_types,
|
||||||
|
credentials_scopes=None,
|
||||||
|
discriminator="server_url",
|
||||||
|
discriminator_values={"https://mcp.sentry.dev/mcp"},
|
||||||
|
)
|
||||||
|
field_linear = CredentialsFieldInfo(
|
||||||
|
credentials_provider=frozenset([ProviderName.MCP]),
|
||||||
|
credentials_types=oauth2_types,
|
||||||
|
credentials_scopes=None,
|
||||||
|
discriminator="server_url",
|
||||||
|
discriminator_values={"https://mcp.linear.app/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
combined = CredentialsFieldInfo.combine(
|
||||||
|
(field_sentry, ("node-sentry", "credentials")),
|
||||||
|
(field_linear, ("node-linear", "credentials")),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should produce 2 separate credential entries
|
||||||
|
assert len(combined) == 2, (
|
||||||
|
f"Expected 2 credential entries for 2 MCP blocks with different servers, "
|
||||||
|
f"got {len(combined)}: {list(combined.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Each entry should contain the server hostname in its key
|
||||||
|
keys = list(combined.keys())
|
||||||
|
assert any(
|
||||||
|
"mcp.sentry.dev" in k for k in keys
|
||||||
|
), f"Expected 'mcp.sentry.dev' in one key, got {keys}"
|
||||||
|
assert any(
|
||||||
|
"mcp.linear.app" in k for k in keys
|
||||||
|
), f"Expected 'mcp.linear.app' in one key, got {keys}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_credential_combine_same_server():
|
||||||
|
"""Two MCP credential fields with the same server URL should be combined
|
||||||
|
into one credential entry."""
|
||||||
|
from backend.data.model import CredentialsFieldInfo, CredentialsType
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
|
||||||
|
|
||||||
|
field_a = CredentialsFieldInfo(
|
||||||
|
credentials_provider=frozenset([ProviderName.MCP]),
|
||||||
|
credentials_types=oauth2_types,
|
||||||
|
credentials_scopes=None,
|
||||||
|
discriminator="server_url",
|
||||||
|
discriminator_values={"https://mcp.sentry.dev/mcp"},
|
||||||
|
)
|
||||||
|
field_b = CredentialsFieldInfo(
|
||||||
|
credentials_provider=frozenset([ProviderName.MCP]),
|
||||||
|
credentials_types=oauth2_types,
|
||||||
|
credentials_scopes=None,
|
||||||
|
discriminator="server_url",
|
||||||
|
discriminator_values={"https://mcp.sentry.dev/mcp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
combined = CredentialsFieldInfo.combine(
|
||||||
|
(field_a, ("node-a", "credentials")),
|
||||||
|
(field_b, ("node-b", "credentials")),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should produce 1 credential entry (same server URL)
|
||||||
|
assert len(combined) == 1, (
|
||||||
|
f"Expected 1 credential entry for 2 MCP blocks with same server, "
|
||||||
|
f"got {len(combined)}: {list(combined.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_credential_combine_no_discriminator_values():
|
||||||
|
"""MCP credential fields without discriminator_values should be merged
|
||||||
|
into a single entry (backwards compat for blocks without server_url set)."""
|
||||||
|
from backend.data.model import CredentialsFieldInfo, CredentialsType
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
|
||||||
|
|
||||||
|
field_a = CredentialsFieldInfo(
|
||||||
|
credentials_provider=frozenset([ProviderName.MCP]),
|
||||||
|
credentials_types=oauth2_types,
|
||||||
|
credentials_scopes=None,
|
||||||
|
discriminator="server_url",
|
||||||
|
)
|
||||||
|
field_b = CredentialsFieldInfo(
|
||||||
|
credentials_provider=frozenset([ProviderName.MCP]),
|
||||||
|
credentials_types=oauth2_types,
|
||||||
|
credentials_scopes=None,
|
||||||
|
discriminator="server_url",
|
||||||
|
)
|
||||||
|
|
||||||
|
combined = CredentialsFieldInfo.combine(
|
||||||
|
(field_a, ("node-a", "credentials")),
|
||||||
|
(field_b, ("node-b", "credentials")),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should produce 1 entry (no URL differentiation)
|
||||||
|
assert len(combined) == 1, (
|
||||||
|
f"Expected 1 credential entry for MCP blocks without discriminator_values, "
|
||||||
|
f"got {len(combined)}: {list(combined.keys())}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from pydantic import (
|
|||||||
GetCoreSchemaHandler,
|
GetCoreSchemaHandler,
|
||||||
SecretStr,
|
SecretStr,
|
||||||
field_serializer,
|
field_serializer,
|
||||||
|
model_validator,
|
||||||
)
|
)
|
||||||
from pydantic_core import (
|
from pydantic_core import (
|
||||||
CoreSchema,
|
CoreSchema,
|
||||||
@@ -499,6 +500,25 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
provider: CP
|
provider: CP
|
||||||
type: CT
|
type: CT
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_legacy_provider(cls, data: Any) -> Any:
|
||||||
|
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug.
|
||||||
|
|
||||||
|
Python 3.13 changed ``str(StrEnum)`` to return ``"ClassName.MEMBER"``
|
||||||
|
instead of the plain value. Old stored credential references may have
|
||||||
|
``provider: "ProviderName.MCP"`` instead of ``"mcp"``.
|
||||||
|
"""
|
||||||
|
if isinstance(data, dict):
|
||||||
|
prov = data.get("provider", "")
|
||||||
|
if isinstance(prov, str) and prov.startswith("ProviderName."):
|
||||||
|
member = prov.removeprefix("ProviderName.")
|
||||||
|
try:
|
||||||
|
data = {**data, "provider": ProviderName[member].value}
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
return data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def allowed_providers(cls) -> tuple[ProviderName, ...] | None:
|
def allowed_providers(cls) -> tuple[ProviderName, ...] | None:
|
||||||
return get_args(cls.model_fields["provider"].annotation)
|
return get_args(cls.model_fields["provider"].annotation)
|
||||||
@@ -603,11 +623,18 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
] = defaultdict(list)
|
] = defaultdict(list)
|
||||||
|
|
||||||
for field, key in fields:
|
for field, key in fields:
|
||||||
if field.provider == frozenset([ProviderName.HTTP]):
|
if (
|
||||||
# HTTP host-scoped credentials can have different hosts that reqires different credential sets.
|
field.discriminator
|
||||||
# Group by host extracted from the URL
|
and not field.discriminator_mapping
|
||||||
|
and field.discriminator_values
|
||||||
|
):
|
||||||
|
# URL-based discrimination (e.g. HTTP host-scoped, MCP server URL):
|
||||||
|
# Each unique host gets its own credential entry.
|
||||||
|
provider_prefix = next(iter(field.provider))
|
||||||
|
# Use .value for enum types to get the plain string (e.g. "mcp" not "ProviderName.MCP")
|
||||||
|
prefix_str = getattr(provider_prefix, "value", str(provider_prefix))
|
||||||
providers = frozenset(
|
providers = frozenset(
|
||||||
[cast(CP, "http")]
|
[cast(CP, prefix_str)]
|
||||||
+ [
|
+ [
|
||||||
cast(CP, parse_url(str(value)).netloc)
|
cast(CP, parse_url(str(value)).netloc)
|
||||||
for value in field.discriminator_values
|
for value in field.discriminator_values
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -225,6 +226,10 @@ class SyncRabbitMQ(RabbitMQBase):
|
|||||||
class AsyncRabbitMQ(RabbitMQBase):
|
class AsyncRabbitMQ(RabbitMQBase):
|
||||||
"""Asynchronous RabbitMQ client"""
|
"""Asynchronous RabbitMQ client"""
|
||||||
|
|
||||||
|
def __init__(self, config: RabbitMQConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self._reconnect_lock: asyncio.Lock | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return bool(self._connection and not self._connection.is_closed)
|
return bool(self._connection and not self._connection.is_closed)
|
||||||
@@ -235,7 +240,17 @@ class AsyncRabbitMQ(RabbitMQBase):
|
|||||||
|
|
||||||
@conn_retry("AsyncRabbitMQ", "Acquiring async connection")
|
@conn_retry("AsyncRabbitMQ", "Acquiring async connection")
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
if self.is_connected:
|
if self.is_connected and self._channel and not self._channel.is_closed:
|
||||||
|
return
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.is_connected
|
||||||
|
and self._connection
|
||||||
|
and (self._channel is None or self._channel.is_closed)
|
||||||
|
):
|
||||||
|
self._channel = await self._connection.channel()
|
||||||
|
await self._channel.set_qos(prefetch_count=1)
|
||||||
|
await self.declare_infrastructure()
|
||||||
return
|
return
|
||||||
|
|
||||||
self._connection = await aio_pika.connect_robust(
|
self._connection = await aio_pika.connect_robust(
|
||||||
@@ -291,24 +306,46 @@ class AsyncRabbitMQ(RabbitMQBase):
|
|||||||
exchange, routing_key=queue.routing_key or queue.name
|
exchange, routing_key=queue.routing_key or queue.name
|
||||||
)
|
)
|
||||||
|
|
||||||
@func_retry
|
@property
|
||||||
async def publish_message(
|
def _lock(self) -> asyncio.Lock:
|
||||||
|
if self._reconnect_lock is None:
|
||||||
|
self._reconnect_lock = asyncio.Lock()
|
||||||
|
return self._reconnect_lock
|
||||||
|
|
||||||
|
async def _ensure_channel(self) -> aio_pika.abc.AbstractChannel:
|
||||||
|
"""Get a valid channel, reconnecting if the current one is stale.
|
||||||
|
|
||||||
|
Uses a lock to prevent concurrent reconnection attempts from racing.
|
||||||
|
"""
|
||||||
|
if self.is_ready:
|
||||||
|
return self._channel # type: ignore # is_ready guarantees non-None
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
# Double-check after acquiring lock
|
||||||
|
if self.is_ready:
|
||||||
|
return self._channel # type: ignore
|
||||||
|
|
||||||
|
self._channel = None
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
if self._channel is None:
|
||||||
|
raise RuntimeError("Channel should be established after connect")
|
||||||
|
|
||||||
|
return self._channel
|
||||||
|
|
||||||
|
async def _publish_once(
|
||||||
self,
|
self,
|
||||||
routing_key: str,
|
routing_key: str,
|
||||||
message: str,
|
message: str,
|
||||||
exchange: Optional[Exchange] = None,
|
exchange: Optional[Exchange] = None,
|
||||||
persistent: bool = True,
|
persistent: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not self.is_ready:
|
channel = await self._ensure_channel()
|
||||||
await self.connect()
|
|
||||||
|
|
||||||
if self._channel is None:
|
|
||||||
raise RuntimeError("Channel should be established after connect")
|
|
||||||
|
|
||||||
if exchange:
|
if exchange:
|
||||||
exchange_obj = await self._channel.get_exchange(exchange.name)
|
exchange_obj = await channel.get_exchange(exchange.name)
|
||||||
else:
|
else:
|
||||||
exchange_obj = self._channel.default_exchange
|
exchange_obj = channel.default_exchange
|
||||||
|
|
||||||
await exchange_obj.publish(
|
await exchange_obj.publish(
|
||||||
aio_pika.Message(
|
aio_pika.Message(
|
||||||
@@ -322,9 +359,23 @@ class AsyncRabbitMQ(RabbitMQBase):
|
|||||||
routing_key=routing_key,
|
routing_key=routing_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@func_retry
|
||||||
|
async def publish_message(
|
||||||
|
self,
|
||||||
|
routing_key: str,
|
||||||
|
message: str,
|
||||||
|
exchange: Optional[Exchange] = None,
|
||||||
|
persistent: bool = True,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
await self._publish_once(routing_key, message, exchange, persistent)
|
||||||
|
except aio_pika.exceptions.ChannelInvalidStateError:
|
||||||
|
logger.warning(
|
||||||
|
"RabbitMQ channel invalid, forcing reconnect and retrying publish"
|
||||||
|
)
|
||||||
|
async with self._lock:
|
||||||
|
self._channel = None
|
||||||
|
await self._publish_once(routing_key, message, exchange, persistent)
|
||||||
|
|
||||||
async def get_channel(self) -> aio_pika.abc.AbstractChannel:
|
async def get_channel(self) -> aio_pika.abc.AbstractChannel:
|
||||||
if not self.is_ready:
|
return await self._ensure_channel()
|
||||||
await self.connect()
|
|
||||||
if self._channel is None:
|
|
||||||
raise RuntimeError("Channel should be established after connect")
|
|
||||||
return self._channel
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from redis.asyncio.lock import Lock as AsyncRedisLock
|
|||||||
|
|
||||||
from backend.blocks.agent import AgentExecutorBlock
|
from backend.blocks.agent import AgentExecutorBlock
|
||||||
from backend.blocks.io import AgentOutputBlock
|
from backend.blocks.io import AgentOutputBlock
|
||||||
|
from backend.blocks.mcp.block import MCPToolBlock
|
||||||
from backend.data import redis_client as redis
|
from backend.data import redis_client as redis
|
||||||
from backend.data.block import (
|
from backend.data.block import (
|
||||||
BlockInput,
|
BlockInput,
|
||||||
@@ -229,6 +230,10 @@ async def execute_node(
|
|||||||
_input_data.nodes_input_masks = nodes_input_masks
|
_input_data.nodes_input_masks = nodes_input_masks
|
||||||
_input_data.user_id = user_id
|
_input_data.user_id = user_id
|
||||||
input_data = _input_data.model_dump()
|
input_data = _input_data.model_dump()
|
||||||
|
elif isinstance(node_block, MCPToolBlock):
|
||||||
|
_mcp_data = MCPToolBlock.Input(**node.input_default)
|
||||||
|
_mcp_data.tool_arguments = input_data.get("tool_arguments", {})
|
||||||
|
input_data = _mcp_data.model_dump()
|
||||||
data.inputs = input_data
|
data.inputs = input_data
|
||||||
|
|
||||||
# Execute the node
|
# Execute the node
|
||||||
@@ -265,8 +270,34 @@ async def execute_node(
|
|||||||
|
|
||||||
# Handle regular credentials fields
|
# Handle regular credentials fields
|
||||||
for field_name, input_type in input_model.get_credentials_fields().items():
|
for field_name, input_type in input_model.get_credentials_fields().items():
|
||||||
credentials_meta = input_type(**input_data[field_name])
|
field_value = input_data.get(field_name)
|
||||||
credentials, lock = await creds_manager.acquire(user_id, credentials_meta.id)
|
if not field_value or (
|
||||||
|
isinstance(field_value, dict) and not field_value.get("id")
|
||||||
|
):
|
||||||
|
# No credentials configured — nullify so JSON schema validation
|
||||||
|
# doesn't choke on the empty default `{}`.
|
||||||
|
input_data[field_name] = None
|
||||||
|
continue # Block runs without credentials
|
||||||
|
|
||||||
|
credentials_meta = input_type(**field_value)
|
||||||
|
# Write normalized values back so JSON schema validation also passes
|
||||||
|
# (model_validator may have fixed legacy formats like "ProviderName.MCP")
|
||||||
|
input_data[field_name] = credentials_meta.model_dump(mode="json")
|
||||||
|
try:
|
||||||
|
credentials, lock = await creds_manager.acquire(
|
||||||
|
user_id, credentials_meta.id
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
# Credential was deleted or doesn't exist.
|
||||||
|
# If the field has a default, run without credentials.
|
||||||
|
if input_model.model_fields[field_name].default is not None:
|
||||||
|
log_metadata.warning(
|
||||||
|
f"Credentials #{credentials_meta.id} not found, "
|
||||||
|
"running without (field has default)"
|
||||||
|
)
|
||||||
|
input_data[field_name] = input_model.model_fields[field_name].default
|
||||||
|
continue
|
||||||
|
raise
|
||||||
creds_locks.append(lock)
|
creds_locks.append(lock)
|
||||||
extra_exec_kwargs[field_name] = credentials
|
extra_exec_kwargs[field_name] = credentials
|
||||||
|
|
||||||
|
|||||||
@@ -265,7 +265,13 @@ async def _validate_node_input_credentials(
|
|||||||
# Track if any credential field is missing for this node
|
# Track if any credential field is missing for this node
|
||||||
has_missing_credentials = False
|
has_missing_credentials = False
|
||||||
|
|
||||||
|
# A credential field is optional if the node metadata says so, or if
|
||||||
|
# the block schema declares a default for the field.
|
||||||
|
required_fields = block.input_schema.get_required_fields()
|
||||||
|
is_creds_optional = node.credentials_optional
|
||||||
|
|
||||||
for field_name, credentials_meta_type in credentials_fields.items():
|
for field_name, credentials_meta_type in credentials_fields.items():
|
||||||
|
field_is_optional = is_creds_optional or field_name not in required_fields
|
||||||
try:
|
try:
|
||||||
# Check nodes_input_masks first, then input_default
|
# Check nodes_input_masks first, then input_default
|
||||||
field_value = None
|
field_value = None
|
||||||
@@ -278,7 +284,7 @@ async def _validate_node_input_credentials(
|
|||||||
elif field_name in node.input_default:
|
elif field_name in node.input_default:
|
||||||
# For optional credentials, don't use input_default - treat as missing
|
# For optional credentials, don't use input_default - treat as missing
|
||||||
# This prevents stale credential IDs from failing validation
|
# This prevents stale credential IDs from failing validation
|
||||||
if node.credentials_optional:
|
if field_is_optional:
|
||||||
field_value = None
|
field_value = None
|
||||||
else:
|
else:
|
||||||
field_value = node.input_default[field_name]
|
field_value = node.input_default[field_name]
|
||||||
@@ -288,8 +294,8 @@ async def _validate_node_input_credentials(
|
|||||||
isinstance(field_value, dict) and not field_value.get("id")
|
isinstance(field_value, dict) and not field_value.get("id")
|
||||||
):
|
):
|
||||||
has_missing_credentials = True
|
has_missing_credentials = True
|
||||||
# If node has credentials_optional flag, mark for skipping instead of error
|
# If credential field is optional, skip instead of error
|
||||||
if node.credentials_optional:
|
if field_is_optional:
|
||||||
continue # Don't add error, will be marked for skip after loop
|
continue # Don't add error, will be marked for skip after loop
|
||||||
else:
|
else:
|
||||||
credential_errors[node.id][
|
credential_errors[node.id][
|
||||||
@@ -339,16 +345,16 @@ async def _validate_node_input_credentials(
|
|||||||
] = "Invalid credentials: type/provider mismatch"
|
] = "Invalid credentials: type/provider mismatch"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If node has optional credentials and any are missing, mark for skipping
|
# If node has optional credentials and any are missing, allow running without.
|
||||||
# But only if there are no other errors for this node
|
# The executor will pass credentials=None to the block's run().
|
||||||
if (
|
if (
|
||||||
has_missing_credentials
|
has_missing_credentials
|
||||||
and node.credentials_optional
|
and is_creds_optional
|
||||||
and node.id not in credential_errors
|
and node.id not in credential_errors
|
||||||
):
|
):
|
||||||
nodes_to_skip.add(node.id)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Node #{node.id} will be skipped: optional credentials not configured"
|
f"Node #{node.id}: optional credentials not configured, "
|
||||||
|
"running without"
|
||||||
)
|
)
|
||||||
|
|
||||||
return credential_errors, nodes_to_skip
|
return credential_errors, nodes_to_skip
|
||||||
|
|||||||
@@ -495,6 +495,7 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip(
|
|||||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||||
"credentials": mock_credentials_field_type
|
"credentials": mock_credentials_field_type
|
||||||
}
|
}
|
||||||
|
mock_block.input_schema.get_required_fields.return_value = {"credentials"}
|
||||||
mock_node.block = mock_block
|
mock_node.block = mock_block
|
||||||
|
|
||||||
# Create mock graph
|
# Create mock graph
|
||||||
@@ -508,8 +509,8 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip(
|
|||||||
nodes_input_masks=None,
|
nodes_input_masks=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Node should be in nodes_to_skip, not in errors
|
# Node should NOT be in nodes_to_skip (runs without credentials) and not in errors
|
||||||
assert mock_node.id in nodes_to_skip
|
assert mock_node.id not in nodes_to_skip
|
||||||
assert mock_node.id not in errors
|
assert mock_node.id not in errors
|
||||||
|
|
||||||
|
|
||||||
@@ -535,6 +536,7 @@ async def test_validate_node_input_credentials_required_missing_creds_error(
|
|||||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||||
"credentials": mock_credentials_field_type
|
"credentials": mock_credentials_field_type
|
||||||
}
|
}
|
||||||
|
mock_block.input_schema.get_required_fields.return_value = {"credentials"}
|
||||||
mock_node.block = mock_block
|
mock_node.block = mock_block
|
||||||
|
|
||||||
# Create mock graph
|
# Create mock graph
|
||||||
|
|||||||
@@ -22,6 +22,27 @@ from backend.util.settings import Settings
|
|||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
|
|
||||||
|
def provider_matches(stored: str, expected: str) -> bool:
|
||||||
|
"""Compare provider strings, handling Python 3.13 ``str(StrEnum)`` bug.
|
||||||
|
|
||||||
|
On Python 3.13, ``str(ProviderName.MCP)`` returns ``"ProviderName.MCP"``
|
||||||
|
instead of ``"mcp"``. OAuth states persisted with the buggy format need
|
||||||
|
to match when ``expected`` is the canonical value (e.g. ``"mcp"``).
|
||||||
|
"""
|
||||||
|
if stored == expected:
|
||||||
|
return True
|
||||||
|
if stored.startswith("ProviderName."):
|
||||||
|
member = stored.removeprefix("ProviderName.")
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
try:
|
||||||
|
return ProviderName[member].value == expected
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
# This is an overrride since ollama doesn't actually require an API key, but the creddential system enforces one be attached
|
# This is an overrride since ollama doesn't actually require an API key, but the creddential system enforces one be attached
|
||||||
ollama_credentials = APIKeyCredentials(
|
ollama_credentials = APIKeyCredentials(
|
||||||
id="744fdc56-071a-4761-b5a5-0af0ce10a2b5",
|
id="744fdc56-071a-4761-b5a5-0af0ce10a2b5",
|
||||||
@@ -389,7 +410,7 @@ class IntegrationCredentialsStore:
|
|||||||
self, user_id: str, provider: str
|
self, user_id: str, provider: str
|
||||||
) -> list[Credentials]:
|
) -> list[Credentials]:
|
||||||
credentials = await self.get_all_creds(user_id)
|
credentials = await self.get_all_creds(user_id)
|
||||||
return [c for c in credentials if c.provider == provider]
|
return [c for c in credentials if provider_matches(c.provider, provider)]
|
||||||
|
|
||||||
async def get_authorized_providers(self, user_id: str) -> list[str]:
|
async def get_authorized_providers(self, user_id: str) -> list[str]:
|
||||||
credentials = await self.get_all_creds(user_id)
|
credentials = await self.get_all_creds(user_id)
|
||||||
@@ -485,17 +506,6 @@ class IntegrationCredentialsStore:
|
|||||||
async with self.edit_user_integrations(user_id) as user_integrations:
|
async with self.edit_user_integrations(user_id) as user_integrations:
|
||||||
user_integrations.oauth_states.append(state)
|
user_integrations.oauth_states.append(state)
|
||||||
|
|
||||||
async with await self.locked_user_integrations(user_id):
|
|
||||||
|
|
||||||
user_integrations = await self._get_user_integrations(user_id)
|
|
||||||
oauth_states = user_integrations.oauth_states
|
|
||||||
oauth_states.append(state)
|
|
||||||
user_integrations.oauth_states = oauth_states
|
|
||||||
|
|
||||||
await self.db_manager.update_user_integrations(
|
|
||||||
user_id=user_id, data=user_integrations
|
|
||||||
)
|
|
||||||
|
|
||||||
return token, code_challenge
|
return token, code_challenge
|
||||||
|
|
||||||
def _generate_code_challenge(self) -> tuple[str, str]:
|
def _generate_code_challenge(self) -> tuple[str, str]:
|
||||||
@@ -521,7 +531,7 @@ class IntegrationCredentialsStore:
|
|||||||
state
|
state
|
||||||
for state in oauth_states
|
for state in oauth_states
|
||||||
if secrets.compare_digest(state.token, token)
|
if secrets.compare_digest(state.token, token)
|
||||||
and state.provider == provider
|
and provider_matches(state.provider, provider)
|
||||||
and state.expires_at > now.timestamp()
|
and state.expires_at > now.timestamp()
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
|
|||||||
@@ -9,7 +9,10 @@ from redis.asyncio.lock import Lock as AsyncRedisLock
|
|||||||
|
|
||||||
from backend.data.model import Credentials, OAuth2Credentials
|
from backend.data.model import Credentials, OAuth2Credentials
|
||||||
from backend.data.redis_client import get_redis_async
|
from backend.data.redis_client import get_redis_async
|
||||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
from backend.integrations.credentials_store import (
|
||||||
|
IntegrationCredentialsStore,
|
||||||
|
provider_matches,
|
||||||
|
)
|
||||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.exceptions import MissingConfigError
|
from backend.util.exceptions import MissingConfigError
|
||||||
@@ -137,6 +140,9 @@ class IntegrationCredentialsManager:
|
|||||||
self, user_id: str, credentials: OAuth2Credentials, lock: bool = True
|
self, user_id: str, credentials: OAuth2Credentials, lock: bool = True
|
||||||
) -> OAuth2Credentials:
|
) -> OAuth2Credentials:
|
||||||
async with self._locked(user_id, credentials.id, "refresh"):
|
async with self._locked(user_id, credentials.id, "refresh"):
|
||||||
|
if provider_matches(credentials.provider, ProviderName.MCP.value):
|
||||||
|
oauth_handler = create_mcp_oauth_handler(credentials)
|
||||||
|
else:
|
||||||
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
|
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
|
||||||
if oauth_handler.needs_refresh(credentials):
|
if oauth_handler.needs_refresh(credentials):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -236,3 +242,25 @@ async def _get_provider_oauth_handler(provider_name_str: str) -> "BaseOAuthHandl
|
|||||||
client_secret=client_secret,
|
client_secret=client_secret,
|
||||||
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
|
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_mcp_oauth_handler(
|
||||||
|
credentials: OAuth2Credentials,
|
||||||
|
) -> "BaseOAuthHandler":
|
||||||
|
"""Create an MCPOAuthHandler from credential metadata for token refresh.
|
||||||
|
|
||||||
|
MCP OAuth handlers have dynamic endpoints discovered per-server, so they
|
||||||
|
can't be registered as singletons in HANDLERS_BY_NAME. Instead, the handler
|
||||||
|
is reconstructed from metadata stored on the credential during initial auth.
|
||||||
|
"""
|
||||||
|
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||||
|
|
||||||
|
meta = credentials.metadata or {}
|
||||||
|
return MCPOAuthHandler(
|
||||||
|
client_id=meta.get("mcp_client_id", ""),
|
||||||
|
client_secret=meta.get("mcp_client_secret", ""),
|
||||||
|
redirect_uri="", # Not needed for token refresh
|
||||||
|
authorize_url="", # Not needed for token refresh
|
||||||
|
token_url=meta.get("mcp_token_url", ""),
|
||||||
|
resource_url=meta.get("mcp_resource_url"),
|
||||||
|
)
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class ProviderName(str, Enum):
|
|||||||
IDEOGRAM = "ideogram"
|
IDEOGRAM = "ideogram"
|
||||||
JINA = "jina"
|
JINA = "jina"
|
||||||
LLAMA_API = "llama_api"
|
LLAMA_API = "llama_api"
|
||||||
|
MCP = "mcp"
|
||||||
MEDIUM = "medium"
|
MEDIUM = "medium"
|
||||||
MEM0 = "mem0"
|
MEM0 = "mem0"
|
||||||
NOTION = "notion"
|
NOTION = "notion"
|
||||||
|
|||||||
@@ -50,6 +50,21 @@ async def _on_graph_activate(graph: "BaseGraph | GraphModel", user_id: str):
|
|||||||
if (
|
if (
|
||||||
creds_meta := new_node.input_default.get(creds_field_name)
|
creds_meta := new_node.input_default.get(creds_field_name)
|
||||||
) and not await get_credentials(creds_meta["id"]):
|
) and not await get_credentials(creds_meta["id"]):
|
||||||
|
# If the credential field is optional (has a default in the
|
||||||
|
# schema, or node metadata marks it optional), clear the stale
|
||||||
|
# reference instead of blocking the save.
|
||||||
|
creds_field_optional = (
|
||||||
|
new_node.credentials_optional
|
||||||
|
or creds_field_name not in block_input_schema.get_required_fields()
|
||||||
|
)
|
||||||
|
if creds_field_optional:
|
||||||
|
new_node.input_default[creds_field_name] = {}
|
||||||
|
logger.warning(
|
||||||
|
f"Node #{new_node.id}: cleared stale optional "
|
||||||
|
f"credentials #{creds_meta['id']} for "
|
||||||
|
f"'{creds_field_name}'"
|
||||||
|
)
|
||||||
|
continue
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Node #{new_node.id} input '{creds_field_name}' updated with "
|
f"Node #{new_node.id} input '{creds_field_name}' updated with "
|
||||||
f"non-existent credentials #{creds_meta['id']}"
|
f"non-existent credentials #{creds_meta['id']}"
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ class HostResolver(abc.AbstractResolver):
|
|||||||
def __init__(self, ssl_hostname: str, ip_addresses: list[str]):
|
def __init__(self, ssl_hostname: str, ip_addresses: list[str]):
|
||||||
self.ssl_hostname = ssl_hostname
|
self.ssl_hostname = ssl_hostname
|
||||||
self.ip_addresses = ip_addresses
|
self.ip_addresses = ip_addresses
|
||||||
self._default = aiohttp.AsyncResolver()
|
self._default = aiohttp.ThreadedResolver()
|
||||||
|
|
||||||
async def resolve(self, host, port=0, family=socket.AF_INET):
|
async def resolve(self, host, port=0, family=socket.AF_INET):
|
||||||
if host == self.ssl_hostname:
|
if host == self.ssl_hostname:
|
||||||
@@ -467,7 +467,7 @@ class Requests:
|
|||||||
resolver = HostResolver(ssl_hostname=hostname, ip_addresses=ip_addresses)
|
resolver = HostResolver(ssl_hostname=hostname, ip_addresses=ip_addresses)
|
||||||
ssl_context = ssl.create_default_context()
|
ssl_context = ssl.create_default_context()
|
||||||
connector = aiohttp.TCPConnector(resolver=resolver, ssl=ssl_context)
|
connector = aiohttp.TCPConnector(resolver=resolver, ssl=ssl_context)
|
||||||
session_kwargs = {}
|
session_kwargs: dict = {}
|
||||||
if connector:
|
if connector:
|
||||||
session_kwargs["connector"] = connector
|
session_kwargs["connector"] = connector
|
||||||
|
|
||||||
|
|||||||
@@ -25,8 +25,12 @@ RUN if [ -f .env.production ]; then \
|
|||||||
cp .env.default .env; \
|
cp .env.default .env; \
|
||||||
fi
|
fi
|
||||||
RUN pnpm run generate:api
|
RUN pnpm run generate:api
|
||||||
|
# Disable source-map generation in Docker builds to halve webpack memory usage.
|
||||||
|
# Source maps are only useful when SENTRY_AUTH_TOKEN is set (Vercel deploys);
|
||||||
|
# the Docker image never uploads them, so generating them just wastes RAM.
|
||||||
|
ENV NEXT_PUBLIC_SOURCEMAPS="false"
|
||||||
# In CI, we want NEXT_PUBLIC_PW_TEST=true during build so Next.js inlines it
|
# In CI, we want NEXT_PUBLIC_PW_TEST=true during build so Next.js inlines it
|
||||||
RUN if [ "$NEXT_PUBLIC_PW_TEST" = "true" ]; then NEXT_PUBLIC_PW_TEST=true NODE_OPTIONS="--max-old-space-size=4096" pnpm build; else NODE_OPTIONS="--max-old-space-size=4096" pnpm build; fi
|
RUN if [ "$NEXT_PUBLIC_PW_TEST" = "true" ]; then NEXT_PUBLIC_PW_TEST=true NODE_OPTIONS="--max-old-space-size=8192" pnpm build; else NODE_OPTIONS="--max-old-space-size=8192" pnpm build; fi
|
||||||
|
|
||||||
# Prod stage - based on NextJS reference Dockerfile https://github.com/vercel/next.js/blob/64271354533ed16da51be5dce85f0dbd15f17517/examples/with-docker/Dockerfile
|
# Prod stage - based on NextJS reference Dockerfile https://github.com/vercel/next.js/blob/64271354533ed16da51be5dce85f0dbd15f17517/examples/with-docker/Dockerfile
|
||||||
FROM node:21-alpine AS prod
|
FROM node:21-alpine AS prod
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
import { withSentryConfig } from "@sentry/nextjs";
|
import { withSentryConfig } from "@sentry/nextjs";
|
||||||
|
|
||||||
|
// Allow Docker builds to skip source-map generation (halves memory usage).
|
||||||
|
// Defaults to true so Vercel/local builds are unaffected.
|
||||||
|
const enableSourceMaps = process.env.NEXT_PUBLIC_SOURCEMAPS !== "false";
|
||||||
|
|
||||||
/** @type {import('next').NextConfig} */
|
/** @type {import('next').NextConfig} */
|
||||||
const nextConfig = {
|
const nextConfig = {
|
||||||
productionBrowserSourceMaps: true,
|
productionBrowserSourceMaps: enableSourceMaps,
|
||||||
// Externalize OpenTelemetry packages to fix Turbopack HMR issues
|
// Externalize OpenTelemetry packages to fix Turbopack HMR issues
|
||||||
serverExternalPackages: [
|
serverExternalPackages: [
|
||||||
"@opentelemetry/instrumentation",
|
"@opentelemetry/instrumentation",
|
||||||
@@ -14,9 +18,37 @@ const nextConfig = {
|
|||||||
serverActions: {
|
serverActions: {
|
||||||
bodySizeLimit: "256mb",
|
bodySizeLimit: "256mb",
|
||||||
},
|
},
|
||||||
// Increase body size limit for API routes (file uploads) - 256MB to match backend limit
|
|
||||||
proxyClientMaxBodySize: "256mb",
|
|
||||||
middlewareClientMaxBodySize: "256mb",
|
middlewareClientMaxBodySize: "256mb",
|
||||||
|
// Limit parallel webpack workers to reduce peak memory during builds.
|
||||||
|
cpus: 2,
|
||||||
|
},
|
||||||
|
// Work around cssnano "Invalid array length" bug in Next.js's bundled
|
||||||
|
// cssnano-simple comment parser when processing very large CSS chunks.
|
||||||
|
// CSS is still bundled correctly; gzip handles most of the size savings anyway.
|
||||||
|
webpack: (config, { dev }) => {
|
||||||
|
if (!dev) {
|
||||||
|
// Next.js adds CssMinimizerPlugin internally (after user config), so we
|
||||||
|
// can't filter it from config.plugins. Instead, intercept the webpack
|
||||||
|
// compilation hooks and replace the buggy plugin's tap with a no-op.
|
||||||
|
config.plugins.push({
|
||||||
|
apply(compiler) {
|
||||||
|
compiler.hooks.compilation.tap(
|
||||||
|
"DisableCssMinimizer",
|
||||||
|
(compilation) => {
|
||||||
|
compilation.hooks.processAssets.intercept({
|
||||||
|
register: (tap) => {
|
||||||
|
if (tap.name === "CssMinimizerPlugin") {
|
||||||
|
return { ...tap, fn: async () => {} };
|
||||||
|
}
|
||||||
|
return tap;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
},
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return config;
|
||||||
},
|
},
|
||||||
images: {
|
images: {
|
||||||
domains: [
|
domains: [
|
||||||
@@ -54,9 +86,16 @@ const nextConfig = {
|
|||||||
transpilePackages: ["geist"],
|
transpilePackages: ["geist"],
|
||||||
};
|
};
|
||||||
|
|
||||||
const isDevelopmentBuild = process.env.NODE_ENV !== "production";
|
// Only run the Sentry webpack plugin when we can actually upload source maps
|
||||||
|
// (i.e. on Vercel with SENTRY_AUTH_TOKEN set). The Sentry *runtime* SDK
|
||||||
|
// (imported in app code) still captures errors without the plugin.
|
||||||
|
// Skipping the plugin saves ~1 GB of peak memory during `next build`.
|
||||||
|
const skipSentryPlugin =
|
||||||
|
process.env.NODE_ENV !== "production" ||
|
||||||
|
!enableSourceMaps ||
|
||||||
|
!process.env.SENTRY_AUTH_TOKEN;
|
||||||
|
|
||||||
export default isDevelopmentBuild
|
export default skipSentryPlugin
|
||||||
? nextConfig
|
? nextConfig
|
||||||
: withSentryConfig(nextConfig, {
|
: withSentryConfig(nextConfig, {
|
||||||
// For all available options, see:
|
// For all available options, see:
|
||||||
@@ -96,7 +135,7 @@ export default isDevelopmentBuild
|
|||||||
|
|
||||||
// This helps Sentry with sourcemaps... https://docs.sentry.io/platforms/javascript/guides/nextjs/sourcemaps/
|
// This helps Sentry with sourcemaps... https://docs.sentry.io/platforms/javascript/guides/nextjs/sourcemaps/
|
||||||
sourcemaps: {
|
sourcemaps: {
|
||||||
disable: false,
|
disable: !enableSourceMaps,
|
||||||
assets: [".next/**/*.js", ".next/**/*.js.map"],
|
assets: [".next/**/*.js", ".next/**/*.js.map"],
|
||||||
ignore: ["**/node_modules/**"],
|
ignore: ["**/node_modules/**"],
|
||||||
deleteSourcemapsAfterUpload: false, // Source is public anyway :)
|
deleteSourcemapsAfterUpload: false, // Source is public anyway :)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
},
|
},
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"dev": "pnpm run generate:api:force && next dev --turbo",
|
"dev": "pnpm run generate:api:force && next dev --turbo",
|
||||||
"build": "next build",
|
"build": "cross-env NODE_OPTIONS=--max-old-space-size=16384 next build",
|
||||||
"start": "next start",
|
"start": "next start",
|
||||||
"start:standalone": "cd .next/standalone && node server.js",
|
"start:standalone": "cd .next/standalone && node server.js",
|
||||||
"lint": "next lint && prettier --check .",
|
"lint": "next lint && prettier --check .",
|
||||||
@@ -30,6 +30,7 @@
|
|||||||
"defaults"
|
"defaults"
|
||||||
],
|
],
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"@ai-sdk/react": "3.0.61",
|
||||||
"@faker-js/faker": "10.0.0",
|
"@faker-js/faker": "10.0.0",
|
||||||
"@hookform/resolvers": "5.2.2",
|
"@hookform/resolvers": "5.2.2",
|
||||||
"@next/third-parties": "15.4.6",
|
"@next/third-parties": "15.4.6",
|
||||||
@@ -60,6 +61,10 @@
|
|||||||
"@rjsf/utils": "6.1.2",
|
"@rjsf/utils": "6.1.2",
|
||||||
"@rjsf/validator-ajv8": "6.1.2",
|
"@rjsf/validator-ajv8": "6.1.2",
|
||||||
"@sentry/nextjs": "10.27.0",
|
"@sentry/nextjs": "10.27.0",
|
||||||
|
"@streamdown/cjk": "1.0.1",
|
||||||
|
"@streamdown/code": "1.0.1",
|
||||||
|
"@streamdown/math": "1.0.1",
|
||||||
|
"@streamdown/mermaid": "1.0.1",
|
||||||
"@supabase/ssr": "0.7.0",
|
"@supabase/ssr": "0.7.0",
|
||||||
"@supabase/supabase-js": "2.78.0",
|
"@supabase/supabase-js": "2.78.0",
|
||||||
"@tanstack/react-query": "5.90.6",
|
"@tanstack/react-query": "5.90.6",
|
||||||
@@ -68,6 +73,7 @@
|
|||||||
"@vercel/analytics": "1.5.0",
|
"@vercel/analytics": "1.5.0",
|
||||||
"@vercel/speed-insights": "1.2.0",
|
"@vercel/speed-insights": "1.2.0",
|
||||||
"@xyflow/react": "12.9.2",
|
"@xyflow/react": "12.9.2",
|
||||||
|
"ai": "6.0.59",
|
||||||
"boring-avatars": "1.11.2",
|
"boring-avatars": "1.11.2",
|
||||||
"class-variance-authority": "0.7.1",
|
"class-variance-authority": "0.7.1",
|
||||||
"clsx": "2.1.1",
|
"clsx": "2.1.1",
|
||||||
@@ -87,7 +93,6 @@
|
|||||||
"launchdarkly-react-client-sdk": "3.9.0",
|
"launchdarkly-react-client-sdk": "3.9.0",
|
||||||
"lodash": "4.17.21",
|
"lodash": "4.17.21",
|
||||||
"lucide-react": "0.552.0",
|
"lucide-react": "0.552.0",
|
||||||
"moment": "2.30.1",
|
|
||||||
"next": "15.4.10",
|
"next": "15.4.10",
|
||||||
"next-themes": "0.4.6",
|
"next-themes": "0.4.6",
|
||||||
"nuqs": "2.7.2",
|
"nuqs": "2.7.2",
|
||||||
@@ -112,9 +117,11 @@
|
|||||||
"remark-math": "6.0.0",
|
"remark-math": "6.0.0",
|
||||||
"shepherd.js": "14.5.1",
|
"shepherd.js": "14.5.1",
|
||||||
"sonner": "2.0.7",
|
"sonner": "2.0.7",
|
||||||
|
"streamdown": "2.1.0",
|
||||||
"tailwind-merge": "2.6.0",
|
"tailwind-merge": "2.6.0",
|
||||||
"tailwind-scrollbar": "3.1.0",
|
"tailwind-scrollbar": "3.1.0",
|
||||||
"tailwindcss-animate": "1.0.7",
|
"tailwindcss-animate": "1.0.7",
|
||||||
|
"use-stick-to-bottom": "1.1.2",
|
||||||
"uuid": "11.1.0",
|
"uuid": "11.1.0",
|
||||||
"vaul": "1.1.2",
|
"vaul": "1.1.2",
|
||||||
"zod": "3.25.76",
|
"zod": "3.25.76",
|
||||||
@@ -172,7 +179,8 @@
|
|||||||
},
|
},
|
||||||
"pnpm": {
|
"pnpm": {
|
||||||
"overrides": {
|
"overrides": {
|
||||||
"@opentelemetry/instrumentation": "0.209.0"
|
"@opentelemetry/instrumentation": "0.209.0",
|
||||||
|
"lodash-es": "4.17.23"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"packageManager": "pnpm@10.20.0+sha512.cf9998222162dd85864d0a8102e7892e7ba4ceadebbf5a31f9c2fce48dfce317a9c53b9f6464d1ef9042cba2e02ae02a9f7c143a2b438cd93c91840f0192b9dd"
|
"packageManager": "pnpm@10.20.0+sha512.cf9998222162dd85864d0a8102e7892e7ba4ceadebbf5a31f9c2fce48dfce317a9c53b9f6464d1ef9042cba2e02ae02a9f7c143a2b438cd93c91840f0192b9dd"
|
||||||
|
|||||||
1180
autogpt_platform/frontend/pnpm-lock.yaml
generated
1180
autogpt_platform/frontend/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,96 @@
|
|||||||
|
import { NextResponse } from "next/server";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Safely encode a value as JSON for embedding in a script tag.
|
||||||
|
* Escapes characters that could break out of the script context to prevent XSS.
|
||||||
|
*/
|
||||||
|
function safeJsonStringify(value: unknown): string {
|
||||||
|
return JSON.stringify(value)
|
||||||
|
.replace(/</g, "\\u003c")
|
||||||
|
.replace(/>/g, "\\u003e")
|
||||||
|
.replace(/&/g, "\\u0026");
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCP-specific OAuth callback route.
|
||||||
|
//
|
||||||
|
// Unlike the generic oauth_callback which relies on window.opener.postMessage,
|
||||||
|
// this route uses BroadcastChannel as the PRIMARY communication method.
|
||||||
|
// This is critical because cross-origin OAuth flows (e.g. Sentry → localhost)
|
||||||
|
// often lose window.opener due to COOP (Cross-Origin-Opener-Policy) headers.
|
||||||
|
//
|
||||||
|
// BroadcastChannel works across all same-origin tabs/popups regardless of opener.
|
||||||
|
export async function GET(request: Request) {
|
||||||
|
const { searchParams } = new URL(request.url);
|
||||||
|
const code = searchParams.get("code");
|
||||||
|
const state = searchParams.get("state");
|
||||||
|
|
||||||
|
const success = Boolean(code && state);
|
||||||
|
const message = success
|
||||||
|
? { success: true, code, state }
|
||||||
|
: {
|
||||||
|
success: false,
|
||||||
|
message: `Missing parameters: ${searchParams.toString()}`,
|
||||||
|
};
|
||||||
|
|
||||||
|
return new NextResponse(
|
||||||
|
`<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head><title>MCP Sign-in</title></head>
|
||||||
|
<body style="font-family: system-ui, -apple-system, sans-serif; display: flex; align-items: center; justify-content: center; min-height: 100vh; margin: 0; background: #f9fafb;">
|
||||||
|
<div style="text-align: center; max-width: 400px; padding: 2rem;">
|
||||||
|
<div id="spinner" style="margin: 0 auto 1rem; width: 32px; height: 32px; border: 3px solid #e5e7eb; border-top-color: #3b82f6; border-radius: 50%; animation: spin 0.8s linear infinite;"></div>
|
||||||
|
<p id="status" style="color: #374151; font-size: 16px;">Completing sign-in...</p>
|
||||||
|
</div>
|
||||||
|
<style>@keyframes spin { to { transform: rotate(360deg); } }</style>
|
||||||
|
<script>
|
||||||
|
(function() {
|
||||||
|
var msg = ${safeJsonStringify(message)};
|
||||||
|
var sent = false;
|
||||||
|
|
||||||
|
// Method 1: BroadcastChannel (reliable across tabs/popups, no opener needed)
|
||||||
|
try {
|
||||||
|
var bc = new BroadcastChannel("mcp_oauth");
|
||||||
|
bc.postMessage({ type: "mcp_oauth_result", success: msg.success, code: msg.code, state: msg.state, message: msg.message });
|
||||||
|
bc.close();
|
||||||
|
sent = true;
|
||||||
|
} catch(e) { /* BroadcastChannel not supported */ }
|
||||||
|
|
||||||
|
// Method 2: window.opener.postMessage (fallback for same-origin popups)
|
||||||
|
try {
|
||||||
|
if (window.opener && !window.opener.closed) {
|
||||||
|
window.opener.postMessage(
|
||||||
|
{ message_type: "mcp_oauth_result", success: msg.success, code: msg.code, state: msg.state, message: msg.message },
|
||||||
|
window.location.origin
|
||||||
|
);
|
||||||
|
sent = true;
|
||||||
|
}
|
||||||
|
} catch(e) { /* opener not available (COOP) */ }
|
||||||
|
|
||||||
|
// Method 3: localStorage (most reliable cross-tab fallback)
|
||||||
|
try {
|
||||||
|
localStorage.setItem("mcp_oauth_result", JSON.stringify(msg));
|
||||||
|
sent = true;
|
||||||
|
} catch(e) { /* localStorage not available */ }
|
||||||
|
|
||||||
|
var statusEl = document.getElementById("status");
|
||||||
|
var spinnerEl = document.getElementById("spinner");
|
||||||
|
spinnerEl.style.display = "none";
|
||||||
|
|
||||||
|
if (msg.success && sent) {
|
||||||
|
statusEl.textContent = "Sign-in complete! This window will close.";
|
||||||
|
statusEl.style.color = "#059669";
|
||||||
|
setTimeout(function() { window.close(); }, 1500);
|
||||||
|
} else if (msg.success) {
|
||||||
|
statusEl.textContent = "Sign-in successful! You can close this tab and return to the builder.";
|
||||||
|
statusEl.style.color = "#059669";
|
||||||
|
} else {
|
||||||
|
statusEl.textContent = "Sign-in failed: " + (msg.message || "Unknown error");
|
||||||
|
statusEl.style.color = "#dc2626";
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>`,
|
||||||
|
{ headers: { "Content-Type": "text/html" } },
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -47,7 +47,10 @@ export type CustomNode = XYNode<CustomNodeData, "custom">;
|
|||||||
|
|
||||||
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||||
({ data, id: nodeId, selected }) => {
|
({ data, id: nodeId, selected }) => {
|
||||||
const { inputSchema, outputSchema } = useCustomNode({ data, nodeId });
|
const { inputSchema, outputSchema, isMCPWithTool } = useCustomNode({
|
||||||
|
data,
|
||||||
|
nodeId,
|
||||||
|
});
|
||||||
|
|
||||||
const isAgent = data.uiType === BlockUIType.AGENT;
|
const isAgent = data.uiType === BlockUIType.AGENT;
|
||||||
|
|
||||||
@@ -98,6 +101,7 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
|||||||
jsonSchema={preprocessInputSchema(inputSchema)}
|
jsonSchema={preprocessInputSchema(inputSchema)}
|
||||||
nodeId={nodeId}
|
nodeId={nodeId}
|
||||||
uiType={data.uiType}
|
uiType={data.uiType}
|
||||||
|
isMCPWithTool={isMCPWithTool}
|
||||||
className={cn(
|
className={cn(
|
||||||
"bg-white px-4",
|
"bg-white px-4",
|
||||||
isWebhook && "pointer-events-none opacity-50",
|
isWebhook && "pointer-events-none opacity-50",
|
||||||
|
|||||||
@@ -20,10 +20,8 @@ type Props = {
|
|||||||
|
|
||||||
export const NodeHeader = ({ data, nodeId }: Props) => {
|
export const NodeHeader = ({ data, nodeId }: Props) => {
|
||||||
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
||||||
const title =
|
|
||||||
(data.metadata?.customized_name as string) ||
|
const title = (data.metadata?.customized_name as string) || data.title;
|
||||||
data.hardcodedValues?.agent_name ||
|
|
||||||
data.title;
|
|
||||||
|
|
||||||
const [isEditingTitle, setIsEditingTitle] = useState(false);
|
const [isEditingTitle, setIsEditingTitle] = useState(false);
|
||||||
const [editedTitle, setEditedTitle] = useState(title);
|
const [editedTitle, setEditedTitle] = useState(title);
|
||||||
|
|||||||
@@ -3,6 +3,36 @@ import { CustomNodeData } from "./CustomNode";
|
|||||||
import { BlockUIType } from "../../../types";
|
import { BlockUIType } from "../../../types";
|
||||||
import { useMemo } from "react";
|
import { useMemo } from "react";
|
||||||
import { mergeSchemaForResolution } from "./helpers";
|
import { mergeSchemaForResolution } from "./helpers";
|
||||||
|
import { SpecialBlockID } from "@/lib/autogpt-server-api";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build a dynamic input schema for MCP blocks.
|
||||||
|
*
|
||||||
|
* When a tool has been selected (tool_input_schema is populated), the block
|
||||||
|
* renders the selected tool's input parameters *plus* the credentials field
|
||||||
|
* so users can select/change the OAuth credential used for execution.
|
||||||
|
*
|
||||||
|
* Static fields like server_url, selected_tool, available_tools, and
|
||||||
|
* tool_arguments are hidden because they're pre-configured from the dialog.
|
||||||
|
*/
|
||||||
|
function buildMCPInputSchema(
|
||||||
|
toolInputSchema: Record<string, any>,
|
||||||
|
blockInputSchema: Record<string, any>,
|
||||||
|
): Record<string, any> {
|
||||||
|
// Extract the credentials field from the block's original input schema
|
||||||
|
const credentialsSchema =
|
||||||
|
blockInputSchema?.properties?.credentials ?? undefined;
|
||||||
|
|
||||||
|
return {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
// Credentials field first so the dropdown appears at the top
|
||||||
|
...(credentialsSchema ? { credentials: credentialsSchema } : {}),
|
||||||
|
...(toolInputSchema.properties ?? {}),
|
||||||
|
},
|
||||||
|
required: [...(toolInputSchema.required ?? [])],
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
export const useCustomNode = ({
|
export const useCustomNode = ({
|
||||||
data,
|
data,
|
||||||
@@ -19,9 +49,17 @@ export const useCustomNode = ({
|
|||||||
);
|
);
|
||||||
|
|
||||||
const isAgent = data.uiType === BlockUIType.AGENT;
|
const isAgent = data.uiType === BlockUIType.AGENT;
|
||||||
|
const isMCPWithTool =
|
||||||
|
data.block_id === SpecialBlockID.MCP_TOOL &&
|
||||||
|
!!data.hardcodedValues?.tool_input_schema?.properties;
|
||||||
|
|
||||||
const currentInputSchema = isAgent
|
const currentInputSchema = isAgent
|
||||||
? (data.hardcodedValues.input_schema ?? {})
|
? (data.hardcodedValues.input_schema ?? {})
|
||||||
|
: isMCPWithTool
|
||||||
|
? buildMCPInputSchema(
|
||||||
|
data.hardcodedValues.tool_input_schema,
|
||||||
|
data.inputSchema,
|
||||||
|
)
|
||||||
: data.inputSchema;
|
: data.inputSchema;
|
||||||
const currentOutputSchema = isAgent
|
const currentOutputSchema = isAgent
|
||||||
? (data.hardcodedValues.output_schema ?? {})
|
? (data.hardcodedValues.output_schema ?? {})
|
||||||
@@ -54,5 +92,6 @@ export const useCustomNode = ({
|
|||||||
return {
|
return {
|
||||||
inputSchema,
|
inputSchema,
|
||||||
outputSchema,
|
outputSchema,
|
||||||
|
isMCPWithTool,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -9,39 +9,72 @@ interface FormCreatorProps {
|
|||||||
jsonSchema: RJSFSchema;
|
jsonSchema: RJSFSchema;
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
uiType: BlockUIType;
|
uiType: BlockUIType;
|
||||||
|
/** When true the block is an MCP Tool with a selected tool. */
|
||||||
|
isMCPWithTool?: boolean;
|
||||||
showHandles?: boolean;
|
showHandles?: boolean;
|
||||||
className?: string;
|
className?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const FormCreator: React.FC<FormCreatorProps> = React.memo(
|
export const FormCreator: React.FC<FormCreatorProps> = React.memo(
|
||||||
({ jsonSchema, nodeId, uiType, showHandles = true, className }) => {
|
({
|
||||||
|
jsonSchema,
|
||||||
|
nodeId,
|
||||||
|
uiType,
|
||||||
|
isMCPWithTool = false,
|
||||||
|
showHandles = true,
|
||||||
|
className,
|
||||||
|
}) => {
|
||||||
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
||||||
|
|
||||||
const getHardCodedValues = useNodeStore(
|
const getHardCodedValues = useNodeStore(
|
||||||
(state) => state.getHardCodedValues,
|
(state) => state.getHardCodedValues,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const isAgent = uiType === BlockUIType.AGENT;
|
||||||
|
|
||||||
const handleChange = ({ formData }: any) => {
|
const handleChange = ({ formData }: any) => {
|
||||||
if ("credentials" in formData && !formData.credentials?.id) {
|
if ("credentials" in formData && !formData.credentials?.id) {
|
||||||
delete formData.credentials;
|
delete formData.credentials;
|
||||||
}
|
}
|
||||||
|
|
||||||
const updatedValues =
|
let updatedValues;
|
||||||
uiType === BlockUIType.AGENT
|
if (isAgent) {
|
||||||
? {
|
updatedValues = {
|
||||||
...getHardCodedValues(nodeId),
|
...getHardCodedValues(nodeId),
|
||||||
inputs: formData,
|
inputs: formData,
|
||||||
|
};
|
||||||
|
} else if (isMCPWithTool) {
|
||||||
|
// Separate credentials from tool arguments — credentials are stored
|
||||||
|
// at the top level of hardcodedValues, not inside tool_arguments.
|
||||||
|
const { credentials, ...toolArgs } = formData;
|
||||||
|
updatedValues = {
|
||||||
|
...getHardCodedValues(nodeId),
|
||||||
|
tool_arguments: toolArgs,
|
||||||
|
...(credentials?.id ? { credentials } : {}),
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
updatedValues = formData;
|
||||||
}
|
}
|
||||||
: formData;
|
|
||||||
|
|
||||||
updateNodeData(nodeId, { hardcodedValues: updatedValues });
|
updateNodeData(nodeId, { hardcodedValues: updatedValues });
|
||||||
};
|
};
|
||||||
|
|
||||||
const hardcodedValues = getHardCodedValues(nodeId);
|
const hardcodedValues = getHardCodedValues(nodeId);
|
||||||
const initialValues =
|
|
||||||
uiType === BlockUIType.AGENT
|
let initialValues;
|
||||||
? (hardcodedValues.inputs ?? {})
|
if (isAgent) {
|
||||||
: hardcodedValues;
|
initialValues = hardcodedValues.inputs ?? {};
|
||||||
|
} else if (isMCPWithTool) {
|
||||||
|
// Merge tool arguments with credentials for the form
|
||||||
|
initialValues = {
|
||||||
|
...(hardcodedValues.tool_arguments ?? {}),
|
||||||
|
...(hardcodedValues.credentials?.id
|
||||||
|
? { credentials: hardcodedValues.credentials }
|
||||||
|
: {}),
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
initialValues = hardcodedValues;
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { Button } from "@/components/__legacy__/ui/button";
|
import { Button } from "@/components/__legacy__/ui/button";
|
||||||
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
||||||
import { beautifyString, cn } from "@/lib/utils";
|
import { beautifyString, cn } from "@/lib/utils";
|
||||||
import React, { ButtonHTMLAttributes } from "react";
|
import React, { ButtonHTMLAttributes, useCallback, useState } from "react";
|
||||||
import { highlightText } from "./helpers";
|
import { highlightText } from "./helpers";
|
||||||
import { PlusIcon } from "@phosphor-icons/react";
|
import { PlusIcon } from "@phosphor-icons/react";
|
||||||
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
||||||
@@ -9,6 +9,12 @@ import { useControlPanelStore } from "../../../stores/controlPanelStore";
|
|||||||
import { blockDragPreviewStyle } from "./style";
|
import { blockDragPreviewStyle } from "./style";
|
||||||
import { useReactFlow } from "@xyflow/react";
|
import { useReactFlow } from "@xyflow/react";
|
||||||
import { useNodeStore } from "../../../stores/nodeStore";
|
import { useNodeStore } from "../../../stores/nodeStore";
|
||||||
|
import { SpecialBlockID } from "@/lib/autogpt-server-api";
|
||||||
|
import {
|
||||||
|
MCPToolDialog,
|
||||||
|
type MCPToolDialogResult,
|
||||||
|
} from "@/app/(platform)/build/components/legacy-builder/MCPToolDialog";
|
||||||
|
|
||||||
interface Props extends ButtonHTMLAttributes<HTMLButtonElement> {
|
interface Props extends ButtonHTMLAttributes<HTMLButtonElement> {
|
||||||
title?: string;
|
title?: string;
|
||||||
description?: string;
|
description?: string;
|
||||||
@@ -33,9 +39,13 @@ export const Block: BlockComponent = ({
|
|||||||
);
|
);
|
||||||
const { setViewport } = useReactFlow();
|
const { setViewport } = useReactFlow();
|
||||||
const { addBlock } = useNodeStore();
|
const { addBlock } = useNodeStore();
|
||||||
|
const [mcpDialogOpen, setMcpDialogOpen] = useState(false);
|
||||||
|
|
||||||
const handleClick = () => {
|
const isMCPBlock = blockData.id === SpecialBlockID.MCP_TOOL;
|
||||||
const customNode = addBlock(blockData);
|
|
||||||
|
const addBlockAndCenter = useCallback(
|
||||||
|
(block: BlockInfo, hardcodedValues?: Record<string, any>) => {
|
||||||
|
const customNode = addBlock(block, hardcodedValues);
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
setViewport(
|
setViewport(
|
||||||
{
|
{
|
||||||
@@ -46,9 +56,69 @@ export const Block: BlockComponent = ({
|
|||||||
{ duration: 500 },
|
{ duration: 500 },
|
||||||
);
|
);
|
||||||
}, 50);
|
}, 50);
|
||||||
|
return customNode;
|
||||||
|
},
|
||||||
|
[addBlock, setViewport],
|
||||||
|
);
|
||||||
|
|
||||||
|
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
||||||
|
|
||||||
|
const handleMCPToolConfirm = useCallback(
|
||||||
|
(result: MCPToolDialogResult) => {
|
||||||
|
// Derive a display label: prefer server name, fall back to URL hostname.
|
||||||
|
let serverLabel = result.serverName;
|
||||||
|
if (!serverLabel) {
|
||||||
|
try {
|
||||||
|
serverLabel = new URL(result.serverUrl).hostname;
|
||||||
|
} catch {
|
||||||
|
serverLabel = "MCP";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const customNode = addBlockAndCenter(blockData, {
|
||||||
|
server_url: result.serverUrl,
|
||||||
|
server_name: serverLabel,
|
||||||
|
selected_tool: result.selectedTool,
|
||||||
|
tool_input_schema: result.toolInputSchema,
|
||||||
|
available_tools: result.availableTools,
|
||||||
|
credentials: result.credentials ?? undefined,
|
||||||
|
});
|
||||||
|
if (customNode) {
|
||||||
|
const title = result.selectedTool
|
||||||
|
? `${serverLabel}: ${beautifyString(result.selectedTool)}`
|
||||||
|
: undefined;
|
||||||
|
updateNodeData(customNode.id, {
|
||||||
|
metadata: {
|
||||||
|
...customNode.data.metadata,
|
||||||
|
credentials_optional: true,
|
||||||
|
...(title && { customized_name: title }),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
setMcpDialogOpen(false);
|
||||||
|
},
|
||||||
|
[addBlockAndCenter, blockData, updateNodeData],
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleClick = () => {
|
||||||
|
if (isMCPBlock) {
|
||||||
|
setMcpDialogOpen(true);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const customNode = addBlockAndCenter(blockData);
|
||||||
|
// Set customized_name for agent blocks so the agent's name persists
|
||||||
|
if (customNode && blockData.id === SpecialBlockID.AGENT) {
|
||||||
|
updateNodeData(customNode.id, {
|
||||||
|
metadata: {
|
||||||
|
...customNode.data.metadata,
|
||||||
|
customized_name: blockData.name,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleDragStart = (e: React.DragEvent<HTMLButtonElement>) => {
|
const handleDragStart = (e: React.DragEvent<HTMLButtonElement>) => {
|
||||||
|
if (isMCPBlock) return;
|
||||||
e.dataTransfer.effectAllowed = "copy";
|
e.dataTransfer.effectAllowed = "copy";
|
||||||
e.dataTransfer.setData("application/reactflow", JSON.stringify(blockData));
|
e.dataTransfer.setData("application/reactflow", JSON.stringify(blockData));
|
||||||
|
|
||||||
@@ -71,12 +141,14 @@ export const Block: BlockComponent = ({
|
|||||||
: undefined;
|
: undefined;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
<>
|
||||||
<Button
|
<Button
|
||||||
draggable={true}
|
draggable={!isMCPBlock}
|
||||||
data-id={blockDataId}
|
data-id={blockDataId}
|
||||||
className={cn(
|
className={cn(
|
||||||
"group flex h-16 w-full min-w-[7.5rem] items-center justify-start space-x-3 whitespace-normal rounded-[0.75rem] bg-zinc-50 px-[0.875rem] py-[0.625rem] text-start shadow-none",
|
"group flex h-16 w-full min-w-[7.5rem] items-center justify-start space-x-3 whitespace-normal rounded-[0.75rem] bg-zinc-50 px-[0.875rem] py-[0.625rem] text-start shadow-none",
|
||||||
"hover:cursor-default hover:bg-zinc-100 focus:ring-0 active:bg-zinc-100 active:ring-1 active:ring-zinc-300 disabled:cursor-not-allowed",
|
"hover:cursor-default hover:bg-zinc-100 focus:ring-0 active:bg-zinc-100 active:ring-1 active:ring-zinc-300 disabled:cursor-not-allowed",
|
||||||
|
isMCPBlock && "hover:cursor-pointer",
|
||||||
className,
|
className,
|
||||||
)}
|
)}
|
||||||
onDragStart={handleDragStart}
|
onDragStart={handleDragStart}
|
||||||
@@ -111,6 +183,14 @@ export const Block: BlockComponent = ({
|
|||||||
<PlusIcon className="h-5 w-5 text-zinc-50" />
|
<PlusIcon className="h-5 w-5 text-zinc-50" />
|
||||||
</div>
|
</div>
|
||||||
</Button>
|
</Button>
|
||||||
|
{isMCPBlock && (
|
||||||
|
<MCPToolDialog
|
||||||
|
open={mcpDialogOpen}
|
||||||
|
onClose={() => setMcpDialogOpen(false)}
|
||||||
|
onConfirm={handleMCPToolConfirm}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { debounce } from "lodash";
|
import debounce from "lodash/debounce";
|
||||||
import { useCallback, useEffect, useRef, useState } from "react";
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
import { useBlockMenuStore } from "../../../../stores/blockMenuStore";
|
import { useBlockMenuStore } from "../../../../stores/blockMenuStore";
|
||||||
import { getQueryClient } from "@/lib/react-query/queryClient";
|
import { getQueryClient } from "@/lib/react-query/queryClient";
|
||||||
|
|||||||
@@ -70,10 +70,10 @@ export const HorizontalScroll: React.FC<HorizontalScrollAreaProps> = ({
|
|||||||
{children}
|
{children}
|
||||||
</div>
|
</div>
|
||||||
{canScrollLeft && (
|
{canScrollLeft && (
|
||||||
<div className="pointer-events-none absolute inset-y-0 left-0 w-8 bg-gradient-to-r from-white via-white/80 to-white/0" />
|
<div className="pointer-events-none absolute inset-y-0 left-0 w-8 bg-gradient-to-r from-background via-background/80 to-background/0" />
|
||||||
)}
|
)}
|
||||||
{canScrollRight && (
|
{canScrollRight && (
|
||||||
<div className="pointer-events-none absolute inset-y-0 right-0 w-8 bg-gradient-to-l from-white via-white/80 to-white/0" />
|
<div className="pointer-events-none absolute inset-y-0 right-0 w-8 bg-gradient-to-l from-background via-background/80 to-background/0" />
|
||||||
)}
|
)}
|
||||||
{canScrollLeft && (
|
{canScrollLeft && (
|
||||||
<button
|
<button
|
||||||
|
|||||||
@@ -29,6 +29,10 @@ import {
|
|||||||
TooltipTrigger,
|
TooltipTrigger,
|
||||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||||
import { GraphMeta } from "@/lib/autogpt-server-api";
|
import { GraphMeta } from "@/lib/autogpt-server-api";
|
||||||
|
import {
|
||||||
|
MCPToolDialog,
|
||||||
|
type MCPToolDialogResult,
|
||||||
|
} from "@/app/(platform)/build/components/legacy-builder/MCPToolDialog";
|
||||||
import jaro from "jaro-winkler";
|
import jaro from "jaro-winkler";
|
||||||
import { getV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
import { getV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||||
import { okData } from "@/app/api/helpers";
|
import { okData } from "@/app/api/helpers";
|
||||||
@@ -94,6 +98,7 @@ export function BlocksControl({
|
|||||||
const [searchQuery, setSearchQuery] = useState("");
|
const [searchQuery, setSearchQuery] = useState("");
|
||||||
const deferredSearchQuery = useDeferredValue(searchQuery);
|
const deferredSearchQuery = useDeferredValue(searchQuery);
|
||||||
const [selectedCategory, setSelectedCategory] = useState<string | null>(null);
|
const [selectedCategory, setSelectedCategory] = useState<string | null>(null);
|
||||||
|
const [mcpDialogOpen, setMcpDialogOpen] = useState(false);
|
||||||
|
|
||||||
const blocks = useSearchableBlocks(_blocks);
|
const blocks = useSearchableBlocks(_blocks);
|
||||||
|
|
||||||
@@ -186,11 +191,32 @@ export function BlocksControl({
|
|||||||
setSelectedCategory(null);
|
setSelectedCategory(null);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
const handleMCPToolConfirm = useCallback(
|
||||||
|
(result: MCPToolDialogResult) => {
|
||||||
|
addBlock(SpecialBlockID.MCP_TOOL, "MCPToolBlock", {
|
||||||
|
server_url: result.serverUrl,
|
||||||
|
server_name: result.serverName,
|
||||||
|
selected_tool: result.selectedTool,
|
||||||
|
tool_input_schema: result.toolInputSchema,
|
||||||
|
available_tools: result.availableTools,
|
||||||
|
credentials: result.credentials ?? undefined,
|
||||||
|
});
|
||||||
|
setMcpDialogOpen(false);
|
||||||
|
},
|
||||||
|
[addBlock],
|
||||||
|
);
|
||||||
|
|
||||||
// Handler to add a block, fetching graph data on-demand for agent blocks
|
// Handler to add a block, fetching graph data on-demand for agent blocks
|
||||||
const handleAddBlock = useCallback(
|
const handleAddBlock = useCallback(
|
||||||
async (block: _Block & { notAvailable: string | null }) => {
|
async (block: _Block & { notAvailable: string | null }) => {
|
||||||
if (block.notAvailable) return;
|
if (block.notAvailable) return;
|
||||||
|
|
||||||
|
// For MCP blocks, open the configuration dialog instead of placing directly
|
||||||
|
if (block.id === SpecialBlockID.MCP_TOOL) {
|
||||||
|
setMcpDialogOpen(true);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// For agent blocks, fetch the full graph to get schemas
|
// For agent blocks, fetch the full graph to get schemas
|
||||||
if (block.uiType === BlockUIType.AGENT && block.hardcodedValues) {
|
if (block.uiType === BlockUIType.AGENT && block.hardcodedValues) {
|
||||||
const graphID = block.hardcodedValues.graph_id as string;
|
const graphID = block.hardcodedValues.graph_id as string;
|
||||||
@@ -230,6 +256,7 @@ export function BlocksControl({
|
|||||||
}, [blocks]);
|
}, [blocks]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
<>
|
||||||
<Popover
|
<Popover
|
||||||
open={pinBlocksPopover ? true : undefined}
|
open={pinBlocksPopover ? true : undefined}
|
||||||
onOpenChange={(open) => open || resetFilters()}
|
onOpenChange={(open) => open || resetFilters()}
|
||||||
@@ -322,12 +349,21 @@ export function BlocksControl({
|
|||||||
className={`m-2 my-4 flex h-20 shadow-none dark:border-slate-700 dark:bg-slate-800 dark:text-slate-100 dark:hover:bg-slate-700 ${
|
className={`m-2 my-4 flex h-20 shadow-none dark:border-slate-700 dark:bg-slate-800 dark:text-slate-100 dark:hover:bg-slate-700 ${
|
||||||
block.notAvailable
|
block.notAvailable
|
||||||
? "cursor-not-allowed opacity-50"
|
? "cursor-not-allowed opacity-50"
|
||||||
|
: block.id === SpecialBlockID.MCP_TOOL
|
||||||
|
? "cursor-pointer hover:shadow-lg"
|
||||||
: "cursor-move hover:shadow-lg"
|
: "cursor-move hover:shadow-lg"
|
||||||
}`}
|
}`}
|
||||||
data-id={`block-card-${block.id}`}
|
data-id={`block-card-${block.id}`}
|
||||||
draggable={!block.notAvailable}
|
draggable={
|
||||||
|
!block.notAvailable &&
|
||||||
|
block.id !== SpecialBlockID.MCP_TOOL
|
||||||
|
}
|
||||||
onDragStart={(e) => {
|
onDragStart={(e) => {
|
||||||
if (block.notAvailable) return;
|
if (
|
||||||
|
block.notAvailable ||
|
||||||
|
block.id === SpecialBlockID.MCP_TOOL
|
||||||
|
)
|
||||||
|
return;
|
||||||
e.dataTransfer.effectAllowed = "copy";
|
e.dataTransfer.effectAllowed = "copy";
|
||||||
e.dataTransfer.setData(
|
e.dataTransfer.setData(
|
||||||
"application/reactflow",
|
"application/reactflow",
|
||||||
@@ -386,6 +422,13 @@ export function BlocksControl({
|
|||||||
</Card>
|
</Card>
|
||||||
</PopoverContent>
|
</PopoverContent>
|
||||||
</Popover>
|
</Popover>
|
||||||
|
|
||||||
|
<MCPToolDialog
|
||||||
|
open={mcpDialogOpen}
|
||||||
|
onClose={() => setMcpDialogOpen(false)}
|
||||||
|
onConfirm={handleMCPToolConfirm}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import {
|
|||||||
GraphInputSchema,
|
GraphInputSchema,
|
||||||
GraphOutputSchema,
|
GraphOutputSchema,
|
||||||
NodeExecutionResult,
|
NodeExecutionResult,
|
||||||
|
SpecialBlockID,
|
||||||
} from "@/lib/autogpt-server-api";
|
} from "@/lib/autogpt-server-api";
|
||||||
import {
|
import {
|
||||||
beautifyString,
|
beautifyString,
|
||||||
@@ -215,6 +216,26 @@ export const CustomNode = React.memo(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MCP Tool block: display the selected tool's dynamic schema
|
||||||
|
const isMCPWithTool =
|
||||||
|
data.block_id === SpecialBlockID.MCP_TOOL &&
|
||||||
|
!!data.hardcodedValues?.tool_input_schema?.properties;
|
||||||
|
|
||||||
|
if (isMCPWithTool) {
|
||||||
|
// Show only the tool's input parameters. Credentials are NOT included
|
||||||
|
// because authentication is handled by the MCP dialog's OAuth flow
|
||||||
|
// and stored server-side.
|
||||||
|
const toolSchema = data.hardcodedValues.tool_input_schema;
|
||||||
|
|
||||||
|
data.inputSchema = {
|
||||||
|
type: "object",
|
||||||
|
properties: {
|
||||||
|
...(toolSchema.properties ?? {}),
|
||||||
|
},
|
||||||
|
required: [...(toolSchema.required ?? [])],
|
||||||
|
} as BlockIORootSchema;
|
||||||
|
}
|
||||||
|
|
||||||
const setHardcodedValues = useCallback(
|
const setHardcodedValues = useCallback(
|
||||||
(values: any) => {
|
(values: any) => {
|
||||||
updateNodeData(id, { hardcodedValues: values });
|
updateNodeData(id, { hardcodedValues: values });
|
||||||
@@ -375,7 +396,9 @@ export const CustomNode = React.memo(
|
|||||||
|
|
||||||
const displayTitle =
|
const displayTitle =
|
||||||
customTitle ||
|
customTitle ||
|
||||||
beautifyString(data.blockType?.replace(/Block$/, "") || data.title);
|
(isMCPWithTool
|
||||||
|
? `${data.hardcodedValues.server_name || "MCP"}: ${beautifyString(data.hardcodedValues.selected_tool || "")}`
|
||||||
|
: beautifyString(data.blockType?.replace(/Block$/, "") || data.title));
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
isInitialSetup.current = false;
|
isInitialSetup.current = false;
|
||||||
@@ -389,6 +412,15 @@ export const CustomNode = React.memo(
|
|||||||
data.inputSchema,
|
data.inputSchema,
|
||||||
),
|
),
|
||||||
});
|
});
|
||||||
|
} else if (isMCPWithTool) {
|
||||||
|
// MCP dialog already configured server_url, selected_tool, etc.
|
||||||
|
// Just ensure tool_arguments is initialized.
|
||||||
|
if (!data.hardcodedValues.tool_arguments) {
|
||||||
|
setHardcodedValues({
|
||||||
|
...data.hardcodedValues,
|
||||||
|
tool_arguments: {},
|
||||||
|
});
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
setHardcodedValues(
|
setHardcodedValues(
|
||||||
fillObjectDefaultsFromSchema(data.hardcodedValues, data.inputSchema),
|
fillObjectDefaultsFromSchema(data.hardcodedValues, data.inputSchema),
|
||||||
@@ -525,8 +557,11 @@ export const CustomNode = React.memo(
|
|||||||
);
|
);
|
||||||
|
|
||||||
default:
|
default:
|
||||||
const getInputPropKey = (key: string) =>
|
const getInputPropKey = (key: string) => {
|
||||||
nodeType == BlockUIType.AGENT ? `inputs.${key}` : key;
|
if (nodeType == BlockUIType.AGENT) return `inputs.${key}`;
|
||||||
|
if (isMCPWithTool) return `tool_arguments.${key}`;
|
||||||
|
return key;
|
||||||
|
};
|
||||||
|
|
||||||
return keys.map(([propKey, propSchema]) => {
|
return keys.map(([propKey, propSchema]) => {
|
||||||
const isRequired = data.inputSchema.required?.includes(propKey);
|
const isRequired = data.inputSchema.required?.includes(propKey);
|
||||||
|
|||||||
@@ -42,7 +42,11 @@ import { getV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/
|
|||||||
import { okData } from "@/app/api/helpers";
|
import { okData } from "@/app/api/helpers";
|
||||||
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
|
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
|
||||||
import { Key, storage } from "@/services/storage/local-storage";
|
import { Key, storage } from "@/services/storage/local-storage";
|
||||||
import { findNewlyAddedBlockCoordinates, getTypeColor } from "@/lib/utils";
|
import {
|
||||||
|
beautifyString,
|
||||||
|
findNewlyAddedBlockCoordinates,
|
||||||
|
getTypeColor,
|
||||||
|
} from "@/lib/utils";
|
||||||
import { history } from "../history";
|
import { history } from "../history";
|
||||||
import { CustomEdge } from "../CustomEdge/CustomEdge";
|
import { CustomEdge } from "../CustomEdge/CustomEdge";
|
||||||
import ConnectionLine from "../ConnectionLine";
|
import ConnectionLine from "../ConnectionLine";
|
||||||
@@ -748,6 +752,27 @@ const FlowEditor: React.FC<{
|
|||||||
block_id: blockID,
|
block_id: blockID,
|
||||||
isOutputStatic: nodeSchema.staticOutput,
|
isOutputStatic: nodeSchema.staticOutput,
|
||||||
uiType: nodeSchema.uiType,
|
uiType: nodeSchema.uiType,
|
||||||
|
// Set customized_name at creation so it persists through save/load
|
||||||
|
...(blockID === SpecialBlockID.MCP_TOOL && {
|
||||||
|
metadata: {
|
||||||
|
credentials_optional: true,
|
||||||
|
...(finalHardcodedValues.selected_tool && {
|
||||||
|
customized_name: `${
|
||||||
|
finalHardcodedValues.server_name ||
|
||||||
|
(() => {
|
||||||
|
try {
|
||||||
|
return new URL(finalHardcodedValues.server_url).hostname;
|
||||||
|
} catch {
|
||||||
|
return "MCP";
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
}: ${beautifyString(finalHardcodedValues.selected_tool)}`,
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
...(blockID === SpecialBlockID.AGENT && {
|
||||||
|
metadata: { customized_name: blockName },
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -877,8 +902,6 @@ const FlowEditor: React.FC<{
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
node.data.metadata?.customized_name ||
|
node.data.metadata?.customized_name ||
|
||||||
(node.data.uiType == BlockUIType.AGENT &&
|
|
||||||
node.data.hardcodedValues.agent_name) ||
|
|
||||||
node.data.blockType.replace(/Block$/, "")
|
node.data.blockType.replace(/Block$/, "")
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -0,0 +1,534 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import React, {
|
||||||
|
useState,
|
||||||
|
useCallback,
|
||||||
|
useRef,
|
||||||
|
useEffect,
|
||||||
|
useContext,
|
||||||
|
} from "react";
|
||||||
|
import {
|
||||||
|
Dialog,
|
||||||
|
DialogContent,
|
||||||
|
DialogDescription,
|
||||||
|
DialogFooter,
|
||||||
|
DialogHeader,
|
||||||
|
DialogTitle,
|
||||||
|
} from "@/components/__legacy__/ui/dialog";
|
||||||
|
import { Button } from "@/components/__legacy__/ui/button";
|
||||||
|
import { Input } from "@/components/__legacy__/ui/input";
|
||||||
|
import { Label } from "@/components/__legacy__/ui/label";
|
||||||
|
import { LoadingSpinner } from "@/components/__legacy__/ui/loading";
|
||||||
|
import { Badge } from "@/components/__legacy__/ui/badge";
|
||||||
|
import { ScrollArea } from "@/components/__legacy__/ui/scroll-area";
|
||||||
|
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||||
|
import type { CredentialsMetaInput, MCPTool } from "@/lib/autogpt-server-api";
|
||||||
|
import { CaretDown } from "@phosphor-icons/react";
|
||||||
|
import { openOAuthPopup } from "@/lib/oauth-popup";
|
||||||
|
import { CredentialsProvidersContext } from "@/providers/agent-credentials/credentials-provider";
|
||||||
|
|
||||||
|
export type MCPToolDialogResult = {
|
||||||
|
serverUrl: string;
|
||||||
|
serverName: string | null;
|
||||||
|
selectedTool: string;
|
||||||
|
toolInputSchema: Record<string, any>;
|
||||||
|
availableTools: Record<string, any>;
|
||||||
|
/** Credentials meta from OAuth flow, null for public servers. */
|
||||||
|
credentials: CredentialsMetaInput | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
interface MCPToolDialogProps {
|
||||||
|
open: boolean;
|
||||||
|
onClose: () => void;
|
||||||
|
onConfirm: (result: MCPToolDialogResult) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
type DialogStep = "url" | "tool";
|
||||||
|
|
||||||
|
export function MCPToolDialog({
|
||||||
|
open,
|
||||||
|
onClose,
|
||||||
|
onConfirm,
|
||||||
|
}: MCPToolDialogProps) {
|
||||||
|
const api = useBackendAPI();
|
||||||
|
const allProviders = useContext(CredentialsProvidersContext);
|
||||||
|
|
||||||
|
const [step, setStep] = useState<DialogStep>("url");
|
||||||
|
const [serverUrl, setServerUrl] = useState("");
|
||||||
|
const [tools, setTools] = useState<MCPTool[]>([]);
|
||||||
|
const [serverName, setServerName] = useState<string | null>(null);
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [error, setError] = useState<string | null>(null);
|
||||||
|
const [authRequired, setAuthRequired] = useState(false);
|
||||||
|
const [oauthLoading, setOauthLoading] = useState(false);
|
||||||
|
const [showManualToken, setShowManualToken] = useState(false);
|
||||||
|
const [manualToken, setManualToken] = useState("");
|
||||||
|
const [selectedTool, setSelectedTool] = useState<MCPTool | null>(null);
|
||||||
|
const [credentials, setCredentials] = useState<CredentialsMetaInput | null>(
|
||||||
|
null,
|
||||||
|
);
|
||||||
|
|
||||||
|
const startOAuthRef = useRef(false);
|
||||||
|
const oauthAbortRef = useRef<((reason?: string) => void) | null>(null);
|
||||||
|
|
||||||
|
// Clean up on unmount
|
||||||
|
useEffect(() => {
|
||||||
|
return () => {
|
||||||
|
oauthAbortRef.current?.();
|
||||||
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const reset = useCallback(() => {
|
||||||
|
oauthAbortRef.current?.();
|
||||||
|
oauthAbortRef.current = null;
|
||||||
|
setStep("url");
|
||||||
|
setServerUrl("");
|
||||||
|
setManualToken("");
|
||||||
|
setTools([]);
|
||||||
|
setServerName(null);
|
||||||
|
setLoading(false);
|
||||||
|
setError(null);
|
||||||
|
setAuthRequired(false);
|
||||||
|
setOauthLoading(false);
|
||||||
|
setShowManualToken(false);
|
||||||
|
setSelectedTool(null);
|
||||||
|
setCredentials(null);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const handleClose = useCallback(() => {
|
||||||
|
reset();
|
||||||
|
onClose();
|
||||||
|
}, [reset, onClose]);
|
||||||
|
|
||||||
|
const discoverTools = useCallback(
|
||||||
|
async (url: string, authToken?: string) => {
|
||||||
|
setLoading(true);
|
||||||
|
setError(null);
|
||||||
|
try {
|
||||||
|
const result = await api.mcpDiscoverTools(url, authToken);
|
||||||
|
setTools(result.tools);
|
||||||
|
setServerName(result.server_name);
|
||||||
|
setAuthRequired(false);
|
||||||
|
setShowManualToken(false);
|
||||||
|
setStep("tool");
|
||||||
|
} catch (e: any) {
|
||||||
|
if (e?.status === 401 || e?.status === 403) {
|
||||||
|
setAuthRequired(true);
|
||||||
|
setError(null);
|
||||||
|
// Automatically start OAuth sign-in instead of requiring a second click
|
||||||
|
setLoading(false);
|
||||||
|
startOAuthRef.current = true;
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
const message =
|
||||||
|
e?.message || e?.detail || "Failed to connect to MCP server";
|
||||||
|
setError(
|
||||||
|
typeof message === "string" ? message : JSON.stringify(message),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[api],
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleDiscoverTools = useCallback(() => {
|
||||||
|
if (!serverUrl.trim()) return;
|
||||||
|
discoverTools(serverUrl.trim(), manualToken.trim() || undefined);
|
||||||
|
}, [serverUrl, manualToken, discoverTools]);
|
||||||
|
|
||||||
|
const handleOAuthSignIn = useCallback(async () => {
|
||||||
|
if (!serverUrl.trim()) return;
|
||||||
|
setError(null);
|
||||||
|
|
||||||
|
// Abort any previous OAuth flow
|
||||||
|
oauthAbortRef.current?.();
|
||||||
|
|
||||||
|
setOauthLoading(true);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const { login_url, state_token } = await api.mcpOAuthLogin(
|
||||||
|
serverUrl.trim(),
|
||||||
|
);
|
||||||
|
|
||||||
|
const { promise, cleanup } = openOAuthPopup(login_url, {
|
||||||
|
stateToken: state_token,
|
||||||
|
useCrossOriginListeners: true,
|
||||||
|
});
|
||||||
|
oauthAbortRef.current = cleanup.abort;
|
||||||
|
|
||||||
|
const result = await promise;
|
||||||
|
|
||||||
|
// Exchange code for tokens via the credentials provider (updates cache)
|
||||||
|
setLoading(true);
|
||||||
|
setOauthLoading(false);
|
||||||
|
|
||||||
|
const mcpProvider = allProviders?.["mcp"];
|
||||||
|
const callbackResult = mcpProvider
|
||||||
|
? await mcpProvider.mcpOAuthCallback(result.code, state_token)
|
||||||
|
: await api.mcpOAuthCallback(result.code, state_token);
|
||||||
|
|
||||||
|
setCredentials({
|
||||||
|
id: callbackResult.id,
|
||||||
|
provider: callbackResult.provider,
|
||||||
|
type: callbackResult.type,
|
||||||
|
title: callbackResult.title,
|
||||||
|
});
|
||||||
|
setAuthRequired(false);
|
||||||
|
|
||||||
|
// Discover tools now that we're authenticated
|
||||||
|
const toolsResult = await api.mcpDiscoverTools(serverUrl.trim());
|
||||||
|
setTools(toolsResult.tools);
|
||||||
|
setServerName(toolsResult.server_name);
|
||||||
|
setStep("tool");
|
||||||
|
} catch (e: any) {
|
||||||
|
// If server doesn't support OAuth → show manual token entry
|
||||||
|
if (e?.status === 400) {
|
||||||
|
setShowManualToken(true);
|
||||||
|
setError(
|
||||||
|
"This server does not support OAuth sign-in. Please enter a token manually.",
|
||||||
|
);
|
||||||
|
} else if (e?.message === "OAuth flow timed out") {
|
||||||
|
setError("OAuth sign-in timed out. Please try again.");
|
||||||
|
} else {
|
||||||
|
const status = e?.status;
|
||||||
|
let message: string;
|
||||||
|
if (status === 401 || status === 403) {
|
||||||
|
message =
|
||||||
|
"Authentication succeeded but the server still rejected the request. " +
|
||||||
|
"The token audience may not match. Please try again.";
|
||||||
|
} else {
|
||||||
|
message = e?.message || e?.detail || "Failed to complete sign-in";
|
||||||
|
}
|
||||||
|
setError(
|
||||||
|
typeof message === "string" ? message : JSON.stringify(message),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
setOauthLoading(false);
|
||||||
|
setLoading(false);
|
||||||
|
oauthAbortRef.current = null;
|
||||||
|
}
|
||||||
|
}, [api, serverUrl, allProviders]);
|
||||||
|
|
||||||
|
// Auto-start OAuth sign-in when server returns 401/403
|
||||||
|
useEffect(() => {
|
||||||
|
if (authRequired && startOAuthRef.current) {
|
||||||
|
startOAuthRef.current = false;
|
||||||
|
handleOAuthSignIn();
|
||||||
|
}
|
||||||
|
}, [authRequired, handleOAuthSignIn]);
|
||||||
|
|
||||||
|
const handleConfirm = useCallback(() => {
|
||||||
|
if (!selectedTool) return;
|
||||||
|
|
||||||
|
const availableTools: Record<string, any> = {};
|
||||||
|
for (const t of tools) {
|
||||||
|
availableTools[t.name] = {
|
||||||
|
description: t.description,
|
||||||
|
input_schema: t.input_schema,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
onConfirm({
|
||||||
|
serverUrl: serverUrl.trim(),
|
||||||
|
serverName,
|
||||||
|
selectedTool: selectedTool.name,
|
||||||
|
toolInputSchema: selectedTool.input_schema,
|
||||||
|
availableTools,
|
||||||
|
credentials,
|
||||||
|
});
|
||||||
|
reset();
|
||||||
|
}, [
|
||||||
|
selectedTool,
|
||||||
|
tools,
|
||||||
|
serverUrl,
|
||||||
|
serverName,
|
||||||
|
credentials,
|
||||||
|
onConfirm,
|
||||||
|
reset,
|
||||||
|
]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Dialog open={open} onOpenChange={(isOpen) => !isOpen && handleClose()}>
|
||||||
|
<DialogContent className="max-w-lg">
|
||||||
|
<DialogHeader>
|
||||||
|
<DialogTitle>
|
||||||
|
{step === "url"
|
||||||
|
? "Connect to MCP Server"
|
||||||
|
: `Select a Tool${serverName ? ` — ${serverName}` : ""}`}
|
||||||
|
</DialogTitle>
|
||||||
|
<DialogDescription>
|
||||||
|
{step === "url"
|
||||||
|
? "Enter the URL of an MCP server to discover its available tools."
|
||||||
|
: `Found ${tools.length} tool${tools.length !== 1 ? "s" : ""}. Select one to add to your agent.`}
|
||||||
|
</DialogDescription>
|
||||||
|
</DialogHeader>
|
||||||
|
|
||||||
|
{step === "url" && (
|
||||||
|
<div className="flex flex-col gap-4 py-2">
|
||||||
|
<div className="flex flex-col gap-2">
|
||||||
|
<Label htmlFor="mcp-server-url">Server URL</Label>
|
||||||
|
<Input
|
||||||
|
id="mcp-server-url"
|
||||||
|
type="url"
|
||||||
|
placeholder="https://mcp.example.com/mcp"
|
||||||
|
value={serverUrl}
|
||||||
|
onChange={(e) => setServerUrl(e.target.value)}
|
||||||
|
onKeyDown={(e) => e.key === "Enter" && handleDiscoverTools()}
|
||||||
|
autoFocus
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Auth required: show manual token option */}
|
||||||
|
{authRequired && !showManualToken && (
|
||||||
|
<button
|
||||||
|
onClick={() => setShowManualToken(true)}
|
||||||
|
className="text-xs text-gray-500 underline hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-300"
|
||||||
|
>
|
||||||
|
or enter a token manually
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Manual token entry — only visible when expanded */}
|
||||||
|
{showManualToken && (
|
||||||
|
<div className="flex flex-col gap-2">
|
||||||
|
<Label htmlFor="mcp-auth-token" className="text-sm">
|
||||||
|
Bearer Token
|
||||||
|
</Label>
|
||||||
|
<Input
|
||||||
|
id="mcp-auth-token"
|
||||||
|
type="password"
|
||||||
|
placeholder="Paste your auth token here"
|
||||||
|
value={manualToken}
|
||||||
|
onChange={(e) => setManualToken(e.target.value)}
|
||||||
|
onKeyDown={(e) => e.key === "Enter" && handleDiscoverTools()}
|
||||||
|
autoFocus
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{error && <p className="text-sm text-red-500">{error}</p>}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{step === "tool" && (
|
||||||
|
<ScrollArea className="max-h-[50vh] py-2">
|
||||||
|
<div className="flex flex-col gap-2 pr-3">
|
||||||
|
{tools.map((tool) => (
|
||||||
|
<MCPToolCard
|
||||||
|
key={tool.name}
|
||||||
|
tool={tool}
|
||||||
|
selected={selectedTool?.name === tool.name}
|
||||||
|
onSelect={() => setSelectedTool(tool)}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</ScrollArea>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<DialogFooter>
|
||||||
|
{step === "tool" && (
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
onClick={() => {
|
||||||
|
setStep("url");
|
||||||
|
setSelectedTool(null);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
Back
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
<Button variant="outline" onClick={handleClose}>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
{step === "url" && (
|
||||||
|
<Button
|
||||||
|
onClick={
|
||||||
|
authRequired && !showManualToken
|
||||||
|
? handleOAuthSignIn
|
||||||
|
: handleDiscoverTools
|
||||||
|
}
|
||||||
|
disabled={!serverUrl.trim() || loading || oauthLoading}
|
||||||
|
>
|
||||||
|
{loading || oauthLoading ? (
|
||||||
|
<span className="flex items-center gap-2">
|
||||||
|
<LoadingSpinner className="size-4" />
|
||||||
|
{oauthLoading ? "Waiting for sign-in..." : "Connecting..."}
|
||||||
|
</span>
|
||||||
|
) : authRequired && !showManualToken ? (
|
||||||
|
"Sign in & Connect"
|
||||||
|
) : (
|
||||||
|
"Discover Tools"
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
{step === "tool" && (
|
||||||
|
<Button onClick={handleConfirm} disabled={!selectedTool}>
|
||||||
|
Add Block
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</DialogFooter>
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------- Tool Card Component --------------- //
|
||||||
|
|
||||||
|
/** Truncate a description to a reasonable length for the collapsed view. */
|
||||||
|
function truncateDescription(text: string, maxLen = 120): string {
|
||||||
|
if (text.length <= maxLen) return text;
|
||||||
|
return text.slice(0, maxLen).trimEnd() + "…";
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Pretty-print a JSON Schema type for a parameter. */
|
||||||
|
function schemaTypeLabel(schema: Record<string, any>): string {
|
||||||
|
if (schema.type) return schema.type;
|
||||||
|
if (schema.anyOf)
|
||||||
|
return schema.anyOf.map((s: any) => s.type ?? "any").join(" | ");
|
||||||
|
if (schema.oneOf)
|
||||||
|
return schema.oneOf.map((s: any) => s.type ?? "any").join(" | ");
|
||||||
|
return "any";
|
||||||
|
}
|
||||||
|
|
||||||
|
function MCPToolCard({
|
||||||
|
tool,
|
||||||
|
selected,
|
||||||
|
onSelect,
|
||||||
|
}: {
|
||||||
|
tool: MCPTool;
|
||||||
|
selected: boolean;
|
||||||
|
onSelect: () => void;
|
||||||
|
}) {
|
||||||
|
const [expanded, setExpanded] = useState(false);
|
||||||
|
const properties = tool.input_schema?.properties ?? {};
|
||||||
|
const required = new Set<string>(tool.input_schema?.required ?? []);
|
||||||
|
const paramNames = Object.keys(properties);
|
||||||
|
|
||||||
|
// Strip XML-like tags from description for cleaner display.
|
||||||
|
// Loop to handle nested tags like <scr<script>ipt> (CodeQL fix).
|
||||||
|
let cleanDescription = tool.description ?? "";
|
||||||
|
let prev = "";
|
||||||
|
while (prev !== cleanDescription) {
|
||||||
|
prev = cleanDescription;
|
||||||
|
cleanDescription = cleanDescription.replace(/<[^>]*>/g, "");
|
||||||
|
}
|
||||||
|
cleanDescription = cleanDescription.trim();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<button
|
||||||
|
onClick={onSelect}
|
||||||
|
className={`group flex flex-col rounded-lg border text-left transition-colors ${
|
||||||
|
selected
|
||||||
|
? "border-blue-500 bg-blue-50 dark:border-blue-400 dark:bg-blue-950"
|
||||||
|
: "border-gray-200 hover:border-gray-300 hover:bg-gray-50 dark:border-slate-700 dark:hover:border-slate-600 dark:hover:bg-slate-800"
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
{/* Header */}
|
||||||
|
<div className="flex items-center gap-2 px-3 pb-1 pt-3">
|
||||||
|
<span className="flex-1 text-sm font-semibold dark:text-white">
|
||||||
|
{tool.name}
|
||||||
|
</span>
|
||||||
|
{paramNames.length > 0 && (
|
||||||
|
<Badge variant="secondary" className="text-[10px]">
|
||||||
|
{paramNames.length} param{paramNames.length !== 1 ? "s" : ""}
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Description (collapsed: truncated) */}
|
||||||
|
{cleanDescription && (
|
||||||
|
<p className="px-3 pb-1 text-xs leading-relaxed text-gray-500 dark:text-gray-400">
|
||||||
|
{expanded ? cleanDescription : truncateDescription(cleanDescription)}
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Parameter badges (collapsed view) */}
|
||||||
|
{!expanded && paramNames.length > 0 && (
|
||||||
|
<div className="flex flex-wrap gap-1 px-3 pb-2">
|
||||||
|
{paramNames.slice(0, 6).map((name) => (
|
||||||
|
<Badge
|
||||||
|
key={name}
|
||||||
|
variant="outline"
|
||||||
|
className="text-[10px] font-normal"
|
||||||
|
>
|
||||||
|
{name}
|
||||||
|
{required.has(name) && (
|
||||||
|
<span className="ml-0.5 text-red-400">*</span>
|
||||||
|
)}
|
||||||
|
</Badge>
|
||||||
|
))}
|
||||||
|
{paramNames.length > 6 && (
|
||||||
|
<Badge variant="outline" className="text-[10px] font-normal">
|
||||||
|
+{paramNames.length - 6} more
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Expanded: full parameter details */}
|
||||||
|
{expanded && paramNames.length > 0 && (
|
||||||
|
<div className="mx-3 mb-2 rounded border border-gray-100 bg-gray-50/50 dark:border-slate-700 dark:bg-slate-800/50">
|
||||||
|
<table className="w-full text-xs">
|
||||||
|
<thead>
|
||||||
|
<tr className="border-b border-gray-100 dark:border-slate-700">
|
||||||
|
<th className="px-2 py-1 text-left font-medium text-gray-500 dark:text-gray-400">
|
||||||
|
Parameter
|
||||||
|
</th>
|
||||||
|
<th className="px-2 py-1 text-left font-medium text-gray-500 dark:text-gray-400">
|
||||||
|
Type
|
||||||
|
</th>
|
||||||
|
<th className="px-2 py-1 text-left font-medium text-gray-500 dark:text-gray-400">
|
||||||
|
Description
|
||||||
|
</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{paramNames.map((name) => {
|
||||||
|
const prop = properties[name] ?? {};
|
||||||
|
return (
|
||||||
|
<tr
|
||||||
|
key={name}
|
||||||
|
className="border-b border-gray-50 last:border-0 dark:border-slate-700/50"
|
||||||
|
>
|
||||||
|
<td className="px-2 py-1 font-mono text-[11px] text-gray-700 dark:text-gray-300">
|
||||||
|
{name}
|
||||||
|
{required.has(name) && (
|
||||||
|
<span className="ml-0.5 text-red-400">*</span>
|
||||||
|
)}
|
||||||
|
</td>
|
||||||
|
<td className="px-2 py-1 text-gray-500 dark:text-gray-400">
|
||||||
|
{schemaTypeLabel(prop)}
|
||||||
|
</td>
|
||||||
|
<td className="max-w-[200px] truncate px-2 py-1 text-gray-500 dark:text-gray-400">
|
||||||
|
{prop.description ?? "—"}
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Toggle details */}
|
||||||
|
{(paramNames.length > 0 || cleanDescription.length > 120) && (
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
setExpanded((prev) => !prev);
|
||||||
|
}}
|
||||||
|
className="flex w-full items-center justify-center gap-1 border-t border-gray-100 py-1.5 text-[10px] text-gray-400 hover:text-gray-600 dark:border-slate-700 dark:text-gray-500 dark:hover:text-gray-300"
|
||||||
|
>
|
||||||
|
{expanded ? "Hide details" : "Show details"}
|
||||||
|
<CaretDown
|
||||||
|
className={`h-3 w-3 transition-transform ${expanded ? "rotate-180" : ""}`}
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
</button>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,76 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||||
|
import { SidebarProvider } from "@/components/ui/sidebar";
|
||||||
|
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
||||||
|
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
|
||||||
|
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
||||||
|
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
|
||||||
|
import { useCopilotPage } from "./useCopilotPage";
|
||||||
|
|
||||||
|
export function CopilotPage() {
|
||||||
|
const {
|
||||||
|
sessionId,
|
||||||
|
messages,
|
||||||
|
status,
|
||||||
|
error,
|
||||||
|
stop,
|
||||||
|
createSession,
|
||||||
|
onSend,
|
||||||
|
isLoadingSession,
|
||||||
|
isCreatingSession,
|
||||||
|
isUserLoading,
|
||||||
|
isLoggedIn,
|
||||||
|
// Mobile drawer
|
||||||
|
isMobile,
|
||||||
|
isDrawerOpen,
|
||||||
|
sessions,
|
||||||
|
isLoadingSessions,
|
||||||
|
handleOpenDrawer,
|
||||||
|
handleCloseDrawer,
|
||||||
|
handleDrawerOpenChange,
|
||||||
|
handleSelectSession,
|
||||||
|
handleNewChat,
|
||||||
|
} = useCopilotPage();
|
||||||
|
|
||||||
|
if (isUserLoading || !isLoggedIn) {
|
||||||
|
return <LoadingSpinner size="large" cover />;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<SidebarProvider
|
||||||
|
defaultOpen={true}
|
||||||
|
className="h-[calc(100vh-72px)] min-h-0"
|
||||||
|
>
|
||||||
|
{!isMobile && <ChatSidebar />}
|
||||||
|
<div className="relative flex h-full w-full flex-col overflow-hidden bg-[#f8f8f9] px-0">
|
||||||
|
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
||||||
|
<div className="flex-1 overflow-hidden">
|
||||||
|
<ChatContainer
|
||||||
|
messages={messages}
|
||||||
|
status={status}
|
||||||
|
error={error}
|
||||||
|
sessionId={sessionId}
|
||||||
|
isLoadingSession={isLoadingSession}
|
||||||
|
isCreatingSession={isCreatingSession}
|
||||||
|
onCreateSession={createSession}
|
||||||
|
onSend={onSend}
|
||||||
|
onStop={stop}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{isMobile && (
|
||||||
|
<MobileDrawer
|
||||||
|
isOpen={isDrawerOpen}
|
||||||
|
sessions={sessions}
|
||||||
|
currentSessionId={sessionId}
|
||||||
|
isLoading={isLoadingSessions}
|
||||||
|
onSelectSession={handleSelectSession}
|
||||||
|
onNewChat={handleNewChat}
|
||||||
|
onClose={handleCloseDrawer}
|
||||||
|
onOpenChange={handleDrawerOpenChange}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</SidebarProvider>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,74 @@
|
|||||||
|
"use client";
|
||||||
|
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
|
||||||
|
import { UIDataTypes, UIMessage, UITools } from "ai";
|
||||||
|
import { LayoutGroup, motion } from "framer-motion";
|
||||||
|
import { ChatMessagesContainer } from "../ChatMessagesContainer/ChatMessagesContainer";
|
||||||
|
import { CopilotChatActionsProvider } from "../CopilotChatActionsProvider/CopilotChatActionsProvider";
|
||||||
|
import { EmptySession } from "../EmptySession/EmptySession";
|
||||||
|
|
||||||
|
export interface ChatContainerProps {
|
||||||
|
messages: UIMessage<unknown, UIDataTypes, UITools>[];
|
||||||
|
status: string;
|
||||||
|
error: Error | undefined;
|
||||||
|
sessionId: string | null;
|
||||||
|
isLoadingSession: boolean;
|
||||||
|
isCreatingSession: boolean;
|
||||||
|
onCreateSession: () => void | Promise<string>;
|
||||||
|
onSend: (message: string) => void | Promise<void>;
|
||||||
|
onStop: () => void;
|
||||||
|
}
|
||||||
|
export const ChatContainer = ({
|
||||||
|
messages,
|
||||||
|
status,
|
||||||
|
error,
|
||||||
|
sessionId,
|
||||||
|
isLoadingSession,
|
||||||
|
isCreatingSession,
|
||||||
|
onCreateSession,
|
||||||
|
onSend,
|
||||||
|
onStop,
|
||||||
|
}: ChatContainerProps) => {
|
||||||
|
const inputLayoutId = "copilot-2-chat-input";
|
||||||
|
|
||||||
|
return (
|
||||||
|
<CopilotChatActionsProvider onSend={onSend}>
|
||||||
|
<LayoutGroup id="copilot-2-chat-layout">
|
||||||
|
<div className="flex h-full min-h-0 w-full flex-col bg-[#f8f8f9] px-2 lg:px-0">
|
||||||
|
{sessionId ? (
|
||||||
|
<div className="mx-auto flex h-full min-h-0 w-full max-w-3xl flex-col">
|
||||||
|
<ChatMessagesContainer
|
||||||
|
messages={messages}
|
||||||
|
status={status}
|
||||||
|
error={error}
|
||||||
|
isLoading={isLoadingSession}
|
||||||
|
/>
|
||||||
|
<motion.div
|
||||||
|
initial={{ opacity: 0 }}
|
||||||
|
animate={{ opacity: 1 }}
|
||||||
|
transition={{ duration: 0.3 }}
|
||||||
|
className="relative px-3 pb-2 pt-2"
|
||||||
|
>
|
||||||
|
<div className="pointer-events-none absolute left-0 right-0 top-[-18px] z-10 h-6 bg-gradient-to-b from-transparent to-[#f8f8f9]" />
|
||||||
|
<ChatInput
|
||||||
|
inputId="chat-input-session"
|
||||||
|
onSend={onSend}
|
||||||
|
disabled={status === "streaming"}
|
||||||
|
isStreaming={status === "streaming"}
|
||||||
|
onStop={onStop}
|
||||||
|
placeholder="What else can I help with?"
|
||||||
|
/>
|
||||||
|
</motion.div>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<EmptySession
|
||||||
|
inputLayoutId={inputLayoutId}
|
||||||
|
isCreatingSession={isCreatingSession}
|
||||||
|
onCreateSession={onCreateSession}
|
||||||
|
onSend={onSend}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</LayoutGroup>
|
||||||
|
</CopilotChatActionsProvider>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -6,17 +6,19 @@ import {
|
|||||||
MicrophoneIcon,
|
MicrophoneIcon,
|
||||||
StopIcon,
|
StopIcon,
|
||||||
} from "@phosphor-icons/react";
|
} from "@phosphor-icons/react";
|
||||||
|
import { ChangeEvent, useCallback } from "react";
|
||||||
import { RecordingIndicator } from "./components/RecordingIndicator";
|
import { RecordingIndicator } from "./components/RecordingIndicator";
|
||||||
import { useChatInput } from "./useChatInput";
|
import { useChatInput } from "./useChatInput";
|
||||||
import { useVoiceRecording } from "./useVoiceRecording";
|
import { useVoiceRecording } from "./useVoiceRecording";
|
||||||
|
|
||||||
export interface Props {
|
export interface Props {
|
||||||
onSend: (message: string) => void;
|
onSend: (message: string) => void | Promise<void>;
|
||||||
disabled?: boolean;
|
disabled?: boolean;
|
||||||
isStreaming?: boolean;
|
isStreaming?: boolean;
|
||||||
onStop?: () => void;
|
onStop?: () => void;
|
||||||
placeholder?: string;
|
placeholder?: string;
|
||||||
className?: string;
|
className?: string;
|
||||||
|
inputId?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ChatInput({
|
export function ChatInput({
|
||||||
@@ -26,14 +28,14 @@ export function ChatInput({
|
|||||||
onStop,
|
onStop,
|
||||||
placeholder = "Type your message...",
|
placeholder = "Type your message...",
|
||||||
className,
|
className,
|
||||||
|
inputId = "chat-input",
|
||||||
}: Props) {
|
}: Props) {
|
||||||
const inputId = "chat-input";
|
|
||||||
const {
|
const {
|
||||||
value,
|
value,
|
||||||
setValue,
|
setValue,
|
||||||
handleKeyDown: baseHandleKeyDown,
|
handleKeyDown: baseHandleKeyDown,
|
||||||
handleSubmit,
|
handleSubmit,
|
||||||
handleChange,
|
handleChange: baseHandleChange,
|
||||||
hasMultipleLines,
|
hasMultipleLines,
|
||||||
} = useChatInput({
|
} = useChatInput({
|
||||||
onSend,
|
onSend,
|
||||||
@@ -60,6 +62,15 @@ export function ChatInput({
|
|||||||
inputId,
|
inputId,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Block text changes when recording
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||||
|
if (isRecording) return;
|
||||||
|
baseHandleChange(e);
|
||||||
|
},
|
||||||
|
[isRecording, baseHandleChange],
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<form onSubmit={handleSubmit} className={cn("relative flex-1", className)}>
|
<form onSubmit={handleSubmit} className={cn("relative flex-1", className)}>
|
||||||
<div className="relative">
|
<div className="relative">
|
||||||
@@ -21,6 +21,7 @@ export function useChatInput({
|
|||||||
}: Args) {
|
}: Args) {
|
||||||
const [value, setValue] = useState("");
|
const [value, setValue] = useState("");
|
||||||
const [hasMultipleLines, setHasMultipleLines] = useState(false);
|
const [hasMultipleLines, setHasMultipleLines] = useState(false);
|
||||||
|
const [isSending, setIsSending] = useState(false);
|
||||||
|
|
||||||
useEffect(
|
useEffect(
|
||||||
function focusOnMount() {
|
function focusOnMount() {
|
||||||
@@ -100,9 +101,12 @@ export function useChatInput({
|
|||||||
}
|
}
|
||||||
}, [value, maxRows, inputId]);
|
}, [value, maxRows, inputId]);
|
||||||
|
|
||||||
const handleSend = () => {
|
async function handleSend() {
|
||||||
if (disabled || !value.trim()) return;
|
if (disabled || isSending || !value.trim()) return;
|
||||||
onSend(value.trim());
|
|
||||||
|
setIsSending(true);
|
||||||
|
try {
|
||||||
|
await onSend(value.trim());
|
||||||
setValue("");
|
setValue("");
|
||||||
setHasMultipleLines(false);
|
setHasMultipleLines(false);
|
||||||
const textarea = document.getElementById(inputId) as HTMLTextAreaElement;
|
const textarea = document.getElementById(inputId) as HTMLTextAreaElement;
|
||||||
@@ -116,18 +120,21 @@ export function useChatInput({
|
|||||||
wrapper.style.height = "";
|
wrapper.style.height = "";
|
||||||
wrapper.style.maxHeight = "";
|
wrapper.style.maxHeight = "";
|
||||||
}
|
}
|
||||||
};
|
} finally {
|
||||||
|
setIsSending(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function handleKeyDown(event: KeyboardEvent<HTMLTextAreaElement>) {
|
function handleKeyDown(event: KeyboardEvent<HTMLTextAreaElement>) {
|
||||||
if (event.key === "Enter" && !event.shiftKey) {
|
if (event.key === "Enter" && !event.shiftKey) {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
handleSend();
|
void handleSend();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleSubmit(e: FormEvent<HTMLFormElement>) {
|
function handleSubmit(e: FormEvent<HTMLFormElement>) {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
handleSend();
|
void handleSend();
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleChange(e: ChangeEvent<HTMLTextAreaElement>) {
|
function handleChange(e: ChangeEvent<HTMLTextAreaElement>) {
|
||||||
@@ -142,5 +149,6 @@ export function useChatInput({
|
|||||||
handleSubmit,
|
handleSubmit,
|
||||||
handleChange,
|
handleChange,
|
||||||
hasMultipleLines,
|
hasMultipleLines,
|
||||||
|
isSending,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@@ -38,9 +38,13 @@ export function useVoiceRecording({
|
|||||||
const streamRef = useRef<MediaStream | null>(null);
|
const streamRef = useRef<MediaStream | null>(null);
|
||||||
const isRecordingRef = useRef(false);
|
const isRecordingRef = useRef(false);
|
||||||
|
|
||||||
const isSupported =
|
const [isSupported, setIsSupported] = useState(false);
|
||||||
typeof window !== "undefined" &&
|
|
||||||
!!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
|
useEffect(() => {
|
||||||
|
setIsSupported(
|
||||||
|
!!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia),
|
||||||
|
);
|
||||||
|
}, []);
|
||||||
|
|
||||||
const clearTimer = useCallback(() => {
|
const clearTimer = useCallback(() => {
|
||||||
if (timerRef.current) {
|
if (timerRef.current) {
|
||||||
@@ -214,17 +218,33 @@ export function useVoiceRecording({
|
|||||||
|
|
||||||
const handleKeyDown = useCallback(
|
const handleKeyDown = useCallback(
|
||||||
(event: KeyboardEvent<HTMLTextAreaElement>) => {
|
(event: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||||
if (event.key === " " && !value.trim() && !isTranscribing) {
|
// Allow space to toggle recording (start when empty, stop when recording)
|
||||||
|
if (event.key === " " && !isTranscribing) {
|
||||||
|
if (isRecordingRef.current) {
|
||||||
|
// Stop recording on space
|
||||||
|
event.preventDefault();
|
||||||
|
stopRecording();
|
||||||
|
return;
|
||||||
|
} else if (!value.trim()) {
|
||||||
|
// Start recording on space when input is empty
|
||||||
|
event.preventDefault();
|
||||||
|
void startRecording();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Block all key events when recording (except space handled above)
|
||||||
|
if (isRecordingRef.current) {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
toggleRecording();
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
baseHandleKeyDown(event);
|
baseHandleKeyDown(event);
|
||||||
},
|
},
|
||||||
[value, isTranscribing, toggleRecording, baseHandleKeyDown],
|
[value, isTranscribing, stopRecording, startRecording, baseHandleKeyDown],
|
||||||
);
|
);
|
||||||
|
|
||||||
const showMicButton = isSupported;
|
const showMicButton = isSupported;
|
||||||
|
// Don't include isRecording in disabled state - we need key events to work
|
||||||
|
// Text input is blocked via handleKeyDown instead
|
||||||
const isInputDisabled = disabled || isStreaming || isTranscribing;
|
const isInputDisabled = disabled || isStreaming || isTranscribing;
|
||||||
|
|
||||||
// Cleanup on unmount
|
// Cleanup on unmount
|
||||||
@@ -0,0 +1,274 @@
|
|||||||
|
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
|
||||||
|
import {
|
||||||
|
Conversation,
|
||||||
|
ConversationContent,
|
||||||
|
ConversationScrollButton,
|
||||||
|
} from "@/components/ai-elements/conversation";
|
||||||
|
import {
|
||||||
|
Message,
|
||||||
|
MessageContent,
|
||||||
|
MessageResponse,
|
||||||
|
} from "@/components/ai-elements/message";
|
||||||
|
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||||
|
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
import { CreateAgentTool } from "../../tools/CreateAgent/CreateAgent";
|
||||||
|
import { EditAgentTool } from "../../tools/EditAgent/EditAgent";
|
||||||
|
import { FindAgentsTool } from "../../tools/FindAgents/FindAgents";
|
||||||
|
import { FindBlocksTool } from "../../tools/FindBlocks/FindBlocks";
|
||||||
|
import { RunAgentTool } from "../../tools/RunAgent/RunAgent";
|
||||||
|
import { RunBlockTool } from "../../tools/RunBlock/RunBlock";
|
||||||
|
import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs";
|
||||||
|
import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput";
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Workspace media support
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Resolve workspace:// URLs in markdown text to proxy download URLs.
|
||||||
|
* Detects MIME type from the hash fragment (e.g. workspace://id#video/mp4)
|
||||||
|
* and prefixes the alt text with "video:" so the custom img component can
|
||||||
|
* render a <video> element instead.
|
||||||
|
*/
|
||||||
|
function resolveWorkspaceUrls(text: string): string {
|
||||||
|
return text.replace(
|
||||||
|
/!\[([^\]]*)\]\(workspace:\/\/([^)#\s]+)(?:#([^)\s]*))?\)/g,
|
||||||
|
(_match, alt: string, fileId: string, mimeHint?: string) => {
|
||||||
|
const apiPath = getGetWorkspaceDownloadFileByIdUrl(fileId);
|
||||||
|
const url = `/api/proxy${apiPath}`;
|
||||||
|
if (mimeHint?.startsWith("video/")) {
|
||||||
|
return ``;
|
||||||
|
}
|
||||||
|
return ``;
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Custom img component for Streamdown that renders <video> elements
|
||||||
|
* for workspace video files (detected via "video:" alt-text prefix).
|
||||||
|
* Falls back to <video> when an <img> fails to load for workspace files.
|
||||||
|
*/
|
||||||
|
function WorkspaceMediaImage(props: React.JSX.IntrinsicElements["img"]) {
|
||||||
|
const { src, alt, ...rest } = props;
|
||||||
|
const [imgFailed, setImgFailed] = useState(false);
|
||||||
|
const isWorkspace = src?.includes("/workspace/files/") ?? false;
|
||||||
|
|
||||||
|
if (!src) return null;
|
||||||
|
|
||||||
|
if (alt?.startsWith("video:") || (imgFailed && isWorkspace)) {
|
||||||
|
return (
|
||||||
|
<span className="my-2 inline-block">
|
||||||
|
<video
|
||||||
|
controls
|
||||||
|
className="h-auto max-w-full rounded-md border border-zinc-200"
|
||||||
|
preload="metadata"
|
||||||
|
>
|
||||||
|
<source src={src} />
|
||||||
|
Your browser does not support the video tag.
|
||||||
|
</video>
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
// eslint-disable-next-line @next/next/no-img-element
|
||||||
|
<img
|
||||||
|
src={src}
|
||||||
|
alt={alt || "Image"}
|
||||||
|
className="h-auto max-w-full rounded-md border border-zinc-200"
|
||||||
|
loading="lazy"
|
||||||
|
onError={() => {
|
||||||
|
if (isWorkspace) setImgFailed(true);
|
||||||
|
}}
|
||||||
|
{...rest}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Stable components override for Streamdown (avoids re-creating on every render). */
|
||||||
|
const STREAMDOWN_COMPONENTS = { img: WorkspaceMediaImage };
|
||||||
|
|
||||||
|
const THINKING_PHRASES = [
|
||||||
|
"Thinking...",
|
||||||
|
"Considering this...",
|
||||||
|
"Working through this...",
|
||||||
|
"Analyzing your request...",
|
||||||
|
"Reasoning...",
|
||||||
|
"Looking into it...",
|
||||||
|
"Processing your request...",
|
||||||
|
"Mulling this over...",
|
||||||
|
"Piecing it together...",
|
||||||
|
"On it...",
|
||||||
|
];
|
||||||
|
|
||||||
|
function getRandomPhrase() {
|
||||||
|
return THINKING_PHRASES[Math.floor(Math.random() * THINKING_PHRASES.length)];
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ChatMessagesContainerProps {
|
||||||
|
messages: UIMessage<unknown, UIDataTypes, UITools>[];
|
||||||
|
status: string;
|
||||||
|
error: Error | undefined;
|
||||||
|
isLoading: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const ChatMessagesContainer = ({
|
||||||
|
messages,
|
||||||
|
status,
|
||||||
|
error,
|
||||||
|
isLoading,
|
||||||
|
}: ChatMessagesContainerProps) => {
|
||||||
|
const [thinkingPhrase, setThinkingPhrase] = useState(getRandomPhrase);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (status === "submitted") {
|
||||||
|
setThinkingPhrase(getRandomPhrase());
|
||||||
|
}
|
||||||
|
}, [status]);
|
||||||
|
|
||||||
|
const lastMessage = messages[messages.length - 1];
|
||||||
|
const lastAssistantHasVisibleContent =
|
||||||
|
lastMessage?.role === "assistant" &&
|
||||||
|
lastMessage.parts.some(
|
||||||
|
(p) =>
|
||||||
|
(p.type === "text" && p.text.trim().length > 0) ||
|
||||||
|
p.type.startsWith("tool-"),
|
||||||
|
);
|
||||||
|
|
||||||
|
const showThinking =
|
||||||
|
status === "submitted" ||
|
||||||
|
(status === "streaming" && !lastAssistantHasVisibleContent);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Conversation className="min-h-0 flex-1">
|
||||||
|
<ConversationContent className="gap-6 px-3 py-6">
|
||||||
|
{isLoading && messages.length === 0 && (
|
||||||
|
<div className="flex flex-1 items-center justify-center">
|
||||||
|
<LoadingSpinner size="large" className="text-neutral-400" />
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{messages.map((message, messageIndex) => {
|
||||||
|
const isLastAssistant =
|
||||||
|
messageIndex === messages.length - 1 &&
|
||||||
|
message.role === "assistant";
|
||||||
|
const messageHasVisibleContent = message.parts.some(
|
||||||
|
(p) =>
|
||||||
|
(p.type === "text" && p.text.trim().length > 0) ||
|
||||||
|
p.type.startsWith("tool-"),
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Message from={message.role} key={message.id}>
|
||||||
|
<MessageContent
|
||||||
|
className={
|
||||||
|
"text-[1rem] leading-relaxed " +
|
||||||
|
"group-[.is-user]:rounded-xl group-[.is-user]:bg-purple-100 group-[.is-user]:px-3 group-[.is-user]:py-2.5 group-[.is-user]:text-slate-900 group-[.is-user]:[border-bottom-right-radius:0] " +
|
||||||
|
"group-[.is-assistant]:bg-transparent group-[.is-assistant]:text-slate-900"
|
||||||
|
}
|
||||||
|
>
|
||||||
|
{message.parts.map((part, i) => {
|
||||||
|
switch (part.type) {
|
||||||
|
case "text":
|
||||||
|
return (
|
||||||
|
<MessageResponse
|
||||||
|
key={`${message.id}-${i}`}
|
||||||
|
components={STREAMDOWN_COMPONENTS}
|
||||||
|
>
|
||||||
|
{resolveWorkspaceUrls(part.text)}
|
||||||
|
</MessageResponse>
|
||||||
|
);
|
||||||
|
case "tool-find_block":
|
||||||
|
return (
|
||||||
|
<FindBlocksTool
|
||||||
|
key={`${message.id}-${i}`}
|
||||||
|
part={part as ToolUIPart}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
case "tool-find_agent":
|
||||||
|
case "tool-find_library_agent":
|
||||||
|
return (
|
||||||
|
<FindAgentsTool
|
||||||
|
key={`${message.id}-${i}`}
|
||||||
|
part={part as ToolUIPart}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
case "tool-search_docs":
|
||||||
|
case "tool-get_doc_page":
|
||||||
|
return (
|
||||||
|
<SearchDocsTool
|
||||||
|
key={`${message.id}-${i}`}
|
||||||
|
part={part as ToolUIPart}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
case "tool-run_block":
|
||||||
|
return (
|
||||||
|
<RunBlockTool
|
||||||
|
key={`${message.id}-${i}`}
|
||||||
|
part={part as ToolUIPart}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
case "tool-run_agent":
|
||||||
|
case "tool-schedule_agent":
|
||||||
|
return (
|
||||||
|
<RunAgentTool
|
||||||
|
key={`${message.id}-${i}`}
|
||||||
|
part={part as ToolUIPart}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
case "tool-create_agent":
|
||||||
|
return (
|
||||||
|
<CreateAgentTool
|
||||||
|
key={`${message.id}-${i}`}
|
||||||
|
part={part as ToolUIPart}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
case "tool-edit_agent":
|
||||||
|
return (
|
||||||
|
<EditAgentTool
|
||||||
|
key={`${message.id}-${i}`}
|
||||||
|
part={part as ToolUIPart}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
case "tool-view_agent_output":
|
||||||
|
return (
|
||||||
|
<ViewAgentOutputTool
|
||||||
|
key={`${message.id}-${i}`}
|
||||||
|
part={part as ToolUIPart}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
default:
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
})}
|
||||||
|
{isLastAssistant &&
|
||||||
|
!messageHasVisibleContent &&
|
||||||
|
showThinking && (
|
||||||
|
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
||||||
|
{thinkingPhrase}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</MessageContent>
|
||||||
|
</Message>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
{showThinking && lastMessage?.role !== "assistant" && (
|
||||||
|
<Message from="assistant">
|
||||||
|
<MessageContent className="text-[1rem] leading-relaxed">
|
||||||
|
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
||||||
|
{thinkingPhrase}
|
||||||
|
</span>
|
||||||
|
</MessageContent>
|
||||||
|
</Message>
|
||||||
|
)}
|
||||||
|
{error && (
|
||||||
|
<div className="rounded-lg bg-red-50 p-3 text-red-600">
|
||||||
|
Error: {error.message}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</ConversationContent>
|
||||||
|
<ConversationScrollButton />
|
||||||
|
</Conversation>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -0,0 +1,188 @@
|
|||||||
|
"use client";
|
||||||
|
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||||
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
|
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||||
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
|
import {
|
||||||
|
Sidebar,
|
||||||
|
SidebarContent,
|
||||||
|
SidebarFooter,
|
||||||
|
SidebarHeader,
|
||||||
|
SidebarTrigger,
|
||||||
|
useSidebar,
|
||||||
|
} from "@/components/ui/sidebar";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import { PlusCircleIcon, PlusIcon } from "@phosphor-icons/react";
|
||||||
|
import { motion } from "framer-motion";
|
||||||
|
import { parseAsString, useQueryState } from "nuqs";
|
||||||
|
|
||||||
|
export function ChatSidebar() {
|
||||||
|
const { state } = useSidebar();
|
||||||
|
const isCollapsed = state === "collapsed";
|
||||||
|
const [sessionId, setSessionId] = useQueryState("sessionId", parseAsString);
|
||||||
|
|
||||||
|
const { data: sessionsResponse, isLoading: isLoadingSessions } =
|
||||||
|
useGetV2ListSessions({ limit: 50 });
|
||||||
|
|
||||||
|
const sessions =
|
||||||
|
sessionsResponse?.status === 200 ? sessionsResponse.data.sessions : [];
|
||||||
|
|
||||||
|
function handleNewChat() {
|
||||||
|
setSessionId(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleSelectSession(id: string) {
|
||||||
|
setSessionId(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatDate(dateString: string) {
|
||||||
|
const date = new Date(dateString);
|
||||||
|
const now = new Date();
|
||||||
|
const diffMs = now.getTime() - date.getTime();
|
||||||
|
const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24));
|
||||||
|
|
||||||
|
if (diffDays === 0) return "Today";
|
||||||
|
if (diffDays === 1) return "Yesterday";
|
||||||
|
if (diffDays < 7) return `${diffDays} days ago`;
|
||||||
|
|
||||||
|
const day = date.getDate();
|
||||||
|
const ordinal =
|
||||||
|
day % 10 === 1 && day !== 11
|
||||||
|
? "st"
|
||||||
|
: day % 10 === 2 && day !== 12
|
||||||
|
? "nd"
|
||||||
|
: day % 10 === 3 && day !== 13
|
||||||
|
? "rd"
|
||||||
|
: "th";
|
||||||
|
const month = date.toLocaleDateString("en-US", { month: "short" });
|
||||||
|
const year = date.getFullYear();
|
||||||
|
|
||||||
|
return `${day}${ordinal} ${month} ${year}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Sidebar
|
||||||
|
variant="inset"
|
||||||
|
collapsible="icon"
|
||||||
|
className="!top-[50px] !h-[calc(100vh-50px)] border-r border-zinc-100 px-0"
|
||||||
|
>
|
||||||
|
{isCollapsed && (
|
||||||
|
<SidebarHeader
|
||||||
|
className={cn(
|
||||||
|
"flex",
|
||||||
|
isCollapsed
|
||||||
|
? "flex-row items-center justify-between gap-y-4 md:flex-col md:items-start md:justify-start"
|
||||||
|
: "flex-row items-center justify-between",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<motion.div
|
||||||
|
key={isCollapsed ? "header-collapsed" : "header-expanded"}
|
||||||
|
className="flex flex-col items-center gap-3 pt-4"
|
||||||
|
initial={{ opacity: 0, filter: "blur(3px)" }}
|
||||||
|
animate={{ opacity: 1, filter: "blur(0px)" }}
|
||||||
|
transition={{ type: "spring", bounce: 0.2 }}
|
||||||
|
>
|
||||||
|
<div className="flex flex-col items-center gap-2">
|
||||||
|
<SidebarTrigger />
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
onClick={handleNewChat}
|
||||||
|
style={{ minWidth: "auto", width: "auto" }}
|
||||||
|
>
|
||||||
|
<PlusCircleIcon className="!size-5" />
|
||||||
|
<span className="sr-only">New Chat</span>
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</motion.div>
|
||||||
|
</SidebarHeader>
|
||||||
|
)}
|
||||||
|
<SidebarContent className="gap-4 overflow-y-auto px-4 py-4 [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
||||||
|
{!isCollapsed && (
|
||||||
|
<motion.div
|
||||||
|
initial={{ opacity: 0 }}
|
||||||
|
animate={{ opacity: 1 }}
|
||||||
|
transition={{ duration: 0.2, delay: 0.1 }}
|
||||||
|
className="flex items-center justify-between px-3"
|
||||||
|
>
|
||||||
|
<Text variant="h3" size="body-medium">
|
||||||
|
Your chats
|
||||||
|
</Text>
|
||||||
|
<div className="relative left-6">
|
||||||
|
<SidebarTrigger />
|
||||||
|
</div>
|
||||||
|
</motion.div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{!isCollapsed && (
|
||||||
|
<motion.div
|
||||||
|
initial={{ opacity: 0 }}
|
||||||
|
animate={{ opacity: 1 }}
|
||||||
|
transition={{ duration: 0.2, delay: 0.15 }}
|
||||||
|
className="mt-4 flex flex-col gap-1"
|
||||||
|
>
|
||||||
|
{isLoadingSessions ? (
|
||||||
|
<div className="flex items-center justify-center py-4">
|
||||||
|
<LoadingSpinner size="small" className="text-neutral-400" />
|
||||||
|
</div>
|
||||||
|
) : sessions.length === 0 ? (
|
||||||
|
<p className="py-4 text-center text-sm text-neutral-500">
|
||||||
|
No conversations yet
|
||||||
|
</p>
|
||||||
|
) : (
|
||||||
|
sessions.map((session) => (
|
||||||
|
<button
|
||||||
|
key={session.id}
|
||||||
|
onClick={() => handleSelectSession(session.id)}
|
||||||
|
className={cn(
|
||||||
|
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
|
||||||
|
session.id === sessionId
|
||||||
|
? "bg-zinc-100"
|
||||||
|
: "hover:bg-zinc-50",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
|
||||||
|
<div className="min-w-0 max-w-full">
|
||||||
|
<Text
|
||||||
|
variant="body"
|
||||||
|
className={cn(
|
||||||
|
"truncate font-normal",
|
||||||
|
session.id === sessionId
|
||||||
|
? "text-zinc-600"
|
||||||
|
: "text-zinc-800",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{session.title || `Untitled chat`}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
<Text variant="small" className="text-neutral-400">
|
||||||
|
{formatDate(session.updated_at)}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
</button>
|
||||||
|
))
|
||||||
|
)}
|
||||||
|
</motion.div>
|
||||||
|
)}
|
||||||
|
</SidebarContent>
|
||||||
|
{!isCollapsed && sessionId && (
|
||||||
|
<SidebarFooter className="shrink-0 bg-zinc-50 p-3 pb-1 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
|
||||||
|
<motion.div
|
||||||
|
initial={{ opacity: 0 }}
|
||||||
|
animate={{ opacity: 1 }}
|
||||||
|
transition={{ duration: 0.2, delay: 0.2 }}
|
||||||
|
>
|
||||||
|
<Button
|
||||||
|
variant="primary"
|
||||||
|
size="small"
|
||||||
|
onClick={handleNewChat}
|
||||||
|
className="w-full"
|
||||||
|
leftIcon={<PlusIcon className="h-4 w-4" weight="bold" />}
|
||||||
|
>
|
||||||
|
New Chat
|
||||||
|
</Button>
|
||||||
|
</motion.div>
|
||||||
|
</SidebarFooter>
|
||||||
|
)}
|
||||||
|
</Sidebar>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { CopilotChatActionsContext } from "./useCopilotChatActions";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
onSend: (message: string) => void | Promise<void>;
|
||||||
|
children: React.ReactNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function CopilotChatActionsProvider({ onSend, children }: Props) {
|
||||||
|
return (
|
||||||
|
<CopilotChatActionsContext.Provider value={{ onSend }}>
|
||||||
|
{children}
|
||||||
|
</CopilotChatActionsContext.Provider>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { createContext, useContext } from "react";
|
||||||
|
|
||||||
|
interface CopilotChatActions {
|
||||||
|
onSend: (message: string) => void | Promise<void>;
|
||||||
|
}
|
||||||
|
|
||||||
|
const CopilotChatActionsContext = createContext<CopilotChatActions | null>(
|
||||||
|
null,
|
||||||
|
);
|
||||||
|
|
||||||
|
export function useCopilotChatActions(): CopilotChatActions {
|
||||||
|
const ctx = useContext(CopilotChatActionsContext);
|
||||||
|
if (!ctx) {
|
||||||
|
throw new Error(
|
||||||
|
"useCopilotChatActions must be used within CopilotChatActionsProvider",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
export { CopilotChatActionsContext };
|
||||||
@@ -1,99 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { ChatLoader } from "@/components/contextual/Chat/components/ChatLoader/ChatLoader";
|
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
|
||||||
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
|
|
||||||
import type { ReactNode } from "react";
|
|
||||||
import { DesktopSidebar } from "./components/DesktopSidebar/DesktopSidebar";
|
|
||||||
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
|
||||||
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
|
|
||||||
import { useCopilotShell } from "./useCopilotShell";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
children: ReactNode;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function CopilotShell({ children }: Props) {
|
|
||||||
const {
|
|
||||||
isMobile,
|
|
||||||
isDrawerOpen,
|
|
||||||
isLoading,
|
|
||||||
isCreatingSession,
|
|
||||||
isLoggedIn,
|
|
||||||
hasActiveSession,
|
|
||||||
sessions,
|
|
||||||
currentSessionId,
|
|
||||||
handleOpenDrawer,
|
|
||||||
handleCloseDrawer,
|
|
||||||
handleDrawerOpenChange,
|
|
||||||
handleNewChatClick,
|
|
||||||
handleSessionClick,
|
|
||||||
hasNextPage,
|
|
||||||
isFetchingNextPage,
|
|
||||||
fetchNextPage,
|
|
||||||
} = useCopilotShell();
|
|
||||||
|
|
||||||
if (!isLoggedIn) {
|
|
||||||
return (
|
|
||||||
<div className="flex h-full items-center justify-center">
|
|
||||||
<ChatLoader />
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div
|
|
||||||
className="flex overflow-hidden bg-[#EFEFF0]"
|
|
||||||
style={{ height: `calc(100vh - ${NAVBAR_HEIGHT_PX}px)` }}
|
|
||||||
>
|
|
||||||
{!isMobile && (
|
|
||||||
<DesktopSidebar
|
|
||||||
sessions={sessions}
|
|
||||||
currentSessionId={currentSessionId}
|
|
||||||
isLoading={isLoading}
|
|
||||||
hasNextPage={hasNextPage}
|
|
||||||
isFetchingNextPage={isFetchingNextPage}
|
|
||||||
onSelectSession={handleSessionClick}
|
|
||||||
onFetchNextPage={fetchNextPage}
|
|
||||||
onNewChat={handleNewChatClick}
|
|
||||||
hasActiveSession={Boolean(hasActiveSession)}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<div className="relative flex min-h-0 flex-1 flex-col">
|
|
||||||
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
|
|
||||||
<div className="flex min-h-0 flex-1 flex-col">
|
|
||||||
{isCreatingSession ? (
|
|
||||||
<div className="flex h-full flex-1 flex-col items-center justify-center bg-[#f8f8f9]">
|
|
||||||
<div className="flex flex-col items-center gap-4">
|
|
||||||
<ChatLoader />
|
|
||||||
<Text variant="body" className="text-zinc-500">
|
|
||||||
Creating your chat...
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
) : (
|
|
||||||
children
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{isMobile && (
|
|
||||||
<MobileDrawer
|
|
||||||
isOpen={isDrawerOpen}
|
|
||||||
sessions={sessions}
|
|
||||||
currentSessionId={currentSessionId}
|
|
||||||
isLoading={isLoading}
|
|
||||||
hasNextPage={hasNextPage}
|
|
||||||
isFetchingNextPage={isFetchingNextPage}
|
|
||||||
onSelectSession={handleSessionClick}
|
|
||||||
onFetchNextPage={fetchNextPage}
|
|
||||||
onNewChat={handleNewChatClick}
|
|
||||||
onClose={handleCloseDrawer}
|
|
||||||
onOpenChange={handleDrawerOpenChange}
|
|
||||||
hasActiveSession={Boolean(hasActiveSession)}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
|
||||||
import { scrollbarStyles } from "@/components/styles/scrollbars";
|
|
||||||
import { cn } from "@/lib/utils";
|
|
||||||
import { Plus } from "@phosphor-icons/react";
|
|
||||||
import { SessionsList } from "../SessionsList/SessionsList";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
sessions: SessionSummaryResponse[];
|
|
||||||
currentSessionId: string | null;
|
|
||||||
isLoading: boolean;
|
|
||||||
hasNextPage: boolean;
|
|
||||||
isFetchingNextPage: boolean;
|
|
||||||
onSelectSession: (sessionId: string) => void;
|
|
||||||
onFetchNextPage: () => void;
|
|
||||||
onNewChat: () => void;
|
|
||||||
hasActiveSession: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function DesktopSidebar({
|
|
||||||
sessions,
|
|
||||||
currentSessionId,
|
|
||||||
isLoading,
|
|
||||||
hasNextPage,
|
|
||||||
isFetchingNextPage,
|
|
||||||
onSelectSession,
|
|
||||||
onFetchNextPage,
|
|
||||||
onNewChat,
|
|
||||||
hasActiveSession,
|
|
||||||
}: Props) {
|
|
||||||
return (
|
|
||||||
<aside className="flex h-full w-80 flex-col border-r border-zinc-100 bg-zinc-50">
|
|
||||||
<div className="shrink-0 px-6 py-4">
|
|
||||||
<Text variant="h3" size="body-medium">
|
|
||||||
Your chats
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
<div
|
|
||||||
className={cn(
|
|
||||||
"flex min-h-0 flex-1 flex-col overflow-y-auto px-3 py-3",
|
|
||||||
scrollbarStyles,
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
<SessionsList
|
|
||||||
sessions={sessions}
|
|
||||||
currentSessionId={currentSessionId}
|
|
||||||
isLoading={isLoading}
|
|
||||||
hasNextPage={hasNextPage}
|
|
||||||
isFetchingNextPage={isFetchingNextPage}
|
|
||||||
onSelectSession={onSelectSession}
|
|
||||||
onFetchNextPage={onFetchNextPage}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
{hasActiveSession && (
|
|
||||||
<div className="shrink-0 bg-zinc-50 p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
|
|
||||||
<Button
|
|
||||||
variant="primary"
|
|
||||||
size="small"
|
|
||||||
onClick={onNewChat}
|
|
||||||
className="w-full"
|
|
||||||
leftIcon={<Plus width="1rem" height="1rem" />}
|
|
||||||
>
|
|
||||||
New Chat
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</aside>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
|
||||||
import { scrollbarStyles } from "@/components/styles/scrollbars";
|
|
||||||
import { cn } from "@/lib/utils";
|
|
||||||
import { PlusIcon, X } from "@phosphor-icons/react";
|
|
||||||
import { Drawer } from "vaul";
|
|
||||||
import { SessionsList } from "../SessionsList/SessionsList";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
isOpen: boolean;
|
|
||||||
sessions: SessionSummaryResponse[];
|
|
||||||
currentSessionId: string | null;
|
|
||||||
isLoading: boolean;
|
|
||||||
hasNextPage: boolean;
|
|
||||||
isFetchingNextPage: boolean;
|
|
||||||
onSelectSession: (sessionId: string) => void;
|
|
||||||
onFetchNextPage: () => void;
|
|
||||||
onNewChat: () => void;
|
|
||||||
onClose: () => void;
|
|
||||||
onOpenChange: (open: boolean) => void;
|
|
||||||
hasActiveSession: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function MobileDrawer({
|
|
||||||
isOpen,
|
|
||||||
sessions,
|
|
||||||
currentSessionId,
|
|
||||||
isLoading,
|
|
||||||
hasNextPage,
|
|
||||||
isFetchingNextPage,
|
|
||||||
onSelectSession,
|
|
||||||
onFetchNextPage,
|
|
||||||
onNewChat,
|
|
||||||
onClose,
|
|
||||||
onOpenChange,
|
|
||||||
hasActiveSession,
|
|
||||||
}: Props) {
|
|
||||||
return (
|
|
||||||
<Drawer.Root open={isOpen} onOpenChange={onOpenChange} direction="left">
|
|
||||||
<Drawer.Portal>
|
|
||||||
<Drawer.Overlay className="fixed inset-0 z-[60] bg-black/10 backdrop-blur-sm" />
|
|
||||||
<Drawer.Content className="fixed left-0 top-0 z-[70] flex h-full w-80 flex-col border-r border-zinc-200 bg-zinc-50">
|
|
||||||
<div className="shrink-0 border-b border-zinc-200 p-4">
|
|
||||||
<div className="flex items-center justify-between">
|
|
||||||
<Drawer.Title className="text-lg font-semibold text-zinc-800">
|
|
||||||
Your chats
|
|
||||||
</Drawer.Title>
|
|
||||||
<Button
|
|
||||||
variant="icon"
|
|
||||||
size="icon"
|
|
||||||
aria-label="Close sessions"
|
|
||||||
onClick={onClose}
|
|
||||||
>
|
|
||||||
<X width="1.25rem" height="1.25rem" />
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div
|
|
||||||
className={cn(
|
|
||||||
"flex min-h-0 flex-1 flex-col overflow-y-auto px-3 py-3",
|
|
||||||
scrollbarStyles,
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
<SessionsList
|
|
||||||
sessions={sessions}
|
|
||||||
currentSessionId={currentSessionId}
|
|
||||||
isLoading={isLoading}
|
|
||||||
hasNextPage={hasNextPage}
|
|
||||||
isFetchingNextPage={isFetchingNextPage}
|
|
||||||
onSelectSession={onSelectSession}
|
|
||||||
onFetchNextPage={onFetchNextPage}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
{hasActiveSession && (
|
|
||||||
<div className="shrink-0 bg-white p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
|
|
||||||
<Button
|
|
||||||
variant="primary"
|
|
||||||
size="small"
|
|
||||||
onClick={onNewChat}
|
|
||||||
className="w-full"
|
|
||||||
leftIcon={<PlusIcon width="1rem" height="1rem" />}
|
|
||||||
>
|
|
||||||
New Chat
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</Drawer.Content>
|
|
||||||
</Drawer.Portal>
|
|
||||||
</Drawer.Root>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
import { useState } from "react";
|
|
||||||
|
|
||||||
export function useMobileDrawer() {
|
|
||||||
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
|
|
||||||
|
|
||||||
const handleOpenDrawer = () => {
|
|
||||||
setIsDrawerOpen(true);
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleCloseDrawer = () => {
|
|
||||||
setIsDrawerOpen(false);
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleDrawerOpenChange = (open: boolean) => {
|
|
||||||
setIsDrawerOpen(open);
|
|
||||||
};
|
|
||||||
|
|
||||||
return {
|
|
||||||
isDrawerOpen,
|
|
||||||
handleOpenDrawer,
|
|
||||||
handleCloseDrawer,
|
|
||||||
handleDrawerOpenChange,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
|
||||||
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
|
||||||
import { InfiniteList } from "@/components/molecules/InfiniteList/InfiniteList";
|
|
||||||
import { cn } from "@/lib/utils";
|
|
||||||
import { getSessionTitle } from "../../helpers";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
sessions: SessionSummaryResponse[];
|
|
||||||
currentSessionId: string | null;
|
|
||||||
isLoading: boolean;
|
|
||||||
hasNextPage: boolean;
|
|
||||||
isFetchingNextPage: boolean;
|
|
||||||
onSelectSession: (sessionId: string) => void;
|
|
||||||
onFetchNextPage: () => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function SessionsList({
|
|
||||||
sessions,
|
|
||||||
currentSessionId,
|
|
||||||
isLoading,
|
|
||||||
hasNextPage,
|
|
||||||
isFetchingNextPage,
|
|
||||||
onSelectSession,
|
|
||||||
onFetchNextPage,
|
|
||||||
}: Props) {
|
|
||||||
if (isLoading) {
|
|
||||||
return (
|
|
||||||
<div className="space-y-1">
|
|
||||||
{Array.from({ length: 5 }).map((_, i) => (
|
|
||||||
<div key={i} className="rounded-lg px-3 py-2.5">
|
|
||||||
<Skeleton className="h-5 w-full" />
|
|
||||||
</div>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (sessions.length === 0) {
|
|
||||||
return (
|
|
||||||
<div className="flex h-full items-center justify-center">
|
|
||||||
<Text variant="body" className="text-zinc-500">
|
|
||||||
You don't have previous chats
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<InfiniteList
|
|
||||||
items={sessions}
|
|
||||||
hasMore={hasNextPage}
|
|
||||||
isFetchingMore={isFetchingNextPage}
|
|
||||||
onEndReached={onFetchNextPage}
|
|
||||||
className="space-y-1"
|
|
||||||
renderItem={(session) => {
|
|
||||||
const isActive = session.id === currentSessionId;
|
|
||||||
return (
|
|
||||||
<button
|
|
||||||
onClick={() => onSelectSession(session.id)}
|
|
||||||
className={cn(
|
|
||||||
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
|
|
||||||
isActive ? "bg-zinc-100" : "hover:bg-zinc-50",
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
<Text
|
|
||||||
variant="body"
|
|
||||||
className={cn(
|
|
||||||
"font-normal",
|
|
||||||
isActive ? "text-zinc-600" : "text-zinc-800",
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
{getSessionTitle(session)}
|
|
||||||
</Text>
|
|
||||||
</button>
|
|
||||||
);
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
|
|
||||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
|
||||||
import { okData } from "@/app/api/helpers";
|
|
||||||
import { useEffect, useState } from "react";
|
|
||||||
|
|
||||||
const PAGE_SIZE = 50;
|
|
||||||
|
|
||||||
export interface UseSessionsPaginationArgs {
|
|
||||||
enabled: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
|
||||||
const [offset, setOffset] = useState(0);
|
|
||||||
|
|
||||||
const [accumulatedSessions, setAccumulatedSessions] = useState<
|
|
||||||
SessionSummaryResponse[]
|
|
||||||
>([]);
|
|
||||||
|
|
||||||
const [totalCount, setTotalCount] = useState<number | null>(null);
|
|
||||||
|
|
||||||
const { data, isLoading, isFetching, isError } = useGetV2ListSessions(
|
|
||||||
{ limit: PAGE_SIZE, offset },
|
|
||||||
{
|
|
||||||
query: {
|
|
||||||
enabled: enabled && offset >= 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
const responseData = okData(data);
|
|
||||||
if (responseData) {
|
|
||||||
const newSessions = responseData.sessions;
|
|
||||||
const total = responseData.total;
|
|
||||||
setTotalCount(total);
|
|
||||||
|
|
||||||
if (offset === 0) {
|
|
||||||
setAccumulatedSessions(newSessions);
|
|
||||||
} else {
|
|
||||||
setAccumulatedSessions((prev) => [...prev, ...newSessions]);
|
|
||||||
}
|
|
||||||
} else if (!enabled) {
|
|
||||||
setAccumulatedSessions([]);
|
|
||||||
setTotalCount(null);
|
|
||||||
}
|
|
||||||
}, [data, offset, enabled]);
|
|
||||||
|
|
||||||
const hasNextPage =
|
|
||||||
totalCount !== null && accumulatedSessions.length < totalCount;
|
|
||||||
|
|
||||||
const areAllSessionsLoaded =
|
|
||||||
totalCount !== null &&
|
|
||||||
accumulatedSessions.length >= totalCount &&
|
|
||||||
!isFetching &&
|
|
||||||
!isLoading;
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (
|
|
||||||
hasNextPage &&
|
|
||||||
!isFetching &&
|
|
||||||
!isLoading &&
|
|
||||||
!isError &&
|
|
||||||
totalCount !== null
|
|
||||||
) {
|
|
||||||
setOffset((prev) => prev + PAGE_SIZE);
|
|
||||||
}
|
|
||||||
}, [hasNextPage, isFetching, isLoading, isError, totalCount]);
|
|
||||||
|
|
||||||
const fetchNextPage = () => {
|
|
||||||
if (hasNextPage && !isFetching) {
|
|
||||||
setOffset((prev) => prev + PAGE_SIZE);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const reset = () => {
|
|
||||||
// Only reset the offset - keep existing sessions visible during refetch
|
|
||||||
// The effect will replace sessions when new data arrives at offset 0
|
|
||||||
setOffset(0);
|
|
||||||
};
|
|
||||||
|
|
||||||
return {
|
|
||||||
sessions: accumulatedSessions,
|
|
||||||
isLoading,
|
|
||||||
isFetching,
|
|
||||||
hasNextPage,
|
|
||||||
areAllSessionsLoaded,
|
|
||||||
totalCount,
|
|
||||||
fetchNextPage,
|
|
||||||
reset,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
@@ -1,106 +0,0 @@
|
|||||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
|
||||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
|
||||||
import { format, formatDistanceToNow, isToday } from "date-fns";
|
|
||||||
|
|
||||||
export function convertSessionDetailToSummary(session: SessionDetailResponse) {
|
|
||||||
return {
|
|
||||||
id: session.id,
|
|
||||||
created_at: session.created_at,
|
|
||||||
updated_at: session.updated_at,
|
|
||||||
title: undefined,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
export function filterVisibleSessions(sessions: SessionSummaryResponse[]) {
|
|
||||||
const fiveMinutesAgo = Date.now() - 5 * 60 * 1000;
|
|
||||||
return sessions.filter((session) => {
|
|
||||||
const hasBeenUpdated = session.updated_at !== session.created_at;
|
|
||||||
|
|
||||||
if (hasBeenUpdated) return true;
|
|
||||||
|
|
||||||
const isRecentlyCreated =
|
|
||||||
new Date(session.created_at).getTime() > fiveMinutesAgo;
|
|
||||||
|
|
||||||
return isRecentlyCreated;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getSessionTitle(session: SessionSummaryResponse) {
|
|
||||||
if (session.title) return session.title;
|
|
||||||
|
|
||||||
const isNewSession = session.updated_at === session.created_at;
|
|
||||||
|
|
||||||
if (isNewSession) {
|
|
||||||
const createdDate = new Date(session.created_at);
|
|
||||||
if (isToday(createdDate)) {
|
|
||||||
return "Today";
|
|
||||||
}
|
|
||||||
return format(createdDate, "MMM d, yyyy");
|
|
||||||
}
|
|
||||||
|
|
||||||
return "Untitled Chat";
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getSessionUpdatedLabel(session: SessionSummaryResponse) {
|
|
||||||
if (!session.updated_at) return "";
|
|
||||||
return formatDistanceToNow(new Date(session.updated_at), { addSuffix: true });
|
|
||||||
}
|
|
||||||
|
|
||||||
export function mergeCurrentSessionIntoList(
|
|
||||||
accumulatedSessions: SessionSummaryResponse[],
|
|
||||||
currentSessionId: string | null,
|
|
||||||
currentSessionData: SessionDetailResponse | null | undefined,
|
|
||||||
recentlyCreatedSessions?: Map<string, SessionSummaryResponse>,
|
|
||||||
) {
|
|
||||||
const filteredSessions: SessionSummaryResponse[] = [];
|
|
||||||
const addedIds = new Set<string>();
|
|
||||||
|
|
||||||
if (accumulatedSessions.length > 0) {
|
|
||||||
const visibleSessions = filterVisibleSessions(accumulatedSessions);
|
|
||||||
|
|
||||||
if (currentSessionId) {
|
|
||||||
const currentInAll = accumulatedSessions.find(
|
|
||||||
(s) => s.id === currentSessionId,
|
|
||||||
);
|
|
||||||
if (currentInAll) {
|
|
||||||
const isInVisible = visibleSessions.some(
|
|
||||||
(s) => s.id === currentSessionId,
|
|
||||||
);
|
|
||||||
if (!isInVisible) {
|
|
||||||
filteredSessions.push(currentInAll);
|
|
||||||
addedIds.add(currentInAll.id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const session of visibleSessions) {
|
|
||||||
if (!addedIds.has(session.id)) {
|
|
||||||
filteredSessions.push(session);
|
|
||||||
addedIds.add(session.id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (currentSessionId && currentSessionData) {
|
|
||||||
if (!addedIds.has(currentSessionId)) {
|
|
||||||
const summarySession = convertSessionDetailToSummary(currentSessionData);
|
|
||||||
filteredSessions.unshift(summarySession);
|
|
||||||
addedIds.add(currentSessionId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (recentlyCreatedSessions) {
|
|
||||||
for (const [sessionId, sessionData] of recentlyCreatedSessions) {
|
|
||||||
if (!addedIds.has(sessionId)) {
|
|
||||||
filteredSessions.unshift(sessionData);
|
|
||||||
addedIds.add(sessionId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return filteredSessions;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getCurrentSessionId(searchParams: URLSearchParams) {
|
|
||||||
return searchParams.get("sessionId");
|
|
||||||
}
|
|
||||||
@@ -1,124 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import {
|
|
||||||
getGetV2GetSessionQueryKey,
|
|
||||||
getGetV2ListSessionsQueryKey,
|
|
||||||
useGetV2GetSession,
|
|
||||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
|
||||||
import { okData } from "@/app/api/helpers";
|
|
||||||
import { useChatStore } from "@/components/contextual/Chat/chat-store";
|
|
||||||
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
|
||||||
import { usePathname, useSearchParams } from "next/navigation";
|
|
||||||
import { useCopilotStore } from "../../copilot-page-store";
|
|
||||||
import { useCopilotSessionId } from "../../useCopilotSessionId";
|
|
||||||
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
|
||||||
import { getCurrentSessionId } from "./helpers";
|
|
||||||
import { useShellSessionList } from "./useShellSessionList";
|
|
||||||
|
|
||||||
export function useCopilotShell() {
|
|
||||||
const pathname = usePathname();
|
|
||||||
const searchParams = useSearchParams();
|
|
||||||
const queryClient = useQueryClient();
|
|
||||||
const breakpoint = useBreakpoint();
|
|
||||||
const { isLoggedIn } = useSupabase();
|
|
||||||
const isMobile =
|
|
||||||
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
|
||||||
|
|
||||||
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
|
|
||||||
|
|
||||||
const isOnHomepage = pathname === "/copilot";
|
|
||||||
const paramSessionId = searchParams.get("sessionId");
|
|
||||||
|
|
||||||
const {
|
|
||||||
isDrawerOpen,
|
|
||||||
handleOpenDrawer,
|
|
||||||
handleCloseDrawer,
|
|
||||||
handleDrawerOpenChange,
|
|
||||||
} = useMobileDrawer();
|
|
||||||
|
|
||||||
const paginationEnabled = !isMobile || isDrawerOpen || !!paramSessionId;
|
|
||||||
|
|
||||||
const currentSessionId = getCurrentSessionId(searchParams);
|
|
||||||
|
|
||||||
const { data: currentSessionData } = useGetV2GetSession(
|
|
||||||
currentSessionId || "",
|
|
||||||
{
|
|
||||||
query: {
|
|
||||||
enabled: !!currentSessionId,
|
|
||||||
select: okData,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
const {
|
|
||||||
sessions,
|
|
||||||
isLoading,
|
|
||||||
isSessionsFetching,
|
|
||||||
hasNextPage,
|
|
||||||
fetchNextPage,
|
|
||||||
resetPagination,
|
|
||||||
recentlyCreatedSessionsRef,
|
|
||||||
} = useShellSessionList({
|
|
||||||
paginationEnabled,
|
|
||||||
currentSessionId,
|
|
||||||
currentSessionData,
|
|
||||||
isOnHomepage,
|
|
||||||
paramSessionId,
|
|
||||||
});
|
|
||||||
|
|
||||||
const stopStream = useChatStore((s) => s.stopStream);
|
|
||||||
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
|
|
||||||
|
|
||||||
function handleSessionClick(sessionId: string) {
|
|
||||||
if (sessionId === currentSessionId) return;
|
|
||||||
|
|
||||||
// Stop current stream - SSE reconnection allows resuming later
|
|
||||||
if (currentSessionId) {
|
|
||||||
stopStream(currentSessionId);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
|
|
||||||
queryClient.invalidateQueries({
|
|
||||||
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
setUrlSessionId(sessionId, { shallow: false });
|
|
||||||
if (isMobile) handleCloseDrawer();
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleNewChatClick() {
|
|
||||||
// Stop current stream - SSE reconnection allows resuming later
|
|
||||||
if (currentSessionId) {
|
|
||||||
stopStream(currentSessionId);
|
|
||||||
}
|
|
||||||
|
|
||||||
resetPagination();
|
|
||||||
queryClient.invalidateQueries({
|
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
|
||||||
});
|
|
||||||
setUrlSessionId(null, { shallow: false });
|
|
||||||
if (isMobile) handleCloseDrawer();
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
isMobile,
|
|
||||||
isDrawerOpen,
|
|
||||||
isLoggedIn,
|
|
||||||
hasActiveSession:
|
|
||||||
Boolean(currentSessionId) && (!isOnHomepage || Boolean(paramSessionId)),
|
|
||||||
isLoading: isLoading || isCreatingSession,
|
|
||||||
isCreatingSession,
|
|
||||||
sessions,
|
|
||||||
currentSessionId: urlSessionId,
|
|
||||||
handleOpenDrawer,
|
|
||||||
handleCloseDrawer,
|
|
||||||
handleDrawerOpenChange,
|
|
||||||
handleNewChatClick,
|
|
||||||
handleSessionClick,
|
|
||||||
hasNextPage,
|
|
||||||
isFetchingNextPage: isSessionsFetching,
|
|
||||||
fetchNextPage,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
@@ -1,113 +0,0 @@
|
|||||||
import { getGetV2ListSessionsQueryKey } from "@/app/api/__generated__/endpoints/chat/chat";
|
|
||||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
|
||||||
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
|
||||||
import { useChatStore } from "@/components/contextual/Chat/chat-store";
|
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
|
||||||
import { useEffect, useMemo, useRef } from "react";
|
|
||||||
import { useSessionsPagination } from "./components/SessionsList/useSessionsPagination";
|
|
||||||
import {
|
|
||||||
convertSessionDetailToSummary,
|
|
||||||
filterVisibleSessions,
|
|
||||||
mergeCurrentSessionIntoList,
|
|
||||||
} from "./helpers";
|
|
||||||
|
|
||||||
interface UseShellSessionListArgs {
|
|
||||||
paginationEnabled: boolean;
|
|
||||||
currentSessionId: string | null;
|
|
||||||
currentSessionData: SessionDetailResponse | null | undefined;
|
|
||||||
isOnHomepage: boolean;
|
|
||||||
paramSessionId: string | null;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function useShellSessionList({
|
|
||||||
paginationEnabled,
|
|
||||||
currentSessionId,
|
|
||||||
currentSessionData,
|
|
||||||
isOnHomepage,
|
|
||||||
paramSessionId,
|
|
||||||
}: UseShellSessionListArgs) {
|
|
||||||
const queryClient = useQueryClient();
|
|
||||||
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
|
|
||||||
|
|
||||||
const {
|
|
||||||
sessions: accumulatedSessions,
|
|
||||||
isLoading: isSessionsLoading,
|
|
||||||
isFetching: isSessionsFetching,
|
|
||||||
hasNextPage,
|
|
||||||
fetchNextPage,
|
|
||||||
reset: resetPagination,
|
|
||||||
} = useSessionsPagination({
|
|
||||||
enabled: paginationEnabled,
|
|
||||||
});
|
|
||||||
|
|
||||||
const recentlyCreatedSessionsRef = useRef<
|
|
||||||
Map<string, SessionSummaryResponse>
|
|
||||||
>(new Map());
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (isOnHomepage && !paramSessionId) {
|
|
||||||
queryClient.invalidateQueries({
|
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, [isOnHomepage, paramSessionId, queryClient]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (currentSessionId && currentSessionData) {
|
|
||||||
const isNewSession =
|
|
||||||
currentSessionData.updated_at === currentSessionData.created_at;
|
|
||||||
const isNotInAccumulated = !accumulatedSessions.some(
|
|
||||||
(s) => s.id === currentSessionId,
|
|
||||||
);
|
|
||||||
if (isNewSession || isNotInAccumulated) {
|
|
||||||
const summary = convertSessionDetailToSummary(currentSessionData);
|
|
||||||
recentlyCreatedSessionsRef.current.set(currentSessionId, summary);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}, [currentSessionId, currentSessionData, accumulatedSessions]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
for (const sessionId of recentlyCreatedSessionsRef.current.keys()) {
|
|
||||||
if (accumulatedSessions.some((s) => s.id === sessionId)) {
|
|
||||||
recentlyCreatedSessionsRef.current.delete(sessionId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}, [accumulatedSessions]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
const unsubscribe = onStreamComplete(() => {
|
|
||||||
queryClient.invalidateQueries({
|
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
|
||||||
});
|
|
||||||
});
|
|
||||||
return unsubscribe;
|
|
||||||
}, [onStreamComplete, queryClient]);
|
|
||||||
|
|
||||||
const sessions = useMemo(
|
|
||||||
() =>
|
|
||||||
mergeCurrentSessionIntoList(
|
|
||||||
accumulatedSessions,
|
|
||||||
currentSessionId,
|
|
||||||
currentSessionData,
|
|
||||||
recentlyCreatedSessionsRef.current,
|
|
||||||
),
|
|
||||||
[accumulatedSessions, currentSessionId, currentSessionData],
|
|
||||||
);
|
|
||||||
|
|
||||||
const visibleSessions = useMemo(
|
|
||||||
() => filterVisibleSessions(sessions),
|
|
||||||
[sessions],
|
|
||||||
);
|
|
||||||
|
|
||||||
const isLoading = isSessionsLoading && accumulatedSessions.length === 0;
|
|
||||||
|
|
||||||
return {
|
|
||||||
sessions: visibleSessions,
|
|
||||||
isLoading,
|
|
||||||
isSessionsFetching,
|
|
||||||
hasNextPage,
|
|
||||||
fetchNextPage,
|
|
||||||
resetPagination,
|
|
||||||
recentlyCreatedSessionsRef,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,111 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
|
||||||
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
|
import { SpinnerGapIcon } from "@phosphor-icons/react";
|
||||||
|
import { motion } from "framer-motion";
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
import {
|
||||||
|
getGreetingName,
|
||||||
|
getInputPlaceholder,
|
||||||
|
getQuickActions,
|
||||||
|
} from "./helpers";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
inputLayoutId: string;
|
||||||
|
isCreatingSession: boolean;
|
||||||
|
onCreateSession: () => void | Promise<string>;
|
||||||
|
onSend: (message: string) => void | Promise<void>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function EmptySession({
|
||||||
|
inputLayoutId,
|
||||||
|
isCreatingSession,
|
||||||
|
onSend,
|
||||||
|
}: Props) {
|
||||||
|
const { user } = useSupabase();
|
||||||
|
const greetingName = getGreetingName(user);
|
||||||
|
const quickActions = getQuickActions();
|
||||||
|
const [loadingAction, setLoadingAction] = useState<string | null>(null);
|
||||||
|
const [inputPlaceholder, setInputPlaceholder] = useState(
|
||||||
|
getInputPlaceholder(),
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
|
||||||
|
}, [window.innerWidth]);
|
||||||
|
|
||||||
|
async function handleQuickActionClick(action: string) {
|
||||||
|
if (isCreatingSession || loadingAction) return;
|
||||||
|
|
||||||
|
setLoadingAction(action);
|
||||||
|
try {
|
||||||
|
await onSend(action);
|
||||||
|
} finally {
|
||||||
|
setLoadingAction(null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-0 py-5 md:px-6 md:py-10">
|
||||||
|
<motion.div
|
||||||
|
className="w-full max-w-3xl text-center"
|
||||||
|
initial={{ opacity: 0 }}
|
||||||
|
animate={{ opacity: 1 }}
|
||||||
|
transition={{ duration: 0.3 }}
|
||||||
|
>
|
||||||
|
<div className="mx-auto max-w-3xl">
|
||||||
|
<Text variant="h3" className="mb-1 !text-[1.375rem] text-zinc-700">
|
||||||
|
Hey, <span className="text-violet-600">{greetingName}</span>
|
||||||
|
</Text>
|
||||||
|
<Text variant="h3" className="mb-8 !font-normal">
|
||||||
|
Tell me about your work — I'll find what to automate.
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
<div className="mb-6">
|
||||||
|
<motion.div
|
||||||
|
layoutId={inputLayoutId}
|
||||||
|
transition={{ type: "spring", bounce: 0.2, duration: 0.65 }}
|
||||||
|
className="w-full px-2"
|
||||||
|
>
|
||||||
|
<ChatInput
|
||||||
|
inputId="chat-input-empty"
|
||||||
|
onSend={onSend}
|
||||||
|
disabled={isCreatingSession}
|
||||||
|
placeholder={inputPlaceholder}
|
||||||
|
className="w-full"
|
||||||
|
/>
|
||||||
|
</motion.div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex flex-wrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
||||||
|
{quickActions.map((action) => (
|
||||||
|
<Button
|
||||||
|
key={action}
|
||||||
|
type="button"
|
||||||
|
variant="outline"
|
||||||
|
size="small"
|
||||||
|
onClick={() => void handleQuickActionClick(action)}
|
||||||
|
disabled={isCreatingSession || loadingAction !== null}
|
||||||
|
aria-busy={loadingAction === action}
|
||||||
|
leftIcon={
|
||||||
|
loadingAction === action ? (
|
||||||
|
<SpinnerGapIcon
|
||||||
|
className="h-4 w-4 animate-spin"
|
||||||
|
weight="bold"
|
||||||
|
/>
|
||||||
|
) : null
|
||||||
|
}
|
||||||
|
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
|
||||||
|
>
|
||||||
|
{action}
|
||||||
|
</Button>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</motion.div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -1,6 +1,26 @@
|
|||||||
import type { User } from "@supabase/supabase-js";
|
import { User } from "@supabase/supabase-js";
|
||||||
|
|
||||||
export function getGreetingName(user?: User | null): string {
|
export function getInputPlaceholder(width?: number) {
|
||||||
|
if (!width) return "What's your role and what eats up most of your day?";
|
||||||
|
|
||||||
|
if (width < 500) {
|
||||||
|
return "I'm a chef and I hate...";
|
||||||
|
}
|
||||||
|
if (width <= 1080) {
|
||||||
|
return "What's your role and what eats up most of your day?";
|
||||||
|
}
|
||||||
|
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getQuickActions() {
|
||||||
|
return [
|
||||||
|
"I don't know where to start, just ask me stuff",
|
||||||
|
"I do the same thing every week and it's killing me",
|
||||||
|
"Help me find where I'm wasting my time",
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getGreetingName(user?: User | null) {
|
||||||
if (!user) return "there";
|
if (!user) return "there";
|
||||||
const metadata = user.user_metadata as Record<string, unknown> | undefined;
|
const metadata = user.user_metadata as Record<string, unknown> | undefined;
|
||||||
const fullName = metadata?.full_name;
|
const fullName = metadata?.full_name;
|
||||||
@@ -16,30 +36,3 @@ export function getGreetingName(user?: User | null): string {
|
|||||||
}
|
}
|
||||||
return "there";
|
return "there";
|
||||||
}
|
}
|
||||||
|
|
||||||
export function buildCopilotChatUrl(prompt: string): string {
|
|
||||||
const trimmed = prompt.trim();
|
|
||||||
if (!trimmed) return "/copilot/chat";
|
|
||||||
const encoded = encodeURIComponent(trimmed);
|
|
||||||
return `/copilot/chat?prompt=${encoded}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getQuickActions(): string[] {
|
|
||||||
return [
|
|
||||||
"I don't know where to start, just ask me stuff",
|
|
||||||
"I do the same thing every week and it's killing me",
|
|
||||||
"Help me find where I'm wasting my time",
|
|
||||||
];
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getInputPlaceholder(width?: number) {
|
|
||||||
if (!width) return "What's your role and what eats up most of your day?";
|
|
||||||
|
|
||||||
if (width < 500) {
|
|
||||||
return "I'm a chef and I hate...";
|
|
||||||
}
|
|
||||||
if (width <= 1080) {
|
|
||||||
return "What's your role and what eats up most of your day?";
|
|
||||||
}
|
|
||||||
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,140 @@
|
|||||||
|
import type { SessionSummaryResponse } from "@/app/api/__generated__/models/sessionSummaryResponse";
|
||||||
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
|
import { scrollbarStyles } from "@/components/styles/scrollbars";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import { PlusIcon, SpinnerGapIcon, X } from "@phosphor-icons/react";
|
||||||
|
import { Drawer } from "vaul";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
isOpen: boolean;
|
||||||
|
sessions: SessionSummaryResponse[];
|
||||||
|
currentSessionId: string | null;
|
||||||
|
isLoading: boolean;
|
||||||
|
onSelectSession: (sessionId: string) => void;
|
||||||
|
onNewChat: () => void;
|
||||||
|
onClose: () => void;
|
||||||
|
onOpenChange: (open: boolean) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatDate(dateString: string) {
|
||||||
|
const date = new Date(dateString);
|
||||||
|
const now = new Date();
|
||||||
|
const diffMs = now.getTime() - date.getTime();
|
||||||
|
const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24));
|
||||||
|
|
||||||
|
if (diffDays === 0) return "Today";
|
||||||
|
if (diffDays === 1) return "Yesterday";
|
||||||
|
if (diffDays < 7) return `${diffDays} days ago`;
|
||||||
|
|
||||||
|
const day = date.getDate();
|
||||||
|
const ordinal =
|
||||||
|
day % 10 === 1 && day !== 11
|
||||||
|
? "st"
|
||||||
|
: day % 10 === 2 && day !== 12
|
||||||
|
? "nd"
|
||||||
|
: day % 10 === 3 && day !== 13
|
||||||
|
? "rd"
|
||||||
|
: "th";
|
||||||
|
const month = date.toLocaleDateString("en-US", { month: "short" });
|
||||||
|
const year = date.getFullYear();
|
||||||
|
|
||||||
|
return `${day}${ordinal} ${month} ${year}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function MobileDrawer({
|
||||||
|
isOpen,
|
||||||
|
sessions,
|
||||||
|
currentSessionId,
|
||||||
|
isLoading,
|
||||||
|
onSelectSession,
|
||||||
|
onNewChat,
|
||||||
|
onClose,
|
||||||
|
onOpenChange,
|
||||||
|
}: Props) {
|
||||||
|
return (
|
||||||
|
<Drawer.Root open={isOpen} onOpenChange={onOpenChange} direction="left">
|
||||||
|
<Drawer.Portal>
|
||||||
|
<Drawer.Overlay className="fixed inset-0 z-[60] bg-black/10 backdrop-blur-sm" />
|
||||||
|
<Drawer.Content className="fixed left-0 top-0 z-[70] flex h-full w-80 flex-col border-r border-zinc-200 bg-zinc-50">
|
||||||
|
<div className="shrink-0 border-b border-zinc-200 px-4 py-2">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<Drawer.Title className="text-lg font-semibold text-zinc-800">
|
||||||
|
Your chats
|
||||||
|
</Drawer.Title>
|
||||||
|
<Button
|
||||||
|
variant="icon"
|
||||||
|
size="icon"
|
||||||
|
aria-label="Close sessions"
|
||||||
|
onClick={onClose}
|
||||||
|
>
|
||||||
|
<X width="1rem" height="1rem" />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"flex min-h-0 flex-1 flex-col gap-1 overflow-y-auto px-3 py-3",
|
||||||
|
scrollbarStyles,
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{isLoading ? (
|
||||||
|
<div className="flex items-center justify-center py-4">
|
||||||
|
<SpinnerGapIcon className="h-5 w-5 animate-spin text-neutral-400" />
|
||||||
|
</div>
|
||||||
|
) : sessions.length === 0 ? (
|
||||||
|
<p className="py-4 text-center text-sm text-neutral-500">
|
||||||
|
No conversations yet
|
||||||
|
</p>
|
||||||
|
) : (
|
||||||
|
sessions.map((session) => (
|
||||||
|
<button
|
||||||
|
key={session.id}
|
||||||
|
onClick={() => onSelectSession(session.id)}
|
||||||
|
className={cn(
|
||||||
|
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
|
||||||
|
session.id === currentSessionId
|
||||||
|
? "bg-zinc-100"
|
||||||
|
: "hover:bg-zinc-50",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
|
||||||
|
<div className="min-w-0 max-w-full">
|
||||||
|
<Text
|
||||||
|
variant="body"
|
||||||
|
className={cn(
|
||||||
|
"truncate font-normal",
|
||||||
|
session.id === currentSessionId
|
||||||
|
? "text-zinc-600"
|
||||||
|
: "text-zinc-800",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{session.title || "Untitled chat"}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
<Text variant="small" className="text-neutral-400">
|
||||||
|
{formatDate(session.updated_at)}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
</button>
|
||||||
|
))
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
{currentSessionId && (
|
||||||
|
<div className="shrink-0 bg-white p-3 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
|
||||||
|
<Button
|
||||||
|
variant="primary"
|
||||||
|
size="small"
|
||||||
|
onClick={onNewChat}
|
||||||
|
className="w-full"
|
||||||
|
leftIcon={<PlusIcon width="1rem" height="1rem" />}
|
||||||
|
>
|
||||||
|
New Chat
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</Drawer.Content>
|
||||||
|
</Drawer.Portal>
|
||||||
|
</Drawer.Root>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import { AnimatePresence, motion } from "framer-motion";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
text: string;
|
||||||
|
className?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function MorphingTextAnimation({ text, className }: Props) {
|
||||||
|
const letters = text.split("");
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={cn(className)}>
|
||||||
|
<AnimatePresence mode="popLayout" initial={false}>
|
||||||
|
<motion.div key={text} className="whitespace-nowrap">
|
||||||
|
<motion.span className="inline-flex overflow-hidden">
|
||||||
|
{letters.map((char, index) => (
|
||||||
|
<motion.span
|
||||||
|
key={`${text}-${index}`}
|
||||||
|
initial={{
|
||||||
|
opacity: 0,
|
||||||
|
y: 8,
|
||||||
|
rotateX: "80deg",
|
||||||
|
filter: "blur(6px)",
|
||||||
|
}}
|
||||||
|
animate={{
|
||||||
|
opacity: 1,
|
||||||
|
y: 0,
|
||||||
|
rotateX: "0deg",
|
||||||
|
filter: "blur(0px)",
|
||||||
|
}}
|
||||||
|
exit={{
|
||||||
|
opacity: 0,
|
||||||
|
y: -8,
|
||||||
|
rotateX: "-80deg",
|
||||||
|
filter: "blur(6px)",
|
||||||
|
}}
|
||||||
|
style={{ willChange: "transform" }}
|
||||||
|
transition={{
|
||||||
|
delay: 0.015 * index,
|
||||||
|
type: "spring",
|
||||||
|
bounce: 0.5,
|
||||||
|
}}
|
||||||
|
className="inline-block"
|
||||||
|
>
|
||||||
|
{char === " " ? "\u00A0" : char}
|
||||||
|
</motion.span>
|
||||||
|
))}
|
||||||
|
</motion.span>
|
||||||
|
</motion.div>
|
||||||
|
</AnimatePresence>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
.loader {
|
||||||
|
position: relative;
|
||||||
|
animation: rotate 1s infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
.loader::before,
|
||||||
|
.loader::after {
|
||||||
|
border-radius: 50%;
|
||||||
|
content: "";
|
||||||
|
display: block;
|
||||||
|
/* 40% of container size */
|
||||||
|
height: 40%;
|
||||||
|
width: 40%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.loader::before {
|
||||||
|
animation: ball1 1s infinite;
|
||||||
|
background-color: #a1a1aa; /* zinc-400 */
|
||||||
|
box-shadow: calc(var(--spacing)) 0 0 #18181b; /* zinc-900 */
|
||||||
|
margin-bottom: calc(var(--gap));
|
||||||
|
}
|
||||||
|
|
||||||
|
.loader::after {
|
||||||
|
animation: ball2 1s infinite;
|
||||||
|
background-color: #18181b; /* zinc-900 */
|
||||||
|
box-shadow: calc(var(--spacing)) 0 0 #a1a1aa; /* zinc-400 */
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes rotate {
|
||||||
|
0% {
|
||||||
|
transform: rotate(0deg) scale(0.8);
|
||||||
|
}
|
||||||
|
50% {
|
||||||
|
transform: rotate(360deg) scale(1.2);
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
transform: rotate(720deg) scale(0.8);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes ball1 {
|
||||||
|
0% {
|
||||||
|
box-shadow: calc(var(--spacing)) 0 0 #18181b;
|
||||||
|
}
|
||||||
|
50% {
|
||||||
|
box-shadow: 0 0 0 #18181b;
|
||||||
|
margin-bottom: 0;
|
||||||
|
transform: translate(calc(var(--spacing) / 2), calc(var(--spacing) / 2));
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
box-shadow: calc(var(--spacing)) 0 0 #18181b;
|
||||||
|
margin-bottom: calc(var(--gap));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes ball2 {
|
||||||
|
0% {
|
||||||
|
box-shadow: calc(var(--spacing)) 0 0 #a1a1aa;
|
||||||
|
}
|
||||||
|
50% {
|
||||||
|
box-shadow: 0 0 0 #a1a1aa;
|
||||||
|
margin-top: calc(var(--ball-size) * -1);
|
||||||
|
transform: translate(calc(var(--spacing) / 2), calc(var(--spacing) / 2));
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
box-shadow: calc(var(--spacing)) 0 0 #a1a1aa;
|
||||||
|
margin-top: 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import styles from "./OrbitLoader.module.css";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
size?: number;
|
||||||
|
className?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function OrbitLoader({ size = 24, className }: Props) {
|
||||||
|
const ballSize = Math.round(size * 0.4);
|
||||||
|
const spacing = Math.round(size * 0.6);
|
||||||
|
const gap = Math.round(size * 0.2);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={cn(styles.loader, className)}
|
||||||
|
style={
|
||||||
|
{
|
||||||
|
width: size,
|
||||||
|
height: size,
|
||||||
|
"--ball-size": `${ballSize}px`,
|
||||||
|
"--spacing": `${spacing}px`,
|
||||||
|
"--gap": `${gap}px`,
|
||||||
|
} as React.CSSProperties
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
value: number;
|
||||||
|
label?: string;
|
||||||
|
className?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ProgressBar({ value, label, className }: Props) {
|
||||||
|
const clamped = Math.min(100, Math.max(0, value));
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={cn("flex flex-col gap-1.5", className)}>
|
||||||
|
<div className="flex items-center justify-between text-xs text-neutral-500">
|
||||||
|
<span>{label ?? "Working on it..."}</span>
|
||||||
|
<span>{Math.round(clamped)}%</span>
|
||||||
|
</div>
|
||||||
|
<div className="h-2 w-full overflow-hidden rounded-full bg-neutral-200">
|
||||||
|
<div
|
||||||
|
className="h-full rounded-full bg-neutral-900 transition-[width] duration-300 ease-out"
|
||||||
|
style={{ width: `${clamped}%` }}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
.loader {
|
||||||
|
position: relative;
|
||||||
|
display: inline-block;
|
||||||
|
flex-shrink: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.loader::before,
|
||||||
|
.loader::after {
|
||||||
|
content: "";
|
||||||
|
box-sizing: border-box;
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
border-radius: 50%;
|
||||||
|
background: currentColor;
|
||||||
|
position: absolute;
|
||||||
|
left: 0;
|
||||||
|
top: 0;
|
||||||
|
animation: ripple 2s linear infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
.loader::after {
|
||||||
|
animation-delay: 1s;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes ripple {
|
||||||
|
0% {
|
||||||
|
transform: scale(0);
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
transform: scale(1);
|
||||||
|
opacity: 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import styles from "./PulseLoader.module.css";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
size?: number;
|
||||||
|
className?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function PulseLoader({ size = 24, className }: Props) {
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={cn(styles.loader, className)}
|
||||||
|
style={{ width: size, height: size }}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
.loader {
|
||||||
|
position: relative;
|
||||||
|
display: inline-block;
|
||||||
|
flex-shrink: 0;
|
||||||
|
transform: rotateZ(45deg);
|
||||||
|
perspective: 1000px;
|
||||||
|
border-radius: 50%;
|
||||||
|
color: currentColor;
|
||||||
|
}
|
||||||
|
|
||||||
|
.loader::before,
|
||||||
|
.loader::after {
|
||||||
|
content: "";
|
||||||
|
display: block;
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
left: 0;
|
||||||
|
width: inherit;
|
||||||
|
height: inherit;
|
||||||
|
border-radius: 50%;
|
||||||
|
transform: rotateX(70deg);
|
||||||
|
animation: spin 1s linear infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
.loader::after {
|
||||||
|
color: var(--spinner-accent, #a855f7);
|
||||||
|
transform: rotateY(70deg);
|
||||||
|
animation-delay: 0.4s;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes spin {
|
||||||
|
0%,
|
||||||
|
100% {
|
||||||
|
box-shadow: 0.2em 0 0 0 currentColor;
|
||||||
|
}
|
||||||
|
12% {
|
||||||
|
box-shadow: 0.2em 0.2em 0 0 currentColor;
|
||||||
|
}
|
||||||
|
25% {
|
||||||
|
box-shadow: 0 0.2em 0 0 currentColor;
|
||||||
|
}
|
||||||
|
37% {
|
||||||
|
box-shadow: -0.2em 0.2em 0 0 currentColor;
|
||||||
|
}
|
||||||
|
50% {
|
||||||
|
box-shadow: -0.2em 0 0 0 currentColor;
|
||||||
|
}
|
||||||
|
62% {
|
||||||
|
box-shadow: -0.2em -0.2em 0 0 currentColor;
|
||||||
|
}
|
||||||
|
75% {
|
||||||
|
box-shadow: 0 -0.2em 0 0 currentColor;
|
||||||
|
}
|
||||||
|
87% {
|
||||||
|
box-shadow: 0.2em -0.2em 0 0 currentColor;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import styles from "./SpinnerLoader.module.css";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
size?: number;
|
||||||
|
className?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function SpinnerLoader({ size = 24, className }: Props) {
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={cn(styles.loader, className)}
|
||||||
|
style={{ width: size, height: size }}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,235 @@
|
|||||||
|
import { Link } from "@/components/atoms/Link/Link";
|
||||||
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
/* Layout */
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
|
||||||
|
export function ContentGrid({
|
||||||
|
children,
|
||||||
|
className,
|
||||||
|
}: {
|
||||||
|
children: React.ReactNode;
|
||||||
|
className?: string;
|
||||||
|
}) {
|
||||||
|
return <div className={cn("grid gap-2", className)}>{children}</div>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
/* Card */
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
|
||||||
|
export function ContentCard({
|
||||||
|
children,
|
||||||
|
className,
|
||||||
|
}: {
|
||||||
|
children: React.ReactNode;
|
||||||
|
className?: string;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"rounded-lg bg-gradient-to-r from-purple-500/30 to-blue-500/30 p-[1px]",
|
||||||
|
className,
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<div className="rounded-lg bg-neutral-100 p-3">{children}</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Flex row with a left content area (`children`) and an optional right‑side `action`. */
|
||||||
|
export function ContentCardHeader({
|
||||||
|
children,
|
||||||
|
action,
|
||||||
|
className,
|
||||||
|
}: {
|
||||||
|
children: React.ReactNode;
|
||||||
|
action?: React.ReactNode;
|
||||||
|
className?: string;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<div className={cn("flex items-start justify-between gap-2", className)}>
|
||||||
|
<div className="min-w-0">{children}</div>
|
||||||
|
{action}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ContentCardTitle({
|
||||||
|
children,
|
||||||
|
className,
|
||||||
|
}: {
|
||||||
|
children: React.ReactNode;
|
||||||
|
className?: string;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<Text
|
||||||
|
variant="body-medium"
|
||||||
|
className={cn("truncate text-zinc-800", className)}
|
||||||
|
>
|
||||||
|
{children}
|
||||||
|
</Text>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ContentCardSubtitle({
|
||||||
|
children,
|
||||||
|
className,
|
||||||
|
}: {
|
||||||
|
children: React.ReactNode;
|
||||||
|
className?: string;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<Text
|
||||||
|
variant="small"
|
||||||
|
className={cn("mt-0.5 truncate font-mono text-zinc-800", className)}
|
||||||
|
>
|
||||||
|
{children}
|
||||||
|
</Text>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ContentCardDescription({
|
||||||
|
children,
|
||||||
|
className,
|
||||||
|
}: {
|
||||||
|
children: React.ReactNode;
|
||||||
|
className?: string;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<Text variant="body" className={cn("mt-2 text-zinc-800", className)}>
|
||||||
|
{children}
|
||||||
|
</Text>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
/* Text */
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
|
||||||
|
export function ContentMessage({
|
||||||
|
children,
|
||||||
|
className,
|
||||||
|
}: {
|
||||||
|
children: React.ReactNode;
|
||||||
|
className?: string;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<Text variant="body" className={cn("text-zinc-800", className)}>
|
||||||
|
{children}
|
||||||
|
</Text>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ContentHint({
|
||||||
|
children,
|
||||||
|
className,
|
||||||
|
}: {
|
||||||
|
children: React.ReactNode;
|
||||||
|
className?: string;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<Text variant="small" className={cn("text-neutral-500", className)}>
|
||||||
|
{children}
|
||||||
|
</Text>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
/* Code / data */
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
|
||||||
|
export function ContentCodeBlock({
|
||||||
|
children,
|
||||||
|
className,
|
||||||
|
}: {
|
||||||
|
children: React.ReactNode;
|
||||||
|
className?: string;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<pre
|
||||||
|
className={cn(
|
||||||
|
"whitespace-pre-wrap rounded-lg border bg-black p-3 text-xs text-neutral-200",
|
||||||
|
className,
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{children}
|
||||||
|
</pre>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
/* Inline elements */
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
|
||||||
|
export function ContentBadge({
|
||||||
|
children,
|
||||||
|
className,
|
||||||
|
}: {
|
||||||
|
children: React.ReactNode;
|
||||||
|
className?: string;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<Text
|
||||||
|
variant="small"
|
||||||
|
as="span"
|
||||||
|
className={cn(
|
||||||
|
"shrink-0 rounded-full border bg-muted px-2 py-0.5 text-[11px] text-zinc-800",
|
||||||
|
className,
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{children}
|
||||||
|
</Text>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ContentLink({
|
||||||
|
href,
|
||||||
|
children,
|
||||||
|
className,
|
||||||
|
...rest
|
||||||
|
}: Omit<React.ComponentProps<typeof Link>, "className"> & {
|
||||||
|
className?: string;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<Link
|
||||||
|
variant="primary"
|
||||||
|
isExternal
|
||||||
|
href={href}
|
||||||
|
className={cn("shrink-0 text-xs text-purple-500", className)}
|
||||||
|
{...rest}
|
||||||
|
>
|
||||||
|
{children}
|
||||||
|
</Link>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
/* Lists */
|
||||||
|
/* ------------------------------------------------------------------ */
|
||||||
|
|
||||||
|
export function ContentSuggestionsList({
|
||||||
|
items,
|
||||||
|
max = 5,
|
||||||
|
className,
|
||||||
|
}: {
|
||||||
|
items: string[];
|
||||||
|
max?: number;
|
||||||
|
className?: string;
|
||||||
|
}) {
|
||||||
|
if (items.length === 0) return null;
|
||||||
|
return (
|
||||||
|
<ul
|
||||||
|
className={cn(
|
||||||
|
"mt-2 list-disc space-y-1 pl-5 font-sans text-[0.75rem] leading-[1.125rem] text-zinc-800",
|
||||||
|
className,
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{items.slice(0, max).map((s) => (
|
||||||
|
<li key={s}>{s}</li>
|
||||||
|
))}
|
||||||
|
</ul>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,102 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import { CaretDownIcon } from "@phosphor-icons/react";
|
||||||
|
import { AnimatePresence, motion, useReducedMotion } from "framer-motion";
|
||||||
|
import { useId } from "react";
|
||||||
|
import { useToolAccordion } from "./useToolAccordion";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
icon: React.ReactNode;
|
||||||
|
title: React.ReactNode;
|
||||||
|
titleClassName?: string;
|
||||||
|
description?: React.ReactNode;
|
||||||
|
children: React.ReactNode;
|
||||||
|
className?: string;
|
||||||
|
defaultExpanded?: boolean;
|
||||||
|
expanded?: boolean;
|
||||||
|
onExpandedChange?: (expanded: boolean) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ToolAccordion({
|
||||||
|
icon,
|
||||||
|
title,
|
||||||
|
titleClassName,
|
||||||
|
description,
|
||||||
|
children,
|
||||||
|
className,
|
||||||
|
defaultExpanded,
|
||||||
|
expanded,
|
||||||
|
onExpandedChange,
|
||||||
|
}: Props) {
|
||||||
|
const shouldReduceMotion = useReducedMotion();
|
||||||
|
const contentId = useId();
|
||||||
|
const { isExpanded, toggle } = useToolAccordion({
|
||||||
|
expanded,
|
||||||
|
defaultExpanded,
|
||||||
|
onExpandedChange,
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"mt-2 w-full rounded-lg border border-slate-200 bg-slate-100 px-3 py-2",
|
||||||
|
className,
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
aria-expanded={isExpanded}
|
||||||
|
aria-controls={contentId}
|
||||||
|
onClick={toggle}
|
||||||
|
className="flex w-full items-center justify-between gap-3 py-1 text-left"
|
||||||
|
>
|
||||||
|
<div className="flex min-w-0 items-center gap-3">
|
||||||
|
<span className="flex shrink-0 items-center text-gray-800">
|
||||||
|
{icon}
|
||||||
|
</span>
|
||||||
|
<div className="min-w-0">
|
||||||
|
<p
|
||||||
|
className={cn(
|
||||||
|
"truncate text-sm font-medium text-gray-800",
|
||||||
|
titleClassName,
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{title}
|
||||||
|
</p>
|
||||||
|
{description && (
|
||||||
|
<p className="truncate text-xs text-slate-800">{description}</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<CaretDownIcon
|
||||||
|
className={cn(
|
||||||
|
"h-4 w-4 shrink-0 text-slate-500 transition-transform",
|
||||||
|
isExpanded && "rotate-180",
|
||||||
|
)}
|
||||||
|
weight="bold"
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<AnimatePresence initial={false}>
|
||||||
|
{isExpanded && (
|
||||||
|
<motion.div
|
||||||
|
id={contentId}
|
||||||
|
initial={{ height: 0, opacity: 0, filter: "blur(10px)" }}
|
||||||
|
animate={{ height: "auto", opacity: 1, filter: "blur(0px)" }}
|
||||||
|
exit={{ height: 0, opacity: 0, filter: "blur(10px)" }}
|
||||||
|
transition={
|
||||||
|
shouldReduceMotion
|
||||||
|
? { duration: 0 }
|
||||||
|
: { type: "spring", bounce: 0.35, duration: 0.55 }
|
||||||
|
}
|
||||||
|
className="overflow-hidden"
|
||||||
|
style={{ willChange: "height, opacity, filter" }}
|
||||||
|
>
|
||||||
|
<div className="pb-2 pt-3">{children}</div>
|
||||||
|
</motion.div>
|
||||||
|
)}
|
||||||
|
</AnimatePresence>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
import { useState } from "react";
|
||||||
|
|
||||||
|
interface UseToolAccordionOptions {
|
||||||
|
expanded?: boolean;
|
||||||
|
defaultExpanded?: boolean;
|
||||||
|
onExpandedChange?: (expanded: boolean) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface UseToolAccordionResult {
|
||||||
|
isExpanded: boolean;
|
||||||
|
toggle: () => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useToolAccordion({
|
||||||
|
expanded,
|
||||||
|
defaultExpanded = false,
|
||||||
|
onExpandedChange,
|
||||||
|
}: UseToolAccordionOptions): UseToolAccordionResult {
|
||||||
|
const [uncontrolledExpanded, setUncontrolledExpanded] =
|
||||||
|
useState(defaultExpanded);
|
||||||
|
|
||||||
|
const isControlled = typeof expanded === "boolean";
|
||||||
|
const isExpanded = isControlled ? expanded : uncontrolledExpanded;
|
||||||
|
|
||||||
|
function toggle() {
|
||||||
|
const next = !isExpanded;
|
||||||
|
if (!isControlled) setUncontrolledExpanded(next);
|
||||||
|
onExpandedChange?.(next);
|
||||||
|
}
|
||||||
|
|
||||||
|
return { isExpanded, toggle };
|
||||||
|
}
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { create } from "zustand";
|
|
||||||
|
|
||||||
interface CopilotStoreState {
|
|
||||||
isStreaming: boolean;
|
|
||||||
isSwitchingSession: boolean;
|
|
||||||
isCreatingSession: boolean;
|
|
||||||
isInterruptModalOpen: boolean;
|
|
||||||
pendingAction: (() => void) | null;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface CopilotStoreActions {
|
|
||||||
setIsStreaming: (isStreaming: boolean) => void;
|
|
||||||
setIsSwitchingSession: (isSwitchingSession: boolean) => void;
|
|
||||||
setIsCreatingSession: (isCreating: boolean) => void;
|
|
||||||
openInterruptModal: (onConfirm: () => void) => void;
|
|
||||||
confirmInterrupt: () => void;
|
|
||||||
cancelInterrupt: () => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
type CopilotStore = CopilotStoreState & CopilotStoreActions;
|
|
||||||
|
|
||||||
export const useCopilotStore = create<CopilotStore>((set, get) => ({
|
|
||||||
isStreaming: false,
|
|
||||||
isSwitchingSession: false,
|
|
||||||
isCreatingSession: false,
|
|
||||||
isInterruptModalOpen: false,
|
|
||||||
pendingAction: null,
|
|
||||||
|
|
||||||
setIsStreaming(isStreaming) {
|
|
||||||
set({ isStreaming });
|
|
||||||
},
|
|
||||||
|
|
||||||
setIsSwitchingSession(isSwitchingSession) {
|
|
||||||
set({ isSwitchingSession });
|
|
||||||
},
|
|
||||||
|
|
||||||
setIsCreatingSession(isCreatingSession) {
|
|
||||||
set({ isCreatingSession });
|
|
||||||
},
|
|
||||||
|
|
||||||
openInterruptModal(onConfirm) {
|
|
||||||
set({ isInterruptModalOpen: true, pendingAction: onConfirm });
|
|
||||||
},
|
|
||||||
|
|
||||||
confirmInterrupt() {
|
|
||||||
const { pendingAction } = get();
|
|
||||||
set({ isInterruptModalOpen: false, pendingAction: null });
|
|
||||||
if (pendingAction) pendingAction();
|
|
||||||
},
|
|
||||||
|
|
||||||
cancelInterrupt() {
|
|
||||||
set({ isInterruptModalOpen: false, pendingAction: null });
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
@@ -0,0 +1,128 @@
|
|||||||
|
import type { UIMessage, UIDataTypes, UITools } from "ai";
|
||||||
|
|
||||||
|
interface SessionChatMessage {
|
||||||
|
role: string;
|
||||||
|
content: string | null;
|
||||||
|
tool_call_id: string | null;
|
||||||
|
tool_calls: unknown[] | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
function coerceSessionChatMessages(
|
||||||
|
rawMessages: unknown[],
|
||||||
|
): SessionChatMessage[] {
|
||||||
|
return rawMessages
|
||||||
|
.map((m) => {
|
||||||
|
if (!m || typeof m !== "object") return null;
|
||||||
|
const msg = m as Record<string, unknown>;
|
||||||
|
|
||||||
|
const role = typeof msg.role === "string" ? msg.role : null;
|
||||||
|
if (!role) return null;
|
||||||
|
|
||||||
|
return {
|
||||||
|
role,
|
||||||
|
content:
|
||||||
|
typeof msg.content === "string"
|
||||||
|
? msg.content
|
||||||
|
: msg.content == null
|
||||||
|
? null
|
||||||
|
: String(msg.content),
|
||||||
|
tool_call_id:
|
||||||
|
typeof msg.tool_call_id === "string"
|
||||||
|
? msg.tool_call_id
|
||||||
|
: msg.tool_call_id == null
|
||||||
|
? null
|
||||||
|
: String(msg.tool_call_id),
|
||||||
|
tool_calls: Array.isArray(msg.tool_calls) ? msg.tool_calls : null,
|
||||||
|
};
|
||||||
|
})
|
||||||
|
.filter((m): m is SessionChatMessage => m !== null);
|
||||||
|
}
|
||||||
|
|
||||||
|
function safeJsonParse(value: string): unknown {
|
||||||
|
try {
|
||||||
|
return JSON.parse(value) as unknown;
|
||||||
|
} catch {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function toToolInput(rawArguments: unknown): unknown {
|
||||||
|
if (typeof rawArguments === "string") {
|
||||||
|
const trimmed = rawArguments.trim();
|
||||||
|
return trimmed ? safeJsonParse(trimmed) : {};
|
||||||
|
}
|
||||||
|
if (rawArguments && typeof rawArguments === "object") return rawArguments;
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function convertChatSessionMessagesToUiMessages(
|
||||||
|
sessionId: string,
|
||||||
|
rawMessages: unknown[],
|
||||||
|
): UIMessage<unknown, UIDataTypes, UITools>[] {
|
||||||
|
const messages = coerceSessionChatMessages(rawMessages);
|
||||||
|
const toolOutputsByCallId = new Map<string, unknown>();
|
||||||
|
|
||||||
|
for (const msg of messages) {
|
||||||
|
if (msg.role !== "tool") continue;
|
||||||
|
if (!msg.tool_call_id) continue;
|
||||||
|
if (msg.content == null) continue;
|
||||||
|
toolOutputsByCallId.set(msg.tool_call_id, msg.content);
|
||||||
|
}
|
||||||
|
|
||||||
|
const uiMessages: UIMessage<unknown, UIDataTypes, UITools>[] = [];
|
||||||
|
|
||||||
|
messages.forEach((msg, index) => {
|
||||||
|
if (msg.role === "tool") return;
|
||||||
|
if (msg.role !== "user" && msg.role !== "assistant") return;
|
||||||
|
|
||||||
|
const parts: UIMessage<unknown, UIDataTypes, UITools>["parts"] = [];
|
||||||
|
|
||||||
|
if (typeof msg.content === "string" && msg.content.trim()) {
|
||||||
|
parts.push({ type: "text", text: msg.content, state: "done" });
|
||||||
|
}
|
||||||
|
|
||||||
|
if (msg.role === "assistant" && Array.isArray(msg.tool_calls)) {
|
||||||
|
for (const rawToolCall of msg.tool_calls) {
|
||||||
|
if (!rawToolCall || typeof rawToolCall !== "object") continue;
|
||||||
|
const toolCall = rawToolCall as {
|
||||||
|
id?: unknown;
|
||||||
|
function?: { name?: unknown; arguments?: unknown };
|
||||||
|
};
|
||||||
|
|
||||||
|
const toolCallId = String(toolCall.id ?? "").trim();
|
||||||
|
const toolName = String(toolCall.function?.name ?? "").trim();
|
||||||
|
if (!toolCallId || !toolName) continue;
|
||||||
|
|
||||||
|
const input = toToolInput(toolCall.function?.arguments);
|
||||||
|
const output = toolOutputsByCallId.get(toolCallId);
|
||||||
|
|
||||||
|
if (output !== undefined) {
|
||||||
|
parts.push({
|
||||||
|
type: `tool-${toolName}`,
|
||||||
|
toolCallId,
|
||||||
|
state: "output-available",
|
||||||
|
input,
|
||||||
|
output: typeof output === "string" ? safeJsonParse(output) : output,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
parts.push({
|
||||||
|
type: `tool-${toolName}`,
|
||||||
|
toolCallId,
|
||||||
|
state: "input-available",
|
||||||
|
input,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parts.length === 0) return;
|
||||||
|
|
||||||
|
uiMessages.push({
|
||||||
|
id: `${sessionId}-${index}`,
|
||||||
|
role: msg.role,
|
||||||
|
parts,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return uiMessages;
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user